NLP - AI 写诗

文章目录

    • 代码
      • 测试已训练好的模型
      • Tokenize
      • 加载数据集
      • 训练


代码


测试已训练好的模型

由于模型较大,up主上传时使用了分卷压缩,这里我们先合并解压

# 将四个文件合并为1个
$ cat save.zip.00* > save.zip

# 解压;这一步也可以手动
$ unzip save.zip

import torch

model_path = 'xx/Chinese_Poetry_Generate-main/save.model'
model = torch.load(model_path)

generate('秋', row=4, col=7)

然后根据需要,一步步添加下方的方法。

这里我得到的效果为:


> generate('秋', row=4, col=7)
0 [CLS] 秋 树 摇 风 满 远 林 , 登 临 此 兴 可 同 吟 。 云 生 半 夜 天 风 捲 , 月 过 东 篱 露 气 深 。
1 [CLS] 秋 阴 凝 未 尽 人 意 , 老 去 人 间 已 非 客 。 一 叶 舟 中 无 所 依 , 秋 风 满 前 山 色 黑 。
2 [CLS] 秋 霜 萧 萧 一 叶 秋 , 秋 风 萧 萧 吹 远 愁 。 高 堂 夜 静 月 照 影 , 独 立 空 台 愁 白 头 。

generate('莫', row=4, col=5)
0 [CLS] 莫 把 年 华 念 , 吾 衰 鬓 未 斑 。 病 身 今 老 矣 , 衰 病 旧 依 然 。
1 [CLS] 莫 信 江 山 胜 , 还 因 古 塞 名 。 云 收 天 半 暮 , 霜 坠 雁 馀 清 。
2 [CLS] 莫 辞 寒 食 过 , 为 惜 雨 晴 时 。 风 景 今 犹 昨 , 江 山 几 更 非 。

Tokenize

from transformers import AutoTokenizer

#加载编码器
tokenizer = AutoTokenizer.from_pretrained('uer/gpt2-chinese-cluecorpussmall')

print(tokenizer)

#编码试算
tokenizer.batch_encode_plus([
    '欲出未出光辣达,千山万山如火发.须臾走向天上来,逐却残星赶却月.',
    '满目江山四望幽,白云高卷嶂烟收.日回禽影穿疏木,风递猿声入小楼.远岫似屏横碧落,断帆如叶截中流.'
])

PreTrainedTokenizerFast(name_or_path='uer/gpt2-chinese-cluecorpussmall', vocab_size=21128, model_max_len=1024, is_fast=True, padding_side='right', truncation_side='right', 
special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

{'input_ids': [[101, 3617, 1139, 3313, 1139, 1045, 6793, 6809, 117, 1283, 2255, 674, 2255, 1963, 4125, 1355, 119, 7557, 5640, 6624, 1403, 1921, 677, 3341, 117, 6852, 1316, 3655, 3215, 6628, 1316, 3299, 119, 102], [101, 4007, 4680, 3736, 2255, 1724, 3307, 2406, 117, 4635, 756, 7770, 1318, 2322, 4170, 3119, 119, 3189, 1726, 4896, 2512, 4959, 4541, 3312, 117, 7599, 6853, 4351, 1898, 1057, 2207, 3517, 119, 6823, 2273, 849, 2242, 3566, 4819, 5862, 117, 3171, 2359, 1963, 1383, 2779, 704, 3837, 119, 102]], 
'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 
'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

加载数据集

import torch


#简单数据集
class Dataset(torch.utils.data.Dataset):

    def __init__(self):
        with open('chinese_poems.txt') as f:
            lines = f.readlines()
        lines = [i.strip() for i in lines]

        self.lines = lines

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, i):
        return self.lines[i]


dataset = Dataset()

len(dataset), dataset[0]

(304752, '欲出未出光辣达,千山万山如火发.须臾走向天上来,逐却残星赶却月.')

import torch
import os
import pandas as pd


#更多数据数据集
class Dataset(torch.utils.data.Dataset):

    def __init__(self):
        data = []
        for i in os.listdir('more_datas'):
            if i == '.ipynb_checkpoints':
                continue
            data.append(pd.read_csv('more_datas/%s' % i))

        data = pd.concat(data).reset_index()

        data = data['内容']

        data = data.str.strip()

        #移除一些标点符号
        data = data.str.replace('[《》“”「」]', '', regex=True)

        #正则过滤
        select = data.str.match('^[\w,。?、!:;]+$', na=False)
        data = data[select]

        #标点符号合并
        data = data.str.replace('[?!;]', '。', regex=True)
        data = data.str.replace('[、:]', ',', regex=True)

        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        return self.data.iloc[i]


dataset = Dataset()

len(dataset), dataset[0]

(839587,
 '不饮强须饮,今日是重阳。向来健者安在,世事两茫茫。叔子去人远矣,正复何关人事,堕泪忽成行。叔子泪自堕,湮没使人伤。燕何归,鸿欲断,蝶休忙。渊明自无可奈,冷眼菊花黄。看取龙山落日,又见骑台荒草,谁弱复谁强。酒亦有何好,暂醉得相忘。')

def collate_fn(data):
    data = tokenizer.batch_encode_plus(data,
                                       padding=True,
                                       truncation=True,
                                       max_length=512,
                                       return_tensors='pt')

    data['labels'] = data['input_ids'].clone()

    return data


#数据加载器
loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=8,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
)

for i, data in enumerate(loader):
    break

for k, v in data.items():
    print(k, v.shape)

len(loader)

input_ids torch.Size([8, 130])
token_type_ids torch.Size([8, 130])
attention_mask torch.Size([8, 130])
labels torch.Size([8, 130])
104948

from transformers import AutoModelForCausalLM, GPT2Model

#加载模型
model = AutoModelForCausalLM.from_pretrained(
    'uer/gpt2-chinese-cluecorpussmall')

#统计参数量
print(sum(i.numel() for i in model.parameters()) / 10000)

with torch.no_grad():
    out = model(**data)

out['loss'], out['logits'].shape

10206.8736

(tensor(9.5407), torch.Size([8, 130, 21128]))

def generate(text, row, col):

    def generate_loop(data):
        with torch.no_grad():
            out = model(**data)

        #取最后一个字
        #[5, b, 50257]
        out = out['logits']
        #[5, 50257]
        out = out[:, -1]

        #第50大的值,以此为分界线,小于该值的全部赋值为负无穷
        #[5, 50257] -> [5, 50]
        topk_value = torch.topk(out, 50).values
        #[5, 50] -> [5] -> [5, 1]
        topk_value = topk_value[:, -1].unsqueeze(dim=1)

        #赋值
        #[5, 50257]
        out = out.masked_fill(out < topk_value, -float('inf'))

        #不允许写特殊符号
        out[:, tokenizer.sep_token_id] = -float('inf')
        out[:, tokenizer.unk_token_id] = -float('inf')
        out[:, tokenizer.pad_token_id] = -float('inf')
        for i in ',。':
            out[:, tokenizer.get_vocab()[i]] = -float('inf')

        #根据概率采样,无放回,所以不可能重复
        #[5, 50257] -> [5, 1]
        out = out.softmax(dim=1)
        out = out.multinomial(num_samples=1)

        #强制添加标点符号
        c = data['input_ids'].shape[1] / (col + 1)
        if c % 1 == 0:
            if c % 2 == 0:
                out[:, 0] = tokenizer.get_vocab()['。']
            else:
                out[:, 0] = tokenizer.get_vocab()[',']

        data['input_ids'] = torch.cat([data['input_ids'], out], dim=1)
        data['attention_mask'] = torch.ones_like(data['input_ids'])
        data['token_type_ids'] = torch.zeros_like(data['input_ids'])
        data['labels'] = data['input_ids'].clone()

        if data['input_ids'].shape[1] >= row * col + row + 1:
            return data

        return generate_loop(data)

    #重复3遍
    data = tokenizer.batch_encode_plus([text] * 3, return_tensors='pt')
    data['input_ids'] = data['input_ids'][:, :-1]
    data['attention_mask'] = torch.ones_like(data['input_ids'])
    data['token_type_ids'] = torch.zeros_like(data['input_ids'])
    data['labels'] = data['input_ids'].clone()

    data = generate_loop(data)

    for i in range(3):
        print(i, tokenizer.decode(data['input_ids'][i]))


generate('秋高气爽', row=4, col=5)

训练

from transformers import AdamW
from transformers.optimization import get_scheduler


#训练
def train():
    global model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)

    optimizer = AdamW(model.parameters(), lr=5e-5)
    scheduler = get_scheduler(name='linear',
                              num_warmup_steps=0,
                              num_training_steps=len(loader),
                              optimizer=optimizer)

    model.train()
    for i, data in enumerate(loader):
        for k in data.keys():
            data[k] = data[k].to(device)
        out = model(**data)
        loss = out['loss']

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()

        optimizer.zero_grad()
        model.zero_grad()

        if i % 1000 == 0:
            labels = data['labels'][:, 1:]
            out = out['logits'].argmax(dim=2)[:, :-1]

            select = labels != 0
            labels = labels[select]
            out = out[select]
            del select

            accuracy = (labels == out).sum().item() / labels.numel()

            lr = optimizer.state_dict()['param_groups'][0]['lr']

            print(i, loss.item(), lr, accuracy)

    model = model.to('cpu')
    torch.save(model, 'save.model')


#train()

2022-12-08

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/602840.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

网贷大数据查询要怎么保证准确性?

相信现在不少人都听说过什么是网贷大数据&#xff0c;但还有很多人都会将它跟征信混为一谈&#xff0c;其实两者有本质上的区别&#xff0c;那网贷大数据查询要怎么保证准确性呢?本文将为大家总结几点&#xff0c;感兴趣的朋友不妨去看看。 想要保证网贷大数据查询的准确度&am…

差动绕组电流互感器过电压保护器ACTB

安科瑞薛瑶瑶18701709087/17343930412 电流互感器在运行中如果二次绕组开路或一次绕组流过异常电流&#xff0c;都会在二次侧产生数千伏甚至上万伏的过电压。这不仅会使CT和二次设备损坏&#xff0c;也严重威胁运行人员的生命安全&#xff0c;并造成重大经济损失。采用电流互感…

SpringBoot多数据源配置

&#x1f353; 简介&#xff1a;java系列技术分享(&#x1f449;持续更新中…&#x1f525;) &#x1f353; 初衷:一起学习、一起进步、坚持不懈 &#x1f353; 如果文章内容有误与您的想法不一致,欢迎大家在评论区指正&#x1f64f; &#x1f353; 希望这篇文章对你有所帮助,欢…

Git知识点总结

目录 1、版本控制 1.1什么是版本控制 1.2常见的版本控制工具 1.3版本控制分类 2、集中版本控制 SVN 3、分布式版本控制 Git 2、Git与SVN的主要区别 3、软件下载 安装&#xff1a;无脑下一步即可&#xff01;安装完毕就可以使用了&#xff01; 4、启动Git 4.1常用的Li…

CentOS 7 :虚拟机网络环境配置+ 安装gcc(新手进)

虚拟机安装完centos的系统却发现无法正常联网&#xff0c;咋破&#xff01; 几个简单的步骤&#xff1a; 一、检查和设置虚拟机网络适配器 这里笔者使用的桥接模式&#xff0c;朋友们可以有不同的选项设置 二、查看宿主机的网络 以笔者的为例&#xff0c;宿主机采用wlan上网模…

在python中对Requests的理解

离上次写文章已经有小半个月了&#xff0c;但是&#xff1a; 没有动态的日子里&#xff0c;都在努力生活❤️&#xff1b;发表动态的日子里&#xff0c;都在热爱生活。&#x1f339; 目录 一、python集成工具的分类&#xff1a;1.解释Requests2. Requests3. Response对象的属性…

mvc 异步请求、异步连接、异步表单

》》》 利用Jquery ajax 》》》 mvc 异步表单 c# MVC 添加异步 jquery.unobtrusive-ajax.min.js 方法 具–>Nuget程序包管理器–>程序包管理器控制台 在控制台输入&#xff1a;PM>Install-Package Microsoft.jQuery.Unobtrusive.Ajax –version 3.0.0 回车执行即可在…

5分钟了解Flutter线程Isolate的运用以及Isolate到底是怎样执行的

5分钟了解Flutter线程Isolate的运用以及Isolate到底是什么 Isolate在dart是什么flutter线程内存隔离Isolate的使用第一种&#xff0c;无参数使用Isolate.run 第二种&#xff0c;有参数使用compute:使用Isolate.spawn Isolate与外面线程通讯Isolate以文件形式加载到内存运行 Iso…

led显示屏用什么胶水封装比较好?

led显示屏用什么胶水封装比较好&#xff1f; LED显示屏通常使用特定的胶水进行封装&#xff0c;以确保其稳定性和耐用性。常见的用于LED显示屏封装的胶水类型包括有机硅灌封胶、环氧树脂灌封胶等。 有机硅灌封胶具有优异的耐高温、防水、绝缘和密封性能&#xff0c;非常适合用…

使用MATLAB/Simulink点亮STM32开发板LED灯

使用MATLAB/Simulink点亮STM32开发板LED灯-笔记 一、STM32CubeMX新建工程二、Simulink 新建工程三、MDK导入生成的代码 一、STM32CubeMX新建工程 1. 打开 STM32CubeMX 软件&#xff0c;点击“新建工程”&#xff0c;选择中对应的型号 2. RCC 设置&#xff0c;选择 HSE(外部高…

python菜鸟级安装手册-上篇

python安装教程 电脑-右键-属性&#xff0c;确认系统类型和版本号&#xff0c;比如本案例系统是64位 win10 点击python官网&#xff0c;进行下载 适用于 Windows 的 Python 版本 |Python.org 选择第一个安装程序64位即可满足需要&#xff0c; 嵌入式程序包是压缩包版本&…

MySQL中的ON DUPLICATE KEY UPDATE和REPLACE

在 MySQL 中&#xff0c;ON DUPLICATE KEY UPDATE 和 REPLACE 语句都可以用来处理插入数据时主键或唯一键冲突的情况&#xff0c;但它们在处理冲突的方式上有所不同。它们有以下区别&#xff1a; 行为方式&#xff1a; ON DUPLICATE KEY UPDATE&#xff1a;当插入的数据行存在冲…

【竞技宝】欧冠:多特淘汰大巴黎进决赛,姆巴佩迷失

多特蒙德在本赛季欧冠半决赛第二回合较量中,跟大巴黎队狭路相逢。赛前,大部分球迷和媒体都看好坐拥姆巴佩的大巴黎队,可以靠着主场作战的优势,逆转多特蒙德晋级欧冠决赛。大巴黎队主场作战确实创造出不少得分机会,只可惜球队运气有些差,射门都打在了多特蒙德横梁上。反观多特蒙…

双翻斗雨量计学习

双翻斗雨量计用户手册&#xff08;脉冲型&#xff09; 本仪器由雨量计壳体、承雨口、漏斗、翻斗支撑、上漏斗雨量调节支架、上漏斗、汇集漏斗、计数翻斗雨量调节支架、计数翻斗、干簧管安装架、轴承螺钉、出水漏斗、腿部支架、干簧管、水平泡、调节支撑板、控制盒、调平装置、接…

IaC实战指南:DevOps的自动化基石

基础设施即代码&#xff08;Infrastructure as Code&#xff0c;IaC&#xff09;是指利用脚本、配置或编程语言创建和维护基础设施的一组实践和流程。通过IaC&#xff0c;我们可以轻松测试各个组件、实现所需的功能并在最小化停机时间的前提下进行扩展。更值得一提的是&#xf…

算法基础01一快速排序,归并排序,二分

一.排序 1.快速 排序 基于分治 确定分界点 左 右 中间 随机划分区间 左半边<x >x在右半边递归处理左右两端 #include<iostream>using namespace std;const int N 1e6 10;int n; int q[N]; void quick_sort(int q[],int l,int r) {if(l>r)return;//边界&…

k8s 资源文件参数介绍

Kubernetes资源文件yaml参数介绍 yaml 介绍 yaml 是一个类似 XML、JSON 的标记性语言。它强调以数据为中心&#xff0c;并不是以标识语言为重点例如 SpringBoot 的配置文件 application.yml 也是一个 yaml 格式的文件 语法格式 通过缩进表示层级关系不能使用tab进行缩进&am…

怎么快速分享视频文件?用二维码看视频的方法

怎样不通过传输下载分享视频内容呢&#xff1f;以前分享视频内容&#xff0c;大多会通过微信、QQ、邮箱、网盘等形式来传递。但是这种方式需要下载后才可以观看&#xff0c;不仅占用手机内存&#xff0c;而且效率也比较低&#xff0c;所以现在很多人会采用视频生成二维码的方式…

Could not resolve placeholder ‘xx.xxx.host’ in value “xxx“问题解决

Could not resolve placeholder ‘xx.xxx.host’ in value "xxx"问题解决 众多原因其中之一 springboot 项目&#xff0c;idea 配置apollo 时&#xff0c;运行指定了配置文件 uat 所以使用本地配置文件启动 时&#xff0c;一直去找uat 配置文件&#xff0c;结果自…

树莓派4b测量光照强度

1.BH1750光照强度连接图 2. BH1750工作原理 BH1750的通讯过程 第1步:发送上电命令。 发送的过程和第2步基本一致,把测量命令(0x10)改成上电命令(0x01)。第2步:发送测量命令。 下面图片上的例子,ADDR引脚是接GND的,发送的测量命令是“连续高分辨率测量(0x10)”。 发送数据…
最新文章