tensorflow加载embedding模型进行可视化
Posted aijianiula
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow加载embedding模型进行可视化相关的知识,希望对你有一定的参考价值。
1.功能
采用python的gensim模块训练的word2vec模型,然后采用tensorflow读取模型可视化embedding向量
ps:采用C++版本训练的w2v模型,python的gensim模块读不了。
2.python训练word2vec模型代码
import multiprocessing from gensim.models.word2vec import Word2Vec, LineSentence print(‘开始训练‘) train_file = "/tmp/train_data" model = Word2Vec(LineSentence(train_file), size=128, workers=multiprocessing.cpu_count(), iter=10) print(‘结束‘) model.init_sims(replace=True) model.save(‘/tmp/emb.bin‘)
3.tensorflow读取模型可视化
import numpy as np import tensorflow as tf import os from gensim.models.word2vec import Word2Vec from tensorflow.contrib.tensorboard.plugins import projector log_dir = ‘/tmp/embedding_log‘ if not os.path.exists(log_dir): os.mkdir(log_dir) # load model model_file = ‘/tmp/emb.bin‘ word2vec = Word2Vec.load(model_file) # create a list of vectors embedding = np.empty((len(word2vec.vocab.keys()), word2vec.vector_size), dtype=np.float32) for i, word in enumerate(word2vec.vocab.keys()): embedding[i] = word2vec[word] # setup a TensorFlow session tf.reset_default_graph() sess = tf.InteractiveSession() X = tf.Variable([0.0], name=‘embedding‘) place = tf.placeholder(tf.float32, shape=embedding.shape) set_x = tf.assign(X, place, validate_shape=False) sess.run(tf.global_variables_initializer()) sess.run(set_x, feed_dict={place: embedding}) # write labels with open(os.path.join(log_dir, ‘metadata.tsv‘), ‘w‘) as f: for word in word2vec.vocab.keys(): f.write(word + ‘ ‘) # create a TensorFlow summary writer summary_writer = tf.summary.FileWriter(log_dir, sess.graph) config = projector.ProjectorConfig() embedding_conf = config.embeddings.add() embedding_conf.tensor_name = ‘embedding:0‘ embedding_conf.metadata_path = os.path.join(log_dir, ‘metadata.tsv‘) projector.visualize_embeddings(summary_writer, config) # save the model saver = tf.train.Saver() saver.save(sess, os.path.join(log_dir, "model.ckpt")) print("完成!")
以上是关于tensorflow加载embedding模型进行可视化的主要内容,如果未能解决你的问题,请参考以下文章
6.3 tensorflow2实现FNN推荐系统——Python实战
Tensorflow - tensorflow.models.embeddings 中没有名为“embeddings”的模块