diff --git a/paddle/fluid/operators/adam_op.h b/paddle/fluid/operators/adam_op.h index 84a584f424823a450effd4c36e9da600f5851da2..5b27068c9e805146b8bce03f4f676ef0d4d16c53 100644 --- a/paddle/fluid/operators/adam_op.h +++ b/paddle/fluid/operators/adam_op.h @@ -174,12 +174,13 @@ struct SparseAdamFunctor { const int64_t* rows_; int64_t row_numel_; + int64_t row_count_; SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow, const T* beta2_pow, const T* mom1, T* mom1_out, const T* mom2, T* mom2_out, const T* lr, const T* grad, const T* param, T* param_out, const int64_t* rows, - int64_t row_numel) + int64_t row_numel, int64_t row_count) : beta1_(beta1), beta2_(beta2), epsilon_(epsilon), @@ -194,28 +195,47 @@ struct SparseAdamFunctor { param_(param), param_out_(param_out), rows_(rows), - row_numel_(row_numel) {} + row_numel_(row_numel), + row_count_(row_count) {} + + inline HOSTDEVICE int64_t BinarySearchInRows(int64_t row) const { + int64_t beg = 0, end = row_count_ - 1; + while (beg <= end) { + auto mid = ((beg + end) >> 1); + if (rows_[mid] == row) + return mid; + else if (rows_[mid] < row) + beg = mid + 1; + else + end = mid - 1; + } + return -1; + } inline HOSTDEVICE void operator()(size_t i) const { + int64_t row = i / row_numel_; + auto row_idx = BinarySearchInRows(row); + T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0; + + // The following code is the same as dense + T mom1 = moment1_[i]; + T mom2 = moment2_[i]; + T lr = *lr_; T beta1_pow = *beta1_pow_; T beta2_pow = *beta2_pow_; - for (int64_t j = 0; j < row_numel_; ++j) { - T g = grad_[i * row_numel_ + j]; - T mom1 = moment1_[rows_[i] * row_numel_ + j]; - T mom2 = moment2_[rows_[i] * row_numel_ + j]; - T lr = *lr_; - T p = param_[rows_[i] * row_numel_ + j]; - - lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); - - mom1 = beta1_ * mom1 + (1 - beta1_) * g; - mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; - p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); - - moment1_out_[rows_[i] * row_numel_ + j] = mom1; - moment2_out_[rows_[i] * row_numel_ + j] = mom2; - param_out_[rows_[i] * row_numel_ + j] = p; - } // for col id + T p = param_[i]; + + // Calculation + lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); + + mom1 = beta1_ * mom1 + (1 - beta1_) * g; + mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; + p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); + + // Write back to global memory + moment1_out_[i] = mom1; + moment2_out_[i] = mom2; + param_out_[i] = p; } }; @@ -287,9 +307,14 @@ class AdamOpKernel : public framework::OpKernel { return; } // merge duplicated rows if any. + // The rows of grad_merge have been sorted inside MergeAdd functor scatter::MergeAdd merge_func; - auto grad_merge = - merge_func(ctx.template device_context(), grad); + auto& grad_merge = *(ctx.scope() + .NewScope() + .Var("sparse_adam_grad_merge") + ->GetMutable()); + merge_func(ctx.template device_context(), grad, + &grad_merge); auto& grad_tensor = grad_merge.value(); const T* grad_data = grad_tensor.template data(); int64_t* rows = nullptr; @@ -314,10 +339,11 @@ class AdamOpKernel : public framework::OpKernel { mom2.template data(), mom2_out.template mutable_data(ctx.GetPlace()), lr.template data(), grad_data, param.template data(), - param_out.template mutable_data(ctx.GetPlace()), rows, row_numel); + param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, + grad_merge.rows().size()); platform::ForRange for_range( static_cast(ctx.device_context()), - grad_merge.rows().size()); + param.numel()); for_range(functor); } else { PADDLE_THROW("Variable type not supported by adam_op"); diff --git a/paddle/fluid/operators/clip_op.h b/paddle/fluid/operators/clip_op.h index 85607a6b0e4239b063ba75888e6859a5d85eafc9..daf06f370ffb591e25ad846b94c8284aad19a8dd 100644 --- a/paddle/fluid/operators/clip_op.h +++ b/paddle/fluid/operators/clip_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/transform.h" namespace paddle { @@ -61,14 +62,32 @@ class ClipKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto max = context.Attr("max"); auto min = context.Attr("min"); - auto* x = context.Input("X"); - auto* out = context.Output("Out"); - T* out_data = out->mutable_data(context.GetPlace()); - const T* x_data = x->data(); - int64_t numel = x->numel(); - Transform trans; - trans(context.template device_context(), x_data, - x_data + numel, out_data, ClipFunctor(min, max)); + auto* x_var = context.InputVar("X"); + if (x_var->IsType()) { + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + T* out_data = out->mutable_data(context.GetPlace()); + const T* x_data = x->data(); + int64_t numel = x->numel(); + Transform trans; + trans(context.template device_context(), x_data, + x_data + numel, out_data, ClipFunctor(min, max)); + } else if (x_var->IsType()) { + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + PADDLE_ENFORCE_NE(x, out, + "Inplace clip is not allowed when x is SelectedRows"); + math::scatter::MergeAdd merge_func; + merge_func(context.template device_context(), *x, out); + auto* out_tensor = out->mutable_value(); + auto* out_data = out_tensor->data(); + int64_t numel = out_tensor->numel(); + Transform trans; + trans(context.template device_context(), out_data, + out_data + numel, out_data, ClipFunctor(min, max)); + } else { + PADDLE_THROW("ClipOp only supports LoDTensor and SelectedRows"); + } } }; @@ -78,10 +97,12 @@ class ClipGradKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto max = context.Attr("max"); auto min = context.Attr("min"); - auto* d_out = context.Input(framework::GradVarName("Out")); - auto* d_x = context.Output(framework::GradVarName("X")); + auto* d_out = + context.Input(framework::GradVarName("Out")); + auto* d_x = + context.Output(framework::GradVarName("X")); if (d_x != nullptr) { - auto* x = context.Input("X"); + auto* x = context.Input("X"); int64_t numel = d_out->numel(); auto* d_x_data = d_x->mutable_data(context.GetPlace()); const T* d_out_data = d_out->data(); diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index a830dc5250a6aea7e622da4046b512d0c7c5d6f9..8e8baf49b2330e95ff1a868b0b0a03bc10d84484 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -199,6 +199,14 @@ struct MergeAdd { framework::SelectedRows operator()(const platform::CPUDeviceContext& context, const framework::SelectedRows& input) { framework::SelectedRows out; + (*this)(context, input, &out); + return out; + } + + void operator()(const platform::CPUDeviceContext& context, + const framework::SelectedRows& input, + framework::SelectedRows* output) { + framework::SelectedRows& out = *output; auto input_rows = input.rows(); std::set row_set(input_rows.begin(), input_rows.end()); std::vector merge_rows(row_set.begin(), row_set.end()); @@ -223,7 +231,6 @@ struct MergeAdd { out_data[out_i * input_width + j] += input_data[i * input_width + j]; } } - return out; } }; diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index d559aaa7210eca1f169585760b73e7d95b71d281..ae51a53a7197950338ef773d63103fa13bf0a5f5 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -234,7 +234,7 @@ template __global__ void MergeAddKernel(const T* input, const int64_t* input_rows, T* out, const int64_t* out_rows, size_t out_rows_size, int64_t row_numel) { - const int ty = blockIdx.y; + const int ty = blockIdx.x; int tid = threadIdx.x; __shared__ size_t out_idx; @@ -260,6 +260,14 @@ struct MergeAdd { framework::SelectedRows operator()(const platform::CUDADeviceContext& context, const framework::SelectedRows& input) { framework::SelectedRows out; + (*this)(context, input, &out); + return out; + } + + void operator()(const platform::CUDADeviceContext& context, + const framework::SelectedRows& input, + framework::SelectedRows* output) { + framework::SelectedRows& out = *output; framework::Vector input_rows(input.rows()); std::set row_set(input_rows.begin(), input_rows.end()); std::vector merge_rows(row_set.begin(), row_set.end()); @@ -281,16 +289,12 @@ struct MergeAdd { const int block_size = 256; dim3 threads(block_size, 1); - dim3 grid1(1, input_rows.size()); + dim3 grid1(input_rows.size(), 1); - MergeAddKernel< - T, 256><<(context) - .stream()>>>( + MergeAddKernel<<>>( input_data, input_rows.CUDAData(context.GetPlace()), out_data, out.mutable_rows()->CUDAMutableData(context.GetPlace()), out.rows().size(), input_width); - return out; } }; diff --git a/paddle/fluid/operators/math/selected_rows_functor.h b/paddle/fluid/operators/math/selected_rows_functor.h index 18304f83f8706f822ce628e2374b00a71f1cc171..aa419f74fcd2a53cdd734ec270bc154b78c9f2ff 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.h +++ b/paddle/fluid/operators/math/selected_rows_functor.h @@ -65,6 +65,9 @@ struct MergeAdd { // the input SelectedRows object. framework::SelectedRows operator()(const DeviceContext& context, const framework::SelectedRows& input); + void operator()(const DeviceContext& context, + const framework::SelectedRows& input, + framework::SelectedRows* output); }; template