From a29b4227eb2ffc2905b287e1de8b705a8dbc7cf5 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Thu, 20 Sep 2018 01:53:48 +0000 Subject: [PATCH] fix sparse gradient clip --- paddle/fluid/operators/clip_op.h | 43 ++++++++++++++----- .../operators/math/selected_rows_functor.cu | 9 ++-- 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/clip_op.h b/paddle/fluid/operators/clip_op.h index 85607a6b0..daf06f370 100644 --- a/paddle/fluid/operators/clip_op.h +++ b/paddle/fluid/operators/clip_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/transform.h" namespace paddle { @@ -61,14 +62,32 @@ class ClipKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto max = context.Attr("max"); auto min = context.Attr("min"); - auto* x = context.Input("X"); - auto* out = context.Output("Out"); - T* out_data = out->mutable_data(context.GetPlace()); - const T* x_data = x->data(); - int64_t numel = x->numel(); - Transform trans; - trans(context.template device_context(), x_data, - x_data + numel, out_data, ClipFunctor(min, max)); + auto* x_var = context.InputVar("X"); + if (x_var->IsType()) { + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + T* out_data = out->mutable_data(context.GetPlace()); + const T* x_data = x->data(); + int64_t numel = x->numel(); + Transform trans; + trans(context.template device_context(), x_data, + x_data + numel, out_data, ClipFunctor(min, max)); + } else if (x_var->IsType()) { + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + PADDLE_ENFORCE_NE(x, out, + "Inplace clip is not allowed when x is SelectedRows"); + math::scatter::MergeAdd merge_func; + merge_func(context.template device_context(), *x, out); + auto* out_tensor = out->mutable_value(); + auto* out_data = out_tensor->data(); + int64_t numel = out_tensor->numel(); + Transform trans; + trans(context.template device_context(), out_data, + out_data + numel, out_data, ClipFunctor(min, max)); + } else { + PADDLE_THROW("ClipOp only supports LoDTensor and SelectedRows"); + } } }; @@ -78,10 +97,12 @@ class ClipGradKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto max = context.Attr("max"); auto min = context.Attr("min"); - auto* d_out = context.Input(framework::GradVarName("Out")); - auto* d_x = context.Output(framework::GradVarName("X")); + auto* d_out = + context.Input(framework::GradVarName("Out")); + auto* d_x = + context.Output(framework::GradVarName("X")); if (d_x != nullptr) { - auto* x = context.Input("X"); + auto* x = context.Input("X"); int64_t numel = d_out->numel(); auto* d_x_data = d_x->mutable_data(context.GetPlace()); const T* d_out_data = d_out->data(); diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index 94258f662..f0f723dd1 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -236,7 +236,7 @@ template __global__ void MergeAddKernel(const T* input, const int64_t* input_rows, T* out, const int64_t* out_rows, size_t out_rows_size, int64_t row_numel) { - const int ty = blockIdx.y; + const int ty = blockIdx.x; int tid = threadIdx.x; __shared__ size_t out_idx; @@ -291,12 +291,9 @@ struct MergeAdd { const int block_size = 256; dim3 threads(block_size, 1); - dim3 grid1(1, input_rows.size()); + dim3 grid1(input_rows.size(), 1); - MergeAddKernel< - T, 256><<(context) - .stream()>>>( + MergeAddKernel<<>>( input_data, input_rows.CUDAData(context.GetPlace()), out_data, out.mutable_rows()->CUDAMutableData(context.GetPlace()), out.rows().size(), input_width); -- GitLab