tensorflow-分布式
Posted wyply115
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow-分布式相关的知识,希望对你有一定的参考价值。
概述
分布式Tensorflow是由高性能的gRPC框架作为底层技术来支持的。这是一个通信框架gRPC(google remote procedure call),是一个高性能、跨平台的RPC框架。RPC协议,即远程过程调用协议,是指通过网络从远程计算机程序上请求服务。
- 分布式架构
注:参数作业所在的服务器称为参数服务器(parameter server),负责管理
参数的存储和更新;工作节点的服务器主要从事计算的任务,如运行操作,
worker节点中需要一个主节点来进行会话初始化,创建文件等操作,其他节点等
待进行计算。
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string("job_name", " ", "启动服务的类型ps or worker")
tf.app.flags.DEFINE_integer("task_index", 0, "指定ps或者worker当中的那一台服务器以task:0 ,task:1")
def main(argv):
# 定义全集计数的op ,给钩子列表当中的训练步数使用
global_step = tf.contrib.framework.get_or_create_global_step()
# 指定集群描述对象, ps , worker,端口指定为没有占用的即可,ip写实际服务器的ip地址
cluster = tf.train.ClusterSpec("ps": ["ip:2223"], "worker": ["ip:2222"])
# 创建不同的服务, ps, worker
server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
# 根据不同服务做不同的事情 ps:去更新保存参数 worker:指定设备去运行模型计算
if FLAGS.job_name == "ps":
# 参数服务器什么都不用干,只需要等待worker传递参数
server.join()
else:
worker_device = "/job:worker/task:0/cpu:0/"
# 可以指定设备去运行
with tf.device(tf.train.replica_device_setter(
worker_device=worker_device,
cluster=cluster
)):
# 简单做一个矩阵乘法运算
x = tf.Variable([[1, 2, 3, 4]])
w = tf.Variable([[2], [2], [2], [2]])
mat = tf.matmul(x, w)
# 创建分布式会话
with tf.train.MonitoredTrainingSession(
master= "grpc://ip:2222", # 指定主worker
is_chief= (FLAGS.task_index == 0),# 判断是否是主worker
config=tf.ConfigProto(log_device_placement=True),# 打印设备信息
hooks=[tf.train.StopAtStepHook(last_step=200)] # 最大步数
) as mon_sess:
while not mon_sess.should_stop():# 当没有异常时打印运算信息
print(mon_sess.run(mat))
if __name__ == "__main__":
tf.app.run()
代码复制一份放入ps服务器,并启动。然后在worker服务器也启动代码即可。
矩阵运算的结果是20,结果输出200个[20]。
以上是关于tensorflow-分布式的主要内容,如果未能解决你的问题,请参考以下文章
TensorFlow——分布式的TensorFlow运行环境