最简单的理解tensorflow分布式计算的例子

时间:2022-05-20 14:01:51

目的是在每个worker使count+1

import tensorflow as tf

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('job_name', '', '')
tf.app.flags.DEFINE_string('ps_hosts', '', '')
tf.app.flags.DEFINE_string('worker_hosts', '','')
tf.app.flags.DEFINE_integer('task_index', 0, '')

ps_hosts = FLAGS.ps_hosts.split(',')
worker_hosts = FLAGS.worker_hosts.split(',')
cluster_spec = tf.train.ClusterSpec({'ps': ps_hosts,'worker': worker_hosts})
server = tf.train.Server(
{'ps': ps_hosts,'worker': worker_hosts},
job_name=FLAGS.job_name,
task_index=FLAGS.task_index)

if FLAGS.job_name == 'ps':
server.join()

with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % FLAGS.task_index,
cluster=cluster_spec)):
count = tf.Variable(0)
increment_count = count.assign_add(1)
init = tf.global_variables_initializer()

sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
logdir="./checkpoint/",
init_op=init,
summary_op=None,
saver=None,
global_step=None,
save_model_secs=60)

with sv.managed_session(server.target) as sess:
sess.run(init)
step = 1
while step <= 200000:
result = sess.run(increment_count)
if step%10000 == 0:
print(result)
if result==2:
print("!!!!!!!!")
step += 1
print("Finished!")

sv.stop()

运行方法
在10.100.208.71执行

python test_dis_hw.py --job_name=worker --ps_hosts=10.100.208.67:1111 --worker_hosts=10.100.208.67:2222,10.100.208.71:2223 --task_index=1

在10.100.208.67执行

python test_dis_hw.py --job_name=worker --ps_hosts=10.100.208.67:1111 --worker_hosts=10.100.208.67:2222,10.100.208.71:2223 --task_index=0

在10.100.208.67执行

python test_dis_hw.py --job_name=ps --ps_hosts=10.100.208.67:1111 --worker_hosts=10.100.208.67:2222,10.100.208.71:2223 --task_index=0