From af3ff422cc41447dd473dfd6651fdbb6b075495d Mon Sep 17 00:00:00 2001 From: Wojciech Uss Date: Fri, 8 Nov 2019 06:21:31 +0100 Subject: [PATCH] Fix dst memory allocation in elementwise_add (#21059) test=develop --- .../elementwise/mkldnn/elementwise_add_mkldnn_op.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc index a3a5a031e9..742f10cc4b 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc @@ -43,7 +43,6 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel { auto* z = ctx.Output("Out"); const T* x_data = x->data(); const T* y_data = y->data(); - T* z_data = z->mutable_data(ctx.GetPlace()); int axis = ctx.Attr("axis"); @@ -92,6 +91,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel { _x.ShareDataWith(*x); } + z->mutable_data(ctx.GetPlace()); auto sum_func = [](T a, T b) -> T { return a + b; }; TransformFunctor { auto sum_pd = handler.AcquireSumPrimitiveDescriptor( {src_x_memory, src_y_memory}, scales, dst_md); + T* z_data = z->mutable_data(ctx.GetPlace(), + sum_pd->dst_primitive_desc().get_size()); + auto dst_memory = handler.AcquireDstMemoryFromPrimitive(z_data); std::vector inputs({*src_x_memory, *src_y_memory}); -- GitLab