提交 e63013a8 编写于 作者: C chengduoZH

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/add_reduce_op_handle

......@@ -91,7 +91,6 @@ void ReduceOpHandle::RunImpl() {
if (paddle::platform::is_cpu_place(pre_place)) {
ReduceLoDTensor func(lod_tensors, trg);
VisitDataType(ToDataType(lod_tensors[0].type()), func);
} else if (paddle::platform::is_gpu_place(pre_place)) {
#ifdef PADDLE_WITH_CUDA
auto out_p = out_var_handles[0]->place_;
......
......@@ -72,10 +72,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dst_md = platform::MKLDNNMemDesc(
dst_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
auto src_memory = mkldnn::memory({src_md, mkldnn_engine},
reinterpret_cast<void*>(input_data));
auto weights_memory = mkldnn::memory({weights_md, mkldnn_engine},
reinterpret_cast<void*>(filter_data));
auto src_memory =
mkldnn::memory({src_md, mkldnn_engine},
reinterpret_cast<void*>(const_cast<T*>(input_data)));
auto weights_memory =
mkldnn::memory({weights_md, mkldnn_engine},
reinterpret_cast<void*>(const_cast<T*>(filter_data)));
auto dst_memory = mkldnn::memory({dst_md, mkldnn_engine}, output_data);
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd =
......@@ -180,9 +182,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
dst_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
// create memory
auto diff_dst_memory =
mkldnn::memory({diff_weights_md, mkldnn_engine},
reinterpret_cast<void*>(output_grad_data));
auto diff_dst_memory = mkldnn::memory(
{diff_weights_md, mkldnn_engine},
reinterpret_cast<void*>(const_cast<T*>(output_grad_data)));
// Retrieve conv_pd from device context
auto conv_pd =
std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
......@@ -202,8 +204,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto diff_weights_memory =
mkldnn::memory({diff_weights_md, mkldnn_engine},
reinterpret_cast<void*>(filter_grad_data));
auto src_memory = mkldnn::memory({src_md, mkldnn_engine},
reinterpret_cast<void*>(input_data));
auto src_memory =
mkldnn::memory({src_md, mkldnn_engine},
reinterpret_cast<void*>(const_cast<T*>(input_data)));
// create backward conv primitive for weights
auto conv_bwd_weights_prim = mkldnn::convolution_backward_weights(
......@@ -222,11 +225,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
strides, paddings, *conv_pd, mkldnn_engine);
// create memory
auto diff_src_memory =
mkldnn::memory({diff_src_md, mkldnn_engine},
reinterpret_cast<void*>(input_grad_data));
auto weights_memory = mkldnn::memory(
{weights_md, mkldnn_engine}, reinterpret_cast<void*>(filter_data));
auto diff_src_memory = mkldnn::memory(
{diff_src_md, mkldnn_engine},
reinterpret_cast<void*>(const_cast<T*>(input_grad_data)));
auto weights_memory =
mkldnn::memory({weights_md, mkldnn_engine},
reinterpret_cast<void*>(const_cast<T*>(filter_data)));
// create backward conv primitive for data
auto conv_bwd_data_prim = mkldnn::convolution_backward_data(
......
......@@ -73,6 +73,15 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
softmax_dst_memory);
std::vector<primitive> pipeline{softmax};
stream(stream::kind::eager).submit(pipeline).wait();
const bool is_test = ctx.Attr<bool>("is_test");
if (!is_test) {
T threshold = exp(-64);
for (size_t i = 0; i < dst_tz[0] * dst_tz[1]; ++i) {
output_data[i] =
output_data[i] < threshold ? threshold : output_data[i];
}
}
}
};
......
......@@ -97,6 +97,9 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("is_test",
"Disable epsilon adding to softmax results. Used by MKLDNN.")
.SetDefault(false);
AddComment(R"DOC(
Softmax Operator.
......
......@@ -37,6 +37,7 @@ from distribute_transpiler import DistributeTranspiler
from distribute_transpiler_simple import SimpleDistributeTranspiler
from concurrency import (Go, make_channel, channel_send, channel_recv,
channel_close, Select)
from inference_transpiler import InferenceTranspiler
import clip
from memory_optimization_transpiler import memory_optimize, release_memory
import profiler
......@@ -66,6 +67,7 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + [
'clip',
'SimpleDistributeTranspiler',
'DistributeTranspiler',
'InferenceTranspiler',
'memory_optimize',
'release_memory',
'profiler',
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from framework import Program
from executor import global_scope
from . import core
class InferenceTranspiler:
def transpile(self, program, place, scope=None):
'''
Transpile the program. Support only fuse batch normalization now.
:param program: program to transpile
:type program: Program
:param place: inference place
:type place: Place
:param scope: inference scope
:type scope: Scope or None
'''
if not isinstance(program, Program):
raise TypeError("program should be as Program type")
if not isinstance(place, core.CPUPlace) and not isinstance(
place, core.CUDAPlace):
raise TypeError("place should be as CPUPlace/CUDAPlace type")
if scope is None:
scope = global_scope()
if not isinstance(scope, core.Scope):
raise TypeError("scope should be as Scope type or None")
self.fuse_batch_norm(program, place, scope)
def fuse_batch_norm(self, program, place, scope):
'''
Transpile the program by fused batch normalization.
The batch normalization followed the convolution or fully connected layer
can be integrated with them. Doing so will give us a forward acceleration,
especially in environments like mobile or embedded.
For input X:
- Conv process: X = input * W + bias
- Batch norm process: X' = (X - mean) / std
- Scale Process: Y = a * X' + b
After fuse into one operation:
Y = (input * W + bias - mean) / std * a + b
= input * a * W / std + ((bias - mean) / std * a + b)
The operator transformation is:
- before:
- conv->batch_norm->any_other_op (bias == 0)
- conv->elementwise_add->batch_norm->any_other_op (bias != 0)
- after:
- conv->elementwise_add->any_other_op
The transpile stages are:
1. insert elementwise_add op when bias == 0.
2. fuse the batch_norm's parameters to conv and elementwise_add operators.
3. remove batch_norm ops which are not used in any other ops.
4. adjust the input of any_other_op to be the output of elementwise_add operator.
5. remove unused variables.
:param program: program to transpile
:type program: Program
:param place: inference place
:type place: Place
:param scope: inference scope
:type scope: Scope
'''
self.scope = scope
self.place = place
self.block = program.block(0)
self.input_map = {} # store the input names should be adjusted
i = 0
while i < len(self.block.ops):
current_op = self.block.ops[i]
# TODO(luotao1): consider only conv2d now. fc would be delt later.
if current_op.type in ['conv2d']:
# TODO(luotao1): consider single chain network now.
# For branch network, we counldn't use block.ops[i + 1] as
# the judgment condition.
next_op = self.block.ops[i + 1]
# conv2d without bias
if (next_op.type == 'batch_norm'):
# insert bias op
bias_op = self._insert_bias_op(i + 1, current_op, next_op)
# fuse batch_norm
self._fuse_param(current_op, next_op, bias_op, 0)
# remove batch_norm_op
self.block.remove_op(i + 2)
i = i + 1
# conv2d with bias, the next_op.type is elementwise_add
elif (next_op.type == 'elementwise_add'):
next_next_op = self.block.ops[i + 2]
if (next_next_op.type == 'batch_norm'):
# fuse batch_norm
self._fuse_param(current_op, next_next_op, next_op, 1)
# remove batch_norm_op
self.block.remove_op(i + 2)
i = i + 1
i = i + 1
self._adjust_input()
self._remove_unused_var()
# TODO(luotao): use clone() method to flush the program.desc in force,
# since some large program.desc will not be flushed immediately.
# And a better solution will be considered later.
program = program.clone()
# ====================== private transpiler functions =====================
def _insert_bias_op(self, index, current_op, bn_op):
'''
Construct elementwise_add operator for adding bias
and insert it into program.
:param index: insert location of bias_op
:type index: Int
:param current_op: current operator (conv or fc)
:type current_op: Operator
:param bn_op: batch norm operator
:type bn_op: Operator
:return: bias_op
:rtype: Operator
'''
# The input of bias_op is current_op's output and Bias of bn_op
# The output of bias_op is bn_op's output
x_var = self.block.var(current_op.output("Output")[0])
y_var = self.block.var(bn_op.input("Bias")[0])
out_var = self.block.var(bn_op.output("Y")[0])
bias_op = self.block.insert_op(
index,
type="elementwise_add",
inputs={"X": x_var,
"Y": y_var},
outputs={"Out": out_var},
attrs={"axis": 1}) # dim_start=1
return bias_op
def _fuse_param(self, current_op, bn_op, bias_op, with_bias):
'''
fuse the batch_norm_op' parameters to current_op (conv or fc)
:param current_op: current operator (conv or fc)
:type current_op: Operator
:param bn_op: batch norm operator
:type bn_op: Operator
:param bias_op: elementwise_add operator for adding bias
:type bias_op: Operator
:param with_bias: If current operator has bias, with_bias = 1; otherwise 0.
:type with_bias: Int
'''
def _update_param(op, old_param_name, new_param):
# For the sake of remaining the original variables the same as before,
# create new variables in scope to store the new parameters.
old_param_name = old_param_name[0]
old_var = self.block.vars[old_param_name]
new_param_name = old_param_name + '_fuse_bn'
new_var = self.block.create_parameter(
name=new_param_name.encode('ascii'),
type=old_var.type,
dtype=old_var.dtype,
shape=old_var.shape)
op.rename_input(old_param_name, new_param_name)
self.scope.var(new_param_name)
tensor = self.scope.find_var(new_param_name).get_tensor()
tensor.set(np.array(new_param), self.place)
def _load_param(param_name):
return np.array(self.scope.find_var(param_name[0]).get_tensor())
bias_bn = _load_param(bn_op.input("Bias")) #Bias
scale_bn = _load_param(bn_op.input("Scale")) #Scale
mean_bn = _load_param(bn_op.input("Mean")) #Mean
var_bn = _load_param(bn_op.input("Variance")) #Variance
# TODO(luotao1): consider only conv2d now. fc would be delt later.
current_param = _load_param(current_op.input("Filter"))
std_bn = np.float32(np.sqrt(np.add(var_bn, 1e-5)))
tmp = np.float32(np.divide(scale_bn, std_bn))
# add bias of batch_norm_op to conv2d
if with_bias:
bias = _load_param(bias_op.input("Y"))
else:
bias = np.zeros(bias_bn.shape)
bias = np.float32(
np.add(np.multiply(np.subtract(bias, mean_bn), tmp), bias_bn))
# re-compute weight of conv2d
tmp = tmp.reshape(tmp.shape[0], -1)
dst_param = current_param.reshape((tmp.shape[0], -1))
dst_param = np.float32(np.multiply(dst_param, tmp))
dst_param = dst_param.reshape(current_param.shape)
# update parameters
_update_param(current_op, current_op.input("Filter"), dst_param)
_update_param(bias_op, bias_op.input("Y"), bias)
# collect the renamed input
self.input_map[bn_op.output("Y")[0]] = bias_op.output("Out")[0]
def _adjust_input(self):
for i in range(len(self.block.ops)):
current_op = self.block.ops[i]
for input_arg in current_op.input_arg_names:
if input_arg in self.input_map:
current_op.rename_input(input_arg,
self.input_map[input_arg])
def _remove_unused_var(self):
'''
remove unused varibles in program
'''
args = []
for i in range(len(self.block.ops)):
current_op = self.block.ops[i]
args += current_op.input_arg_names
args += current_op.output_arg_names
args = list(set(args)) # unique the input and output arguments
for var in self.block.vars.keys():
if var not in args:
self.block.remove_var(var)
......@@ -88,6 +88,7 @@ def fc(input,
bias_attr=None,
use_mkldnn=False,
act=None,
is_test=False,
name=None):
"""
**Fully Connected Layer**
......@@ -134,6 +135,7 @@ def fc(input,
bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias
of this layer. If it is set to None, no bias will be added to the output units.
act (str, default None): Activation to be applied to the output of this layer.
is_test(bool): A flag indicating whether execution is in test phase.
use_mkldnn(bool): Use mkldnn kernel or not, it is valid only when the mkldnn
library is installed. Default: False
name (str, default None): The name of this layer.
......@@ -177,8 +179,11 @@ def fc(input,
inputs={"Input": input,
"W": w},
outputs={"Out": tmp},
attrs={"use_mkldnn": use_mkldnn,
"bias_attr": bias_attr})
attrs={
"use_mkldnn": use_mkldnn,
"is_test": is_test,
"bias_attr": bias_attr
})
return helper.append_activation(tmp)
else:
for input_var, param_attr in helper.iter_inputs_and_params():
......
......@@ -22,10 +22,17 @@ import sys
import numpy
import unittest
import os
import numpy as np
def resnet_cifar10(input, depth=32):
def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu'):
def conv_bn_layer(input,
ch_out,
filter_size,
stride,
padding,
act='relu',
bias_attr=False):
tmp = fluid.layers.conv2d(
input=input,
filter_size=filter_size,
......@@ -33,7 +40,7 @@ def resnet_cifar10(input, depth=32):
stride=stride,
padding=padding,
act=None,
bias_attr=False)
bias_attr=bias_attr)
return fluid.layers.batch_norm(input=tmp, act=act)
def shortcut(input, ch_in, ch_out, stride):
......@@ -44,7 +51,7 @@ def resnet_cifar10(input, depth=32):
def basicblock(input, ch_in, ch_out, stride):
tmp = conv_bn_layer(input, ch_out, 3, stride, 1)
tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None)
tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None, bias_attr=True)
short = shortcut(input, ch_in, ch_out, stride)
return fluid.layers.elementwise_add(x=tmp, y=short, act='relu')
......@@ -219,11 +226,26 @@ def infer(use_cuda, save_dirname=None):
batch_size = 1
tensor_img = numpy.random.rand(batch_size, 3, 32, 32).astype("float32")
# Use inference_transpiler to speedup
inference_transpiler_program = inference_program.clone()
t = fluid.InferenceTranspiler()
t.transpile(inference_transpiler_program, place)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
transpiler_results = exe.run(inference_transpiler_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
assert len(results[0]) == len(transpiler_results[0])
for i in range(len(results[0])):
np.testing.assert_almost_equal(
results[0][i], transpiler_results[0][i], decimal=6)
print("infer results: ", results[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册