未验证 提交 c02f773a 编写于 作者: Y Yancey 提交者: GitHub

Fix dist demo var type error (#8600)

* Fix dist demo error

* revert trainer_id
上级 decaad5c
......@@ -784,6 +784,7 @@ class Block(object):
elif type(v) == Variable:
var = Variable(
self,
type=v.type,
name=new_name,
error_clip=error_clip,
stop_gradient=stop_gradient)
......
......@@ -48,6 +48,7 @@ current_endpoint = os.getenv("SERVER_ENDPOINT")
# run as trainer or parameter server
training_role = os.getenv("TRAINING_ROLE",
"TRAINER") # get the training role: trainer/pserver
t.transpile(optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2)
if training_role == "PSERVER":
......@@ -65,8 +66,6 @@ else:
PASS_NUM = 100
for pass_id in range(PASS_NUM):
fluid.io.save_persistables(exe, "./fit_a_line.model/")
fluid.io.load_persistables(exe, "./fit_a_line.model/")
for data in train_reader():
avg_loss_value = exe.run(trainer_prog,
feed=feeder.feed(data),
......
......@@ -138,6 +138,7 @@ current_endpoint = os.getenv("SERVER_ENDPOINT")
# run as trainer or parameter server
training_role = os.getenv("TRAINING_ROLE",
"TRAINER") # get the training role: trainer/pserver
t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=TRAINERS)
......
......@@ -191,6 +191,7 @@ def main():
# run as trainer or parameter server
training_role = os.getenv(
"TRAINING_ROLE", "TRAINER") # get the training role: trainer/pserver
t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2)
......
......@@ -82,6 +82,7 @@ current_endpoint = os.getenv("SERVER_ENDPOINT")
# run as trainer or parameter server
training_role = os.getenv("TRAINING_ROLE",
"TRAINER") # get the training role: trainer/pserver
t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=TRAINERS)
if training_role == "PSERVER":
......@@ -97,9 +98,10 @@ elif training_role == "TRAINER":
feed_list=[first_word, second_word, third_word, forth_word, next_word],
place=place)
exe.run(fluid.default_startup_program())
trainer_prog = t.get_trainer_program()
for pass_id in range(PASS_NUM):
for data in train_reader():
avg_cost_np = exe.run(t.get_trainer_program(),
avg_cost_np = exe.run(trainer_prog,
feed=feeder.feed(data),
fetch_list=[avg_cost])
print("avg_cost_np", avg_cost_np)
......
......@@ -115,6 +115,7 @@ def main():
# run as trainer or parameter server
training_role = os.getenv(
"TRAINING_ROLE", "TRAINER") # get the training role: trainer/pserver
t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2)
......
......@@ -64,11 +64,7 @@ if not current_endpoint:
t = fluid.DistributeTranspiler()
t.transpile(
optimize_ops,
params_grads,
0,
pservers=pserver_endpoints,
trainers=trainers)
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=trainers)
if training_role == "PSERVER":
pserver_prog = t.get_pserver_program(current_endpoint)
......
......@@ -171,6 +171,7 @@ def main():
current_endpoint = os.getenv("SERVER_ENDPOINT")
# run as trainer or parameter server
training_role = os.getenv("TRAINING_ROLE", "TRAINER")
t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2)
......
......@@ -90,6 +90,7 @@ def main():
# run as trainer or parameter server
training_role = os.getenv(
"TRAINING_ROLE", "TRAINER") # get the training role: trainer/pserver
t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2)
......
......@@ -102,6 +102,7 @@ def main():
# run as trainer or parameter server
training_role = os.getenv(
"TRAINING_ROLE", "TRAINER") # get the training role: trainer/pserver
t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册