提交 f6881971 编写于 作者: M Michal Gallus 提交者: Tomasz Patejko

MKLDNN conv + elementwise_add fusion: Fix output_data to point to the right...

MKLDNN conv + elementwise_add fusion: Fix output_data to point to the right tensor, also fix transpiler integration
上级 efd76614
......@@ -399,8 +399,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"Output and elementwise parameter need to have the "
"same dimension sizes");
output_data = output->mutable_data<T>(ctx.GetPlace());
output->ShareDataWith(*residual_param);
output_data = output->mutable_data<T>(ctx.GetPlace());
} else {
output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
......
......@@ -92,7 +92,8 @@ class InferenceTranspiler(object):
if current_op.type in ['conv2d']:
next_op = self.block.ops[i + 1]
if next_op.type == 'elementwise_add':
self._fuse_conv_eltwise(current_op, next_op)
self._fuse_conv_eltwise(i, current_op, next_op)
self.block._remove_op(i + 1) # Remove old conv
self.block._remove_op(i + 1) # Remove elementwise_add
i = i + 1
self._adjust_input()
......@@ -444,7 +445,7 @@ class InferenceTranspiler(object):
outputs={"Output": out_var},
attrs=attrs)
def _fuse_conv_eltwise(self, conv_op, eltwise_op):
def _fuse_conv_eltwise(self, index, conv_op, eltwise_op):
'''
fuse the conv op with elementwise_add
......@@ -454,9 +455,26 @@ class InferenceTranspiler(object):
:type eltwise_op: Operator
'''
conv_op._set_attr("fuse_eltwise", True)
self.input_map[conv_op.output("Output")[0]] = eltwise_op.input("Y")[0]
self.input_map[eltwise_op.output("Out")[0]] = eltwise_op.input("Y")[0]
residual_var = self.block.var(eltwise_op.input("X")[0])
out_var = self.block.var(eltwise_op.output("Out")[0])
filter_var = self.block.var(conv_op.input("Filter")[0])
in_var = self.block.var(conv_op.input("Input")[0])
bias_var = self.block.var(conv_op.input("Bias")[0])
conv_op.set_attr("fuse_eltwise", True)
attrs = {name: conv_op.attr(name) for name in conv_op.attr_names}
self.block._insert_op(
index,
type="conv2d",
inputs={
"Input": in_var,
"Filter": filter_var,
"Bias": bias_var,
"ResidualData": residual_var
},
outputs={"Output": out_var},
attrs=attrs)
def _adjust_input(self):
for i in range(len(self.block.ops)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册