als_tf
Posted kayy
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了als_tf相关的知识,希望对你有一定的参考价值。
# -*- coding: utf-8 -*-
import tensorflow as tf
def als_tf(user_item_matrix, feature_num, train_times, reg):
user_num, item_num = user_item_matrix.shape
x = tf.placeholder(tf.float32, [user_num, item_num])
user_feature = tf.Variable(tf.random_normal([user_num, feature_num], 0, 0.1), name=‘user_feature‘)
item_feature = tf.Variable(tf.random_normal([item_num, feature_num], 0, 0.1), name=‘item_feature‘)
loss = 1.0 * tf.reduce_sum((tf.matmul(user_feature, item_feature, transpose_b=True) - x) ** 2) +
1.0 * reg * ((tf.reduce_sum(user_feature ** 2) + tf.reduce_sum(item_feature ** 2)))
global_step = tf.Variable(0, trainable=False, name=‘global_step‘)
learning_rate_init = 0.03
learning_rate = tf.train.exponential_decay(learning_rate_init, global_step, 10, 0.99, staircase=True)
train_step = tf.train.AdagradOptimizer(learning_rate).minimize(loss, global_step=global_step)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
for i in range(train_times):
train_feed_dict = {
x: user_item_matrix
}
sess.run(train_step, feed_dict=train_feed_dict)
if i % 10 == 0:
print(i, "train_times, loss: ", sess.run(loss, feed_dict=train_feed_dict))
import numpy as np
if __name__ == ‘__main__‘:
x = np.random.randint(0, 1, [30, 50])
als_tf(x, 5, 100, 0.5)
以上是关于als_tf的主要内容,如果未能解决你的问题,请参考以下文章