From c8ffdecbf4f165d48f0b3adbd1f869eba6723331 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Tue, 30 Nov 2021 19:31:10 +0800 Subject: [PATCH] [opt] Add regularation and Nesterov for mergerd_momentum op (#37527) * add regularation and Nesterov for mergerd_momentum * refine unittest for use_nesterov attr * refine op check * refine code * fix bug * refine code of regularization_flag * delete useless code --- .../optimizers/merged_momentum_op.cc | 15 +- .../operators/optimizers/merged_momentum_op.h | 174 +++++++++++++--- paddle/fluid/pybind/op_function.h | 2 +- .../unittests/test_merged_momentum_op.py | 197 ++++++++++++++++++ 4 files changed, 360 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op.cc b/paddle/fluid/operators/optimizers/merged_momentum_op.cc index 6c63376b5eb..1733150f271 100644 --- a/paddle/fluid/operators/optimizers/merged_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/merged_momentum_op.cc @@ -50,7 +50,8 @@ class MergedMomentumOpMaker : public framework::OpProtoAndCheckerMaker { .AsDuplicable(); AddInput("LearningRate", "(Tensor, default Tensor) " - "Input learning rate"); + "Input learning rate") + .AsDuplicable(); AddInput("MasterParam", "FP32 master weight for AMP.") .AsDispensable() .AsDuplicable(); @@ -68,6 +69,18 @@ class MergedMomentumOpMaker : public framework::OpProtoAndCheckerMaker { .AsDispensable() .AsDuplicable(); AddAttr("mu", "(float) Momentum coefficient"); + AddAttr("use_nesterov", + "(bool, default false) " + "Use Nesterov Momentum or not.") + .SetDefault(false); + AddAttr>( + "regularization_method", + "(string) regularization_method, right now only " + "support l2decay or none") + .SetDefault({}); + AddAttr>("regularization_coeff", + "(float) regularization_coeff") + .SetDefault({}); AddAttr("multi_precision", "(bool, default false) " "Whether to use multi-precision during weight updating.") diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op.h b/paddle/fluid/operators/optimizers/merged_momentum_op.h index 4dfaa4de3ad..7560b4fd8e5 100644 --- a/paddle/fluid/operators/optimizers/merged_momentum_op.h +++ b/paddle/fluid/operators/optimizers/merged_momentum_op.h @@ -18,6 +18,7 @@ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/optimizers/momentum_op.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/macros.h" @@ -85,33 +86,43 @@ class MergedMomentumOpKernel : public framework::OpKernel { auto params = ctx.MultiInput("Param"); auto params_out = ctx.MultiOutput("ParamOut"); size_t n = params.size(); - PADDLE_ENFORCE_EQ( - n, params_out.size(), - platform::errors::InvalidArgument( - "Output(ParamOut) number must be equal to Input(Param) number.")); + PADDLE_ENFORCE_EQ(n, params_out.size(), + platform::errors::InvalidArgument( + "The size of Output(ParamOut) must be equal to " + "Input(Param), but got the size of Output(ParamOut) " + "is %d, the size of Input(Param) is %d.", + params_out.size(), n)); for (size_t i = 0; i < n; ++i) { - PADDLE_ENFORCE_EQ( - params[i], params_out[i], - platform::errors::InvalidArgument( - "Input(Param) and Output(ParamOut) must be the same Tensors.")); + PADDLE_ENFORCE_EQ(params[i], params_out[i], + platform::errors::InvalidArgument( + "The size of Input(Param) and Output(ParamOut) " + "must be the same Tensors.")); } auto grads = ctx.MultiInput("Grad"); PADDLE_ENFORCE_EQ( n, grads.size(), platform::errors::InvalidArgument( - "Input(Grad) number must be equal to Input(Param) number.")); + "The size of Input(Grad) must be equal to Input(Param), but got " + "the size of Input(Grad) is %d, the size of Input(Param) is %d.", + grads.size(), n)); auto velocitys = ctx.MultiInput("Velocity"); PADDLE_ENFORCE_EQ(n, velocitys.size(), platform::errors::InvalidArgument( - "Input(Velocity) number and Input(Param) number.")); + "The size of Input(Velocity) must be equal to " + "Input(Param), but got the size of Input(Velocity) " + "is %d, the size of Input(Param) is %d.", + velocitys.size(), n)); auto velocitys_out = ctx.MultiOutput("VelocityOut"); PADDLE_ENFORCE_EQ( n, velocitys_out.size(), - platform::errors::InvalidArgument("Output(VelocityOut) number must be " - "equal to Input(Param) number.")); + platform::errors::InvalidArgument( + "The size of Output(VelocityOut) must be " + "equal to Input(Param), but got the size of Output(VelocityOut) is " + "%d, the size of Input(Param) is %d.", + velocitys_out.size(), n)); for (size_t i = 0; i < n; ++i) { PADDLE_ENFORCE_EQ(velocitys[i], velocitys_out[i], platform::errors::InvalidArgument( @@ -126,12 +137,18 @@ class MergedMomentumOpKernel : public framework::OpKernel { if (multi_precision) { PADDLE_ENFORCE_EQ( n, master_params.size(), - platform::errors::InvalidArgument("Input(MasterParam) number must be " - "equal to Input(Param) number.")); - PADDLE_ENFORCE_EQ(n, master_params_out.size(), - platform::errors::InvalidArgument( - "Output(MasterParamOut) number must be equal to " - "Input(MasterParam) number.")); + platform::errors::InvalidArgument( + "The size of Input(MasterParam) must be " + "equal to Input(Param), but got the size of Input(MasterParam) " + "is %d, the size of Input(Param) is %d.", + master_params.size(), n)); + PADDLE_ENFORCE_EQ( + n, master_params_out.size(), + platform::errors::InvalidArgument( + "The size of Output(MasterParamOut) must be equal to " + "Input(MasterParam), but got the size of Output(MasterParamOut) " + "is %d, the size of Input(Param) is %d.", + master_params_out.size(), n)); for (size_t i = 0; i < n; ++i) { PADDLE_ENFORCE_EQ(master_params[i], master_params_out[i], platform::errors::InvalidArgument( @@ -147,20 +164,61 @@ class MergedMomentumOpKernel : public framework::OpKernel { master_params_out.clear(); } - auto lr = ctx.Input("LearningRate"); auto mu = ctx.Attr("mu"); auto rescale_grad = ctx.Attr("rescale_grad"); + auto lrs = ctx.MultiInput("LearningRate"); + if (lrs.size() != 1) { + PADDLE_ENFORCE_EQ( + n, lrs.size(), + platform::errors::InvalidArgument( + "If the size of Input(LearningRate) is not 1, the size of " + "Input(LearningRate) must be " + "equal to Input(Param), but got the size of Input(LearningRate) " + "is %d, the size of Input(Param) is %d.", + lrs.size(), n)); + } + auto use_nesterov = ctx.Attr("use_nesterov"); + auto regularization_methods = + ctx.Attr>("regularization_method"); + auto regularization_coeffs = + ctx.Attr>("regularization_coeff"); + if (regularization_methods.size() != 0) { + PADDLE_ENFORCE_EQ( + n, regularization_methods.size(), + platform::errors::InvalidArgument( + "The size of Attr(regularization_method) must be equal " + "to Input(Param), but got the size of " + "Attr(regularization_method) is %d, the size of Input(Param) is " + "%d.", + regularization_methods.size(), n)); + PADDLE_ENFORCE_EQ( + n, regularization_coeffs.size(), + platform::errors::InvalidArgument( + "The size of Attr(regularization_coeff) must be equal " + "to Input(Param), but got the size of Attr(regularization_coeff) " + "is %d, the size of Input(Param) is %d.", + regularization_coeffs.size(), n)); + } + + VLOG(5) << "use_nesterov: " << use_nesterov + << ", regularization_methods.size(): " + << regularization_methods.size() + << ", regularization_coeffs.size(): " + << regularization_coeffs.size(); + using MPType = typename operators::details::MPTypeTrait::Type; auto &dev_ctx = ctx.template device_context(); + if (lrs.size() == 1 && use_nesterov == false && + regularization_methods.size() == 0) { #define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \ MergedMomentumKernelParam kernel_params; \ constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \ size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \ kernel_params.mu = static_cast(mu); \ kernel_params.rescale_grad = static_cast(rescale_grad); \ - kernel_params.lr = lr->data(); \ + kernel_params.lr = lrs[0]->data(); \ for (size_t i = 0; i < kernel_num; ++i) { \ size_t start = i * kMaxMergedNum; \ size_t end = std::min((i + 1) * kMaxMergedNum, n); \ @@ -182,14 +240,78 @@ class MergedMomentumOpKernel : public framework::OpKernel { VLOG(10) << "Launch MergedMomentum kernel " << i << " " \ << kernel_params.param_num; \ } - - if (multi_precision) { - PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true); + if (multi_precision) { + PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true); + } else { + PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false); + } +#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL } else { - PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false); - } + for (size_t idx = 0; idx < n; idx++) { + RegularizationType regularization_flag = + regularization_methods.size() > 0 && + regularization_methods[idx] == "l2_decay" + ? RegularizationType::kL2DECAY + : RegularizationType::kNONE; -#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL + MPType regularization_coeff = static_cast(0.0); + if (regularization_coeffs.size() != 0) { + regularization_coeff = + static_cast(regularization_coeffs[idx]); + } + auto lr_temp = lrs.size() > 1 ? lrs[idx] : lrs[0]; + + const MPType *master_in_data = + multi_precision ? master_params[idx]->data() : nullptr; + MPType *master_out_data = + multi_precision ? master_params_out[idx]->data() : nullptr; + if (platform::is_cpu_place(ctx.GetPlace())) { + CPUDenseMomentumFunctor functor; + functor(params[idx], grads[idx], velocitys[idx], lr_temp, mu, + use_nesterov, regularization_flag, regularization_coeff, + params_out[idx], velocitys_out[idx]); + VLOG(10) << "Launch MergedMomentum cpu kernel."; + } else if (platform::is_gpu_place(ctx.GetPlace())) { + platform::ForRange for_range( + static_cast(ctx.device_context()), + params[idx]->numel()); +#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \ + DenseMomentumFunctor functor( \ + params[idx]->data(), grads[idx]->data(), \ + velocitys[idx]->data(), lr_temp->data(), master_in_data, \ + mu, rescale_grad, params[idx]->numel(), regularization_coeff, \ + params_out[idx]->data(), velocitys_out[idx]->data(), \ + master_out_data); \ + for_range(functor); + if (use_nesterov) { + if (regularization_flag == RegularizationType::kL2DECAY) { + PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( + UseNesterov, RegularizationType::kL2DECAY); + VLOG(10) + << "Launch MergedMomentum gpu kernel use_nesterov kL2DECAY."; + } else { + PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(UseNesterov, + RegularizationType::kNONE); + VLOG(10) + << "Launch MergedMomentum gpu kernel use_nesterov kNONE."; + } + } else { + if (regularization_flag == RegularizationType::kL2DECAY) { + PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( + NoNesterov, RegularizationType::kL2DECAY); + VLOG(10) + << "Launch MergedMomentum gpu kernel no_nesterov kL2DECAY."; + } else { + PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(NoNesterov, + RegularizationType::kNONE); + VLOG(10) << "Launch MergedMomentum gpu kernel no_nesterov kNONE."; + } + } + } + } + VLOG(10) + << "Launch MergedMomentum kernel with multi_lr and regularization."; + } } }; diff --git a/paddle/fluid/pybind/op_function.h b/paddle/fluid/pybind/op_function.h index 997cb610faf..324cd4b1b16 100644 --- a/paddle/fluid/pybind/op_function.h +++ b/paddle/fluid/pybind/op_function.h @@ -827,7 +827,7 @@ GetVarBaseListFromArgs(const std::string& op_type, const std::string& arg_name, bool dispensable = false) { PyObject* list = PyTuple_GET_ITEM(args, arg_idx); - if (list == nullptr) { + if (list == nullptr || list == Py_None) { if (!dispensable) { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument '%s' (position %d) must be list of Tensor, but got " diff --git a/python/paddle/fluid/tests/unittests/test_merged_momentum_op.py b/python/paddle/fluid/tests/unittests/test_merged_momentum_op.py index 96e458795a3..9bc3bb7ad34 100644 --- a/python/paddle/fluid/tests/unittests/test_merged_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_merged_momentum_op.py @@ -130,6 +130,130 @@ def run_momentum_op(params, return exe.run(main, feed=feed_dict, fetch_list=fetch_list) +def run_momentum_op2(params, + grads, + velocitys, + master_params, + learning_rate, + place, + multi_precision, + mu=0.9, + rescale_grad=0.01, + use_merged=False, + use_nesterov=True): + assert len(params) == len(grads) + assert len(params) == len(velocitys) + if multi_precision: + assert len(params) == len(master_params) + op_type = 'merged_momentum' if use_merged else 'momentum' + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): + helper = LayerHelper(op_type, **locals()) + + param_vars = [ + helper.create_variable( + persistable=True, shape=p.shape, dtype=p.dtype) for p in params + ] + grad_vars = [ + helper.create_variable( + shape=g.shape, dtype=g.dtype) for g in grads + ] + velocity_vars = [ + helper.create_variable( + persistable=True, shape=v.shape, dtype=v.dtype) + for v in velocitys + ] + lr_var = helper.create_variable( + persistable=True, + shape=learning_rate.shape, + dtype=learning_rate.dtype) + + feed_dict = OrderedDict() + + feed_dict.update( + OrderedDict([(p_var.name, p_val) + for p_var, p_val in zip(param_vars, params)])) + feed_dict.update( + OrderedDict([(v_var.name, v_val) + for v_var, v_val in zip(velocity_vars, velocitys)])) + fetch_list = list(feed_dict.keys()) + + feed_dict.update( + OrderedDict([(g_var.name, g_val) + for g_var, g_val in zip(grad_vars, grads)])) + feed_dict.update({lr_var.name: learning_rate}) + + if multi_precision: + master_param_vars = [ + helper.create_variable( + persistable=True, shape=p.shape, dtype=p.dtype) + for p in master_params + ] + feed_dict.update( + OrderedDict([(mp_var.name, mp_val) + for mp_var, mp_val in zip(master_param_vars, + master_params)])) + # CPUPlace does not use MasterParam + if isinstance(place, paddle.CUDAPlace): + fetch_list = fetch_list + [ + mp_var.name for mp_var in master_param_vars + ] + else: + master_param_vars = None + + if not use_merged: + for i, (p, g, + v) in enumerate(zip(param_vars, grad_vars, velocity_vars)): + inputs = { + 'Param': p, + 'Grad': g, + 'Velocity': v, + 'LearningRate': lr_var, + } + outputs = {'ParamOut': p, 'VelocityOut': v} + if multi_precision: + inputs['MasterParam'] = master_param_vars[i] + outputs['MasterParamOut'] = master_param_vars[i] + attrs = { + 'mu': mu, + 'multi_precision': multi_precision, + 'rescale_grad': rescale_grad, + 'use_nesterov': use_nesterov, + 'regularization_method': 'l2_decay', + 'regularization_coeff': 2.0, + } + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + else: + inputs = { + 'Param': param_vars, + 'Grad': grad_vars, + 'Velocity': velocity_vars, + 'LearningRate': lr_var, + } + outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars} + if multi_precision: + inputs['MasterParam'] = master_param_vars + outputs['MasterParamOut'] = master_param_vars + attrs = { + 'mu': mu, + 'multi_precision': multi_precision, + 'rescale_grad': rescale_grad, + 'use_nesterov': use_nesterov, + 'regularization_method': + ['l2_decay' for i in range(len(param_vars))], + 'regularization_coeff': [2.0 for i in range(len(param_vars))], + } + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + + exe = paddle.static.Executor(place) + with paddle.static.scope_guard(paddle.static.Scope()): + exe.run(startup) + return exe.run(main, feed=feed_dict, fetch_list=fetch_list) + + class TestMergedMomentum(unittest.TestCase): def setUp(self): paddle.enable_static() @@ -193,5 +317,78 @@ class TestMergedMomentum(unittest.TestCase): self.check_with_place(place, multi_precision) +class TestMergedMomentum2(unittest.TestCase): + def setUp(self): + paddle.enable_static() + self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]] + self.seed = 10 + + def gen_rand_data(self, shapes, dtype): + return [np.random.random(s).astype(dtype) for s in shapes] + + def prepare_data(self, shapes, multi_precision, seed, place): + np.random.seed(seed) + mp_dtype = np.float32 + dtype = np.float16 if multi_precision and isinstance( + place, paddle.CUDAPlace) else np.float32 + params = self.gen_rand_data(shapes, dtype) + grads = self.gen_rand_data(shapes, dtype) + velocitys = self.gen_rand_data(shapes, mp_dtype) + learning_rate = self.gen_rand_data([[1]], mp_dtype)[0] + if multi_precision: + master_params = [p.astype(mp_dtype) for p in params] + else: + master_params = None + return params, grads, velocitys, master_params, learning_rate + + def check_with_place(self, place, multi_precision): + params, grads, velocitys, master_params, learning_rate = self.prepare_data( + self.shapes, multi_precision, self.seed, place) + + def run_op(use_nesterov, use_merged): + # FIXME(zengjinle): CPU Momentum Op does not support rescale_grad + rescale_grad = 1.0 if isinstance(place, paddle.CPUPlace) else 0.01 + return run_momentum_op2( + params, + grads, + velocitys, + master_params, + learning_rate, + place, + multi_precision, + rescale_grad=rescale_grad, + use_merged=use_merged, + use_nesterov=use_nesterov) + + outs1 = run_op(use_nesterov=True, use_merged=True) + outs2 = run_op(use_nesterov=True, use_merged=False) + self.assertEqual(len(outs1), len(outs2)) + for i, (out1, out2) in enumerate(zip(outs1, outs2)): + if isinstance(place, paddle.CUDAPlace): + self.assertTrue(np.array_equal(out1, out2)) + else: + self.assertTrue(np.allclose(out1, out2, atol=1e-7)) + + outs3 = run_op(use_nesterov=False, use_merged=True) + outs4 = run_op(use_nesterov=False, use_merged=False) + self.assertEqual(len(outs3), len(outs4)) + for j, (out3, out4) in enumerate(zip(outs3, outs4)): + if isinstance(place, paddle.CUDAPlace): + self.assertTrue(np.array_equal(out3, out4)) + else: + self.assertTrue(np.allclose(out3, out4, atol=1e-7)) + + def get_places(self): + places = [paddle.CPUPlace()] + if paddle.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + return places + + def test_main(self): + for multi_precision in [False, True]: + for place in self.get_places(): + self.check_with_place(place, multi_precision) + + if __name__ == "__main__": unittest.main() -- GitLab