提交 b38452df 编写于 作者: T typhoonzero

fix styles

上级 cb34f6a2
...@@ -12,4 +12,5 @@ Check the logs for the distributed training progress and analyze the performance ...@@ -12,4 +12,5 @@ Check the logs for the distributed training progress and analyze the performance
## Enable verbos logs ## Enable verbos logs
Edit `pserver.yaml` and `trainer.yaml` and add an environment variable `GLOG_v=3` to see what happend in detail. Edit `pserver.yaml` and `trainer.yaml` and add an environment variable `GLOG_v=3` to see what happend in detail.
\ No newline at end of file
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/bin/env python #!/bin/env python
import os import os
import sys import sys
...@@ -33,6 +47,7 @@ def wait_pods_running(label_selector, desired): ...@@ -33,6 +47,7 @@ def wait_pods_running(label_selector, desired):
print 'current cnt: %d sleep for 5 seconds...' % count print 'current cnt: %d sleep for 5 seconds...' % count
time.sleep(5) time.sleep(5)
def count_pods_by_phase(label_selector, phase): def count_pods_by_phase(label_selector, phase):
pod_list = fetch_pods_info(label_selector) pod_list = fetch_pods_info(label_selector)
filtered_pod_list = filter(lambda x: x[0] == phase, pod_list) filtered_pod_list = filter(lambda x: x[0] == phase, pod_list)
...@@ -45,12 +60,14 @@ def fetch_pserver_ips(): ...@@ -45,12 +60,14 @@ def fetch_pserver_ips():
pserver_ips = [item[1] for item in pod_list] pserver_ips = [item[1] for item in pod_list]
return ",".join(pserver_ips) return ",".join(pserver_ips)
def fetch_master_ip(): def fetch_master_ip():
label_selector = "paddle-job-master=%s" % PADDLE_JOB_NAME label_selector = "paddle-job-master=%s" % PADDLE_JOB_NAME
pod_list = fetch_pods_info(label_selector) pod_list = fetch_pods_info(label_selector)
master_ips = [item[1] for item in pod_list] master_ips = [item[1] for item in pod_list]
return master_ips[0] return master_ips[0]
def fetch_trainer_id(): def fetch_trainer_id():
label_selector = "paddle-job=%s" % PADDLE_JOB_NAME label_selector = "paddle-job=%s" % PADDLE_JOB_NAME
pod_list = fetch_pods_info(label_selector) pod_list = fetch_pods_info(label_selector)
...@@ -75,4 +92,3 @@ if __name__ == "__main__": ...@@ -75,4 +92,3 @@ if __name__ == "__main__":
print count_pods_by_phase(sys.argv[2], sys.argv[3]) print count_pods_by_phase(sys.argv[2], sys.argv[3])
elif command == "wait_pods_running": elif command == "wait_pods_running":
wait_pods_running(sys.argv[2], sys.argv[3]) wait_pods_running(sys.argv[2], sys.argv[3])
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.v2 as paddle import paddle.v2 as paddle
paddle.dataset.cifar.train10() paddle.dataset.cifar.train10()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VGG16 benchmark in Fluid""" """VGG16 benchmark in Fluid"""
from __future__ import print_function from __future__ import print_function
...@@ -11,6 +25,7 @@ import argparse ...@@ -11,6 +25,7 @@ import argparse
import functools import functools
import os import os
def str2bool(v): def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'): if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True return True
...@@ -19,6 +34,7 @@ def str2bool(v): ...@@ -19,6 +34,7 @@ def str2bool(v):
else: else:
raise argparse.ArgumentTypeError('Boolean value expected.') raise argparse.ArgumentTypeError('Boolean value expected.')
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
'--batch_size', type=int, default=128, help="Batch size for training.") '--batch_size', type=int, default=128, help="Batch size for training.")
...@@ -122,7 +138,6 @@ def main(): ...@@ -122,7 +138,6 @@ def main():
place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(0) place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(0)
exe = fluid.Executor(place) exe = fluid.Executor(place)
# test # test
def test(exe): def test(exe):
accuracy.reset(exe) accuracy.reset(exe)
...@@ -148,20 +163,21 @@ def main(): ...@@ -148,20 +163,21 @@ def main():
accuracy.reset(exe) accuracy.reset(exe)
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
ts = time.time() ts = time.time()
img_data = np.array(map(lambda x: x[0].reshape(data_shape), img_data = np.array(
data)).astype("float32") map(lambda x: x[0].reshape(data_shape), data)).astype(
"float32")
y_data = np.array(map(lambda x: x[1], data)).astype("int64") y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = y_data.reshape([-1, 1]) y_data = y_data.reshape([-1, 1])
loss, acc = exe.run(trainer_prog, loss, acc = exe.run(trainer_prog,
feed={"pixel": img_data, feed={"pixel": img_data,
"label": y_data}, "label": y_data},
fetch_list=[avg_cost] + accuracy.metrics) fetch_list=[avg_cost] + accuracy.metrics)
iters += 1 iters += 1
num_samples += len(data) num_samples += len(data)
print( print(
"Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, spent %f" % "Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, spent %f"
(pass_id, iters, loss, acc, time.time() - ts) % (pass_id, iters, loss, acc, time.time() - ts)
) # The accuracy is the accumulation of batches, but not the current batch. ) # The accuracy is the accumulation of batches, but not the current batch.
pass_elapsed = time.time() - start_time pass_elapsed = time.time() - start_time
...@@ -170,7 +186,7 @@ def main(): ...@@ -170,7 +186,7 @@ def main():
print( print(
"Pass = %d, Training performance = %f imgs/s, Train accuracy = %f, Test accuracy = %f\n" "Pass = %d, Training performance = %f imgs/s, Train accuracy = %f, Test accuracy = %f\n"
% (pass_id, num_samples / pass_elapsed, pass_train_acc, % (pass_id, num_samples / pass_elapsed, pass_train_acc,
pass_test_acc)) pass_test_acc))
if args.local: if args.local:
# Parameter initialization # Parameter initialization
...@@ -179,8 +195,8 @@ def main(): ...@@ -179,8 +195,8 @@ def main():
# data reader # data reader
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.cifar.train10() paddle.dataset.cifar.train10() if args.data_set == 'cifar10'
if args.data_set == 'cifar10' else paddle.dataset.flowers.train(), else paddle.dataset.flowers.train(),
buf_size=5120), buf_size=5120),
batch_size=args.batch_size) batch_size=args.batch_size)
test_reader = paddle.batch( test_reader = paddle.batch(
...@@ -196,19 +212,25 @@ def main(): ...@@ -196,19 +212,25 @@ def main():
pserver_endpoints = ",".join(eplist) pserver_endpoints = ",".join(eplist)
print("pserver endpoints: ", pserver_endpoints) print("pserver endpoints: ", pserver_endpoints)
trainers = int(os.getenv("TRAINERS")) # total trainer count trainers = int(os.getenv("TRAINERS")) # total trainer count
current_endpoint = os.getenv("POD_IP") + ":6174" # current pserver endpoint current_endpoint = os.getenv(
training_role = os.getenv("TRAINING_ROLE", "POD_IP") + ":6174" # current pserver endpoint
"TRAINER") # get the training role: trainer/pserver training_role = os.getenv(
"TRAINING_ROLE",
"TRAINER") # get the training role: trainer/pserver
t = fluid.DistributeTranspiler() t = fluid.DistributeTranspiler()
t.transpile( t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=trainers) optimize_ops,
params_grads,
pservers=pserver_endpoints,
trainers=trainers)
if training_role == "PSERVER": if training_role == "PSERVER":
if not current_endpoint: if not current_endpoint:
print("need env SERVER_ENDPOINT") print("need env SERVER_ENDPOINT")
exit(1) exit(1)
pserver_prog = t.get_pserver_program(current_endpoint) pserver_prog = t.get_pserver_program(current_endpoint)
pserver_startup = t.get_startup_program(current_endpoint, pserver_prog) pserver_startup = t.get_startup_program(current_endpoint,
pserver_prog)
print("starting server side startup") print("starting server side startup")
exe.run(pserver_startup) exe.run(pserver_startup)
print("starting parameter server...") print("starting parameter server...")
...@@ -220,13 +242,13 @@ def main(): ...@@ -220,13 +242,13 @@ def main():
# data reader # data reader
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.cifar.train10() paddle.dataset.cifar.train10() if args.data_set == 'cifar10'
if args.data_set == 'cifar10' else paddle.dataset.flowers.train(), else paddle.dataset.flowers.train(),
buf_size=5120), buf_size=5120),
batch_size=args.batch_size) batch_size=args.batch_size)
test_reader = paddle.batch( test_reader = paddle.batch(
paddle.dataset.cifar.test10() paddle.dataset.cifar.test10() if args.data_set == 'cifar10' else
if args.data_set == 'cifar10' else paddle.dataset.flowers.test(), paddle.dataset.flowers.test(),
batch_size=args.batch_size) batch_size=args.batch_size)
trainer_prog = t.get_trainer_program() trainer_prog = t.get_trainer_program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册