diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc index a3a5a031e971338ba59c88c4bd303f60fee55889..742f10cc4bc54eba45259476254e7a93ad2ff63f 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc @@ -43,7 +43,6 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel { auto* z = ctx.Output("Out"); const T* x_data = x->data(); const T* y_data = y->data(); - T* z_data = z->mutable_data(ctx.GetPlace()); int axis = ctx.Attr("axis"); @@ -92,6 +91,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel { _x.ShareDataWith(*x); } + z->mutable_data(ctx.GetPlace()); auto sum_func = [](T a, T b) -> T { return a + b; }; TransformFunctor { auto sum_pd = handler.AcquireSumPrimitiveDescriptor( {src_x_memory, src_y_memory}, scales, dst_md); + T* z_data = z->mutable_data(ctx.GetPlace(), + sum_pd->dst_primitive_desc().get_size()); + auto dst_memory = handler.AcquireDstMemoryFromPrimitive(z_data); std::vector inputs({*src_x_memory, *src_y_memory});