diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index 3eba268cfa9712e4bc5475dd44076bc768552bce..0c2e6d40241218a6697b896262608b3a002e0571 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include @@ -301,6 +302,9 @@ struct MergeAdd { } std::vector merge_rows(merged_row_set.begin(), merged_row_set.end()); + if (sorted_result_) { + std::sort(merge_rows); + } std::unordered_map rows_to_id; for (size_t i = 0; i < merge_rows.size(); ++i) { rows_to_id[merge_rows[i]] = i; diff --git a/paddle/fluid/operators/math/selected_rows_functor.h b/paddle/fluid/operators/math/selected_rows_functor.h index 6d146d39d6d07678e859b82b25ba60ed7661546d..b7b19f130e5269629c27ae51b5b346ef1a789e3e 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.h +++ b/paddle/fluid/operators/math/selected_rows_functor.h @@ -78,6 +78,10 @@ namespace scatter { // functors for manuplating SelectedRows data template struct MergeAdd { + MergeAdd() : sorted_result_(false) {} + + explicit MergeAdd(bool sorted_result) : sorted_result_(sorted_result) {} + // unary functor, merge by adding duplicated rows in // the input SelectedRows object. framework::SelectedRows operator()(const DeviceContext& context, @@ -88,6 +92,9 @@ struct MergeAdd { void operator()(const DeviceContext& context, const std::vector& inputs, framework::SelectedRows* output); + + private: + bool sorted_result_; }; enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY }; diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 3455d1ee54e8e6e498d0b0e6932ec099af9c0b30..c2bf7040d77315030b5108edb64dd1bb914d81d8 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -158,7 +158,7 @@ struct AdamFunctor { }; template -struct SparseAdamFunctor { +struct SparseAdamFunctor { T beta1_; T beta2_; T epsilon_; @@ -227,6 +227,78 @@ struct SparseAdamFunctor { } }; +template +struct SparseAdamFunctor { + T beta1_; + T beta2_; + T epsilon_; + + const T* beta1_pow_; + const T* beta2_pow_; + const T* moment1_; + T* moment1_out_; + const T* moment2_; + T* moment2_out_; + const T* lr_; + const T* grad_; + const T* param_; + T* param_out_; + + const int64_t* rows_; + int64_t row_numel_; + int64_t row_count_; + + SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow, + const T* beta2_pow, const T* mom1, T* mom1_out, + const T* mom2, T* mom2_out, const T* lr, const T* grad, + const T* param, T* param_out, const int64_t* rows, + int64_t row_numel, int64_t row_count) + : beta1_(beta1), + beta2_(beta2), + epsilon_(epsilon), + beta1_pow_(beta1_pow), + beta2_pow_(beta2_pow), + moment1_(mom1), + moment1_out_(mom1_out), + moment2_(mom2), + moment2_out_(mom2_out), + lr_(lr), + grad_(grad), + param_(param), + param_out_(param_out), + rows_(rows), + row_numel_(row_numel), + row_count_(row_count) {} + + inline void operator()(size_t numel) const { + // lr could be reuse + T lr = *lr_; + T beta1_pow = *beta1_pow_; + T beta2_pow = *beta2_pow_; + lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); + for (size_t i = 0U, j = 0U; i != numel; ++i) { + T mom1 = moment1_[i]; + T mom2 = moment2_[i]; + T p = param_[i]; + + // Calculation + if (i == *(rows_ + j)) { + mom1 = beta1_ * mom1 + (1 - beta1_) * g; + mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; + ++j; + } else { + mom1 = beta1_ * mom1; + mom2 = beta2_ * mom2; + } + p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); + // Write back to global memory + moment1_out_[i] = mom1; + moment2_out_[i] = mom2; + param_out_[i] = p; + } + } +}; + template class AdamOpKernel : public framework::OpKernel { public: @@ -316,7 +388,7 @@ class AdamOpKernel : public framework::OpKernel { } else { // merge duplicated rows if any. // The rows of grad_merge have been sorted inside MergeAdd functor - scatter::MergeAdd merge_func; + scatter::MergeAdd merge_func(true); auto* grad_merge_var = const_cast(ctx.scope()) .Var() ->GetMutable(); @@ -337,25 +409,40 @@ class AdamOpKernel : public framework::OpKernel { } else { #endif rows = grad_merge.rows().data(); - #if defined(PADDLE_WITH_CUDA) } #endif auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); - SparseAdamFunctor functor( - beta1, beta2, epsilon, beta1_pow.template data(), - beta2_pow.template data(), mom1.template data(), - mom1_out.template mutable_data(ctx.GetPlace()), - mom2.template data(), - mom2_out.template mutable_data(ctx.GetPlace()), - lr.template data(), grad_data, param.template data(), - param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, - grad_merge.rows().size()); - platform::ForRange for_range( - static_cast(ctx.device_context()), - param.numel()); - for_range(functor); + if (platform::is_cpu_place(ctx.GetPlace())) { + SparseAdamFunctor functor( + beta1, beta2, epsilon, beta1_pow.template data(), + beta2_pow.template data(), mom1.template data(), + mom1_out.template mutable_data(ctx.GetPlace()), + mom2.template data(), + mom2_out.template mutable_data(ctx.GetPlace()), + lr.template data(), grad_data, param.template data(), + param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, + grad_merge.rows().size()); + + functor(param.numel()); + } else if (platform::is_gpu_place(ctx.GetPlace())) { + SparseAdamFunctor functor( + beta1, beta2, epsilon, beta1_pow.template data(), + beta2_pow.template data(), mom1.template data(), + mom1_out.template mutable_data(ctx.GetPlace()), + mom2.template data(), + mom2_out.template mutable_data(ctx.GetPlace()), + lr.template data(), grad_data, param.template data(), + param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, + grad_merge.rows().size()); + + // FIXME(minqiyang): remove BinarySearch in GPU later + platform::ForRange for_range( + static_cast(ctx.device_context()), + param.numel()); + for_range(functor); + } } else { PADDLE_THROW("Variable type not supported by adam_op"); }