From 1e52f324318bfe31e6dc43bcf5eeb682a44ec5d7 Mon Sep 17 00:00:00 2001 From: jiangcheng Date: Sat, 3 Apr 2021 11:07:07 +0800 Subject: [PATCH] Optimize elementwise_add_grad op, test=develop (#32051) --- .../elementwise/elementwise_add_op.cu | 45 ++++++++++++++----- 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index 68fd81f826..313607d975 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -112,18 +112,39 @@ elementwise_add_grad(const framework::ExecutionContext& ctx, const framework::Tensor* out, const framework::Tensor* dout, framework::Tensor* dx, framework::Tensor* dy) { - auto size = x->numel(); - int vec_size = max(static_cast(sizeof(float4) / sizeof(T)), 1); - dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); - dim3 grid_size = - dim3(((size + vec_size - 1) / vec_size + PADDLE_CUDA_THREAD_SIZE - 1) / - PADDLE_CUDA_THREAD_SIZE, - 1); - SimpleElemwiseAddGradCUDAKernel< - T><<().stream()>>>( - dout->data(), size, vec_size, dx->mutable_data(ctx.GetPlace()), - dy->mutable_data(ctx.GetPlace())); + auto* dx_data = dx->mutable_data(ctx.GetPlace()); + auto* dy_data = dy->mutable_data(ctx.GetPlace()); + auto* dout_data = dout->data(); + if (dx_data == dout_data && dy_data != dout_data) { + VLOG(4) << "Special case when dx_data is the same as dout_data, " + "only need copy dout to dy"; + framework::TensorCopy( + *dout, ctx.GetPlace(), + ctx.template device_context(), dy); + } else if (dx_data != dout_data && dy_data == dout_data) { + VLOG(4) << "Special case when dy_data is the same as dout_data, " + "only need copy dout to dx"; + framework::TensorCopy( + *dout, ctx.GetPlace(), + ctx.template device_context(), dx); + } else if (dx_data != dout_data && dy_data != dout_data) { + auto size = x->numel(); + int vec_size = max(static_cast(sizeof(float4) / sizeof(T)), 1); + dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); + dim3 grid_size = + dim3(((size + vec_size - 1) / vec_size + PADDLE_CUDA_THREAD_SIZE - 1) / + PADDLE_CUDA_THREAD_SIZE, + 1); + SimpleElemwiseAddGradCUDAKernel< + T><<().stream()>>>( + dout->data(), size, vec_size, dx->mutable_data(ctx.GetPlace()), + dy->mutable_data(ctx.GetPlace())); + } else { + VLOG(4) << "Special case when dy_data is the same as dout_data, " + "and dx_data is the same as dout_data, do not need " + "any operator"; + } } } // namespace operators -- GitLab