BiLSTM+CRF(Keras)
Posted cyandn
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了BiLSTM+CRF(Keras)相关的知识,希望对你有一定的参考价值。
数据集为玻森命名实体数据。
目前代码流程跑通了,后续再进行优化。
项目地址:https://github.com/cyandn/practice/tree/master/NER
步骤:
数据预处理:
def data_process(): zh_punctuation = [‘,‘, ‘。‘, ‘?‘, ‘;‘, ‘!‘, ‘……‘] with open(‘data/BosonNLP_NER_6C_process.txt‘, ‘w‘, encoding=‘utf8‘) as fw: with open(‘data/BosonNLP_NER_6C.txt‘, encoding=‘utf8‘) as fr: for line in fr.readlines(): line = ‘‘.join(line.split()).replace(‘\\n‘, ‘‘) # 去除文本中的空字符 i = 0 while i < len(line): word = line[i] if word in zh_punctuation: fw.write(word + ‘/O‘) fw.write(‘\n‘) i += 1 continue if word == ‘‘: i += 2 temp = ‘‘ while line[i] != ‘‘: temp += line[i] i += 1 i += 2 type_ne = temp.split(‘:‘) etype = type_ne[0] entity = type_ne[1] fw.write(entity[0] + ‘/B_‘ + etype + ‘ ‘) for item in entity[1:]: fw.write(item + ‘/I_‘ + etype + ‘ ‘) else: fw.write(word + ‘/O ‘) i += 1
加载数据:
def load_data(self): maxlen = 0 with open(‘data/BosonNLP_NER_6C_process.txt‘, encoding=‘utf8‘) as f: for line in f.readlines(): word_list = line.strip().split() one_sample, one_label = zip( *[word.rsplit(‘/‘, 1) for word in word_list]) one_sample_len = len(one_sample) if one_sample_len > maxlen: maxlen = one_sample_len one_sample = ‘ ‘.join(one_sample) one_label = [config.classes[label] for label in one_label] self.total_sample.append(one_sample) self.total_label.append(one_label) tok = Tokenizer() tok.fit_on_texts(self.total_sample) self.vocabulary = len(tok.word_index) + 1 self.total_sample = tok.texts_to_sequences(self.total_sample) self.total_sample = np.array(pad_sequences( self.total_sample, maxlen=maxlen, padding=‘post‘, truncating=‘post‘)) self.total_label = np.array(pad_sequences( self.total_label, maxlen=maxlen, padding=‘post‘, truncating=‘post‘))[:, :, None] print(‘total_sample shape:‘, self.total_sample.shape) print(‘total_label shape:‘, self.total_label.shape) X_train, self.X_test, y_train, self.y_test = train_test_split( self.total_sample, self.total_label, test_size=config.proportion[‘test‘], random_state=666) self.X_train, self.X_val, self.y_train, self.y_val = train_test_split( X_train, y_train, test_size=config.proportion[‘val‘], random_state=666) print(‘X_train shape:‘, self.X_train.shape) print(‘y_train shape:‘, self.y_train.shape) print(‘X_val shape:‘, self.X_val.shape) print(‘y_val shape:‘, self.y_val.shape) print(‘X_test shape:‘, self.X_test.shape) print(‘y_test shape:‘, self.y_test.shape) del self.total_sample del self.total_label
构建模型:
def build_model(self): model = Sequential() model.add(Embedding(self.vocabulary, 100, mask_zero=True)) model.add(Bidirectional(LSTM(64, return_sequences=True))) model.add(CRF(len(config.classes), sparse_target=True)) model.summary() opt = Adam(lr=config.hyperparameter[‘learning_rate‘]) model.compile(opt, loss=crf_loss, metrics=[crf_viterbi_accuracy]) self.model = model
训练:
def train(self): save_dir = os.path.join(os.getcwd(), ‘saved_models‘) model_name = ‘epoch:03d_val_crf_viterbi_accuracy:.4f.h5‘ if not os.path.isdir(save_dir): os.makedirs(save_dir) tensorboard = TensorBoard() checkpoint = ModelCheckpoint(os.path.join(save_dir, model_name), monitor=‘val_crf_viterbi_accuracy‘, save_best_only=True) lr_reduce = ReduceLROnPlateau( monitor=‘val_crf_viterbi_accuracy‘, factor=0.2, patience=10) self.model.fit(self.X_train, self.y_train, batch_size=config.hyperparameter[‘batch_size‘], epochs=config.hyperparameter[‘epochs‘], callbacks=[tensorboard, checkpoint, lr_reduce], validation_data=[self.X_val, self.y_val])
预测:
def evaluate(self): best_model_name = sorted(os.listdir(‘saved_models‘)).pop() self.best_model = load_model(os.path.join(‘saved_models‘, best_model_name), custom_objects=‘CRF‘: CRF, ‘crf_loss‘: crf_loss, ‘crf_viterbi_accuracy‘: crf_viterbi_accuracy) scores = self.best_model.evaluate(self.X_test, self.y_test) print(‘test loss:‘, scores[0]) print(‘test accuracy:‘, scores[1])
参考:
https://zhuanlan.zhihu.com/p/44042528
https://blog.csdn.net/buppt/article/details/81180361
https://github.com/stephen-v/zh-NER-keras
http://www.voidcn.com/article/p-pykfinyn-bro.html
以上是关于BiLSTM+CRF(Keras)的主要内容,如果未能解决你的问题,请参考以下文章