提交 56528531 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwis_add fusion: initial work on passing eltwise data to conv primitive

上级 42f569fd
...@@ -386,8 +386,22 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -386,8 +386,22 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto user_weights_memory_p = handler.AcquireWeightsMemory( auto user_weights_memory_p = handler.AcquireWeightsMemory(
user_weights_md, to_void_cast<T>(filter_data)); user_weights_md, to_void_cast<T>(filter_data));
T* output_data =
T* output_data = nullptr;
if (fuse_eltwise) {
auto eltwise_param = ctx.Input<Tensor>("EltwiseParameter");
auto eltwise_param_data = eltwise_param->data<T>();
PADDLE_ENFORCE(eltwise_param_data != nullptr, "Provide data if you want MKLDNN conv+elementwise_add fusion");
PADDLE_ENFORCE_EQ(output->dims(), eltwise_param->dims(), "Output and elementwise parameter need to have the same dimension sizes");
output_data = const_cast<T*>(eltwise_param_data);
} else {
output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize()); output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
}
// create reorder primitive if the input format is not the preferred one // create reorder primitive if the input format is not the preferred one
auto src_memory_p = auto src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
......
...@@ -132,6 +132,9 @@ void Conv2DOpMaker::Make() { ...@@ -132,6 +132,9 @@ void Conv2DOpMaker::Make() {
"(Tensor) The output tensor of convolution operator. " "(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW.") "The format of output tensor is also NCHW.")
.Reuse("Input"); .Reuse("Input");
AddInput("EltwiseParameter",
"(Tensor) Tensor to which convolution output will be added."
"Used on with fuse_eltwise fusion.");
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1}), the " "(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of " "strides(h_stride, w_stride) of "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部