diff --git a/paddle/fluid/operators/optimizers/adam_op.cc b/paddle/fluid/operators/optimizers/adam_op.cc index 130e10a1f8de307a5f26eca1cee157a8cdf17414..d4355c89f31cc35c2ed32966b3bdca50d7a9eb07 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cc +++ b/paddle/fluid/operators/optimizers/adam_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/optimizers/adam_op.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/operators/optimizers/adamw_op.h" namespace paddle { namespace operators { @@ -230,11 +231,30 @@ $$ )DOC"); } }; + +class AdamWOpMaker : public AdamOpMaker { + public: + void Make() { + AdamOpMaker::Make(); + AddAttr("coeff", + "(float, default 0.01) " + "coeff of the weight decay") + .SetDefault(0.01f); + AddAttr("with_decay", + "(bool, default false) " + "whether to do weight decay") + .SetDefault(false); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(adam, ops::AdamOp, ops::AdamOpMaker); + +REGISTER_OP_WITHOUT_GRADIENT(adamw, ops::AdamWOp, ops::AdamWOpMaker); + REGISTER_OP_CPU_KERNEL( adam, ops::AdamOpKernel, ops::AdamOpKernel); diff --git a/paddle/fluid/operators/optimizers/adam_op_npu.cc b/paddle/fluid/operators/optimizers/adam_op_npu.cc index d0de480c1a0ccce9d235d2c5c836bac8d41095c1..1169bc12ac230c0601ebe502a7eb6c3866f25381 100644 --- a/paddle/fluid/operators/optimizers/adam_op_npu.cc +++ b/paddle/fluid/operators/optimizers/adam_op_npu.cc @@ -225,6 +225,79 @@ class AdamNPUKernel : public framework::OpKernel { } }; +template +class AdamWNPUKernel : public AdamNPUKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + VLOG(3) << "NPU AdamW Kernel"; + bool skip_update = false; + if (ctx.HasInput("SkipUpdate")) { + VLOG(3) << "Has SkipUpdate"; + auto* skip_update_tensor = ctx.Input("SkipUpdate"); + PADDLE_ENFORCE_EQ(skip_update_tensor->numel(), 1, + platform::errors::InvalidArgument( + "Input(SkipUpdate) size must be 1, but get %d", + skip_update_tensor->numel())); + std::vector skip_update_vec; + TensorToVector(*skip_update_tensor, ctx.device_context(), + &skip_update_vec); + skip_update = skip_update_vec[0]; + } + VLOG(3) << "Skip update" << skip_update; + bool with_decay = ctx.Attr("with_decay"); + if (!skip_update && with_decay) { + float coeff = ctx.Attr("coeff"); + auto* lr = ctx.Input("LearningRate"); + + auto place = ctx.GetPlace(); + + auto stream = + ctx.template device_context() + .stream(); + + Tensor one(framework::proto::VarType::FP32); + Tensor decay(framework::proto::VarType::FP32); + Tensor tmp(framework::proto::VarType::FP32); + + tmp.mutable_data({1}, place); + one.mutable_data({1}, place); + decay.mutable_data({1}, place); + + FillNpuTensorWithConstant(&one, 1.0f); + framework::NPUAttributeMap attr_input = {{"value", coeff}}; + + const auto& runner1 = NpuOpRunner("Muls", {*lr}, {tmp}, attr_input); + runner1.Run(stream); + + const auto& runner2 = NpuOpRunner("Sub", {one, tmp}, {decay}, {}); + runner2.Run(stream); + + if (ctx.HasInput("MasterParam")) { + PADDLE_THROW(platform::errors::Unimplemented( + "Master Parma is not supported on npu")); + } else { + auto* param_out = ctx.Output("ParamOut"); + param_out->mutable_data(ctx.GetPlace()); + + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE_EQ(param_var->IsType(), true, + platform::errors::InvalidArgument( + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.InputNames("Param").front(), + framework::ToTypeName(param_var->Type()))); + auto* param = ctx.Input("Param"); + + const auto& runner = + NpuOpRunner("Mul", {*param, decay}, + {*const_cast(param)}, {}); + runner.Run(stream); + } + } + AdamNPUKernel::Compute(ctx); + } +}; + } // namespace operators } // namespace paddle @@ -234,3 +307,6 @@ REGISTER_OP_NPU_KERNEL( adam, ops::AdamNPUKernel, ops::AdamNPUKernel); + +REGISTER_OP_NPU_KERNEL(adamw, ops::AdamWNPUKernel, + ops::AdamWNPUKernel); diff --git a/paddle/fluid/operators/optimizers/adamw_op.cc b/paddle/fluid/operators/optimizers/adamw_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c2111d53f3a45fedec31ff1abcd9263a26145b98 --- /dev/null +++ b/paddle/fluid/operators/optimizers/adamw_op.cc @@ -0,0 +1,20 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +namespace ops = paddle::operators; +REGISTER_OP_CPU_KERNEL( + adamw, ops::AdamWOpKernel, + ops::AdamWOpKernel); diff --git a/paddle/fluid/operators/optimizers/adamw_op.h b/paddle/fluid/operators/optimizers/adamw_op.h new file mode 100644 index 0000000000000000000000000000000000000000..3301bc4808e3a8362fdf73d2f4cef35e6733693c --- /dev/null +++ b/paddle/fluid/operators/optimizers/adamw_op.h @@ -0,0 +1,105 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include + +namespace paddle { +namespace operators { + +class AdamWOp : public AdamOp { + using AdamOp::AdamOp; +}; + +struct CPUAdamW; + +template +class AdamWFunctor; + +template +class AdamWFunctor { + private: + const float coeff_; + const float learning_rate_; + T* param_; + + public: + AdamWFunctor(const float& coeff, const float& learning_rate, T* param) + : coeff_(coeff), learning_rate_(learning_rate), param_(param) {} + + inline HOSTDEVICE void operator()(size_t numel) const { + Eigen::Map> param{ + param_, static_cast(numel)}; + // Calculation + param = param * (1.0f - learning_rate_ * coeff_); + } +}; + +template +class AdamWOpKernel : public AdamOpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE_EQ(param_var->IsType(), true, + platform::errors::InvalidArgument( + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.InputNames("Param").front(), + framework::ToTypeName(param_var->Type()))); + + using paddle::framework::LoDTensor; + bool skip_update = false; + // TODO(liupeng): + if (ctx.HasInput("SkipUpdate")) { + VLOG(3) << "Has SkipUpdate"; + auto* skip_update_tensor = ctx.Input("SkipUpdate"); + PADDLE_ENFORCE_EQ(skip_update_tensor->numel(), 1, + platform::errors::InvalidArgument( + "Input(SkipUpdate) size must be 1, but get %d", + skip_update_tensor->numel())); + std::vector skip_update_vec; + TensorToVector(*skip_update_tensor, ctx.device_context(), + &skip_update_vec); + skip_update = skip_update_vec[0]; + } + VLOG(3) << "Skip update" << skip_update; + bool with_decay = ctx.Attr("with_decay"); + + if (skip_update || !with_decay) { + AdamOpKernel::Compute(ctx); + return; + } + + float coeff = ctx.Attr("coeff"); + auto* lr = ctx.Input("LearningRate"); + + LoDTensor* param; + + if (ctx.HasInput("MasterParam")) { + // TODO(liupeng): master + param = const_cast(ctx.Input("MasterParam")); + } else { + param = const_cast(ctx.Input("Param")); + } + + // AdamWFunctor(float coeff, const float* learning_rate, T* parma) + AdamWFunctor functor(coeff, *lr->data(), + param->data()); + functor(param->numel()); + + AdamOpKernel::Compute(ctx); + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py index 07272404768ff781fce7d18d634a56129bd0fb1c..e939ac765b2c9e30e84f1358c0cca0f3b660db49 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py @@ -197,7 +197,6 @@ class FP16Utils(object): if op.type == "update_loss_scaling": update_loss_scaling_op_idx = idx inf_var_name = op.desc.input('FoundInfinite')[0] - op._rename_input(inf_var_name, inf_var_name + "@GLOBAL_WORLD") break # not use amp @@ -246,10 +245,10 @@ class FP16Utils(object): update_loss_scaling_op_idx, type='cast', inputs={'X': inf_var_int32}, - outputs={'Out': inf_var_global}, + outputs={'Out': inf_var}, attrs={ "in_dtype": inf_var_int32.dtype, - "out_dtype": inf_var_global.dtype, + "out_dtype": inf_var.dtype, OP_ROLE_KEY: OpRole.Optimize }) update_loss_scaling_op_idx += 1 diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator.py b/python/paddle/fluid/contrib/mixed_precision/decorator.py index 22eb2d20f3db7f0c5124afdbdc1d5b1b7d6590e4..563c394c9fbfe82e491170a65fbf273ac6f5104f 100644 --- a/python/paddle/fluid/contrib/mixed_precision/decorator.py +++ b/python/paddle/fluid/contrib/mixed_precision/decorator.py @@ -399,12 +399,18 @@ class OptimizerWithMixedPrecision(object): self._decr_ratio, name="update_loss_scaling") # Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow - if isinstance(self._optimizer, paddle.fluid.optimizer.Adam): + # With fleet, optimizers are nested and the real optimizer set by user is the inner most one. + real_optimizer = self._optimizer + while hasattr(real_optimizer, "inner_opt"): + real_optimizer = real_optimizer.inner_opt + if isinstance(real_optimizer, (paddle.fluid.optimizer.Adam, + paddle.optimizer.AdamW)): # NOTE(zhiqiu): Since found_inf needs to be on cpu in adam op, we # copy it in advance to avoid multiple time copies. - found_inf = paddle.tensor.creation._memcpy(found_inf, - paddle.CPUPlace()) - self._optimizer._set_auxiliary_var('found_inf', found_inf) + with self._train_program._optimized_guard([]): + found_inf = paddle.tensor.creation._memcpy(found_inf, + paddle.CPUPlace()) + real_optimizer._set_auxiliary_var('found_inf', found_inf) optimize_ops = self._optimizer.apply_gradients(params_grads) return optimize_ops diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 3cb6d24c86faf2faeee7f81e41a63966550024a0..9e87681c4bef306f7504f8678004fbdec9b7a12e 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4661,12 +4661,8 @@ class PipelineOptimizer(object): op._set_attr(self._op_device_key, f"{self._device}:all") else: other_known_ops = [ - 'update_loss_scaling', - 'reduce_any', - 'concat', - 'sum', - 'check_finite_and_unscale', - 'alloc_float_status', + 'update_loss_scaling', 'reduce_any', 'concat', 'sum', + 'check_finite_and_unscale', 'alloc_float_status', 'memcpy' ] assert op.type in other_known_ops, "For other ops without " \ "op_device set, they must be one of {}, but it " \ diff --git a/python/paddle/fluid/tests/unittests/npu/test_adamw_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_adamw_op_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..78ee572d11fee644dbc9b0d0b9fde969cff0b22c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_adamw_op_npu.py @@ -0,0 +1,250 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from test_adam_op import adamw_step + +paddle.enable_static() +SEED = 2021 + + +class TestAdamW(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "adamw" + param = np.random.uniform(-1, 1, (105, 102)).astype("float32") + grad = np.random.uniform(-1, 1, (105, 102)).astype("float32") + moment1 = np.random.uniform(-1, 1, (105, 102)).astype("float32") + # The second moment is positive + moment2 = np.random.random((105, 102)).astype("float32") + + learning_rate = 0.5 + beta1 = 0.78 + beta2 = 0.836 + epsilon = 1e-4 + beta1_pow = beta1**10 + beta2_pow = beta2**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([beta2_pow]).astype("float32") + } + + self.attrs = { + 'epsilon': epsilon, + 'beta1': beta1, + 'beta2': beta2, + "coeff": 0.9, + "with_decay": True + } + + param_out, moment1_out, \ + moment2_out = adamw_step(self.inputs, self.attrs) + + self.outputs = { + 'Moment1Out': moment1_out, + 'Moment2Out': moment2_out, + 'ParamOut': param_out, + 'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1, + 'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2 + } + + def set_npu(self): + self.__class__.use_npu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-5) + + +class TestAdamOpWithSkipUpdate(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "adamw" + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The second moment is positive + moment2 = np.random.random((102, 105)).astype("float32") + + learning_rate = 0.004 + beta1 = 0.78 + beta2 = 0.836 + epsilon = 1e-4 + beta1_pow = beta1**10 + beta2_pow = beta2**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([beta2_pow]).astype("float32"), + 'Beta1Tensor': np.array([beta1]).astype("float32"), + 'Beta2Tensor': np.array([beta2]).astype("float32"), + 'EpsilonTensor': np.array([epsilon]).astype("float32"), + "SkipUpdate": np.array([True]).astype("bool"), + } + + self.attrs = {'epsilon': epsilon, "coeff": 0.02, "with_decay": True} + + self.outputs = { + 'Moment1Out': moment1, + 'Moment2Out': moment2, + 'ParamOut': param, + 'Beta1PowOut': self.inputs['Beta1Pow'], + 'Beta2PowOut': self.inputs['Beta2Pow'], + } + + def set_npu(self): + self.__class__.use_npu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-5) + + +class TestAdamOpWithoutDecay(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "adamw" + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") + # The second moment is positive + moment2 = np.random.random((102, 105)).astype("float32") + + learning_rate = 0.004 + beta1 = 0.78 + beta2 = 0.836 + epsilon = 1e-4 + beta1_pow = beta1**10 + beta2_pow = beta2**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([beta2_pow]).astype("float32"), + 'Beta1Tensor': np.array([beta1]).astype("float32"), + 'Beta2Tensor': np.array([beta2]).astype("float32"), + 'EpsilonTensor': np.array([epsilon]).astype("float32"), + "SkipUpdate": np.array([True]).astype("bool"), + } + + self.attrs = {'epsilon': epsilon, "coeff": 0.02, "with_decay": False} + + self.outputs = { + 'Moment1Out': moment1, + 'Moment2Out': moment2, + 'ParamOut': param, + 'Beta1PowOut': self.inputs['Beta1Pow'], + 'Beta2PowOut': self.inputs['Beta2Pow'], + } + + def set_npu(self): + self.__class__.use_npu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-5) + + +class TestNet(unittest.TestCase): + def _test(self, run_npu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + a_np = np.random.random(size=(32, 32)).astype('float32') + b_np = np.random.random(size=(32, 32)).astype('float32') + label_np = np.random.randint(2, size=(32, 1)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + a = paddle.static.data(name="a", shape=[32, 32], dtype='float32') + b = paddle.static.data(name="b", shape=[32, 32], dtype='float32') + label = paddle.static.data( + name="label", shape=[32, 1], dtype='int64') + + sum = paddle.add(a, b) + z = paddle.pow(sum, 2.0) + + fc_1 = fluid.layers.fc(input=z, size=128) + prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax') + + cost = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.reduce_mean(cost) + adam = paddle.optimizer.AdamW(learning_rate=0.01, weight_decay=0.02) + adam.minimize(loss) + + if run_npu: + place = paddle.NPUPlace(0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + print("Start run on {}".format(place)) + for epoch in range(100): + + pred_res, loss_res = exe.run( + main_prog, + feed={"a": a_np, + "b": b_np, + "label": label_np}, + fetch_list=[prediction, loss]) + if epoch % 10 == 0: + print("Epoch {} | Prediction[0]: {}, Loss: {}".format( + epoch, pred_res[0], loss_res)) + + return pred_res, loss_res + + def test_npu(self): + npu_pred, npu_loss = self._test(True) + cpu_pred, cpu_loss = self._test(False) + self.assertTrue(np.allclose(npu_pred, cpu_pred, rtol=1e-3)) + self.assertTrue(np.allclose(npu_loss, cpu_loss, rtol=1e-3)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_adam_op.py b/python/paddle/fluid/tests/unittests/test_adam_op.py index 27b66c13aecf33c053217a9980faecee2495953e..70109164960a33d732063405f6dc5afbf54984dc 100644 --- a/python/paddle/fluid/tests/unittests/test_adam_op.py +++ b/python/paddle/fluid/tests/unittests/test_adam_op.py @@ -215,6 +215,45 @@ def adam_step(inputs, attributes): return param_out, moment1_out, moment2_out +def adamw_step(inputs, attributes): + ''' + Simulate one step of the adam optimizer + :param inputs: dict of inputs + :param attributes: dict of attributes + :return tuple: tuple of output param, moment1, moment2, + beta1 power accumulator and beta2 power accumulator + ''' + param = inputs['Param'] + grad = inputs['Grad'] + moment1 = inputs['Moment1'] + moment2 = inputs['Moment2'] + lr = inputs['LearningRate'] + beta1_pow = inputs['Beta1Pow'] + beta2_pow = inputs['Beta2Pow'] + + epsilon = attributes['epsilon'] + coeff = attributes["coeff"] + if attributes.get("with_decay", False): + decay = 1.0 - lr * coeff + param2 = param * decay + param = param2.copy() + if 'beta1' in attributes: + beta1 = attributes['beta1'] + else: + beta1 = inputs['Beta1Tensor'][0] + if 'beta2' in attributes: + beta2 = attributes['beta2'] + else: + beta2 = inputs['Beta2Tensor'][0] + + moment1_out = beta1 * moment1 + (1 - beta1) * grad + moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad) + lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow) + param_out = param - lr_t * (moment1_out / (np.sqrt(moment2_out) + epsilon)) + + return param_out, moment1_out, moment2_out + + def adam_step_sparse(inputs, attributes, height, rows, row_numel, np_grad, lazy_mode): ''' diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index 11ba49c0707a37ac966f8df1f99bb8e40dece874..965785908979bb30e4d6e5016f32cff102988579 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -16,9 +16,12 @@ from .optimizer import Optimizer from .adam import Adam from ..fluid import core from ..fluid import framework +from ..fluid.framework import Variable from ..fluid.dygraph import base as imperative_base import paddle +_C_ops = core.ops + __all__ = [] @@ -173,6 +176,23 @@ class AdamW(Adam): multi_precision=multi_precision) self._default_dict = {'coeff': coeff} + self.type = "adamw" + + # now the adamw op doesn't support cuda + if core.is_compiled_with_cuda(): + self.type = "adam" + # Use _auxiliary_vars together with _set_auxiliary_var/_get_auxiliary_var to achieve that. + self._auxiliary_vars = dict() + + def _set_auxiliary_var(self, key, val): + self._auxiliary_vars[key] = val + + def _get_auxiliary_var(self, key): + if key in self._auxiliary_vars: + return self._auxiliary_vars[key] + else: + return None + def _append_decoupled_weight_decay(self, block, param_and_grad): """ Add decoupled weight decay op. @@ -228,8 +248,107 @@ class AdamW(Adam): paddle.fluid.layers.assign(input=scaled_param, output=param) def _append_optimize_op(self, block, param_and_grad): - self._append_decoupled_weight_decay(block, param_and_grad) - return super(AdamW, self)._append_optimize_op(block, param_and_grad) + if not core.is_compiled_with_npu(): + self._append_decoupled_weight_decay(block, param_and_grad) + return super(AdamW, self)._append_optimize_op(block, param_and_grad) + + assert isinstance(block, framework.Block) + if isinstance(param_and_grad, dict): + param_and_grad = self._update_param_group(param_and_grad) + param, grad = param_and_grad + + # Whether we should do weight decay for the parameter. + with_decay = True + if self._apply_decay_param_fun is not None \ + and not self._apply_decay_param_fun(param.name): + with_decay = False + + moment1 = self._get_accumulator(self._moment1_acc_str, + param_and_grad[0]) + moment2 = self._get_accumulator(self._moment2_acc_str, + param_and_grad[0]) + beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str, + param_and_grad[0]) + beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str, + param_and_grad[0]) + find_master = self._multi_precision and param_and_grad[ + 0].dtype == core.VarDesc.VarType.FP16 + master_weight = (self._master_weights[param_and_grad[0].name] + if find_master else None) + lr = self._create_param_lr(param_and_grad) + + # create the adam optimize op + if framework.in_dygraph_mode(): + + _beta1 = self._beta1 if not isinstance( + self._beta1, Variable) else self._beta1.numpy().item(0) + _beta2 = self._beta2 if not isinstance( + self._beta2, Variable) else self._beta2.numpy().item(0) + _, _, _, _, _ = _C_ops.adam( + param_and_grad[0], param_and_grad[1], lr, moment1, moment2, + beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1, + moment2, beta1_pow_acc, beta2_pow_acc, 'epsilon', self._epsilon, + 'lazy_mode', self._lazy_mode, 'min_row_size_to_use_multithread', + 1000, 'beta1', _beta1, 'beta2', _beta2) + + return None + + inputs = { + "Param": [param_and_grad[0]], + "Grad": [param_and_grad[1]], + "LearningRate": [lr], + "Moment1": [moment1], + "Moment2": [moment2], + "Beta1Pow": [beta1_pow_acc], + "Beta2Pow": [beta2_pow_acc], + } + + # Pass found_inf to adamw, to skip update for not only param, but also momentum and beta_pow + found_inf = self._get_auxiliary_var('found_inf') + + if found_inf: + inputs['SkipUpdate'] = found_inf + + outputs = { + "ParamOut": [param_and_grad[0]], + "Moment1Out": [moment1], + "Moment2Out": [moment2], + "Beta1PowOut": [beta1_pow_acc], + "Beta2PowOut": [beta2_pow_acc], + } + attrs = { + "lazy_mode": self._lazy_mode, + "min_row_size_to_use_multithread": 1000, + "multi_precision": find_master, + "with_decay": with_decay, + "coeff": self._coeff, + } + + if isinstance(self._beta1, Variable): + inputs['Beta1Tensor'] = self._beta1 + else: + attrs['beta1'] = self._beta1 + if isinstance(self._beta2, Variable): + inputs['Beta2Tensor'] = self._beta2 + else: + attrs['beta2'] = self._beta2 + if isinstance(self._epsilon, Variable): + inputs['EpsilonTensor'] = self._epsilon + else: + attrs['epsilon'] = self._epsilon + + if find_master: + inputs["MasterParam"] = master_weight + outputs["MasterParamOut"] = master_weight + + adamw_op = block.append_op( + type=self.type, + inputs=inputs, + outputs=outputs, + attrs=attrs, + stop_gradient=True) + + return adamw_op def _create_optimization_pass(self, parameters_and_grads): optimize_ops = super(