用 Pytorch 实现简单循环神经网络

一、歌词生成项目 想要在 nlp 方面深入,于是选择训练生成一个 RNN 网络,主要目标是自动生成歌词。在这里受到了 最浅显易懂的 PyTorch 深度学习入门 的启发,并利用 up 主提供的 源码 中的数据集。 相关代码的编写也有参考该 up 主的部分,但均为在理解内容的基础上自行编写的。另外也有对该 up 主代码中的疏漏进行修改的地方。 二、数据的获取 (1)编写数据集 原作者用爬虫获取的歌词数据被保存在 lyrics.txt 文件中。我们要将数据按可供训练的模型加载。具体来说 我们希望每一次获取数据,都能得到输入和目标输出(对本项目来说就是两段有一个文字偏差的序列) 并且将文字数字化,即 nlp 的 tokenize 为了实现批量训练,需要每次获取定长的序列 为了实现第一点,我们要继承 dataset;实现第二点,需要根据数据建立字符表;实现第三点,需要定长截取歌词句子的一部分。 另外,为了减少每次加载数据所用的时间,还需要将数据集的信息持久化。 我们的 LyricsDataset 具体实现如下。首先,我们在构造函数中通过传入的路径加载数据,判断是否已经存在处理过后的数据,如存在则加载;如不存在则读入原始数据并处理 class LyricsDataset(Dataset): def __init__(self, root_path, seq_size): self.seq_size = seq_size processed_name = "/processed/lyrics.pth" raw_name = "/raw/lyrics.txt" if os.path.exists(root_path + processed_name): print("find processed data") self.__load_processed_data(root_path + processed_name) else: print("processed data not found, will process raw data") self....

一月 17, 2023 · 4 分钟 · 642 字 · Wokron