Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
07642119
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
07642119
编写于
5月 17, 2018
作者:
Q
qiaolongfei
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into add-mkldnn-to-paddle-lib
上级
de3c5175
dbbeccc4
变更
18
显示空白变更内容
内联
并排
Showing
18 changed file
with
266 addition
and
74 deletion
+266
-74
benchmark/fluid/mnist.py
benchmark/fluid/mnist.py
+10
-6
benchmark/fluid/resnet.py
benchmark/fluid/resnet.py
+8
-4
benchmark/fluid/vgg.py
benchmark/fluid/vgg.py
+8
-4
cmake/inference_lib.cmake
cmake/inference_lib.cmake
+6
-0
doc/fluid/design/concepts/functions_operators_layers.md
doc/fluid/design/concepts/functions_operators_layers.md
+1
-1
paddle/fluid/framework/details/op_handle_base.h
paddle/fluid/framework/details/op_handle_base.h
+8
-0
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+1
-1
paddle/fluid/inference/tensorrt/convert/op_converter.h
paddle/fluid/inference/tensorrt/convert/op_converter.h
+1
-1
paddle/fluid/operators/smooth_l1_loss_op.cc
paddle/fluid/operators/smooth_l1_loss_op.cc
+23
-2
paddle/scripts/paddle_build.sh
paddle/scripts/paddle_build.sh
+33
-6
paddle/scripts/paddle_docker_build.sh
paddle/scripts/paddle_docker_build.sh
+1
-0
python/paddle/fluid/data_feeder.py
python/paddle/fluid/data_feeder.py
+2
-2
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+21
-17
python/paddle/fluid/tests/book/test_label_semantic_roles.py
python/paddle/fluid/tests/book/test_label_semantic_roles.py
+6
-21
python/paddle/fluid/tests/test_data_feeder.py
python/paddle/fluid/tests/test_data_feeder.py
+54
-7
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+2
-2
python/paddle/fluid/trainer.py
python/paddle/fluid/trainer.py
+33
-0
tools/test_runner.py
tools/test_runner.py
+48
-0
未找到文件。
benchmark/fluid/mnist.py
浏览文件 @
07642119
...
...
@@ -159,6 +159,7 @@ def run_benchmark(model, args):
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
args
.
batch_size
)
accuracy
=
fluid
.
metrics
.
Accuracy
()
train_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
True
,
loss_name
=
avg_cost
.
name
)
iters
,
num_samples
,
start_time
=
0
,
0
,
time
.
time
()
for
pass_id
in
range
(
args
.
pass_num
):
accuracy
.
reset
()
...
...
@@ -175,17 +176,20 @@ def run_benchmark(model, args):
y_data
=
np
.
array
(
map
(
lambda
x
:
x
[
1
],
data
)).
astype
(
"int64"
)
y_data
=
y_data
.
reshape
([
len
(
y_data
),
1
])
outs
=
exe
.
run
(
fluid
.
default_main_program
(),
outs
=
train_exe
.
run
(
feed
=
{
"pixel"
:
img_data
,
"label"
:
y_data
},
fetch_list
=
[
avg_cost
,
batch_acc
,
batch_size_tensor
]
fetch_list
=
[
avg_cost
.
name
,
batch_acc
.
name
,
batch_size_tensor
.
name
]
)
# The accuracy is the accumulation of batches, but not the current batch.
accuracy
.
update
(
value
=
outs
[
1
],
weight
=
outs
[
2
])
accuracy
.
update
(
value
=
np
.
array
(
np
.
mean
(
outs
[
1
])),
weight
=
np
.
mean
(
np
.
array
(
outs
[
2
])))
iters
+=
1
num_samples
+=
len
(
y_data
)
loss
=
np
.
array
(
outs
[
0
]
)
acc
=
np
.
array
(
outs
[
1
]
)
loss
=
np
.
mean
(
np
.
array
(
outs
[
0
])
)
acc
=
np
.
mean
(
np
.
array
(
outs
[
1
])
)
train_losses
.
append
(
loss
)
train_accs
.
append
(
acc
)
print
(
"Pass: %d, Iter: %d, Loss: %f, Accuracy: %f"
%
...
...
benchmark/fluid/resnet.py
浏览文件 @
07642119
...
...
@@ -241,6 +241,7 @@ def run_benchmark(model, args):
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
accuracy
=
fluid
.
average
.
WeightedAverage
()
train_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
True
,
loss_name
=
avg_cost
.
name
)
if
args
.
use_fake_data
:
data
=
train_reader
().
next
()
image
=
np
.
array
(
map
(
lambda
x
:
x
[
0
].
reshape
(
dshape
),
data
)).
astype
(
...
...
@@ -264,14 +265,17 @@ def run_benchmark(model, args):
data
)).
astype
(
'float32'
)
label
=
np
.
array
(
map
(
lambda
x
:
x
[
1
],
data
)).
astype
(
'int64'
)
label
=
label
.
reshape
([
-
1
,
1
])
loss
,
acc
,
weight
=
exe
.
run
(
fluid
.
default_main_program
(),
loss
,
acc
,
weight
=
train_exe
.
run
(
feed
=
{
'data'
:
image
,
'label'
:
label
},
fetch_list
=
[
avg_cost
,
batch_acc
,
batch_size_tensor
])
fetch_list
=
[
avg_cost
.
name
,
batch_acc
.
name
,
batch_size_tensor
.
name
])
iters
+=
1
num_samples
+=
len
(
label
)
accuracy
.
add
(
value
=
acc
,
weight
=
weight
)
accuracy
.
add
(
value
=
np
.
array
(
np
.
mean
(
acc
)),
weight
=
np
.
mean
(
weight
))
loss
=
np
.
mean
(
np
.
array
(
loss
))
acc
=
np
.
mean
(
np
.
array
(
acc
))
train_losses
.
append
(
loss
)
train_accs
.
append
(
acc
)
print
(
"Pass: %d, Iter: %d, Loss: %f, Accuracy: %f"
%
...
...
benchmark/fluid/vgg.py
浏览文件 @
07642119
...
...
@@ -169,6 +169,7 @@ def main():
iters
,
num_samples
,
start_time
=
0
,
0
,
time
.
time
()
accuracy
=
fluid
.
average
.
WeightedAverage
()
train_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
True
,
loss_name
=
avg_cost
.
name
)
for
pass_id
in
range
(
args
.
pass_num
):
accuracy
.
reset
()
train_accs
=
[]
...
...
@@ -184,14 +185,17 @@ def main():
y_data
=
np
.
array
(
map
(
lambda
x
:
x
[
1
],
data
)).
astype
(
"int64"
)
y_data
=
y_data
.
reshape
([
-
1
,
1
])
loss
,
acc
,
weight
=
exe
.
run
(
fluid
.
default_main_program
(),
loss
,
acc
,
weight
=
train_exe
.
run
(
feed
=
{
"pixel"
:
img_data
,
"label"
:
y_data
},
fetch_list
=
[
avg_cost
,
batch_acc
,
batch_size_tensor
])
accuracy
.
add
(
value
=
acc
,
weight
=
weight
)
fetch_list
=
[
avg_cost
.
name
,
batch_acc
.
name
,
batch_size_tensor
.
name
])
accuracy
.
add
(
value
=
np
.
array
(
np
.
mean
(
acc
)),
weight
=
np
.
mean
(
weight
))
iters
+=
1
num_samples
+=
len
(
y_data
)
loss
=
np
.
mean
(
np
.
array
(
loss
))
acc
=
np
.
mean
(
np
.
array
(
acc
))
print
(
"Pass = %d, Iter = %d, Loss = %f, Accuracy = %f"
%
(
pass_id
,
iters
,
loss
,
acc
)
...
...
cmake/inference_lib.cmake
浏览文件 @
07642119
...
...
@@ -156,4 +156,10 @@ copy(string_lib
DSTS
${
dst_dir
}
/
${
module
}
${
dst_dir
}
/
${
module
}
/tinyformat
)
set
(
module
"pybind"
)
copy
(
pybind_lib
SRCS
${
CMAKE_CURRENT_BINARY_DIR
}
/paddle/fluid/
${
module
}
/pybind.h
DSTS
${
dst_dir
}
/
${
module
}
)
add_custom_target
(
inference_lib_dist DEPENDS
${
inference_lib_dist_dep
}
)
doc/fluid/design/concepts/functions_operators_layers.md
浏览文件 @
07642119
...
...
@@ -40,7 +40,7 @@ template <typename T>
class
FCOp
:
public
OperatorBase
{
public:
void
Run
(...)
{
add
(
mul
(
Input
<
T
>
(
"X"
),
Input
<
T
>
(
"W"
)),
Input
<
T
>
(
"b"
);
add
(
mul
(
Input
<
T
>
(
"X"
),
Input
<
T
>
(
"W"
)),
Input
<
T
>
(
"b"
)
)
;
}
};
REGISTER_OP
(
FCOp
,
"fc"
);
...
...
paddle/fluid/framework/details/op_handle_base.h
浏览文件 @
07642119
...
...
@@ -70,6 +70,14 @@ class OpHandleBase {
const
std
::
vector
<
VarHandleBase
*>
&
Inputs
()
const
{
return
inputs_
;
}
size_t
NoDupInputSize
()
const
{
std
::
unordered_set
<
VarHandleBase
*>
res
;
for
(
auto
*
var
:
inputs_
)
{
res
.
emplace
(
var
);
}
return
res
.
size
();
}
const
std
::
vector
<
VarHandleBase
*>
&
Outputs
()
const
{
return
outputs_
;
}
protected:
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
07642119
...
...
@@ -174,7 +174,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
void
ThreadedSSAGraphExecutor
::
InsertPendingOp
(
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
*
pending_ops
,
OpHandleBase
*
op_instance
)
const
{
pending_ops
->
insert
({
op_instance
,
op_instance
->
Inputs
().
s
ize
()});
pending_ops
->
insert
({
op_instance
,
op_instance
->
NoDupInputS
ize
()});
}
void
ThreadedSSAGraphExecutor
::
InsertPendingVar
(
...
...
paddle/fluid/inference/tensorrt/convert/op_converter.h
浏览文件 @
07642119
...
...
@@ -49,7 +49,7 @@ class OpConverter {
// convert fluid block to tensorrt network
void
ConvertBlock
(
const
framework
::
proto
::
BlockDesc
&
block
,
TensorRTEngine
*
engine
)
{
for
(
size_
t
i
=
0
;
i
<
block
.
ops_size
();
i
++
)
{
for
(
in
t
i
=
0
;
i
<
block
.
ops_size
();
i
++
)
{
const
auto
&
op
=
block
.
ops
(
i
);
OpConverter
::
Run
(
op
,
engine
);
}
...
...
paddle/fluid/operators/smooth_l1_loss_op.cc
浏览文件 @
07642119
...
...
@@ -105,7 +105,7 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
in_dims
=
ctx
->
GetInputDim
(
"
X
"
);
auto
in_dims
=
ctx
->
GetInputDim
(
"
Diff
"
);
auto
out_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
));
PADDLE_ENFORCE_GE
(
out_dims
.
size
(),
2
,
...
...
@@ -127,12 +127,33 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
}
};
class
SmoothL1LossGradMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
op
=
new
framework
::
OpDesc
();
op
->
SetType
(
"smooth_l1_loss_grad"
);
op
->
SetInput
(
"InsideWeight"
,
Input
(
"InsideWeight"
));
op
->
SetInput
(
"OutsideWeight"
,
Input
(
"OutsideWeight"
));
op
->
SetInput
(
"Diff"
,
Output
(
"Diff"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
op
->
SetAttrMap
(
Attrs
());
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Y"
),
InputGrad
(
"Y"
));
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
op
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
smooth_l1_loss
,
ops
::
SmoothL1LossOp
,
ops
::
SmoothL1LossOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
ops
::
SmoothL1LossGradMaker
);
REGISTER_OPERATOR
(
smooth_l1_loss_grad
,
ops
::
SmoothL1LossGradOp
);
REGISTER_OP_CPU_KERNEL
(
smooth_l1_loss
,
...
...
paddle/scripts/paddle_build.sh
浏览文件 @
07642119
...
...
@@ -20,19 +20,15 @@
#=================================================
function
print_usage
()
{
RED
=
'\033[0;31m'
BLUE
=
'\033[0;34m'
BOLD
=
'\033[1m'
NONE
=
'\033[0m'
echo
-e
"
\n
${
RED
}
Usage
${
NONE
}
:
${
BOLD
}
$
0
${
NONE
}
[OPTION]"
${
BOLD
}$
{
SCRIPT_NAME
}
${
NONE
}
[OPTION]"
echo
-e
"
\n
${
RED
}
Options
${
NONE
}
:
${
BLUE
}
build
${
NONE
}
: run build for x86 platform
${
BLUE
}
build_android
${
NONE
}
: run build for android platform
${
BLUE
}
build_ios
${
NONE
}
: run build for ios platform
${
BLUE
}
test
${
NONE
}
: run all unit tests
${
BLUE
}
single_test
${
NONE
}
: run a single unit test
${
BLUE
}
bind_test
${
NONE
}
: parallel tests bind to different GPU
${
BLUE
}
doc
${
NONE
}
: generate paddle documents
${
BLUE
}
html
${
NONE
}
: convert C++ source code into HTML
...
...
@@ -45,7 +41,15 @@ function print_usage() {
}
function
init
()
{
RED
=
'\033[0;31m'
BLUE
=
'\033[0;34m'
BOLD
=
'\033[1m'
NONE
=
'\033[0m'
PADDLE_ROOT
=
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
/../../"
&&
pwd
)
"
if
[
-z
"
${
SCRIPT_NAME
}
"
]
;
then
SCRIPT_NAME
=
$0
fi
}
function
cmake_gen
()
{
...
...
@@ -309,6 +313,25 @@ EOF
fi
}
function
single_test
()
{
TEST_NAME
=
$1
if
[
-z
"
${
TEST_NAME
}
"
]
;
then
echo
-e
"
${
RED
}
Usage:
${
NONE
}
"
echo
-e
"
${
BOLD
}${
SCRIPT_NAME
}${
NONE
}
${
BLUE
}
single_test
${
NONE
}
[test_name]"
exit
1
fi
mkdir
-p
${
PADDLE_ROOT
}
/build
cd
${
PADDLE_ROOT
}
/build
if
[
${
WITH_TESTING
:-
ON
}
==
"ON"
]
;
then
cat
<<
EOF
========================================
Running
${
TEST_NAME
}
...
========================================
EOF
ctest
--output-on-failure
-R
${
TEST_NAME
}
fi
}
function
bind_test
()
{
# the number of process to run tests
NUM_PROC
=
6
...
...
@@ -480,6 +503,7 @@ function main() {
build
)
cmake_gen
${
PYTHON_ABI
:-
""
}
build
gen_dockerfile
;;
build_android
)
build_android
...
...
@@ -490,6 +514,9 @@ function main() {
test
)
run_test
;;
single_test
)
single_test
$2
;;
bind_test
)
bind_test
;;
...
...
paddle/scripts/paddle_docker_build.sh
浏览文件 @
07642119
...
...
@@ -63,6 +63,7 @@ EOL
${
DOCKER_CMD
}
run
-it
\
--name
$CONTAINER_ID
\
${
DOCKER_ENV
}
\
-e
SCRIPT_NAME
=
$0
\
-v
$PADDLE_ROOT
:/paddle
\
-v
${
HOME
}
/.ccache:/root/.ccache
\
-w
/paddle
\
...
...
python/paddle/fluid/data_feeder.py
浏览文件 @
07642119
...
...
@@ -54,9 +54,9 @@ class DataToLoDTensorConverter(object):
self
.
data
.
append
(
data
)
else
:
cur_lod_len
=
len
(
data
)
lod
[
-
1
].
append
(
lod
[
-
1
][
-
1
]
+
cur_lod_len
)
lod
[
0
].
append
(
lod
[
0
][
-
1
]
+
cur_lod_len
)
for
each_data
in
data
:
self
.
_feed_impl_
(
each_data
,
lod
[
:
-
1
],
lod_level
-
1
)
self
.
_feed_impl_
(
each_data
,
lod
[
1
:
],
lod_level
-
1
)
def
done
(
self
):
arr
=
numpy
.
array
(
self
.
data
,
dtype
=
self
.
dtype
).
reshape
(
self
.
shape
)
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
07642119
...
...
@@ -1329,6 +1329,8 @@ def sequence_pool(input, pool_type):
sqrt : out.data = [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2),
6.93=(2+4+6)/sqrt(3), 4.24=(5+1)/sqrt(2)
max : out.data = [3, 6, 5], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1)
last : out.data = [3, 6, 1], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1)
first : out.data = [1, 2, 5], where 1=first(1,3), 2=first(2,4,6), 5=first(5,1)
Args:
input(variable): The input variable which is a LoDTensor.
...
...
@@ -1348,6 +1350,8 @@ def sequence_pool(input, pool_type):
sum_x = fluid.layers.sequence_pool(input=x, pool_type='sum')
sqrt_x = fluid.layers.sequence_pool(input=x, pool_type='sqrt')
max_x = fluid.layers.sequence_pool(input=x, pool_type='max')
last_x = fluid.layers.sequence_pool(input=x, pool_type='last')
first_x = fluid.layers.sequence_pool(input=x, pool_type='first')
"""
helper
=
LayerHelper
(
'sequence_pool'
,
**
locals
())
dtype
=
helper
.
input_dtype
()
...
...
@@ -3263,35 +3267,35 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None):
"""
**Smooth L1 Loss Operator. **
This operator computes the smooth
l
1 loss for X and Y.
This operator computes the smooth
L
1 loss for X and Y.
The operator takes the first dimension of X and Y as batch size.
For each instance, it computes the smooth
l
1 loss element by element first
For each instance, it computes the smooth
L
1 loss element by element first
and then sums all the losses. So the shape of Out is [batch_size, 1].
Args:
x (Variable): A tensor with rank at least 2. The input value of smooth
l
1 loss op with shape [batch_size, dim1, ..., dimN].
L
1 loss op with shape [batch_size, dim1, ..., dimN].
y (Variable): A tensor with rank at least 2. The target value of smooth
l
1 loss op with same shape as x.
L
1 loss op with same shape as x.
inside_weight (Variable|None): A tensor with rank at least 2. This
input is optional and should have same shape with x. If provided,
the result of (x - y) will be multiplied by this tensor element by
element.
outside_weight (Variable|None): A tensor with rank at least 2. This
input is optional and should have same shape with x. If provided,
the out smooth
l
1 loss will be multiplied by this tensor element
the out smooth
L
1 loss will be multiplied by this tensor element
by element.
sigma (float|None): Hyper parameter of smooth
l
1 loss op. A float scalar
sigma (float|None): Hyper parameter of smooth
L
1 loss op. A float scalar
with default value 1.0.
Returns:
Variable: A tensor with rank be 2. The output smooth
l
1 loss with
Variable: A tensor with rank be 2. The output smooth
L
1 loss with
shape [batch_size, 1].
Examples:
.. code-block:: python
data = fluid.layers.data(name='data', shape=[128], dtype='float32')
label = fluid.layers.data(name='label', shape=[100], dtype='
int64
')
label = fluid.layers.data(name='label', shape=[100], dtype='
float32
')
fc = fluid.layers.fc(input=data, size=100)
out = fluid.layers.smooth_l1(x=fc, y=label)
"""
...
...
python/paddle/fluid/tests/book/test_label_semantic_roles.py
浏览文件 @
07642119
...
...
@@ -182,12 +182,6 @@ def train(use_cuda, save_dirname=None, is_local=True):
crf_decode
=
fluid
.
layers
.
crf_decoding
(
input
=
feature_out
,
param_attr
=
fluid
.
ParamAttr
(
name
=
'crfw'
))
chunk_evaluator
=
fluid
.
evaluator
.
ChunkEvaluator
(
input
=
crf_decode
,
label
=
target
,
chunk_scheme
=
"IOB"
,
num_chunk_types
=
int
(
math
.
ceil
((
label_dict_len
-
1
)
/
2.0
)))
train_data
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
conll05
.
test
(),
buf_size
=
8192
),
...
...
@@ -203,7 +197,6 @@ def train(use_cuda, save_dirname=None, is_local=True):
def
train_loop
(
main_program
):
exe
.
run
(
fluid
.
default_startup_program
())
embedding_param
=
fluid
.
global_scope
().
find_var
(
embedding_name
).
get_tensor
()
embedding_param
.
set
(
...
...
@@ -213,27 +206,19 @@ def train(use_cuda, save_dirname=None, is_local=True):
start_time
=
time
.
time
()
batch_id
=
0
for
pass_id
in
xrange
(
PASS_NUM
):
chunk_evaluator
.
reset
(
exe
)
for
data
in
train_data
():
cost
,
precision
,
recall
,
f1_score
=
exe
.
run
(
main_program
,
cost
=
exe
.
run
(
main_program
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
avg_cost
]
+
chunk_evaluator
.
metrics
)
pass_precision
,
pass_recall
,
pass_f1_score
=
chunk_evaluator
.
eval
(
exe
)
fetch_list
=
[
avg_cost
])
cost
=
cost
[
0
]
if
batch_id
%
10
==
0
:
print
(
"avg_cost:"
+
str
(
cost
)
+
" precision:"
+
str
(
precision
)
+
" recall:"
+
str
(
recall
)
+
" f1_score:"
+
str
(
f1_score
)
+
" pass_precision:"
+
str
(
pass_precision
)
+
" pass_recall:"
+
str
(
pass_recall
)
+
" pass_f1_score:"
+
str
(
pass_f1_score
))
print
(
"avg_cost:"
+
str
(
cost
))
if
batch_id
!=
0
:
print
(
"second per batch: "
+
str
((
time
.
time
(
)
-
start_time
)
/
batch_id
))
# Set the threshold low to speed up the CI test
if
float
(
pass_precision
)
>
0.01
:
if
float
(
cost
)
<
60.0
:
if
save_dirname
is
not
None
:
# TODO(liuyiqun): Change the target to crf_decode
fluid
.
io
.
save_inference_model
(
save_dirname
,
[
...
...
python/paddle/fluid/tests/test_data_feeder.py
浏览文件 @
07642119
...
...
@@ -13,15 +13,62 @@
# limitations under the License.
import
paddle.fluid
as
fluid
import
unittest
def
test_converter
():
class
TestDataFeeder
(
unittest
.
TestCase
):
def
test_lod_level_0_converter
(
self
):
img
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
1
,
28
,
28
])
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
feeder
=
fluid
.
DataFeeder
([
img
,
label
],
fluid
.
CPUPlace
())
result
=
feeder
.
feed
([[[
0
]
*
784
,
[
9
]],
[[
1
]
*
784
,
[
1
]]
])
result
=
feeder
.
feed
([([
0
]
*
784
,
[
9
]),
([
1
]
*
784
,
[
1
])
])
print
(
result
)
self
.
assertEqual
(
result
[
'image'
].
shape
(),
[
2
,
1
,
28
,
28
])
self
.
assertEqual
(
result
[
'label'
].
shape
(),
[
2
,
1
])
self
.
assertEqual
(
result
[
'image'
].
lod
(),
[])
self
.
assertEqual
(
result
[
'label'
].
lod
(),
[])
def
test_lod_level_1_converter
(
self
):
# lod_level = 1
# each sentence has a different number of words
sentences
=
fluid
.
layers
.
data
(
name
=
'sentences'
,
shape
=
[
1
],
dtype
=
'int64'
,
lod_level
=
1
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
feeder
=
fluid
.
DataFeeder
([
sentences
,
label
],
fluid
.
CPUPlace
())
# lod = [[0, 3, 5, 9]]
# data = [[1, 2, 3], [4, 5], [6, 7, 8, 9]]
# label = [1] * len(data)
result
=
feeder
.
feed
(
[([
1
,
2
,
3
],
[
1
]),
([
4
,
5
],
[
1
]),
([
6
,
7
,
8
,
9
],
[
1
])])
print
(
result
)
self
.
assertEqual
(
result
[
'sentences'
].
shape
(),
[
9
,
1
])
self
.
assertEqual
(
result
[
'label'
].
shape
(),
[
3
,
1
])
self
.
assertEqual
(
result
[
'sentences'
].
lod
(),
[[
0
,
3
,
5
,
9
]])
self
.
assertEqual
(
result
[
'label'
].
lod
(),
[])
def
test_lod_level_2_converter
(
self
):
# lod_level = 2
# paragraphs -> sentences -> words
paragraphs
=
fluid
.
layers
.
data
(
name
=
'paragraphs'
,
shape
=
[
1
],
dtype
=
'int64'
,
lod_level
=
2
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
feeder
=
fluid
.
DataFeeder
([
paragraphs
,
label
],
fluid
.
CPUPlace
())
# lod = [[0, 2, 3], [0, 3, 5, 9]]
# data = [[[1, 2, 3], [4, 5]], [[6, 7, 8, 9]]]
# label = [1] * len(data)
result
=
feeder
.
feed
(
[([[
1
,
2
,
3
],
[
4
,
5
]],
[
1
]),
([[
6
,
7
,
8
,
9
]],
[
1
])])
print
(
result
)
self
.
assertEqual
(
result
[
'paragraphs'
].
shape
(),
[
9
,
1
])
self
.
assertEqual
(
result
[
'label'
].
shape
(),
[
2
,
1
])
self
.
assertEqual
(
result
[
'paragraphs'
].
lod
(),
[[
0
,
2
,
3
],
[
0
,
3
,
5
,
9
]])
self
.
assertEqual
(
result
[
'label'
].
lod
(),
[])
if
__name__
==
'__main__'
:
test_converter
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
07642119
...
...
@@ -28,11 +28,11 @@ function(py_test_modules TARGET_NAME)
if
(
WITH_TESTING
)
set
(
options
""
)
set
(
oneValueArgs
""
)
set
(
multiValueArgs MODULES DEPS
ARGS
ENVS
)
set
(
multiValueArgs MODULES DEPS ENVS
)
cmake_parse_arguments
(
py_test_modules
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
add_test
(
NAME
${
TARGET_NAME
}
COMMAND env PYTHONPATH=
${
PADDLE_BINARY_DIR
}
/python
${
py_test_modules_ENVS
}
${
PYTHON_EXECUTABLE
}
-u -m unittest --verbose
${
py_test_modules_MODULES
}
${
py_test_modules_ARG
S
}
${
PYTHON_EXECUTABLE
}
${
PADDLE_SOURCE_DIR
}
/tools/test_runner.py
${
py_test_modules_MODULE
S
}
WORKING_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
)
endif
()
endfunction
()
...
...
python/paddle/fluid/trainer.py
浏览文件 @
07642119
...
...
@@ -131,7 +131,40 @@ class Trainer(object):
# load params from param_path into scope
io
.
load_persistables
(
exe
,
dirname
=
param_path
)
def
_transpile_nccl2_dist
(
self
):
# PADDLE_TRAINER_IPS
if
"PADDLE_TRAINER_IPS"
not
in
os
.
environ
:
self
.
nccl_id_var
=
None
else
:
self
.
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
))
port
=
os
.
getenv
(
"PADDLE_PSERVER_PORT"
)
worker_ips
=
os
.
getenv
(
"PADDLE_TRAINER_IPS"
)
worker_endpoints
=
[]
for
ip
in
worker_ips
.
split
(
","
):
worker_endpoints
.
append
(
':'
.
join
([
ip
,
port
]))
self
.
num_trainers
=
len
(
worker_endpoints
)
current_endpoint
=
os
.
getenv
(
"POD_IP"
)
+
":"
+
port
worker_endpoints
.
remove
(
current_endpoint
)
# TODO(wuyi): use self.nccl_id_var, self.num_trainers and self.trainer_id
# in ParallelExecutor to start
# distributed training using NCCL2
self
.
nccl_id_var
=
self
.
startup_program
.
global_block
().
create_var
(
name
=
"NCCLID"
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
self
.
startup_program
.
global_block
().
append_op
(
type
=
"gen_nccl_id"
,
inputs
=
{},
outputs
=
{
"NCCLID"
:
self
.
nccl_id_var
},
attrs
=
{
"endpoint"
:
current_endpoint
,
"endpoint_list"
:
worker_endpoints
,
"trainer_id"
:
self
.
trainer_id
})
def
_dist_transpile_if_necessary
(
self
,
optimize_ops
,
params_grads
):
self
.
_transpile_nccl2_dist
()
if
self
.
nccl_id_var
!=
None
:
return
if
"PADDLE_TRAINING_ROLE"
not
in
os
.
environ
:
return
...
...
tools/test_runner.py
0 → 100644
浏览文件 @
07642119
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
os
import
sys
import
paddle.fluid
as
fluid
import
importlib
import
cStringIO
def
main
():
sys
.
path
.
append
(
os
.
getcwd
())
some_test_failed
=
False
for
module_name
in
sys
.
argv
[
1
:]:
buffer
=
cStringIO
.
StringIO
()
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
scope
=
fluid
.
core
.
Scope
()
with
fluid
.
program_guard
(
main
,
startup
):
with
fluid
.
scope_guard
(
scope
):
with
fluid
.
unique_name
.
guard
():
test_loader
=
unittest
.
TestLoader
()
module
=
importlib
.
import_module
(
module_name
)
tests
=
test_loader
.
loadTestsFromModule
(
module
)
res
=
unittest
.
TextTestRunner
(
stream
=
buffer
).
run
(
tests
)
if
not
res
.
wasSuccessful
():
some_test_failed
=
True
print
>>
sys
.
stderr
,
module_name
,
'failed
\n
'
,
buffer
.
getvalue
(
)
if
some_test_failed
:
exit
(
1
)
if
__name__
==
'__main__'
:
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录