未验证 提交 08ee7485 编写于 作者: L Leo Chen 提交者: GitHub

refine reshape grad and double grad kernel, use tensor copy async (#29128) (#29446)

上级 14cf420e
......@@ -405,7 +405,9 @@ class ReshapeGradKernel {
auto in_dims = d_x->dims();
d_x->mutable_data(ctx.GetPlace(), d_out->type());
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
framework::TensorCopy(
*d_out, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), d_x);
d_x->Resize(in_dims);
}
};
......@@ -419,7 +421,9 @@ class ReshapeDoubleGradKernel {
auto out_dims = dd_out->dims();
dd_out->mutable_data(ctx.GetPlace(), dd_x->type());
framework::TensorCopySync(*dd_x, ctx.GetPlace(), dd_out);
framework::TensorCopy(
*dd_x, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dd_out);
dd_out->Resize(out_dims);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册