Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c0876cf6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c0876cf6
编写于
2月 28, 2018
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
差异文件
update due to upstream's change
上级
ee88855d
c02f773a
变更
60
隐藏空白更改
内联
并排
Showing
60 changed file
with
1898 addition
and
452 deletion
+1898
-452
CMakeLists.txt
CMakeLists.txt
+1
-0
benchmark/cluster/vgg16/Dockerfile
benchmark/cluster/vgg16/Dockerfile
+26
-9
benchmark/cluster/vgg16/fluid_trainer.yaml
benchmark/cluster/vgg16/fluid_trainer.yaml
+1
-1
benchmark/cluster/vgg16/tf_k8s
benchmark/cluster/vgg16/tf_k8s
+82
-0
benchmark/cluster/vgg16/tf_pserver.yaml
benchmark/cluster/vgg16/tf_pserver.yaml
+56
-0
benchmark/cluster/vgg16/tf_trainer.yaml
benchmark/cluster/vgg16/tf_trainer.yaml
+58
-0
benchmark/cluster/vgg16/vgg16_fluid.py
benchmark/cluster/vgg16/vgg16_fluid.py
+24
-11
benchmark/cluster/vgg16/vgg16_tf.py
benchmark/cluster/vgg16/vgg16_tf.py
+362
-0
cmake/configure.cmake
cmake/configure.cmake
+9
-1
cmake/cuda.cmake
cmake/cuda.cmake
+2
-1
cmake/cupti.cmake
cmake/cupti.cmake
+41
-0
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+2
-2
paddle/fluid/framework/framework.proto
paddle/fluid/framework/framework.proto
+2
-0
paddle/fluid/framework/lod_tensor.cc
paddle/fluid/framework/lod_tensor.cc
+7
-1
paddle/fluid/framework/op_desc.h
paddle/fluid/framework/op_desc.h
+2
-0
paddle/fluid/inference/io.cc
paddle/fluid/inference/io.cc
+7
-20
paddle/fluid/inference/tests/book/CMakeLists.txt
paddle/fluid/inference/tests/book/CMakeLists.txt
+1
-1
paddle/fluid/inference/tests/book/test_inference_label_semantic_roles.cc
...ference/tests/book/test_inference_label_semantic_roles.cc
+36
-10
paddle/fluid/inference/tests/book/test_inference_understand_sentiment.cc
...ference/tests/book/test_inference_understand_sentiment.cc
+6
-1
paddle/fluid/inference/tests/book/test_inference_word2vec.cc
paddle/fluid/inference/tests/book/test_inference_word2vec.cc
+5
-5
paddle/fluid/inference/tests/test_helper.h
paddle/fluid/inference/tests/test_helper.h
+2
-2
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+23
-28
paddle/fluid/operators/bipartite_match_op.cc
paddle/fluid/operators/bipartite_match_op.cc
+55
-2
paddle/fluid/platform/CMakeLists.txt
paddle/fluid/platform/CMakeLists.txt
+4
-1
paddle/fluid/platform/device_tracer.cc
paddle/fluid/platform/device_tracer.cc
+285
-0
paddle/fluid/platform/device_tracer.h
paddle/fluid/platform/device_tracer.h
+72
-0
paddle/fluid/platform/dynload/CMakeLists.txt
paddle/fluid/platform/dynload/CMakeLists.txt
+6
-2
paddle/fluid/platform/dynload/cupti.cc
paddle/fluid/platform/dynload/cupti.cc
+35
-0
paddle/fluid/platform/dynload/cupti.h
paddle/fluid/platform/dynload/cupti.h
+86
-0
paddle/fluid/platform/dynload/dynamic_loader.cc
paddle/fluid/platform/dynload/dynamic_loader.cc
+16
-0
paddle/fluid/platform/dynload/dynamic_loader.h
paddle/fluid/platform/dynload/dynamic_loader.h
+2
-0
paddle/fluid/platform/profiler.cc
paddle/fluid/platform/profiler.cc
+28
-4
paddle/fluid/platform/profiler.h
paddle/fluid/platform/profiler.h
+7
-2
paddle/fluid/platform/profiler.proto
paddle/fluid/platform/profiler.proto
+30
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+1
-0
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+1
-0
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+57
-41
python/paddle/fluid/layers/detection.py
python/paddle/fluid/layers/detection.py
+16
-3
python/paddle/fluid/profiler.py
python/paddle/fluid/profiler.py
+8
-3
python/paddle/fluid/tests/book/notest_rnn_encoder_decoer.py
python/paddle/fluid/tests/book/notest_rnn_encoder_decoer.py
+28
-26
python/paddle/fluid/tests/book/test_fit_a_line.py
python/paddle/fluid/tests/book/test_fit_a_line.py
+20
-17
python/paddle/fluid/tests/book/test_image_classification.py
python/paddle/fluid/tests/book/test_image_classification.py
+20
-16
python/paddle/fluid/tests/book/test_label_semantic_roles.py
python/paddle/fluid/tests/book/test_label_semantic_roles.py
+57
-47
python/paddle/fluid/tests/book/test_recognize_digits.py
python/paddle/fluid/tests/book/test_recognize_digits.py
+43
-29
python/paddle/fluid/tests/book/test_recommender_system.py
python/paddle/fluid/tests/book/test_recommender_system.py
+47
-45
python/paddle/fluid/tests/book/test_understand_sentiment.py
python/paddle/fluid/tests/book/test_understand_sentiment.py
+40
-29
python/paddle/fluid/tests/book/test_word2vec.py
python/paddle/fluid/tests/book/test_word2vec.py
+74
-61
python/paddle/fluid/tests/book_distribute/notest_dist_fit_a_line.py
...dle/fluid/tests/book_distribute/notest_dist_fit_a_line.py
+1
-2
python/paddle/fluid/tests/book_distribute/notest_dist_image_classification.py
...tests/book_distribute/notest_dist_image_classification.py
+1
-0
python/paddle/fluid/tests/book_distribute/notest_dist_label_semantic_roles.py
...tests/book_distribute/notest_dist_label_semantic_roles.py
+1
-0
python/paddle/fluid/tests/book_distribute/notest_dist_word2vec.py
...addle/fluid/tests/book_distribute/notest_dist_word2vec.py
+3
-1
python/paddle/fluid/tests/book_distribute/notest_machine_translation.py
...fluid/tests/book_distribute/notest_machine_translation.py
+1
-0
python/paddle/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py
...ests/book_distribute/notest_recognize_digits_conv_dist.py
+1
-5
python/paddle/fluid/tests/book_distribute/notest_recommender_system_dist.py
...d/tests/book_distribute/notest_recommender_system_dist.py
+1
-0
python/paddle/fluid/tests/book_distribute/notest_understand_sentiment_conv_dist.py
.../book_distribute/notest_understand_sentiment_conv_dist.py
+1
-0
python/paddle/fluid/tests/book_distribute/notest_understand_sentiment_dynamic_lstm.py
...ok_distribute/notest_understand_sentiment_dynamic_lstm.py
+1
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+2
-0
python/paddle/fluid/tests/unittests/test_bipartite_match_op.py
...n/paddle/fluid/tests/unittests/test_bipartite_match_op.py
+41
-3
python/paddle/fluid/tests/unittests/test_nvprof.py
python/paddle/fluid/tests/unittests/test_nvprof.py
+46
-0
python/paddle/fluid/tests/unittests/test_profiler.py
python/paddle/fluid/tests/unittests/test_profiler.py
+5
-20
未找到文件。
CMakeLists.txt
浏览文件 @
c0876cf6
...
...
@@ -146,6 +146,7 @@ include(external/cares)
include
(
external/grpc
)
include
(
cudnn
)
# set cudnn libraries, must before configure
include
(
cupti
)
include
(
configure
)
# add paddle env configuration
include
(
generic
)
# simplify cmake module
include
(
package
)
# set paddle packages
...
...
benchmark/cluster/vgg16/Dockerfile
浏览文件 @
c0876cf6
#FROM python:2.7.14
FROM
nvidia/cuda:8.0-cudnn5-runtime-ubuntu16.04
RUN
apt-get update
&&
apt-get
install
-y
python
RUN
pip
install
-U
kubernetes opencv-python
&&
apt-get update
-y
&&
apt-get
install
-y
iputils-ping libgtk2.0-dev
# NOTE: By default CI built wheel packages turn WITH_DISTRIBUTE=OFF,
# so we must build one with distribute support to install in this image.
# you can get mirror list here:
# https://launchpad.net/ubuntu/+archivemirrors
ARG
UBUNTU_MIRROR
RUN
/bin/bash
-c
'if [[ -n ${UBUNTU_MIRROR} ]]; then sed -i '
s#http://archive.ubuntu.com/ubuntu#
${
UBUNTU_MIRROR
}
#g' /etc/apt/sources.list; fi'
RUN
apt-get update
&&
apt-get
install
-y
python python-dev python-pip iputils-ping libgtk2.0-dev
RUN
pip
install
-U
kubernetes opencv-python
RUN
pip
install
paddlepaddle
# if network is slowly, you may need to add proxy here.
# ENV https_proxy=
RUN
sh
-c
'echo "import paddle.v2 as paddle\npaddle.dataset.cifar.train10()" | python'
RUN
pip uninstall
-y
paddlepaddle
# unset proxy if it is setted.
# ENV https_proxy=""
# NOTE: By default CI built wheel packages turn WITH_DISTRIBUTE=OFF,
# so we must build one with distribute support to install in this image.
ADD
*.whl /
RUN
pip
install
/
*
.whl
&&
rm
-f
/
*
.whl
ENV
LD_LIBRARY_PATH=/usr/local/lib
# tf k8s
RUN
pip
install
tensorflow
==
1.4.0
ADD
tf_k8s /usr/bin
RUN
chmod
+x /usr/bin/tf_k8s
ADD
vgg16_tf.py /workspace/
# below lines may change a lot for debugging
ADD
https://raw.githubusercontent.com/PaddlePaddle/cloud/develop/docker/paddle_k8s /usr/bin
ADD
https://raw.githubusercontent.com/PaddlePaddle/cloud/develop/docker/k8s_tools.py /root
ADD
*.whl /
RUN
pip
install
/
*
.whl
&&
rm
-f
/
*
.whl
&&
\
chmod
+x /usr/bin/paddle_k8s
ENV
LD_LIBRARY_PATH=/usr/local/lib
RUN
chmod
+x /usr/bin/paddle_k8s
ADD
vgg16_fluid.py vgg16_v2.py /workspace/
benchmark/cluster/vgg16/fluid_trainer.yaml
浏览文件 @
c0876cf6
...
...
@@ -11,7 +11,7 @@ spec:
paddle-job
:
vgg16job
spec
:
imagePullSecrets
:
-
name
:
job-registry-secret
-
name
:
job-registry-secret
hostNetwork
:
true
containers
:
-
name
:
trainer
...
...
benchmark/cluster/vgg16/tf_k8s
0 → 100644
浏览文件 @
c0876cf6
#!/bin/bash
check_trainer_ret
()
{
ret
=
$1
stdbuf
-oL
echo
"job returned
$ret
...setting pod return message..."
stdbuf
-oL
echo
"==============================="
if
[
$ret
-eq
136
]
;
then
echo
"Error Arithmetic Operation(Floating Point Exception)"
>
/dev/termination-log
elif
[
$ret
-eq
139
]
;
then
echo
"Segmentation Fault"
>
/dev/termination-log
elif
[
$ret
-eq
1
]
;
then
echo
"General Error"
>
/dev/termination-log
elif
[
$ret
-eq
134
]
;
then
echo
"Program Abort"
>
/dev/termination-log
fi
stdbuf
-oL
echo
"termination log wroted..."
exit
$ret
}
g_pservers
=
""
g_trainers
=
""
wait_running_pods
(){
pserver_label
=
"tf-job-pserver=
${
JOB_NAME
}
"
trainer_label
=
"tf-job-trainer=
${
JOB_NAME
}
"
stdbuf
-oL
python /root/k8s_tools.py wait_pods_running
${
pserver_label
}
${
PSERVERS_NUM
}
stdbuf
-oL
python /root/k8s_tools.py wait_pods_running
${
trainer_label
}
${
TRAINERS_NUM
}
g_pservers
=
$(
python /root/k8s_tools.py fetch_endpoints
${
pserver_label
}
${
PORT
}
)
g_trainers
=
$(
python /root/k8s_tools.py fetch_endpoints
${
trainer_label
}
${
PORT
}
)
}
start_tf_pserver
(){
wait_running_pods
label
=
"tf-job-pserver=
${
JOB_NAME
}
"
pserver_id
=
$(
python /root/k8s_tools.py fetch_id
${
label
}
)
cmd
=
"
${
ENTRY
}
--ps_hosts=
${
g_pservers
}
--worker_hosts=
${
g_trainers
}
\
--job_name=
${
TF_JOB_NAME
}
--task_index=
${
pserver_id
}
"
stdbuf
-oL
sh
-c
"cd
${
TRAINER_PACKAGE
}
&&
${
cmd
}
"
}
start_tf_trainer
(){
wait_running_pods
label
=
"tf-job-trainer=
${
JOB_NAME
}
"
trainer_id
=
$(
python /root/k8s_tools.py fetch_id
${
label
}
)
cmd
=
"
${
ENTRY
}
--ps_hosts=
${
g_pservers
}
--worker_hosts=
${
g_trainers
}
\
--job_name=
${
TF_JOB_NAME
}
--task_index=
${
trainer_id
}
--batch_size=
${
BATCH_SIZE
}
"
stdbuf
-oL
sh
-c
"cd
${
TRAINER_PACKAGE
}
&&
${
cmd
}
"
check_trainer_ret
$?
}
start_tf
(){
if
[[
"
${
TF_JOB_NAME
}
"
==
"worker"
]]
;
then
start_tf_trainer
else
start_tf_pserver
fi
}
usage
()
{
echo
"usage: tf_k8s [<args>]:"
echo
" start_tf Start tensorflow jobs"
}
case
"
$1
"
in
start_tf
)
start_tf
;;
--help
)
usage
;;
*
)
usage
;;
esac
benchmark/cluster/vgg16/tf_pserver.yaml
0 → 100644
浏览文件 @
c0876cf6
apiVersion
:
extensions/v1beta1
kind
:
ReplicaSet
metadata
:
name
:
vgg16job-tf-pserver
spec
:
replicas
:
10
template
:
metadata
:
labels
:
tf-job-pserver
:
vgg16job-tf
spec
:
hostNetwork
:
true
imagePullSecrets
:
-
name
:
job-registry-secret
containers
:
-
name
:
pserver
image
:
"
registry.baidu.com/paddlepaddle/fluid_benchmark_tf:vgg16"
imagePullPolicy
:
Always
command
:
[
"
tf_k8s"
,
"
start_tf"
]
ports
:
-
name
:
jobport-30236
containerPort
:
30236
env
:
-
name
:
PORT
value
:
"
32036"
-
name
:
ENTRY
value
:
"
python
vgg16_tf.py"
-
name
:
JOB_NAME
value
:
vgg16job-tf
-
name
:
PSERVERS_NUM
value
:
"
10"
-
name
:
TF_JOB_NAME
value
:
"
ps"
-
name
:
TRAINERS_NUM
value
:
"
20"
-
name
:
BATCH_SIZE
value
:
"
128"
-
name
:
TRAINER_PACKAGE
value
:
"
/workspace"
-
name
:
NUM_PASSES
value
:
"
1"
-
name
:
NAMESPACE
valueFrom
:
fieldRef
:
fieldPath
:
"
metadata.namespace"
-
name
:
POD_IP
valueFrom
:
fieldRef
:
fieldPath
:
"
status.podIP"
resources
:
requests
:
memory
:
10Gi
cpu
:
4
limits
:
memory
:
10Gi
cpu
:
4
benchmark/cluster/vgg16/tf_trainer.yaml
0 → 100644
浏览文件 @
c0876cf6
apiVersion
:
batch/v1
kind
:
Job
metadata
:
name
:
vgg16job-tf-trainer
spec
:
parallelism
:
20
completions
:
20
template
:
metadata
:
labels
:
tf-job-trainer
:
vgg16job-tf
spec
:
imagePullSecrets
:
-
name
:
job-registry-secret
hostNetwork
:
true
containers
:
-
name
:
trainer
image
:
"
registry.baidu.com/paddlepaddle/fluid_benchmark_tf:vgg16"
imagePullPolicy
:
Always
command
:
[
"
tf_k8s"
,
"
start_tf"
]
ports
:
-
name
:
jobport-30236
containerPort
:
30236
env
:
-
name
:
PORT
value
:
"
32036"
-
name
:
JOB_NAME
value
:
vgg16job-tf
-
name
:
TF_JOB_NAME
value
:
"
worker"
-
name
:
ENTRY
value
:
"
python
vgg16_tf.py"
-
name
:
PSERVERS_NUM
value
:
"
10"
-
name
:
BATCH_SIZE
value
:
"
128"
-
name
:
TRAINERS_NUM
value
:
"
20"
-
name
:
TRAINER_PACKAGE
value
:
"
/workspace"
-
name
:
NUM_PASSES
value
:
"
1"
-
name
:
NAMESPACE
valueFrom
:
fieldRef
:
fieldPath
:
"
metadata.namespace"
-
name
:
POD_IP
valueFrom
:
fieldRef
:
fieldPath
:
"
status.podIP"
resources
:
requests
:
memory
:
40Gi
cpu
:
2
limits
:
memory
:
40Gi
cpu
:
2
restartPolicy
:
Never
benchmark/cluster/vgg16/vgg16_fluid.py
浏览文件 @
c0876cf6
...
...
@@ -68,6 +68,21 @@ parser.add_argument(
type
=
str2bool
,
default
=
True
,
help
=
'Whether to run as local mode.'
)
parser
.
add_argument
(
"--ps_hosts"
,
type
=
str
,
default
=
""
,
help
=
"Comma-separated list of hostname:port pairs"
)
parser
.
add_argument
(
"--trainer_hosts"
,
type
=
str
,
default
=
""
,
help
=
"Comma-separated list of hostname:port pairs"
)
# Flags for defining the tf.train.Server
parser
.
add_argument
(
"--task_index"
,
type
=
int
,
default
=
0
,
help
=
"Index of task within the job"
)
args
=
parser
.
parse_args
()
...
...
@@ -180,8 +195,9 @@ def main():
iters
+=
1
num_samples
+=
len
(
data
)
print
(
"Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, spent %f"
%
(
pass_id
,
iters
,
loss
,
acc
,
time
.
time
()
-
ts
)
"Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, Speed = %.2f img/s"
%
(
pass_id
,
iters
,
loss
,
acc
,
len
(
data
)
/
(
time
.
time
()
-
ts
))
)
# The accuracy is the accumulation of batches, but not the current batch.
pass_elapsed
=
time
.
time
()
-
start_time
...
...
@@ -209,27 +225,24 @@ def main():
batch_size
=
args
.
batch_size
)
train_loop
(
exe
,
fluid
.
default_main_program
())
else
:
pserver_ips
=
os
.
getenv
(
"PADDLE_INIT_PSERVERS"
)
# all pserver endpoints
eplist
=
[]
for
ip
in
pserver_ips
.
split
(
","
):
eplist
.
append
(
':'
.
join
([
ip
,
"6174"
]))
pserver_endpoints
=
","
.
join
(
eplist
)
print
(
"pserver endpoints: "
,
pserver_endpoints
)
trainers
=
int
(
os
.
getenv
(
"TRAINERS"
))
# total trainer count
print
(
"trainers total: "
,
trainers
)
current_endpoint
=
os
.
getenv
(
"POD_IP"
)
+
":6174"
# current pserver endpoint
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
# get the training role: trainer/pserver
t
=
fluid
.
DistributeTranspiler
()
t
.
transpile
(
optimize_ops
,
params_grads
,
pservers
=
pserver_endpoints
,
trainer_id
=
args
.
task_index
,
pservers
=
args
.
ps_hosts
,
trainers
=
trainers
)
if
training_role
==
"PSERVER"
:
current_endpoint
=
os
.
getenv
(
"POD_IP"
)
+
":"
+
os
.
getenv
(
"PADDLE_INIT_PORT"
)
if
not
current_endpoint
:
print
(
"need env SERVER_ENDPOINT"
)
exit
(
1
)
...
...
benchmark/cluster/vgg16/vgg16_tf.py
0 → 100644
浏览文件 @
c0876cf6
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VGG16 benchmark in TensorFlow
You can get distribution example template structure here:
https://medium.com/clusterone/how-to-write-distributed-tensorflow-code-with-an-example-on-tensorport-70bf3306adcb
https://www.tensorflow.org/deploy/distributed
"""
import
tensorflow
as
tf
import
paddle.v2
as
paddle
import
numpy
as
np
import
argparse
import
time
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
128
,
help
=
"Batch size for training."
)
parser
.
add_argument
(
'--learning_rate'
,
type
=
float
,
default
=
1e-3
,
help
=
"Learning rate for training."
)
parser
.
add_argument
(
'--num_passes'
,
type
=
int
,
default
=
50
,
help
=
"No. of passes."
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'CPU'
,
choices
=
[
'CPU'
,
'GPU'
],
help
=
"The device type."
)
parser
.
add_argument
(
'--data_format'
,
type
=
str
,
default
=
'NHWC'
,
choices
=
[
'NCHW'
,
'NHWC'
],
help
=
'The data order, NCHW=[batch, channels, height, width].'
'Only support NHWC right now.'
)
parser
.
add_argument
(
'--data_set'
,
type
=
str
,
default
=
'cifar10'
,
choices
=
[
'cifar10'
,
'flowers'
],
help
=
'Optional dataset for benchmark.'
)
parser
.
add_argument
(
"--ps_hosts"
,
type
=
str
,
default
=
""
,
help
=
"Comma-separated list of hostname:port pairs"
)
parser
.
add_argument
(
"--worker_hosts"
,
type
=
str
,
default
=
""
,
help
=
"Comma-separated list of hostname:port pairs"
)
parser
.
add_argument
(
"--job_name"
,
type
=
str
,
default
=
""
,
help
=
"One of 'worker', 'ps'"
)
# Flags for defining the tf.train.Server
parser
.
add_argument
(
"--task_index"
,
type
=
int
,
default
=
0
,
help
=
"Index of task within the job"
)
args
=
parser
.
parse_args
()
class
VGG16Model
(
object
):
def
__init__
(
self
):
self
.
parameters
=
[]
def
batch_norm_relu
(
self
,
inputs
,
is_training
):
"""Performs a batch normalization followed by a ReLU."""
# We set fused=True for a significant speed boost. See
# https://www.tensorflow.org/speed/speed_guide#common_fused_ops
inputs
=
tf
.
layers
.
batch_normalization
(
inputs
=
inputs
,
axis
=
1
if
args
.
data_format
==
'NCHW'
else
-
1
,
momentum
=
0.9
,
epsilon
=
1e-05
,
center
=
True
,
scale
=
True
,
training
=
is_training
,
fused
=
True
)
inputs
=
tf
.
nn
.
relu
(
inputs
)
return
inputs
def
conv_bn_layer
(
self
,
name
,
images
,
kernel_shape
,
is_training
,
drop_rate
=
0.0
):
with
tf
.
name_scope
(
name
)
as
scope
:
kernel
=
tf
.
Variable
(
tf
.
truncated_normal
(
kernel_shape
,
dtype
=
tf
.
float32
,
stddev
=
1e-1
),
name
=
'weights'
)
conv
=
tf
.
nn
.
conv2d
(
images
,
kernel
,
[
1
,
1
,
1
,
1
],
data_format
=
args
.
data_format
,
padding
=
'SAME'
)
biases
=
tf
.
Variable
(
tf
.
constant
(
0.0
,
shape
=
[
kernel_shape
[
-
1
]],
dtype
=
tf
.
float32
),
trainable
=
True
,
name
=
'biases'
)
out
=
tf
.
nn
.
bias_add
(
conv
,
biases
)
out
=
self
.
batch_norm_relu
(
out
,
is_training
)
out
=
tf
.
layers
.
dropout
(
out
,
rate
=
drop_rate
,
training
=
is_training
)
return
out
def
fc_layer
(
self
,
name
,
inputs
,
shape
):
with
tf
.
name_scope
(
name
)
as
scope
:
fc_w
=
tf
.
Variable
(
tf
.
truncated_normal
(
shape
,
dtype
=
tf
.
float32
,
stddev
=
1e-1
),
name
=
'weights'
)
fc_b
=
tf
.
Variable
(
tf
.
constant
(
0.0
,
shape
=
[
shape
[
-
1
]],
dtype
=
tf
.
float32
),
trainable
=
True
,
name
=
'biases'
)
out
=
tf
.
nn
.
bias_add
(
tf
.
matmul
(
inputs
,
fc_w
),
fc_b
)
return
out
def
network
(
self
,
images
,
class_dim
,
is_training
):
""" VGG16 model structure.
TODO(kuke): enable this network to support the 'NCHW' data format
"""
# conv1
conv1_1
=
self
.
conv_bn_layer
(
'conv1_1'
,
images
,
[
3
,
3
,
3
,
64
],
is_training
,
drop_rate
=
0.3
)
conv1_2
=
self
.
conv_bn_layer
(
'conv1_2'
,
conv1_1
,
[
3
,
3
,
64
,
64
],
is_training
,
drop_rate
=
0.0
)
# pool1
pool1
=
tf
.
nn
.
max_pool
(
conv1_2
,
ksize
=
[
1
,
2
,
2
,
1
],
strides
=
[
1
,
2
,
2
,
1
],
padding
=
'SAME'
,
name
=
'pool1'
)
# conv2
conv2_1
=
self
.
conv_bn_layer
(
'conv2_1'
,
pool1
,
[
3
,
3
,
64
,
128
],
is_training
,
drop_rate
=
0.4
)
conv2_2
=
self
.
conv_bn_layer
(
'conv2_2'
,
conv2_1
,
[
3
,
3
,
128
,
128
],
is_training
,
drop_rate
=
0.0
)
# pool2
pool2
=
tf
.
nn
.
max_pool
(
conv2_2
,
ksize
=
[
1
,
2
,
2
,
1
],
strides
=
[
1
,
2
,
2
,
1
],
padding
=
'SAME'
,
name
=
'pool2'
)
# conv3
conv3_1
=
self
.
conv_bn_layer
(
'conv3_1'
,
pool2
,
[
3
,
3
,
128
,
256
],
is_training
,
drop_rate
=
0.4
)
conv3_2
=
self
.
conv_bn_layer
(
'conv3_2'
,
conv3_1
,
[
3
,
3
,
256
,
256
],
is_training
,
drop_rate
=
0.4
)
conv3_3
=
self
.
conv_bn_layer
(
'conv3_3'
,
conv3_2
,
[
3
,
3
,
256
,
256
],
is_training
,
drop_rate
=
0.0
)
# pool3
pool3
=
tf
.
nn
.
max_pool
(
conv3_3
,
ksize
=
[
1
,
2
,
2
,
1
],
strides
=
[
1
,
2
,
2
,
1
],
padding
=
'SAME'
,
name
=
'pool3'
)
# conv4
conv4_1
=
self
.
conv_bn_layer
(
'conv4_1'
,
pool3
,
[
3
,
3
,
256
,
512
],
is_training
,
drop_rate
=
0.4
)
conv4_2
=
self
.
conv_bn_layer
(
'conv4_2'
,
conv4_1
,
[
3
,
3
,
512
,
512
],
is_training
,
drop_rate
=
0.4
)
conv4_3
=
self
.
conv_bn_layer
(
'conv4_3'
,
conv4_2
,
[
3
,
3
,
512
,
512
],
is_training
,
drop_rate
=
0.0
)
# pool4
pool4
=
tf
.
nn
.
max_pool
(
conv4_3
,
ksize
=
[
1
,
2
,
2
,
1
],
strides
=
[
1
,
2
,
2
,
1
],
padding
=
'SAME'
,
name
=
'pool4'
)
# conv5
conv5_1
=
self
.
conv_bn_layer
(
'conv5_1'
,
pool4
,
[
3
,
3
,
512
,
512
],
is_training
,
drop_rate
=
0.4
)
conv5_2
=
self
.
conv_bn_layer
(
'conv5_2'
,
conv5_1
,
[
3
,
3
,
512
,
512
],
is_training
,
drop_rate
=
0.4
)
conv5_3
=
self
.
conv_bn_layer
(
'conv5_3'
,
conv5_2
,
[
3
,
3
,
512
,
512
],
is_training
,
drop_rate
=
0.0
)
# pool5
pool5
=
tf
.
nn
.
max_pool
(
conv5_3
,
ksize
=
[
1
,
2
,
2
,
1
],
strides
=
[
1
,
2
,
2
,
1
],
padding
=
'SAME'
,
name
=
'pool4'
)
# flatten
shape
=
int
(
np
.
prod
(
pool5
.
get_shape
()[
1
:]))
pool5_flat
=
tf
.
reshape
(
pool5
,
[
-
1
,
shape
])
# fc1
drop
=
tf
.
layers
.
dropout
(
pool5_flat
,
rate
=
0.5
,
training
=
is_training
)
fc1
=
self
.
fc_layer
(
'fc1'
,
drop
,
[
shape
,
512
])
# fc2
bn
=
self
.
batch_norm_relu
(
fc1
,
is_training
)
drop
=
tf
.
layers
.
dropout
(
bn
,
rate
=
0.5
,
training
=
is_training
)
fc2
=
self
.
fc_layer
(
'fc2'
,
drop
,
[
512
,
512
])
fc3
=
self
.
fc_layer
(
'fc3'
,
fc2
,
[
512
,
class_dim
])
return
fc3
def
run_benchmark
(
cluster_spec
,
server
):
"""Run benchmark on cifar10 or flowers."""
if
args
.
data_set
==
"cifar10"
:
class_dim
=
10
raw_shape
=
(
3
,
32
,
32
)
dat_shape
=
(
None
,
32
,
32
,
3
)
if
args
.
data_format
==
'NHWC'
else
(
None
,
3
,
32
,
32
)
else
:
class_dim
=
102
raw_shape
=
(
3
,
224
,
224
)
dat_shape
=
(
None
,
224
,
224
,
3
)
if
args
.
data_format
==
'NHWC'
else
(
None
,
3
,
224
,
224
)
device
=
tf
.
train
.
replica_device_setter
(
worker_device
=
"/job:worker/task:{}"
.
format
(
args
.
task_index
),
cluster
=
cluster_spec
)
with
tf
.
device
(
device
):
images
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
dat_shape
)
labels
=
tf
.
placeholder
(
tf
.
int64
,
shape
=
(
None
,
))
is_training
=
tf
.
placeholder
(
'bool'
)
onehot_labels
=
tf
.
one_hot
(
labels
,
depth
=
class_dim
)
vgg16
=
VGG16Model
()
logits
=
vgg16
.
network
(
images
,
class_dim
,
is_training
)
loss
=
tf
.
losses
.
softmax_cross_entropy
(
onehot_labels
=
onehot_labels
,
logits
=
logits
)
avg_loss
=
tf
.
reduce_mean
(
loss
)
correct
=
tf
.
equal
(
tf
.
argmax
(
logits
,
1
),
labels
)
accuracy
=
tf
.
reduce_mean
(
tf
.
cast
(
correct
,
tf
.
float32
))
optimizer
=
tf
.
train
.
AdamOptimizer
(
learning_rate
=
args
.
learning_rate
)
update_ops
=
tf
.
get_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
)
global_step
=
tf
.
Variable
(
0
,
name
=
'global_step'
,
trainable
=
False
)
with
tf
.
control_dependencies
(
update_ops
):
train_op
=
optimizer
.
minimize
(
avg_loss
,
global_step
=
global_step
)
summary_op
=
tf
.
summary
.
merge_all
()
init_op
=
tf
.
global_variables_initializer
()
# data reader
train_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
cifar
.
train10
()
if
args
.
data_set
==
'cifar10'
else
paddle
.
dataset
.
flowers
.
train
(),
buf_size
=
5120
),
batch_size
=
args
.
batch_size
)
test_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
cifar
.
test10
()
if
args
.
data_set
==
'cifar10'
else
paddle
.
dataset
.
flowers
.
test
(),
buf_size
=
5120
),
batch_size
=
args
.
batch_size
)
# test
def
test
():
test_accs
=
[]
for
batch_id
,
data
in
enumerate
(
test_reader
()):
test_images
=
np
.
array
(
map
(
lambda
x
:
np
.
transpose
(
x
[
0
].
reshape
(
raw_shape
),
axes
=
[
1
,
2
,
0
])
if
args
.
data_format
==
'NHWC'
else
x
[
0
],
data
)).
astype
(
"float32"
)
test_labels
=
np
.
array
(
map
(
lambda
x
:
x
[
1
],
data
)).
astype
(
'int64'
)
test_accs
.
append
(
accuracy
.
eval
(
feed_dict
=
{
images
:
test_images
,
labels
:
test_labels
,
is_training
:
False
}))
return
np
.
mean
(
test_accs
)
config
=
tf
.
ConfigProto
(
intra_op_parallelism_threads
=
1
,
inter_op_parallelism_threads
=
1
)
config
.
gpu_options
.
allow_growth
=
True
hooks
=
[
tf
.
train
.
StopAtStepHook
(
last_step
=
1000000
)]
with
tf
.
train
.
MonitoredTrainingSession
(
master
=
server
.
target
,
is_chief
=
(
args
.
task_index
==
0
),
hooks
=
hooks
)
as
sess
:
iters
,
num_samples
,
start_time
=
0
,
0
,
0.0
for
pass_id
in
range
(
args
.
num_passes
):
# train
num_samples
=
0
start_time
=
time
.
time
()
for
batch_id
,
data
in
enumerate
(
train_reader
()):
train_images
=
np
.
array
(
map
(
lambda
x
:
np
.
transpose
(
x
[
0
].
reshape
(
raw_shape
),
axes
=
[
1
,
2
,
0
])
if
args
.
data_format
==
'NHWC'
else
x
[
0
],
data
)).
astype
(
"float32"
)
train_labels
=
np
.
array
(
map
(
lambda
x
:
x
[
1
],
data
)).
astype
(
'int64'
)
iter_begin_time
=
time
.
time
()
_
,
loss
,
acc
=
sess
.
run
([
train_op
,
avg_loss
,
accuracy
],
feed_dict
=
{
images
:
train_images
,
labels
:
train_labels
,
is_training
:
True
})
iters
+=
1
print
(
"Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, Speed=%.2f imgs/sec"
%
(
pass_id
,
iters
,
loss
,
acc
,
len
(
data
)
/
(
time
.
time
()
-
iter_begin_time
)))
num_samples
+=
len
(
data
)
train_elapsed
=
time
.
time
()
-
start_time
# test
pass_test_acc
=
test
()
print
(
"Pass = %d, Train speed = %f imgs/s, Test accuracy = %f
\n
"
%
(
pass_id
,
num_samples
/
train_elapsed
,
pass_test_acc
))
def
print_arguments
():
print
(
'----------- Configuration Arguments -----------'
)
for
arg
,
value
in
sorted
(
vars
(
args
).
iteritems
()):
print
(
'%s: %s'
%
(
arg
,
value
))
print
(
'------------------------------------------------'
)
if
__name__
==
'__main__'
:
print_arguments
()
ps_hosts
=
args
.
ps_hosts
.
split
(
","
)
worker_hosts
=
args
.
worker_hosts
.
split
(
","
)
# Create a cluster from the parameter server and worker hosts.
cluster_spec
=
tf
.
train
.
ClusterSpec
({
"ps"
:
ps_hosts
,
"worker"
:
worker_hosts
})
# Create and start a server for the local task.
server
=
tf
.
train
.
Server
(
cluster_spec
,
job_name
=
args
.
job_name
,
task_index
=
args
.
task_index
)
if
args
.
job_name
==
"ps"
:
print
(
"start pserver"
)
server
.
join
()
elif
args
.
job_name
==
"worker"
:
print
(
"start worker"
)
run_benchmark
(
cluster_spec
,
server
)
cmake/configure.cmake
浏览文件 @
c0876cf6
...
...
@@ -59,6 +59,7 @@ endif(NOT WITH_GOLANG)
if
(
NOT WITH_GPU
)
add_definitions
(
-DHPPL_STUB_FUNC
)
add_definitions
(
"-DCUPTI_LIB_PATH=
\"\"
"
)
list
(
APPEND CMAKE_CXX_SOURCE_FILE_EXTENSIONS cu
)
else
()
...
...
@@ -73,7 +74,14 @@ else()
if
(
NOT CUDNN_FOUND
)
message
(
FATAL_ERROR
"Paddle needs cudnn to compile"
)
endif
()
if
(
CUPTI_FOUND
)
include_directories
(
${
CUPTI_INCLUDE_DIR
}
)
add_definitions
(
-DPADDLE_WITH_CUPTI
)
add_definitions
(
"-DCUPTI_LIB_PATH=
\"
${
CUPTI_LIBRARY_PATH
}
\"
"
)
else
()
add_definitions
(
"-DCUPTI_LIB_PATH=
\"\"
"
)
message
(
STATUS
"Cannot find CUPTI, GPU Profiling is incorrect."
)
endif
()
set
(
CUDA_NVCC_FLAGS
${
CUDA_NVCC_FLAGS
}
"-Xcompiler
${
SIMD_FLAG
}
"
)
# Include cuda and cudnn
...
...
cmake/cuda.cmake
浏览文件 @
c0876cf6
...
...
@@ -155,7 +155,8 @@ endif()
include_directories
(
${
CUDA_INCLUDE_DIRS
}
)
list
(
APPEND EXTERNAL_LIBS
${
CUDA_LIBRARIES
}
${
CUDA_rt_LIBRARY
}
)
if
(
NOT WITH_DSO
)
list
(
APPEND EXTERNAL_LIBS
${
CUDNN_LIBRARY
}
${
CUDA_CUBLAS_LIBRARIES
}
${
CUDA_curand_LIBRARY
}
${
NCCL_LIBRARY
}
)
# TODO(panyx0718): CUPTI only allows DSO?
list
(
APPEND EXTERNAL_LIBS
${
CUDNN_LIBRARY
}
${
CUPTI_LIBRARY
}
${
CUDA_CUBLAS_LIBRARIES
}
${
CUDA_curand_LIBRARY
}
${
NCCL_LIBRARY
}
)
endif
(
NOT WITH_DSO
)
# setting nvcc arch flags
...
...
cmake/cupti.cmake
0 → 100644
浏览文件 @
c0876cf6
if
(
NOT WITH_GPU
)
return
()
endif
()
set
(
CUPTI_ROOT
"/usr"
CACHE PATH
"CUPTI ROOT"
)
find_path
(
CUPTI_INCLUDE_DIR cupti.h
PATHS
${
CUPTI_ROOT
}
${
CUPTI_ROOT
}
/include
$ENV{CUPTI_ROOT} $ENV{CUPTI_ROOT}/include
${
CUDA_TOOLKIT_ROOT_DIR
}
/extras/CUPTI/include
NO_DEFAULT_PATH
)
get_filename_component
(
__libpath_hist
${
CUDA_CUDART_LIBRARY
}
PATH
)
set
(
TARGET_ARCH
"x86_64"
)
if
(
NOT
${
CMAKE_SYSTEM_PROCESSOR
}
)
set
(
TARGET_ARCH
${
CMAKE_SYSTEM_PROCESSOR
}
)
endif
()
list
(
APPEND CUPTI_CHECK_LIBRARY_DIRS
${
CUPTI_ROOT
}
${
CUPTI_ROOT
}
/lib64
${
CUPTI_ROOT
}
/lib
${
CUPTI_ROOT
}
/lib/
${
TARGET_ARCH
}
-linux-gnu
$ENV{CUPTI_ROOT}
$ENV{CUPTI_ROOT}/lib64
$ENV{CUPTI_ROOT}/lib
/usr/lib
${
CUDA_TOOLKIT_ROOT_DIR
}
/extras/CUPTI/lib64
)
find_library
(
CUPTI_LIBRARY NAMES libcupti.so libcupti.dylib
# libcupti_static.a
PATHS
${
CUPTI_CHECK_LIBRARY_DIRS
}
${
CUPTI_INCLUDE_DIR
}
${
__libpath_hist
}
NO_DEFAULT_PATH
DOC
"Path to cuPTI library."
)
get_filename_component
(
CUPTI_LIBRARY_PATH
${
CUPTI_LIBRARY
}
DIRECTORY
)
if
(
CUPTI_INCLUDE_DIR AND CUPTI_LIBRARY
)
set
(
CUPTI_FOUND ON
)
else
()
set
(
CUPTI_FOUND OFF
)
endif
()
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
c0876cf6
...
...
@@ -56,7 +56,7 @@ cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library
(
op_info SRCS op_info.cc DEPS attribute framework_proto
)
cc_library
(
shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context
)
cc_library
(
operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform lod_tensor
)
shape_inference data_transform lod_tensor
profiler
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry init
)
cc_library
(
proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog
)
...
...
@@ -80,7 +80,7 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
cc_library
(
feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table
profiler
feed_fetch_method
)
framework_proto backward glog lod_rank_table feed_fetch_method
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
...
...
paddle/fluid/framework/framework.proto
浏览文件 @
c0876cf6
...
...
@@ -167,4 +167,6 @@ message BlockDesc {
// Please refer to
// https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/program.md
// for more details.
// TODO(panyx0718): A model can have multiple programs. Need a
// way to distinguish them. Maybe ID or name?
message
ProgramDesc
{
repeated
BlockDesc
blocks
=
1
;
}
paddle/fluid/framework/lod_tensor.cc
浏览文件 @
c0876cf6
...
...
@@ -31,8 +31,14 @@ std::ostream &operator<<(std::ostream &os, const LoD &lod) {
os
<<
"{"
;
for
(
auto
&
v
:
lod
)
{
os
<<
"{"
;
bool
is_first
=
true
;
for
(
auto
&
i
:
v
)
{
os
<<
i
<<
","
;
if
(
is_first
)
{
os
<<
i
;
is_first
=
false
;
}
else
{
os
<<
", "
<<
i
;
}
}
os
<<
"}"
;
}
...
...
paddle/fluid/framework/op_desc.h
浏览文件 @
c0876cf6
...
...
@@ -125,6 +125,8 @@ class OpDesc {
BlockDesc
*
Block
()
{
return
this
->
block_
;
}
const
BlockDesc
&
BlockRef
()
const
{
return
*
this
->
block_
;
}
void
SetBlock
(
BlockDesc
*
block
)
{
this
->
block_
=
block
;
}
private:
...
...
paddle/fluid/inference/io.cc
浏览文件 @
c0876cf6
...
...
@@ -32,23 +32,11 @@ void ReadBinaryFile(const std::string& filename, std::string& contents) {
inputfs
.
close
();
}
bool
IsParameter
(
const
framework
::
VarDesc
*
var
,
const
framework
::
ProgramDesc
&
main_program
)
{
if
(
var
->
Persistable
())
{
// There are many unreachable variables in the program
for
(
size_t
i
=
0
;
i
<
main_program
.
Size
();
++
i
)
{
const
framework
::
BlockDesc
&
block
=
main_program
.
Block
(
i
);
for
(
auto
*
op
:
block
.
AllOps
())
{
if
(
op
->
Type
()
==
framework
::
kFeedOpType
)
{
continue
;
}
for
(
auto
input_argument_name
:
op
->
InputArgumentNames
())
{
if
(
input_argument_name
==
var
->
Name
())
{
return
true
;
}
}
}
}
bool
IsPersistable
(
const
framework
::
VarDesc
*
var
)
{
if
(
var
->
Persistable
()
&&
var
->
GetType
()
!=
framework
::
proto
::
VarType
::
FEED_MINIBATCH
&&
var
->
GetType
()
!=
framework
::
proto
::
VarType
::
FETCH_LIST
)
{
return
true
;
}
return
false
;
}
...
...
@@ -65,8 +53,8 @@ void LoadPersistables(framework::Executor& executor,
std
::
vector
<
std
::
string
>
paramlist
;
for
(
auto
*
var
:
global_block
.
AllVars
())
{
if
(
IsP
arameter
(
var
,
main_program
))
{
VLOG
(
3
)
<<
"p
arameter
's name: "
<<
var
->
Name
();
if
(
IsP
ersistable
(
var
))
{
VLOG
(
3
)
<<
"p
ersistable variable
's name: "
<<
var
->
Name
();
framework
::
VarDesc
*
new_var
=
load_block
->
Var
(
var
->
Name
());
new_var
->
SetShape
(
var
->
GetShape
());
...
...
@@ -101,7 +89,6 @@ void LoadPersistables(framework::Executor& executor,
executor
.
Run
(
*
load_program
,
&
scope
,
0
,
true
,
true
);
VLOG
(
3
)
<<
"Ran loading successfully"
;
delete
load_program
;
}
...
...
paddle/fluid/inference/tests/book/CMakeLists.txt
浏览文件 @
c0876cf6
...
...
@@ -30,5 +30,5 @@ inference_test(label_semantic_roles)
inference_test
(
recognize_digits ARGS mlp conv
)
inference_test
(
recommender_system
)
#inference_test(rnn_encoder_decoder)
inference_test
(
understand_sentiment
)
inference_test
(
understand_sentiment
ARGS conv
)
inference_test
(
word2vec
)
paddle/fluid/inference/tests/book/test_inference_label_semantic_roles.cc
浏览文件 @
c0876cf6
...
...
@@ -32,16 +32,42 @@ TEST(inference, label_semantic_roles) {
paddle
::
framework
::
LoDTensor
word
,
predicate
,
ctx_n2
,
ctx_n1
,
ctx_0
,
ctx_p1
,
ctx_p2
,
mark
;
paddle
::
framework
::
LoD
lod
{{
0
,
4
,
10
}};
SetupLoDTensor
(
word
,
lod
,
static_cast
<
int64_t
>
(
0
),
static_cast
<
int64_t
>
(
1
));
SetupLoDTensor
(
predicate
,
lod
,
static_cast
<
int64_t
>
(
0
),
static_cast
<
int64_t
>
(
1
));
SetupLoDTensor
(
ctx_n2
,
lod
,
static_cast
<
int64_t
>
(
0
),
static_cast
<
int64_t
>
(
1
));
SetupLoDTensor
(
ctx_n1
,
lod
,
static_cast
<
int64_t
>
(
0
),
static_cast
<
int64_t
>
(
1
));
SetupLoDTensor
(
ctx_0
,
lod
,
static_cast
<
int64_t
>
(
0
),
static_cast
<
int64_t
>
(
1
));
SetupLoDTensor
(
ctx_p1
,
lod
,
static_cast
<
int64_t
>
(
0
),
static_cast
<
int64_t
>
(
1
));
SetupLoDTensor
(
ctx_p2
,
lod
,
static_cast
<
int64_t
>
(
0
),
static_cast
<
int64_t
>
(
1
));
SetupLoDTensor
(
mark
,
lod
,
static_cast
<
int64_t
>
(
0
),
static_cast
<
int64_t
>
(
1
));
int64_t
word_dict_len
=
44068
;
int64_t
predicate_dict_len
=
3162
;
int64_t
mark_dict_len
=
2
;
SetupLoDTensor
(
word
,
lod
,
static_cast
<
int64_t
>
(
0
),
static_cast
<
int64_t
>
(
word_dict_len
-
1
));
SetupLoDTensor
(
predicate
,
lod
,
static_cast
<
int64_t
>
(
0
),
static_cast
<
int64_t
>
(
predicate_dict_len
-
1
));
SetupLoDTensor
(
ctx_n2
,
lod
,
static_cast
<
int64_t
>
(
0
),
static_cast
<
int64_t
>
(
word_dict_len
-
1
));
SetupLoDTensor
(
ctx_n1
,
lod
,
static_cast
<
int64_t
>
(
0
),
static_cast
<
int64_t
>
(
word_dict_len
-
1
));
SetupLoDTensor
(
ctx_0
,
lod
,
static_cast
<
int64_t
>
(
0
),
static_cast
<
int64_t
>
(
word_dict_len
-
1
));
SetupLoDTensor
(
ctx_p1
,
lod
,
static_cast
<
int64_t
>
(
0
),
static_cast
<
int64_t
>
(
word_dict_len
-
1
));
SetupLoDTensor
(
ctx_p2
,
lod
,
static_cast
<
int64_t
>
(
0
),
static_cast
<
int64_t
>
(
word_dict_len
-
1
));
SetupLoDTensor
(
mark
,
lod
,
static_cast
<
int64_t
>
(
0
),
static_cast
<
int64_t
>
(
mark_dict_len
-
1
));
std
::
vector
<
paddle
::
framework
::
LoDTensor
*>
cpu_feeds
;
cpu_feeds
.
push_back
(
&
word
);
...
...
paddle/fluid/inference/tests/book/test_inference_understand_sentiment.cc
浏览文件 @
c0876cf6
...
...
@@ -31,7 +31,12 @@ TEST(inference, understand_sentiment) {
paddle
::
framework
::
LoDTensor
words
;
paddle
::
framework
::
LoD
lod
{{
0
,
4
,
10
}};
SetupLoDTensor
(
words
,
lod
,
static_cast
<
int64_t
>
(
0
),
static_cast
<
int64_t
>
(
10
));
int64_t
word_dict_len
=
5147
;
SetupLoDTensor
(
words
,
lod
,
static_cast
<
int64_t
>
(
0
),
static_cast
<
int64_t
>
(
word_dict_len
-
1
));
std
::
vector
<
paddle
::
framework
::
LoDTensor
*>
cpu_feeds
;
cpu_feeds
.
push_back
(
&
words
);
...
...
paddle/fluid/inference/tests/book/test_inference_word2vec.cc
浏览文件 @
c0876cf6
...
...
@@ -31,12 +31,12 @@ TEST(inference, word2vec) {
paddle
::
framework
::
LoDTensor
first_word
,
second_word
,
third_word
,
fourth_word
;
paddle
::
framework
::
LoD
lod
{{
0
,
1
}};
int64_t
dict_size
=
207
2
;
// Hard-coding t
he size of dictionary
int64_t
dict_size
=
207
3
;
// T
he size of dictionary
SetupLoDTensor
(
first_word
,
lod
,
static_cast
<
int64_t
>
(
0
),
dict_size
);
SetupLoDTensor
(
second_word
,
lod
,
static_cast
<
int64_t
>
(
0
),
dict_size
);
SetupLoDTensor
(
third_word
,
lod
,
static_cast
<
int64_t
>
(
0
),
dict_size
);
SetupLoDTensor
(
fourth_word
,
lod
,
static_cast
<
int64_t
>
(
0
),
dict_size
);
SetupLoDTensor
(
first_word
,
lod
,
static_cast
<
int64_t
>
(
0
),
dict_size
-
1
);
SetupLoDTensor
(
second_word
,
lod
,
static_cast
<
int64_t
>
(
0
),
dict_size
-
1
);
SetupLoDTensor
(
third_word
,
lod
,
static_cast
<
int64_t
>
(
0
),
dict_size
-
1
);
SetupLoDTensor
(
fourth_word
,
lod
,
static_cast
<
int64_t
>
(
0
),
dict_size
-
1
);
std
::
vector
<
paddle
::
framework
::
LoDTensor
*>
cpu_feeds
;
cpu_feeds
.
push_back
(
&
first_word
);
...
...
paddle/fluid/inference/tests/test_helper.h
浏览文件 @
c0876cf6
...
...
@@ -101,8 +101,8 @@ void TestInference(const std::string& dirname,
if
(
IsCombined
)
{
// All parameters are saved in a single file.
// Hard-coding the file names of program and parameters in unittest.
//
Users are free to specify different filename
//
(provided: the filenames are changed in the python api as well: io.py)
//
The file names should be consistent with that used in Python API
//
`fluid.io.save_inference_model`.
std
::
string
prog_filename
=
"__model_combined__"
;
std
::
string
param_filename
=
"__params_combined__"
;
inference_program
=
paddle
::
inference
::
Load
(
executor
,
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
c0876cf6
...
...
@@ -11,6 +11,8 @@ function(op_library TARGET)
set
(
cc_srcs
)
set
(
cu_srcs
)
set
(
cu_cc_srcs
)
set
(
cudnn_cu_cc_srcs
)
set
(
CUDNN_FILE
)
set
(
op_common_deps operator op_registry math_function
)
set
(
options
""
)
set
(
oneValueArgs
""
)
...
...
@@ -30,10 +32,16 @@ function(op_library TARGET)
if
(
EXISTS
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
TARGET
}
.cu
)
list
(
APPEND cu_srcs
${
TARGET
}
.cu
)
endif
()
string
(
REPLACE
"_op"
"_cudnn_op"
CUDNN_FILE
"
${
TARGET
}
"
)
if
(
EXISTS
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
CUDNN_FILE
}
.cu.cc
)
list
(
APPEND cudnn_cu_cc_srcs
${
CUDNN_FILE
}
.cu.cc
)
endif
()
else
()
foreach
(
src
${
op_library_SRCS
}
)
if
(
${
src
}
MATCHES
".*
\\
.cu$"
)
list
(
APPEND cu_srcs
${
src
}
)
elseif
(
${
src
}
MATCHES
".*_cudnn_op.cu.cc$"
)
list
(
APPEND cudnn_cu_cc_srcs
${
src
}
)
elseif
(
${
src
}
MATCHES
".*
\\
.cu.cc$"
)
list
(
APPEND cu_cc_srcs
${
src
}
)
elseif
(
${
src
}
MATCHES
".*
\\
.cc$"
)
...
...
@@ -54,7 +62,7 @@ function(op_library TARGET)
set
(
DEPS_OPS
${
TARGET
}
${
DEPS_OPS
}
PARENT_SCOPE
)
endif
()
if
(
WITH_GPU
)
nv_library
(
${
TARGET
}
SRCS
${
cc_srcs
}
${
cu_cc_srcs
}
${
cu_srcs
}
DEPS
${
op_library_DEPS
}
nv_library
(
${
TARGET
}
SRCS
${
cc_srcs
}
${
cu_cc_srcs
}
${
cu
dnn_cu_cc_srcs
}
${
cu
_srcs
}
DEPS
${
op_library_DEPS
}
${
op_common_deps
}
)
else
()
cc_library
(
${
TARGET
}
SRCS
${
cc_srcs
}
DEPS
${
op_library_DEPS
}
...
...
@@ -98,6 +106,12 @@ function(op_library TARGET)
set
(
pybind_flag 1
)
endif
()
# pybind USE_OP_DEVICE_KERNEL for CUDNN
list
(
LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len
)
if
(
WITH_GPU AND
${
cudnn_cu_cc_srcs_len
}
GREATER 0
)
file
(
APPEND
${
pybind_file
}
"USE_OP_DEVICE_KERNEL(
${
TARGET
}
, CUDNN);
\n
"
)
endif
()
# pybind USE_OP
if
(
${
pybind_flag
}
EQUAL 0
)
file
(
APPEND
${
pybind_file
}
"USE_OP(
${
TARGET
}
);
\n
"
)
...
...
@@ -152,43 +166,24 @@ op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library
(
lstmp_op DEPS sequence2batch lstm_compute
)
op_library
(
gru_op DEPS sequence2batch gru_compute
)
op_library
(
recurrent_op DEPS executor
)
op_library
(
warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale
math_function
)
op_library
(
warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale
)
op_library
(
cos_sim_op DEPS cos_sim_functor
)
op_library
(
parallel_do_op DEPS executor
)
op_library
(
create_reader_op DEPS reader
)
# Regist multiple Kernel to pybind
if
(
WITH_GPU
)
op_library
(
conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS
vol2col depthwise_conv
)
op_library
(
edit_distance_op SRCS edit_distance_op.cc edit_distance_op.cu DEPS math_function
)
op_library
(
pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling
)
op_library
(
conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc
conv_transpose_cudnn_op.cu.cc DEPS vol2col
)
file
(
APPEND
${
pybind_file
}
"USE_OP_DEVICE_KERNEL(conv2d, CUDNN);
\n
"
)
file
(
APPEND
${
pybind_file
}
"USE_OP_DEVICE_KERNEL(pool2d, CUDNN);
\n
"
)
file
(
APPEND
${
pybind_file
}
"USE_OP_DEVICE_KERNEL(conv2d_transpose, CUDNN);
\n
"
)
op_library
(
conv_op DEPS vol2col depthwise_conv
)
else
()
op_library
(
conv_op SRCS conv_op.cc DEPS vol2col
)
op_library
(
pool_op SRCS pool_op.cc DEPS pooling
)
op_library
(
conv_transpose_op SRCS conv_transpose_op.cc DEPS vol2col
)
op_library
(
conv_op DEPS vol2col
)
endif
()
op_library
(
pool_op DEPS pooling
)
op_library
(
conv_transpose_op DEPS vol2col
)
cc_library
(
batch_size_like SRCS batch_size_like.cc DEPS op_registry
)
op_library
(
fill_constant_batch_size_like_op
SRCS fill_constant_batch_size_like_op.cc fill_constant_batch_size_like_op.cu.cc
DEPS batch_size_like
)
op_library
(
uniform_random_batch_size_like_op
SRCS uniform_random_batch_size_like_op.cc
DEPS batch_size_like uniform_random_op
)
op_library
(
gaussian_random_batch_size_like_op
SRCS gaussian_random_batch_size_like_op.cc
DEPS batch_size_like gaussian_random_op
)
op_library
(
fill_constant_batch_size_like_op DEPS batch_size_like
)
op_library
(
uniform_random_batch_size_like_op DEPS batch_size_like uniform_random_op
)
op_library
(
gaussian_random_batch_size_like_op DEPS batch_size_like gaussian_random_op
)
# FIXME(typhoonzero): save/load depends lodtensor serialization functions
op_library
(
save_op DEPS lod_tensor
)
...
...
paddle/fluid/operators/bipartite_match_op.cc
浏览文件 @
c0876cf6
...
...
@@ -94,6 +94,38 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
}
}
void
ArgMaxMatch
(
const
Tensor
&
dist
,
int
*
match_indices
,
T
*
match_dist
,
T
overlap_threshold
)
const
{
constexpr
T
kEPS
=
static_cast
<
T
>
(
1e-6
);
int64_t
row
=
dist
.
dims
()[
0
];
int64_t
col
=
dist
.
dims
()[
1
];
auto
*
dist_data
=
dist
.
data
<
T
>
();
for
(
int64_t
j
=
0
;
j
<
col
;
++
j
)
{
if
(
match_indices
[
j
]
!=
-
1
)
{
// the j-th column has been matched to one entity.
continue
;
}
int
max_row_idx
=
-
1
;
T
max_dist
=
-
1
;
for
(
int
i
=
0
;
i
<
row
;
++
i
)
{
T
dist
=
dist_data
[
i
*
col
+
j
];
if
(
dist
<
kEPS
)
{
// distance is 0 between m-th row and j-th column
continue
;
}
if
(
dist
>=
overlap_threshold
&&
dist
>
max_dist
)
{
max_row_idx
=
i
;
max_dist
=
dist
;
}
}
if
(
max_row_idx
!=
-
1
)
{
PADDLE_ENFORCE_EQ
(
match_indices
[
j
],
-
1
);
match_indices
[
j
]
=
max_row_idx
;
match_dist
[
j
]
=
max_dist
;
}
}
}
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
dist_mat
=
context
.
Input
<
LoDTensor
>
(
"DistMat"
);
auto
*
match_indices
=
context
.
Output
<
Tensor
>
(
"ColToRowMatchIndices"
);
...
...
@@ -120,13 +152,21 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
int
*
indices
=
match_indices
->
data
<
int
>
();
T
*
dist
=
match_dist
->
data
<
T
>
();
auto
type
=
context
.
Attr
<
std
::
string
>
(
"match_type"
);
auto
threshold
=
context
.
Attr
<
float
>
(
"dist_threshold"
);
if
(
n
==
1
)
{
BipartiteMatch
(
*
dist_mat
,
indices
,
dist
);
if
(
type
==
"per_prediction"
)
{
ArgMaxMatch
(
*
dist_mat
,
indices
,
dist
,
threshold
);
}
}
else
{
auto
lod
=
dist_mat
->
lod
().
back
();
for
(
size_t
i
=
0
;
i
<
lod
.
size
()
-
1
;
++
i
)
{
Tensor
one_ins
=
dist_mat
->
Slice
(
lod
[
i
],
lod
[
i
+
1
]);
BipartiteMatch
(
one_ins
,
indices
+
i
*
col
,
dist
+
i
*
col
);
if
(
type
==
"per_prediction"
)
{
ArgMaxMatch
(
one_ins
,
indices
+
i
*
col
,
dist
+
i
*
col
,
threshold
);
}
}
}
}
...
...
@@ -147,6 +187,19 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
"This tensor can contain LoD information to represent a batch of "
"inputs. One instance of this batch can contain different numbers of "
"entities."
);
AddAttr
<
std
::
string
>
(
"match_type"
,
"(string, defalut: per_prediction) "
"The type of matching method, should be 'bipartite' or "
"'per_prediction', 'bipartite' by defalut."
)
.
SetDefault
(
"bipartite"
)
.
InEnum
({
"bipartite"
,
"per_prediction"
});
AddAttr
<
float
>
(
"dist_threshold"
,
"(float, defalut: 0.5) "
"If `match_type` is 'per_prediction', this threshold is to determine "
"the extra matching bboxes based on the maximum distance."
)
.
SetDefault
(
0.5
);
AddOutput
(
"ColToRowMatchIndices"
,
"(Tensor) A 2-D Tensor with shape [N, M] in int type. "
"N is the batch size. If ColToRowMatchIndices[i][j] is -1, it "
...
...
@@ -168,10 +221,10 @@ distance matrix. For input 2D matrix, the bipartite matching algorithm can
find the matched column for each row, also can find the matched row for
each column. And this operator only calculate matched indices from column
to row. For each instance, the number of matched indices is the number of
of columns of the input ditance matrix.
of columns of the input di
s
tance matrix.
There are two outputs to save matched indices and distance.
A simple description, this algo
thri
m matched the best (maximum distance)
A simple description, this algo
rith
m matched the best (maximum distance)
row entity to the column entity and the matched indices are not duplicated
in each row of ColToRowMatchIndices. If the column entity is not matched
any row entity, set -1 in ColToRowMatchIndices.
...
...
paddle/fluid/platform/CMakeLists.txt
浏览文件 @
c0876cf6
proto_library
(
profiler_proto SRCS profiler.proto
)
if
(
WITH_GPU
)
cc_library
(
enforce SRCS enforce.cc DEPS
)
else
()
...
...
@@ -37,7 +39,8 @@ nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda)
nv_test
(
transform_test SRCS transform_test.cu DEPS paddle_memory place device_context
)
nv_test
(
nccl_test SRCS nccl_test.cu DEPS dynload_cuda gpu_info device_context
)
cc_library
(
profiler SRCS profiler.cc DEPS device_context
)
cc_library
(
device_tracer SRCS device_tracer.cc DEPS profiler_proto
${
GPU_CTX_DEPS
}
)
cc_library
(
profiler SRCS profiler.cc DEPS device_context device_tracer
)
cc_test
(
profiler_test SRCS profiler_test.cc DEPS profiler
)
nv_test
(
float16_gpu_test SRCS float16_test.cu
)
...
...
paddle/fluid/platform/device_tracer.cc
0 → 100644
浏览文件 @
c0876cf6
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/device_tracer.h"
#include <map>
#include <mutex>
#include "glog/logging.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/string/printf.h"
namespace
paddle
{
namespace
platform
{
namespace
{
thread_local
const
char
*
cur_annotation
=
nullptr
;
std
::
once_flag
tracer_once_flag
;
DeviceTracer
*
tracer
=
nullptr
;
}
// namespace
#ifdef PADDLE_WITH_CUPTI
namespace
{
// TODO(panyx0718): Revisit the buffer size here.
uint64_t
kBufSize
=
32
*
1024
;
uint64_t
kAlignSize
=
8
;
#define ALIGN_BUFFER(buffer, align) \
(((uintptr_t)(buffer) & ((align)-1)) \
? ((buffer) + (align) - ((uintptr_t)(buffer) & ((align)-1))) \
: (buffer))
#define CUPTI_CALL(call) \
do { \
CUptiResult _status = call; \
if (_status != CUPTI_SUCCESS) { \
const char *errstr; \
dynload::cuptiGetResultString(_status, &errstr); \
fprintf(stderr, "%s:%d: error: function %s failed with error %s.\n", \
__FILE__, __LINE__, #call, errstr); \
exit(-1); \
} \
} while (0)
void
EnableActivity
()
{
// Device activity record is created when CUDA initializes, so we
// want to enable it before cuInit() or any CUDA runtime call.
CUPTI_CALL
(
dynload
::
cuptiActivityEnable
(
CUPTI_ACTIVITY_KIND_MEMCPY
));
CUPTI_CALL
(
dynload
::
cuptiActivityEnable
(
CUPTI_ACTIVITY_KIND_KERNEL
));
CUPTI_CALL
(
dynload
::
cuptiActivityEnable
(
CUPTI_ACTIVITY_KIND_DEVICE
));
CUPTI_CALL
(
dynload
::
cuptiActivityEnable
(
CUPTI_ACTIVITY_KIND_MEMSET
));
CUPTI_CALL
(
dynload
::
cuptiActivityEnable
(
CUPTI_ACTIVITY_KIND_OVERHEAD
));
// We don't track these activities for now.
// CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONTEXT));
// CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_DRIVER));
// CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_RUNTIME));
// CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_NAME));
// CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MARKER));
}
void
DisableActivity
()
{
CUPTI_CALL
(
dynload
::
cuptiActivityDisable
(
CUPTI_ACTIVITY_KIND_MEMCPY
));
CUPTI_CALL
(
dynload
::
cuptiActivityDisable
(
CUPTI_ACTIVITY_KIND_KERNEL
));
CUPTI_CALL
(
dynload
::
cuptiActivityDisable
(
CUPTI_ACTIVITY_KIND_DEVICE
));
// Disable all other activity record kinds.
CUPTI_CALL
(
dynload
::
cuptiActivityDisable
(
CUPTI_ACTIVITY_KIND_CONTEXT
));
CUPTI_CALL
(
dynload
::
cuptiActivityDisable
(
CUPTI_ACTIVITY_KIND_DRIVER
));
CUPTI_CALL
(
dynload
::
cuptiActivityDisable
(
CUPTI_ACTIVITY_KIND_RUNTIME
));
CUPTI_CALL
(
dynload
::
cuptiActivityDisable
(
CUPTI_ACTIVITY_KIND_MEMSET
));
CUPTI_CALL
(
dynload
::
cuptiActivityDisable
(
CUPTI_ACTIVITY_KIND_NAME
));
CUPTI_CALL
(
dynload
::
cuptiActivityDisable
(
CUPTI_ACTIVITY_KIND_MARKER
));
CUPTI_CALL
(
dynload
::
cuptiActivityDisable
(
CUPTI_ACTIVITY_KIND_OVERHEAD
));
}
void
CUPTIAPI
bufferRequested
(
uint8_t
**
buffer
,
size_t
*
size
,
size_t
*
maxNumRecords
)
{
uint8_t
*
buf
=
(
uint8_t
*
)
malloc
(
kBufSize
+
kAlignSize
);
*
size
=
kBufSize
;
*
buffer
=
ALIGN_BUFFER
(
buf
,
kAlignSize
);
*
maxNumRecords
=
0
;
}
void
CUPTIAPI
bufferCompleted
(
CUcontext
ctx
,
uint32_t
streamId
,
uint8_t
*
buffer
,
size_t
size
,
size_t
validSize
)
{
CUptiResult
status
;
CUpti_Activity
*
record
=
NULL
;
if
(
validSize
>
0
)
{
do
{
status
=
dynload
::
cuptiActivityGetNextRecord
(
buffer
,
validSize
,
&
record
);
if
(
status
==
CUPTI_SUCCESS
)
{
switch
(
record
->
kind
)
{
case
CUPTI_ACTIVITY_KIND_KERNEL
:
case
CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL
:
{
auto
*
kernel
=
reinterpret_cast
<
const
CUpti_ActivityKernel3
*>
(
record
);
tracer
->
AddKernelRecords
(
kernel
->
start
,
kernel
->
end
,
kernel
->
deviceId
,
kernel
->
streamId
,
kernel
->
correlationId
);
break
;
}
default:
{
break
;
}
}
}
else
if
(
status
==
CUPTI_ERROR_MAX_LIMIT_REACHED
)
{
// Seems not an error in this case.
break
;
}
else
{
CUPTI_CALL
(
status
);
}
}
while
(
1
);
size_t
dropped
;
CUPTI_CALL
(
dynload
::
cuptiActivityGetNumDroppedRecords
(
ctx
,
streamId
,
&
dropped
));
if
(
dropped
!=
0
)
{
fprintf
(
stderr
,
"Dropped %u activity records
\n
"
,
(
unsigned
int
)
dropped
);
}
}
free
(
buffer
);
}
}
// namespace
class
DeviceTracerImpl
:
public
DeviceTracer
{
public:
DeviceTracerImpl
()
:
enabled_
(
false
)
{}
void
AddAnnotation
(
uint64_t
id
,
const
std
::
string
&
anno
)
{
std
::
lock_guard
<
std
::
mutex
>
l
(
trace_mu_
);
correlations_
[
id
]
=
anno
;
}
void
AddKernelRecords
(
uint64_t
start
,
uint64_t
end
,
uint32_t
device_id
,
uint32_t
stream_id
,
uint32_t
correlation_id
)
{
std
::
lock_guard
<
std
::
mutex
>
l
(
trace_mu_
);
kernel_records_
.
push_back
(
KernelRecord
{
start
,
end
,
device_id
,
stream_id
,
correlation_id
});
}
bool
IsEnabled
()
{
std
::
lock_guard
<
std
::
mutex
>
l
(
trace_mu_
);
return
enabled_
;
}
void
Enable
()
{
std
::
lock_guard
<
std
::
mutex
>
l
(
trace_mu_
);
if
(
enabled_
)
{
fprintf
(
stderr
,
"DeviceTracer already enabled
\n
"
);
return
;
}
EnableActivity
();
// Register callbacks for buffer requests and completed by CUPTI.
CUPTI_CALL
(
dynload
::
cuptiActivityRegisterCallbacks
(
bufferRequested
,
bufferCompleted
));
CUptiResult
ret
;
ret
=
dynload
::
cuptiSubscribe
(
&
subscriber_
,
static_cast
<
CUpti_CallbackFunc
>
(
ApiCallback
),
this
);
if
(
ret
==
CUPTI_ERROR_MAX_LIMIT_REACHED
)
{
fprintf
(
stderr
,
"CUPTI subcriber limit reached.
\n
"
);
}
else
if
(
ret
!=
CUPTI_SUCCESS
)
{
fprintf
(
stderr
,
"Failed to create CUPTI subscriber.
\n
"
);
}
CUPTI_CALL
(
dynload
::
cuptiEnableCallback
(
1
,
subscriber_
,
CUPTI_CB_DOMAIN_DRIVER_API
,
CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel
));
CUPTI_CALL
(
dynload
::
cuptiGetTimestamp
(
&
start_ns_
));
enabled_
=
true
;
}
proto
::
Profile
GenProfile
()
{
std
::
lock_guard
<
std
::
mutex
>
l
(
trace_mu_
);
proto
::
Profile
profile_pb
;
profile_pb
.
set_start_ns
(
start_ns_
);
profile_pb
.
set_end_ns
(
end_ns_
);
std
::
map
<
std
::
string
,
std
::
vector
<
uint64_t
>>
event_times
;
for
(
const
KernelRecord
&
r
:
kernel_records_
)
{
if
(
correlations_
.
find
(
r
.
correlation_id
)
==
correlations_
.
end
())
{
fprintf
(
stderr
,
"cannot relate a kernel activity
\n
"
);
continue
;
}
auto
*
event
=
profile_pb
.
add_events
();
event
->
set_name
(
correlations_
.
at
(
r
.
correlation_id
));
event
->
set_start_ns
(
r
.
start_ns
);
event
->
set_end_ns
(
r
.
end_ns
);
event
->
set_stream_id
(
r
.
stream_id
);
event
->
set_device_id
(
r
.
device_id
);
event_times
[
event
->
name
()].
push_back
(
r
.
end_ns
-
r
.
start_ns
);
}
for
(
const
auto
&
et
:
event_times
)
{
fprintf
(
stderr
,
"%s: total: %fms invoked cuda kernels: %lu
\n
"
,
et
.
first
.
c_str
(),
std
::
accumulate
(
et
.
second
.
begin
(),
et
.
second
.
end
(),
0
)
/
1000000.0
,
et
.
second
.
size
());
}
return
profile_pb
;
}
void
Disable
()
{
// flush might cause additional calls to DeviceTracker.
dynload
::
cuptiActivityFlushAll
(
CUPTI_ACTIVITY_FLAG_FLUSH_FORCED
);
std
::
lock_guard
<
std
::
mutex
>
l
(
trace_mu_
);
DisableActivity
();
dynload
::
cuptiUnsubscribe
(
subscriber_
);
CUPTI_CALL
(
dynload
::
cuptiGetTimestamp
(
&
end_ns_
));
PADDLE_ENFORCE
(
dynload
::
cuptiFinalize
());
enabled_
=
false
;
}
private:
static
void
CUPTIAPI
ApiCallback
(
void
*
userdata
,
CUpti_CallbackDomain
domain
,
CUpti_CallbackId
cbid
,
const
void
*
cbdata
)
{
auto
*
cbInfo
=
reinterpret_cast
<
const
CUpti_CallbackData
*>
(
cbdata
);
DeviceTracer
*
tracer
=
reinterpret_cast
<
DeviceTracer
*>
(
userdata
);
if
((
domain
==
CUPTI_CB_DOMAIN_DRIVER_API
)
&&
(
cbid
==
CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel
))
{
if
(
cbInfo
->
callbackSite
==
CUPTI_API_ENTER
)
{
const
std
::
string
anno
=
cur_annotation
?
cur_annotation
:
cbInfo
->
symbolName
;
tracer
->
AddAnnotation
(
cbInfo
->
correlationId
,
anno
);
}
}
else
{
VLOG
(
1
)
<<
"Unhandled API Callback for "
<<
domain
<<
" "
<<
cbid
;
}
}
std
::
mutex
trace_mu_
;
bool
enabled_
;
uint64_t
start_ns_
;
uint64_t
end_ns_
;
std
::
vector
<
KernelRecord
>
kernel_records_
;
std
::
unordered_map
<
uint32_t
,
std
::
string
>
correlations_
;
CUpti_SubscriberHandle
subscriber_
;
};
#endif // PADDLE_WITH_CUPTI
class
DeviceTracerDummy
:
public
DeviceTracer
{
public:
DeviceTracerDummy
()
{}
void
AddAnnotation
(
uint64_t
id
,
const
std
::
string
&
anno
)
{}
void
AddKernelRecords
(
uint64_t
start
,
uint64_t
end
,
uint32_t
device_id
,
uint32_t
stream_id
,
uint32_t
correlation_id
)
{}
bool
IsEnabled
()
{
return
false
;
}
void
Enable
()
{}
proto
::
Profile
GenProfile
()
{
return
proto
::
Profile
();
}
void
Disable
()
{}
};
void
CreateTracer
(
DeviceTracer
**
t
)
{
#ifdef PADDLE_WITH_CUPTI
*
t
=
new
DeviceTracerImpl
();
#else
*
t
=
new
DeviceTracerDummy
();
#endif // PADDLE_WITH_CUPTI
}
DeviceTracer
*
GetDeviceTracer
()
{
std
::
call_once
(
tracer_once_flag
,
CreateTracer
,
&
tracer
);
return
tracer
;
}
void
SetCurAnnotation
(
const
char
*
anno
)
{
cur_annotation
=
anno
;
}
void
ClearCurAnnotation
()
{
cur_annotation
=
nullptr
;
}
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/device_tracer.h
0 → 100644
浏览文件 @
c0876cf6
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/platform/dynload/cupti.h"
#include "paddle/fluid/platform/profiler.pb.h"
namespace
paddle
{
namespace
platform
{
///////////////////////
// WARN: Under Development. Don't depend on it yet.
//////////////////////
// DeviceTracer performs the following tasks:
// 1. Register cuda callbacks for various events: kernel, memcpy, etc.
// 2. Collect cuda statistics: start/end ts, memory, etc.
// 3. Generate a protobuf for further analysis.
class
DeviceTracer
{
public:
struct
KernelRecord
{
uint64_t
start_ns
;
uint64_t
end_ns
;
uint32_t
device_id
;
uint32_t
stream_id
;
uint32_t
correlation_id
;
};
virtual
~
DeviceTracer
()
{}
// Needs to be called once before use.
virtual
void
Enable
()
=
0
;
// Needs to be called once after use.
virtual
void
Disable
()
=
0
;
// Add a pair to correlate internal cuda id with high level
// annotation (string). So cuda statistics can be represented by
// human-readable annotations.
virtual
void
AddAnnotation
(
uint64_t
id
,
const
std
::
string
&
anno
)
=
0
;
// Add a cuda kernel stats. `correlation_id` will be mapped to annotation
// added before for human readability.
virtual
void
AddKernelRecords
(
uint64_t
start
,
uint64_t
end
,
uint32_t
device_id
,
uint32_t
stream_id
,
uint32_t
correlation_id
)
=
0
;
// Generate a proto after done (Disabled).
virtual
proto
::
Profile
GenProfile
()
=
0
;
virtual
bool
IsEnabled
()
=
0
;
};
// Get a DeviceTracer.
DeviceTracer
*
GetDeviceTracer
();
// Set a name for the cuda kernel operation being launched by the thread.
void
SetCurAnnotation
(
const
char
*
anno
);
// Clear the name after the operation is done.
void
ClearCurAnnotation
();
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/dynload/CMakeLists.txt
浏览文件 @
c0876cf6
cc_library
(
dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce
)
nv_library
(
dynload_cuda SRCS cublas.cc cudnn.cc curand.cc nccl.cc
DEPS dynamic_loader
)
list
(
APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc nccl.cc
)
if
(
CUPTI_FOUND
)
list
(
APPEND CUDA_SRCS cupti.cc
)
endif
(
CUPTI_FOUND
)
nv_library
(
dynload_cuda SRCS
${
CUDA_SRCS
}
DEPS dynamic_loader
)
cc_library
(
dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc
)
paddle/fluid/platform/dynload/cupti.cc
0 → 100644
浏览文件 @
c0876cf6
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_CUPTI
#include "paddle/fluid/platform/dynload/cupti.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
platform
{
namespace
dynload
{
std
::
once_flag
cupti_dso_flag
;
void
*
cupti_dso_handle
=
nullptr
;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
CUPTI_ROUTINE_EACH
(
DEFINE_WRAP
);
}
// namespace dynload
}
// namespace platform
}
// namespace paddle
#endif // PADDLE_WITH_CUPTI
paddle/fluid/platform/dynload/cupti.h
0 → 100644
浏览文件 @
c0876cf6
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_CUPTI
#include <cuda.h>
#include <cupti.h>
#include <dlfcn.h>
#include <mutex>
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
namespace
paddle
{
namespace
platform
{
namespace
dynload
{
extern
std
::
once_flag
cupti_dso_flag
;
extern
void
*
cupti_dso_handle
;
/**
* The following macro definition can generate structs
* (for each function) to dynamic load cupti routine
* via operator overloading.
*
* note: default dynamic linked libs
*/
#ifdef PADDLE_USE_DSO
#define DECLARE_DYNAMIC_LOAD_CUPTI_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
inline CUptiResult CUPTIAPI operator()(Args... args) { \
typedef CUptiResult CUPTIAPI (*cuptiFunc)(Args...); \
std::call_once(cupti_dso_flag, \
paddle::platform::dynload::GetCUPTIDsoHandle, \
&cupti_dso_handle); \
void *p_##__name = dlsym(cupti_dso_handle, #__name); \
return reinterpret_cast<cuptiFunc>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
#else
#define DECLARE_DYNAMIC_LOAD_CUPTI_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
inline CUptiResult CUPTIAPI operator()(Args... args) { \
return __name(args...); \
} \
}; \
extern DynLoad__##__name __name
#endif
#define CUPTI_ROUTINE_EACH(__macro) \
__macro(cuptiActivityEnable); \
__macro(cuptiActivityDisable); \
__macro(cuptiActivityRegisterCallbacks); \
__macro(cuptiActivityGetAttribute); \
__macro(cuptiActivitySetAttribute); \
__macro(cuptiGetTimestamp); \
__macro(cuptiActivityGetNextRecord); \
__macro(cuptiGetResultString); \
__macro(cuptiActivityGetNumDroppedRecords); \
__macro(cuptiActivityFlushAll); \
__macro(cuptiFinalize); \
__macro(cuptiSubscribe); \
__macro(cuptiUnsubscribe); \
__macro(cuptiEnableCallback);
CUPTI_ROUTINE_EACH
(
DECLARE_DYNAMIC_LOAD_CUPTI_WRAP
);
#undef DECLARE_DYNAMIC_LOAD_CUPTI_WRAP
}
// namespace dynload
}
// namespace platform
}
// namespace paddle
#endif // PADDLE_WITH_CUPTI
paddle/fluid/platform/dynload/dynamic_loader.cc
浏览文件 @
c0876cf6
...
...
@@ -40,10 +40,14 @@ DEFINE_string(nccl_dir, "",
"libcurand. For instance, /usr/local/cuda/lib64. If default, "
"dlopen will search cuda from LD_LIBRARY_PATH"
);
DEFINE_string
(
cupti_dir
,
""
,
"Specify path for loading cupti.so."
);
namespace
paddle
{
namespace
platform
{
namespace
dynload
{
static
const
char
*
cupti_lib_path
=
CUPTI_LIB_PATH
;
static
inline
std
::
string
join
(
const
std
::
string
&
part1
,
const
std
::
string
&
part2
)
{
// directory separator
...
...
@@ -143,6 +147,18 @@ void GetCUDNNDsoHandle(void** dso_handle) {
#endif
}
void
GetCUPTIDsoHandle
(
void
**
dso_handle
)
{
std
::
string
cupti_path
=
cupti_lib_path
;
if
(
!
FLAGS_cupti_dir
.
empty
())
{
cupti_path
=
FLAGS_cupti_dir
;
}
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath
(
cupti_path
,
"libcupti.dylib"
,
dso_handle
,
false
);
#else
GetDsoHandleFromSearchPath
(
cupti_path
,
"libcupti.so"
,
dso_handle
,
false
);
#endif
}
void
GetCurandDsoHandle
(
void
**
dso_handle
)
{
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath
(
FLAGS_cuda_dir
,
"libcurand.dylib"
,
dso_handle
);
...
...
paddle/fluid/platform/dynload/dynamic_loader.h
浏览文件 @
c0876cf6
...
...
@@ -34,6 +34,8 @@ void GetCublasDsoHandle(void** dso_handle);
*/
void
GetCUDNNDsoHandle
(
void
**
dso_handle
);
void
GetCUPTIDsoHandle
(
void
**
dso_handle
);
/**
* @brief load the DSO of CURAND
*
...
...
paddle/fluid/platform/profiler.cc
浏览文件 @
c0876cf6
...
...
@@ -15,7 +15,13 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler.h"
#include <iomanip>
#include <map>
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#endif // PADDLE_WITH_CUDA
#include "glog/logging.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/platform/device_tracer.h"
#include "paddle/fluid/string/printf.h"
namespace
paddle
{
namespace
platform
{
...
...
@@ -132,10 +138,13 @@ RecordEvent::RecordEvent(const std::string& name,
dev_ctx_
=
dev_ctx
;
name_
=
name
;
PushEvent
(
name_
,
dev_ctx_
);
// Maybe need the same push/pop behavior.
SetCurAnnotation
(
name_
.
c_str
());
}
RecordEvent
::~
RecordEvent
()
{
if
(
g_state
==
ProfilerState
::
kDisabled
)
return
;
ClearCurAnnotation
();
PopEvent
(
name_
,
dev_ctx_
);
}
...
...
@@ -147,7 +156,14 @@ void EnableProfiler(ProfilerState state) {
"The profiling state should be disabled when calling "
,
"EnableProfiler."
);
g_state
=
state
;
g_profiler_place
=
(
g_state
==
ProfilerState
::
kCUDA
)
?
"CUDA"
:
"CPU"
;
if
(
g_state
==
ProfilerState
::
kCUDA
)
{
g_profiler_place
=
"CUDA"
;
}
else
if
(
g_state
==
ProfilerState
::
kCPU
)
{
g_profiler_place
=
"CPU"
;
}
else
{
g_profiler_place
=
"All"
;
GetDeviceTracer
()
->
Enable
();
}
#ifdef PADDLE_WITH_CUDA
if
(
g_state
==
ProfilerState
::
kCUDA
)
{
// Generate some dummy evenets first to reduce the startup overhead.
...
...
@@ -190,6 +206,12 @@ void DisableProfiler(EventSortingKey sorted_key) {
Mark
(
"_stop_profiler_"
,
nullptr
);
g_state
=
ProfilerState
::
kDisabled
;
DeviceTracer
*
tracer
=
GetDeviceTracer
();
if
(
g_profiler_place
==
"All"
&&
tracer
&&
tracer
->
IsEnabled
())
{
tracer
->
Disable
();
tracer
->
GenProfile
();
}
std
::
vector
<
std
::
vector
<
Event
>>
all_events
=
GetAllEvents
();
ParseEvents
(
all_events
,
sorted_key
);
ResetProfiler
();
...
...
@@ -254,9 +276,11 @@ void ParseEvents(std::vector<std::vector<Event>>& events,
}
if
(
rit
!=
pushed_events
.
rend
())
{
double
event_time
=
(
g_profiler_place
==
"CUDA"
)
?
rit
->
CudaElapsedMs
(
events
[
i
][
j
])
:
rit
->
CpuElapsedMs
(
events
[
i
][
j
]);
double
event_time
=
(
g_profiler_place
==
"CUDA"
||
g_profiler_place
==
"All"
)
?
rit
->
CudaElapsedMs
(
events
[
i
][
j
])
:
rit
->
CpuElapsedMs
(
events
[
i
][
j
]);
std
::
string
event_name
=
"thread"
+
std
::
to_string
(
rit
->
thread_id
())
+
"::"
+
rit
->
name
();
max_name_width
=
std
::
max
(
max_name_width
,
event_name
.
size
());
...
...
paddle/fluid/platform/profiler.h
浏览文件 @
c0876cf6
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include <mutex>
#include <vector>
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.pb.h"
namespace
paddle
{
namespace
platform
{
...
...
@@ -93,6 +94,7 @@ enum ProfilerState {
kDisabled
,
// disabled state
kCPU
,
// CPU profiling state
kCUDA
,
// GPU profiling state
kAll
,
// Profile both CPU and GPU. (Currently experimental).
};
void
Mark
(
const
std
::
string
&
name
,
const
DeviceContext
*
dev_ctx
);
...
...
@@ -102,7 +104,7 @@ void PushEvent(const std::string& name, const DeviceContext* dev_ctx);
void
PopEvent
(
const
std
::
string
&
name
,
const
DeviceContext
*
dev_ctx
);
struct
RecordEvent
{
explicit
RecordEvent
(
const
std
::
string
&
name
,
const
DeviceContext
*
dev_ctx
);
RecordEvent
(
const
std
::
string
&
name
,
const
DeviceContext
*
dev_ctx
);
~
RecordEvent
();
...
...
@@ -110,9 +112,12 @@ struct RecordEvent {
const
DeviceContext
*
dev_ctx_
;
// Event name
std
::
string
name_
;
// Need to distinguish name by op type, block_id, program_id and perhaps
// different kernel invocations within an op.
std
::
string
full_name_
;
};
// Return the event list of all threads. As
um
med the returned value calls
// Return the event list of all threads. As
su
med the returned value calls
// event_lists, event_lists[i][j] represents the j-th Event of i-th thread.
std
::
vector
<
std
::
vector
<
Event
>>
GetAllEvents
();
...
...
paddle/fluid/platform/profiler.proto
0 → 100644
浏览文件 @
c0876cf6
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax
=
"proto2"
;
package
paddle
.
platform.proto
;
message
Event
{
optional
string
name
=
1
;
optional
uint64
start_ns
=
2
;
optional
uint64
end_ns
=
3
;
optional
uint32
device_id
=
5
;
optional
uint32
stream_id
=
6
;
}
message
Profile
{
repeated
Event
events
=
1
;
optional
uint64
start_ns
=
2
;
optional
uint64
end_ns
=
3
;
}
\ No newline at end of file
paddle/fluid/pybind/pybind.cc
浏览文件 @
c0876cf6
...
...
@@ -459,6 +459,7 @@ All parameter, weight, gradient are variables in Paddle.
.
value
(
"kDisabled"
,
platform
::
ProfilerState
::
kDisabled
)
.
value
(
"kCPU"
,
platform
::
ProfilerState
::
kCPU
)
.
value
(
"kCUDA"
,
platform
::
ProfilerState
::
kCUDA
)
.
value
(
"kAll"
,
platform
::
ProfilerState
::
kAll
)
.
export_values
();
py
::
enum_
<
platform
::
EventSortingKey
>
(
m
,
"EventSortingKey"
,
py
::
arithmetic
())
...
...
python/paddle/fluid/framework.py
浏览文件 @
c0876cf6
...
...
@@ -784,6 +784,7 @@ class Block(object):
elif
type
(
v
)
==
Variable
:
var
=
Variable
(
self
,
type
=
v
.
type
,
name
=
new_name
,
error_clip
=
error_clip
,
stop_gradient
=
stop_gradient
)
...
...
python/paddle/fluid/io.py
浏览文件 @
c0876cf6
...
...
@@ -68,7 +68,7 @@ def save_vars(executor,
main_program
=
None
,
vars
=
None
,
predicate
=
None
,
save_file_
name
=
None
):
file
name
=
None
):
"""
Save variables to directory by executor.
...
...
@@ -80,8 +80,8 @@ def save_vars(executor,
as a bool. If it returns true, the corresponding input variable will be saved.
:param vars: variables need to be saved. If vars is specified, program & predicate
will be ignored
:param
save_file_name: The name of a single file that all vars are saved to.
If it is None, save variables to separate files.
:param
filename: The name of a single file that all vars are saved to.
If it is None, save variables to separate files.
:return: None
"""
...
...
@@ -95,7 +95,7 @@ def save_vars(executor,
executor
,
dirname
=
dirname
,
vars
=
filter
(
predicate
,
main_program
.
list_vars
()),
save_file_name
=
save_file_
name
)
filename
=
file
name
)
else
:
save_program
=
Program
()
save_block
=
save_program
.
global_block
()
...
...
@@ -103,7 +103,7 @@ def save_vars(executor,
save_var_map
=
{}
for
each_var
in
vars
:
new_var
=
_clone_var_in_block_
(
save_block
,
each_var
)
if
save_file_
name
is
None
:
if
file
name
is
None
:
save_block
.
append_op
(
type
=
'save'
,
inputs
=
{
'X'
:
[
new_var
]},
...
...
@@ -112,7 +112,7 @@ def save_vars(executor,
else
:
save_var_map
[
new_var
.
name
]
=
new_var
if
save_file_
name
is
not
None
:
if
file
name
is
not
None
:
save_var_list
=
[]
for
name
in
sorted
(
save_var_map
.
keys
()):
save_var_list
.
append
(
save_var_map
[
name
])
...
...
@@ -121,12 +121,12 @@ def save_vars(executor,
type
=
'save_combine'
,
inputs
=
{
'X'
:
save_var_list
},
outputs
=
{},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
save_file_
name
)})
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
file
name
)})
executor
.
run
(
save_program
)
def
save_params
(
executor
,
dirname
,
main_program
=
None
,
save_file_
name
=
None
):
def
save_params
(
executor
,
dirname
,
main_program
=
None
,
file
name
=
None
):
"""
Save all parameters to directory with executor.
"""
...
...
@@ -136,11 +136,10 @@ def save_params(executor, dirname, main_program=None, save_file_name=None):
main_program
=
main_program
,
vars
=
None
,
predicate
=
is_parameter
,
save_file_name
=
save_file_
name
)
filename
=
file
name
)
def
save_persistables
(
executor
,
dirname
,
main_program
=
None
,
save_file_name
=
None
):
def
save_persistables
(
executor
,
dirname
,
main_program
=
None
,
filename
=
None
):
"""
Save all persistables to directory with executor.
"""
...
...
@@ -150,7 +149,7 @@ def save_persistables(executor, dirname, main_program=None,
main_program
=
main_program
,
vars
=
None
,
predicate
=
is_persistable
,
save_file_name
=
save_file_
name
)
filename
=
file
name
)
def
load_vars
(
executor
,
...
...
@@ -158,7 +157,7 @@ def load_vars(executor,
main_program
=
None
,
vars
=
None
,
predicate
=
None
,
load_file_
name
=
None
):
file
name
=
None
):
"""
Load variables from directory by executor.
...
...
@@ -170,8 +169,8 @@ def load_vars(executor,
as a bool. If it returns true, the corresponding input variable will be loaded.
:param vars: variables need to be loaded. If vars is specified, program &
predicate will be ignored
:param
load_file_name: The name of the single file that all vars are loaded from.
If it is None, load variables from separate files.
:param
filename: The name of the single file that all vars are loaded from.
If it is None, load variables from separate files.
:return: None
"""
...
...
@@ -185,7 +184,7 @@ def load_vars(executor,
executor
,
dirname
=
dirname
,
vars
=
filter
(
predicate
,
main_program
.
list_vars
()),
load_file_name
=
load_file_
name
)
filename
=
file
name
)
else
:
load_prog
=
Program
()
load_block
=
load_prog
.
global_block
()
...
...
@@ -194,7 +193,7 @@ def load_vars(executor,
for
each_var
in
vars
:
assert
isinstance
(
each_var
,
Variable
)
new_var
=
_clone_var_in_block_
(
load_block
,
each_var
)
if
load_file_
name
is
None
:
if
file
name
is
None
:
load_block
.
append_op
(
type
=
'load'
,
inputs
=
{},
...
...
@@ -203,7 +202,7 @@ def load_vars(executor,
else
:
load_var_map
[
new_var
.
name
]
=
new_var
if
load_file_
name
is
not
None
:
if
file
name
is
not
None
:
load_var_list
=
[]
for
name
in
sorted
(
load_var_map
.
keys
()):
load_var_list
.
append
(
load_var_map
[
name
])
...
...
@@ -212,12 +211,12 @@ def load_vars(executor,
type
=
'load_combine'
,
inputs
=
{},
outputs
=
{
"Out"
:
load_var_list
},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
load_file_
name
)})
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
file
name
)})
executor
.
run
(
load_prog
)
def
load_params
(
executor
,
dirname
,
main_program
=
None
,
load_file_
name
=
None
):
def
load_params
(
executor
,
dirname
,
main_program
=
None
,
file
name
=
None
):
"""
load all parameters from directory by executor.
"""
...
...
@@ -226,11 +225,10 @@ def load_params(executor, dirname, main_program=None, load_file_name=None):
dirname
=
dirname
,
main_program
=
main_program
,
predicate
=
is_parameter
,
load_file_name
=
load_file_
name
)
filename
=
file
name
)
def
load_persistables
(
executor
,
dirname
,
main_program
=
None
,
load_file_name
=
None
):
def
load_persistables
(
executor
,
dirname
,
main_program
=
None
,
filename
=
None
):
"""
load all persistables from directory by executor.
"""
...
...
@@ -239,7 +237,7 @@ def load_persistables(executor, dirname, main_program=None,
dirname
=
dirname
,
main_program
=
main_program
,
predicate
=
is_persistable
,
load_file_name
=
load_file_
name
)
filename
=
file
name
)
def
get_inference_program
(
target_vars
,
main_program
=
None
):
...
...
@@ -299,7 +297,8 @@ def save_inference_model(dirname,
target_vars
,
executor
,
main_program
=
None
,
save_file_name
=
None
):
model_filename
=
None
,
params_filename
=
None
):
"""
Build a model especially for inference,
and save it to directory by the executor.
...
...
@@ -310,8 +309,11 @@ def save_inference_model(dirname,
:param executor: executor that save inference model
:param main_program: original program, which will be pruned to build the inference model.
Default default_main_program().
:param save_file_name: The name of a single file that all parameters are saved to.
If it is None, save parameters to separate files.
:param model_filename: The name of file to save inference program.
If not specified, default filename `__model__` will be used.
:param params_filename: The name of file to save parameters.
It is used for the case that all parameters are saved in a single binary file.
If not specified, parameters are considered saved in separate files.
:return: None
"""
...
...
@@ -342,15 +344,19 @@ def save_inference_model(dirname,
prepend_feed_ops
(
inference_program
,
feeded_var_names
)
append_fetch_ops
(
inference_program
,
fetch_var_names
)
if
save_file_name
==
None
:
model_file
_name
=
dirname
+
"/__model__"
if
model_filename
is
not
None
:
model_file
name
=
os
.
path
.
basename
(
model_filename
)
else
:
model_file_name
=
dirname
+
"/__model_combined__"
model_filename
=
"__model__"
model_filename
=
os
.
path
.
join
(
dirname
,
model_filename
)
with
open
(
model_file_name
,
"wb"
)
as
f
:
if
params_filename
is
not
None
:
params_filename
=
os
.
path
.
basename
(
params_filename
)
with
open
(
model_filename
,
"wb"
)
as
f
:
f
.
write
(
inference_program
.
desc
.
serialize_to_string
())
save_persistables
(
executor
,
dirname
,
inference_program
,
save_file_
name
)
save_persistables
(
executor
,
dirname
,
inference_program
,
params_file
name
)
def
get_feed_targets_names
(
program
):
...
...
@@ -371,15 +377,21 @@ def get_fetch_targets_names(program):
return
fetch_targets_names
def
load_inference_model
(
dirname
,
executor
,
load_file_name
=
None
):
def
load_inference_model
(
dirname
,
executor
,
model_filename
=
None
,
params_filename
=
None
):
"""
Load inference model from a directory
:param dirname: directory path
:param executor: executor that load inference model
:param load_file_name: The name of the single file that all parameters are loaded from.
If it is None, load parameters from separate files.
:param model_filename: The name of file to load inference program.
If not specified, default filename `__model__` will be used.
:param params_filename: The name of file to load parameters.
It is used for the case that all parameters are saved in a single binary file.
If not specified, parameters are considered saved in separate files.
:return: [program, feed_target_names, fetch_targets]
program: program especially for inference.
feed_target_names: Names of variables that need to feed data
...
...
@@ -388,16 +400,20 @@ def load_inference_model(dirname, executor, load_file_name=None):
if
not
os
.
path
.
isdir
(
dirname
):
raise
ValueError
(
"There is no directory named '%s'"
,
dirname
)
if
load_file_name
==
None
:
model_file
_name
=
dirname
+
"/__model__"
if
model_filename
is
not
None
:
model_file
name
=
os
.
path
.
basename
(
model_filename
)
else
:
model_file_name
=
dirname
+
"/__model_combined__"
model_filename
=
"__model__"
model_filename
=
os
.
path
.
join
(
dirname
,
model_filename
)
if
params_filename
is
not
None
:
params_filename
=
os
.
path
.
basename
(
params_filename
)
with
open
(
model_file
_
name
,
"rb"
)
as
f
:
with
open
(
model_filename
,
"rb"
)
as
f
:
program_desc_str
=
f
.
read
()
program
=
Program
.
parse_from_string
(
program_desc_str
)
load_persistables
(
executor
,
dirname
,
program
,
load_file_
name
)
load_persistables
(
executor
,
dirname
,
program
,
params_file
name
)
feed_target_names
=
get_feed_targets_names
(
program
)
fetch_target_names
=
get_fetch_targets_names
(
program
)
...
...
python/paddle/fluid/layers/detection.py
浏览文件 @
c0876cf6
...
...
@@ -132,7 +132,10 @@ def detection_output(scores,
return
nmsed_outs
def
bipartite_match
(
dist_matrix
,
name
=
None
):
def
bipartite_match
(
dist_matrix
,
match_type
=
None
,
dist_threshold
=
None
,
name
=
None
):
"""
**Bipartite matchint operator**
...
...
@@ -164,6 +167,11 @@ def bipartite_match(dist_matrix, name=None):
This tensor can contain LoD information to represent a batch of
inputs. One instance of this batch can contain different numbers of
entities.
match_type(string|None): The type of matching method, should be
'bipartite' or 'per_prediction', 'bipartite' by defalut.
dist_threshold(float|None): If `match_type` is 'per_prediction',
this threshold is to determine the extra matching bboxes based
on the maximum distance, 0.5 by defalut.
Returns:
match_indices(Variable): A 2-D Tensor with shape [N, M] in int type.
N is the batch size. If match_indices[i][j] is -1, it
...
...
@@ -183,6 +191,10 @@ def bipartite_match(dist_matrix, name=None):
helper
.
append_op
(
type
=
'bipartite_match'
,
inputs
=
{
'DistMat'
:
dist_matrix
},
attrs
=
{
'match_type'
:
match_type
,
'dist_threshold'
:
dist_threshold
,
},
outputs
=
{
'ColToRowMatchIndices'
:
match_indices
,
'ColToRowMatchDist'
:
match_distance
...
...
@@ -333,7 +345,7 @@ def ssd_loss(location,
loc_loss_weight (float): Weight for localization loss, 1.0 by default.
conf_loss_weight (float): Weight for confidence loss, 1.0 by default.
match_type (str): The type of matching method during training, should
be 'bipartite' or 'per_prediction'.
be 'bipartite' or 'per_prediction'
, 'per_prediction' by defalut
.
mining_type (str): The hard example mining type, should be 'hard_example'
or 'max_negative', now only support `max_negative`.
...
...
@@ -381,7 +393,8 @@ def ssd_loss(location,
# 1.1 Compute IOU similarity between ground-truth boxes and prior boxes.
iou
=
iou_similarity
(
x
=
gt_box
,
y
=
prior_box
)
# 1.2 Compute matched boundding box by bipartite matching algorithm.
matched_indices
,
matched_dist
=
bipartite_match
(
iou
)
matched_indices
,
matched_dist
=
bipartite_match
(
iou
,
match_type
,
overlap_threshold
)
# 2. Compute confidence for mining hard examples
# 2.1. Get the target label based on matched indices
...
...
python/paddle/fluid/profiler.py
浏览文件 @
c0876cf6
...
...
@@ -97,9 +97,14 @@ def profiler(state, sorted_key=None):
The `ave` means sorting by the average execution time.
"""
if
state
not
in
[
'CPU'
,
'GPU'
]:
raise
ValueError
(
"The state must be 'CPU' or 'GPU'."
)
prof_state
=
core
.
ProfilerState
.
kCUDA
if
state
==
"GPU"
else
core
.
ProfilerState
.
kCPU
if
state
not
in
[
'CPU'
,
'GPU'
,
"All"
]:
raise
ValueError
(
"The state must be 'CPU' or 'GPU' or 'All'."
)
if
state
==
"GPU"
:
prof_state
=
core
.
ProfilerState
.
kCUDA
elif
state
==
"CPU"
:
prof_state
=
core
.
ProfilerState
.
kCPU
else
:
prof_state
=
core
.
ProfilerState
.
kAll
core
.
enable_profiler
(
prof_state
)
yield
...
...
python/paddle/fluid/tests/book/notest_rnn_encoder_decoer.py
浏览文件 @
c0876cf6
...
...
@@ -228,32 +228,34 @@ def infer(use_cuda, save_dirname=None):
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
save_dirname
,
exe
)
lod
=
[
0
,
4
,
10
]
word_data
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
1
)
trg_word
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
1
)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
assert
feed_target_names
[
0
]
==
'source_sequence'
assert
feed_target_names
[
1
]
==
'target_sequence'
results
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
word_data
,
feed_target_names
[
1
]:
trg_word
,
},
fetch_list
=
fetch_targets
,
return_numpy
=
False
)
print
(
results
[
0
].
lod
())
np_data
=
np
.
array
(
results
[
0
])
print
(
"Inference shape: "
,
np_data
.
shape
)
print
(
"Inference results: "
,
np_data
)
inference_scope
=
fluid
.
core
.
Scope
()
with
fluid
.
scope_guard
(
inference_scope
):
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
save_dirname
,
exe
)
lod
=
[
0
,
4
,
10
]
word_data
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
1
)
trg_word
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
1
)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
assert
feed_target_names
[
0
]
==
'source_sequence'
assert
feed_target_names
[
1
]
==
'target_sequence'
results
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
word_data
,
feed_target_names
[
1
]:
trg_word
,
},
fetch_list
=
fetch_targets
,
return_numpy
=
False
)
print
(
results
[
0
].
lod
())
np_data
=
np
.
array
(
results
[
0
])
print
(
"Inference shape: "
,
np_data
.
shape
)
print
(
"Inference results: "
,
np_data
)
def
main
(
use_cuda
):
...
...
python/paddle/fluid/tests/book/test_fit_a_line.py
浏览文件 @
c0876cf6
...
...
@@ -72,23 +72,26 @@ def infer(use_cuda, save_dirname=None):
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
save_dirname
,
exe
)
# The input's dimension should be 2-D and the second dim is 13
# The input data should be >= 0
batch_size
=
10
tensor_x
=
numpy
.
random
.
uniform
(
0
,
10
,
[
batch_size
,
13
]).
astype
(
"float32"
)
assert
feed_target_names
[
0
]
==
'x'
results
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
tensor_x
},
fetch_list
=
fetch_targets
)
print
(
"infer shape: "
,
results
[
0
].
shape
)
print
(
"infer results: "
,
results
[
0
])
inference_scope
=
fluid
.
core
.
Scope
()
with
fluid
.
scope_guard
(
inference_scope
):
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
save_dirname
,
exe
)
# The input's dimension should be 2-D and the second dim is 13
# The input data should be >= 0
batch_size
=
10
tensor_x
=
numpy
.
random
.
uniform
(
0
,
10
,
[
batch_size
,
13
]).
astype
(
"float32"
)
assert
feed_target_names
[
0
]
==
'x'
results
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
tensor_x
},
fetch_list
=
fetch_targets
)
print
(
"infer shape: "
,
results
[
0
].
shape
)
print
(
"infer results: "
,
results
[
0
])
def
main
(
use_cuda
):
...
...
python/paddle/fluid/tests/book/test_image_classification.py
浏览文件 @
c0876cf6
...
...
@@ -174,22 +174,26 @@ def infer(use_cuda, save_dirname=None):
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
save_dirname
,
exe
)
# The input's dimension of conv should be 4-D or 5-D.
tensor_img
=
numpy
.
random
.
rand
(
1
,
3
,
32
,
32
).
astype
(
"float32"
)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
tensor_img
},
fetch_list
=
fetch_targets
)
print
(
"infer results: "
,
results
[
0
])
inference_scope
=
fluid
.
core
.
Scope
()
with
fluid
.
scope_guard
(
inference_scope
):
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
save_dirname
,
exe
)
# The input's dimension of conv should be 4-D or 5-D.
# Use normilized image pixels as input data, which should be in the range [0, 1.0].
batch_size
=
1
tensor_img
=
numpy
.
random
.
rand
(
batch_size
,
3
,
32
,
32
).
astype
(
"float32"
)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
tensor_img
},
fetch_list
=
fetch_targets
)
print
(
"infer results: "
,
results
[
0
])
def
main
(
net_type
,
use_cuda
):
...
...
python/paddle/fluid/tests/book/test_label_semantic_roles.py
浏览文件 @
c0876cf6
...
...
@@ -26,7 +26,7 @@ import unittest
word_dict
,
verb_dict
,
label_dict
=
conll05
.
get_dict
()
word_dict_len
=
len
(
word_dict
)
label_dict_len
=
len
(
label_dict
)
pred_len
=
len
(
verb_dict
)
pred_
dict_
len
=
len
(
verb_dict
)
mark_dict_len
=
2
word_dim
=
32
...
...
@@ -53,7 +53,7 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark,
# 8 features
predicate_embedding
=
fluid
.
layers
.
embedding
(
input
=
predicate
,
size
=
[
pred_len
,
word_dim
],
size
=
[
pred_
dict_
len
,
word_dim
],
dtype
=
'float32'
,
is_sparse
=
IS_SPARSE
,
param_attr
=
'vemb'
)
...
...
@@ -234,6 +234,7 @@ def train(use_cuda, save_dirname=None):
# Set the threshold low to speed up the CI test
if
float
(
pass_precision
)
>
0.05
:
if
save_dirname
is
not
None
:
# TODO(liuyiqun): Change the target to crf_decode
fluid
.
io
.
save_inference_model
(
save_dirname
,
[
'word_data'
,
'verb_data'
,
'ctx_n2_data'
,
'ctx_n1_data'
,
'ctx_0_data'
,
'ctx_p1_data'
,
...
...
@@ -251,51 +252,60 @@ def infer(use_cuda, save_dirname=None):
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
save_dirname
,
exe
)
lod
=
[
0
,
4
,
10
]
ts_word
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
1
)
ts_pred
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
1
)
ts_ctx_n2
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
1
)
ts_ctx_n1
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
1
)
ts_ctx_0
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
1
)
ts_ctx_p1
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
1
)
ts_ctx_p2
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
1
)
ts_mark
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
1
)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
assert
feed_target_names
[
0
]
==
'word_data'
assert
feed_target_names
[
1
]
==
'verb_data'
assert
feed_target_names
[
2
]
==
'ctx_n2_data'
assert
feed_target_names
[
3
]
==
'ctx_n1_data'
assert
feed_target_names
[
4
]
==
'ctx_0_data'
assert
feed_target_names
[
5
]
==
'ctx_p1_data'
assert
feed_target_names
[
6
]
==
'ctx_p2_data'
assert
feed_target_names
[
7
]
==
'mark_data'
results
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
ts_word
,
feed_target_names
[
1
]:
ts_pred
,
feed_target_names
[
2
]:
ts_ctx_n2
,
feed_target_names
[
3
]:
ts_ctx_n1
,
feed_target_names
[
4
]:
ts_ctx_0
,
feed_target_names
[
5
]:
ts_ctx_p1
,
feed_target_names
[
6
]:
ts_ctx_p2
,
feed_target_names
[
7
]:
ts_mark
},
fetch_list
=
fetch_targets
,
return_numpy
=
False
)
print
(
results
[
0
].
lod
())
np_data
=
np
.
array
(
results
[
0
])
print
(
"Inference Shape: "
,
np_data
.
shape
)
print
(
"Inference results: "
,
np_data
)
inference_scope
=
fluid
.
core
.
Scope
()
with
fluid
.
scope_guard
(
inference_scope
):
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
save_dirname
,
exe
)
lod
=
[
0
,
4
,
10
]
word
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
word_dict_len
-
1
)
pred
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
pred_dict_len
-
1
)
ctx_n2
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
word_dict_len
-
1
)
ctx_n1
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
word_dict_len
-
1
)
ctx_0
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
word_dict_len
-
1
)
ctx_p1
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
word_dict_len
-
1
)
ctx_p2
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
word_dict_len
-
1
)
mark
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
mark_dict_len
-
1
)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
assert
feed_target_names
[
0
]
==
'word_data'
assert
feed_target_names
[
1
]
==
'verb_data'
assert
feed_target_names
[
2
]
==
'ctx_n2_data'
assert
feed_target_names
[
3
]
==
'ctx_n1_data'
assert
feed_target_names
[
4
]
==
'ctx_0_data'
assert
feed_target_names
[
5
]
==
'ctx_p1_data'
assert
feed_target_names
[
6
]
==
'ctx_p2_data'
assert
feed_target_names
[
7
]
==
'mark_data'
results
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
word
,
feed_target_names
[
1
]:
pred
,
feed_target_names
[
2
]:
ctx_n2
,
feed_target_names
[
3
]:
ctx_n1
,
feed_target_names
[
4
]:
ctx_0
,
feed_target_names
[
5
]:
ctx_p1
,
feed_target_names
[
6
]:
ctx_p2
,
feed_target_names
[
7
]:
mark
},
fetch_list
=
fetch_targets
,
return_numpy
=
False
)
print
(
results
[
0
].
lod
())
np_data
=
np
.
array
(
results
[
0
])
print
(
"Inference Shape: "
,
np_data
.
shape
)
def
main
(
use_cuda
):
...
...
python/paddle/fluid/tests/book/test_recognize_digits.py
浏览文件 @
c0876cf6
...
...
@@ -78,7 +78,12 @@ def conv_net(img, label):
return
loss_net
(
conv_pool_2
,
label
)
def
train
(
nn_type
,
use_cuda
,
parallel
,
save_dirname
,
save_param_filename
):
def
train
(
nn_type
,
use_cuda
,
parallel
,
save_dirname
=
None
,
model_filename
=
None
,
params_filename
=
None
):
if
use_cuda
and
not
fluid
.
core
.
is_compiled_with_cuda
():
return
img
=
fluid
.
layers
.
data
(
name
=
'img'
,
shape
=
[
1
,
28
,
28
],
dtype
=
'float32'
)
...
...
@@ -146,7 +151,8 @@ def train(nn_type, use_cuda, parallel, save_dirname, save_param_filename):
fluid
.
io
.
save_inference_model
(
save_dirname
,
[
"img"
],
[
prediction
],
exe
,
save_file_name
=
save_param_filename
)
model_filename
=
model_filename
,
params_filename
=
params_filename
)
return
else
:
print
(
...
...
@@ -158,54 +164,62 @@ def train(nn_type, use_cuda, parallel, save_dirname, save_param_filename):
raise
AssertionError
(
"Loss of recognize digits is too large"
)
def
infer
(
use_cuda
,
save_dirname
=
None
,
param_filename
=
None
):
def
infer
(
use_cuda
,
save_dirname
=
None
,
model_filename
=
None
,
params_filename
=
None
):
if
save_dirname
is
None
:
return
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
save_dirname
,
exe
,
param_filename
)
# The input's dimension of conv should be 4-D or 5-D.
# Use normilized image pixels as input data, which should be in the range [-1.0, 1.0].
batch_size
=
1
tensor_img
=
numpy
.
random
.
uniform
(
-
1.0
,
1.0
,
[
batch_size
,
1
,
28
,
28
]).
astype
(
"float32"
)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
tensor_img
},
fetch_list
=
fetch_targets
)
print
(
"infer results: "
,
results
[
0
])
inference_scope
=
fluid
.
core
.
Scope
()
with
fluid
.
scope_guard
(
inference_scope
):
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
save_dirname
,
exe
,
model_filename
,
params_filename
)
# The input's dimension of conv should be 4-D or 5-D.
# Use normilized image pixels as input data, which should be in the range [-1.0, 1.0].
batch_size
=
1
tensor_img
=
numpy
.
random
.
uniform
(
-
1.0
,
1.0
,
[
batch_size
,
1
,
28
,
28
]).
astype
(
"float32"
)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
tensor_img
},
fetch_list
=
fetch_targets
)
print
(
"infer results: "
,
results
[
0
])
def
main
(
use_cuda
,
parallel
,
nn_type
,
combine
):
save_dirname
=
None
model_filename
=
None
params_filename
=
None
if
not
use_cuda
and
not
parallel
:
save_dirname
=
"recognize_digits_"
+
nn_type
+
".inference.model"
save_filename
=
None
if
combine
==
True
:
save_filename
=
"__params_combined__"
else
:
save_dirname
=
None
save_filename
=
None
model_filename
=
"__model_combined__"
params_filename
=
"__params_combined__"
train
(
nn_type
=
nn_type
,
use_cuda
=
use_cuda
,
parallel
=
parallel
,
save_dirname
=
save_dirname
,
save_param_filename
=
save_filename
)
model_filename
=
model_filename
,
params_filename
=
params_filename
)
infer
(
use_cuda
=
use_cuda
,
save_dirname
=
save_dirname
,
param_filename
=
save_filename
)
model_filename
=
model_filename
,
params_filename
=
params_filename
)
class
TestRecognizeDigits
(
unittest
.
TestCase
):
...
...
python/paddle/fluid/tests/book/test_recommender_system.py
浏览文件 @
c0876cf6
...
...
@@ -251,13 +251,6 @@ def infer(use_cuda, save_dirname=None):
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
save_dirname
,
exe
)
def
create_lod_tensor
(
data
,
lod
=
None
):
tensor
=
fluid
.
LoDTensor
()
if
lod
is
None
:
...
...
@@ -275,44 +268,53 @@ def infer(use_cuda, save_dirname=None):
tensor
.
set
(
flattened_data
,
place
)
return
tensor
# Use the first data from paddle.dataset.movielens.test() as input
assert
feed_target_names
[
0
]
==
"user_id"
user_id
=
create_lod_tensor
([[
1
]])
assert
feed_target_names
[
1
]
==
"gender_id"
gender_id
=
create_lod_tensor
([[
1
]])
assert
feed_target_names
[
2
]
==
"age_id"
age_id
=
create_lod_tensor
([[
0
]])
assert
feed_target_names
[
3
]
==
"job_id"
job_id
=
create_lod_tensor
([[
10
]])
assert
feed_target_names
[
4
]
==
"movie_id"
movie_id
=
create_lod_tensor
([[
783
]])
assert
feed_target_names
[
5
]
==
"category_id"
category_id
=
create_lod_tensor
([[
10
],
[
8
],
[
9
]],
[[
0
,
3
]])
assert
feed_target_names
[
6
]
==
"movie_title"
movie_title
=
create_lod_tensor
([[
1069
],
[
4140
],
[
2923
],
[
710
],
[
988
]],
[[
0
,
5
]])
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
user_id
,
feed_target_names
[
1
]:
gender_id
,
feed_target_names
[
2
]:
age_id
,
feed_target_names
[
3
]:
job_id
,
feed_target_names
[
4
]:
movie_id
,
feed_target_names
[
5
]:
category_id
,
feed_target_names
[
6
]:
movie_title
},
fetch_list
=
fetch_targets
,
return_numpy
=
False
)
print
(
"inferred score: "
,
np
.
array
(
results
[
0
]))
inference_scope
=
fluid
.
core
.
Scope
()
with
fluid
.
scope_guard
(
inference_scope
):
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
save_dirname
,
exe
)
# Use the first data from paddle.dataset.movielens.test() as input
assert
feed_target_names
[
0
]
==
"user_id"
user_id
=
create_lod_tensor
([[
1
]])
assert
feed_target_names
[
1
]
==
"gender_id"
gender_id
=
create_lod_tensor
([[
1
]])
assert
feed_target_names
[
2
]
==
"age_id"
age_id
=
create_lod_tensor
([[
0
]])
assert
feed_target_names
[
3
]
==
"job_id"
job_id
=
create_lod_tensor
([[
10
]])
assert
feed_target_names
[
4
]
==
"movie_id"
movie_id
=
create_lod_tensor
([[
783
]])
assert
feed_target_names
[
5
]
==
"category_id"
category_id
=
create_lod_tensor
([[
10
],
[
8
],
[
9
]],
[[
0
,
3
]])
assert
feed_target_names
[
6
]
==
"movie_title"
movie_title
=
create_lod_tensor
([[
1069
],
[
4140
],
[
2923
],
[
710
],
[
988
]],
[[
0
,
5
]])
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
user_id
,
feed_target_names
[
1
]:
gender_id
,
feed_target_names
[
2
]:
age_id
,
feed_target_names
[
3
]:
job_id
,
feed_target_names
[
4
]:
movie_id
,
feed_target_names
[
5
]:
category_id
,
feed_target_names
[
6
]:
movie_title
},
fetch_list
=
fetch_targets
,
return_numpy
=
False
)
print
(
"inferred score: "
,
np
.
array
(
results
[
0
]))
def
main
(
use_cuda
):
...
...
python/paddle/fluid/tests/book/test_understand_sentiment.py
浏览文件 @
c0876cf6
#
Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -193,36 +193,39 @@ def train(word_dict, net_method, use_cuda, parallel=False, save_dirname=None):
net_method
.
__name__
))
def
infer
(
use_cuda
,
save_dirname
=
None
):
def
infer
(
word_dict
,
use_cuda
,
save_dirname
=
None
):
if
save_dirname
is
None
:
return
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
save_dirname
,
exe
)
lod
=
[
0
,
4
,
10
]
word_dict
=
paddle
.
dataset
.
imdb
.
word_dict
()
tensor_words
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
len
(
word_dict
)
-
1
)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
assert
feed_target_names
[
0
]
==
"words"
results
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
tensor_words
},
fetch_list
=
fetch_targets
,
return_numpy
=
False
)
print
(
results
[
0
].
lod
())
np_data
=
np
.
array
(
results
[
0
])
print
(
"Inference Shape: "
,
np_data
.
shape
)
print
(
"Inference results: "
,
np_data
)
inference_scope
=
fluid
.
core
.
Scope
()
with
fluid
.
scope_guard
(
inference_scope
):
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
save_dirname
,
exe
)
word_dict_len
=
len
(
word_dict
)
lod
=
[
0
,
4
,
10
]
tensor_words
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
word_dict_len
-
1
)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
assert
feed_target_names
[
0
]
==
"words"
results
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
tensor_words
},
fetch_list
=
fetch_targets
,
return_numpy
=
False
)
print
(
results
[
0
].
lod
())
np_data
=
np
.
array
(
results
[
0
])
print
(
"Inference Shape: "
,
np_data
.
shape
)
print
(
"Inference results: "
,
np_data
)
def
main
(
word_dict
,
net_method
,
use_cuda
,
parallel
=
False
,
save_dirname
=
None
):
...
...
@@ -258,7 +261,7 @@ class TestUnderstandSentiment(unittest.TestCase):
self
.
word_dict
,
net_method
=
convolution_net
,
use_cuda
=
False
,
save_dirname
=
"understand_sentiment.inference.model"
)
save_dirname
=
"understand_sentiment
_conv
.inference.model"
)
def
test_conv_cpu_parallel
(
self
):
with
self
.
new_program_scope
():
...
...
@@ -271,7 +274,11 @@ class TestUnderstandSentiment(unittest.TestCase):
@
unittest
.
skip
(
reason
=
"make CI faster"
)
def
test_stacked_lstm_cpu
(
self
):
with
self
.
new_program_scope
():
main
(
self
.
word_dict
,
net_method
=
stacked_lstm_net
,
use_cuda
=
False
)
main
(
self
.
word_dict
,
net_method
=
stacked_lstm_net
,
use_cuda
=
False
,
save_dirname
=
"understand_sentiment_stacked_lstm.inference.model"
)
def
test_stacked_lstm_cpu_parallel
(
self
):
with
self
.
new_program_scope
():
...
...
@@ -287,7 +294,7 @@ class TestUnderstandSentiment(unittest.TestCase):
self
.
word_dict
,
net_method
=
convolution_net
,
use_cuda
=
True
,
save_dirname
=
"understand_sentiment.inference.model"
)
save_dirname
=
"understand_sentiment
_conv
.inference.model"
)
def
test_conv_gpu_parallel
(
self
):
with
self
.
new_program_scope
():
...
...
@@ -300,7 +307,11 @@ class TestUnderstandSentiment(unittest.TestCase):
@
unittest
.
skip
(
reason
=
"make CI faster"
)
def
test_stacked_lstm_gpu
(
self
):
with
self
.
new_program_scope
():
main
(
self
.
word_dict
,
net_method
=
stacked_lstm_net
,
use_cuda
=
True
)
main
(
self
.
word_dict
,
net_method
=
stacked_lstm_net
,
use_cuda
=
True
,
save_dirname
=
"understand_sentiment_stacked_lstm.inference.model"
)
def
test_stacked_lstm_gpu_parallel
(
self
):
with
self
.
new_program_scope
():
...
...
python/paddle/fluid/tests/book/test_word2vec.py
浏览文件 @
c0876cf6
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# # Licensed under the Apache License, Version 2.0 (the "License");
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
...
...
@@ -21,6 +22,7 @@ import sys
def
create_random_lodtensor
(
lod
,
place
,
low
,
high
):
# The range of data elements is [low, high]
data
=
np
.
random
.
random_integers
(
low
,
high
,
[
lod
[
-
1
],
1
]).
astype
(
"int64"
)
res
=
fluid
.
LoDTensor
()
res
.
set
(
data
,
place
)
...
...
@@ -28,54 +30,7 @@ def create_random_lodtensor(lod, place, low, high):
return
res
def
infer
(
use_cuda
,
save_dirname
=
None
):
if
save_dirname
is
None
:
return
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
save_dirname
,
exe
)
word_dict
=
paddle
.
dataset
.
imikolov
.
build_dict
()
dict_size
=
len
(
word_dict
)
-
1
# Setup input, by creating 4 words, and setting up lod required for
# lookup_table_op
lod
=
[
0
,
1
]
first_word
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
dict_size
)
second_word
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
dict_size
)
third_word
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
dict_size
)
fourth_word
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
dict_size
)
assert
feed_target_names
[
0
]
==
'firstw'
assert
feed_target_names
[
1
]
==
'secondw'
assert
feed_target_names
[
2
]
==
'thirdw'
assert
feed_target_names
[
3
]
==
'forthw'
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
first_word
,
feed_target_names
[
1
]:
second_word
,
feed_target_names
[
2
]:
third_word
,
feed_target_names
[
3
]:
fourth_word
},
fetch_list
=
fetch_targets
,
return_numpy
=
False
)
print
(
results
[
0
].
lod
())
np_data
=
np
.
array
(
results
[
0
])
print
(
"Inference Shape: "
,
np_data
.
shape
)
print
(
"Inference results: "
,
np_data
)
def
train
(
use_cuda
,
is_sparse
,
parallel
,
save_dirname
):
def
train
(
use_cuda
,
is_sparse
,
is_parallel
,
save_dirname
):
PASS_NUM
=
100
EMBED_SIZE
=
32
HIDDEN_SIZE
=
256
...
...
@@ -130,7 +85,7 @@ def train(use_cuda, is_sparse, parallel, save_dirname):
forth_word
=
fluid
.
layers
.
data
(
name
=
'forthw'
,
shape
=
[
1
],
dtype
=
'int64'
)
next_word
=
fluid
.
layers
.
data
(
name
=
'nextw'
,
shape
=
[
1
],
dtype
=
'int64'
)
if
not
parallel
:
if
not
is_
parallel
:
avg_cost
,
predict_word
=
__network__
(
[
first_word
,
second_word
,
third_word
,
forth_word
,
next_word
])
else
:
...
...
@@ -176,11 +131,67 @@ def train(use_cuda, is_sparse, parallel, save_dirname):
raise
AssertionError
(
"Cost is too large {0:2.2}"
.
format
(
avg_cost_np
[
0
]))
def
main
(
use_cuda
,
is_sparse
,
parallel
):
def
infer
(
use_cuda
,
save_dirname
=
None
):
if
save_dirname
is
None
:
return
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
inference_scope
=
fluid
.
core
.
Scope
()
with
fluid
.
scope_guard
(
inference_scope
):
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
save_dirname
,
exe
)
word_dict
=
paddle
.
dataset
.
imikolov
.
build_dict
()
dict_size
=
len
(
word_dict
)
# Setup inputs, by creating 4 words, the lod of which should be [0, 1]
lod
=
[
0
,
1
]
first_word
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
dict_size
-
1
)
second_word
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
dict_size
-
1
)
third_word
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
dict_size
-
1
)
fourth_word
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
dict_size
-
1
)
assert
feed_target_names
[
0
]
==
'firstw'
assert
feed_target_names
[
1
]
==
'secondw'
assert
feed_target_names
[
2
]
==
'thirdw'
assert
feed_target_names
[
3
]
==
'forthw'
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
first_word
,
feed_target_names
[
1
]:
second_word
,
feed_target_names
[
2
]:
third_word
,
feed_target_names
[
3
]:
fourth_word
},
fetch_list
=
fetch_targets
,
return_numpy
=
False
)
print
(
results
[
0
].
lod
())
np_data
=
np
.
array
(
results
[
0
])
print
(
"Inference Shape: "
,
np_data
.
shape
)
def
main
(
use_cuda
,
is_sparse
,
is_parallel
):
if
use_cuda
and
not
fluid
.
core
.
is_compiled_with_cuda
():
return
save_dirname
=
"word2vec.inference.model"
train
(
use_cuda
,
is_sparse
,
parallel
,
save_dirname
)
if
not
is_parallel
:
save_dirname
=
"word2vec.inference.model"
else
:
save_dirname
=
None
train
(
use_cuda
,
is_sparse
,
is_parallel
,
save_dirname
)
infer
(
use_cuda
,
save_dirname
)
...
...
@@ -193,10 +204,10 @@ class W2VTest(unittest.TestCase):
pass
def
inject_test_method
(
use_cuda
,
is_sparse
,
parallel
):
def
inject_test_method
(
use_cuda
,
is_sparse
,
is_
parallel
):
fn_name
=
"test_{0}_{1}_{2}"
.
format
(
"cuda"
if
use_cuda
else
"cpu"
,
"sparse"
if
is_sparse
else
"dense"
,
"parallel"
if
parallel
else
"normal"
)
if
is_
parallel
else
"normal"
)
def
__impl__
(
*
args
,
**
kwargs
):
prog
=
fluid
.
Program
()
...
...
@@ -204,10 +215,12 @@ def inject_test_method(use_cuda, is_sparse, parallel):
scope
=
fluid
.
core
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
with
fluid
.
program_guard
(
prog
,
startup_prog
):
main
(
use_cuda
=
use_cuda
,
is_sparse
=
is_sparse
,
parallel
=
parallel
)
main
(
use_cuda
=
use_cuda
,
is_sparse
=
is_sparse
,
is_parallel
=
is_parallel
)
# run only 2 cases: use_cuda is either True or False
if
is_sparse
==
False
and
parallel
==
False
:
if
use_cuda
and
is_sparse
:
fn
=
__impl__
else
:
# skip the other test when on CI server
...
...
@@ -219,8 +232,8 @@ def inject_test_method(use_cuda, is_sparse, parallel):
for
use_cuda
in
(
False
,
True
):
for
is_sparse
in
(
False
,
True
):
for
parallel
in
(
False
,
True
):
inject_test_method
(
use_cuda
,
is_sparse
,
parallel
)
for
is_
parallel
in
(
False
,
True
):
inject_test_method
(
use_cuda
,
is_sparse
,
is_
parallel
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/book_distribute/notest_dist_fit_a_line.py
浏览文件 @
c0876cf6
...
...
@@ -48,6 +48,7 @@ current_endpoint = os.getenv("SERVER_ENDPOINT")
# run as trainer or parameter server
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
# get the training role: trainer/pserver
t
.
transpile
(
optimize_ops
,
params_grads
,
pservers
=
pserver_endpoints
,
trainers
=
2
)
if
training_role
==
"PSERVER"
:
...
...
@@ -65,8 +66,6 @@ else:
PASS_NUM
=
100
for
pass_id
in
range
(
PASS_NUM
):
fluid
.
io
.
save_persistables
(
exe
,
"./fit_a_line.model/"
)
fluid
.
io
.
load_persistables
(
exe
,
"./fit_a_line.model/"
)
for
data
in
train_reader
():
avg_loss_value
=
exe
.
run
(
trainer_prog
,
feed
=
feeder
.
feed
(
data
),
...
...
python/paddle/fluid/tests/book_distribute/notest_dist_image_classification.py
浏览文件 @
c0876cf6
...
...
@@ -138,6 +138,7 @@ current_endpoint = os.getenv("SERVER_ENDPOINT")
# run as trainer or parameter server
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
# get the training role: trainer/pserver
t
.
transpile
(
optimize_ops
,
params_grads
,
pservers
=
pserver_endpoints
,
trainers
=
TRAINERS
)
...
...
python/paddle/fluid/tests/book_distribute/notest_dist_label_semantic_roles.py
浏览文件 @
c0876cf6
...
...
@@ -191,6 +191,7 @@ def main():
# run as trainer or parameter server
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
# get the training role: trainer/pserver
t
.
transpile
(
optimize_ops
,
params_grads
,
pservers
=
pserver_endpoints
,
trainers
=
2
)
...
...
python/paddle/fluid/tests/book_distribute/notest_dist_word2vec.py
浏览文件 @
c0876cf6
...
...
@@ -82,6 +82,7 @@ current_endpoint = os.getenv("SERVER_ENDPOINT")
# run as trainer or parameter server
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
# get the training role: trainer/pserver
t
.
transpile
(
optimize_ops
,
params_grads
,
pservers
=
pserver_endpoints
,
trainers
=
TRAINERS
)
if
training_role
==
"PSERVER"
:
...
...
@@ -97,9 +98,10 @@ elif training_role == "TRAINER":
feed_list
=
[
first_word
,
second_word
,
third_word
,
forth_word
,
next_word
],
place
=
place
)
exe
.
run
(
fluid
.
default_startup_program
())
trainer_prog
=
t
.
get_trainer_program
()
for
pass_id
in
range
(
PASS_NUM
):
for
data
in
train_reader
():
avg_cost_np
=
exe
.
run
(
t
.
get_trainer_program
()
,
avg_cost_np
=
exe
.
run
(
t
rainer_prog
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
avg_cost
])
print
(
"avg_cost_np"
,
avg_cost_np
)
...
...
python/paddle/fluid/tests/book_distribute/notest_machine_translation.py
浏览文件 @
c0876cf6
...
...
@@ -115,6 +115,7 @@ def main():
# run as trainer or parameter server
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
# get the training role: trainer/pserver
t
.
transpile
(
optimize_ops
,
params_grads
,
pservers
=
pserver_endpoints
,
trainers
=
2
)
...
...
python/paddle/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py
浏览文件 @
c0876cf6
...
...
@@ -64,11 +64,7 @@ if not current_endpoint:
t
=
fluid
.
DistributeTranspiler
()
t
.
transpile
(
optimize_ops
,
params_grads
,
0
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
optimize_ops
,
params_grads
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
if
training_role
==
"PSERVER"
:
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
)
...
...
python/paddle/fluid/tests/book_distribute/notest_recommender_system_dist.py
浏览文件 @
c0876cf6
...
...
@@ -171,6 +171,7 @@ def main():
current_endpoint
=
os
.
getenv
(
"SERVER_ENDPOINT"
)
# run as trainer or parameter server
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
t
.
transpile
(
optimize_ops
,
params_grads
,
pservers
=
pserver_endpoints
,
trainers
=
2
)
...
...
python/paddle/fluid/tests/book_distribute/notest_understand_sentiment_conv_dist.py
浏览文件 @
c0876cf6
...
...
@@ -90,6 +90,7 @@ def main():
# run as trainer or parameter server
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
# get the training role: trainer/pserver
t
.
transpile
(
optimize_ops
,
params_grads
,
pservers
=
pserver_endpoints
,
trainers
=
2
)
...
...
python/paddle/fluid/tests/book_distribute/notest_understand_sentiment_dynamic_lstm.py
浏览文件 @
c0876cf6
...
...
@@ -102,6 +102,7 @@ def main():
# run as trainer or parameter server
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
# get the training role: trainer/pserver
t
.
transpile
(
optimize_ops
,
params_grads
,
pservers
=
pserver_endpoints
,
trainers
=
2
)
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
c0876cf6
...
...
@@ -41,6 +41,7 @@ list(REMOVE_ITEM TEST_OPS test_while_op)
list
(
REMOVE_ITEM TEST_OPS test_lod_array_length_op
)
list
(
REMOVE_ITEM TEST_OPS test_reorder_lod_tensor
)
list
(
REMOVE_ITEM TEST_OPS test_profiler
)
list
(
REMOVE_ITEM TEST_OPS test_nvprof
)
list
(
REMOVE_ITEM TEST_OPS test_normalization_wrapper
)
list
(
REMOVE_ITEM TEST_OPS test_executor_and_mul
)
list
(
REMOVE_ITEM TEST_OPS test_assign_value_op
)
...
...
@@ -75,6 +76,7 @@ py_test_modules(test_while_op MODULES test_while_op)
py_test_modules
(
test_lod_array_length_op MODULES test_lod_array_length_op
)
py_test_modules
(
test_reorder_lod_tensor MODULES test_reorder_lod_tensor
)
py_test_modules
(
test_profiler MODULES test_profiler
)
py_test_modules
(
test_nvprof MODULES test_nvprof
)
py_test_modules
(
test_normalization_wrapper MODULES test_normalization_wrapper
)
py_test_modules
(
test_executor_and_mul MODULES test_executor_and_mul
)
py_test_modules
(
test_assign_value_op MODULES test_assign_value_op
)
...
...
python/paddle/fluid/tests/unittests/test_bipartite_match_op.py
浏览文件 @
c0876cf6
...
...
@@ -46,7 +46,20 @@ def bipartite_match(distance, match_indices, match_dist):
idx
+=
1
def
batch_bipartite_match
(
distance
,
lod
):
def
argmax_match
(
distance
,
match_indices
,
match_dist
,
threshold
):
r
,
c
=
distance
.
shape
for
j
in
xrange
(
c
):
if
match_indices
[
j
]
!=
-
1
:
continue
col_dist
=
distance
[:,
j
]
indices
=
np
.
argwhere
(
col_dist
>=
threshold
).
flatten
()
if
len
(
indices
)
<
1
:
continue
match_indices
[
j
]
=
indices
[
np
.
argmax
(
col_dist
[
indices
])]
match_dist
[
j
]
=
col_dist
[
match_indices
[
j
]]
def
batch_bipartite_match
(
distance
,
lod
,
match_type
=
None
,
dist_threshold
=
None
):
"""Bipartite Matching algorithm for batch input.
Arg:
distance (numpy.array) : The distance of two entries with shape [M, N].
...
...
@@ -59,6 +72,9 @@ def batch_bipartite_match(distance, lod):
for
i
in
range
(
len
(
lod
)
-
1
):
bipartite_match
(
distance
[
lod
[
i
]:
lod
[
i
+
1
],
:],
match_indices
[
i
,
:],
match_dist
[
i
,
:])
if
match_type
==
'per_prediction'
:
argmax_match
(
distance
[
lod
[
i
]:
lod
[
i
+
1
],
:],
match_indices
[
i
,
:],
match_dist
[
i
,
:],
dist_threshold
)
return
match_indices
,
match_dist
...
...
@@ -71,8 +87,8 @@ class TestBipartiteMatchOpWithLoD(OpTest):
self
.
inputs
=
{
'DistMat'
:
(
dist
,
lod
)}
self
.
outputs
=
{
'ColToRowMatchIndices'
:
(
match_indices
)
,
'ColToRowMatchDist'
:
(
match_dist
)
,
'ColToRowMatchIndices'
:
match_indices
,
'ColToRowMatchDist'
:
match_dist
,
}
def
test_check_output
(
self
):
...
...
@@ -96,5 +112,27 @@ class TestBipartiteMatchOpWithoutLoD(OpTest):
self
.
check_output
()
class
TestBipartiteMatchOpWithPerPredictionType
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
'bipartite_match'
lod
=
[[
0
,
5
,
11
,
23
]]
dist
=
np
.
random
.
random
((
23
,
237
)).
astype
(
'float32'
)
match_indices
,
match_dist
=
batch_bipartite_match
(
dist
,
lod
[
0
],
'per_prediction'
,
0.5
)
self
.
inputs
=
{
'DistMat'
:
(
dist
,
lod
)}
self
.
outputs
=
{
'ColToRowMatchIndices'
:
match_indices
,
'ColToRowMatchDist'
:
match_dist
,
}
self
.
attrs
=
{
'match_type'
:
'per_prediction'
,
'dist_threshold'
:
0.5
,
}
def
test_check_output
(
self
):
self
.
check_output
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_nvprof.py
0 → 100644
浏览文件 @
c0876cf6
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
os
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid.profiler
as
profiler
import
paddle.fluid.layers
as
layers
import
paddle.fluid.core
as
core
class
TestNVProf
(
unittest
.
TestCase
):
def
test_nvprof
(
self
):
if
not
fluid
.
core
.
is_compiled_with_cuda
():
return
epoc
=
8
dshape
=
[
4
,
3
,
28
,
28
]
data
=
layers
.
data
(
name
=
'data'
,
shape
=
[
3
,
28
,
28
],
dtype
=
'float32'
)
conv
=
layers
.
conv2d
(
data
,
20
,
3
,
stride
=
[
1
,
1
],
padding
=
[
1
,
1
])
place
=
fluid
.
CUDAPlace
(
0
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
output_file
=
'cuda_profiler.txt'
with
profiler
.
cuda_profiler
(
output_file
,
'csv'
)
as
nvprof
:
for
i
in
range
(
epoc
):
input
=
np
.
random
.
random
(
dshape
).
astype
(
'float32'
)
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
{
'data'
:
input
})
os
.
remove
(
output_file
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_profiler.py
浏览文件 @
c0876cf6
...
...
@@ -22,27 +22,9 @@ import paddle.fluid.core as core
class
TestProfiler
(
unittest
.
TestCase
):
def
test_nvprof
(
self
):
if
not
fluid
.
core
.
is_compiled_with_cuda
():
return
epoc
=
8
dshape
=
[
4
,
3
,
28
,
28
]
data
=
layers
.
data
(
name
=
'data'
,
shape
=
[
3
,
28
,
28
],
dtype
=
'float32'
)
conv
=
layers
.
conv2d
(
data
,
20
,
3
,
stride
=
[
1
,
1
],
padding
=
[
1
,
1
])
place
=
fluid
.
CUDAPlace
(
0
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
output_file
=
'cuda_profiler.txt'
with
profiler
.
cuda_profiler
(
output_file
,
'csv'
)
as
nvprof
:
for
i
in
range
(
epoc
):
input
=
np
.
random
.
random
(
dshape
).
astype
(
'float32'
)
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
{
'data'
:
input
})
os
.
remove
(
output_file
)
def
net_profiler
(
self
,
state
):
if
state
==
'GPU'
and
not
core
.
is_compiled_with_cuda
():
enable_if_gpu
=
state
==
'GPU'
or
state
==
"All"
if
enable_if_gpu
and
not
core
.
is_compiled_with_cuda
():
return
startup_program
=
fluid
.
Program
()
main_program
=
fluid
.
Program
()
...
...
@@ -85,6 +67,9 @@ class TestProfiler(unittest.TestCase):
def
test_cuda_profiler
(
self
):
self
.
net_profiler
(
'GPU'
)
def
test_all_profiler
(
self
):
self
.
net_profiler
(
'All'
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录