圆月山庄资源网 Design By www.vgjia.com
model.py:
#!/usr/bin/python # -*- coding: utf-8 -*- import torch from torch import nn import numpy as np from torch.autograd import Variable import torch.nn.functional as F class TextRNN(nn.Module): """文本分类,RNN模型""" def __init__(self): super(TextRNN, self).__init__() # 三个待输入的数据 self.embedding = nn.Embedding(5000, 64) # 进行词嵌入 # self.rnn = nn.LSTM(input_size=64, hidden_size=128, num_layers=2, bidirectional=True) self.rnn = nn.GRU(input_size=64, hidden_size=128, num_layers=2, bidirectional=True) self.f1 = nn.Sequential(nn.Linear(256,128), nn.Dropout(0.8), nn.ReLU()) self.f2 = nn.Sequential(nn.Linear(128,10), nn.Softmax()) def forward(self, x): x = self.embedding(x) x,_ = self.rnn(x) x = F.dropout(x,p=0.8) x = self.f1(x[:,-1,:]) return self.f2(x) class TextCNN(nn.Module): def __init__(self): super(TextCNN, self).__init__() self.embedding = nn.Embedding(5000,64) self.conv = nn.Conv1d(64,256,5) self.f1 = nn.Sequential(nn.Linear(256*596, 128), nn.ReLU()) self.f2 = nn.Sequential(nn.Linear(128, 10), nn.Softmax()) def forward(self, x): x = self.embedding(x) x = x.detach().numpy() x = np.transpose(x,[0,2,1]) x = torch.Tensor(x) x = Variable(x) x = self.conv(x) x = x.view(-1,256*596) x = self.f1(x) return self.f2(x)
train.py:
# coding: utf-8 from __future__ import print_function import torch from torch import nn from torch import optim from torch.autograd import Variable import os import numpy as np from model import TextRNN,TextCNN from cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab base_dir = 'cnews' train_dir = os.path.join(base_dir, 'cnews.train.txt') test_dir = os.path.join(base_dir, 'cnews.test.txt') val_dir = os.path.join(base_dir, 'cnews.val.txt') vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt') def train(): x_train, y_train = process_file(train_dir, word_to_id, cat_to_id,600)#获取训练数据每个字的id和对应标签的oe-hot形式 x_val, y_val = process_file(val_dir, word_to_id, cat_to_id,600) #使用LSTM或者CNN model = TextRNN() # model = TextCNN() #选择损失函数 Loss = nn.MultiLabelSoftMarginLoss() # Loss = nn.BCELoss() # Loss = nn.MSELoss() optimizer = optim.Adam(model.parameters(),lr=0.001) best_val_acc = 0 for epoch in range(1000): batch_train = batch_iter(x_train, y_train,100) for x_batch, y_batch in batch_train: x = np.array(x_batch) y = np.array(y_batch) x = torch.LongTensor(x) y = torch.Tensor(y) # y = torch.LongTensor(y) x = Variable(x) y = Variable(y) out = model(x) loss = Loss(out,y) optimizer.zero_grad() loss.backward() optimizer.step() accracy = np.mean((torch.argmax(out,1)==torch.argmax(y,1)).numpy()) #对模型进行验证 if (epoch+1)%20 == 0: batch_val = batch_iter(x_val, y_val, 100) for x_batch, y_batch in batch_train: x = np.array(x_batch) y = np.array(y_batch) x = torch.LongTensor(x) y = torch.Tensor(y) # y = torch.LongTensor(y) x = Variable(x) y = Variable(y) out = model(x) loss = Loss(out, y) optimizer.zero_grad() loss.backward() optimizer.step() accracy = np.mean((torch.argmax(out, 1) == torch.argmax(y, 1)).numpy()) if accracy > best_val_acc: torch.save(model.state_dict(),'model_params.pkl') best_val_acc = accracy print(accracy) if __name__ == '__main__': #获取文本的类别及其对应id的字典 categories, cat_to_id = read_category() #获取训练文本中所有出现过的字及其所对应的id words, word_to_id = read_vocab(vocab_dir) #获取字数 vocab_size = len(words) train()
test.py:
# coding: utf-8 from __future__ import print_function import os import tensorflow.contrib.keras as kr import torch from torch import nn from cnews_loader import read_category, read_vocab from model import TextRNN from torch.autograd import Variable import numpy as np try: bool(type(unicode)) except NameError: unicode = str base_dir = 'cnews' vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt') class TextCNN(nn.Module): def __init__(self): super(TextCNN, self).__init__() self.embedding = nn.Embedding(5000,64) self.conv = nn.Conv1d(64,256,5) self.f1 = nn.Sequential(nn.Linear(152576, 128), nn.ReLU()) self.f2 = nn.Sequential(nn.Linear(128, 10), nn.Softmax()) def forward(self, x): x = self.embedding(x) x = x.detach().numpy() x = np.transpose(x,[0,2,1]) x = torch.Tensor(x) x = Variable(x) x = self.conv(x) x = x.view(-1,152576) x = self.f1(x) return self.f2(x) class CnnModel: def __init__(self): self.categories, self.cat_to_id = read_category() self.words, self.word_to_id = read_vocab(vocab_dir) self.model = TextCNN() self.model.load_state_dict(torch.load('model_params.pkl')) def predict(self, message): # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行 content = unicode(message) data = [self.word_to_id[x] for x in content if x in self.word_to_id] data = kr.preprocessing.sequence.pad_sequences([data],600) data = torch.LongTensor(data) y_pred_cls = self.model(data) class_index = torch.argmax(y_pred_cls[0]).item() return self.categories[class_index] class RnnModel: def __init__(self): self.categories, self.cat_to_id = read_category() self.words, self.word_to_id = read_vocab(vocab_dir) self.model = TextRNN() self.model.load_state_dict(torch.load('model_rnn_params.pkl')) def predict(self, message): # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行 content = unicode(message) data = [self.word_to_id[x] for x in content if x in self.word_to_id] data = kr.preprocessing.sequence.pad_sequences([data], 600) data = torch.LongTensor(data) y_pred_cls = self.model(data) class_index = torch.argmax(y_pred_cls[0]).item() return self.categories[class_index] if __name__ == '__main__': model = CnnModel() # model = RnnModel() test_demo = ['湖人助教力助科比恢复手感 他也是阿泰的精神导师新浪体育讯记者戴高乐报道 上赛季,科比的右手食指遭遇重创,他的投篮手感也因此大受影响。不过很快科比就调整了自己的投篮手型,并通过这一方式让自己的投篮命中率回升。而在这科比背后,有一位特别助教对科比帮助很大,他就是查克·珀森。珀森上赛季担任湖人的特别助教,除了帮助科比调整投篮手型之外,他的另一个重要任务就是担任阿泰的精神导师。来到湖人队之后,阿泰收敛起了暴躁的脾气,成为湖人夺冠路上不可或缺的一员,珀森的“心灵按摩”功不可没。经历了上赛季的成功之后,珀森本赛季被“升职”成为湖人队的全职助教,每场比赛,他都会坐在球场边,帮助禅师杰克逊一起指挥湖人球员在场上拼杀。对于珀森的工作,禅师非常欣赏,“查克非常善于分析问题,”菲尔·杰克逊说,“他总是在寻找问题的答案,同时也在找造成这一问题的原因,这是我们都非常乐于看到的。我会在平时把防守中出现的一些问题交给他,然后他会通过组织球员练习找到解决的办法。他在球员时代曾是一名很好的外线投手,不过现在他与内线球员的配合也相当不错。', '弗老大被裁美国媒体看热闹“特权”在中国像蠢蛋弗老大要走了。虽然他只在首钢男篮效力了13天,而且表现毫无亮点,大大地让球迷和俱乐部失望了,但就像中国人常说的“好聚好散”,队友还是友好地与他告别,俱乐部与他和平分手,球迷还请他留下了在北京的最后一次签名。相比之下,弗老大的同胞美国人却没那么“宽容”。他们嘲讽这位NBA前巨星的英雄迟暮,批评他在CBA的业余表现,还惊讶于中国人的“大方”。今天,北京首钢俱乐部将与弗朗西斯继续商讨解约一事。从昨日的进展来看,双方可以做到“买卖不成人意在”,但回到美国后,恐怕等待弗朗西斯的就没有这么轻松的环境了。进展@北京昨日与队友告别 最后一次为球迷签名弗朗西斯在13天里为首钢队打了4场比赛,3场的得分为0,只有一场得了2分。昨天是他来到北京的第14天,虽然他与首钢还未正式解约,但双方都明白“缘分已尽”。下午,弗朗西斯来到首钢俱乐部与队友们告别。弗朗西斯走到队友身边,依次与他们握手拥抱。“你们都对我很好,安排的条件也很好,我很喜欢这支球队,想融入你们,但我现在真的很不适应。希望你们'] for i in test_demo: print(i,":",model.predict(i))
cnews_loader.py:
# coding: utf-8 import sys from collections import Counter import numpy as np import tensorflow.contrib.keras as kr if sys.version_info[0] > 2: is_py3 = True else: reload(sys) sys.setdefaultencoding("utf-8") is_py3 = False def native_word(word, encoding='utf-8'): """如果在python2下面使用python3训练的模型,可考虑调用此函数转化一下字符编码""" if not is_py3: return word.encode(encoding) else: return word def native_content(content): if not is_py3: return content.decode('utf-8') else: return content def open_file(filename, mode='r'): """ 常用文件操作,可在python2和python3间切换. mode: 'r' or 'w' for read or write """ if is_py3: return open(filename, mode, encoding='utf-8', errors='ignore') else: return open(filename, mode) def read_file(filename): """读取文件数据""" contents, labels = [], [] with open_file(filename) as f: for line in f: try: label, content = line.strip().split('\t') if content: contents.append(list(native_content(content))) labels.append(native_content(label)) except: pass return contents, labels def build_vocab(train_dir, vocab_dir, vocab_size=5000): """根据训练集构建词汇表,存储""" data_train, _ = read_file(train_dir) all_data = [] for content in data_train: all_data.extend(content) counter = Counter(all_data) count_pairs = counter.most_common(vocab_size - 1) words, _ = list(zip(*count_pairs)) # 添加一个 <PAD> 来将所有文本pad为同一长度 words = ['<PAD>'] + list(words) open_file(vocab_dir, mode='w').write('\n'.join(words) + '\n') def read_vocab(vocab_dir): """读取词汇表""" # words = open_file(vocab_dir).read().strip().split('\n') with open_file(vocab_dir) as fp: # 如果是py2 则每个值都转化为unicode words = [native_content(_.strip()) for _ in fp.readlines()] word_to_id = dict(zip(words, range(len(words)))) return words, word_to_id def read_category(): """读取分类目录,固定""" categories = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐'] categories = [native_content(x) for x in categories] cat_to_id = dict(zip(categories, range(len(categories)))) return categories, cat_to_id def to_words(content, words): """将id表示的内容转换为文字""" return ''.join(words[x] for x in content) def process_file(filename, word_to_id, cat_to_id, max_length=600): """将文件转换为id表示""" contents, labels = read_file(filename)#读取训练数据的每一句话及其所对应的类别 data_id, label_id = [], [] for i in range(len(contents)): data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])#将每句话id化 label_id.append(cat_to_id[labels[i]])#每句话对应的类别的id # # # 使用keras提供的pad_sequences来将文本pad为固定长度 x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length) y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id)) # 将标签转换为one-hot表示 # return x_pad, y_pad def batch_iter(x, y, batch_size=64): """生成批次数据""" data_len = len(x) num_batch = int((data_len - 1) / batch_size) + 1 indices = np.random.permutation(np.arange(data_len)) x_shuffle = x[indices] y_shuffle = y[indices] for i in range(num_batch): start_id = i * batch_size end_id = min((i + 1) * batch_size, data_len) yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]
以上这篇pytorch实现用CNN和LSTM对文本进行分类方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
圆月山庄资源网 Design By www.vgjia.com
广告合作:本站广告合作请联系QQ:858582 申请时备注:广告合作(否则不回)
免责声明:本站文章均来自网站采集或用户投稿,网站不提供任何软件下载或自行开发的软件! 如有用户或公司发现本站内容信息存在侵权行为,请邮件告知! 858582#qq.com
免责声明:本站文章均来自网站采集或用户投稿,网站不提供任何软件下载或自行开发的软件! 如有用户或公司发现本站内容信息存在侵权行为,请邮件告知! 858582#qq.com
圆月山庄资源网 Design By www.vgjia.com
暂无评论...
《魔兽世界》大逃杀!60人新游玩模式《强袭风暴》3月21日上线
暴雪近日发布了《魔兽世界》10.2.6 更新内容,新游玩模式《强袭风暴》即将于3月21 日在亚服上线,届时玩家将前往阿拉希高地展开一场 60 人大逃杀对战。
艾泽拉斯的冒险者已经征服了艾泽拉斯的大地及遥远的彼岸。他们在对抗世界上最致命的敌人时展现出过人的手腕,并且成功阻止终结宇宙等级的威胁。当他们在为即将于《魔兽世界》资料片《地心之战》中来袭的萨拉塔斯势力做战斗准备时,他们还需要在熟悉的阿拉希高地面对一个全新的敌人──那就是彼此。在《巨龙崛起》10.2.6 更新的《强袭风暴》中,玩家将会进入一个全新的海盗主题大逃杀式限时活动,其中包含极高的风险和史诗级的奖励。
《强袭风暴》不是普通的战场,作为一个独立于主游戏之外的活动,玩家可以用大逃杀的风格来体验《魔兽世界》,不分职业、不分装备(除了你在赛局中捡到的),光是技巧和战略的强弱之分就能决定出谁才是能坚持到最后的赢家。本次活动将会开放单人和双人模式,玩家在加入海盗主题的预赛大厅区域前,可以从强袭风暴角色画面新增好友。游玩游戏将可以累计名望轨迹,《巨龙崛起》和《魔兽世界:巫妖王之怒 经典版》的玩家都可以获得奖励。
更新日志
2024年11月03日
2024年11月03日
- 明达年度发烧碟MasterSuperiorAudiophile2021[DSF]
- 英文DJ 《致命的温柔》24K德国HD金碟DTS 2CD[WAV+分轨][1.7G]
- 张学友1997《不老的传说》宝丽金首版 [WAV+CUE][971M]
- 张韶涵2024 《不负韶华》开盘母带[低速原抓WAV+CUE][1.1G]
- lol全球总决赛lcs三号种子是谁 S14全球总决赛lcs三号种子队伍介绍
- lol全球总决赛lck三号种子是谁 S14全球总决赛lck三号种子队伍
- 群星.2005-三里屯音乐之男孩女孩的情人节【太合麦田】【WAV+CUE】
- 崔健.2005-给你一点颜色【东西音乐】【WAV+CUE】
- 南台湾小姑娘.1998-心爱,等一下【大旗】【WAV+CUE】
- 【新世纪】群星-美丽人生(CestLaVie)(6CD)[WAV+CUE]
- ProteanQuartet-Tempusomniavincit(2024)[24-WAV]
- SirEdwardElgarconductsElgar[FLAC+CUE]
- 田震《20世纪中华歌坛名人百集珍藏版》[WAV+CUE][1G]
- BEYOND《大地》24K金蝶限量编号[低速原抓WAV+CUE][986M]
- 陈奕迅《准备中 SACD》[日本限量版] [WAV+CUE][1.2G]