Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
e63013a8
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
e63013a8
编写于
4月 18, 2018
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into feature/add_reduce_op_handle
上级
1eeb2e00
61f4baa1
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
304 addition
and
20 deletion
+304
-20
paddle/fluid/framework/details/reduce_op_handle.cc
paddle/fluid/framework/details/reduce_op_handle.cc
+0
-1
paddle/fluid/operators/conv_mkldnn_op.cc
paddle/fluid/operators/conv_mkldnn_op.cc
+18
-14
paddle/fluid/operators/softmax_mkldnn_op.cc
paddle/fluid/operators/softmax_mkldnn_op.cc
+9
-0
paddle/fluid/operators/softmax_op.cc
paddle/fluid/operators/softmax_op.cc
+3
-0
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+2
-0
python/paddle/fluid/inference_transpiler.py
python/paddle/fluid/inference_transpiler.py
+240
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+7
-2
python/paddle/fluid/tests/book/test_image_classification.py
python/paddle/fluid/tests/book/test_image_classification.py
+25
-3
未找到文件。
paddle/fluid/framework/details/reduce_op_handle.cc
浏览文件 @
e63013a8
...
...
@@ -91,7 +91,6 @@ void ReduceOpHandle::RunImpl() {
if
(
paddle
::
platform
::
is_cpu_place
(
pre_place
))
{
ReduceLoDTensor
func
(
lod_tensors
,
trg
);
VisitDataType
(
ToDataType
(
lod_tensors
[
0
].
type
()),
func
);
}
else
if
(
paddle
::
platform
::
is_gpu_place
(
pre_place
))
{
#ifdef PADDLE_WITH_CUDA
auto
out_p
=
out_var_handles
[
0
]
->
place_
;
...
...
paddle/fluid/operators/conv_mkldnn_op.cc
浏览文件 @
e63013a8
...
...
@@ -72,10 +72,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto
dst_md
=
platform
::
MKLDNNMemDesc
(
dst_tz
,
mkldnn
::
memory
::
data_type
::
f32
,
mkldnn
::
memory
::
format
::
nchw
);
auto
src_memory
=
mkldnn
::
memory
({
src_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
input_data
));
auto
weights_memory
=
mkldnn
::
memory
({
weights_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
filter_data
));
auto
src_memory
=
mkldnn
::
memory
({
src_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
input_data
)));
auto
weights_memory
=
mkldnn
::
memory
({
weights_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
filter_data
)));
auto
dst_memory
=
mkldnn
::
memory
({
dst_md
,
mkldnn_engine
},
output_data
);
std
::
shared_ptr
<
mkldnn
::
convolution_forward
::
primitive_desc
>
conv_pd
=
...
...
@@ -180,9 +182,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
dst_tz
,
mkldnn
::
memory
::
data_type
::
f32
,
mkldnn
::
memory
::
format
::
nchw
);
// create memory
auto
diff_dst_memory
=
mkldnn
::
memory
(
{
diff_weights_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
output_grad_data
));
auto
diff_dst_memory
=
mkldnn
::
memory
(
{
diff_weights_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
output_grad_data
)
));
// Retrieve conv_pd from device context
auto
conv_pd
=
std
::
static_pointer_cast
<
mkldnn
::
convolution_forward
::
primitive_desc
>
(
...
...
@@ -202,8 +204,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto
diff_weights_memory
=
mkldnn
::
memory
({
diff_weights_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
filter_grad_data
));
auto
src_memory
=
mkldnn
::
memory
({
src_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
input_data
));
auto
src_memory
=
mkldnn
::
memory
({
src_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
input_data
)));
// create backward conv primitive for weights
auto
conv_bwd_weights_prim
=
mkldnn
::
convolution_backward_weights
(
...
...
@@ -222,11 +225,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
strides
,
paddings
,
*
conv_pd
,
mkldnn_engine
);
// create memory
auto
diff_src_memory
=
mkldnn
::
memory
({
diff_src_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
input_grad_data
));
auto
weights_memory
=
mkldnn
::
memory
(
{
weights_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
filter_data
));
auto
diff_src_memory
=
mkldnn
::
memory
(
{
diff_src_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
input_grad_data
)));
auto
weights_memory
=
mkldnn
::
memory
({
weights_md
,
mkldnn_engine
},
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
filter_data
)));
// create backward conv primitive for data
auto
conv_bwd_data_prim
=
mkldnn
::
convolution_backward_data
(
...
...
paddle/fluid/operators/softmax_mkldnn_op.cc
浏览文件 @
e63013a8
...
...
@@ -73,6 +73,15 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
softmax_dst_memory
);
std
::
vector
<
primitive
>
pipeline
{
softmax
};
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
if
(
!
is_test
)
{
T
threshold
=
exp
(
-
64
);
for
(
size_t
i
=
0
;
i
<
dst_tz
[
0
]
*
dst_tz
[
1
];
++
i
)
{
output_data
[
i
]
=
output_data
[
i
]
<
threshold
?
threshold
:
output_data
[
i
];
}
}
}
};
...
...
paddle/fluid/operators/softmax_op.cc
浏览文件 @
e63013a8
...
...
@@ -97,6 +97,9 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr
<
bool
>
(
"use_mkldnn"
,
"(bool, default false) Only used in mkldnn kernel"
)
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"is_test"
,
"Disable epsilon adding to softmax results. Used by MKLDNN."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
Softmax Operator.
...
...
python/paddle/fluid/__init__.py
浏览文件 @
e63013a8
...
...
@@ -37,6 +37,7 @@ from distribute_transpiler import DistributeTranspiler
from
distribute_transpiler_simple
import
SimpleDistributeTranspiler
from
concurrency
import
(
Go
,
make_channel
,
channel_send
,
channel_recv
,
channel_close
,
Select
)
from
inference_transpiler
import
InferenceTranspiler
import
clip
from
memory_optimization_transpiler
import
memory_optimize
,
release_memory
import
profiler
...
...
@@ -66,6 +67,7 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + [
'clip'
,
'SimpleDistributeTranspiler'
,
'DistributeTranspiler'
,
'InferenceTranspiler'
,
'memory_optimize'
,
'release_memory'
,
'profiler'
,
...
...
python/paddle/fluid/inference_transpiler.py
0 → 100644
浏览文件 @
e63013a8
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
from
framework
import
Program
from
executor
import
global_scope
from
.
import
core
class
InferenceTranspiler
:
def
transpile
(
self
,
program
,
place
,
scope
=
None
):
'''
Transpile the program. Support only fuse batch normalization now.
:param program: program to transpile
:type program: Program
:param place: inference place
:type place: Place
:param scope: inference scope
:type scope: Scope or None
'''
if
not
isinstance
(
program
,
Program
):
raise
TypeError
(
"program should be as Program type"
)
if
not
isinstance
(
place
,
core
.
CPUPlace
)
and
not
isinstance
(
place
,
core
.
CUDAPlace
):
raise
TypeError
(
"place should be as CPUPlace/CUDAPlace type"
)
if
scope
is
None
:
scope
=
global_scope
()
if
not
isinstance
(
scope
,
core
.
Scope
):
raise
TypeError
(
"scope should be as Scope type or None"
)
self
.
fuse_batch_norm
(
program
,
place
,
scope
)
def
fuse_batch_norm
(
self
,
program
,
place
,
scope
):
'''
Transpile the program by fused batch normalization.
The batch normalization followed the convolution or fully connected layer
can be integrated with them. Doing so will give us a forward acceleration,
especially in environments like mobile or embedded.
For input X:
- Conv process: X = input * W + bias
- Batch norm process: X' = (X - mean) / std
- Scale Process: Y = a * X' + b
After fuse into one operation:
Y = (input * W + bias - mean) / std * a + b
= input * a * W / std + ((bias - mean) / std * a + b)
The operator transformation is:
- before:
- conv->batch_norm->any_other_op (bias == 0)
- conv->elementwise_add->batch_norm->any_other_op (bias != 0)
- after:
- conv->elementwise_add->any_other_op
The transpile stages are:
1. insert elementwise_add op when bias == 0.
2. fuse the batch_norm's parameters to conv and elementwise_add operators.
3. remove batch_norm ops which are not used in any other ops.
4. adjust the input of any_other_op to be the output of elementwise_add operator.
5. remove unused variables.
:param program: program to transpile
:type program: Program
:param place: inference place
:type place: Place
:param scope: inference scope
:type scope: Scope
'''
self
.
scope
=
scope
self
.
place
=
place
self
.
block
=
program
.
block
(
0
)
self
.
input_map
=
{}
# store the input names should be adjusted
i
=
0
while
i
<
len
(
self
.
block
.
ops
):
current_op
=
self
.
block
.
ops
[
i
]
# TODO(luotao1): consider only conv2d now. fc would be delt later.
if
current_op
.
type
in
[
'conv2d'
]:
# TODO(luotao1): consider single chain network now.
# For branch network, we counldn't use block.ops[i + 1] as
# the judgment condition.
next_op
=
self
.
block
.
ops
[
i
+
1
]
# conv2d without bias
if
(
next_op
.
type
==
'batch_norm'
):
# insert bias op
bias_op
=
self
.
_insert_bias_op
(
i
+
1
,
current_op
,
next_op
)
# fuse batch_norm
self
.
_fuse_param
(
current_op
,
next_op
,
bias_op
,
0
)
# remove batch_norm_op
self
.
block
.
remove_op
(
i
+
2
)
i
=
i
+
1
# conv2d with bias, the next_op.type is elementwise_add
elif
(
next_op
.
type
==
'elementwise_add'
):
next_next_op
=
self
.
block
.
ops
[
i
+
2
]
if
(
next_next_op
.
type
==
'batch_norm'
):
# fuse batch_norm
self
.
_fuse_param
(
current_op
,
next_next_op
,
next_op
,
1
)
# remove batch_norm_op
self
.
block
.
remove_op
(
i
+
2
)
i
=
i
+
1
i
=
i
+
1
self
.
_adjust_input
()
self
.
_remove_unused_var
()
# TODO(luotao): use clone() method to flush the program.desc in force,
# since some large program.desc will not be flushed immediately.
# And a better solution will be considered later.
program
=
program
.
clone
()
# ====================== private transpiler functions =====================
def
_insert_bias_op
(
self
,
index
,
current_op
,
bn_op
):
'''
Construct elementwise_add operator for adding bias
and insert it into program.
:param index: insert location of bias_op
:type index: Int
:param current_op: current operator (conv or fc)
:type current_op: Operator
:param bn_op: batch norm operator
:type bn_op: Operator
:return: bias_op
:rtype: Operator
'''
# The input of bias_op is current_op's output and Bias of bn_op
# The output of bias_op is bn_op's output
x_var
=
self
.
block
.
var
(
current_op
.
output
(
"Output"
)[
0
])
y_var
=
self
.
block
.
var
(
bn_op
.
input
(
"Bias"
)[
0
])
out_var
=
self
.
block
.
var
(
bn_op
.
output
(
"Y"
)[
0
])
bias_op
=
self
.
block
.
insert_op
(
index
,
type
=
"elementwise_add"
,
inputs
=
{
"X"
:
x_var
,
"Y"
:
y_var
},
outputs
=
{
"Out"
:
out_var
},
attrs
=
{
"axis"
:
1
})
# dim_start=1
return
bias_op
def
_fuse_param
(
self
,
current_op
,
bn_op
,
bias_op
,
with_bias
):
'''
fuse the batch_norm_op' parameters to current_op (conv or fc)
:param current_op: current operator (conv or fc)
:type current_op: Operator
:param bn_op: batch norm operator
:type bn_op: Operator
:param bias_op: elementwise_add operator for adding bias
:type bias_op: Operator
:param with_bias: If current operator has bias, with_bias = 1; otherwise 0.
:type with_bias: Int
'''
def
_update_param
(
op
,
old_param_name
,
new_param
):
# For the sake of remaining the original variables the same as before,
# create new variables in scope to store the new parameters.
old_param_name
=
old_param_name
[
0
]
old_var
=
self
.
block
.
vars
[
old_param_name
]
new_param_name
=
old_param_name
+
'_fuse_bn'
new_var
=
self
.
block
.
create_parameter
(
name
=
new_param_name
.
encode
(
'ascii'
),
type
=
old_var
.
type
,
dtype
=
old_var
.
dtype
,
shape
=
old_var
.
shape
)
op
.
rename_input
(
old_param_name
,
new_param_name
)
self
.
scope
.
var
(
new_param_name
)
tensor
=
self
.
scope
.
find_var
(
new_param_name
).
get_tensor
()
tensor
.
set
(
np
.
array
(
new_param
),
self
.
place
)
def
_load_param
(
param_name
):
return
np
.
array
(
self
.
scope
.
find_var
(
param_name
[
0
]).
get_tensor
())
bias_bn
=
_load_param
(
bn_op
.
input
(
"Bias"
))
#Bias
scale_bn
=
_load_param
(
bn_op
.
input
(
"Scale"
))
#Scale
mean_bn
=
_load_param
(
bn_op
.
input
(
"Mean"
))
#Mean
var_bn
=
_load_param
(
bn_op
.
input
(
"Variance"
))
#Variance
# TODO(luotao1): consider only conv2d now. fc would be delt later.
current_param
=
_load_param
(
current_op
.
input
(
"Filter"
))
std_bn
=
np
.
float32
(
np
.
sqrt
(
np
.
add
(
var_bn
,
1e-5
)))
tmp
=
np
.
float32
(
np
.
divide
(
scale_bn
,
std_bn
))
# add bias of batch_norm_op to conv2d
if
with_bias
:
bias
=
_load_param
(
bias_op
.
input
(
"Y"
))
else
:
bias
=
np
.
zeros
(
bias_bn
.
shape
)
bias
=
np
.
float32
(
np
.
add
(
np
.
multiply
(
np
.
subtract
(
bias
,
mean_bn
),
tmp
),
bias_bn
))
# re-compute weight of conv2d
tmp
=
tmp
.
reshape
(
tmp
.
shape
[
0
],
-
1
)
dst_param
=
current_param
.
reshape
((
tmp
.
shape
[
0
],
-
1
))
dst_param
=
np
.
float32
(
np
.
multiply
(
dst_param
,
tmp
))
dst_param
=
dst_param
.
reshape
(
current_param
.
shape
)
# update parameters
_update_param
(
current_op
,
current_op
.
input
(
"Filter"
),
dst_param
)
_update_param
(
bias_op
,
bias_op
.
input
(
"Y"
),
bias
)
# collect the renamed input
self
.
input_map
[
bn_op
.
output
(
"Y"
)[
0
]]
=
bias_op
.
output
(
"Out"
)[
0
]
def
_adjust_input
(
self
):
for
i
in
range
(
len
(
self
.
block
.
ops
)):
current_op
=
self
.
block
.
ops
[
i
]
for
input_arg
in
current_op
.
input_arg_names
:
if
input_arg
in
self
.
input_map
:
current_op
.
rename_input
(
input_arg
,
self
.
input_map
[
input_arg
])
def
_remove_unused_var
(
self
):
'''
remove unused varibles in program
'''
args
=
[]
for
i
in
range
(
len
(
self
.
block
.
ops
)):
current_op
=
self
.
block
.
ops
[
i
]
args
+=
current_op
.
input_arg_names
args
+=
current_op
.
output_arg_names
args
=
list
(
set
(
args
))
# unique the input and output arguments
for
var
in
self
.
block
.
vars
.
keys
():
if
var
not
in
args
:
self
.
block
.
remove_var
(
var
)
python/paddle/fluid/layers/nn.py
浏览文件 @
e63013a8
...
...
@@ -88,6 +88,7 @@ def fc(input,
bias_attr
=
None
,
use_mkldnn
=
False
,
act
=
None
,
is_test
=
False
,
name
=
None
):
"""
**Fully Connected Layer**
...
...
@@ -134,6 +135,7 @@ def fc(input,
bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias
of this layer. If it is set to None, no bias will be added to the output units.
act (str, default None): Activation to be applied to the output of this layer.
is_test(bool): A flag indicating whether execution is in test phase.
use_mkldnn(bool): Use mkldnn kernel or not, it is valid only when the mkldnn
library is installed. Default: False
name (str, default None): The name of this layer.
...
...
@@ -177,8 +179,11 @@ def fc(input,
inputs
=
{
"Input"
:
input
,
"W"
:
w
},
outputs
=
{
"Out"
:
tmp
},
attrs
=
{
"use_mkldnn"
:
use_mkldnn
,
"bias_attr"
:
bias_attr
})
attrs
=
{
"use_mkldnn"
:
use_mkldnn
,
"is_test"
:
is_test
,
"bias_attr"
:
bias_attr
})
return
helper
.
append_activation
(
tmp
)
else
:
for
input_var
,
param_attr
in
helper
.
iter_inputs_and_params
():
...
...
python/paddle/fluid/tests/book/test_image_classification.py
浏览文件 @
e63013a8
...
...
@@ -22,10 +22,17 @@ import sys
import
numpy
import
unittest
import
os
import
numpy
as
np
def
resnet_cifar10
(
input
,
depth
=
32
):
def
conv_bn_layer
(
input
,
ch_out
,
filter_size
,
stride
,
padding
,
act
=
'relu'
):
def
conv_bn_layer
(
input
,
ch_out
,
filter_size
,
stride
,
padding
,
act
=
'relu'
,
bias_attr
=
False
):
tmp
=
fluid
.
layers
.
conv2d
(
input
=
input
,
filter_size
=
filter_size
,
...
...
@@ -33,7 +40,7 @@ def resnet_cifar10(input, depth=32):
stride
=
stride
,
padding
=
padding
,
act
=
None
,
bias_attr
=
False
)
bias_attr
=
bias_attr
)
return
fluid
.
layers
.
batch_norm
(
input
=
tmp
,
act
=
act
)
def
shortcut
(
input
,
ch_in
,
ch_out
,
stride
):
...
...
@@ -44,7 +51,7 @@ def resnet_cifar10(input, depth=32):
def
basicblock
(
input
,
ch_in
,
ch_out
,
stride
):
tmp
=
conv_bn_layer
(
input
,
ch_out
,
3
,
stride
,
1
)
tmp
=
conv_bn_layer
(
tmp
,
ch_out
,
3
,
1
,
1
,
act
=
None
)
tmp
=
conv_bn_layer
(
tmp
,
ch_out
,
3
,
1
,
1
,
act
=
None
,
bias_attr
=
True
)
short
=
shortcut
(
input
,
ch_in
,
ch_out
,
stride
)
return
fluid
.
layers
.
elementwise_add
(
x
=
tmp
,
y
=
short
,
act
=
'relu'
)
...
...
@@ -219,11 +226,26 @@ def infer(use_cuda, save_dirname=None):
batch_size
=
1
tensor_img
=
numpy
.
random
.
rand
(
batch_size
,
3
,
32
,
32
).
astype
(
"float32"
)
# Use inference_transpiler to speedup
inference_transpiler_program
=
inference_program
.
clone
()
t
=
fluid
.
InferenceTranspiler
()
t
.
transpile
(
inference_transpiler_program
,
place
)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
tensor_img
},
fetch_list
=
fetch_targets
)
transpiler_results
=
exe
.
run
(
inference_transpiler_program
,
feed
=
{
feed_target_names
[
0
]:
tensor_img
},
fetch_list
=
fetch_targets
)
assert
len
(
results
[
0
])
==
len
(
transpiler_results
[
0
])
for
i
in
range
(
len
(
results
[
0
])):
np
.
testing
.
assert_almost_equal
(
results
[
0
][
i
],
transpiler_results
[
0
][
i
],
decimal
=
6
)
print
(
"infer results: "
,
results
[
0
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录