未验证 提交 1e52f324 编写于 作者: J jiangcheng 提交者: GitHub

Optimize elementwise_add_grad op, test=develop (#32051)

上级 36687d7a
......@@ -112,18 +112,39 @@ elementwise_add_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy) {
auto size = x->numel();
int vec_size = max(static_cast<int>(sizeof(float4) / sizeof(T)), 1);
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
dim3 grid_size =
dim3(((size + vec_size - 1) / vec_size + PADDLE_CUDA_THREAD_SIZE - 1) /
PADDLE_CUDA_THREAD_SIZE,
1);
SimpleElemwiseAddGradCUDAKernel<
T><<<grid_size, block_size, 0,
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
dout->data<T>(), size, vec_size, dx->mutable_data<T>(ctx.GetPlace()),
dy->mutable_data<T>(ctx.GetPlace()));
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
auto* dy_data = dy->mutable_data<T>(ctx.GetPlace());
auto* dout_data = dout->data<T>();
if (dx_data == dout_data && dy_data != dout_data) {
VLOG(4) << "Special case when dx_data is the same as dout_data, "
"only need copy dout to dy";
framework::TensorCopy(
*dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dy);
} else if (dx_data != dout_data && dy_data == dout_data) {
VLOG(4) << "Special case when dy_data is the same as dout_data, "
"only need copy dout to dx";
framework::TensorCopy(
*dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dx);
} else if (dx_data != dout_data && dy_data != dout_data) {
auto size = x->numel();
int vec_size = max(static_cast<int>(sizeof(float4) / sizeof(T)), 1);
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
dim3 grid_size =
dim3(((size + vec_size - 1) / vec_size + PADDLE_CUDA_THREAD_SIZE - 1) /
PADDLE_CUDA_THREAD_SIZE,
1);
SimpleElemwiseAddGradCUDAKernel<
T><<<grid_size, block_size, 0,
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
dout->data<T>(), size, vec_size, dx->mutable_data<T>(ctx.GetPlace()),
dy->mutable_data<T>(ctx.GetPlace()));
} else {
VLOG(4) << "Special case when dy_data is the same as dout_data, "
"and dx_data is the same as dout_data, do not need "
"any operator";
}
}
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册