From adaeee4d3d3834616e121c32c95b09a87f24712d Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Fri, 17 Sep 2021 23:49:30 +0800 Subject: [PATCH] [AMP] Support pure fp16 training mode for dygraph (#35521) * add pure fp16 major function in auto_cast & tracer * support master weight in dygraph for pure fp16 * check mix dtype of fp16&fp32 for check_finite_and_unscale op * change pure fp16 funtion name * refine some bug in auto_cast * refine auto_cast interface logic * add param _casted_by_pure_fp16 for class Layer * support state_dict hook for save model by user appointed dtype in pure_fp16_decorator * refine pure_fp16_decorator as decorator * add unittest * add comment * add comment * support recompute * add comment for auto_cast and decorator * support to_static_state_dict for paddle.jit.save * unlimite models num and optimizers num * add lookup_table in black_list * fix momentum and layer state_dict * fix bug in layer state_dict * fix bug in layer state_dict_helper * refine unittest * refine test_momentun_op * refine interface and some code * refine amp_decorator interface * refine pure fp16 interface * refine master weight interface --- paddle/fluid/imperative/amp_auto_cast.cc | 27 +- paddle/fluid/imperative/amp_auto_cast.h | 16 +- paddle/fluid/imperative/tracer.cc | 5 +- paddle/fluid/imperative/tracer.h | 6 +- paddle/fluid/pybind/imperative.cc | 4 +- paddle/fluid/pybind/op_function_generator.cc | 18 +- python/paddle/amp/__init__.py | 3 +- python/paddle/amp/auto_cast.py | 73 ++- .../fleet/meta_parallel/pp_utils/utils.py | 9 +- .../distributed/fleet/utils/recompute.py | 12 +- python/paddle/fluid/contrib/optimizer.py | 20 +- python/paddle/fluid/dygraph/amp/auto_cast.py | 281 +++++++++++- .../paddle/fluid/dygraph/amp/loss_scaler.py | 32 +- python/paddle/fluid/dygraph/jit.py | 11 +- python/paddle/fluid/dygraph/layers.py | 119 ++++- python/paddle/fluid/optimizer.py | 63 +-- .../test_imperative_auto_mixed_precision.py | 416 +++++++++++++++++- .../tests/unittests/test_jit_save_load.py | 1 - python/paddle/optimizer/adam.py | 56 +-- python/paddle/optimizer/adamw.py | 14 +- python/paddle/optimizer/momentum.py | 75 ++-- 21 files changed, 1069 insertions(+), 192 deletions(-) diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index eba30ff8ed..48e5e430b1 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -117,7 +117,7 @@ static inline std::shared_ptr CastToType( imperative::NameVarBaseMap outs = {{"Out", {out}}}; { - AutoCastGuard guard(tracer, false); + AutoCastGuard guard(tracer, 0); tracer->TraceOp("cast", ins, outs, std::move(attrs)); } @@ -225,5 +225,30 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type, return new_ins; } +NameVarBaseMap CastPureFp16Inputs(const std::string& op_type, + const NameVarBaseMap& ins) { + NameVarBaseMap new_ins(ins); + auto dst_type = framework::proto::VarType::FP16; + if (AmpOperators::Instance().GetMutableUnsupportedFp16Ops()->count(op_type) || + AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) { + dst_type = framework::proto::VarType::FP32; + } + for (auto& pair : new_ins) { + if ((op_type == "batch_norm" || op_type == "layer_norm" || + op_type == "sync_batch_norm") && + pair.first != "X") { + continue; + } + VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " + << GetDtypeStr(*pair.second.cbegin()) << " to " + << framework::DataTypeToString(dst_type); + for (auto& var : pair.second) { + var = (dst_type == framework::proto::VarType::FP32 ? CastToFP32(var) + : CastToFP16(var)); + } + } + return new_ins; +} + } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/amp_auto_cast.h b/paddle/fluid/imperative/amp_auto_cast.h index fa76c19688..79bc83a777 100644 --- a/paddle/fluid/imperative/amp_auto_cast.h +++ b/paddle/fluid/imperative/amp_auto_cast.h @@ -63,15 +63,16 @@ std::ostream& operator<<(std::ostream& os, AmpOperators& ops); // NOTE(zhiqiu): AutoCastGuard is used for RAII. class AutoCastGuard { public: - AutoCastGuard(std::shared_ptr tracer, bool guard_mode) + AutoCastGuard(std::shared_ptr tracer, int guard_level) : tracer_(tracer) { - pre_mode_ = tracer_->IsAutoCastEnabled(); - if (pre_mode_ != guard_mode) { - tracer_->SetEnableAutoCast(guard_mode); + pre_amp_level_ = tracer_->AMPLevel(); + + if (pre_amp_level_ != guard_level) { + tracer_->SetAMPLevel(guard_level); } } - ~AutoCastGuard() { tracer_->SetEnableAutoCast(pre_mode_); } + ~AutoCastGuard() { tracer_->SetAMPLevel(pre_amp_level_); } // forbid copy and operator= AutoCastGuard(const AutoCastGuard& guard) = delete; @@ -79,11 +80,14 @@ class AutoCastGuard { private: std::shared_ptr tracer_; - bool pre_mode_; + int pre_amp_level_; }; NameVarBaseMap AutoCastInputs(const std::string& op_type, const NameVarBaseMap& ins); +NameVarBaseMap CastPureFp16Inputs(const std::string& op_type, + const NameVarBaseMap& ins); + } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 9dc9c4d90a..49e079c58c 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -176,9 +176,12 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, : attr_checker->GetDefaultAttrMap(); NameVarBaseMap new_ins = ins; - if (enable_autocast_) { + if (amp_level_ == 1) { VLOG(5) << "Auto mixed precision run operator: " << type; new_ins = AutoCastInputs(type, ins); + } else if (amp_level_ == 2) { + VLOG(5) << "Pure fp16 run operator: " << type; + new_ins = CastPureFp16Inputs(type, ins); } try { diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index b734ae5c49..e77623d7a4 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -105,9 +105,9 @@ class Tracer { void SetHasGrad(bool has_grad) { has_grad_ = has_grad; } - void SetEnableAutoCast(bool enabled) { enable_autocast_ = enabled; } + void SetAMPLevel(int level) { amp_level_ = level; } - bool IsAutoCastEnabled() const { return enable_autocast_; } + int AMPLevel() const { return amp_level_; } paddle::framework::GarbageCollector* MutableGarbageCollectorIfNotExists( const platform::Place& place); @@ -118,9 +118,9 @@ class Tracer { bool enable_program_desc_tracing_{false}; std::unique_ptr generator_; platform::Place expected_place_; - bool enable_autocast_{false}; GarbageCollectorMap gcs_; static thread_local bool has_grad_; + int amp_level_{0}; }; // To access static variable current_tracer diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 62279449e3..5aae05db8c 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -1947,8 +1947,8 @@ void BindImperative(py::module *m_ptr) { .def_property("_enable_program_desc_tracing", &imperative::Tracer::IsProgramDescTracingEnabled, &imperative::Tracer::SetEnableProgramDescTracing) - .def_property("_enable_autocast", &imperative::Tracer::IsAutoCastEnabled, - &imperative::Tracer::SetEnableAutoCast) + .def_property("_amp_level", &imperative::Tracer::AMPLevel, + &imperative::Tracer::SetAMPLevel) .def_property("_has_grad", &imperative::Tracer::HasGrad, &imperative::Tracer::SetHasGrad) .def_property( diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 3da4a4b8e8..f9d11e8154 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -63,11 +63,15 @@ std::map> op_ins_map = { {"moving_average_abs_max_scale", {"X", "InAccum", "InState"}}, {"multiclass_nms3", {"BBoxes", "Scores", "RoisNum"}}, {"box_coder", {"PriorBox", "PriorBoxVar", "TargetBox"}}, - {"momentum", {"Param", "Grad", "Velocity", "LearningRate"}}, + {"momentum", {"Param", "Grad", "Velocity", "LearningRate", "MasterParam"}}, {"sparse_momentum", {"Param", "Grad", "Velocity", "Index", "LearningRate"}}, {"rnn", {"Input", "PreState", "WeightList", "SequenceLength"}}, {"run_program", {"X", "Params"}}, - {"matrix_rank", {"X", "TolTensor"}}}; + {"matrix_rank", {"X", "TolTensor"}}, + {"adam", + {"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow", + "Beta2Pow", "MasterParam"}}, +}; // NOTE(zhiqiu): Like op_ins_map. // Commonly, the outputs in auto-generated OP function are determined by the @@ -97,12 +101,15 @@ std::map> op_outs_map = { {"Out", "OutScale", "OutAccum", "OutState"}}, {"multiclass_nms3", {"Out", "NmsRoisNum"}}, {"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}}, - {"momentum", {"ParamOut", "VelocityOut"}}, + {"momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}}, {"sparse_momentum", {"ParamOut", "VelocityOut"}}, {"rnn", {"DropoutState", "Reserve", "Out", "State"}}, {"lamb", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}}, {"run_program", {"DOut"}}, + {"adam", + {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", + "MasterParamOut"}}, }; // NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are @@ -119,13 +126,14 @@ std::map> op_outs_map = { std::map> op_passing_outs_map = { {"sgd", {"ParamOut"}}, {"adam", - {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}}, + {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", + "MasterParamOut"}}, {"adamw", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}}, {"average_accumulates", {"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates", "out_old_num_accumulates", "out_num_updates"}}, - {"momentum", {"ParamOut", "VelocityOut"}}, + {"momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}}, {"sparse_momentum", {"ParamOut", "VelocityOut"}}, {"batch_norm", {"MeanOut", "VarianceOut"}}, {"sync_batch_norm", {"MeanOut", "VarianceOut"}}, diff --git a/python/paddle/amp/__init__.py b/python/paddle/amp/__init__.py index 64992752b2..381aad8850 100644 --- a/python/paddle/amp/__init__.py +++ b/python/paddle/amp/__init__.py @@ -14,5 +14,6 @@ from .auto_cast import auto_cast # noqa: F401 from .grad_scaler import GradScaler # noqa: F401 +from .auto_cast import decorate # noqa: F401 -__all__ = ['auto_cast', 'GradScaler'] +__all__ = ['auto_cast', 'GradScaler', 'decorate'] diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index 974f718c2d..9d4b84c504 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -13,18 +13,22 @@ # limitations under the License. from paddle.fluid.dygraph.amp import amp_guard +from paddle.fluid.dygraph.amp import amp_decorate __all__ = [] -def auto_cast(enable=True, custom_white_list=None, custom_black_list=None): +def auto_cast(enable=True, + custom_white_list=None, + custom_black_list=None, + level='O1'): """ Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode. If enabled, the input data type (float32 or float16) of each operator is decided by autocast algorithm for better performance. Commonly, it is used together with `GradScaler` to achieve Auto-Mixed-Precision in - imperative mode. + imperative mode. It is used together with `decorator` to achieve Pure fp16 in imperative mode. Args: enable(bool, optional): Enable auto-mixed-precision or not. Default is True. @@ -34,6 +38,8 @@ def auto_cast(enable=True, custom_white_list=None, custom_black_list=None): custom_black_list(set|list|tuple, optional): The custom black_list. The set of ops that support fp16 calculation and are considered numerically-dangerous and whose effects may also be observed in downstream ops. These ops will not be converted to fp16. + level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list; + O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp) Examples: @@ -61,6 +67,67 @@ def auto_cast(enable=True, custom_white_list=None, custom_black_list=None): with paddle.amp.auto_cast(custom_white_list={'elementwise_add'}): c = a + b print(c.dtype) # FP16 + + with paddle.amp.auto_cast(custom_white_list={'elementwise_add'}, level='O2'): + d = a + b + print(d.dtype) # FP16 + + """ + return amp_guard(enable, custom_white_list, custom_black_list, level) + + +def decorate(models, + optimizers=None, + level='O1', + master_weight=None, + save_dtype=None): + """ + Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. + When level is O2(pure fp16), the decorate will cast all parameters of models to FP16, except BatchNorm and LayerNorm. + + Commonly, it is used together with `auto_cast` to achieve Pure fp16 in imperative mode. + + Args: + models(Layer|list of Layer, optional): The defined models by user, models must be either a single model or a list of models. Default is None. + optimizers(Optimizer|list of Optimizer, optional): The defined optimizers by user, optimizers must be either a single optimizer or a list of optimizers. Default is None. + level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the decorator will do nothing; + O2 represent Pure fp16, the decorator will cast all parameters of models to FP16, except BatchNorm and LayerNorm. Default is O1(amp) + master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None. + save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, float32, float64 or None. + The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None. + + Examples: + + .. code-block:: python + + # required: gpu + # Demo1: single model and optimizer: + import paddle + + model = paddle.nn.Conv2D(3, 2, 3, bias_attr=False) + optimzier = paddle.optimizer.SGD(parameters=model.parameters()) + + model, optimizer = paddle.amp.decorate(models=model, optimizers=optimzier, level='O2') + + data = paddle.rand([10, 3, 32, 32]) + + with paddle.amp.auto_cast(enable=True, custom_white_list=None, custom_black_list=None, level='O2'): + output = model(data) + print(output.dtype) # FP16 + + # required: gpu + # Demo2: multi models and optimizers: + model2 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False) + optimizer2 = paddle.optimizer.Adam(parameters=model2.parameters()) + + models, optimizers = paddle.amp.decorate(models=[model, model2], optimizers=[optimzier, optimizer2], level='O2') + + data = paddle.rand([10, 3, 32, 32]) + with paddle.amp.auto_cast(enable=True, custom_white_list=None, custom_black_list=None, level='O2'): + output = models[0](data) + output2 = models[1](data) + print(output.dtype) # FP16 + print(output2.dtype) # FP16 """ - return amp_guard(enable, custom_white_list, custom_black_list) + return amp_decorate(models, optimizers, level, master_weight, save_dtype) diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index 598c4b2642..b29b0b3e27 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -198,7 +198,11 @@ class _HPRecomputeFunction(PyLayer): # TODO support AMP tracer = framework._dygraph_tracer() - ctx.is_fw_autocast = tracer._enable_autocast + if tracer._amp_level == 0: + ctx.is_fw_autocast = False + else: + ctx.is_fw_autocast = True + ctx.amp_mode = 'O1' ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() with paddle.no_grad(): @@ -258,7 +262,8 @@ class _HPRecomputeFunction(PyLayer): with paddle.amp.auto_cast( enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list, - custom_black_list=ctx.amp_black_list): + custom_black_list=ctx.amp_black_list, + level=ctx.amp_mode): detached_inputs = detach_variable(tuple(inputs)) outputs = ctx.run_function(*detached_inputs) diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/utils/recompute.py index 89b14258c1..302877e51f 100755 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/utils/recompute.py @@ -98,7 +98,11 @@ class RecomputeFunction(PyLayer): # TODO support AMP tracer = framework._dygraph_tracer() - ctx.is_fw_autocast = tracer._enable_autocast + if tracer._amp_level == 0: + ctx.is_fw_autocast = False + else: + ctx.is_fw_autocast = True + ctx.amp_mode = 'O1' ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() with paddle.no_grad(): @@ -128,14 +132,16 @@ class RecomputeFunction(PyLayer): with paddle.amp.auto_cast( enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list, - custom_black_list=ctx.amp_black_list): + custom_black_list=ctx.amp_black_list, + level=ctx.amp_mode): detached_inputs = detach_variable(tuple(inputs)) outputs = ctx.run_function(*detached_inputs) else: with paddle.amp.auto_cast( enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list, - custom_black_list=ctx.amp_black_list): + custom_black_list=ctx.amp_black_list, + level=ctx.amp_mode): detached_inputs = detach_variable(tuple(inputs)) outputs = ctx.run_function(*detached_inputs) diff --git a/python/paddle/fluid/contrib/optimizer.py b/python/paddle/fluid/contrib/optimizer.py index 7f742adb41..3fb808a88a 100644 --- a/python/paddle/fluid/contrib/optimizer.py +++ b/python/paddle/fluid/contrib/optimizer.py @@ -203,19 +203,21 @@ class Momentum(Optimizer): param_and_grad[0]) lr = self._create_param_lr(param_and_grad) - if framework.in_dygraph_mode(): - _, _ = _C_ops.momentum( - param_and_grad[0], param_and_grad[1], velocity_acc, lr, - param_and_grad[0], velocity_acc, 'mu', self._momentum, - 'use_nesterov', self._use_nesterov, 'regularization_method', - self._regularization_method, 'regularization_coeff', - self._regularization_coeff) - return None - find_master = self._multi_precision and param_and_grad[ 0].dtype == core.VarDesc.VarType.FP16 master_weight = (self._master_weights[param_and_grad[0].name] if find_master else None) + + if framework.in_dygraph_mode(): + _, _, _ = _C_ops.momentum( + param_and_grad[0], param_and_grad[1], velocity_acc, lr, + master_weight, param_and_grad[0], velocity_acc, master_weight, + 'mu', self._momentum, 'use_nesterov', self._use_nesterov, + 'regularization_method', self._regularization_method, + 'regularization_coeff', self._regularization_coeff, + 'multi_precision', find_master) + return None + attrs = { "mu": self._momentum, "use_nesterov": self._use_nesterov, diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 759ce3d16a..25a7323063 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -19,8 +19,13 @@ import contextlib from paddle.fluid.framework import Variable, in_dygraph_mode, OpProtoHolder, Parameter, _dygraph_tracer, dygraph_only, set_flags, get_flags import warnings import copy +import functools +import paddle +import operator +import types +import paddle.fluid as fluid -__all__ = ['amp_guard'] +__all__ = ['amp_guard', 'amp_decorate'] # The set of ops that support fp16 calculation and are considered numerically- # safe and performance-critical. These ops are always converted to fp16. @@ -64,15 +69,22 @@ AMP_RELATED_FLAGS_SETTING = { 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, } +PURE_FP16_BLACK_LIST = {' '} +PURE_FP16_WHITE_LIST = {'lookup_table', 'lookup_table_v2'} + #NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list # The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode. -def _update_list(custom_white_list, custom_black_list): +def _update_list(custom_white_list, custom_black_list, level='O1'): """ Update black and white list according to users' custom list. """ - _white_list = copy.copy(WHITE_LIST) - _black_list = copy.copy(BLACK_LIST) + if level == 'O1': + _white_list = copy.copy(WHITE_LIST) + _black_list = copy.copy(BLACK_LIST) + else: + _white_list = copy.copy(PURE_FP16_WHITE_LIST) + _black_list = copy.copy(PURE_FP16_BLACK_LIST) if custom_white_list and custom_black_list: for op_name in custom_white_list: if op_name in custom_black_list: @@ -97,28 +109,111 @@ def _in_amp_guard(): """ tracer = _dygraph_tracer() if tracer: - return tracer._enable_autocast + if tracer._amp_level == 1: + return True + else: + return False else: return False +@dygraph_only +def pure_fp16_initialize(enable_pure_fp16, models, optimizers): + if not enable_pure_fp16: + return models, optimizers + + for idx in range(len(models)): + for layer in models[idx].sublayers(include_self=True): + layer._casted_by_pure_fp16 = True + if len(layer._sub_layers) is 0: + + if (layer._dtype is 'float16') or isinstance(layer, ( + paddle.nn.BatchNorm, paddle.nn.LayerNorm)): + continue + layer.to(dtype='float16') + + for idx_opt in range(len(optimizers)): + # update _param_groups + if getattr(optimizers[idx_opt], '_param_groups', None) and isinstance( + optimizers[idx_opt]._param_groups[0], dict): + for param_group in optimizers[idx_opt]._param_groups: + for i, param in enumerate(param_group['params']): + for idx_model in range(len(models)): + for layer in models[idx_model].sublayers( + include_self=True): + if id(param) in layer._parameters_transform_map: + param_group['params'][ + i] = layer._parameters_transform_map[id( + param)][0] + for param_group in optimizers[idx_opt]._parameter_list: + params = param_group['params'] + for i, param in enumerate(params): + for idx_model in range(len(models)): + for layer in models[idx_model].sublayers( + include_self=True): + if id(param) in layer._parameters_transform_map: + params[i] = layer._parameters_transform_map[id( + param)][0] + # update _parameter_list + else: + for i, param in enumerate(optimizers[idx_opt]._parameter_list): + for idx_model in range(len(models)): + for layer in models[idx_model].sublayers(include_self=True): + if id(param) in layer._parameters_transform_map: + optimizers[idx_opt]._parameter_list[ + i] = layer._parameters_transform_map[id(param)][ + 0] + if hasattr(optimizers[idx_opt], '_param_groups'): + optimizers[idx_opt]._param_groups[ + i] = layer._parameters_transform_map[id( + param)][0] + return models, optimizers + + +def check_models(models): + for model in models: + if not isinstance(model, paddle.nn.Layer): + raise RuntimeError( + "Current train mode is pure fp16, models should be paddle.nn.Layer, but receive {}.". + format(type(model))) + + +def check_optimizers(optimizers): + for optimizer in optimizers: + if not isinstance(optimizer, (paddle.optimizer.Optimizer, + paddle.fluid.optimizer.Optimizer)): + raise RuntimeError( + "Current train mode is pure fp16, optimizers should be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer, but receive {}.". + format(type(optimizer))) + + @signature_safe_contextmanager @dygraph_only -def amp_guard(enable=True, custom_white_list=None, custom_black_list=None): +def amp_guard(enable=True, + custom_white_list=None, + custom_black_list=None, + level='O1'): """ :api_attr: imperative - Create a context which enables auto-mixed-precision(AMP) of operators executed in imperative mode. + Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode. If enabled, the input data type (float32 or float16) of each operator is decided by autocast algorithm for better performance. - Commonly, it is used together with `AmpScaler` to achieve Auto-Mixed-Precision in - imperative mode. + Commonly, it is used together with `GradScaler` to achieve Auto-Mixed-Precision in + imperative mode. It is used together with `decorator` to achieve Pure fp16 in imperative mode. Args: enable(bool, optional): Enable auto-mixed-precision or not. Default is True. - custom_white_list(set|list, optional): The custom white_list. - custom_black_list(set|list, optional): The custom black_list. + custom_white_list(set|list|tuple, optional): The custom white_list. It's the set of ops that support + fp16 calculation and are considered numerically-safe and performance-critical. These ops + will be converted to fp16. + custom_black_list(set|list|tuple, optional): The custom black_list. The set of ops that support fp16 + calculation and are considered numerically-dangerous and whose effects may also be + observed in downstream ops. These ops will not be converted to fp16. + level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list; + O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp) + Examples: @@ -139,6 +234,11 @@ def amp_guard(enable=True, custom_white_list=None, custom_black_list=None): print(conv.dtype) # FP32 """ + if not (level in ['O1', 'O2']): + raise ValueError( + "level should be O1 or O2, O1 represent AMP train mode, O2 represent Pure fp16 train mode." + ) + tracer = _dygraph_tracer() if not tracer: raise ValueError( @@ -151,17 +251,27 @@ def amp_guard(enable=True, custom_white_list=None, custom_black_list=None): % tracer._expected_place) enable = False - # use default white_list and black_list if no custom lists provided - _white_list = WHITE_LIST - _black_list = BLACK_LIST + if level == 'O1': + amp_level = 1 + _white_list = WHITE_LIST + _black_list = BLACK_LIST + else: + amp_level = 2 + _white_list = PURE_FP16_WHITE_LIST + _black_list = PURE_FP16_BLACK_LIST + if custom_white_list or custom_black_list: _white_list, _black_list = _update_list(custom_white_list, - custom_black_list) + custom_black_list, level) + + if not enable: + amp_level = 0 if tracer: # enable auto_cast - original_enable = tracer._enable_autocast - tracer._enable_autocast = enable + original_amp_level = tracer._amp_level + tracer._amp_level = amp_level + # set amp op list original_white_list, original_black_list = tracer._get_amp_op_list() tracer._set_amp_op_list(_white_list, _black_list) @@ -179,6 +289,141 @@ def amp_guard(enable=True, custom_white_list=None, custom_black_list=None): yield finally: if tracer: - tracer._enable_autocast = original_enable + tracer._amp_level = original_amp_level tracer._set_amp_op_list(original_white_list, original_black_list) # set_flags(original_flags) + + +class StateDictHook(object): + def __init__(self, save_dtype): + self._save_dtype = save_dtype + + def __call__(self, state_dict): + for key in state_dict: + param = state_dict[key] + with fluid.dygraph.guard(): + param_applied = paddle.cast(param, self._save_dtype) + param_applied.name = param.name + state_dict[key] = param_applied + + +@dygraph_only +def amp_decorate(models, + optimizers=None, + level='O1', + master_weight=None, + save_dtype=None): + """ + Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. + When level is O2(pure fp16), the decorate will cast all parameters of models to FP16, except BatchNorm and LayerNorm. + + Commonly, it is used together with `amp_guard` to achieve Pure fp16 in imperative mode. + + Args: + models(Layer|list of Layer, optional): The defined models by user, models must be either a single model or a list of models. Default is None. + optimizers(Optimizer|list of Optimizer, optional): The defined optimizers by user, optimizers must be either a single optimizer or a list of optimizers. Default is None. + level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the decorator will do nothing; + O2 represent Pure fp16, the decorator will cast all parameters of models to FP16, except BatchNorm and LayerNorm. Default is O1(amp) + master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None. + save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, float32, float64 or None. + The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None. + + Examples: + + .. code-block:: python + + # required: gpu + # Demo1: single model and optimizer: + import paddle + import paddle.fluid as fluid + + model = paddle.nn.Conv2D(3, 2, 3, bias_attr=False) + optimzier = paddle.optimizer.SGD(parameters=model.parameters()) + + model, optimizer = fluid.dygraph.amp_decorate(models=model, optimizers=optimzier, level='O2') + + data = paddle.rand([10, 3, 32, 32]) + + with fluid.dygraph.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'): + output = model(data) + print(output.dtype) # FP16 + + # required: gpu + # Demo2: multi models and optimizers: + model2 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False) + optimizer2 = paddle.optimizer.Adam(parameters=model2.parameters()) + + models, optimizers = fluid.dygraph.amp_decorate(models=[model, model2], optimizers=[optimzier, optimizer2], level='O2') + + data = paddle.rand([10, 3, 32, 32]) + + with fluid.dygraph.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'): + output = models[0](data) + output2 = models[1](data) + print(output.dtype) # FP16 + print(output2.dtype) # FP16 + """ + if not (level in ['O1', 'O2']): + raise ValueError( + "level should be O1 or O2, O1 represent AMP train mode, O2 represent Pure fp16 train mode." + ) + + if level == 'O1': + return models, optimizers + + models_is_list = False + if isinstance(models, paddle.nn.Layer): + models_is_list = False + models = [models] + check_models(models) + elif isinstance(models, list): + check_models(models) + models_is_list = True + else: + raise TypeError( + "models must be either a single model or a list of models.") + + optimizers_is_list = False + if isinstance(optimizers, (paddle.optimizer.Optimizer, + paddle.fluid.optimizer.Optimizer)): + optimizers_is_list = False + optimizers = [optimizers] + check_optimizers(optimizers) + elif isinstance(optimizers, list): + check_optimizers(optimizers) + optimizers_is_list = True + else: + raise TypeError( + "optimizers must be either a single optimizer or a list of optimizers." + ) + + models, optimizers = pure_fp16_initialize( + enable_pure_fp16=True, models=models, optimizers=optimizers) + + # supprot master_weight + for idx_opt in range(len(optimizers)): + if hasattr(optimizers[idx_opt], '_multi_precision'): + if master_weight is False: + optimizers[idx_opt]._multi_precision = False + else: + optimizers[idx_opt]._multi_precision = True + + if save_dtype is not None: + if not (save_dtype in ['float16', 'float32', 'float64']): + raise ValueError( + "save_dtype can only be float16 float32 or float64, but your input save_dtype is %s." + % save_dtype) + for idx in range(len(models)): + for layer in models[idx].sublayers(include_self=True): + layer.register_state_dict_hook(StateDictHook(save_dtype)) + + if models_is_list: + if optimizers_is_list: + return models, optimizers + else: + return models, optimizers[0] + else: + if optimizers_is_list: + return models[0], optimizers + else: + return models[0], optimizers[0] diff --git a/python/paddle/fluid/dygraph/amp/loss_scaler.py b/python/paddle/fluid/dygraph/amp/loss_scaler.py index a9fe2c9f3e..38881e43c0 100644 --- a/python/paddle/fluid/dygraph/amp/loss_scaler.py +++ b/python/paddle/fluid/dygraph/amp/loss_scaler.py @@ -216,17 +216,45 @@ class AmpScaler(object): if getattr(optimizer, '_param_groups', None) and isinstance( optimizer._param_groups[0], dict): param_grads = [] + param_grads_fp16 = [] + param_grads_fp32 = [] for group in optimizer._param_groups: for param in group['params']: if param._grad_ivar() is not None: param_grads.append(param._grad_ivar()) + if param._grad_ivar( + ).dtype == core.VarDesc.VarType.FP16: + param_grads_fp16.append(param._grad_ivar()) + else: + param_grads_fp32.append(param._grad_ivar()) else: param_grads = [ param._grad_ivar() for param in optimizer._parameter_list if param._grad_ivar() is not None ] - _C_ops.check_finite_and_unscale(param_grads, self._scale, param_grads, - self._found_inf) + param_grads_fp16 = [ + param._grad_ivar() for param in optimizer._parameter_list + if (param._grad_ivar() is not None + ) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP16 + ) + ] + param_grads_fp32 = [ + param._grad_ivar() for param in optimizer._parameter_list + if (param._grad_ivar() is not None + ) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP32 + ) + ] + temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool)) + temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool)) + if len(param_grads_fp16): + _C_ops.check_finite_and_unscale(param_grads_fp16, self._scale, + param_grads_fp16, + temp_found_inf_fp16) + if len(param_grads_fp32): + _C_ops.check_finite_and_unscale(param_grads_fp32, self._scale, + param_grads_fp32, + temp_found_inf_fp32) + self._found_inf = temp_found_inf_fp16 or temp_found_inf_fp32 def _update(self): """ diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 10c3861e77..d41c373bf5 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -779,10 +779,11 @@ def save(layer, path, input_spec=None, **configs): dygraph_state_dict = None if isinstance(inner_layer, Layer): - dygraph_state_dict = inner_layer.state_dict() + dygraph_state_dict = inner_layer.to_static_state_dict() elif isinstance(attr_func, StaticFunction): if attr_func._class_instance: - dygraph_state_dict = attr_func._class_instance.state_dict() + dygraph_state_dict = attr_func._class_instance.to_static_state_dict( + ) if dygraph_state_dict: # NOTE(chenweihang): we maintain the mapping of variable name to @@ -790,15 +791,19 @@ def save(layer, path, input_spec=None, **configs): # saved to inference program may not need by dygraph Layer, # we only record the state_dict variable's structured name state_names_dict = dict() + state_var_dict = dict() for structured_name, var in six.iteritems(dygraph_state_dict): state_names_dict[var.name] = structured_name + state_var_dict[var.name] = var # 3. share parameters from Layer to scope & record var info for param_or_buffer in concrete_program.parameters: # share to scope param_or_buffer_tensor = scope.var( param_or_buffer.name).get_tensor() - src_tensor = param_or_buffer.value().get_tensor() + #src_tensor = param_or_buffer.value().get_tensor() + src_tensor = state_var_dict[param_or_buffer.name].value( + ).get_tensor() param_or_buffer_tensor._share_data_with(src_tensor) # record var info if param_or_buffer.name not in extra_var_info: diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index cb7666b353..30d5ee4417 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -121,6 +121,13 @@ class Layer(core.Layer): self._forward_pre_hooks = collections.OrderedDict() self._forward_post_hooks = collections.OrderedDict() + self._parameters_transform_map = {} + self._buffers_transform_map = {} + + self._casted_by_pure_fp16 = False + + self._state_dict_hooks = collections.OrderedDict() + def train(self): """ Sets this Layer and all its sublayers to training mode. @@ -1259,6 +1266,87 @@ class Layer(core.Layer): final_str += ')' return final_str + def register_state_dict_hook(self, hook): + hook_remove_helper = HookRemoveHelper(self._state_dict_hooks) + self._state_dict_hooks[hook_remove_helper._hook_id] = hook + return hook_remove_helper + + def _state_dict_impl(self, + destination=None, + include_sublayers=True, + structured_name_prefix="", + include_non_persistable_buffer=False): + """ + Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict + + Parameters: + destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None + include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True + include_non_persistable_buffer(bool, optional): If true, include non persistable buffers of current layer and its sub-layers, it is used in pure fp16 and jit.save. Default: False + """ + + if destination is None: + destination = collections.OrderedDict() + for name, data in self._parameters.items(): + if data is not None: + destination[structured_name_prefix + name] = data + for name, buffer in self._buffers.items(): + if not include_non_persistable_buffer: + if buffer is not None and name not in self._non_persistable_buffer_names_set: + destination[structured_name_prefix + name] = buffer + else: + if buffer is not None: + destination[structured_name_prefix + name] = buffer + + if include_sublayers: + for layer_name, layer_item in self._sub_layers.items(): + if layer_item is not None: + destination_temp = destination.copy() + destination_temp.update( + layer_item._state_dict_impl( + destination_temp, include_sublayers, + structured_name_prefix + layer_name + ".", + include_non_persistable_buffer)) + destination = destination_temp + + for state_dict_hook in self._state_dict_hooks.values(): + hook_result = state_dict_hook(destination) + if hook_result is not None: + destination = hook_result + + return destination + + def to_static_state_dict(self, + destination=None, + include_sublayers=True, + structured_name_prefix=""): + ''' + Get all parameters and buffers of current layer and its sub-layers. And set them into a dict + + Parameters: + destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None + include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True + + Retruns: + dict: a dict contains all the parameters and persistable buffers. + + Examples: + .. code-block:: python + + import paddle + + emb = paddle.nn.Embedding(10, 10) + + state_dict = emb.to_static_state_dict() + paddle.save( state_dict, "paddle_dy.pdparams") + + ''' + return self._state_dict_impl( + destination=destination, + include_sublayers=include_sublayers, + structured_name_prefix=structured_name_prefix, + include_non_persistable_buffer=True) + def state_dict(self, destination=None, include_sublayers=True, @@ -1269,7 +1357,7 @@ class Layer(core.Layer): Parameters: destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True - + Retruns: dict: a dict contains all the parameters and persistable buffers. @@ -1284,26 +1372,11 @@ class Layer(core.Layer): paddle.save( state_dict, "paddle_dy.pdparams") ''' - - if destination is None: - destination = collections.OrderedDict() - for name, data in self._parameters.items(): - if data is not None: - destination[structured_name_prefix + name] = data - for name, buffer in self._buffers.items(): - if buffer is not None and name not in self._non_persistable_buffer_names_set: - destination[structured_name_prefix + name] = buffer - - if include_sublayers: - for layer_name, layer_item in self._sub_layers.items(): - if layer_item is not None: - destination_temp = destination.copy() - destination_temp.update( - layer_item.state_dict( - destination_temp, include_sublayers, - structured_name_prefix + layer_name + ".")) - destination = destination_temp - return destination + return self._state_dict_impl( + destination=destination, + include_sublayers=include_sublayers, + structured_name_prefix=structured_name_prefix, + include_non_persistable_buffer=False) @framework.deprecate_stat_dict def set_state_dict(self, state_dict, use_structured_name=True): @@ -1404,8 +1477,11 @@ class Layer(core.Layer): ).stop_gradient self._parameters[key]._set_grad_ivar(grad_applied) + self._parameters_transform_map[id(param)] = [param_applied, key] + for key, buf in self._buffers.items(): self._buffers[key] = func(buf, device, dtype, blocking) + self._buffers_transform_map[id(buf)] = [self._buffers[key], key] def to(self, device=None, dtype=None, blocking=None): ''' @@ -1501,6 +1577,7 @@ class Layer(core.Layer): return new_t self._apply(transform, device, dtype, blocking) + self._dtype = dtype # [aliases] Compatible with old method names set_dict = set_state_dict diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 8b2495fb2a..f809f1bda0 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -1433,12 +1433,12 @@ class MomentumOptimizer(Optimizer): velocity_acc = self._get_accumulator(self._velocity_acc_str, param_and_grad[0]) lr = self._create_param_lr(param_and_grad) - + master_weight = None if framework.in_dygraph_mode(): - _, _ = _C_ops.momentum(param_and_grad[0], param_and_grad[1], - velocity_acc, lr, param_and_grad[0], - velocity_acc, 'mu', self._momentum, - 'use_nesterov', self._use_nesterov) + _, _, _ = _C_ops.momentum( + param_and_grad[0], param_and_grad[1], velocity_acc, lr, + master_weight, param_and_grad[0], velocity_acc, master_weight, + 'mu', self._momentum, 'use_nesterov', self._use_nesterov) return None attrs = {"mu": self._momentum, "use_nesterov": self._use_nesterov} @@ -1982,26 +1982,29 @@ class LarsMomentumOptimizer(Optimizer): self._master_weights = {} def _create_master_weight(self, param): - assert isinstance(self.helper, LayerHelper) + if param.name in self._master_weights: + var = self._master_weights[param.name] + else: + assert isinstance(self.helper, LayerHelper) - var_name = param.name + '_fp32_master' - var_name = unique_name.generate(var_name) - var = layers.create_global_var( - name=var_name, - shape=param.shape, - value=0, - dtype='float32', - persistable=True) - block = self.helper.startup_program.global_block() - block.append_op( - type="cast", - inputs={"X": [param]}, - outputs={"Out": [var]}, - attrs={ - "in_dtype": param.dtype, - "out_dtype": core.VarDesc.VarType.FP32 - }) - self._master_weights[param.name] = var + var_name = param.name + '_fp32_master' + var_name = unique_name.generate(var_name) + var = layers.create_global_var( + name=var_name, + shape=param.shape, + value=0, + dtype='float32', + persistable=True) + block = self.helper.startup_program.global_block() + block.append_op( + type="cast", + inputs={"X": [param]}, + outputs={"Out": [var]}, + attrs={ + "in_dtype": param.dtype, + "out_dtype": core.VarDesc.VarType.FP32 + }) + self._master_weights[param.name] = var return var def _get_accumulator(self, name, param): @@ -2462,12 +2465,14 @@ class AdamOptimizer(Optimizer): self._beta1, Variable) else self._beta1.numpy().item(0) _beta2 = self._beta2 if not isinstance( self._beta2, Variable) else self._beta2.numpy().item(0) - _, _, _, _, _ = _C_ops.adam( + master_weight = None + _, _, _, _, _, _ = _C_ops.adam( param_and_grad[0], param_and_grad[1], lr, moment1, moment2, - beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1, - moment2, beta1_pow_acc, beta2_pow_acc, 'epsilon', self._epsilon, - 'lazy_mode', self._lazy_mode, 'min_row_size_to_use_multithread', - 1000, 'beta1', _beta1, 'beta2', _beta2, 'use_global_beta_pow', + beta1_pow_acc, beta2_pow_acc, master_weight, param_and_grad[0], + moment1, moment2, beta1_pow_acc, beta2_pow_acc, master_weight, + 'epsilon', self._epsilon, 'lazy_mode', self._lazy_mode, + 'min_row_size_to_use_multithread', 1000, 'beta1', _beta1, + 'beta2', _beta2, 'use_global_beta_pow', self._use_global_beta_pow) return None diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index 330c4c5ffe..ed98195363 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -18,6 +18,8 @@ import paddle.fluid as fluid import numpy as np import six from test_imperative_resnet import ResNet, BottleneckBlock, ConvBNLayer, train_parameters, optimizer_setting +import paddle.nn as nn +from paddle.static import InputSpec if fluid.core.is_compiled_with_cuda(): fluid.set_flags({"FLAGS_cudnn_deterministic": True}) @@ -89,6 +91,21 @@ class TestAutoCast(unittest.TestCase): set(black_list) == (set(base_black_list) - {"log"}) | {"conv2d"}) + base_white_list = fluid.dygraph.amp.auto_cast.PURE_FP16_WHITE_LIST + base_black_list = fluid.dygraph.amp.auto_cast.PURE_FP16_BLACK_LIST + with fluid.dygraph.amp_guard( + custom_white_list=["log"], + custom_black_list=["conv2d"], + level='O2'): + white_list, black_list = tracer._get_amp_op_list() + self.assertTrue( + set(white_list) == + (set(base_white_list) | {"log"}) - {"conv2d"}) + + self.assertTrue( + set(black_list) == + (set(base_black_list) - {"log"}) | {"conv2d"}) + def test_custom_op_list_exception(self): inp_np = np.random.random(size=[1, 3, 128, 128]).astype(np.float32) @@ -115,13 +132,36 @@ class TestAutoCast(unittest.TestCase): conv2d = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None) data = fluid.dygraph.to_variable(data) with fluid.dygraph.amp_guard(True): - out_fp16 = conv2d(data) - out_fp32 = paddle.expand_as( - out_fp16, out_fp16) # expand_as_v2 has no fp16 kernel + out_amp_fp16 = conv2d(data) + out_amp_fp32 = paddle.expand_as( + out_amp_fp16, + out_amp_fp16) # expand_as_v2 has no fp16 kernel + + with fluid.dygraph.amp_guard(True, level='O2'): + out_purefp16_fp16 = conv2d(data) + out_purefp16_fp32 = paddle.expand_as( + out_purefp16_fp16, + out_purefp16_fp16) # expand_as_v2 has no fp16 kernel self.assertTrue(data.dtype == fluid.core.VarDesc.VarType.FP32) - self.assertTrue(out_fp16.dtype == fluid.core.VarDesc.VarType.FP16) - self.assertTrue(out_fp32.dtype == fluid.core.VarDesc.VarType.FP32) + self.assertTrue(out_amp_fp16.dtype == fluid.core.VarDesc.VarType.FP16) + self.assertTrue(out_amp_fp32.dtype == fluid.core.VarDesc.VarType.FP32) + self.assertTrue( + out_purefp16_fp16.dtype == fluid.core.VarDesc.VarType.FP16) + self.assertTrue( + out_purefp16_fp32.dtype == fluid.core.VarDesc.VarType.FP32) + + def test_mode_exception(self): + def func(): + data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32') + with fluid.dygraph.guard(): + conv2d = fluid.dygraph.Conv2D( + 3, 2, 3, bias_attr=False, act=None) + data = fluid.dygraph.to_variable(data) + with fluid.dygraph.amp_guard(level='O'): + out = conv2d(data) + + self.assertRaises(ValueError, func) class TestAmpScaler(unittest.TestCase): @@ -386,6 +426,315 @@ class TestGradScalerStateDict(unittest.TestCase): np.allclose(out_use_state_dict[0], out_no_state_dict[0])) +class TestAmpDecorator(unittest.TestCase): + def test_mode_exception(self): + def func(): + with fluid.dygraph.guard(): + model = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None) + opt = paddle.optimizer.SGD(parameters=model.parameters()) + model, opt = paddle.amp.decorate( + models=model, optimizers=opt, level='O') + + self.assertRaises(ValueError, func) + + def test_input_formate_exception(self): + def test_model_error(): + with fluid.dygraph.guard(): + model = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None) + opt = paddle.optimizer.SGD(parameters=model.parameters()) + paddle.amp.decorate(models=None, optimizers=opt, level='O2') + + self.assertRaises(TypeError, test_model_error) + + def test_optimizer_error(): + with fluid.dygraph.guard(): + model = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None) + paddle.amp.decorate(models=model, optimizers=None, level='O2') + + self.assertRaises(TypeError, test_optimizer_error) + + def test_input_type_exception(self): + def test_error_model_optimizer(): + class MyModel(object): + def __init__(self): + print("A fake Model") + + class MyOptimizer(object): + def __init__(self): + print("A fake Optimizer") + + model = MyModel() + opt = MyOptimizer() + with fluid.dygraph.guard(): + paddle.amp.decorate(models=model, optimizers=opt, level='O2') + + self.assertRaises(TypeError, test_error_model_optimizer) + + def test_set_master_weight(self): + model1 = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None) + opt1 = paddle.optimizer.Adam( + learning_rate=0.0001, + parameters=model1.parameters(), + multi_precision=True) + model1, opt1 = paddle.amp.decorate( + models=model1, optimizers=opt1, level='O2', master_weight=None) + self.assertEqual(opt1._multi_precision, True) + + model2 = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None) + opt2 = paddle.optimizer.Adam( + learning_rate=0.0001, + parameters=model2.parameters(), + multi_precision=False) + model2, opt2 = paddle.amp.decorate( + models=model2, optimizers=opt2, level='O2', master_weight=None) + self.assertEqual(opt2._multi_precision, True) + + model3 = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None) + opt3 = paddle.optimizer.Adam( + learning_rate=0.0001, parameters=model3.parameters()) + model3, opt3 = paddle.amp.decorate( + models=model3, optimizers=opt3, level='O2', master_weight=True) + self.assertEqual(opt3._multi_precision, True) + + model4 = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None) + opt4 = paddle.optimizer.Adam( + learning_rate=0.0001, parameters=model4.parameters()) + model4, opt4 = paddle.amp.decorate( + models=model4, optimizers=opt4, level='O2', master_weight=False) + self.assertEqual(opt4._multi_precision, False) + + +class TestPureFp16SaveLoad(unittest.TestCase): + def test_save_dtype_exception(self): + def func(): + paddle.disable_static() + model = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None) + opt = paddle.optimizer.SGD(parameters=model.parameters()) + paddle.amp.decorate( + models=model, optimizers=opt, level='O2', save_dtype='int') + + self.assertRaises(ValueError, func) + + def train_resnet(self, + enable_amp=True, + use_data_loader=True, + use_save_load=True): + seed = 90 + + batch_size = train_parameters["batch_size"] + batch_num = 4 + + paddle.seed(seed) + paddle.framework.random._manual_program_seed(seed) + + resnet = ResNet(use_cudnn=True) + optimizer = optimizer_setting( + train_parameters, parameter_list=resnet.parameters()) + np.random.seed(seed) + train_reader = paddle.batch( + paddle.dataset.flowers.train(use_xmap=False), batch_size=batch_size) + + dy_param_init_value = {} + for param in resnet.parameters(): + dy_param_init_value[param.name] = param.numpy() + + program = None + scaler = paddle.amp.GradScaler( + enable=enable_amp, init_loss_scaling=2.**10) + + if use_data_loader: + train_reader = paddle.batch( + reader_decorator(paddle.dataset.flowers.train(use_xmap=False)), + batch_size=batch_size, + drop_last=True) + train_loader = fluid.io.DataLoader.from_generator( + capacity=4, + use_double_buffer=True, + iterable=True, + return_list=True) + train_loader.set_sample_list_generator(train_reader) + train_reader = train_loader + + if enable_amp: + resnet, optimizer = paddle.amp.decorate( + models=resnet, + optimizers=optimizer, + level='O2', + save_dtype='float32') + + for batch_id, data in enumerate(train_reader()): + if batch_id >= batch_num: + break + if use_data_loader: + img, label = data + else: + dy_x_data = np.array([x[0].reshape(3, 224, 224) + for x in data]).astype('float32') + if len(np.array([x[1] + for x in data]).astype('int64')) != batch_size: + continue + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) + + img = paddle.to_tensor(dy_x_data) + label = paddle.to_tensor(y_data) + label.stop_gradient = True + + with paddle.amp.auto_cast(enable=enable_amp, level='O2'): + out = resnet(img) + + loss = paddle.nn.functional.cross_entropy(input=out, label=label) + loss = paddle.cast(loss, 'float32') + avg_loss = paddle.mean(x=loss) + + dy_out = avg_loss.numpy() + + scaled_loss = scaler.scale(avg_loss) + scaled_loss.backward() + + scaler.minimize(optimizer, scaled_loss) + + dy_grad_value = {} + for param in resnet.parameters(): + if param.trainable: + np_array = np.array(param._grad_ivar().value().get_tensor()) + dy_grad_value[param.name + fluid.core.grad_var_suffix( + )] = np_array + + resnet.clear_gradients() + + dy_param_value = {} + for param in resnet.parameters(): + dy_param_value[param.name] = param.numpy() + + if use_save_load and batch_id == 2: + # paddle.save + obj = { + 'model': resnet.state_dict(), + 'opt': optimizer.state_dict(), + 'scaler': scaler.state_dict() + } + path = 'model.pdparams' + paddle.save(obj, path) + # paddle.load + obj_load = paddle.load(path) + resnet = ResNet(use_cudnn=True) + optimizer = optimizer_setting( + train_parameters, parameter_list=resnet.parameters()) + resnet.set_state_dict(obj_load['model']) + optimizer.set_state_dict(obj_load['opt']) + scaler.load_state_dict(obj_load['scaler']) + resnet, optimizer = paddle.amp.decorate( + models=resnet, + optimizers=optimizer, + level='O2', + save_dtype='float32') + + if use_data_loader: + train_reader._reset() + return dy_out, dy_param_value, dy_grad_value + + def test_with_save_load(self): + with fluid.dygraph.guard(): + out_use_save_load = self.train_resnet( + enable_amp=True, use_data_loader=True, use_save_load=True) + out_no_save_load = self.train_resnet( + enable_amp=True, use_data_loader=True, use_save_load=False) + print('save_load:', out_use_save_load[0], out_no_save_load[0]) + self.assertTrue(np.allclose(out_use_save_load[0], out_no_save_load[0])) + + +class TestPureFp16InferenceSaveLoad(unittest.TestCase): + def test_inference_save_load(self): + BATCH_SIZE = 16 + BATCH_NUM = 4 + EPOCH_NUM = 4 + IMAGE_SIZE = 784 + CLASS_NUM = 10 + + # define a random dataset + class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = np.random.random([IMAGE_SIZE]).astype('float32') + label = np.random.randint(0, CLASS_NUM - 1, + (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + class LinearNet(nn.Layer): + def __init__(self): + super(LinearNet, self).__init__() + self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM) + + def forward(self, x): + return self._linear(x) + + def train(layer, loader, loss_fn, opt): + for epoch_id in range(EPOCH_NUM): + for batch_id, (image, label) in enumerate(loader()): + with paddle.amp.auto_cast( + enable=True, + custom_white_list=None, + custom_black_list=None, + level='O2'): + out = layer(image) + loss = loss_fn(out, label) + loss.backward() + opt.step() + opt.clear_grad() + + # train + layer = LinearNet() + adam = paddle.optimizer.Adam( + learning_rate=0.001, + parameters=layer.parameters(), + multi_precision=True) + loss_fn = nn.CrossEntropyLoss() + layer, adam = paddle.amp.decorate( + models=layer, optimizers=adam, save_dtype='float32') + dataset = RandomDataset(BATCH_NUM * BATCH_SIZE) + loader = paddle.io.DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + num_workers=2) + + train(layer, loader, loss_fn, adam) + + # save + path = "example_model/linear" + paddle.jit.save( + layer, path, input_spec=[InputSpec( + shape=[IMAGE_SIZE], name='x')]) + + # jit.load + loaded_layer = paddle.jit.load(path) + + # inference + loaded_layer.eval() + x = np.random.randn(1, IMAGE_SIZE).astype('float32') + x_tensor = paddle.to_tensor(x) + pred = loaded_layer(x_tensor) + + # load_inference_model + paddle.enable_static() + exe = paddle.static.Executor(paddle.CPUPlace()) + [inference_program, feed_target_names, fetch_targets] = ( + paddle.static.load_inference_model(path, exe)) + tensor_img = x + results = exe.run(inference_program, + feed={feed_target_names[0]: tensor_img}, + fetch_list=fetch_targets) + + self.assertTrue(np.allclose(pred.numpy(), results, atol=1.e-5)) + + class TestResnet2(unittest.TestCase): """ Use paddle-2.0 API @@ -393,6 +742,7 @@ class TestResnet2(unittest.TestCase): def train_resnet(self, enable_amp=True, + level='O1', use_data_loader=False, use_param_group=False): seed = 90 @@ -418,13 +768,15 @@ class TestResnet2(unittest.TestCase): # NOTE(zhiqiu): The Membership test operations(in / not in) calls "is" and "equal", # see details: https://docs.python.org/3/reference/expressions.html#membership-test-operations. # So do not use other_params = [p for p in resnet.parameters() if p not in conv_params] - optimizer = paddle.optimizer.Momentum(parameters=[{ - 'params': conv_params, - 'learning_rate': 0.01 - }, { - 'params': other_params, - 'learning_rate': 0.001 - }]) + optimizer = paddle.optimizer.Momentum( + parameters=[{ + 'params': conv_params, + 'learning_rate': 0.01 + }, { + 'params': other_params, + 'learning_rate': 0.001 + }], + multi_precision=True) else: optimizer = paddle.optimizer.SGD(parameters=resnet.parameters()) @@ -453,6 +805,10 @@ class TestResnet2(unittest.TestCase): train_loader.set_sample_list_generator(train_reader) train_reader = train_loader + if enable_amp and (level == 'O2'): + resnet, optimizer = paddle.amp.decorate( + models=resnet, optimizers=optimizer, level='O2') + for batch_id, data in enumerate(train_reader()): if batch_id >= batch_num: break @@ -471,10 +827,11 @@ class TestResnet2(unittest.TestCase): label = paddle.to_tensor(y_data) label.stop_gradient = True - with paddle.amp.auto_cast(enable=enable_amp): + with paddle.amp.auto_cast(enable=enable_amp, level=level): out = resnet(img) loss = paddle.nn.functional.cross_entropy(input=out, label=label) + loss = paddle.cast(loss, 'float32') avg_loss = paddle.mean(x=loss) dy_out = avg_loss.numpy() @@ -504,15 +861,20 @@ class TestResnet2(unittest.TestCase): with fluid.dygraph.guard(): out_fp32 = self.train_resnet(enable_amp=False) out_amp = self.train_resnet(enable_amp=True) - print(out_fp32[0], out_amp[0]) + out_pure_fp16 = self.train_resnet(enable_amp=True, level='O2') + print(out_fp32[0], out_amp[0], out_pure_fp16[0]) self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-5)) + self.assertTrue(np.allclose(out_fp32[0], out_pure_fp16[0], atol=1.e-2)) def test_with_data_loader(self): with fluid.dygraph.guard(): out_fp32 = self.train_resnet(enable_amp=False, use_data_loader=True) out_amp = self.train_resnet(enable_amp=True, use_data_loader=True) - print(out_fp32[0], out_amp[0]) + out_pure_fp16 = self.train_resnet( + enable_amp=True, use_data_loader=True, level='O2') + print(out_fp32[0], out_amp[0], out_pure_fp16[0]) self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-5)) + self.assertTrue(np.allclose(out_fp32[0], out_pure_fp16[0], atol=1.e-2)) def test_param_group(self): with fluid.dygraph.guard(): @@ -520,8 +882,14 @@ class TestResnet2(unittest.TestCase): enable_amp=False, use_data_loader=True, use_param_group=True) out_amp = self.train_resnet( enable_amp=True, use_data_loader=True, use_param_group=True) - print(out_fp32[0], out_amp[0]) + out_pure_fp16 = self.train_resnet( + enable_amp=True, + use_data_loader=True, + use_param_group=True, + level='O2') + print(out_fp32[0], out_amp[0], out_pure_fp16[0]) self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-5)) + self.assertTrue(np.allclose(out_fp32[0], out_pure_fp16[0], atol=1.e-2)) class TestResnet(unittest.TestCase): @@ -529,7 +897,7 @@ class TestResnet(unittest.TestCase): Use paddle-1.x API """ - def train_resnet(self, enable_amp=True): + def train_resnet(self, enable_amp=True, level='O1'): seed = 90 batch_size = train_parameters["batch_size"] @@ -542,6 +910,8 @@ class TestResnet(unittest.TestCase): resnet = ResNet(use_cudnn=True) optimizer = optimizer_setting( train_parameters, parameter_list=resnet.parameters()) + optimizer = paddle.optimizer.Momentum( + parameters=resnet.parameters(), multi_precision=True) np.random.seed(seed) train_reader = paddle.batch( paddle.dataset.flowers.train(use_xmap=False), @@ -554,6 +924,11 @@ class TestResnet(unittest.TestCase): program = None scaler = paddle.fluid.dygraph.AmpScaler( enable=enable_amp, init_loss_scaling=2.**10) + + if enable_amp and (level == 'O2'): + resnet, optimizer = paddle.fluid.dygraph.amp_decorate( + models=resnet, optimizers=optimizer, level='O2') + for batch_id, data in enumerate(train_reader()): if batch_id >= batch_num: break @@ -567,7 +942,8 @@ class TestResnet(unittest.TestCase): img = fluid.dygraph.to_variable(dy_x_data) label = fluid.dygraph.to_variable(y_data) label.stop_gradient = True - with paddle.fluid.dygraph.amp_guard(enable=enable_amp): + with paddle.fluid.dygraph.amp_guard( + enable=enable_amp, level=level): out = resnet(img) loss = fluid.layers.cross_entropy(input=out, label=label) @@ -599,8 +975,10 @@ class TestResnet(unittest.TestCase): def test_resnet(self): out_fp32 = self.train_resnet(enable_amp=False) out_amp = self.train_resnet(enable_amp=True) - print(out_fp32[0], out_amp[0]) + out_pure_fp16 = self.train_resnet(enable_amp=True, level='O2') + print(out_fp32[0], out_amp[0], out_pure_fp16[0]) self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-2)) + self.assertTrue(np.allclose(out_fp32[0], out_pure_fp16[0], atol=1.e-1)) class TestLayerNormFp16(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_jit_save_load.py b/python/paddle/fluid/tests/unittests/test_jit_save_load.py index 1d24687a6b..fc58f979b4 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_jit_save_load.py @@ -1099,7 +1099,6 @@ class TestJitSaveLoadSaveWithoutRunning(unittest.TestCase): paddle.static.InputSpec( shape=[None, IMAGE_SIZE], dtype='float32') ]) - result_00 = layer_save(inps0) result_01 = layer_save(inps1) #load and save without running diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index e065ee91c6..cc28eead52 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -22,6 +22,8 @@ from ..fluid.layer_helper import LayerHelper import warnings from ..fluid.dygraph import base as imperative_base from collections import defaultdict +import numpy as np +import time import paddle from paddle import _C_ops @@ -208,26 +210,29 @@ class Adam(Optimizer): } def _create_master_weight(self, param): - assert isinstance(self.helper, LayerHelper) - - var_name = param.name + "_fp32_master" - var_name = unique_name.generate(var_name) - var = layers.create_global_var( - name=var_name, - shape=param.shape, - value=0, - dtype='float32', - persistable=True) - block = self.helper.startup_program.global_block() - block.append_op( - type="cast", - inputs={"X": [param]}, - outputs={"Out": [var]}, - attrs={ - "in_dtype": param.dtype, - "out_dtype": core.VarDesc.VarType.FP32 - }) - self._master_weights[param.name] = var + if param.name in self._master_weights: + var = self._master_weights[param.name] + else: + assert isinstance(self.helper, LayerHelper) + + var_name = param.name + "_fp32_master" + var_name = unique_name.generate(var_name) + var = layers.create_global_var( + name=var_name, + shape=param.shape, + value=0, + dtype='float32', + persistable=True) + block = self.helper.startup_program.global_block() + block.append_op( + type="cast", + inputs={"X": [param]}, + outputs={"Out": [var]}, + attrs={ + "in_dtype": param.dtype, + "out_dtype": core.VarDesc.VarType.FP32 + }) + self._master_weights[param.name] = var return var def _get_accumulator(self, name, param): @@ -317,12 +322,13 @@ class Adam(Optimizer): self._beta1, Variable) else self._beta1.numpy().item(0) _beta2 = self._beta2 if not isinstance( self._beta2, Variable) else self._beta2.numpy().item(0) - _, _, _, _, _ = _C_ops.adam( + _, _, _, _, _, _ = _C_ops.adam( param_and_grad[0], param_and_grad[1], lr, moment1, moment2, - beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1, - moment2, beta1_pow_acc, beta2_pow_acc, 'epsilon', self._epsilon, - 'lazy_mode', self._lazy_mode, 'min_row_size_to_use_multithread', - 1000, 'beta1', _beta1, 'beta2', _beta2) + beta1_pow_acc, beta2_pow_acc, master_weight, param_and_grad[0], + moment1, moment2, beta1_pow_acc, beta2_pow_acc, master_weight, + 'epsilon', self._epsilon, 'lazy_mode', self._lazy_mode, + 'min_row_size_to_use_multithread', 1000, 'beta1', _beta1, + 'beta2', _beta2, 'multi_precision', find_master) return None diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index 0efc40d330..30b8fa975a 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -297,13 +297,15 @@ class AdamW(Adam): self._beta1, Variable) else self._beta1.numpy().item(0) _beta2 = self._beta2 if not isinstance( self._beta2, Variable) else self._beta2.numpy().item(0) - _, _, _, _, _ = _C_ops.adamw( + + _, _, _, _, _, _ = _C_ops.adam( param_and_grad[0], param_and_grad[1], lr, moment1, moment2, - beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1, - moment2, beta1_pow_acc, beta2_pow_acc, 'epsilon', self._epsilon, - 'lazy_mode', self._lazy_mode, 'min_row_size_to_use_multithread', - 1000, 'beta1', _beta1, 'beta2', _beta2, 'coeff', self._coeff, - "lr_ratio", lr_ratio_) + beta1_pow_acc, beta2_pow_acc, master_weight, param_and_grad[0], + moment1, moment2, beta1_pow_acc, beta2_pow_acc, master_weight, + 'epsilon', self._epsilon, 'lazy_mode', self._lazy_mode, + 'min_row_size_to_use_multithread', 1000, 'beta1', _beta1, + 'beta2', _beta2, 'coeff', self._coeff, 'multi_precision', + find_master) return None diff --git a/python/paddle/optimizer/momentum.py b/python/paddle/optimizer/momentum.py index d33c9ecbb4..fde3b28607 100644 --- a/python/paddle/optimizer/momentum.py +++ b/python/paddle/optimizer/momentum.py @@ -170,7 +170,7 @@ class Momentum(Optimizer): 'regularization_method': self._regularization_method, 'regularization_coeff': self._regularization_coeff, } - + ''' if framework.in_dygraph_mode(): self.helper = LayerHelper(self.__class__.__name__) if isinstance(self._parameter_list[0], dict): @@ -180,6 +180,7 @@ class Momentum(Optimizer): else: for p in parameters: self._add_accumulator(self._velocity_acc_str, p) + ''' def _update_regularization(self, weight_decay): reg_method = "" @@ -194,26 +195,29 @@ class Momentum(Optimizer): return reg_method, reg_coeff def _create_master_weight(self, param): - assert isinstance(self.helper, LayerHelper) - - var_name = param.name + "_fp32_master" - var_name = unique_name.generate(var_name) - var = layers.create_global_var( - name=var_name, - shape=param.shape, - value=0, - dtype='float32', - persistable=True) - block = self.helper.startup_program.global_block() - block.append_op( - type="cast", - inputs={"X": [param]}, - outputs={"Out": [var]}, - attrs={ - "in_dtype": param.dtype, - "out_dtype": core.VarDesc.VarType.FP32 - }) - self._master_weights[param.name] = var + if param.name in self._master_weights: + var = self._master_weights[param.name] + else: + assert isinstance(self.helper, LayerHelper) + + var_name = param.name + "_fp32_master" + var_name = unique_name.generate(var_name) + var = layers.create_global_var( + name=var_name, + shape=param.shape, + value=0, + dtype='float32', + persistable=True) + block = self.helper.startup_program.global_block() + block.append_op( + type="cast", + inputs={"X": [param]}, + outputs={"Out": [var]}, + attrs={ + "in_dtype": param.dtype, + "out_dtype": core.VarDesc.VarType.FP32 + }) + self._master_weights[param.name] = var return var def _get_accumulator(self, name, param): @@ -239,10 +243,15 @@ class Momentum(Optimizer): return self._accumulators[name][target_name] def _create_accumulators(self, block, parameters): + ''' if framework.in_dygraph_mode(): return - + ''' assert isinstance(block, framework.Block) + + if isinstance(parameters, dict): + parameters = self._update_param_group(parameters) + for p in parameters: if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: master_p = self._create_master_weight(p) @@ -291,21 +300,23 @@ class Momentum(Optimizer): regularization_method = "" regularization_coeff = 0 + find_master = self._multi_precision and param_and_grad[ + 0].dtype == core.VarDesc.VarType.FP16 + master_weight = (self._master_weights[param_and_grad[0].name] + if find_master else None) + if framework.in_dygraph_mode(): if isinstance(param_and_grad, dict): self._update_regularization(param_and_grad['weight_decay']) - _, _ = _C_ops.momentum( + _, _, _ = _C_ops.momentum( param_and_grad[0], param_and_grad[1], velocity_acc, lr, - param_and_grad[0], velocity_acc, 'mu', self._momentum, - 'use_nesterov', self._use_nesterov, 'regularization_method', - regularization_method, 'regularization_coeff', - regularization_coeff) - return None + master_weight, param_and_grad[0], velocity_acc, master_weight, + 'mu', self._momentum, 'use_nesterov', self._use_nesterov, + 'regularization_method', regularization_method, + 'regularization_coeff', regularization_coeff, 'multi_precision', + find_master) - find_master = self._multi_precision and param_and_grad[ - 0].dtype == core.VarDesc.VarType.FP16 - master_weight = (self._master_weights[param_and_grad[0].name] - if find_master else None) + return None attrs = { "mu": self._momentum, -- GitLab