From 7ddb93d0dc347c89571be5ba43eae5152076a769 Mon Sep 17 00:00:00 2001 From: zhangkeliang Date: Sat, 19 Sep 2020 02:59:15 +0000 Subject: [PATCH] refactor momentum op to combine weight --- .../fluid/operators/optimizers/momentum_op.cc | 5 + .../fluid/operators/optimizers/momentum_op.h | 120 +++++++-- python/paddle/fluid/contrib/__init__.py | 3 + python/paddle/fluid/contrib/optimizer.py | 228 ++++++++++++++++++ .../fluid/tests/unittests/test_momentum_op.py | 167 +++++++++++++ 5 files changed, 501 insertions(+), 22 deletions(-) create mode 100644 python/paddle/fluid/contrib/optimizer.py diff --git a/paddle/fluid/operators/optimizers/momentum_op.cc b/paddle/fluid/operators/optimizers/momentum_op.cc index ccebfeca26c..a1aafc1c0f0 100644 --- a/paddle/fluid/operators/optimizers/momentum_op.cc +++ b/paddle/fluid/operators/optimizers/momentum_op.cc @@ -61,6 +61,11 @@ void MomentumOpMaker::Make() { "(bool, default false) " "Use Nesterov Momentum") .SetDefault(false); + AddAttr("regularization_method", + "(string) regularization_method") + .SetDefault(""); + AddAttr("regularization_coeff", "(float) regularization_coeff") + .SetDefault(1.0); AddComment(R"DOC( Momentum Optimizer. diff --git a/paddle/fluid/operators/optimizers/momentum_op.h b/paddle/fluid/operators/optimizers/momentum_op.h index 10b72524efd..fe7ec6182df 100644 --- a/paddle/fluid/operators/optimizers/momentum_op.h +++ b/paddle/fluid/operators/optimizers/momentum_op.h @@ -29,6 +29,12 @@ using framework::SelectedRows; struct NoNesterov; struct UseNesterov; +enum class RegularizationFlag { + kNONE = 0, + kL1DECAY = 1, // do not need support right now + kL2DECAY = 2, +}; + class MomentumOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override; @@ -100,6 +106,8 @@ class CPUDenseMomentumFunctor { const Tensor* learning_rate; const T mu; const T use_nesterov; + const RegularizationFlag regularization_flag; + const T regularization_coeff; Tensor* param_out; Tensor* velocity_out; @@ -107,13 +115,17 @@ class CPUDenseMomentumFunctor { CPUDenseMomentumFunctor(const Tensor* param, const Tensor* grad, const Tensor* velocity, const Tensor* learning_rate, const T mu, const bool use_nesterov, - Tensor* param_out, Tensor* velocity_out) + const RegularizationFlag regularization_flag, + const T regularization_coeff, Tensor* param_out, + Tensor* velocity_out) : param(param), grad(grad), velocity(velocity), learning_rate(learning_rate), mu(mu), use_nesterov(use_nesterov), + regularization_flag(regularization_flag), + regularization_coeff(regularization_coeff), param_out(param_out), velocity_out(velocity_out) {} @@ -126,11 +138,20 @@ class CPUDenseMomentumFunctor { auto g = framework::EigenVector::Flatten(*grad); auto* lr = learning_rate->data(); - v_out = v * mu + g; - if (use_nesterov) { - p_out = p - (g + v_out * mu) * lr[0]; + if (regularization_flag == RegularizationFlag::kL2DECAY) { + v_out = v * mu + p * regularization_coeff + g; + if (use_nesterov) { + p_out = p - (p * regularization_coeff + g + v_out * mu) * lr[0]; + } else { + p_out = p - lr[0] * v_out; + } } else { - p_out = p - lr[0] * v_out; + v_out = v * mu + g; + if (use_nesterov) { + p_out = p - (g + v_out * mu) * lr[0]; + } else { + p_out = p - lr[0] * v_out; + } } } }; @@ -152,11 +173,14 @@ class DenseMomentumFunctor { const int64_t num_; T* p_out_; T* v_out_; + const RegularizationFlag regularization_flag; + const T regularization_coeff; public: DenseMomentumFunctor(const T* p, const T* g, const T* v, const T* learning_rate, const T mu, const int64_t num, - T* p_out, T* v_out) + const RegularizationFlag regularization_flag, + const T regularization_coeff, T* p_out, T* v_out) : p_(p), g_(g), v_(v), @@ -164,13 +188,20 @@ class DenseMomentumFunctor { mu_(mu), num_(num), p_out_(p_out), - v_out_(v_out) {} + v_out_(v_out), + regularization_flag(regularization_flag), + regularization_coeff(regularization_coeff) {} inline HOSTDEVICE void operator()(size_t i) const { // put memory access in register const T p = p_[i]; - const T g = g_[i]; + T g = g_[i]; const T lr = lr_[0]; const T v = v_[i]; + + g = regularization_flag == RegularizationFlag::kL2DECAY + ? g + regularization_coeff * p + : g; + T v_out = v * mu_ + g; T p_out = p - (g + v_out * mu_) * lr; // write reigster to memory @@ -190,11 +221,14 @@ class DenseMomentumFunctor { const int64_t num_; T* p_out_; T* v_out_; + const RegularizationFlag regularization_flag; + const T regularization_coeff; public: DenseMomentumFunctor(const T* p, const T* g, const T* v, const T* learning_rate, const T mu, const int64_t num, - T* p_out, T* v_out) + const RegularizationFlag regularization_flag, + const T regularization_coeff, T* p_out, T* v_out) : p_(p), g_(g), v_(v), @@ -202,13 +236,20 @@ class DenseMomentumFunctor { mu_(mu), num_(num), p_out_(p_out), - v_out_(v_out) {} + v_out_(v_out), + regularization_flag(regularization_flag), + regularization_coeff(regularization_coeff) {} inline HOSTDEVICE void operator()(size_t i) const { // put memory access in register const T p = p_[i]; - const T g = g_[i]; + T g = g_[i]; const T lr = lr_[0]; const T v = v_[i]; + + g = regularization_flag == RegularizationFlag::kL2DECAY + ? g + regularization_coeff * p + : g; + T v_out = v * mu_ + g; T p_out = p - lr * v_out; // write reigster to memory @@ -233,11 +274,15 @@ class SparseMomentumFunctor { const int64_t row_height_; T* p_out_; T* v_out_; + const RegularizationFlag regularization_flag; + const T regularization_coeff; public: SparseMomentumFunctor(const T* p, const T* g, const T* v, const T* lr, const T mu, const int64_t* rows, int64_t row_numel, - int64_t row_height, T* p_out, T* v_out) + int64_t row_height, + const RegularizationFlag regularization_flag, + const T regularization_coeff, T* p_out, T* v_out) : p_(p), g_(g), v_(v), @@ -247,7 +292,9 @@ class SparseMomentumFunctor { row_numel_(row_numel), row_height_(row_height), p_out_(p_out), - v_out_(v_out) {} + v_out_(v_out), + regularization_flag(regularization_flag), + regularization_coeff(regularization_coeff) {} inline HOSTDEVICE void operator()(size_t i) { auto row_idx = @@ -258,6 +305,11 @@ class SparseMomentumFunctor { const T p = p_[i]; const T lr = lr_[0]; const T v = v_[i]; + + g = regularization_flag == RegularizationFlag::kL2DECAY + ? g + regularization_coeff * p + : g; + T v_out = v * mu_ + g; T p_out = p - (g + v_out * mu_) * lr; // write reigster to memory @@ -279,11 +331,15 @@ class SparseMomentumFunctor { const int64_t row_height_; T* p_out_; T* v_out_; + const RegularizationFlag regularization_flag; + const T regularization_coeff; public: SparseMomentumFunctor(const T* p, const T* g, const T* v, const T* lr, const T mu, const int64_t* rows, int64_t row_numel, - int64_t row_height, T* p_out, T* v_out) + int64_t row_height, + const RegularizationFlag regularization_flag, + const T regularization_coeff, T* p_out, T* v_out) : p_(p), g_(g), v_(v), @@ -293,7 +349,9 @@ class SparseMomentumFunctor { row_numel_(row_numel), row_height_(row_height), p_out_(p_out), - v_out_(v_out) {} + v_out_(v_out), + regularization_flag(regularization_flag), + regularization_coeff(regularization_coeff) {} inline HOSTDEVICE void operator()(size_t i) { auto row_idx = @@ -304,6 +362,11 @@ class SparseMomentumFunctor { const T p = p_[i]; const T lr = lr_[0]; const T v = v_[i]; + + g = regularization_flag == RegularizationFlag::kL2DECAY + ? g + regularization_coeff * p + : g; + T v_out = v * mu_ + g; T p_out = p - v_out * lr; // write reigster to memory @@ -316,6 +379,16 @@ template class MomentumOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + std::string regularization_method = + ctx.Attr("regularization_method"); + T regularization_coeff = + static_cast(ctx.Attr("regularization_coeff")); + RegularizationFlag regularization_flag{ + RegularizationFlag::kNONE}; // disable regularization + if (regularization_method == "l2_decay") { + regularization_flag = RegularizationFlag::kL2DECAY; + } + T mu = static_cast(ctx.Attr("mu")); bool use_nesterov = ctx.Attr("use_nesterov"); @@ -324,6 +397,7 @@ class MomentumOpKernel : public framework::OpKernel { auto param_out = ctx.Output("ParamOut"); auto* velocity = ctx.Input("Velocity"); auto velocity_out = ctx.Output("VelocityOut"); + param_out->mutable_data(ctx.GetPlace()); velocity_out->mutable_data(ctx.GetPlace()); @@ -331,9 +405,9 @@ class MomentumOpKernel : public framework::OpKernel { if (grad_var->IsType()) { auto grad = ctx.Input("Grad"); if (platform::is_cpu_place(ctx.GetPlace())) { - CPUDenseMomentumFunctor functor(param, grad, velocity, learning_rate, - mu, use_nesterov, param_out, - velocity_out); + CPUDenseMomentumFunctor functor( + param, grad, velocity, learning_rate, mu, use_nesterov, + regularization_flag, regularization_coeff, param_out, velocity_out); functor(); } else if (platform::is_gpu_place(ctx.GetPlace())) { platform::ForRange for_range( @@ -342,16 +416,16 @@ class MomentumOpKernel : public framework::OpKernel { if (use_nesterov) { DenseMomentumFunctor functor( param->data(), grad->data(), velocity->data(), - learning_rate->data(), mu, param->numel(), - param_out->mutable_data(ctx.GetPlace()), + learning_rate->data(), mu, param->numel(), regularization_flag, + regularization_coeff, param_out->mutable_data(ctx.GetPlace()), velocity_out->mutable_data(ctx.GetPlace())); for_range(functor); } else { DenseMomentumFunctor functor( param->data(), grad->data(), velocity->data(), - learning_rate->data(), mu, param->numel(), - param_out->mutable_data(ctx.GetPlace()), + learning_rate->data(), mu, param->numel(), regularization_flag, + regularization_coeff, param_out->mutable_data(ctx.GetPlace()), velocity_out->mutable_data(ctx.GetPlace())); for_range(functor); } @@ -384,6 +458,7 @@ class MomentumOpKernel : public framework::OpKernel { param->data(), merged_grad->value().data(), velocity->data(), learning_rate->data(), mu, rows, row_numel, static_cast(merged_grad->rows().size()), + regularization_flag, regularization_coeff, param_out->mutable_data(ctx.GetPlace()), velocity_out->mutable_data(ctx.GetPlace())); for_range(functor); @@ -393,6 +468,7 @@ class MomentumOpKernel : public framework::OpKernel { param->data(), merged_grad->value().data(), velocity->data(), learning_rate->data(), mu, rows, row_numel, static_cast(merged_grad->rows().size()), + regularization_flag, regularization_coeff, param_out->mutable_data(ctx.GetPlace()), velocity_out->mutable_data(ctx.GetPlace())); for_range(functor); diff --git a/python/paddle/fluid/contrib/__init__.py b/python/paddle/fluid/contrib/__init__.py index 5ae06cb1a0f..9f6b61bd3da 100644 --- a/python/paddle/fluid/contrib/__init__.py +++ b/python/paddle/fluid/contrib/__init__.py @@ -35,6 +35,8 @@ from . import mixed_precision from .mixed_precision import * from . import layers from .layers import * +from . import optimizer +from .optimizer import * __all__ = [] __all__ += decoder.__all__ @@ -46,3 +48,4 @@ __all__ += utils.__all__ __all__ += extend_optimizer.__all__ __all__ += ['mixed_precision'] __all__ += layers.__all__ +__all__ += optimizer.__all__ diff --git a/python/paddle/fluid/contrib/optimizer.py b/python/paddle/fluid/contrib/optimizer.py new file mode 100644 index 00000000000..8297ad880e7 --- /dev/null +++ b/python/paddle/fluid/contrib/optimizer.py @@ -0,0 +1,228 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ..optimizer import Optimizer +from ..regularizer import L1DecayRegularizer +from ..regularizer import L2DecayRegularizer +from .. import framework +from .. import core +from ..framework import program_guard +from ..clip import append_gradient_clip_ops + +__all__ = ['Momentum'] + + +class Momentum(Optimizer): + """ + + Simple Momentum optimizer with velocity state + + This optimizer has a flag for Nestrov Momentum. + + The update equations are as follows: + + .. math:: + + & velocity = mu * velocity + gradient + + & if (use\_nesterov): + + &\quad param = param - (gradient + mu * velocity) * learning\_rate + + & else: + + &\quad param = param - learning\_rate * velocity + + Parameters: + learning_rate (float|Variable): The learning rate used to update parameters. \ + Can be a float value or a Variable with one float value as data element. + momentum (float): Momentum factor + parameter_list (Iterable, optional): Iterable of ``Variable`` names to update to minimize ``loss``. \ + This parameter is required in dygraph mode. \ + The default value is None in static mode, at this time all parameters will be updated. + use_nesterov (bool, optional): Enables Nesterov momentum, default is false. + regularization (WeightDecayRegularizer, optional): The strategy of regularization. There are two method: \ + :ref:`api_fluid_regularizer_L1Decay` , :ref:`api_fluid_regularizer_L2Decay` . If a parameter has set \ + regularizer using :ref:`api_fluid_ParamAttr` already, the regularization setting here in optimizer will be \ + ignored for this parameter. Otherwise, the regularization setting here in optimizer will take effect. \ + Default None, meaning there is no regularization. + grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of + some derived class of ``GradientClipBase`` . There are three cliping strategies + ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , + :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. + name (str, optional): This parameter is used by developers to print debugging information. \ + For details, please refer to :ref:`api_guide_Name`. Default is None. + + Examples: + .. code-block:: python + + import paddle + import paddle.fluid as fluid + import numpy as np + + place = fluid.CPUPlace() + main = fluid.Program() + with fluid.program_guard(main): + x = fluid.layers.data(name='x', shape=[13], dtype='float32') + y = fluid.layers.data(name='y', shape=[1], dtype='float32') + y_predict = fluid.layers.fc(input=x, size=1, act=None) + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_cost = fluid.layers.mean(cost) + + moment_optimizer = fluid.optimizer.MomentumOptimizer(learning_rate=0.001, momentum=0.9) + moment_optimizer.minimize(avg_cost) + + fetch_list = [avg_cost] + train_reader = paddle.batch( + paddle.dataset.uci_housing.train(), batch_size=1) + feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + for data in train_reader(): + exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list) + + """ + _velocity_acc_str = "velocity" + + def __init__(self, + learning_rate, + momentum, + parameter_list=None, + use_nesterov=False, + regularization=None, + grad_clip=None, + name=None): + assert learning_rate is not None + assert momentum is not None + super(Momentum, self).__init__( + learning_rate=learning_rate, + parameter_list=parameter_list, + regularization=regularization, + grad_clip=grad_clip, + name=name) + self.type = "momentum" + self._momentum = momentum + self._use_nesterov = bool(use_nesterov) + self._regularization_method = "" + self._regularization_coef = 0 + if (isinstance(regularization, L2DecayRegularizer)): + self._regularization_method = "l2_decay" + self._regularization_coef = regularization._regularization_coeff + if (isinstance(regularization, L1DecayRegularizer)): + self._regularization_method = "l1_decay" + self._regularization_coef = regularization._regularization_coeff + + def _create_accumulators(self, block, parameters): + assert isinstance(block, framework.Block) + + for p in parameters: + self._add_accumulator(self._velocity_acc_str, p) + + def _append_optimize_op(self, block, param_and_grad): + assert isinstance(block, framework.Block) + + velocity_acc = self._get_accumulator(self._velocity_acc_str, + param_and_grad[0]) + lr = self._create_param_lr(param_and_grad) + + if framework.in_dygraph_mode(): + _, _ = core.ops.momentum(param_and_grad[0], param_and_grad[1], + velocity_acc, lr, param_and_grad[0], + velocity_acc, 'mu', self._momentum, + 'use_nesterov', self._use_nesterov) + return None + + attrs = { + "mu": self._momentum, + "use_nesterov": self._use_nesterov, + "regularization_method": self._regularization_method, + "regularization_coeff": self._regularization_coef + } + inputs = { + "Param": [param_and_grad[0]], + "Grad": [param_and_grad[1]], + "Velocity": [velocity_acc], + "LearningRate": [lr] + } + + outputs = { + "ParamOut": [param_and_grad[0]], + "VelocityOut": [velocity_acc] + } + # create the momentum optimize op + momentum_op = block.append_op( + type=self.type, + inputs=inputs, + outputs=outputs, + attrs=attrs, + stop_gradient=True) + + return momentum_op + + def apply_gradients(self, params_grads): + """ + Second part of `minimize`, appending optimization operators for + given `params_grads` pairs. + + Args: + params_grads (list): list of (param, grad) pair to do optimization. + + Returns: + list: A list of operators appended to the current program. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + loss = network() + optimizer = fluid.optimizer.SGD(learning_rate=0.1) + params_grads = optimizer.backward(loss) + # you may append operations for params_grads here + # ... + optimizer.apply_gradients(params_grads) + """ + + params_grads = sorted(params_grads, key=lambda x: x[0].name) + + # 'optimizer(grad_clip)' or 'set_gradient_clip' + if self._grad_clip is not None: + params_grads = self._grad_clip(params_grads) + else: + params_grads = append_gradient_clip_ops(params_grads) + + optimize_ops = self._create_optimization_pass(params_grads) + return optimize_ops + + def apply_optimize(self, loss, startup_program, params_grads): + """ + Second part of `minimize`, appending optimization operators for + given `params_grads` pairs. + Args: + loss (Variable): loss variable to run optimizations. + startup_program (Program): startup_program for initializing parameters + in `parameter_list`. + params_grads (list): list of (param, grad) pair to do optimization. + Returns: + list: A list of operators appended to the current program. + """ + if framework.in_dygraph_mode(): + with program_guard(framework.default_main_program(), + framework.default_startup_program()): + if self._grad_clip is not None: + params_grads = self._grad_clip(params_grads) + optimize_ops = self._create_optimization_pass(params_grads) + else: + program = loss.block.program + with program_guard(program, startup_program): + optimize_ops = self.apply_gradients(params_grads) + return optimize_ops diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py index a535ef5e603..c09992529bc 100644 --- a/python/paddle/fluid/tests/unittests/test_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py @@ -279,5 +279,172 @@ class TestMomentumV2(unittest.TestCase): self.assertRaises(ValueError, paddle.optimizer.Momentum, momentum=None) +class TestMomentumOpWithDecay(OpTest): + def setUp(self): + self.op_type = "momentum" + self.dtype = np.float32 + self.use_nesterov = True + self.regularization_method = 'l2_decay' + self.regularization_coeff = 0.9 + self.init_config() + + param = np.random.random((123, 321)).astype(self.dtype) + grad = np.random.random((123, 321)).astype(self.dtype) + velocity = np.zeros((123, 321)).astype(self.dtype) + learning_rate = np.array([0.001]).astype(self.dtype) + mu = 0.0001 + use_nesterov = self.use_nesterov + regularization_method = self.regularization_method + regularization_coeff = self.regularization_coeff + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Velocity': velocity, + 'LearningRate': learning_rate + } + + self.attrs = { + 'mu': mu, + 'use_nesterov': use_nesterov, + 'regularization_method': regularization_method, + 'regularization_coeff': regularization_coeff + } + + param_decay = regularization_coeff * param + grad_new = grad + param_decay + grad = grad_new + + velocity_out = mu * velocity + grad + if use_nesterov: + param_out = param - grad * learning_rate - \ + velocity_out * mu * learning_rate + else: + param_out = param - learning_rate * velocity_out + + self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out} + + def init_config(self): + pass + + def test_check_output(self): + self.check_output() + + +class TestMomentumOpWithDecayFP16(TestMomentumOpWithDecay): + def init_config(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output(atol=1e-3) + + +class TestMomentumOpWithDecay2(OpTest): + def init_config(self): + self.use_nesterov = False + + +class TestSparseMomentumOpWithDecay(unittest.TestCase): + def setUp(self): + self.use_nesterov = False + + def check_with_place(self, place): + self.init_kernel() + scope = core.Scope() + # create and initialize Grad Variable + height = 10 + rows = [0, 4, 7] + row_numel = 12 + mu = 1.0 + use_nesterov = self.use_nesterov + regularization_method = 'l2_decay' + regularization_coeff = 0.9 + + # create and initialize Param Variable + param = scope.var('Param').get_tensor() + param_array = np.full((height, row_numel), 5.0).astype("float32") + param.set(param_array, place) + param_out = scope.var("ParamOut").get_tensor() + param_out_array = np.full((height, row_numel), 0.0).astype("float32") + param_out.set(param_out_array, place) + + grad_selected_rows = scope.var('Grad').get_selected_rows() + grad_selected_rows.set_height(height) + grad_selected_rows.set_rows(rows) + grad_np_array = np.ones((len(rows), row_numel)).astype("float32") + grad_np_array[0, 0] = 2.0 + grad_np_array[2, 8] = 4.0 + grad_tensor = grad_selected_rows.get_tensor() + grad_tensor.set(grad_np_array, place) + + velocity = scope.var('Velocity').get_tensor() + velocity_np_array = np.ones((height, row_numel)).astype("float32") + velocity.set(velocity_np_array, place) + velocity_out = scope.var('VelocityOut').get_tensor() + velocity_out_np_array = np.full((height, row_numel), + 0.0).astype("float32") + velocity_out.set(velocity_out_np_array, place) + + # create and initialize LeraningRate Variable + lr = scope.var('LearningRate').get_tensor() + lr_array = np.full((1), 2.0).astype("float32") + lr.set(lr_array, place) + + # create and run operator + op = Operator( + "momentum", + Param='Param', + Grad='Grad', + Velocity='Velocity', + ParamOut='ParamOut', + VelocityOut='VelocityOut', + LearningRate='LearningRate', + mu=mu, + use_nesterov=use_nesterov, + regularization_method=regularization_method, + regularization_coeff=regularization_coeff) + op.run(scope, place) + + # get and compare result + param_out_np_array = np.array(param_out) + velocity_out_np_array = np.array(velocity_out) + + # TODO(dzh): add a more suitable general numpy interface + # for sparse update. + _grad_np_array = np.full((height, row_numel), 0.0).astype("float32") + for i in range(len(rows)): + _grad_np_array[rows[i]] = grad_np_array[i] + + _param = param_array + + _param_decay = regularization_coeff * _param + _grad_np_array_new = _grad_np_array + _param_decay + _grad_np_array = _grad_np_array_new + + _velocity_out = mu * velocity_np_array + _grad_np_array + if use_nesterov: + _param_out = _param - (_grad_np_array + _velocity_out * mu + ) * lr_array + else: + _param_out = _param - lr_array * _velocity_out + self.assertTrue((_velocity_out == velocity_out_np_array).all()) + self.assertTrue((_param_out == param_out_np_array).all()) + + def init_kernel(self): + pass + + def test_sparse_momentum(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + for place in places: + self.check_with_place(place) + + +class TestSparseMomentumOpWithDecay2(TestSparseMomentumOpWithDecay): + def init_kernel(self): + self.use_nesterov = True + + if __name__ == "__main__": unittest.main() -- GitLab