import tensorflow as tf FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string("job_name", " ", "启动服务的类型ps or worker") tf.app.flags.DEFINE_integer("task_index", 0, "指定ps或者worker当中的哪一台服务器以task:0,task:1") def main(argv): # 定义一个全局计数的op,给钩子列表中的训练步数使用 global_step = tf.contrib.framework.get_or_create_global_step() # 指定集群描述对象,ps worker,多台worker或者ps的定位规则,第一台:/job:worker/task:0,第二台:/job:worker/task:1,ps也是如此 cluster = tf.train.ClusterSpec({"ps":["192.168.0.4:2222",], "worker":["192.168.109.128:2323",]}) # 创建不同的服务 ps worker,job_name指定是ps还是worker,task_index,指定启动哪台服务器 server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) # 根据不同的服务器做不同的事情,ps保存参数,worker指定设备运行模型计算 if FLAGS.job_name == ‘ps‘: # 参数服务器只需接受参数 server.join() else: worker_device = "/job:worker/task:0/cpu:0/" # 指定设备去运行 with tf.device(tf.train.replica_device_setter(worker_device=worker_device, cluster=cluster)): # 演示一个矩阵乘法运算 x = tf.Variable([[1, 2, 3, 4]]) w = tf.Variable([[2], [4], [5], [7]]) mat = tf.matmul(x, w) # 创建分布式会话 with tf.train.MonitoredTrainingSession( master="grpc://192.168.0.1:2222", # 指定是否是主work is_chief=(FLAGS.task_index==0), # 判断书否是主worker cOnfig=tf.ConfigProto(log_device_placement =True), # 打印设备信息 hooks=[tf.train.StopAtStepHook(last_step=1000)] # 指定训练步数,指定步数需要定义一个全局计数的op ) as mon_sess: while not mon_sess.should_stop(): # should_stops是否异常停止 mon_sess.run(mat) if __name__ == "__main__": tf.app.run()
第十五节 分布式系统