提交 cd32ddac 编写于 作者: M Michał Gallus 提交者: Tao Luo

Fuse Convolution and Eltwise Add into MKLDNN's Conv+Bias (#12669)

* Fuse Convolution and Eltwise Add into Conv+Bias

* Reduce bias branching at conv_mkldnn_op

* Add MKLDNN build checks for Conv Bias

* Conv-bias: check if bias input exist befor assignment

* Conv-bias: Remove Bias dim check from infershape

It was causing conv3d test to crash upon\ncalling HasInput(Bias)
上级 896a37b6
......@@ -126,6 +126,15 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_bias_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto user_bias_pd = user_bias_memory_p->get_primitive_desc();
auto bias_pd = conv_pd_->bias_primitive_desc();
return this->AcquireMemory(bias_pd, user_bias_pd, user_bias_memory_p,
"@bias_mem_p", pipeline);
}
std::shared_ptr<mkldnn::convolution_forward> AcquireConvolution(
std::shared_ptr<mkldnn::memory> src_memory_p,
std::shared_ptr<mkldnn::memory> weights_memory_p,
......@@ -147,6 +156,28 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
return conv_p;
}
std::shared_ptr<mkldnn::convolution_forward> AcquireConvolution(
std::shared_ptr<mkldnn::memory> src_memory_p,
std::shared_ptr<mkldnn::memory> weights_memory_p,
std::shared_ptr<mkldnn::memory> bias_memory_p,
std::shared_ptr<mkldnn::memory> dst_memory_p) {
auto prim_key = key_ + "@conv_p";
auto conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution primitive in device context");
if (conv_p == nullptr) {
conv_p = std::make_shared<mkldnn::convolution_forward>(
*conv_pd_, *(src_memory_p), *(weights_memory_p.get()),
*(bias_memory_p.get()), *(dst_memory_p.get()));
dev_ctx_.SetBlob(prim_key, conv_p);
} else {
is_reusing_ = true;
}
return conv_p;
}
std::shared_ptr<mkldnn::convolution_backward_weights>
AcquireConvolutionBackwardWeights(
std::shared_ptr<mkldnn::memory> src_memory_p,
......@@ -229,6 +260,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
auto* output = ctx.Output<Tensor>("Output");
PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN &&
......@@ -237,6 +269,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN &&
filter->format() != memory::format::format_undef,
"Wrong layout/format set for Filter tensor");
PADDLE_ENFORCE(input->dims().size() == 4,
"Input must be with 4 dimensions, i.e. NCHW");
PADDLE_ENFORCE(filter->dims().size() == 4,
"Filter must be with 4 dimensions, i.e. OIHW");
if (bias) {
PADDLE_ENFORCE(bias->layout() == DataLayout::kMKLDNN &&
bias->format() != memory::format::format_undef,
"Wrong layout/format set for Bias tensor");
PADDLE_ENFORCE(bias->dims().size() == 1,
"Bias must only have 1 dimension, i.e. X");
}
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
......@@ -253,11 +296,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const T* filter_data = filter->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE(input->dims().size() == 4,
"Input must be with 4 dimensions, i.e. NCHW");
PADDLE_ENFORCE(filter->dims().size() == 4,
"Filter must be with 4 dimensions, i.e. OIHW");
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> weights_tz =
paddle::framework::vectorize2int(filter->dims());
......@@ -288,13 +326,23 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
std::vector<int> bias_tz; // TODO(mgallus): avoid empty vector creation.
// Currently used whenever bias is != nullptr.
auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
// create a conv primitive descriptor and save it for usage in backward
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine);
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
if (bias) {
bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine);
} else {
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides,
paddings, mkldnn_engine);
}
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd);
......@@ -315,8 +363,22 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
// create convolution op primitive
auto conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
std::shared_ptr<mkldnn::convolution_forward> conv_p;
if (bias) {
const T* bias_data = bias->data<T>();
auto user_bias_md = platform::MKLDNNMemDesc(
{bias_tz}, platform::MKLDNNGetDataType<T>(), memory::format::x);
auto user_bias_memory_p =
handler.AcquireBiasMemory(user_bias_md, to_void_cast<T>(bias_data));
auto bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline);
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
bias_memory_p, dst_memory_p);
} else {
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
dst_memory_p);
}
// push primitive to stream and wait until it's executed
pipeline.push_back(*conv_p);
......@@ -346,6 +408,27 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd);
}
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& bias, const memory::desc& dst,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
auto conv_desc = mkldnn::convolution_forward::desc(
mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights,
bias, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
auto p_conv_pd =
new mkldnn::convolution_forward::primitive_desc(conv_desc, engine);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd);
}
};
template <typename T>
......
......@@ -37,6 +37,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
int groups = ctx->Attrs().Get<int>("groups");
......@@ -57,7 +58,6 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[1] * groups,
"The number of input channels should be equal to filter "
"channels * groups.");
PADDLE_ENFORCE_EQ(
filter_dims[0] % groups, 0,
"The number of output channels should be divided by groups.");
......@@ -122,6 +122,11 @@ void Conv2DOpMaker::Make() {
"H is the height of the filter, and W is the width of the filter. "
"If the groups attribute is greater than 1, C equals the number of "
"input image channels divided by the groups.");
AddInput("Bias",
"(Tensor) Bias to be added to each output of filter application."
"The format of output tensor is X (one-dimensional) of size equal"
"to the number of output channels. Only used with MKL-DNN.")
.AsDispensable();
AddOutput("Output",
"(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW.")
......
......@@ -125,6 +125,11 @@ class MKLDNNHandler {
return this->AcquireMemory(md, ptr, "@user_weights_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireBiasMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_bias_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDstMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_dst_mem_p");
......
......@@ -59,8 +59,12 @@ class InferenceTranspiler(object):
scope = global_scope()
if not isinstance(scope, core.Scope):
raise TypeError("scope should be as Scope type or None")
self._fuse_batch_norm(program, place, scope)
use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False))
if use_mkldnn:
self._fuse_relu_mkldnn(program)
self._fuse_conv_bias_mkldnn(program)
else:
self._fuse_batch_norm(program, place, scope)
def _fuse_relu_mkldnn(self, program):
'''
......@@ -82,10 +86,6 @@ class InferenceTranspiler(object):
:param program: program to transpile
:type program: Program
'''
use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False))
if not use_mkldnn:
return
self.block = program.block(0)
i = 0
......@@ -106,6 +106,69 @@ class InferenceTranspiler(object):
# And a better solution will be considered later.
program = program.clone()
def _fuse_conv_bias_mkldnn(self, program):
'''
Transpile the program by fused convolution and elementwise_add.
Replace conv2d and elementwise_add ops with a new conv2d op
based on an old conv2d op and the :math:`Bias` taken from
elementwise_add.
For input :math:`X`:
- Conv process: :math:`X = input * W`
- Elementwise_add process: :math` X = X + bias`
After fuse into one operation:
.. math::
X = input * W + bias
The operator transformation is:
- before:
- conv->elementwise_add->any_other_op
- after:
- conv->any_other_op
The transpile stages are:
1. Extract bias and output variables from elementwise_add.
2. Extract Input, Weight and attributes from conv op.
3. Create a new convolution op based on extracted params.
4. Remove old conv op.
5. Remove elementwise_add.
5. Remove unused variables.
Args:
program (Program): program to transpile
'''
self.block = program.block(0)
i = 0
while i < len(self.block.ops) - 2:
current_op = self.block.ops[i]
next_op = self.block.ops[i + 1]
# conv2d with bias
if current_op.type in ['conv2d'] and \
next_op.type in ['elementwise_add']:
self._fuse_conv_bias(i, current_op, next_op)
self.block._remove_op(i + 1) # Remove old conv
self.block._remove_op(i + 1) # Remove elementwise_add
i = i + 1
i = i + 1
self._remove_unused_var()
# TODO(luotao): use clone() method to flush the program.desc in force,
# since some large program.desc will not be flushed immediately.
# And a better solution will be considered later.
program = program.clone()
def _fuse_batch_norm(self, program, place, scope):
'''
Transpile the program by fused batch normalization.
......@@ -185,7 +248,6 @@ class InferenceTranspiler(object):
self.block._remove_op(i + 2)
i = i + 1
i = i + 1
self._adjust_input()
self._remove_unused_var()
# TODO(luotao): use clone() method to flush the program.desc in force,
......@@ -288,6 +350,33 @@ class InferenceTranspiler(object):
# collect the renamed input
self.input_map[bn_op.output("Y")[0]] = bias_op.output("Out")[0]
def _fuse_conv_bias(self, index, conv_op, elementwise_add_op):
'''
fuse the conv op with elementwise_add
:param index: index of the conv_op in ops list
:type index: Int
:param conv_op: convolution operator
:type conv_op: Operator
:param elementwise_add_op: convolution's bias operator
:type elementwise_add_op: Operator
'''
bias_var = self.block.var(elementwise_add_op.input("Y")[0])
out_var = self.block.var(elementwise_add_op.output("Out")[0])
filter_var = self.block.var(conv_op.input("Filter")[0])
in_var = self.block.var(conv_op.input("Input")[0])
attrs = {name: conv_op.attr(name) for name in conv_op.attr_names}
self.block._insert_op(
index,
type="conv2d",
inputs={"Input": in_var,
"Filter": filter_var,
"Bias": bias_var},
outputs={"Output": out_var},
attrs=attrs)
def _adjust_input(self):
for i in range(len(self.block.ops)):
current_op = self.block.ops[i]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册