diff --git a/paddle/fluid/operators/elementwise_add_op.cu b/paddle/fluid/operators/elementwise_add_op.cu index dfff518f170b56d180b6883c363effb8dbd677b6..6cbf6066c92b02eb75922587f9da5192bed15580 100644 --- a/paddle/fluid/operators/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise_add_op.cu @@ -16,6 +16,60 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise_add_op.h" #include "paddle/fluid/platform/float16.h" +namespace paddle { +namespace operators { + +template +__global__ void ElementwiseAddCUDAKernel(const T *x, const T *y, T *z, int n, + int post, int size) { + int idx_x = threadIdx.x + blockIdx.x * blockDim.x; + if (idx_x < size) { + int idx_y = idx_x / post - (idx_x / (n * post)) * n; + z[idx_x] = x[idx_x] + y[idx_y]; + } +} + +template +class ElementwiseAddKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + using Tensor = framework::Tensor; + + const auto x = ctx.Input("X"); + const auto y = ctx.Input("Y"); + auto z = ctx.Output("Out"); + auto *z_data = z->mutable_data(ctx.GetPlace()); + + auto &device = *(ctx.cuda_device_context().eigen_device()); + const framework::DDim &x_dim = x->dims(); + framework::DDim y_dim = y->dims(); + int size = x->numel(); + if (x_dim == y_dim) { + auto dim = framework::make_ddim({size}); + auto z_eigen = framework::EigenTensor::From(*z, dim); + auto x_eigen = framework::EigenTensor::From(*x, dim); + auto y_eigen = framework::EigenTensor::From(*y, dim); + z_eigen.device(device) = x_eigen + y_eigen; + } else { + int axis = ctx.Attr("axis"); + axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis); + y_dim = trim_trailing_singular_dims(y_dim); + axis = (y_dim.size() == 0) ? x_dim.size() : axis; + int pre, n, post; + get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post); + int threads = 512; + int grids = (size + threads - 1) / threads; + auto stream = ctx.cuda_device_context().stream(); + ElementwiseAddCUDAKernel<<>>( + x->data(), y->data(), z_data, n, post, size); + } + } +}; + +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; namespace plat = paddle::platform; diff --git a/paddle/fluid/operators/elementwise_add_op.h b/paddle/fluid/operators/elementwise_add_op.h index 5356105e2e551c0528694091608fc7585dce66d2..0b19723720171a857c946880c246e2247a0023a7 100644 --- a/paddle/fluid/operators/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise_add_op.h @@ -144,16 +144,41 @@ class ElementwiseAddGradKernel : public framework::OpKernel { auto* dout = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); - // skip out, x, y - auto* out = dout; - auto *x = dout, *y = dout; - if (platform::is_cpu_place(ctx.GetPlace()) && dx != nullptr && - dy != nullptr && (dx->dims() == dy->dims())) { - elementwise_add_grad(ctx, x, y, out, dout, dx, dy); + if (dx != nullptr) { + // In fact, we can just share memory, but it may cause a bug of memory + // optimizer + // dx->ShareDataWith(*dout); + framework::TensorCopy(*dout, ctx.GetPlace(), + ctx.template device_context(), dx); + } + + if (dy == nullptr) return; + + const framework::DDim& x_dim = dout->dims(); + framework::DDim y_dim = dy->dims(); + if (x_dim == y_dim) { + // dy->ShareDataWith(*dout); + framework::TensorCopy(*dout, ctx.GetPlace(), + ctx.template device_context(), dy); } else { - default_elementwise_add_grad(ctx, x, y, out, dout, dx, - dy); + dy->mutable_data(ctx.GetPlace()); + // Perform reduction to dout to calculate dy + int axis = ctx.Attr("axis"); + axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis); + y_dim = trim_trailing_singular_dims(y_dim); + axis = (y_dim.size() == 0) ? x_dim.size() : axis; + + auto& device = + *(ctx.template device_context().eigen_device()); + int pre, n, post; + get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post); + auto eigen_dout = framework::EigenTensor::From( + *dout, framework::make_ddim({pre, n, post})); + auto eigen_dy = + framework::EigenTensor::From(*dy, framework::make_ddim({n})); + eigen_dy.device(device) = eigen_dout.sum( + framework::EigenDim<2>::From(framework::make_ddim({0, 2}))); } } };