未验证 提交 55d3951b 编写于 作者: W Wu Yi 提交者: GitHub

Benchmark/Integrate benchmark scripts (#10707)

* wip integrate benchmark scripts

* testing nlp models

* k8s script to start dist benchmark job

* update script

* done support all models

* add README.md

* update by comment

* clean up

* follow comments
上级 530556dd
# Fluid Benchmark
This directory contains several models configurations and tools that used to run
Fluid benchmarks for local and distributed training.
## Run the Benchmark
To start, run the following command to get the full help message:
```bash
python fluid_benchmark.py --help
```
Currently supported `--model` argument include:
* mnist
* resnet
* you can chose to use different dataset using `--data_set cifar10` or
`--data_set flowers`.
* vgg
* stacked_dynamic_lstm
* machine_translation
* Run the following command to start a benchmark job locally:
```bash
python fluid_benchmark.py --model mnist --parallel 1 --device GPU --with_test
```
You can choose to use GPU/CPU training. With GPU training, you can specify
`--parallel 1` to run multi GPU training.
* Run distributed training with parameter servers:
* start parameter servers:
```bash
PADDLE_TRAINING_ROLE=PSERVER PADDLE_PSERVER_PORT=7164 PADDLE_PSERVER_IPS=127.0.0.1 PADDLE_TRAINERS=1 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=0 python fluid_benchmark.py --model mnist --parallel 0 --device GPU --update_method pserver
```
* start trainers:
```bash
PADDLE_TRAINING_ROLE=PSERVER PADDLE_PSERVER_PORT=7164 PADDLE_PSERVER_IPS=127.0.0.1 PADDLE_TRAINERS=1 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=0 python fluid_benchmark.py --model mnist --parallel 0 --device GPU --update_method pserver
```
* Run distributed training using NCCL2
```bash
PADDLE_PSERVER_PORT=7164 PADDLE_TRAINER_IPS=192.168.0.2,192.168.0.3 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=0 python fluid_benchmark.py --model mnist --parallel 0 --device GPU --update_method nccl2
```
## Run Distributed Benchmark on Kubernetes Cluster
We provide a script `kube_gen_job.py` to generate Kubernetes yaml files to submit
distributed benchmark jobs to your cluster. To generate a job yaml, just run:
```bash
python kube_gen_job.py --jobname myjob --pscpu 4 --cpu 8 --gpu 8 --psmemory 20 --memory 40 --pservers 4 --trainers 4 --entry "python fluid_benchmark.py --model mnist --parallel 1 --device GPU --update_method pserver --with_test" --disttype pserver
```
Then the yaml files are generated under directory `myjob`, you can run:
```bash
kubectl create -f myjob/
```
The job shall start.
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import cProfile
import time
import os
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.profiler as profiler
import paddle.fluid.transpiler.distribute_transpiler as distribute_transpiler
BENCHMARK_MODELS = [
"machine_translation", "resnet", "vgg", "mnist", "stacked_dynamic_lstm"
]
def parse_args():
parser = argparse.ArgumentParser('Fluid model benchmarks.')
parser.add_argument(
'--model',
type=str,
choices=BENCHMARK_MODELS,
default='resnet',
help='The model to run benchmark with.')
parser.add_argument(
'--batch_size', type=int, default=32, help='The minibatch size.')
parser.add_argument(
'--learning_rate',
type=float,
default=0.001,
help='The minibatch size.')
# TODO(wuyi): add "--use_fake_data" option back.
parser.add_argument(
'--skip_batch_num',
type=int,
default=5,
help='The first num of minibatch num to skip, for better performance test'
)
parser.add_argument(
'--iterations', type=int, default=80, help='The number of minibatches.')
parser.add_argument(
'--pass_num', type=int, default=100, help='The number of passes.')
parser.add_argument(
'--data_format',
type=str,
default='NCHW',
choices=['NCHW', 'NHWC'],
help='The data data_format, now only support NCHW.')
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help='The device type.')
parser.add_argument(
'--gpus',
type=int,
default=1,
help='If gpus > 1, will use ParallelExecutor to run, else use Executor.')
parser.add_argument(
'--data_set',
type=str,
default='flowers',
choices=['cifar10', 'flowers'],
help='Optional dataset for benchmark.')
parser.add_argument(
'--infer_only', action='store_true', help='If set, run forward only.')
parser.add_argument(
'--use_cprof', action='store_true', help='If set, use cProfile.')
parser.add_argument(
'--use_nvprof',
action='store_true',
help='If set, use nvprof for CUDA.')
parser.add_argument(
'--no_test',
action='store_false',
help='If set, test the testset during training.')
parser.add_argument(
'--memory_optimize',
action='store_true',
help='If set, optimize runtime memory before start.')
parser.add_argument(
'--update_method',
type=str,
default='local',
choices=['local', 'pserver', 'nccl2'],
help='Choose parameter update method, can be local, pserver, nccl2.')
args = parser.parse_args()
return args
def append_nccl2_prepare():
if os.getenv("PADDLE_TRAINER_ID", None) != None:
# append gen_nccl_id at the end of startup program
trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
port = os.getenv("PADDLE_PSERVER_PORT")
worker_ips = os.getenv("PADDLE_TRAINER_IPS")
worker_endpoints = []
for ip in worker_ips.split(","):
worker_endpoints.append(':'.join([ip, port]))
num_trainers = len(worker_endpoints)
current_endpoint = os.getenv("PADDLE_CURRENT_IP") + ":" + port
worker_endpoints.remove(current_endpoint)
nccl_id_var = fluid.default_startup_program().global_block().create_var(
name="NCCLID",
persistable=True,
type=fluid.core.VarDesc.VarType.RAW)
fluid.default_startup_program().global_block().append_op(
type="gen_nccl_id",
inputs={},
outputs={"NCCLID": nccl_id_var},
attrs={
"endpoint": current_endpoint,
"endpoint_list": worker_endpoints,
"trainer_id": trainer_id
})
return nccl_id_var, num_trainers, trainer_id
else:
raise Exception(
"must set PADDLE_TRAINER_ID env variables for dist train.")
def dist_transpile():
if "PADDLE_TRAINING_ROLE" not in os.environ:
return None, None
# the port of all pservers, needed by both trainer and pserver
port = os.getenv("PADDLE_PSERVER_PORT", "6174")
# comma separated ips of all pservers, needed by trainer and
# pserver
pserver_ips = os.getenv("PADDLE_PSERVER_IPS", "")
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist)
# total number of workers/trainers in the job, needed by
# trainer and pserver
trainers = int(os.getenv("PADDLE_TRAINERS"))
# the IP of the local machine, needed by pserver only
current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port
# the unique trainer id, starting from 0, needed by trainer
# only
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
# the role, should be either PSERVER or TRAINER
training_role = os.getenv("PADDLE_TRAINING_ROLE")
t = distribute_transpiler.DistributeTranspiler()
t.transpile(trainer_id, pservers=pserver_endpoints, trainers=trainers)
if training_role == "PSERVER":
pserver_program = t.get_pserver_program(current_endpoint)
pserver_startup_program = t.get_startup_program(current_endpoint,
pserver_program)
return pserver_program, pserver_startup_program
elif training_role == "TRAINER":
train_program = t.get_trainer_program()
return train_program, fluid.default_startup_program()
else:
raise ValueError(
'TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
)
def test(exe, inference_program, test_reader, feeder, batch_acc):
accuracy_evaluator = fluid.metrics.Accuracy()
for batch_id, data in enumerate(test_reader()):
acc = exe.run(inference_program,
feed=feeder.feed(data),
fetch_list=[batch_acc])
accuracy_evaluator.update(value=np.array(acc), weight=len(data))
return accuracy_evaluator.eval()
# TODO(wuyi): replace train, train_parallel, test functions with new trainer
# API once it is ready.
def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
args, train_prog, startup_prog):
if os.getenv("PADDLE_TRAINING_ROLE") == "PSERVER":
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
exe.run(train_prog)
return
place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(startup_prog)
feed_var_list = [
var for var in train_prog.global_block().vars.itervalues()
if var.is_data
]
feeder = fluid.DataFeeder(feed_var_list, place)
iters, num_samples, start_time = 0, 0, time.time()
for pass_id in range(args.pass_num):
train_losses = []
for batch_id, data in enumerate(train_reader()):
if iters == args.skip_batch_num:
start_time = time.time()
num_samples = 0
if iters == args.iterations:
break
loss = exe.run(train_prog,
feed=feeder.feed(data),
fetch_list=[avg_loss])
iters += 1
num_samples += len(data)
train_losses.append(loss)
print("Pass: %d, Iter: %d, Loss: %f\n" %
(pass_id, iters, np.mean(train_losses)))
train_elapsed = time.time() - start_time
examples_per_sec = num_samples / train_elapsed
print('\nTotal examples: %d, total time: %.5f, %.5f examples/sec\n' %
(num_samples, train_elapsed, examples_per_sec))
print("Pass: %d, Loss: %f" % (pass_id, np.mean(train_losses)))
# evaluation
if not args.no_test and batch_acc != None:
pass_test_acc = test(exe, infer_prog, test_reader, feeder,
batch_acc)
print(", Test Accuracy: %f" % pass_test_acc)
print("\n")
# TODO(wuyi): add warmup passes to get better perf data.
exit(0)
# TODO(wuyi): replace train, train_parallel, test functions with new trainer
# API once it is ready.
def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
batch_acc, args, train_prog, startup_prog, nccl_id_var,
num_trainers, trainer_id):
place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(0)
startup_exe = fluid.Executor(place)
startup_exe.run(startup_prog)
strategy = fluid.ExecutionStrategy()
strategy.num_threads = 1
strategy.allow_op_delay = False
exe = fluid.ParallelExecutor(
True,
avg_loss.name,
exec_strategy=strategy,
num_trainers=num_trainers,
trainer_id=trainer_id)
feed_var_list = [
var for var in train_prog.global_block().vars.itervalues()
if var.is_data
]
feeder = fluid.DataFeeder(feed_var_list, place)
for pass_id in range(args.pass_num):
num_samples = 0
iters = 0
start_time = time.time()
for batch_id, data in enumerate(train_reader()):
if iters == args.skip_batch_num:
start_time = time.time()
num_samples = 0
if iters == args.iterations:
break
loss, = exe.run([avg_loss.name], feed=feeder.feed(data))
if args.update_method == "pserver":
exe.bcast_params()
num_samples += len(data)
iters += 1
if batch_id % 1 == 0:
print("Pass %d, batch %d, loss %s" %
(pass_id, batch_id, np.array(loss)))
train_elapsed = time.time() - start_time
examples_per_sec = num_samples / train_elapsed
print('\nTotal examples: %d, total time: %.5f, %.5f examples/sed\n' %
(num_samples, train_elapsed, examples_per_sec))
if not args.no_test and batch_acc != None:
test_acc = test(startup_exe, infer_prog, test_reader, feeder,
batch_acc)
print("Pass: %d, Test Accuracy: %f\n" % (pass_id, test_acc))
exit(0)
def print_arguments(args):
vars(args)['use_nvprof'] = (vars(args)['use_nvprof'] and
vars(args)['device'] == 'GPU')
print('----------- resnet Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
def main():
args = parse_args()
print_arguments(args)
nccl_id_var, num_trainers, trainer_id = None, 1, 0
if args.use_cprof:
pr = cProfile.Profile()
pr.enable()
model_def = __import__("models.%s" % args.model, fromlist=["models"])
train_args = list(model_def.get_model(args))
train_args.append(args)
# Run optimizer.minimize(avg_loss)
train_args[2].minimize(train_args[0])
if args.memory_optimize:
fluid.memory_optimize(fluid.default_main_program())
if args.update_method == "pserver":
train_prog, startup_prog = dist_transpile()
if not train_prog:
raise Exception(
"Must configure correct environments to run dist train.")
train_args.extend([train_prog, startup_prog])
if args.gpus > 1 and os.getenv("PADDLE_TRAINING_ROLE") == "TRAINER":
train_args.extend([nccl_id_var, num_trainers, trainer_id])
train_parallel(*train_args)
train(*train_args)
exit(0)
# for other update methods, use default programs
train_args.append(fluid.default_main_program())
train_args.append(fluid.default_startup_program())
if args.update_method == "nccl2":
nccl_id_var, num_trainers, trainer_id = append_nccl2_prepare()
if args.gpus == 1:
# NOTE: parallel executor use profiler interanlly
if args.use_nvprof and args.device == 'GPU':
with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
train(*train_args)
else:
train(*train_args)
else:
if args.device == "CPU":
raise Exception("Only support GPU perf with parallel exe")
train_args.extend([nccl_id_var, num_trainers, trainer_id])
train_parallel(*train_args)
if __name__ == "__main__":
main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
import copy
import argparse
import random
import os
from kube_templates import pserver, trainer, envs
def parse_args():
parser = argparse.ArgumentParser(description='Generate dist job yamls.')
parser.add_argument(
'--jobname', default="paddlejob", help='unique job name')
parser.add_argument(
'--cpu', default=1, type=int, help='CPU cores per trainer node')
parser.add_argument(
'--pscpu', default=1, type=int, help='CPU cores per pserver node')
parser.add_argument(
'--gpu', default=0, type=int, help='num of GPUs per node')
parser.add_argument(
'--image',
default="bootstrapper:5000/fluid_benchmark:gpu",
help='num of GPUs per node')
parser.add_argument(
'--pservers', default=1, type=int, help='num of pservers')
parser.add_argument(
'--trainers', default=1, type=int, help='num of trainers')
parser.add_argument('--memory', default=1, type=int, help='trainer memory')
parser.add_argument(
'--psmemory', default=1, type=int, help='pserver memory')
parser.add_argument(
'--port', default=30236, type=int, help='num of trainers')
parser.add_argument(
'--entry', default="python train.py", help='command to run')
parser.add_argument(
'--fluid', default=1, type=int, help='whether is fluid job')
parser.add_argument(
'--rdma', action='store_ture', help='whether mount rdma libs')
parser.add_argument(
'--disttype',
default="pserver",
type=str,
choices=['pserver', 'nccl2', 'local'],
help='pserver or nccl2 or local')
args = parser.parse_args()
return args
def gen_job():
ps = pserver
tn = trainer
args = parse_args()
ps_container = ps["spec"]["template"]["spec"]["containers"][0]
tn_container = tn["spec"]["template"]["spec"]["containers"][0]
if args.fluid == 1:
ps_container["command"] = \
["paddle_k8s", "start_fluid"]
tn_container["command"] = \
["paddle_k8s", "start_fluid"]
ps["metadata"]["name"] = args.jobname + "-pserver"
ps["spec"]["template"]["metadata"]["labels"][
"paddle-job-pserver"] = args.jobname
tn["metadata"]["name"] = args.jobname + "-trainer"
tn["spec"]["template"]["metadata"]["labels"]["paddle-job"] = args.jobname
ps_container["image"] = args.image
tn_container["image"] = args.image
ps_container["resources"]["requests"]["cpu"] = str(args.pscpu)
ps_container["resources"]["requests"]["memory"] = str(args.psmemory) + "Gi"
ps_container["resources"]["limits"]["cpu"] = str(args.pscpu)
ps_container["resources"]["limits"]["memory"] = str(args.psmemory) + "Gi"
tn_container["resources"]["requests"]["cpu"] = str(args.cpu)
tn_container["resources"]["requests"]["memory"] = str(args.memory) + "Gi"
tn_container["resources"]["limits"]["cpu"] = str(args.cpu)
tn_container["resources"]["limits"]["memory"] = str(args.memory) + "Gi"
if args.gpu > 0:
tn_container["resources"]["requests"][
"alpha.kubernetes.io/nvidia-gpu"] = str(args.gpu)
tn_container["resources"]["limits"][
"alpha.kubernetes.io/nvidia-gpu"] = str(args.gpu)
ps["spec"]["replicas"] = int(args.pservers)
tn["spec"]["parallelism"] = int(args.trainers)
tn["spec"]["completions"] = int(args.trainers)
ps_container["ports"][0]["name"] = "jobport-" + str(args.port)
ps_container["ports"][0]["containerPort"] = args.port
spreadport = random.randint(40000, 60000)
tn_container["ports"][0]["name"] = "spr-" + str(spreadport)
tn_container["ports"][0]["containerPort"] = spreadport
envs.append({"name": "PADDLE_JOB_NAME", "value": args.jobname})
envs.append({"name": "TRAINERS", "value": str(args.trainers)})
envs.append({"name": "PSERVERS", "value": str(args.pservers)})
envs.append({"name": "ENTRY", "value": args.entry})
envs.append({"name": "PADDLE_INIT_PORT", "value": str(args.port)})
# NOTE: these directories below are cluster specific, please modify
# this settings before you run on your own cluster.
envs.append({
"name": "LD_LIBRARY_PATH",
"value":
"/usr/local/lib:/usr/local/nvidia/lib64:/usr/local/rdma/lib64:/usr/lib64/mlnx_ofed/valgrind"
})
volumes = [{
"name": "nvidia-driver",
"hostPath": {
"path": "/usr/local/nvidia/lib64"
}
}]
volumeMounts = [{
"mountPath": "/usr/local/nvidia/lib64",
"name": "nvidia-driver"
}]
if args.rdma:
volumes.extend([{
"name": "ibetc",
"hostPath": {
"path": "/etc/libibverbs.d"
}
}, {
"name": "iblibs",
"hostPath": {
"path": "/usr/local/rdma"
}
}, {
"name": "valgrind",
"hostPath": {
"path": "/usr/lib64/mlnx_ofed/valgrind"
}
}])
volumeMounts.extend([{
"mountPath": "/etc/libibverbs.d",
"name": "ibetc"
}, {
"mountPath": "/usr/local/rdma",
"name": "iblibs"
}, {
"mountPath": "/usr/lib64/mlnx_ofed/valgrind",
"name": "valgrind"
}])
# append shm for NCCL2
volumes.append({"name": "dshm", "emptyDir": {"medium": "Memory"}})
volumeMounts.append({"mountPath": "/dev/shm", "name": "dshm"})
tn["spec"]["template"]["spec"]["volumes"] = volumes
tn_container["volumeMounts"] = volumeMounts
ps_container["env"] = envs
ps_container["env"].append({"name": "TRAINING_ROLE", "value": "PSERVER"})
tn_container["env"] = envs
if args.disttype == "pserver":
tn_container["env"].append({
"name": "TRAINING_ROLE",
"value": "TRAINER"
})
elif args.disttype == "nccl2" or args.disttype == "local":
# NCCL2 have no training role, set to plain WORKER
tn_container["env"].append({"name": "TRAINING_ROLE", "value": "WORKER"})
os.mkdir(args.jobname)
if args.disttype == "pserver":
with open("%s/pserver.yaml" % args.jobname, "w") as fn:
yaml.dump(ps, fn)
with open("%s/trainer.yaml" % args.jobname, "w") as fn:
yaml.dump(tn, fn)
if __name__ == "__main__":
gen_job()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pserver import pserver
from trainer import trainer
__all__ = ["pserver", "trainer", "envs"]
envs = [
# envs that don't need to change
{
"name": "GLOG_v",
"value": "0"
},
{
"name": "GLOG_logtostderr",
"value": "1"
},
{
"name": "TOPOLOGY",
"value": ""
},
{
"name": "TRAINER_PACKAGE",
"value": "/workspace"
},
{
"name": "PADDLE_INIT_NICS",
"value": "eth2"
},
{
"name": "NAMESPACE",
"valueFrom": {
"fieldRef": {
"fieldPath": "metadata.namespace"
}
}
},
{
"name": "POD_IP",
"valueFrom": {
"fieldRef": {
"fieldPath": "status.podIP"
}
}
}
]
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pserver = {
"apiVersion": "extensions/v1beta1",
"kind": "ReplicaSet",
"metadata": {
"name": "jobname-pserver"
},
"spec": {
"replicas": 1,
"template": {
"metadata": {
"labels": {
"paddle-job-pserver": "jobname"
}
},
"spec": {
"hostNetwork": True,
"imagePullSecrets": [{
"name": "job-registry-secret"
}],
"containers": [{
"name": "pserver",
"image": "",
"imagePullPolicy": "Always",
"ports": [{
"name": "jobport-1",
"containerPort": 1
}],
"env": [],
"command": ["paddle_k8s", "start_pserver"],
"resources": {
"requests": {
"memory": "10Gi",
"cpu": "4"
},
"limits": {
"memory": "10Gi",
"cpu": "4"
}
}
}]
}
}
}
}
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
trainer = {
"apiVersion": "batch/v1",
"kind": "Job",
"metadata": {
"name": "jobname-pserver"
},
"spec": {
"parallelism": 4,
"completions": 4,
"template": {
"metadata": {
"labels": {
"paddle-job": "jobname"
}
},
"spec": {
"hostNetwork": True,
"imagePullSecrets": [{
"name": "job-registry-secret"
}],
"restartPolicy": "Never",
"containers": [{
"name": "trainer",
"image": "",
"imagePullPolicy": "Always",
# to let container set rlimit
"securityContext": {
"privileged": True
# TODO(wuyi): use below specific cap instead of privileged,
# using privileged will cause all GPU device are visible
# in the container.
# "capabilities": {
# "add": ["SYS_RESOURCE"]
# }
},
"ports": [{
"name": "jobport-1",
"containerPort": 1
}],
"env": [],
"command": ["paddle_k8s", "start_trainer", "v2"],
"resources": {
"requests": {
"memory": "10Gi",
"cpu": "4",
},
"limits": {
"memory": "10Gi",
"cpu": "4",
}
}
}]
}
}
}
}
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = [
"machine_translation", "resnet", "vgg", "mnist", "stacked_dynamic_lstm"
]
......@@ -27,74 +27,6 @@ import paddle.fluid.core as core
import paddle.fluid.framework as framework
from paddle.fluid.executor import Executor
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--embedding_dim",
type=int,
default=512,
help="The dimension of embedding table. (default: %(default)d)")
parser.add_argument(
"--encoder_size",
type=int,
default=512,
help="The size of encoder bi-rnn unit. (default: %(default)d)")
parser.add_argument(
"--decoder_size",
type=int,
default=512,
help="The size of decoder rnn unit. (default: %(default)d)")
parser.add_argument(
"--batch_size",
type=int,
default=16,
help="The sequence number of a mini-batch data. (default: %(default)d)")
parser.add_argument(
'--skip_batch_num',
type=int,
default=5,
help='The first num of minibatch num to skip, for better performance test')
parser.add_argument(
'--iterations', type=int, default=80, help='The number of minibatches.')
parser.add_argument(
"--dict_size",
type=int,
default=30000,
help="The dictionary capacity. Dictionaries of source sequence and "
"target dictionary have same capacity. (default: %(default)d)")
parser.add_argument(
"--pass_num",
type=int,
default=2,
help="The pass number to train. (default: %(default)d)")
parser.add_argument(
"--learning_rate",
type=float,
default=0.0002,
help="Learning rate used to train the model. (default: %(default)f)")
parser.add_argument(
"--infer_only", action='store_true', help="If set, run forward only.")
parser.add_argument(
"--beam_size",
type=int,
default=3,
help="The width for beam searching. (default: %(default)d)")
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help="The device type.")
parser.add_argument(
"--max_length",
type=int,
default=250,
help="The maximum length of sequence when doing generation. "
"(default: %(default)d)")
parser.add_argument(
'--with_test',
action='store_true',
help='If set, test the testset during training.')
def lstm_step(x_t, hidden_t_prev, cell_t_prev, size):
def linear(inputs):
......@@ -264,116 +196,37 @@ def lodtensor_to_ndarray(lod_tensor):
return ndarray
def train():
def get_model(args):
embedding_dim = 512
encoder_size = 512
decoder_size = 512
dict_size = 30000
beam_size = 3
max_length = 250
avg_cost, feeding_list = seq_to_seq_net(
args.embedding_dim,
args.encoder_size,
args.decoder_size,
args.dict_size,
args.dict_size,
embedding_dim,
encoder_size,
decoder_size,
dict_size,
dict_size,
False,
beam_size=args.beam_size,
max_length=args.max_length)
beam_size=beam_size,
max_length=max_length)
# clone from default main program
inference_program = fluid.default_main_program().clone()
optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate)
optimizer.minimize(avg_cost)
fluid.memory_optimize(fluid.default_main_program())
train_batch_generator = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt14.train(args.dict_size), buf_size=1000),
paddle.dataset.wmt14.train(dict_size), buf_size=1000),
batch_size=args.batch_size)
test_batch_generator = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt14.test(args.dict_size), buf_size=1000),
paddle.dataset.wmt14.test(dict_size), buf_size=1000),
batch_size=args.batch_size)
place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(0)
exe = Executor(place)
exe.run(framework.default_startup_program())
def do_validation():
total_loss = 0.0
count = 0
for batch_id, data in enumerate(test_batch_generator()):
src_seq = to_lodtensor(map(lambda x: x[0], data), place)[0]
trg_seq = to_lodtensor(map(lambda x: x[1], data), place)[0]
lbl_seq = to_lodtensor(map(lambda x: x[2], data), place)[0]
fetch_outs = exe.run(inference_program,
feed={
feeding_list[0]: src_seq,
feeding_list[1]: trg_seq,
feeding_list[2]: lbl_seq
},
fetch_list=[avg_cost],
return_numpy=False)
total_loss += lodtensor_to_ndarray(fetch_outs[0])[0]
count += 1
return total_loss / count
iters, num_samples, start_time = 0, 0, time.time()
for pass_id in xrange(args.pass_num):
train_accs = []
train_losses = []
for batch_id, data in enumerate(train_batch_generator()):
if iters == args.skip_batch_num:
start_time = time.time()
num_samples = 0
if iters == args.iterations:
break
src_seq, word_num = to_lodtensor(map(lambda x: x[0], data), place)
num_samples += word_num
trg_seq, word_num = to_lodtensor(map(lambda x: x[1], data), place)
num_samples += word_num
lbl_seq, _ = to_lodtensor(map(lambda x: x[2], data), place)
fetch_outs = exe.run(framework.default_main_program(),
feed={
feeding_list[0]: src_seq,
feeding_list[1]: trg_seq,
feeding_list[2]: lbl_seq
},
fetch_list=[avg_cost])
iters += 1
loss = np.array(fetch_outs[0])
print(
"Pass = %d, Iter = %d, Loss = %f" % (pass_id, iters, loss)
) # The accuracy is the accumulation of batches, but not the current batch.
train_elapsed = time.time() - start_time
examples_per_sec = num_samples / train_elapsed
print('\nTotal examples: %d, total time: %.5f, %.5f examples/sed\n' %
(num_samples, train_elapsed, examples_per_sec))
# evaluation
if args.with_test:
test_loss = do_validation()
exit(0)
def infer():
pass
def print_arguments(args):
print('----------- seq2seq Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
if args.infer_only:
infer()
else:
train()
return avg_cost, inference_program, optimizer, train_batch_generator, \
test_batch_generator, None
......@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
import argparse
import time
import cProfile
import paddle
import paddle.fluid as fluid
......@@ -31,42 +32,6 @@ DTYPE = "float32"
# fluid.default_startup_program().random_seed = SEED
def parse_args():
parser = argparse.ArgumentParser("mnist model benchmark.")
parser.add_argument(
'--batch_size', type=int, default=128, help='The minibatch size.')
parser.add_argument(
'--skip_batch_num',
type=int,
default=5,
help='The first num of minibatch num to skip, for better performance test'
)
parser.add_argument(
'--iterations', type=int, default=35, help='The number of minibatches.')
parser.add_argument(
'--pass_num', type=int, default=5, help='The number of passes.')
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help='The device type.')
parser.add_argument(
'--infer_only', action='store_true', help='If set, run forward only.')
parser.add_argument(
'--use_cprof', action='store_true', help='If set, use cProfile.')
parser.add_argument(
'--use_nvprof',
action='store_true',
help='If set, use nvprof for CUDA.')
parser.add_argument(
'--with_test',
action='store_true',
help='If set, test the testset during training.')
args = parser.parse_args()
return args
def cnn_model(data):
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=data,
......@@ -99,36 +64,13 @@ def cnn_model(data):
return predict
def eval_test(exe, batch_acc, batch_size_tensor, inference_program):
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=args.batch_size)
test_pass_acc = fluid.average.WeightedAverage()
for batch_id, data in enumerate(test_reader()):
img_data = np.array(map(lambda x: x[0].reshape([1, 28, 28]),
data)).astype(DTYPE)
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = y_data.reshape([len(y_data), 1])
acc, weight = exe.run(inference_program,
feed={"pixel": img_data,
"label": y_data},
fetch_list=[batch_acc, batch_size_tensor])
test_pass_acc.add(value=acc, weight=weight)
pass_acc = test_pass_acc.eval()
return pass_acc
def run_benchmark(model, args):
if args.use_cprof:
pr = cProfile.Profile()
pr.enable()
start_time = time.time()
def get_model(args):
# Input data
images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE)
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# Train program
predict = model(images)
predict = cnn_model(images)
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
......@@ -143,86 +85,10 @@ def run_benchmark(model, args):
# Optimization
opt = fluid.optimizer.AdamOptimizer(
learning_rate=0.001, beta1=0.9, beta2=0.999)
opt.minimize(avg_cost)
fluid.memory_optimize(fluid.default_main_program())
# Initialize executor
place = fluid.CPUPlace() if args.device == 'CPU' else fluid.CUDAPlace(0)
exe = fluid.Executor(place)
# Parameter initialization
exe.run(fluid.default_startup_program())
# Reader
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=args.batch_size)
accuracy = fluid.metrics.Accuracy()
train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name)
iters, num_samples, start_time = 0, 0, time.time()
for pass_id in range(args.pass_num):
accuracy.reset()
train_accs = []
train_losses = []
for batch_id, data in enumerate(train_reader()):
if iters == args.skip_batch_num:
start_time = time.time()
num_samples = 0
if iters == args.iterations:
break
img_data = np.array(
map(lambda x: x[0].reshape([1, 28, 28]), data)).astype(DTYPE)
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = y_data.reshape([len(y_data), 1])
outs = train_exe.run(
feed={"pixel": img_data,
"label": y_data},
fetch_list=[
avg_cost.name, batch_acc.name, batch_size_tensor.name
]
) # The accuracy is the accumulation of batches, but not the current batch.
accuracy.update(
value=np.array(np.mean(outs[1])),
weight=np.mean(np.array(outs[2])))
iters += 1
num_samples += len(y_data)
loss = np.mean(np.array(outs[0]))
acc = np.mean(np.array(outs[1]))
train_losses.append(loss)
train_accs.append(acc)
print("Pass: %d, Iter: %d, Loss: %f, Accuracy: %f" %
(pass_id, iters, loss, acc))
print("Pass: %d, Loss: %f, Train Accuray: %f\n" %
(pass_id, np.mean(train_losses), np.mean(train_accs)))
train_elapsed = time.time() - start_time
examples_per_sec = num_samples / train_elapsed
print('\nTotal examples: %d, total time: %.5f, %.5f examples/sed\n' %
(num_samples, train_elapsed, examples_per_sec))
# evaluation
if args.with_test:
test_avg_acc = eval_test(exe, batch_acc, batch_size_tensor,
inference_program)
exit(0)
def print_arguments(args):
vars(args)['use_nvprof'] = (vars(args)['use_nvprof'] and
vars(args)['device'] == 'GPU')
print('----------- mnist Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
if __name__ == '__main__':
args = parse_args()
print_arguments(args)
if args.use_nvprof and args.device == 'GPU':
with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
run_benchmark(cnn_model, args)
else:
run_benchmark(cnn_model, args)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=args.batch_size)
return avg_cost, inference_program, opt, train_reader, test_reader, batch_acc
......@@ -16,7 +16,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import functools
import numpy as np
import time
......@@ -29,64 +28,6 @@ import paddle.fluid.core as core
import paddle.fluid.profiler as profiler
def parse_args():
parser = argparse.ArgumentParser('Convolution model benchmark.')
parser.add_argument(
'--model',
type=str,
choices=['resnet_imagenet', 'resnet_cifar10'],
default='resnet_imagenet',
help='The model architecture.')
parser.add_argument(
'--batch_size', type=int, default=32, help='The minibatch size.')
parser.add_argument(
'--use_fake_data',
action='store_true',
help='use real data or fake data')
parser.add_argument(
'--skip_batch_num',
type=int,
default=5,
help='The first num of minibatch num to skip, for better performance test'
)
parser.add_argument(
'--iterations', type=int, default=80, help='The number of minibatches.')
parser.add_argument(
'--pass_num', type=int, default=100, help='The number of passes.')
parser.add_argument(
'--data_format',
type=str,
default='NCHW',
choices=['NCHW', 'NHWC'],
help='The data data_format, now only support NCHW.')
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help='The device type.')
parser.add_argument(
'--data_set',
type=str,
default='flowers',
choices=['cifar10', 'flowers'],
help='Optional dataset for benchmark.')
parser.add_argument(
'--infer_only', action='store_true', help='If set, run forward only.')
parser.add_argument(
'--use_cprof', action='store_true', help='If set, use cProfile.')
parser.add_argument(
'--use_nvprof',
action='store_true',
help='If set, use nvprof for CUDA.')
parser.add_argument(
'--with_test',
action='store_true',
help='If set, test the testset during training.')
args = parser.parse_args()
return args
def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu'):
conv1 = fluid.layers.conv2d(
input=input,
......@@ -100,7 +41,7 @@ def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu'):
def shortcut(input, ch_out, stride):
ch_in = input.shape[1] if args.data_format == 'NCHW' else input.shape[-1]
ch_in = input.shape[1] # if args.data_format == 'NCHW' else input.shape[-1]
if ch_in != ch_out:
return conv_bn_layer(input, ch_out, 1, stride, 0, None)
else:
......@@ -172,23 +113,22 @@ def resnet_cifar10(input, class_dim, depth=32, data_format='NCHW'):
return out
def run_benchmark(model, args):
if args.use_cprof:
pr = cProfile.Profile()
pr.enable()
def get_model(args):
model = resnet_cifar10
if args.data_set == "cifar10":
class_dim = 10
if args.data_format == 'NCHW':
dshape = [3, 32, 32]
else:
dshape = [32, 32, 3]
model = resnet_cifar10
else:
class_dim = 102
if args.data_format == 'NCHW':
dshape = [3, 224, 224]
else:
dshape = [224, 224, 3]
model = resnet_imagenet
input = fluid.layers.data(name='data', shape=dshape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
......@@ -206,9 +146,6 @@ def run_benchmark(model, args):
target_vars=[batch_acc, batch_size_tensor])
optimizer = fluid.optimizer.Momentum(learning_rate=0.01, momentum=0.9)
opts = optimizer.minimize(avg_cost)
fluid.memory_optimize(fluid.default_main_program())
train_reader = paddle.batch(
paddle.reader.shuffle(
......@@ -221,97 +158,4 @@ def run_benchmark(model, args):
if args.data_set == 'cifar10' else paddle.dataset.flowers.test(),
batch_size=args.batch_size)
def test(exe):
test_accuracy = fluid.average.WeightedAverage()
for batch_id, data in enumerate(test_reader()):
img_data = np.array(map(lambda x: x[0].reshape(dshape),
data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = y_data.reshape([-1, 1])
acc, weight = exe.run(inference_program,
feed={"data": img_data,
"label": y_data},
fetch_list=[batch_acc, batch_size_tensor])
test_accuracy.add(value=acc, weight=weight)
return test_accuracy.eval()
place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
accuracy = fluid.average.WeightedAverage()
train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name)
if args.use_fake_data:
data = train_reader().next()
image = np.array(map(lambda x: x[0].reshape(dshape), data)).astype(
'float32')
label = np.array(map(lambda x: x[1], data)).astype('int64')
label = label.reshape([-1, 1])
iters, num_samples, start_time = 0, 0, time.time()
for pass_id in range(args.pass_num):
accuracy.reset()
train_accs = []
train_losses = []
for batch_id, data in enumerate(train_reader()):
if iters == args.skip_batch_num:
start_time = time.time()
num_samples = 0
if iters == args.iterations:
break
if not args.use_fake_data:
image = np.array(map(lambda x: x[0].reshape(dshape),
data)).astype('float32')
label = np.array(map(lambda x: x[1], data)).astype('int64')
label = label.reshape([-1, 1])
loss, acc, weight = train_exe.run(
feed={'data': image,
'label': label},
fetch_list=[
avg_cost.name, batch_acc.name, batch_size_tensor.name
])
iters += 1
num_samples += len(label)
accuracy.add(value=np.array(np.mean(acc)), weight=np.mean(weight))
loss = np.mean(np.array(loss))
acc = np.mean(np.array(acc))
train_losses.append(loss)
train_accs.append(acc)
print("Pass: %d, Iter: %d, Loss: %f, Accuracy: %f" %
(pass_id, iters, loss, acc))
print("Pass: %d, Loss: %f, Train Accuray: %f\n" %
(pass_id, np.mean(train_losses), np.mean(train_accs)))
train_elapsed = time.time() - start_time
examples_per_sec = num_samples / train_elapsed
print('\nTotal examples: %d, total time: %.5f, %.5f examples/sed\n' %
(num_samples, train_elapsed, examples_per_sec))
# evaluation
if args.with_test:
pass_test_acc = test(exe)
exit(0)
def print_arguments(args):
vars(args)['use_nvprof'] = (vars(args)['use_nvprof'] and
vars(args)['device'] == 'GPU')
print('----------- resnet Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
if __name__ == '__main__':
model_map = {
'resnet_imagenet': resnet_imagenet,
'resnet_cifar10': resnet_cifar10
}
args = parse_args()
print_arguments(args)
if args.data_format == 'NHWC':
raise ValueError('Only support NCHW data_format now.')
if args.use_nvprof and args.device == 'GPU':
with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
run_benchmark(model_map[args.model], args)
else:
run_benchmark(model_map[args.model], args)
return avg_cost, inference_program, optimizer, train_reader, test_reader, batch_acc
......@@ -29,57 +29,6 @@ import paddle.fluid as fluid
import paddle.batch as batch
import paddle.fluid.profiler as profiler
def parse_args():
parser = argparse.ArgumentParser("Understand Sentiment by Dynamic RNN.")
parser.add_argument(
'--batch_size',
type=int,
default=32,
help='The sequence number of a batch data. (default: %(default)d)')
parser.add_argument(
'--skip_batch_num',
type=int,
default=5,
help='The first num of minibatch num to skip, for better performance test'
)
parser.add_argument(
'--iterations', type=int, default=80, help='The number of minibatches.')
parser.add_argument(
'--emb_dim',
type=int,
default=512,
help='Dimension of embedding table. (default: %(default)d)')
parser.add_argument(
'--hidden_dim',
type=int,
default=512,
help='Hidden size of lstm unit. (default: %(default)d)')
parser.add_argument(
'--pass_num',
type=int,
default=100,
help='Epoch number to train. (default: %(default)d)')
parser.add_argument(
'--device',
type=str,
default='CPU',
choices=['CPU', 'GPU'],
help='The device type.')
parser.add_argument(
'--crop_size',
type=int,
default=int(os.environ.get('CROP_SIZE', '1500')),
help='The max sentence length of input. Since this model use plain RNN,'
' Gradient could be explored if sentence is too long')
parser.add_argument(
'--with_test',
action='store_true',
help='If set, test the testset during training.')
args = parser.parse_args()
return args
word_dict = imdb.word_dict()
......@@ -94,14 +43,15 @@ def crop_sentence(reader, crop_size):
return __impl__
def main():
args = parse_args()
lstm_size = args.hidden_dim
def get_model(args):
lstm_size = 512
emb_dim = 512
crop_size = 1500
data = fluid.layers.data(
name="words", shape=[1], lod_level=1, dtype='int64')
sentence = fluid.layers.embedding(
input=data, size=[len(word_dict), args.emb_dim])
input=data, size=[len(word_dict), emb_dim])
sentence = fluid.layers.fc(input=sentence, size=lstm_size, act='tanh')
......@@ -161,51 +111,17 @@ def main():
target_vars=[batch_acc, batch_size_tensor])
adam = fluid.optimizer.Adam()
adam.minimize(loss)
fluid.memory_optimize(fluid.default_main_program())
place = fluid.CPUPlace() if args.device == 'CPU' else fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
train_reader = batch(
paddle.reader.shuffle(
crop_sentence(imdb.train(word_dict), args.crop_size),
buf_size=25000),
crop_sentence(imdb.train(word_dict), crop_size), buf_size=25000),
batch_size=args.batch_size)
test_reader = batch(
paddle.reader.shuffle(
crop_sentence(imdb.test(word_dict), crop_size), buf_size=25000),
batch_size=args.batch_size)
iters, num_samples, start_time = 0, 0, time.time()
for pass_id in range(args.pass_num):
train_accs = []
train_losses = []
for batch_id, data in enumerate(train_reader()):
if iters == args.skip_batch_num:
start_time = time.time()
num_samples = 0
if iters == args.iterations:
break
tensor_words = to_lodtensor([x[0] for x in data], place)
label = numpy.array([x[1] for x in data]).astype("int64")
label = label.reshape((-1, 1))
loss_np, acc, weight = exe.run(
fluid.default_main_program(),
feed={"words": tensor_words,
"label": label},
fetch_list=[loss, batch_acc, batch_size_tensor])
iters += 1
for x in data:
num_samples += len(x[0])
print(
"Pass = %d, Iter = %d, Loss = %f, Accuracy = %f" %
(pass_id, iters, loss_np, acc)
) # The accuracy is the accumulation of batches, but not the current batch.
train_elapsed = time.time() - start_time
examples_per_sec = num_samples / train_elapsed
print('\nTotal examples: %d, total time: %.5f, %.5f examples/sed\n' %
(num_samples, train_elapsed, examples_per_sec))
exit(0)
return loss, inference_program, adam, train_reader, test_reader, batch_acc
def to_lodtensor(data, place):
......@@ -221,16 +137,3 @@ def to_lodtensor(data, place):
res.set(flattened_data, place)
res.set_lod([lod])
return res
def print_arguments(args):
print('----------- lstm Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
if __name__ == '__main__':
args = parse_args()
print_arguments(args)
main()
......@@ -23,46 +23,6 @@ import paddle.fluid.core as core
import argparse
import functools
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--batch_size', type=int, default=128, help="Batch size for training.")
parser.add_argument(
'--skip_batch_num',
type=int,
default=5,
help='The first num of minibatch num to skip, for better performance test')
parser.add_argument(
'--iterations', type=int, default=80, help='The number of minibatches.')
parser.add_argument(
'--learning_rate',
type=float,
default=1e-3,
help="Learning rate for training.")
parser.add_argument('--pass_num', type=int, default=50, help="No. of passes.")
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help="The device type.")
parser.add_argument(
'--data_format',
type=str,
default='NCHW',
choices=['NCHW', 'NHWC'],
help='The data order, now only support NCHW.')
parser.add_argument(
'--data_set',
type=str,
default='cifar10',
choices=['cifar10', 'flowers'],
help='Optional dataset for benchmark.')
parser.add_argument(
'--with_test',
action='store_true',
help='If set, test the testset during training.')
args = parser.parse_args()
def vgg16_bn_drop(input):
def conv_block(input, num_filter, groups, dropouts):
......@@ -91,7 +51,7 @@ def vgg16_bn_drop(input):
return fc2
def main():
def get_model(args):
if args.data_set == "cifar10":
classdim = 10
if args.data_format == 'NCHW':
......@@ -128,16 +88,6 @@ def main():
# Optimization
optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate)
opts = optimizer.minimize(avg_cost)
fluid.memory_optimize(fluid.default_main_program())
# Initialize executor
place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(0)
exe = fluid.Executor(place)
# Parameter initialization
exe.run(fluid.default_startup_program())
# data reader
train_reader = paddle.batch(
......@@ -151,78 +101,4 @@ def main():
if args.data_set == 'cifar10' else paddle.dataset.flowers.test(),
batch_size=args.batch_size)
# test
def test(exe):
test_accuracy = fluid.average.WeightedAverage()
for batch_id, data in enumerate(test_reader()):
img_data = np.array(map(lambda x: x[0].reshape(data_shape),
data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = y_data.reshape([-1, 1])
acc, weight = exe.run(inference_program,
feed={"pixel": img_data,
"label": y_data},
fetch_list=[batch_acc, batch_size_tensor])
test_accuracy.add(value=acc, weight=weight)
return test_accuracy.eval()
iters, num_samples, start_time = 0, 0, time.time()
accuracy = fluid.average.WeightedAverage()
train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name)
for pass_id in range(args.pass_num):
accuracy.reset()
train_accs = []
train_losses = []
for batch_id, data in enumerate(train_reader()):
if iters == args.skip_batch_num:
start_time = time.time()
num_samples = 0
if iters == args.iterations:
break
img_data = np.array(map(lambda x: x[0].reshape(data_shape),
data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = y_data.reshape([-1, 1])
loss, acc, weight = train_exe.run(
feed={"pixel": img_data,
"label": y_data},
fetch_list=[
avg_cost.name, batch_acc.name, batch_size_tensor.name
])
accuracy.add(value=np.array(np.mean(acc)), weight=np.mean(weight))
iters += 1
num_samples += len(y_data)
loss = np.mean(np.array(loss))
acc = np.mean(np.array(acc))
print(
"Pass = %d, Iter = %d, Loss = %f, Accuracy = %f" %
(pass_id, iters, loss, acc)
) # The accuracy is the accumulation of batches, but not the current batch.
# pass_train_acc = accuracy.eval()
train_losses.append(loss)
train_accs.append(acc)
print("Pass: %d, Loss: %f, Train Accuray: %f\n" %
(pass_id, np.mean(train_losses), np.mean(train_accs)))
train_elapsed = time.time() - start_time
examples_per_sec = num_samples / train_elapsed
print('\nTotal examples: %d, total time: %.5f, %.5f examples/sed\n' %
(num_samples, train_elapsed, examples_per_sec))
# evaluation
if args.with_test:
pass_test_acc = test(exe)
exit(0)
def print_arguments():
print('----------- vgg Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
if __name__ == "__main__":
print_arguments()
main()
return avg_cost, inference_program, optimizer, train_reader, test_reader, batch_acc
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册