提交 74ca3ae8 编写于 作者: B bingyanghuang 提交者: Tao Luo

cherry-pick #21059, test=release/1.6 (#21153)

上级 e7d5e0ea
...@@ -43,7 +43,6 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -43,7 +43,6 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const T* y_data = y->data<T>(); const T* y_data = y->data<T>();
T* z_data = z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
...@@ -92,6 +91,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -92,6 +91,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
_x.ShareDataWith(*x); _x.ShareDataWith(*x);
} }
z->mutable_data<T>(ctx.GetPlace());
auto sum_func = [](T a, T b) -> T { return a + b; }; auto sum_func = [](T a, T b) -> T { return a + b; };
TransformFunctor<decltype(sum_func), T, TransformFunctor<decltype(sum_func), T,
...@@ -155,6 +155,9 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -155,6 +155,9 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
auto sum_pd = handler.AcquireSumPrimitiveDescriptor( auto sum_pd = handler.AcquireSumPrimitiveDescriptor(
{src_x_memory, src_y_memory}, scales, dst_md); {src_x_memory, src_y_memory}, scales, dst_md);
T* z_data = z->mutable_data<T>(ctx.GetPlace(),
sum_pd->dst_primitive_desc().get_size());
auto dst_memory = handler.AcquireDstMemoryFromPrimitive(z_data); auto dst_memory = handler.AcquireDstMemoryFromPrimitive(z_data);
std::vector<primitive::at> inputs({*src_x_memory, *src_y_memory}); std::vector<primitive::at> inputs({*src_x_memory, *src_y_memory});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册