From b2d0ee5159bb824e07b30dde12edaf941f8036cb Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Mon, 6 Aug 2018 09:42:56 +0000 Subject: [PATCH] refine elementwise_add op --- paddle/fluid/operators/elementwise_add_op.cu | 54 ++++++++++++++++++++ paddle/fluid/operators/elementwise_add_op.h | 29 +++++++++-- 2 files changed, 78 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/elementwise_add_op.cu b/paddle/fluid/operators/elementwise_add_op.cu index dfff518f17..6cbf6066c9 100644 --- a/paddle/fluid/operators/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise_add_op.cu @@ -16,6 +16,60 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise_add_op.h" #include "paddle/fluid/platform/float16.h" +namespace paddle { +namespace operators { + +template +__global__ void ElementwiseAddCUDAKernel(const T *x, const T *y, T *z, int n, + int post, int size) { + int idx_x = threadIdx.x + blockIdx.x * blockDim.x; + if (idx_x < size) { + int idx_y = idx_x / post - (idx_x / (n * post)) * n; + z[idx_x] = x[idx_x] + y[idx_y]; + } +} + +template +class ElementwiseAddKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + using Tensor = framework::Tensor; + + const auto x = ctx.Input("X"); + const auto y = ctx.Input("Y"); + auto z = ctx.Output("Out"); + auto *z_data = z->mutable_data(ctx.GetPlace()); + + auto &device = *(ctx.cuda_device_context().eigen_device()); + const framework::DDim &x_dim = x->dims(); + framework::DDim y_dim = y->dims(); + int size = x->numel(); + if (x_dim == y_dim) { + auto dim = framework::make_ddim({size}); + auto z_eigen = framework::EigenTensor::From(*z, dim); + auto x_eigen = framework::EigenTensor::From(*x, dim); + auto y_eigen = framework::EigenTensor::From(*y, dim); + z_eigen.device(device) = x_eigen + y_eigen; + } else { + int axis = ctx.Attr("axis"); + axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis); + y_dim = trim_trailing_singular_dims(y_dim); + axis = (y_dim.size() == 0) ? x_dim.size() : axis; + int pre, n, post; + get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post); + int threads = 512; + int grids = (size + threads - 1) / threads; + auto stream = ctx.cuda_device_context().stream(); + ElementwiseAddCUDAKernel<<>>( + x->data(), y->data(), z_data, n, post, size); + } + } +}; + +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; namespace plat = paddle::platform; diff --git a/paddle/fluid/operators/elementwise_add_op.h b/paddle/fluid/operators/elementwise_add_op.h index baf04c30b1..ccda870618 100644 --- a/paddle/fluid/operators/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise_add_op.h @@ -142,16 +142,35 @@ class ElementwiseAddGradKernel : public framework::OpKernel { auto* x = ctx.Input("X"); auto* y = ctx.Input("Y"); - auto* out = ctx.Input("Out"); auto* dout = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); - if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) { - elementwise_add_grad(ctx, x, y, out, dout, dx, dy); + if (dx != nullptr) dx->ShareDataWith(*dout); + if (dy == nullptr) return; + + if (x->dims() == y->dims()) { + dy->ShareDataWith(*dout); } else { - default_elementwise_add_grad(ctx, x, y, out, dout, dx, - dy); + dy->mutable_data(ctx.GetPlace()); + // Perform reduction to dout to calculate dy + const framework::DDim& x_dim = x->dims(); + framework::DDim y_dim = y->dims(); + int axis = ctx.Attr("axis"); + axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis); + y_dim = trim_trailing_singular_dims(y_dim); + axis = (y_dim.size() == 0) ? x_dim.size() : axis; + + auto* device = + ctx.template device_context().eigen_device(); + int pre, n, post; + get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post); + auto eigen_dout = framework::EigenTensor::From( + *dout, framework::make_ddim({pre, n, post})); + auto eigen_dy = + framework::EigenTensor::From(*dy, framework::make_ddim({n})); + eigen_dy.device(*device) = eigen_dout.sum( + framework::EigenDim<2>::From(framework::make_ddim({0, 2}))); } } }; -- GitLab