PyTorch:通过生成的模型运行看不见的文本
Posted
技术标签:
【中文标题】PyTorch:通过生成的模型运行看不见的文本【英文标题】:PyTorch: Running unseen text through generated model 【发布时间】:2020-11-16 16:22:12 【问题描述】:我正在尝试实现一个 PyTorch 项目,发现 here。
import os
from process_file import process_doc
import random
import torch
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
import numpy as np
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import classification_report, confusion_matrix
from full_model import Classifier
import torch.nn as nn
import torch.optim as optim
import time
def get_batch(doc, ref_type='headline'):
sent, ls, out, sids = [], [], [], []
sent.append(doc.headline)
ls.append(len(doc.headline))
for sid in doc.sentences:
if SPEECH:
out.append(out_map[doc.sent_to_speech.get(sid, 'NA')])
else:
out.append(out_map[doc.sent_to_event.get(sid)])
sent.append(doc.sentences[sid])
ls.append(len(doc.sentences[sid]))
sids.append(sid)
ls = torch.LongTensor(ls)
out = torch.LongTensor(out)
return sent, ls, out, sids
def train(epoch, data):
start_time = time.time()
total_loss = 0
global prev_best_macro
for ind, doc in enumerate(data):
model.train()
optimizer.zero_grad()
sent, ls, out, _ = get_batch(doc)
if has_cuda:
ls = ls.cuda()
out = out.cuda()
_output, _, _, _ = model.forward(sent, ls)
loss = criterion(_output, out)
total_loss += loss.item()
loss.backward()
optimizer.step()
del sent, ls, out
if has_cuda:
torch.cuda.empty_cache()
print("--Training--\nEpoch: ", epoch, "Loss: ", total_loss, "Time Elapsed: ", time.time()-start_time)
perf = evaluate(validate_data)
# print(perf)
if prev_best_macro < perf:
prev_best_macro = perf
print ("-------------------Test start-----------------------")
_ = evaluate(test_data, True)
print ("-------------------Test end-----------------------")
torch.save(model.state_dict(), 'discourse_lstm_model.pt')
def evaluate(data, is_test=False):
y_true, y_pred = [], []
model.eval()
for doc in data:
sent, ls, out, sids = get_batch(doc)
if has_cuda:
ls = ls.cuda()
#out = out.cuda()
_output, _, _, _ = model.forward(sent, ls)
_output = _output.squeeze()
_, predict = torch.max(_output, 1)
y_pred += list(predict.cpu().numpy() if has_cuda else predict.numpy())
temp_true = list(out.numpy())
y_true += temp_true
print("MACRO: ", precision_recall_fscore_support(y_true, y_pred, average='macro'))
print("MICRO: ", precision_recall_fscore_support(y_true, y_pred, average='micro'))
if is_test:
print("Classification Report \n", classification_report(y_true, y_pred))
print("Confusion Matrix \n", confusion_matrix(y_true, y_pred))
return precision_recall_fscore_support(y_true, y_pred, average='macro')[2]
if __name__ == '__main__':
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
# parser.add_argument('--drop', help='DROP', default=6, type=float)
# parser.add_argument('--learn_rate', help='LEARNING RATE', default=0, type=float)
# parser.add_argument('--loss_wt', help='LOSS WEIGHTS', default=0, type=str)
parser.add_argument('--seed', help='SEED', default=0, type=int)
args = parser.parse_args()
has_cuda = torch.cuda.is_available()
SPEECH = 0
if SPEECH:
out_map = 'NA':0, 'Speech':1
else:
out_map = 'NA':0,'Main':1,'Main_Consequence':2, 'Cause_Specific':3, 'Cause_General':4, 'Distant_Historical':5,
'Distant_Anecdotal':6, 'Distant_Evaluation':7, 'Distant_Expectations_Consequences':8
train_data = []
validate_data = []
test_data = []
for domain in ["Business", "Politics", "Crime", "Disaster", "kbp"]:
subdir = "../data/train/"+domain
files = os.listdir(subdir)
for file in files:
if '.txt' in file:
doc = process_doc(os.path.join(subdir, file), domain) #'../data/Business/nyt_corpus_data_2007_04_27_1843240.txt'
#print(doc.sent_to_event)
train_data.append(doc)
subdir = "../data/test/"+domain
files = os.listdir(subdir)
for file in files:
if '.txt' in file:
doc = process_doc(os.path.join(subdir, file), domain) #'../data/Business/nyt_corpus_data_2007_04_27_1843240.txt'
#print(doc.sent_to_event)
test_data.append(doc)
subdir = "../data/validation"
files = os.listdir(subdir)
for file in files:
if '.txt' in file:
doc = process_doc(os.path.join(subdir, file), 'VAL') #'../data/Business/nyt_corpus_data_2007_04_27_1843240.txt'
#print(doc.sent_to_event)
validate_data.append(doc)
print(len(train_data), len(validate_data), len(test_data))
seed = args.seed
np.random.seed(seed)
torch.manual_seed(seed)
if has_cuda:
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
prev_best_macro = 0.
model = Classifier('num_layers': 1, 'hidden_dim': 512, 'bidirectional': True, 'embedding_dim': 1024,
'dropout': 0.5, 'out_dim': len(out_map))
if has_cuda:
model = model.cuda()
model.init_weights()
criterion = nn.CrossEntropyLoss()
print("Model Created")
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.Adam(params, lr=5e-5, betas=[0.9, 0.999], eps=1e-8, weight_decay=0)
try:
for epoch in range(15):
print("---------------------------Started Training Epoch = 0--------------------------".format(epoch+1))
train(epoch, train_data)
except KeyboardInterrupt:
print ("----------------- INTERRUPTED -----------------")
evaluate(validate_data)
evaluate(test_data)
运行此代码,我成功输出了一个在约 400 篇文章的语料库上训练的 .pt 模型,每篇文章都根据其部分内容(数据来自 Github 存储库)进行注释。
现在,我想使用这个模型来注释一篇新的、看不见的文章,但我不知道该怎么做。我感觉分类代码已经在上面的 sn-p 中实现,我非常感谢任何关于如何使用此代码对看不见的文章进行分类的帮助/指导。提前非常感谢!
【问题讨论】:
只需将subdir = "../data/validation"
更改为您的文章所在的文件夹。注意文章必须是.txt
。
另外,请参阅How to create a Minimal, Reproducible Example
【参考方案1】:
嗯,您的代码已经进行了训练和测试。您只需要一段额外的代码来加载经过训练的模型和测试数据以进行推理。
应该是这样的:
# load trained model:
model = Classifier('num_layers': 1, 'hidden_dim': 512, 'bidirectional': True, 'embedding_dim': 1024,
'dropout': 0.5, 'out_dim': len(out_map))
model.load_state_dict(torch.load("PATH/TO/SAVED/MODEL.pt")
# load data:
subdir = "PATH/TO/DOCS/"
files = os.listdir(subdir)
validate_data = []
for file in files:
if '.txt' in file:
doc = process_doc(os.path.join(subdir, file), 'VAL')
validate_data.append(doc)
print(len(train_data), len(validate_data), len(test_data))
# use the evaluate function with is_test=True for inference.
evaluate(data, is_test=True)
【讨论】:
非常感谢!只是为了澄清一下,这段代码在哪里插入到上面的 sn-p 中/如何更改上面的代码以合并您的 sn-p? 您应该将它集成到您的代码中!如何做到这一点是您的设计选择。 您可以在if __name__ == '__main__'
下添加这些额外的行。您可以使用arg_parse
从命令行中选择推理模式。
1) 如果你设置了is_test=False
,那么你应该没有这个问题。 2) 在 ***,我们很乐意帮助您解决编码问题。但是您不能要求我们为您完成编码部分。在提出what-should-I-do-here?
类型的问题之前,请花一些时间了解您的代码库。
我的错,您确实必须更改evaluate()
函数,因此它不需要.ann
文件以上是关于PyTorch:通过生成的模型运行看不见的文本的主要内容,如果未能解决你的问题,请参考以下文章