How to do sparse input text classification(dnn) using tensorflow
Posted 游园惊梦(https://github.com/chengh
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了How to do sparse input text classification(dnn) using tensorflow相关的知识,希望对你有一定的参考价值。
You can get complete example code from
https://github.com/chenghuige/hasky/tree/master/applications
Including
- How to parse libsvm dataset file to tfrecords
- Reading tfrecords and do dnn/logistic regresssion classifciation/regresssion
- Train + evaluate
- See train process (loss and metric track) in tensorboard
- Show how to use melt.train_flow to handle all other things(optimizer, learning rate, model saving, log …)
The main realated code:
melt.tfrecords/libsvm_decode #parsing libsvm file
melt.models.mlp
def forward(inputs,
num_outputs,
input_dim=None,
hiddens=[200],
activation_fn=tf.nn.relu,
weights_initializer=initializers.xavier_initializer(),
weights_regularizer=None,
biases_initializer=init_ops.zeros_initializer(),
biases_regularizer=None,
reuse=None,
scope=None
):
text-classfication/model.py shows how to use this
You must specify num_outputs and input_dim for sparse input dataset
For example 10 classes classficiation problem then num_outputs=10
If you do regresssion then num_outputs=1
input_dim should be the same as your dataset num input features
You may change hiddens, the default is [200], means only 1 hidden layer size 200,
You can use more hiddens like [200, 100, 100] means 3 hidden layers with size 200,100,100
You may also set hiddens [] empty , means you only do logistic regression
What\'s the diff between melt.layers.fully_connected and tf.contrib.layers.fully_connected?
Well similary but we will also deal with sparse input, the main difference in here
We use melt.matmul
def matmul(X, w): |
|
|
if isinstance(X, tf.Tensor): |
|
return tf.matmul(X,w) |
|
else: |
|
#X[0] index, X[1] value |
|
return tf.nn.embedding_lookup_sparse(w, X[0], X[1], combiner=\'sum\') |
来自 <https://github.com/chenghuige/tensorflow-example/blob/master/util/melt/ops/ops.py>
Tensorboard show:
以上是关于How to do sparse input text classification(dnn) using tensorflow的主要内容,如果未能解决你的问题,请参考以下文章
How to use GITHUB to do source control
How to use GITHUB to do source control
How do I add elements to a Scala List?
How to use CAR FANS C800 Diagnostic Scan Tool to do diagnosis operation