未验证 提交 1f3caaa8 编写于 作者: K kavyasrinet 提交者: GitHub

Make notest_dist_image_classification consistent with distributed...

Make notest_dist_image_classification  consistent with distributed implementation in others. (#7899)

* Make this file consistent with others

* fixed style
上级 06e22637
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
from __future__ import print_function from __future__ import print_function
import sys
import paddle.v2 as paddle import paddle.v2 as paddle
import paddle.v2.fluid as fluid import paddle.v2.fluid as fluid
import os import os
...@@ -106,10 +104,10 @@ if len(sys.argv) >= 2: ...@@ -106,10 +104,10 @@ if len(sys.argv) >= 2:
net_type = sys.argv[1] net_type = sys.argv[1]
if net_type == "vgg": if net_type == "vgg":
print("train vgg net") print("training vgg net")
net = vgg16_bn_drop(images) net = vgg16_bn_drop(images)
elif net_type == "resnet": elif net_type == "resnet":
print("train resnet") print("training resnet")
net = resnet_cifar10(images, 32) net = resnet_cifar10(images, 32)
else: else:
raise ValueError("%s network is not supported" % net_type) raise ValueError("%s network is not supported" % net_type)
...@@ -129,6 +127,7 @@ train_reader = paddle.batch( ...@@ -129,6 +127,7 @@ train_reader = paddle.batch(
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
place = fluid.CPUPlace() place = fluid.CPUPlace()
feeder = fluid.DataFeeder(place=place, feed_list=[images, label])
exe = fluid.Executor(place) exe = fluid.Executor(place)
t = fluid.DistributeTranspiler() t = fluid.DistributeTranspiler()
...@@ -146,17 +145,14 @@ if training_role == "PSERVER": ...@@ -146,17 +145,14 @@ if training_role == "PSERVER":
if not current_endpoint: if not current_endpoint:
print("need env SERVER_ENDPOINT") print("need env SERVER_ENDPOINT")
exit(1) exit(1)
print("start pserver at:", current_endpoint)
pserver_prog = t.get_pserver_program(current_endpoint) pserver_prog = t.get_pserver_program(current_endpoint)
pserver_startup = t.get_startup_program(current_endpoint, pserver_prog) pserver_startup = t.get_startup_program(current_endpoint, pserver_prog)
exe.run(pserver_startup) exe.run(pserver_startup)
exe.run(pserver_prog) exe.run(pserver_prog)
print("pserver run end")
elif training_role == "TRAINER": elif training_role == "TRAINER":
print("start trainer")
trainer_prog = t.get_trainer_program() trainer_prog = t.get_trainer_program()
feeder = fluid.DataFeeder(place=place, feed_list=[images, label])
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
for pass_id in range(PASS_NUM): for pass_id in range(PASS_NUM):
accuracy.reset(exe) accuracy.reset(exe)
for data in train_reader(): for data in train_reader():
...@@ -164,9 +160,10 @@ elif training_role == "TRAINER": ...@@ -164,9 +160,10 @@ elif training_role == "TRAINER":
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[avg_cost] + accuracy.metrics) fetch_list=[avg_cost] + accuracy.metrics)
pass_acc = accuracy.eval(exe) pass_acc = accuracy.eval(exe)
print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str( print("pass_id:" + str(pass_id) + "loss:" + str(loss) + " pass_acc:"
pass_acc)) + str(pass_acc))
# this model is slow, so if we can train two mini batch, we think it works properly. # this model is slow, so if we can train two mini batches,
# we think it works properly.
print("trainer run end") print("trainer run end")
else: else:
print("environment var TRAINER_ROLE should be TRAINER os PSERVER") print("environment var TRAINER_ROLE should be TRAINER os PSERVER")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册