From 74ca3ae8810b50cc587bac25a9cb548ff6d3d1db Mon Sep 17 00:00:00 2001 From: bingyanghuang <33643817+bingyanghuang@users.noreply.github.com> Date: Wed, 13 Nov 2019 16:21:57 +0800 Subject: [PATCH] cherry-pick #21059, test=release/1.6 (#21153) --- .../elementwise/mkldnn/elementwise_add_mkldnn_op.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc index a3a5a031e97..742f10cc4bc 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc @@ -43,7 +43,6 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel { auto* z = ctx.Output("Out"); const T* x_data = x->data(); const T* y_data = y->data(); - T* z_data = z->mutable_data(ctx.GetPlace()); int axis = ctx.Attr("axis"); @@ -92,6 +91,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel { _x.ShareDataWith(*x); } + z->mutable_data(ctx.GetPlace()); auto sum_func = [](T a, T b) -> T { return a + b; }; TransformFunctor { auto sum_pd = handler.AcquireSumPrimitiveDescriptor( {src_x_memory, src_y_memory}, scales, dst_md); + T* z_data = z->mutable_data(ctx.GetPlace(), + sum_pd->dst_primitive_desc().get_size()); + auto dst_memory = handler.AcquireDstMemoryFromPrimitive(z_data); std::vector inputs({*src_x_memory, *src_y_memory}); -- GitLab