Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ef35c4ed
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ef35c4ed
编写于
2月 27, 2018
作者:
G
gongweibao
提交者:
GitHub
2月 27, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Tensorflow benchmark (#8522)
Tensorflow benchmark
上级
1ac31d3d
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
609 addition
and
21 deletion
+609
-21
benchmark/cluster/vgg16/Dockerfile
benchmark/cluster/vgg16/Dockerfile
+26
-9
benchmark/cluster/vgg16/fluid_trainer.yaml
benchmark/cluster/vgg16/fluid_trainer.yaml
+1
-1
benchmark/cluster/vgg16/tf_k8s
benchmark/cluster/vgg16/tf_k8s
+82
-0
benchmark/cluster/vgg16/tf_pserver.yaml
benchmark/cluster/vgg16/tf_pserver.yaml
+56
-0
benchmark/cluster/vgg16/tf_trainer.yaml
benchmark/cluster/vgg16/tf_trainer.yaml
+58
-0
benchmark/cluster/vgg16/vgg16_fluid.py
benchmark/cluster/vgg16/vgg16_fluid.py
+24
-11
benchmark/cluster/vgg16/vgg16_tf.py
benchmark/cluster/vgg16/vgg16_tf.py
+362
-0
未找到文件。
benchmark/cluster/vgg16/Dockerfile
浏览文件 @
ef35c4ed
#FROM python:2.7.14
FROM
nvidia/cuda:8.0-cudnn5-runtime-ubuntu16.04
FROM
nvidia/cuda:8.0-cudnn5-runtime-ubuntu16.04
RUN
apt-get update
&&
apt-get
install
-y
python
RUN
pip
install
-U
kubernetes opencv-python
&&
apt-get update
-y
&&
apt-get
install
-y
iputils-ping libgtk2.0-dev
# you can get mirror list here:
# NOTE: By default CI built wheel packages turn WITH_DISTRIBUTE=OFF,
# https://launchpad.net/ubuntu/+archivemirrors
# so we must build one with distribute support to install in this image.
ARG
UBUNTU_MIRROR
RUN
/bin/bash
-c
'if [[ -n ${UBUNTU_MIRROR} ]]; then sed -i '
s#http://archive.ubuntu.com/ubuntu#
${
UBUNTU_MIRROR
}
#g' /etc/apt/sources.list; fi'
RUN
apt-get update
&&
apt-get
install
-y
python python-dev python-pip iputils-ping libgtk2.0-dev
RUN
pip
install
-U
kubernetes opencv-python
RUN
pip
install
paddlepaddle
RUN
pip
install
paddlepaddle
# if network is slowly, you may need to add proxy here.
# ENV https_proxy=
RUN
sh
-c
'echo "import paddle.v2 as paddle\npaddle.dataset.cifar.train10()" | python'
RUN
sh
-c
'echo "import paddle.v2 as paddle\npaddle.dataset.cifar.train10()" | python'
RUN
pip uninstall
-y
paddlepaddle
RUN
pip uninstall
-y
paddlepaddle
# unset proxy if it is setted.
# ENV https_proxy=""
# NOTE: By default CI built wheel packages turn WITH_DISTRIBUTE=OFF,
# so we must build one with distribute support to install in this image.
ADD
*.whl /
RUN
pip
install
/
*
.whl
&&
rm
-f
/
*
.whl
ENV
LD_LIBRARY_PATH=/usr/local/lib
# tf k8s
RUN
pip
install
tensorflow
==
1.4.0
ADD
tf_k8s /usr/bin
RUN
chmod
+x /usr/bin/tf_k8s
ADD
vgg16_tf.py /workspace/
# below lines may change a lot for debugging
# below lines may change a lot for debugging
ADD
https://raw.githubusercontent.com/PaddlePaddle/cloud/develop/docker/paddle_k8s /usr/bin
ADD
https://raw.githubusercontent.com/PaddlePaddle/cloud/develop/docker/paddle_k8s /usr/bin
ADD
https://raw.githubusercontent.com/PaddlePaddle/cloud/develop/docker/k8s_tools.py /root
ADD
https://raw.githubusercontent.com/PaddlePaddle/cloud/develop/docker/k8s_tools.py /root
ADD
*.whl /
RUN
chmod
+x /usr/bin/paddle_k8s
RUN
pip
install
/
*
.whl
&&
rm
-f
/
*
.whl
&&
\
chmod
+x /usr/bin/paddle_k8s
ENV
LD_LIBRARY_PATH=/usr/local/lib
ADD
vgg16_fluid.py vgg16_v2.py /workspace/
ADD
vgg16_fluid.py vgg16_v2.py /workspace/
benchmark/cluster/vgg16/fluid_trainer.yaml
浏览文件 @
ef35c4ed
...
@@ -11,7 +11,7 @@ spec:
...
@@ -11,7 +11,7 @@ spec:
paddle-job
:
vgg16job
paddle-job
:
vgg16job
spec
:
spec
:
imagePullSecrets
:
imagePullSecrets
:
-
name
:
job-registry-secret
-
name
:
job-registry-secret
hostNetwork
:
true
hostNetwork
:
true
containers
:
containers
:
-
name
:
trainer
-
name
:
trainer
...
...
benchmark/cluster/vgg16/tf_k8s
0 → 100644
浏览文件 @
ef35c4ed
#!/bin/bash
check_trainer_ret
()
{
ret
=
$1
stdbuf
-oL
echo
"job returned
$ret
...setting pod return message..."
stdbuf
-oL
echo
"==============================="
if
[
$ret
-eq
136
]
;
then
echo
"Error Arithmetic Operation(Floating Point Exception)"
>
/dev/termination-log
elif
[
$ret
-eq
139
]
;
then
echo
"Segmentation Fault"
>
/dev/termination-log
elif
[
$ret
-eq
1
]
;
then
echo
"General Error"
>
/dev/termination-log
elif
[
$ret
-eq
134
]
;
then
echo
"Program Abort"
>
/dev/termination-log
fi
stdbuf
-oL
echo
"termination log wroted..."
exit
$ret
}
g_pservers
=
""
g_trainers
=
""
wait_running_pods
(){
pserver_label
=
"tf-job-pserver=
${
JOB_NAME
}
"
trainer_label
=
"tf-job-trainer=
${
JOB_NAME
}
"
stdbuf
-oL
python /root/k8s_tools.py wait_pods_running
${
pserver_label
}
${
PSERVERS_NUM
}
stdbuf
-oL
python /root/k8s_tools.py wait_pods_running
${
trainer_label
}
${
TRAINERS_NUM
}
g_pservers
=
$(
python /root/k8s_tools.py fetch_endpoints
${
pserver_label
}
${
PORT
}
)
g_trainers
=
$(
python /root/k8s_tools.py fetch_endpoints
${
trainer_label
}
${
PORT
}
)
}
start_tf_pserver
(){
wait_running_pods
label
=
"tf-job-pserver=
${
JOB_NAME
}
"
pserver_id
=
$(
python /root/k8s_tools.py fetch_id
${
label
}
)
cmd
=
"
${
ENTRY
}
--ps_hosts=
${
g_pservers
}
--worker_hosts=
${
g_trainers
}
\
--job_name=
${
TF_JOB_NAME
}
--task_index=
${
pserver_id
}
"
stdbuf
-oL
sh
-c
"cd
${
TRAINER_PACKAGE
}
&&
${
cmd
}
"
}
start_tf_trainer
(){
wait_running_pods
label
=
"tf-job-trainer=
${
JOB_NAME
}
"
trainer_id
=
$(
python /root/k8s_tools.py fetch_id
${
label
}
)
cmd
=
"
${
ENTRY
}
--ps_hosts=
${
g_pservers
}
--worker_hosts=
${
g_trainers
}
\
--job_name=
${
TF_JOB_NAME
}
--task_index=
${
trainer_id
}
--batch_size=
${
BATCH_SIZE
}
"
stdbuf
-oL
sh
-c
"cd
${
TRAINER_PACKAGE
}
&&
${
cmd
}
"
check_trainer_ret
$?
}
start_tf
(){
if
[[
"
${
TF_JOB_NAME
}
"
==
"worker"
]]
;
then
start_tf_trainer
else
start_tf_pserver
fi
}
usage
()
{
echo
"usage: tf_k8s [<args>]:"
echo
" start_tf Start tensorflow jobs"
}
case
"
$1
"
in
start_tf
)
start_tf
;;
--help
)
usage
;;
*
)
usage
;;
esac
benchmark/cluster/vgg16/tf_pserver.yaml
0 → 100644
浏览文件 @
ef35c4ed
apiVersion
:
extensions/v1beta1
kind
:
ReplicaSet
metadata
:
name
:
vgg16job-tf-pserver
spec
:
replicas
:
10
template
:
metadata
:
labels
:
tf-job-pserver
:
vgg16job-tf
spec
:
hostNetwork
:
true
imagePullSecrets
:
-
name
:
job-registry-secret
containers
:
-
name
:
pserver
image
:
"
registry.baidu.com/paddlepaddle/fluid_benchmark_tf:vgg16"
imagePullPolicy
:
Always
command
:
[
"
tf_k8s"
,
"
start_tf"
]
ports
:
-
name
:
jobport-30236
containerPort
:
30236
env
:
-
name
:
PORT
value
:
"
32036"
-
name
:
ENTRY
value
:
"
python
vgg16_tf.py"
-
name
:
JOB_NAME
value
:
vgg16job-tf
-
name
:
PSERVERS_NUM
value
:
"
10"
-
name
:
TF_JOB_NAME
value
:
"
ps"
-
name
:
TRAINERS_NUM
value
:
"
20"
-
name
:
BATCH_SIZE
value
:
"
128"
-
name
:
TRAINER_PACKAGE
value
:
"
/workspace"
-
name
:
NUM_PASSES
value
:
"
1"
-
name
:
NAMESPACE
valueFrom
:
fieldRef
:
fieldPath
:
"
metadata.namespace"
-
name
:
POD_IP
valueFrom
:
fieldRef
:
fieldPath
:
"
status.podIP"
resources
:
requests
:
memory
:
10Gi
cpu
:
4
limits
:
memory
:
10Gi
cpu
:
4
benchmark/cluster/vgg16/tf_trainer.yaml
0 → 100644
浏览文件 @
ef35c4ed
apiVersion
:
batch/v1
kind
:
Job
metadata
:
name
:
vgg16job-tf-trainer
spec
:
parallelism
:
20
completions
:
20
template
:
metadata
:
labels
:
tf-job-trainer
:
vgg16job-tf
spec
:
imagePullSecrets
:
-
name
:
job-registry-secret
hostNetwork
:
true
containers
:
-
name
:
trainer
image
:
"
registry.baidu.com/paddlepaddle/fluid_benchmark_tf:vgg16"
imagePullPolicy
:
Always
command
:
[
"
tf_k8s"
,
"
start_tf"
]
ports
:
-
name
:
jobport-30236
containerPort
:
30236
env
:
-
name
:
PORT
value
:
"
32036"
-
name
:
JOB_NAME
value
:
vgg16job-tf
-
name
:
TF_JOB_NAME
value
:
"
worker"
-
name
:
ENTRY
value
:
"
python
vgg16_tf.py"
-
name
:
PSERVERS_NUM
value
:
"
10"
-
name
:
BATCH_SIZE
value
:
"
128"
-
name
:
TRAINERS_NUM
value
:
"
20"
-
name
:
TRAINER_PACKAGE
value
:
"
/workspace"
-
name
:
NUM_PASSES
value
:
"
1"
-
name
:
NAMESPACE
valueFrom
:
fieldRef
:
fieldPath
:
"
metadata.namespace"
-
name
:
POD_IP
valueFrom
:
fieldRef
:
fieldPath
:
"
status.podIP"
resources
:
requests
:
memory
:
40Gi
cpu
:
2
limits
:
memory
:
40Gi
cpu
:
2
restartPolicy
:
Never
benchmark/cluster/vgg16/vgg16_fluid.py
浏览文件 @
ef35c4ed
...
@@ -68,6 +68,21 @@ parser.add_argument(
...
@@ -68,6 +68,21 @@ parser.add_argument(
type
=
str2bool
,
type
=
str2bool
,
default
=
True
,
default
=
True
,
help
=
'Whether to run as local mode.'
)
help
=
'Whether to run as local mode.'
)
parser
.
add_argument
(
"--ps_hosts"
,
type
=
str
,
default
=
""
,
help
=
"Comma-separated list of hostname:port pairs"
)
parser
.
add_argument
(
"--trainer_hosts"
,
type
=
str
,
default
=
""
,
help
=
"Comma-separated list of hostname:port pairs"
)
# Flags for defining the tf.train.Server
parser
.
add_argument
(
"--task_index"
,
type
=
int
,
default
=
0
,
help
=
"Index of task within the job"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -180,8 +195,9 @@ def main():
...
@@ -180,8 +195,9 @@ def main():
iters
+=
1
iters
+=
1
num_samples
+=
len
(
data
)
num_samples
+=
len
(
data
)
print
(
print
(
"Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, spent %f"
"Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, Speed = %.2f img/s"
%
(
pass_id
,
iters
,
loss
,
acc
,
time
.
time
()
-
ts
)
%
(
pass_id
,
iters
,
loss
,
acc
,
len
(
data
)
/
(
time
.
time
()
-
ts
))
)
# The accuracy is the accumulation of batches, but not the current batch.
)
# The accuracy is the accumulation of batches, but not the current batch.
pass_elapsed
=
time
.
time
()
-
start_time
pass_elapsed
=
time
.
time
()
-
start_time
...
@@ -209,27 +225,24 @@ def main():
...
@@ -209,27 +225,24 @@ def main():
batch_size
=
args
.
batch_size
)
batch_size
=
args
.
batch_size
)
train_loop
(
exe
,
fluid
.
default_main_program
())
train_loop
(
exe
,
fluid
.
default_main_program
())
else
:
else
:
pserver_ips
=
os
.
getenv
(
"PADDLE_INIT_PSERVERS"
)
# all pserver endpoints
eplist
=
[]
for
ip
in
pserver_ips
.
split
(
","
):
eplist
.
append
(
':'
.
join
([
ip
,
"6174"
]))
pserver_endpoints
=
","
.
join
(
eplist
)
print
(
"pserver endpoints: "
,
pserver_endpoints
)
trainers
=
int
(
os
.
getenv
(
"TRAINERS"
))
# total trainer count
trainers
=
int
(
os
.
getenv
(
"TRAINERS"
))
# total trainer count
print
(
"trainers total: "
,
trainers
)
print
(
"trainers total: "
,
trainers
)
current_endpoint
=
os
.
getenv
(
"POD_IP"
)
+
":6174"
# current pserver endpoint
training_role
=
os
.
getenv
(
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINING_ROLE"
,
"TRAINER"
)
# get the training role: trainer/pserver
"TRAINER"
)
# get the training role: trainer/pserver
t
=
fluid
.
DistributeTranspiler
()
t
=
fluid
.
DistributeTranspiler
()
t
.
transpile
(
t
.
transpile
(
optimize_ops
,
optimize_ops
,
params_grads
,
params_grads
,
pservers
=
pserver_endpoints
,
trainer_id
=
args
.
task_index
,
pservers
=
args
.
ps_hosts
,
trainers
=
trainers
)
trainers
=
trainers
)
if
training_role
==
"PSERVER"
:
if
training_role
==
"PSERVER"
:
current_endpoint
=
os
.
getenv
(
"POD_IP"
)
+
":"
+
os
.
getenv
(
"PADDLE_INIT_PORT"
)
if
not
current_endpoint
:
if
not
current_endpoint
:
print
(
"need env SERVER_ENDPOINT"
)
print
(
"need env SERVER_ENDPOINT"
)
exit
(
1
)
exit
(
1
)
...
...
benchmark/cluster/vgg16/vgg16_tf.py
0 → 100644
浏览文件 @
ef35c4ed
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VGG16 benchmark in TensorFlow
You can get distribution example template structure here:
https://medium.com/clusterone/how-to-write-distributed-tensorflow-code-with-an-example-on-tensorport-70bf3306adcb
https://www.tensorflow.org/deploy/distributed
"""
import
tensorflow
as
tf
import
paddle.v2
as
paddle
import
numpy
as
np
import
argparse
import
time
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
128
,
help
=
"Batch size for training."
)
parser
.
add_argument
(
'--learning_rate'
,
type
=
float
,
default
=
1e-3
,
help
=
"Learning rate for training."
)
parser
.
add_argument
(
'--num_passes'
,
type
=
int
,
default
=
50
,
help
=
"No. of passes."
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'CPU'
,
choices
=
[
'CPU'
,
'GPU'
],
help
=
"The device type."
)
parser
.
add_argument
(
'--data_format'
,
type
=
str
,
default
=
'NHWC'
,
choices
=
[
'NCHW'
,
'NHWC'
],
help
=
'The data order, NCHW=[batch, channels, height, width].'
'Only support NHWC right now.'
)
parser
.
add_argument
(
'--data_set'
,
type
=
str
,
default
=
'cifar10'
,
choices
=
[
'cifar10'
,
'flowers'
],
help
=
'Optional dataset for benchmark.'
)
parser
.
add_argument
(
"--ps_hosts"
,
type
=
str
,
default
=
""
,
help
=
"Comma-separated list of hostname:port pairs"
)
parser
.
add_argument
(
"--worker_hosts"
,
type
=
str
,
default
=
""
,
help
=
"Comma-separated list of hostname:port pairs"
)
parser
.
add_argument
(
"--job_name"
,
type
=
str
,
default
=
""
,
help
=
"One of 'worker', 'ps'"
)
# Flags for defining the tf.train.Server
parser
.
add_argument
(
"--task_index"
,
type
=
int
,
default
=
0
,
help
=
"Index of task within the job"
)
args
=
parser
.
parse_args
()
class
VGG16Model
(
object
):
def
__init__
(
self
):
self
.
parameters
=
[]
def
batch_norm_relu
(
self
,
inputs
,
is_training
):
"""Performs a batch normalization followed by a ReLU."""
# We set fused=True for a significant speed boost. See
# https://www.tensorflow.org/speed/speed_guide#common_fused_ops
inputs
=
tf
.
layers
.
batch_normalization
(
inputs
=
inputs
,
axis
=
1
if
args
.
data_format
==
'NCHW'
else
-
1
,
momentum
=
0.9
,
epsilon
=
1e-05
,
center
=
True
,
scale
=
True
,
training
=
is_training
,
fused
=
True
)
inputs
=
tf
.
nn
.
relu
(
inputs
)
return
inputs
def
conv_bn_layer
(
self
,
name
,
images
,
kernel_shape
,
is_training
,
drop_rate
=
0.0
):
with
tf
.
name_scope
(
name
)
as
scope
:
kernel
=
tf
.
Variable
(
tf
.
truncated_normal
(
kernel_shape
,
dtype
=
tf
.
float32
,
stddev
=
1e-1
),
name
=
'weights'
)
conv
=
tf
.
nn
.
conv2d
(
images
,
kernel
,
[
1
,
1
,
1
,
1
],
data_format
=
args
.
data_format
,
padding
=
'SAME'
)
biases
=
tf
.
Variable
(
tf
.
constant
(
0.0
,
shape
=
[
kernel_shape
[
-
1
]],
dtype
=
tf
.
float32
),
trainable
=
True
,
name
=
'biases'
)
out
=
tf
.
nn
.
bias_add
(
conv
,
biases
)
out
=
self
.
batch_norm_relu
(
out
,
is_training
)
out
=
tf
.
layers
.
dropout
(
out
,
rate
=
drop_rate
,
training
=
is_training
)
return
out
def
fc_layer
(
self
,
name
,
inputs
,
shape
):
with
tf
.
name_scope
(
name
)
as
scope
:
fc_w
=
tf
.
Variable
(
tf
.
truncated_normal
(
shape
,
dtype
=
tf
.
float32
,
stddev
=
1e-1
),
name
=
'weights'
)
fc_b
=
tf
.
Variable
(
tf
.
constant
(
0.0
,
shape
=
[
shape
[
-
1
]],
dtype
=
tf
.
float32
),
trainable
=
True
,
name
=
'biases'
)
out
=
tf
.
nn
.
bias_add
(
tf
.
matmul
(
inputs
,
fc_w
),
fc_b
)
return
out
def
network
(
self
,
images
,
class_dim
,
is_training
):
""" VGG16 model structure.
TODO(kuke): enable this network to support the 'NCHW' data format
"""
# conv1
conv1_1
=
self
.
conv_bn_layer
(
'conv1_1'
,
images
,
[
3
,
3
,
3
,
64
],
is_training
,
drop_rate
=
0.3
)
conv1_2
=
self
.
conv_bn_layer
(
'conv1_2'
,
conv1_1
,
[
3
,
3
,
64
,
64
],
is_training
,
drop_rate
=
0.0
)
# pool1
pool1
=
tf
.
nn
.
max_pool
(
conv1_2
,
ksize
=
[
1
,
2
,
2
,
1
],
strides
=
[
1
,
2
,
2
,
1
],
padding
=
'SAME'
,
name
=
'pool1'
)
# conv2
conv2_1
=
self
.
conv_bn_layer
(
'conv2_1'
,
pool1
,
[
3
,
3
,
64
,
128
],
is_training
,
drop_rate
=
0.4
)
conv2_2
=
self
.
conv_bn_layer
(
'conv2_2'
,
conv2_1
,
[
3
,
3
,
128
,
128
],
is_training
,
drop_rate
=
0.0
)
# pool2
pool2
=
tf
.
nn
.
max_pool
(
conv2_2
,
ksize
=
[
1
,
2
,
2
,
1
],
strides
=
[
1
,
2
,
2
,
1
],
padding
=
'SAME'
,
name
=
'pool2'
)
# conv3
conv3_1
=
self
.
conv_bn_layer
(
'conv3_1'
,
pool2
,
[
3
,
3
,
128
,
256
],
is_training
,
drop_rate
=
0.4
)
conv3_2
=
self
.
conv_bn_layer
(
'conv3_2'
,
conv3_1
,
[
3
,
3
,
256
,
256
],
is_training
,
drop_rate
=
0.4
)
conv3_3
=
self
.
conv_bn_layer
(
'conv3_3'
,
conv3_2
,
[
3
,
3
,
256
,
256
],
is_training
,
drop_rate
=
0.0
)
# pool3
pool3
=
tf
.
nn
.
max_pool
(
conv3_3
,
ksize
=
[
1
,
2
,
2
,
1
],
strides
=
[
1
,
2
,
2
,
1
],
padding
=
'SAME'
,
name
=
'pool3'
)
# conv4
conv4_1
=
self
.
conv_bn_layer
(
'conv4_1'
,
pool3
,
[
3
,
3
,
256
,
512
],
is_training
,
drop_rate
=
0.4
)
conv4_2
=
self
.
conv_bn_layer
(
'conv4_2'
,
conv4_1
,
[
3
,
3
,
512
,
512
],
is_training
,
drop_rate
=
0.4
)
conv4_3
=
self
.
conv_bn_layer
(
'conv4_3'
,
conv4_2
,
[
3
,
3
,
512
,
512
],
is_training
,
drop_rate
=
0.0
)
# pool4
pool4
=
tf
.
nn
.
max_pool
(
conv4_3
,
ksize
=
[
1
,
2
,
2
,
1
],
strides
=
[
1
,
2
,
2
,
1
],
padding
=
'SAME'
,
name
=
'pool4'
)
# conv5
conv5_1
=
self
.
conv_bn_layer
(
'conv5_1'
,
pool4
,
[
3
,
3
,
512
,
512
],
is_training
,
drop_rate
=
0.4
)
conv5_2
=
self
.
conv_bn_layer
(
'conv5_2'
,
conv5_1
,
[
3
,
3
,
512
,
512
],
is_training
,
drop_rate
=
0.4
)
conv5_3
=
self
.
conv_bn_layer
(
'conv5_3'
,
conv5_2
,
[
3
,
3
,
512
,
512
],
is_training
,
drop_rate
=
0.0
)
# pool5
pool5
=
tf
.
nn
.
max_pool
(
conv5_3
,
ksize
=
[
1
,
2
,
2
,
1
],
strides
=
[
1
,
2
,
2
,
1
],
padding
=
'SAME'
,
name
=
'pool4'
)
# flatten
shape
=
int
(
np
.
prod
(
pool5
.
get_shape
()[
1
:]))
pool5_flat
=
tf
.
reshape
(
pool5
,
[
-
1
,
shape
])
# fc1
drop
=
tf
.
layers
.
dropout
(
pool5_flat
,
rate
=
0.5
,
training
=
is_training
)
fc1
=
self
.
fc_layer
(
'fc1'
,
drop
,
[
shape
,
512
])
# fc2
bn
=
self
.
batch_norm_relu
(
fc1
,
is_training
)
drop
=
tf
.
layers
.
dropout
(
bn
,
rate
=
0.5
,
training
=
is_training
)
fc2
=
self
.
fc_layer
(
'fc2'
,
drop
,
[
512
,
512
])
fc3
=
self
.
fc_layer
(
'fc3'
,
fc2
,
[
512
,
class_dim
])
return
fc3
def
run_benchmark
(
cluster_spec
,
server
):
"""Run benchmark on cifar10 or flowers."""
if
args
.
data_set
==
"cifar10"
:
class_dim
=
10
raw_shape
=
(
3
,
32
,
32
)
dat_shape
=
(
None
,
32
,
32
,
3
)
if
args
.
data_format
==
'NHWC'
else
(
None
,
3
,
32
,
32
)
else
:
class_dim
=
102
raw_shape
=
(
3
,
224
,
224
)
dat_shape
=
(
None
,
224
,
224
,
3
)
if
args
.
data_format
==
'NHWC'
else
(
None
,
3
,
224
,
224
)
device
=
tf
.
train
.
replica_device_setter
(
worker_device
=
"/job:worker/task:{}"
.
format
(
args
.
task_index
),
cluster
=
cluster_spec
)
with
tf
.
device
(
device
):
images
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
dat_shape
)
labels
=
tf
.
placeholder
(
tf
.
int64
,
shape
=
(
None
,
))
is_training
=
tf
.
placeholder
(
'bool'
)
onehot_labels
=
tf
.
one_hot
(
labels
,
depth
=
class_dim
)
vgg16
=
VGG16Model
()
logits
=
vgg16
.
network
(
images
,
class_dim
,
is_training
)
loss
=
tf
.
losses
.
softmax_cross_entropy
(
onehot_labels
=
onehot_labels
,
logits
=
logits
)
avg_loss
=
tf
.
reduce_mean
(
loss
)
correct
=
tf
.
equal
(
tf
.
argmax
(
logits
,
1
),
labels
)
accuracy
=
tf
.
reduce_mean
(
tf
.
cast
(
correct
,
tf
.
float32
))
optimizer
=
tf
.
train
.
AdamOptimizer
(
learning_rate
=
args
.
learning_rate
)
update_ops
=
tf
.
get_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
)
global_step
=
tf
.
Variable
(
0
,
name
=
'global_step'
,
trainable
=
False
)
with
tf
.
control_dependencies
(
update_ops
):
train_op
=
optimizer
.
minimize
(
avg_loss
,
global_step
=
global_step
)
summary_op
=
tf
.
summary
.
merge_all
()
init_op
=
tf
.
global_variables_initializer
()
# data reader
train_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
cifar
.
train10
()
if
args
.
data_set
==
'cifar10'
else
paddle
.
dataset
.
flowers
.
train
(),
buf_size
=
5120
),
batch_size
=
args
.
batch_size
)
test_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
cifar
.
test10
()
if
args
.
data_set
==
'cifar10'
else
paddle
.
dataset
.
flowers
.
test
(),
buf_size
=
5120
),
batch_size
=
args
.
batch_size
)
# test
def
test
():
test_accs
=
[]
for
batch_id
,
data
in
enumerate
(
test_reader
()):
test_images
=
np
.
array
(
map
(
lambda
x
:
np
.
transpose
(
x
[
0
].
reshape
(
raw_shape
),
axes
=
[
1
,
2
,
0
])
if
args
.
data_format
==
'NHWC'
else
x
[
0
],
data
)).
astype
(
"float32"
)
test_labels
=
np
.
array
(
map
(
lambda
x
:
x
[
1
],
data
)).
astype
(
'int64'
)
test_accs
.
append
(
accuracy
.
eval
(
feed_dict
=
{
images
:
test_images
,
labels
:
test_labels
,
is_training
:
False
}))
return
np
.
mean
(
test_accs
)
config
=
tf
.
ConfigProto
(
intra_op_parallelism_threads
=
1
,
inter_op_parallelism_threads
=
1
)
config
.
gpu_options
.
allow_growth
=
True
hooks
=
[
tf
.
train
.
StopAtStepHook
(
last_step
=
1000000
)]
with
tf
.
train
.
MonitoredTrainingSession
(
master
=
server
.
target
,
is_chief
=
(
args
.
task_index
==
0
),
hooks
=
hooks
)
as
sess
:
iters
,
num_samples
,
start_time
=
0
,
0
,
0.0
for
pass_id
in
range
(
args
.
num_passes
):
# train
num_samples
=
0
start_time
=
time
.
time
()
for
batch_id
,
data
in
enumerate
(
train_reader
()):
train_images
=
np
.
array
(
map
(
lambda
x
:
np
.
transpose
(
x
[
0
].
reshape
(
raw_shape
),
axes
=
[
1
,
2
,
0
])
if
args
.
data_format
==
'NHWC'
else
x
[
0
],
data
)).
astype
(
"float32"
)
train_labels
=
np
.
array
(
map
(
lambda
x
:
x
[
1
],
data
)).
astype
(
'int64'
)
iter_begin_time
=
time
.
time
()
_
,
loss
,
acc
=
sess
.
run
([
train_op
,
avg_loss
,
accuracy
],
feed_dict
=
{
images
:
train_images
,
labels
:
train_labels
,
is_training
:
True
})
iters
+=
1
print
(
"Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, Speed=%.2f imgs/sec"
%
(
pass_id
,
iters
,
loss
,
acc
,
len
(
data
)
/
(
time
.
time
()
-
iter_begin_time
)))
num_samples
+=
len
(
data
)
train_elapsed
=
time
.
time
()
-
start_time
# test
pass_test_acc
=
test
()
print
(
"Pass = %d, Train speed = %f imgs/s, Test accuracy = %f
\n
"
%
(
pass_id
,
num_samples
/
train_elapsed
,
pass_test_acc
))
def
print_arguments
():
print
(
'----------- Configuration Arguments -----------'
)
for
arg
,
value
in
sorted
(
vars
(
args
).
iteritems
()):
print
(
'%s: %s'
%
(
arg
,
value
))
print
(
'------------------------------------------------'
)
if
__name__
==
'__main__'
:
print_arguments
()
ps_hosts
=
args
.
ps_hosts
.
split
(
","
)
worker_hosts
=
args
.
worker_hosts
.
split
(
","
)
# Create a cluster from the parameter server and worker hosts.
cluster_spec
=
tf
.
train
.
ClusterSpec
({
"ps"
:
ps_hosts
,
"worker"
:
worker_hosts
})
# Create and start a server for the local task.
server
=
tf
.
train
.
Server
(
cluster_spec
,
job_name
=
args
.
job_name
,
task_index
=
args
.
task_index
)
if
args
.
job_name
==
"ps"
:
print
(
"start pserver"
)
server
.
join
()
elif
args
.
job_name
==
"worker"
:
print
(
"start worker"
)
run_benchmark
(
cluster_spec
,
server
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录