未验证 提交 c02f773a 编写于 作者: Y Yancey 提交者: GitHub

Fix dist demo var type error (#8600)

* Fix dist demo error

* revert trainer_id
上级 decaad5c
...@@ -784,6 +784,7 @@ class Block(object): ...@@ -784,6 +784,7 @@ class Block(object):
elif type(v) == Variable: elif type(v) == Variable:
var = Variable( var = Variable(
self, self,
type=v.type,
name=new_name, name=new_name,
error_clip=error_clip, error_clip=error_clip,
stop_gradient=stop_gradient) stop_gradient=stop_gradient)
......
...@@ -48,6 +48,7 @@ current_endpoint = os.getenv("SERVER_ENDPOINT") ...@@ -48,6 +48,7 @@ current_endpoint = os.getenv("SERVER_ENDPOINT")
# run as trainer or parameter server # run as trainer or parameter server
training_role = os.getenv("TRAINING_ROLE", training_role = os.getenv("TRAINING_ROLE",
"TRAINER") # get the training role: trainer/pserver "TRAINER") # get the training role: trainer/pserver
t.transpile(optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2) t.transpile(optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2)
if training_role == "PSERVER": if training_role == "PSERVER":
...@@ -65,8 +66,6 @@ else: ...@@ -65,8 +66,6 @@ else:
PASS_NUM = 100 PASS_NUM = 100
for pass_id in range(PASS_NUM): for pass_id in range(PASS_NUM):
fluid.io.save_persistables(exe, "./fit_a_line.model/")
fluid.io.load_persistables(exe, "./fit_a_line.model/")
for data in train_reader(): for data in train_reader():
avg_loss_value = exe.run(trainer_prog, avg_loss_value = exe.run(trainer_prog,
feed=feeder.feed(data), feed=feeder.feed(data),
......
...@@ -138,6 +138,7 @@ current_endpoint = os.getenv("SERVER_ENDPOINT") ...@@ -138,6 +138,7 @@ current_endpoint = os.getenv("SERVER_ENDPOINT")
# run as trainer or parameter server # run as trainer or parameter server
training_role = os.getenv("TRAINING_ROLE", training_role = os.getenv("TRAINING_ROLE",
"TRAINER") # get the training role: trainer/pserver "TRAINER") # get the training role: trainer/pserver
t.transpile( t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=TRAINERS) optimize_ops, params_grads, pservers=pserver_endpoints, trainers=TRAINERS)
......
...@@ -191,6 +191,7 @@ def main(): ...@@ -191,6 +191,7 @@ def main():
# run as trainer or parameter server # run as trainer or parameter server
training_role = os.getenv( training_role = os.getenv(
"TRAINING_ROLE", "TRAINER") # get the training role: trainer/pserver "TRAINING_ROLE", "TRAINER") # get the training role: trainer/pserver
t.transpile( t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2) optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2)
......
...@@ -82,6 +82,7 @@ current_endpoint = os.getenv("SERVER_ENDPOINT") ...@@ -82,6 +82,7 @@ current_endpoint = os.getenv("SERVER_ENDPOINT")
# run as trainer or parameter server # run as trainer or parameter server
training_role = os.getenv("TRAINING_ROLE", training_role = os.getenv("TRAINING_ROLE",
"TRAINER") # get the training role: trainer/pserver "TRAINER") # get the training role: trainer/pserver
t.transpile( t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=TRAINERS) optimize_ops, params_grads, pservers=pserver_endpoints, trainers=TRAINERS)
if training_role == "PSERVER": if training_role == "PSERVER":
...@@ -97,9 +98,10 @@ elif training_role == "TRAINER": ...@@ -97,9 +98,10 @@ elif training_role == "TRAINER":
feed_list=[first_word, second_word, third_word, forth_word, next_word], feed_list=[first_word, second_word, third_word, forth_word, next_word],
place=place) place=place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
trainer_prog = t.get_trainer_program()
for pass_id in range(PASS_NUM): for pass_id in range(PASS_NUM):
for data in train_reader(): for data in train_reader():
avg_cost_np = exe.run(t.get_trainer_program(), avg_cost_np = exe.run(trainer_prog,
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[avg_cost]) fetch_list=[avg_cost])
print("avg_cost_np", avg_cost_np) print("avg_cost_np", avg_cost_np)
......
...@@ -115,6 +115,7 @@ def main(): ...@@ -115,6 +115,7 @@ def main():
# run as trainer or parameter server # run as trainer or parameter server
training_role = os.getenv( training_role = os.getenv(
"TRAINING_ROLE", "TRAINER") # get the training role: trainer/pserver "TRAINING_ROLE", "TRAINER") # get the training role: trainer/pserver
t.transpile( t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2) optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2)
......
...@@ -64,11 +64,7 @@ if not current_endpoint: ...@@ -64,11 +64,7 @@ if not current_endpoint:
t = fluid.DistributeTranspiler() t = fluid.DistributeTranspiler()
t.transpile( t.transpile(
optimize_ops, optimize_ops, params_grads, pservers=pserver_endpoints, trainers=trainers)
params_grads,
0,
pservers=pserver_endpoints,
trainers=trainers)
if training_role == "PSERVER": if training_role == "PSERVER":
pserver_prog = t.get_pserver_program(current_endpoint) pserver_prog = t.get_pserver_program(current_endpoint)
......
...@@ -171,6 +171,7 @@ def main(): ...@@ -171,6 +171,7 @@ def main():
current_endpoint = os.getenv("SERVER_ENDPOINT") current_endpoint = os.getenv("SERVER_ENDPOINT")
# run as trainer or parameter server # run as trainer or parameter server
training_role = os.getenv("TRAINING_ROLE", "TRAINER") training_role = os.getenv("TRAINING_ROLE", "TRAINER")
t.transpile( t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2) optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2)
......
...@@ -90,6 +90,7 @@ def main(): ...@@ -90,6 +90,7 @@ def main():
# run as trainer or parameter server # run as trainer or parameter server
training_role = os.getenv( training_role = os.getenv(
"TRAINING_ROLE", "TRAINER") # get the training role: trainer/pserver "TRAINING_ROLE", "TRAINER") # get the training role: trainer/pserver
t.transpile( t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2) optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2)
......
...@@ -102,6 +102,7 @@ def main(): ...@@ -102,6 +102,7 @@ def main():
# run as trainer or parameter server # run as trainer or parameter server
training_role = os.getenv( training_role = os.getenv(
"TRAINING_ROLE", "TRAINER") # get the training role: trainer/pserver "TRAINING_ROLE", "TRAINER") # get the training role: trainer/pserver
t.transpile( t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2) optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册