Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e63013a8
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e63013a8
编写于
4月 18, 2018
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into feature/add_reduce_op_handle
上级
1eeb2e00
61f4baa1
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
304 addition
and
20 deletion
+304
-20
paddle/fluid/framework/details/reduce_op_handle.cc
paddle/fluid/framework/details/reduce_op_handle.cc
+0
-1
paddle/fluid/operators/conv_mkldnn_op.cc
paddle/fluid/operators/conv_mkldnn_op.cc
+18
-14
paddle/fluid/operators/softmax_mkldnn_op.cc
paddle/fluid/operators/softmax_mkldnn_op.cc
+9
-0
paddle/fluid/operators/softmax_op.cc
paddle/fluid/operators/softmax_op.cc
+3
-0
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+2
-0
python/paddle/fluid/inference_transpiler.py
python/paddle/fluid/inference_transpiler.py
+240
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+7
-2
python/paddle/fluid/tests/book/test_image_classification.py
python/paddle/fluid/tests/book/test_image_classification.py
+25
-3
未找到文件。
paddle/fluid/framework/details/reduce_op_handle.cc
浏览文件 @
e63013a8
...
...
@@ -91,7 +91,6 @@ void ReduceOpHandle::RunImpl() {
if
(
paddle
::
platform
::
is_cpu_place
(
pre_place
))
{
ReduceLoDTensor
func
(
lod_tensors
,
trg
);
VisitDataType
(
ToDataType
(
lod_tensors
[
0
].
type
()),
func
);
}
else
if
(
paddle
::
platform
::
is_gpu_place
(
pre_place
))
{
#ifdef PADDLE_WITH_CUDA
auto
out_p
=
out_var_handles
[
0
]
->
place_
;
...
...
paddle/fluid/operators/conv_mkldnn_op.cc
浏览文件 @
e63013a8
...
...
@@ -72,10 +72,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto
dst_md
=
platform
::
MKLDNNMemDesc
(
dst_tz
,
mkldnn
::
memory
::
data_type
::
f32
,
mkldnn
::
memory
::
format
::
nchw
);
auto
src_memory
=
mkldnn
::
memory
({
src_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
input_data
));
auto
weights_memory
=
mkldnn
::
memory
({
weights_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
filter_data
));
auto
src_memory
=
mkldnn
::
memory
({
src_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
input_data
)));
auto
weights_memory
=
mkldnn
::
memory
({
weights_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
filter_data
)));
auto
dst_memory
=
mkldnn
::
memory
({
dst_md
,
mkldnn_engine
},
output_data
);
std
::
shared_ptr
<
mkldnn
::
convolution_forward
::
primitive_desc
>
conv_pd
=
...
...
@@ -180,9 +182,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
dst_tz
,
mkldnn
::
memory
::
data_type
::
f32
,
mkldnn
::
memory
::
format
::
nchw
);
// create memory
auto
diff_dst_memory
=
mkldnn
::
memory
(
{
diff_weights_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
output_grad_data
));
auto
diff_dst_memory
=
mkldnn
::
memory
(
{
diff_weights_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
output_grad_data
)
));
// Retrieve conv_pd from device context
auto
conv_pd
=
std
::
static_pointer_cast
<
mkldnn
::
convolution_forward
::
primitive_desc
>
(
...
...
@@ -202,8 +204,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto
diff_weights_memory
=
mkldnn
::
memory
({
diff_weights_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
filter_grad_data
));
auto
src_memory
=
mkldnn
::
memory
({
src_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
input_data
));
auto
src_memory
=
mkldnn
::
memory
({
src_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
input_data
)));
// create backward conv primitive for weights
auto
conv_bwd_weights_prim
=
mkldnn
::
convolution_backward_weights
(
...
...
@@ -222,11 +225,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
strides
,
paddings
,
*
conv_pd
,
mkldnn_engine
);
// create memory
auto
diff_src_memory
=
mkldnn
::
memory
({
diff_src_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
input_grad_data
));
auto
weights_memory
=
mkldnn
::
memory
(
{
weights_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
filter_data
));
auto
diff_src_memory
=
mkldnn
::
memory
(
{
diff_src_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
input_grad_data
)));
auto
weights_memory
=
mkldnn
::
memory
({
weights_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
filter_data
)));
// create backward conv primitive for data
auto
conv_bwd_data_prim
=
mkldnn
::
convolution_backward_data
(
...
...
paddle/fluid/operators/softmax_mkldnn_op.cc
浏览文件 @
e63013a8
...
...
@@ -73,6 +73,15 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
softmax_dst_memory
);
std
::
vector
<
primitive
>
pipeline
{
softmax
};
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
if
(
!
is_test
)
{
T
threshold
=
exp
(
-
64
);
for
(
size_t
i
=
0
;
i
<
dst_tz
[
0
]
*
dst_tz
[
1
];
++
i
)
{
output_data
[
i
]
=
output_data
[
i
]
<
threshold
?
threshold
:
output_data
[
i
];
}
}
}
};
...
...
paddle/fluid/operators/softmax_op.cc
浏览文件 @
e63013a8
...
...
@@ -97,6 +97,9 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr
<
bool
>
(
"use_mkldnn"
,
"(bool, default false) Only used in mkldnn kernel"
)
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"is_test"
,
"Disable epsilon adding to softmax results. Used by MKLDNN."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
Softmax Operator.
...
...
python/paddle/fluid/__init__.py
浏览文件 @
e63013a8
...
...
@@ -37,6 +37,7 @@ from distribute_transpiler import DistributeTranspiler
from
distribute_transpiler_simple
import
SimpleDistributeTranspiler
from
concurrency
import
(
Go
,
make_channel
,
channel_send
,
channel_recv
,
channel_close
,
Select
)
from
inference_transpiler
import
InferenceTranspiler
import
clip
from
memory_optimization_transpiler
import
memory_optimize
,
release_memory
import
profiler
...
...
@@ -66,6 +67,7 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + [
'clip'
,
'SimpleDistributeTranspiler'
,
'DistributeTranspiler'
,
'InferenceTranspiler'
,
'memory_optimize'
,
'release_memory'
,
'profiler'
,
...
...
python/paddle/fluid/inference_transpiler.py
0 → 100644
浏览文件 @
e63013a8
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
from
framework
import
Program
from
executor
import
global_scope
from
.
import
core
class
InferenceTranspiler
:
def
transpile
(
self
,
program
,
place
,
scope
=
None
):
'''
Transpile the program. Support only fuse batch normalization now.
:param program: program to transpile
:type program: Program
:param place: inference place
:type place: Place
:param scope: inference scope
:type scope: Scope or None
'''
if
not
isinstance
(
program
,
Program
):
raise
TypeError
(
"program should be as Program type"
)
if
not
isinstance
(
place
,
core
.
CPUPlace
)
and
not
isinstance
(
place
,
core
.
CUDAPlace
):
raise
TypeError
(
"place should be as CPUPlace/CUDAPlace type"
)
if
scope
is
None
:
scope
=
global_scope
()
if
not
isinstance
(
scope
,
core
.
Scope
):
raise
TypeError
(
"scope should be as Scope type or None"
)
self
.
fuse_batch_norm
(
program
,
place
,
scope
)
def
fuse_batch_norm
(
self
,
program
,
place
,
scope
):
'''
Transpile the program by fused batch normalization.
The batch normalization followed the convolution or fully connected layer
can be integrated with them. Doing so will give us a forward acceleration,
especially in environments like mobile or embedded.
For input X:
- Conv process: X = input * W + bias
- Batch norm process: X' = (X - mean) / std
- Scale Process: Y = a * X' + b
After fuse into one operation:
Y = (input * W + bias - mean) / std * a + b
= input * a * W / std + ((bias - mean) / std * a + b)
The operator transformation is:
- before:
- conv->batch_norm->any_other_op (bias == 0)
- conv->elementwise_add->batch_norm->any_other_op (bias != 0)
- after:
- conv->elementwise_add->any_other_op
The transpile stages are:
1. insert elementwise_add op when bias == 0.
2. fuse the batch_norm's parameters to conv and elementwise_add operators.
3. remove batch_norm ops which are not used in any other ops.
4. adjust the input of any_other_op to be the output of elementwise_add operator.
5. remove unused variables.
:param program: program to transpile
:type program: Program
:param place: inference place
:type place: Place
:param scope: inference scope
:type scope: Scope
'''
self
.
scope
=
scope
self
.
place
=
place
self
.
block
=
program
.
block
(
0
)
self
.
input_map
=
{}
# store the input names should be adjusted
i
=
0
while
i
<
len
(
self
.
block
.
ops
):
current_op
=
self
.
block
.
ops
[
i
]
# TODO(luotao1): consider only conv2d now. fc would be delt later.
if
current_op
.
type
in
[
'conv2d'
]:
# TODO(luotao1): consider single chain network now.
# For branch network, we counldn't use block.ops[i + 1] as
# the judgment condition.
next_op
=
self
.
block
.
ops
[
i
+
1
]
# conv2d without bias
if
(
next_op
.
type
==
'batch_norm'
):
# insert bias op
bias_op
=
self
.
_insert_bias_op
(
i
+
1
,
current_op
,
next_op
)
# fuse batch_norm
self
.
_fuse_param
(
current_op
,
next_op
,
bias_op
,
0
)
# remove batch_norm_op
self
.
block
.
remove_op
(
i
+
2
)
i
=
i
+
1
# conv2d with bias, the next_op.type is elementwise_add
elif
(
next_op
.
type
==
'elementwise_add'
):
next_next_op
=
self
.
block
.
ops
[
i
+
2
]
if
(
next_next_op
.
type
==
'batch_norm'
):
# fuse batch_norm
self
.
_fuse_param
(
current_op
,
next_next_op
,
next_op
,
1
)
# remove batch_norm_op
self
.
block
.
remove_op
(
i
+
2
)
i
=
i
+
1
i
=
i
+
1
self
.
_adjust_input
()
self
.
_remove_unused_var
()
# TODO(luotao): use clone() method to flush the program.desc in force,
# since some large program.desc will not be flushed immediately.
# And a better solution will be considered later.
program
=
program
.
clone
()
# ====================== private transpiler functions =====================
def
_insert_bias_op
(
self
,
index
,
current_op
,
bn_op
):
'''
Construct elementwise_add operator for adding bias
and insert it into program.
:param index: insert location of bias_op
:type index: Int
:param current_op: current operator (conv or fc)
:type current_op: Operator
:param bn_op: batch norm operator
:type bn_op: Operator
:return: bias_op
:rtype: Operator
'''
# The input of bias_op is current_op's output and Bias of bn_op
# The output of bias_op is bn_op's output
x_var
=
self
.
block
.
var
(
current_op
.
output
(
"Output"
)[
0
])
y_var
=
self
.
block
.
var
(
bn_op
.
input
(
"Bias"
)[
0
])
out_var
=
self
.
block
.
var
(
bn_op
.
output
(
"Y"
)[
0
])
bias_op
=
self
.
block
.
insert_op
(
index
,
type
=
"elementwise_add"
,
inputs
=
{
"X"
:
x_var
,
"Y"
:
y_var
},
outputs
=
{
"Out"
:
out_var
},
attrs
=
{
"axis"
:
1
})
# dim_start=1
return
bias_op
def
_fuse_param
(
self
,
current_op
,
bn_op
,
bias_op
,
with_bias
):
'''
fuse the batch_norm_op' parameters to current_op (conv or fc)
:param current_op: current operator (conv or fc)
:type current_op: Operator
:param bn_op: batch norm operator
:type bn_op: Operator
:param bias_op: elementwise_add operator for adding bias
:type bias_op: Operator
:param with_bias: If current operator has bias, with_bias = 1; otherwise 0.
:type with_bias: Int
'''
def
_update_param
(
op
,
old_param_name
,
new_param
):
# For the sake of remaining the original variables the same as before,
# create new variables in scope to store the new parameters.
old_param_name
=
old_param_name
[
0
]
old_var
=
self
.
block
.
vars
[
old_param_name
]
new_param_name
=
old_param_name
+
'_fuse_bn'
new_var
=
self
.
block
.
create_parameter
(
name
=
new_param_name
.
encode
(
'ascii'
),
type
=
old_var
.
type
,
dtype
=
old_var
.
dtype
,
shape
=
old_var
.
shape
)
op
.
rename_input
(
old_param_name
,
new_param_name
)
self
.
scope
.
var
(
new_param_name
)
tensor
=
self
.
scope
.
find_var
(
new_param_name
).
get_tensor
()
tensor
.
set
(
np
.
array
(
new_param
),
self
.
place
)
def
_load_param
(
param_name
):
return
np
.
array
(
self
.
scope
.
find_var
(
param_name
[
0
]).
get_tensor
())
bias_bn
=
_load_param
(
bn_op
.
input
(
"Bias"
))
#Bias
scale_bn
=
_load_param
(
bn_op
.
input
(
"Scale"
))
#Scale
mean_bn
=
_load_param
(
bn_op
.
input
(
"Mean"
))
#Mean
var_bn
=
_load_param
(
bn_op
.
input
(
"Variance"
))
#Variance
# TODO(luotao1): consider only conv2d now. fc would be delt later.
current_param
=
_load_param
(
current_op
.
input
(
"Filter"
))
std_bn
=
np
.
float32
(
np
.
sqrt
(
np
.
add
(
var_bn
,
1e-5
)))
tmp
=
np
.
float32
(
np
.
divide
(
scale_bn
,
std_bn
))
# add bias of batch_norm_op to conv2d
if
with_bias
:
bias
=
_load_param
(
bias_op
.
input
(
"Y"
))
else
:
bias
=
np
.
zeros
(
bias_bn
.
shape
)
bias
=
np
.
float32
(
np
.
add
(
np
.
multiply
(
np
.
subtract
(
bias
,
mean_bn
),
tmp
),
bias_bn
))
# re-compute weight of conv2d
tmp
=
tmp
.
reshape
(
tmp
.
shape
[
0
],
-
1
)
dst_param
=
current_param
.
reshape
((
tmp
.
shape
[
0
],
-
1
))
dst_param
=
np
.
float32
(
np
.
multiply
(
dst_param
,
tmp
))
dst_param
=
dst_param
.
reshape
(
current_param
.
shape
)
# update parameters
_update_param
(
current_op
,
current_op
.
input
(
"Filter"
),
dst_param
)
_update_param
(
bias_op
,
bias_op
.
input
(
"Y"
),
bias
)
# collect the renamed input
self
.
input_map
[
bn_op
.
output
(
"Y"
)[
0
]]
=
bias_op
.
output
(
"Out"
)[
0
]
def
_adjust_input
(
self
):
for
i
in
range
(
len
(
self
.
block
.
ops
)):
current_op
=
self
.
block
.
ops
[
i
]
for
input_arg
in
current_op
.
input_arg_names
:
if
input_arg
in
self
.
input_map
:
current_op
.
rename_input
(
input_arg
,
self
.
input_map
[
input_arg
])
def
_remove_unused_var
(
self
):
'''
remove unused varibles in program
'''
args
=
[]
for
i
in
range
(
len
(
self
.
block
.
ops
)):
current_op
=
self
.
block
.
ops
[
i
]
args
+=
current_op
.
input_arg_names
args
+=
current_op
.
output_arg_names
args
=
list
(
set
(
args
))
# unique the input and output arguments
for
var
in
self
.
block
.
vars
.
keys
():
if
var
not
in
args
:
self
.
block
.
remove_var
(
var
)
python/paddle/fluid/layers/nn.py
浏览文件 @
e63013a8
...
...
@@ -88,6 +88,7 @@ def fc(input,
bias_attr
=
None
,
use_mkldnn
=
False
,
act
=
None
,
is_test
=
False
,
name
=
None
):
"""
**Fully Connected Layer**
...
...
@@ -134,6 +135,7 @@ def fc(input,
bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias
of this layer. If it is set to None, no bias will be added to the output units.
act (str, default None): Activation to be applied to the output of this layer.
is_test(bool): A flag indicating whether execution is in test phase.
use_mkldnn(bool): Use mkldnn kernel or not, it is valid only when the mkldnn
library is installed. Default: False
name (str, default None): The name of this layer.
...
...
@@ -177,8 +179,11 @@ def fc(input,
inputs
=
{
"Input"
:
input
,
"W"
:
w
},
outputs
=
{
"Out"
:
tmp
},
attrs
=
{
"use_mkldnn"
:
use_mkldnn
,
"bias_attr"
:
bias_attr
})
attrs
=
{
"use_mkldnn"
:
use_mkldnn
,
"is_test"
:
is_test
,
"bias_attr"
:
bias_attr
})
return
helper
.
append_activation
(
tmp
)
else
:
for
input_var
,
param_attr
in
helper
.
iter_inputs_and_params
():
...
...
python/paddle/fluid/tests/book/test_image_classification.py
浏览文件 @
e63013a8
...
...
@@ -22,10 +22,17 @@ import sys
import
numpy
import
unittest
import
os
import
numpy
as
np
def
resnet_cifar10
(
input
,
depth
=
32
):
def
conv_bn_layer
(
input
,
ch_out
,
filter_size
,
stride
,
padding
,
act
=
'relu'
):
def
conv_bn_layer
(
input
,
ch_out
,
filter_size
,
stride
,
padding
,
act
=
'relu'
,
bias_attr
=
False
):
tmp
=
fluid
.
layers
.
conv2d
(
input
=
input
,
filter_size
=
filter_size
,
...
...
@@ -33,7 +40,7 @@ def resnet_cifar10(input, depth=32):
stride
=
stride
,
padding
=
padding
,
act
=
None
,
bias_attr
=
False
)
bias_attr
=
bias_attr
)
return
fluid
.
layers
.
batch_norm
(
input
=
tmp
,
act
=
act
)
def
shortcut
(
input
,
ch_in
,
ch_out
,
stride
):
...
...
@@ -44,7 +51,7 @@ def resnet_cifar10(input, depth=32):
def
basicblock
(
input
,
ch_in
,
ch_out
,
stride
):
tmp
=
conv_bn_layer
(
input
,
ch_out
,
3
,
stride
,
1
)
tmp
=
conv_bn_layer
(
tmp
,
ch_out
,
3
,
1
,
1
,
act
=
None
)
tmp
=
conv_bn_layer
(
tmp
,
ch_out
,
3
,
1
,
1
,
act
=
None
,
bias_attr
=
True
)
short
=
shortcut
(
input
,
ch_in
,
ch_out
,
stride
)
return
fluid
.
layers
.
elementwise_add
(
x
=
tmp
,
y
=
short
,
act
=
'relu'
)
...
...
@@ -219,11 +226,26 @@ def infer(use_cuda, save_dirname=None):
batch_size
=
1
tensor_img
=
numpy
.
random
.
rand
(
batch_size
,
3
,
32
,
32
).
astype
(
"float32"
)
# Use inference_transpiler to speedup
inference_transpiler_program
=
inference_program
.
clone
()
t
=
fluid
.
InferenceTranspiler
()
t
.
transpile
(
inference_transpiler_program
,
place
)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
tensor_img
},
fetch_list
=
fetch_targets
)
transpiler_results
=
exe
.
run
(
inference_transpiler_program
,
feed
=
{
feed_target_names
[
0
]:
tensor_img
},
fetch_list
=
fetch_targets
)
assert
len
(
results
[
0
])
==
len
(
transpiler_results
[
0
])
for
i
in
range
(
len
(
results
[
0
])):
np
.
testing
.
assert_almost_equal
(
results
[
0
][
i
],
transpiler_results
[
0
][
i
],
decimal
=
6
)
print
(
"infer results: "
,
results
[
0
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录