From 4dda18a8b4f1af281483a16d456798ab00aed1db Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Fri, 15 Oct 2021 11:07:29 +0800 Subject: [PATCH] fix momentum ops (#36452) --- .../fluid/operators/optimizers/momentum_op.h | 67 ++++++++++--------- .../unittests/test_merged_momentum_op.py | 9 ++- 2 files changed, 41 insertions(+), 35 deletions(-) diff --git a/paddle/fluid/operators/optimizers/momentum_op.h b/paddle/fluid/operators/optimizers/momentum_op.h index f461dec66c0..2d713308fd9 100644 --- a/paddle/fluid/operators/optimizers/momentum_op.h +++ b/paddle/fluid/operators/optimizers/momentum_op.h @@ -173,14 +173,15 @@ class CPUDenseMomentumFunctor { } }; -template +template class DenseMomentumFunctor; // NOTE(dzh) for performance. // avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two // functor. -template -class DenseMomentumFunctor { +template +class DenseMomentumFunctor { private: const T* param_; const T* grad_; @@ -193,7 +194,6 @@ class DenseMomentumFunctor { T* param_out_; MT* velocity_out_; MT* master_param_out_; - const RegularizationType regularization_flag_; const MT regularization_coeff_; public: @@ -201,7 +201,6 @@ class DenseMomentumFunctor { const MultiPrecisionType* learning_rate, const MT* master_param, const MT mu, const MT rescale_grad, const int64_t num, - const RegularizationType regularization_flag, const MT regularization_coeff, T* param_out, MT* velocity_out, MT* master_param_out) : param_(param), @@ -215,7 +214,6 @@ class DenseMomentumFunctor { param_out_(param_out), velocity_out_(velocity_out), master_param_out_(master_param_out), - regularization_flag_(regularization_flag), regularization_coeff_(regularization_coeff) {} inline HOSTDEVICE void operator()(size_t i) const { // put memory access in register @@ -225,9 +223,9 @@ class DenseMomentumFunctor { const MT lr = static_cast(lr_[0]); const MT velocity = velocity_[i]; - grad = regularization_flag_ == RegularizationType::kL2DECAY - ? grad + regularization_coeff_ * param - : grad; + if (kRegType == RegularizationType::kL2DECAY) { + grad += regularization_coeff_ * param; + } MT velocity_out = velocity * mu_ + grad; MT param_out = param - (grad + velocity_out * mu_) * lr; @@ -240,8 +238,8 @@ class DenseMomentumFunctor { } }; -template -class DenseMomentumFunctor { +template +class DenseMomentumFunctor { private: const T* param_; const T* grad_; @@ -254,7 +252,6 @@ class DenseMomentumFunctor { T* param_out_; MT* velocity_out_; MT* master_param_out_; - const RegularizationType regularization_flag_; const MT regularization_coeff_; public: @@ -262,7 +259,6 @@ class DenseMomentumFunctor { const MultiPrecisionType* learning_rate, const MT* master_param, const MT mu, const MT rescale_grad, const int64_t num, - const RegularizationType regularization_flag, const MT regularization_coeff, T* param_out, MT* velocity_out, MT* master_param_out) : param_(param), @@ -276,7 +272,6 @@ class DenseMomentumFunctor { param_out_(param_out), velocity_out_(velocity_out), master_param_out_(master_param_out), - regularization_flag_(regularization_flag), regularization_coeff_(regularization_coeff) {} inline HOSTDEVICE void operator()(size_t i) const { // put memory access in register @@ -286,9 +281,9 @@ class DenseMomentumFunctor { const MT lr = static_cast(lr_[0]); const MT velocity = velocity_[i]; - grad = regularization_flag_ == RegularizationType::kL2DECAY - ? grad + regularization_coeff_ * param - : grad; + if (kRegType == RegularizationType::kL2DECAY) { + grad += regularization_coeff_ * param; + } MT velocity_out = velocity * mu_ + grad; MT param_out = param - lr * velocity_out; @@ -522,23 +517,31 @@ class MomentumOpKernel : public framework::OpKernel { platform::ForRange for_range( static_cast(ctx.device_context()), param->numel()); - if (use_nesterov) { - DenseMomentumFunctor functor( - param->data(), grad->data(), velocity->data(), - learning_rate->data(), master_in_data, mu, rescale_grad, - param->numel(), regularization_flag, regularization_coeff, - param_out->mutable_data(ctx.GetPlace()), - velocity_out->mutable_data(ctx.GetPlace()), master_out_data); - for_range(functor); +#define PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(__nesterov, __reg_type) \ + DenseMomentumFunctor functor( \ + param->data(), grad->data(), velocity->data(), \ + learning_rate->data(), master_in_data, mu, rescale_grad, \ + param->numel(), regularization_coeff, \ + param_out->mutable_data(ctx.GetPlace()), \ + velocity_out->mutable_data(ctx.GetPlace()), master_out_data); \ + for_range(functor); + if (use_nesterov) { + if (regularization_flag == RegularizationType::kL2DECAY) { + PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(UseNesterov, + RegularizationType::kL2DECAY); + } else { + PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(UseNesterov, + RegularizationType::kNONE); + } } else { - DenseMomentumFunctor functor( - param->data(), grad->data(), velocity->data(), - learning_rate->data(), master_in_data, mu, rescale_grad, - param->numel(), regularization_flag, regularization_coeff, - param_out->mutable_data(ctx.GetPlace()), - velocity_out->mutable_data(ctx.GetPlace()), master_out_data); - for_range(functor); + if (regularization_flag == RegularizationType::kL2DECAY) { + PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(NoNesterov, + RegularizationType::kL2DECAY); + } else { + PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(NoNesterov, + RegularizationType::kNONE); + } } } diff --git a/python/paddle/fluid/tests/unittests/test_merged_momentum_op.py b/python/paddle/fluid/tests/unittests/test_merged_momentum_op.py index 0118a372c3f..96e458795a3 100644 --- a/python/paddle/fluid/tests/unittests/test_merged_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_merged_momentum_op.py @@ -102,7 +102,7 @@ def run_momentum_op(params, 'Param': p, 'Grad': g, 'Velocity': v, - 'LearningRate': lr_var + 'LearningRate': lr_var, } outputs = {'ParamOut': p, 'VelocityOut': v} if multi_precision: @@ -115,7 +115,7 @@ def run_momentum_op(params, 'Param': param_vars, 'Grad': grad_vars, 'Velocity': velocity_vars, - 'LearningRate': lr_var + 'LearningRate': lr_var, } outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars} if multi_precision: @@ -176,7 +176,10 @@ class TestMergedMomentum(unittest.TestCase): outs2 = run_op(False) self.assertEqual(len(outs1), len(outs2)) for i, (out1, out2) in enumerate(zip(outs1, outs2)): - self.assertTrue(np.allclose(out1, out2, atol=1e-7)) + if isinstance(place, paddle.CUDAPlace): + self.assertTrue(np.array_equal(out1, out2)) + else: + self.assertTrue(np.allclose(out1, out2, atol=1e-7)) def get_places(self): places = [paddle.CPUPlace()] -- GitLab