From 70b9f2ac361cf5a5d060136bcc6989f784d213f0 Mon Sep 17 00:00:00 2001 From: taixiurong Date: Fri, 18 Feb 2022 11:03:40 +0800 Subject: [PATCH] dropout support Seed, fix elementwise_add_grad bug, test=kunlun (#39656) --- paddle/fluid/operators/dropout_op_xpu.cc | 36 +++++----- .../elementwise/elementwise_add_op_xpu.cc | 72 +++---------------- 2 files changed, 27 insertions(+), 81 deletions(-) diff --git a/paddle/fluid/operators/dropout_op_xpu.cc b/paddle/fluid/operators/dropout_op_xpu.cc index 2f3f7e05f8c..51aa0b78c9a 100644 --- a/paddle/fluid/operators/dropout_op_xpu.cc +++ b/paddle/fluid/operators/dropout_op_xpu.cc @@ -32,20 +32,18 @@ class DropoutXPUKernel : public framework::OpKernel { context.Attr("dropout_implementation"); auto& dev_ctx = context.template device_context(); - PADDLE_ENFORCE_EQ(!context.HasInput("Seed"), true, - platform::errors::InvalidArgument( - ("Input(Seed) not supported on XPU"))); + auto* seed = + context.HasInput("Seed") ? context.Input("Seed") : nullptr; + int is_upscale = (dropout_implementation == "upscale_in_train"); if (!context.Attr("is_test")) { - std::random_device rnd; - // int seed = (context.Attr("fix_seed")) ? - // int(context.Attr("seed")) : (rnd()); - int seed = 0; - if (context.Attr("fix_seed") == true) { - seed = static_cast(context.Attr("seed")); + int seed_data = 0; + if (seed) { + seed_data = *(seed->data()); } else { - seed = rnd(); + seed_data = + context.Attr("fix_seed") ? context.Attr("seed") : 0; } auto* mask = context.Output("Mask"); @@ -55,26 +53,26 @@ class DropoutXPUKernel : public framework::OpKernel { int r = xpu::constant(dev_ctx.x_context(), reinterpret_cast(y_data), y->numel(), XPUTyp(0)); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant "); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); r = xpu::constant(dev_ctx.x_context(), reinterpret_cast(mask_data), mask->numel(), XPUTyp(0)); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant "); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); return; } int r = xpu::dropout(dev_ctx.x_context(), reinterpret_cast(x->data()), reinterpret_cast(y->data()), - reinterpret_cast(mask_data), seed, + reinterpret_cast(mask_data), seed_data, mask->numel(), is_upscale, dropout_prob); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout "); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout"); } else { float scale = (is_upscale) ? (1.0) : (static_cast(1.0f - dropout_prob)); int r = xpu::scale( dev_ctx.x_context(), reinterpret_cast(x_data), reinterpret_cast(y_data), x->numel(), false, scale, 0.0f); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale "); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); } } }; @@ -103,7 +101,7 @@ class DropoutGradXPUKernel : public framework::OpKernel { reinterpret_cast(mask_data), reinterpret_cast(grad_x->data()), grad_y->numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul "); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul"); return; } @@ -117,13 +115,13 @@ class DropoutGradXPUKernel : public framework::OpKernel { reinterpret_cast(mask->data()), reinterpret_cast(mask_new), mask->numel(), false, scale, 0.0f); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale "); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); r = xpu::mul(dev_ctx.x_context(), reinterpret_cast(grad_y->data()), reinterpret_cast(mask_new), reinterpret_cast(grad_x->data()), grad_y->numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul "); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul"); } else { int r = xpu::dropout_grad(dev_ctx.x_context(), @@ -131,7 +129,7 @@ class DropoutGradXPUKernel : public framework::OpKernel { reinterpret_cast(grad_y->data()), reinterpret_cast(grad_x->data()), dropout_prob, grad_y->numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout_grad "); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout_grad"); } } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op_xpu.cc b/paddle/fluid/operators/elementwise/elementwise_add_op_xpu.cc index 6167452728a..3df2c7d05d4 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op_xpu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op_xpu.cc @@ -34,17 +34,6 @@ class ElementwiseAddXPUKernel : public framework::OpKernel { } }; -static std::vector get_rdims(const std::vector& xdims, - const std::vector& ydims) { - std::vector rdims; - for (size_t i = 0; i < xdims.size(); i++) { - if (xdims[i] != ydims[i]) { - rdims.push_back(i); - } - } - return rdims; -} - template class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel { using XPUType = typename XPUTypeTrait::Type; @@ -53,64 +42,19 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel { void Compute(const framework::ExecutionContext& ctx) const override { ElemwiseGradKernel::Compute(ctx); auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); auto* dz = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); - const framework::DDim& x_dims = x->dims(); - const framework::DDim& y_dims = y->dims(); const framework::DDim& dz_dims = dz->dims(); int axis = ctx.Attr("axis"); - axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); - int max_dim = std::max(x_dims.size(), y_dims.size()); - PADDLE_ENFORCE_GE( - axis, 0, - platform::errors::InvalidArgument( - "Axis should be great than or equal to 0, but received axis is %d.", - axis)); - PADDLE_ENFORCE_LT( - axis, max_dim, - platform::errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", max_dim, - axis)); - - std::vector x_dims_vec(max_dim, 1); - std::vector y_dims_vec(max_dim, 1); - std::vector z_dims_vec(max_dim, 1); - if (x_dims.size() == max_dim) { - for (int i = 0; i < max_dim; i++) { - x_dims_vec[i] = x_dims[i]; - } - } else { - for (int i = 0; i < x_dims.size(); i++) { - x_dims_vec[i + axis] = x_dims[i]; - } - } - if (y_dims.size() == max_dim) { - for (int i = 0; i < max_dim; i++) { - y_dims_vec[i] = y_dims[i]; - } - } else { - for (int i = 0; i < y_dims.size(); i++) { - y_dims_vec[i + axis] = y_dims[i]; - } - } - - for (int i = 0; i < max_dim; i++) { - z_dims_vec[i] = dz_dims[i]; - } - std::vector rdims_for_x; - std::vector rdims_for_y; - rdims_for_x = get_rdims(x_dims_vec, z_dims_vec); - rdims_for_y = get_rdims(y_dims_vec, z_dims_vec); const T* dz_data = dz->data(); auto& dev_ctx = ctx.template device_context(); if (dx != nullptr) { T* dx_data = dx->mutable_data(ctx.GetPlace()); - if (rdims_for_x.size() == 0) { + if (dx->dims() == dz_dims) { if (dx_data != dz_data) { framework::TensorCopy( *dz, ctx.GetPlace(), @@ -123,27 +67,31 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel { dx->clear(); dx->mutable_data(x->dims(), ctx.GetPlace()); } + std::vector reduce_dims = GetReduceDim(dx->dims(), dz_dims, axis); + std::vector dz_vector = framework::vectorize(dz_dims); int ret = xpu::reduce_sum( dev_ctx.x_context(), reinterpret_cast(dz_data), - reinterpret_cast(dx_data), z_dims_vec, rdims_for_x); - PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum "); + reinterpret_cast(dx_data), dz_vector, reduce_dims); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum"); } } if (dy != nullptr) { T* dy_data = dy->mutable_data(ctx.GetPlace()); - if (rdims_for_y.size() == 0) { + if (dy->dims() == dz_dims) { if (dy_data != dz_data) { framework::TensorCopy( *dz, ctx.GetPlace(), ctx.template device_context(), dy); } } else { + std::vector reduce_dims = GetReduceDim(dy->dims(), dz_dims, axis); + std::vector dz_vector = framework::vectorize(dz_dims); int ret = xpu::reduce_sum( dev_ctx.x_context(), reinterpret_cast(dz_data), - reinterpret_cast(dy_data), z_dims_vec, rdims_for_y); - PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum "); + reinterpret_cast(dy_data), dz_vector, reduce_dims); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum"); } } } -- GitLab