未验证 提交 1e52f324 编写于 作者: J jiangcheng 提交者: GitHub

Optimize elementwise_add_grad op, test=develop (#32051)

上级 36687d7a
...@@ -112,6 +112,22 @@ elementwise_add_grad(const framework::ExecutionContext& ctx, ...@@ -112,6 +112,22 @@ elementwise_add_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* out, const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx, const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy) { framework::Tensor* dy) {
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
auto* dy_data = dy->mutable_data<T>(ctx.GetPlace());
auto* dout_data = dout->data<T>();
if (dx_data == dout_data && dy_data != dout_data) {
VLOG(4) << "Special case when dx_data is the same as dout_data, "
"only need copy dout to dy";
framework::TensorCopy(
*dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dy);
} else if (dx_data != dout_data && dy_data == dout_data) {
VLOG(4) << "Special case when dy_data is the same as dout_data, "
"only need copy dout to dx";
framework::TensorCopy(
*dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dx);
} else if (dx_data != dout_data && dy_data != dout_data) {
auto size = x->numel(); auto size = x->numel();
int vec_size = max(static_cast<int>(sizeof(float4) / sizeof(T)), 1); int vec_size = max(static_cast<int>(sizeof(float4) / sizeof(T)), 1);
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
...@@ -124,6 +140,11 @@ elementwise_add_grad(const framework::ExecutionContext& ctx, ...@@ -124,6 +140,11 @@ elementwise_add_grad(const framework::ExecutionContext& ctx,
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>( ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
dout->data<T>(), size, vec_size, dx->mutable_data<T>(ctx.GetPlace()), dout->data<T>(), size, vec_size, dx->mutable_data<T>(ctx.GetPlace()),
dy->mutable_data<T>(ctx.GetPlace())); dy->mutable_data<T>(ctx.GetPlace()));
} else {
VLOG(4) << "Special case when dy_data is the same as dout_data, "
"and dx_data is the same as dout_data, do not need "
"any operator";
}
} }
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册