From d8ea8a0623e83dd4edf5dd0c4e6d7d1f66237040 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Thu, 3 Dec 2020 16:40:10 +0800 Subject: [PATCH] [Cherry-pick] Add pure fp16 training with master weights. (#29301) * Add pure fp16 training with master weights. (#27712) * add the weight decay func for the momentum op * Add the multi_precision function in Momentum Optimizer. * Make sure that the initial value of master weights are same with the fp16 weights. * add static loss scaling. * add the rescale_grad function in the pure fp16 training. * use the original momentum updating method. * Polish some codes, such as variable names. * add docstring for apis. * update the var creation details of _create_master_weight. * not modify codes about imperative momentum updating. * Fix the error of test_dist_sparse_tensor_load_momentum UT. * add unit test for multi precision fp16 training. * add more unit tests for CI. * Use lower threshold values for allclose comparing in test_multi_precision_fp16_train UT. --- .../fluid/operators/optimizers/momentum_op.cc | 28 +- .../fluid/operators/optimizers/momentum_op.h | 429 +++++++++++------- paddle/fluid/pybind/op_function_generator.cc | 2 + .../contrib/mixed_precision/fp16_utils.py | 127 ++++++ python/paddle/fluid/contrib/optimizer.py | 84 +++- .../paddle/fluid/contrib/tests/CMakeLists.txt | 6 + .../tests/test_multi_precision_fp16_train.py | 269 +++++++++++ .../fleet/parameter_server/ir/pserver_pass.py | 3 +- .../fluid/tests/unittests/test_momentum_op.py | 115 ++++- 9 files changed, 879 insertions(+), 184 deletions(-) create mode 100644 python/paddle/fluid/contrib/tests/test_multi_precision_fp16_train.py diff --git a/paddle/fluid/operators/optimizers/momentum_op.cc b/paddle/fluid/operators/optimizers/momentum_op.cc index edffb093a6..1b01f5ebd8 100644 --- a/paddle/fluid/operators/optimizers/momentum_op.cc +++ b/paddle/fluid/operators/optimizers/momentum_op.cc @@ -49,13 +49,17 @@ void MomentumOpMaker::Make() { AddInput("LearningRate", "(Tensor, default Tensor) " "Input learning rate"); - + AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable(); AddOutput("ParamOut", "(Tensor) This output is updated parameter. " "It shared memory with Input(Param)."); AddOutput("VelocityOut", "(Tensor) This output is updated velocity. " "It shared memory with Input(Velocity)."); + AddOutput("MasterParamOut", + "The updated FP32 master weight for AMP. " + "It shared memory with Input(MasterParam).") + .AsDispensable(); AddAttr("mu", "(float) Momentum coefficient"); AddAttr("use_nesterov", @@ -67,7 +71,17 @@ void MomentumOpMaker::Make() { "(string) regularization_method, right now only support l2decay or none") .SetDefault(""); AddAttr("regularization_coeff", "(float) regularization_coeff") - .SetDefault(0); + .SetDefault(0.0f); + AddAttr("multi_precision", + "(bool, default false) " + "Whether to use multi-precision during weight updating.") + .SetDefault(false); + AddAttr( + "rescale_grad", + "(float, default 1.0) Multiply the gradient with `rescale_grad`" + "before updating. Often choose to be `1.0/batch_size`.") + .SetDefault(1.0f); + AddComment(R"DOC( Momentum Optimizer. @@ -109,4 +123,12 @@ REGISTER_OP_VERSION(momentum) "l2decay or none", std::string("")) .NewAttr("regularization_coeff", "(float) regularization_coeff", - 0.0f)); + 0.0f) + .NewAttr( + "multi_precision", + "(bool) Whether to use multi-precision during weight updating.", + false) + .NewAttr("rescale_grad", + "(float) Multiply the gradient with `rescale_grad`" + "before updating. Often choose to be `1.0/batch_size`.", + 1.0f)); diff --git a/paddle/fluid/operators/optimizers/momentum_op.h b/paddle/fluid/operators/optimizers/momentum_op.h index 3b22e0b7a1..64acdfe890 100644 --- a/paddle/fluid/operators/optimizers/momentum_op.h +++ b/paddle/fluid/operators/optimizers/momentum_op.h @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/algorithm.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/for_range.h" namespace paddle { @@ -29,6 +30,44 @@ using framework::SelectedRows; struct NoNesterov; struct UseNesterov; +namespace details { + +template +class MPTypeTrait { + public: + using Type = T; +}; +template <> +class MPTypeTrait { + public: + using Type = float; +}; + +template +struct CPUDenseUpdater { + template + void operator()(const Tensor& param, const Tensor& velocity, const T& mu, + const T& lr, const bool use_nesterov, G&& grad, + Tensor* param_out, Tensor* velocity_out) const { + auto param_out_vec = framework::EigenVector::Flatten(*param_out); + auto velocity_out_vec = framework::EigenVector::Flatten(*velocity_out); + + auto param_vec = framework::EigenVector::Flatten(param); + auto velocity_vec = framework::EigenVector::Flatten(velocity); + velocity_out_vec = velocity_vec * mu + grad; + if (use_nesterov) { + param_out_vec = param_vec - (grad + velocity_out_vec * mu) * lr; + } else { + param_out_vec = param_vec - lr * velocity_out_vec; + } + } +}; + +} // namespace details + +template +using MultiPrecisionType = typename details::MPTypeTrait::Type; + enum class RegularizationType { kNONE = 0, kL1DECAY = 1, // do not need support right now @@ -118,350 +157,394 @@ class MomentumOp : public framework::OperatorWithKernel { template class CPUDenseMomentumFunctor { - private: - const Tensor* param_; - const Tensor* grad_; - const Tensor* velocity_; - const Tensor* learning_rate_; - const T mu_; - const T use_nesterov_; - RegularizationType regularization_flag_; - const T regularization_coeff_; - Tensor* param_out_; - Tensor* velocity_out_; - public: - CPUDenseMomentumFunctor(const Tensor* param, const Tensor* grad, - const Tensor* velocity, const Tensor* learning_rate, - const T mu, const bool use_nesterov, - RegularizationType regularization_flag, - const T regularization_coeff, Tensor* param_out, - Tensor* velocity_out) - : param_(param), - grad_(grad), - velocity_(velocity), - learning_rate_(learning_rate), - mu_(mu), - use_nesterov_(use_nesterov), - regularization_flag_(regularization_flag), - regularization_coeff_(regularization_coeff), - param_out_(param_out), - velocity_out_(velocity_out) {} - - inline void operator()() { - auto param_out = framework::EigenVector::Flatten(*param_out_); - auto velocity_out = framework::EigenVector::Flatten(*velocity_out_); - - auto param = framework::EigenVector::Flatten(*param_); - auto velocity = framework::EigenVector::Flatten(*velocity_); - auto grad = framework::EigenVector::Flatten(*grad_); - auto* lr = learning_rate_->data(); - - if (regularization_flag_ == RegularizationType::kL2DECAY) { - velocity_out = velocity * mu_ + param * regularization_coeff_ + grad; - if (use_nesterov_) { - param_out = - param - - (param * regularization_coeff_ + grad + velocity_out * mu_) * lr[0]; - } else { - param_out = param - lr[0] * velocity_out; - } + void operator()(const Tensor* param, const Tensor* grad, + const Tensor* velocity, const Tensor* learning_rate, + const T mu, const bool use_nesterov, + const RegularizationType regularization_flag, + const T regularization_coeff, Tensor* param_out, + Tensor* velocity_out) { + auto grad_vec = framework::EigenVector::Flatten(*grad); + auto* lr = learning_rate->data>(); + + details::CPUDenseUpdater updater; + if (regularization_flag == RegularizationType::kL2DECAY) { + auto param_vec = framework::EigenVector::Flatten(*param); + updater(*param, *velocity, mu, static_cast(lr[0]), use_nesterov, + param_vec * regularization_coeff + grad_vec, param_out, + velocity_out); } else { - velocity_out = velocity * mu_ + grad; - if (use_nesterov_) { - param_out = param - (grad + velocity_out * mu_) * lr[0]; - } else { - param_out = param - lr[0] * velocity_out; - } + updater(*param, *velocity, mu, static_cast(lr[0]), use_nesterov, + grad_vec, param_out, velocity_out); } } }; -template +template class DenseMomentumFunctor; // NOTE(dzh) for performance. // avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two // functor. -template -class DenseMomentumFunctor { +template +class DenseMomentumFunctor { private: const T* param_; const T* grad_; - const T* velocity_; - const T* lr_; - const T mu_; + const MT* velocity_; + const MultiPrecisionType* lr_; + const MT* master_param_; + const MT mu_; + const MT rescale_grad_; const int64_t num_; T* param_out_; - T* velocity_out_; - RegularizationType regularization_flag_; - const T regularization_coeff_; + MT* velocity_out_; + MT* master_param_out_; + const RegularizationType regularization_flag_; + const MT regularization_coeff_; public: - DenseMomentumFunctor(const T* param, const T* grad, const T* velocity, - const T* learning_rate, const T mu, const int64_t num, - RegularizationType regularization_flag, - const T regularization_coeff, T* param_out, - T* velocity_out) + DenseMomentumFunctor(const T* param, const T* grad, const MT* velocity, + const MultiPrecisionType* learning_rate, + const MT* master_param, const MT mu, + const MT rescale_grad, const int64_t num, + const RegularizationType regularization_flag, + const MT regularization_coeff, T* param_out, + MT* velocity_out, MT* master_param_out) : param_(param), grad_(grad), velocity_(velocity), lr_(learning_rate), + master_param_(master_param), mu_(mu), + rescale_grad_(rescale_grad), num_(num), param_out_(param_out), velocity_out_(velocity_out), + master_param_out_(master_param_out), regularization_flag_(regularization_flag), regularization_coeff_(regularization_coeff) {} - inline HOSTDEVICE void operator()(size_t i) const { // put memory access in register - const T param = param_[i]; - T grad = grad_[i]; - const T lr = lr_[0]; - const T velocity = velocity_[i]; + const MT param = + master_param_ ? master_param_[i] : static_cast(param_[i]); + MT grad = static_cast(grad_[i]) * rescale_grad_; + const MT lr = static_cast(lr_[0]); + const MT velocity = velocity_[i]; grad = regularization_flag_ == RegularizationType::kL2DECAY ? grad + regularization_coeff_ * param : grad; - T velocity_out = velocity * mu_ + grad; - T param_out = param - (grad + velocity_out * mu_) * lr; + MT velocity_out = velocity * mu_ + grad; + MT param_out = param - (grad + velocity_out * mu_) * lr; // write reigster to memory velocity_out_[i] = velocity_out; - param_out_[i] = param_out; + param_out_[i] = static_cast(param_out); + if (master_param_out_) { + master_param_out_[i] = param_out; + } } }; -template -class DenseMomentumFunctor { +template +class DenseMomentumFunctor { private: const T* param_; const T* grad_; - const T* velocity_; - const T* lr_; - const T mu_; + const MT* velocity_; + const MultiPrecisionType* lr_; + const MT* master_param_; + const MT mu_; + const MT rescale_grad_; const int64_t num_; T* param_out_; - T* velocity_out_; - RegularizationType regularization_flag_; - const T regularization_coeff_; + MT* velocity_out_; + MT* master_param_out_; + const RegularizationType regularization_flag_; + const MT regularization_coeff_; public: - DenseMomentumFunctor(const T* param, const T* grad, const T* velocity, - const T* learning_rate, const T mu, const int64_t num, - RegularizationType regularization_flag, - const T regularization_coeff, T* param_out, - T* velocity_out) + DenseMomentumFunctor(const T* param, const T* grad, const MT* velocity, + const MultiPrecisionType* learning_rate, + const MT* master_param, const MT mu, + const MT rescale_grad, const int64_t num, + const RegularizationType regularization_flag, + const MT regularization_coeff, T* param_out, + MT* velocity_out, MT* master_param_out) : param_(param), grad_(grad), velocity_(velocity), lr_(learning_rate), + master_param_(master_param), mu_(mu), + rescale_grad_(rescale_grad), num_(num), param_out_(param_out), velocity_out_(velocity_out), + master_param_out_(master_param_out), regularization_flag_(regularization_flag), regularization_coeff_(regularization_coeff) {} - inline HOSTDEVICE void operator()(size_t i) const { // put memory access in register - const T param = param_[i]; - T grad = grad_[i]; - const T lr = lr_[0]; - const T velocity = velocity_[i]; + const MT param = + master_param_ ? master_param_[i] : static_cast(param_[i]); + MT grad = static_cast(grad_[i]) * rescale_grad_; + const MT lr = static_cast(lr_[0]); + const MT velocity = velocity_[i]; grad = regularization_flag_ == RegularizationType::kL2DECAY ? grad + regularization_coeff_ * param : grad; - T velocity_out = velocity * mu_ + grad; - T param_out = param - lr * velocity_out; + MT velocity_out = velocity * mu_ + grad; + MT param_out = param - lr * velocity_out; // write reigster to memory velocity_out_[i] = velocity_out; - param_out_[i] = param_out; + param_out_[i] = static_cast(param_out); + if (master_param_out_) { + master_param_out_[i] = param_out; + } } }; -template +template class SparseMomentumFunctor; -template -class SparseMomentumFunctor { +template +class SparseMomentumFunctor { private: const T* param_; const T* grad_; - const T* velocity_; - const T* lr_; - const T mu_; + const MT* velocity_; + const MultiPrecisionType* lr_; + const MT* master_param_; + const MT mu_; + const MT rescale_grad_; const int64_t* rows_; const int64_t row_numel_; const int64_t row_height_; T* param_out_; - T* velocity_out_; - RegularizationType regularization_flag_; - const T regularization_coeff_; + MT* velocity_out_; + MT* master_param_out_; + const RegularizationType regularization_flag_; + const MT regularization_coeff_; public: - SparseMomentumFunctor(const T* param, const T* grad, const T* velocity, - const T* lr, const T mu, const int64_t* rows, + SparseMomentumFunctor(const T* param, const T* grad, const MT* velocity, + const MultiPrecisionType* lr, + const MT* master_param, const MT mu, + const MT rescale_grad, const int64_t* rows, int64_t row_numel, int64_t row_height, - RegularizationType regularization_flag, - const T regularization_coeff, T* param_out, - T* velocity_out) + const RegularizationType regularization_flag, + const MT regularization_coeff, T* param_out, + MT* velocity_out, MT* master_param_out) : param_(param), grad_(grad), velocity_(velocity), lr_(lr), + master_param_(master_param), mu_(mu), + rescale_grad_(rescale_grad), rows_(rows), row_numel_(row_numel), row_height_(row_height), param_out_(param_out), velocity_out_(velocity_out), + master_param_out_(master_param_out), regularization_flag_(regularization_flag), regularization_coeff_(regularization_coeff) {} inline HOSTDEVICE void operator()(size_t i) { auto row_idx = math::BinarySearch(rows_, row_height_, i / row_numel_); - T grad = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] - : static_cast(0); + MT grad = + row_idx >= 0 + ? static_cast(grad_[row_idx * row_numel_ + i % row_numel_]) * + rescale_grad_ + : static_cast(0); // put memory access in register - const T param = param_[i]; - const T lr = lr_[0]; - const T velocity = velocity_[i]; + const MT param = + master_param_ ? master_param_[i] : static_cast(param_[i]); + const MT lr = static_cast(lr_[0]); + const MT velocity = velocity_[i]; grad = regularization_flag_ == RegularizationType::kL2DECAY ? grad + regularization_coeff_ * param : grad; - T velocity_out = velocity * mu_ + grad; - T param_out = param - (grad + velocity_out * mu_) * lr; + MT velocity_out = velocity * mu_ + grad; + MT param_out = param - (grad + velocity_out * mu_) * lr; // write reigster to memory velocity_out_[i] = velocity_out; - param_out_[i] = param_out; + param_out_[i] = static_cast(param_out); + if (master_param_out_) { + master_param_out_[i] = param_out; + } } }; -template -class SparseMomentumFunctor { +template +class SparseMomentumFunctor { private: const T* param_; const T* grad_; - const T* velocity_; - const T* lr_; - const T mu_; + const MT* velocity_; + const MultiPrecisionType* lr_; + const MT* master_param_; + const MT mu_; + const MT rescale_grad_; const int64_t* rows_; const int64_t row_numel_; const int64_t row_height_; T* param_out_; - T* velocity_out_; - RegularizationType regularization_flag_; - const T regularization_coeff_; + MT* velocity_out_; + MT* master_param_out_; + const RegularizationType regularization_flag_; + const MT regularization_coeff_; public: - SparseMomentumFunctor(const T* param, const T* grad, const T* velocity, - const T* lr, const T mu, const int64_t* rows, + SparseMomentumFunctor(const T* param, const T* grad, const MT* velocity, + const MultiPrecisionType* lr, + const MT* master_param, const MT mu, + const MT rescale_grad, const int64_t* rows, int64_t row_numel, int64_t row_height, - RegularizationType regularization_flag, - const T regularization_coeff, T* param_out, - T* velocity_out) + const RegularizationType regularization_flag, + const MT regularization_coeff, T* param_out, + MT* velocity_out, MT* master_param_out) : param_(param), grad_(grad), velocity_(velocity), lr_(lr), + master_param_(master_param), mu_(mu), + rescale_grad_(rescale_grad), rows_(rows), row_numel_(row_numel), row_height_(row_height), param_out_(param_out), velocity_out_(velocity_out), + master_param_out_(master_param_out), regularization_flag_(regularization_flag), regularization_coeff_(regularization_coeff) {} inline HOSTDEVICE void operator()(size_t i) { auto row_idx = math::BinarySearch(rows_, row_height_, i / row_numel_); - T grad = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] - : static_cast(0); + MT grad = + row_idx >= 0 + ? static_cast(grad_[row_idx * row_numel_ + i % row_numel_]) * + rescale_grad_ + : static_cast(0); // put memory access in register - const T param = param_[i]; - const T lr = lr_[0]; - const T velocity = velocity_[i]; + const MT param = + master_param_ ? master_param_[i] : static_cast(param_[i]); + const MT lr = static_cast(lr_[0]); + const MT velocity = velocity_[i]; grad = regularization_flag_ == RegularizationType::kL2DECAY ? grad + regularization_coeff_ * param : grad; - T velocity_out = velocity * mu_ + grad; - T param_out = param - velocity_out * lr; + MT velocity_out = velocity * mu_ + grad; + MT param_out = param - velocity_out * lr; // write reigster to memory velocity_out_[i] = velocity_out; - param_out_[i] = param_out; + param_out_[i] = static_cast(param_out); + if (master_param_out_) { + master_param_out_[i] = param_out; + } } }; template class MomentumOpKernel : public framework::OpKernel { + using MPDType = MultiPrecisionType; + public: void Compute(const framework::ExecutionContext& ctx) const override { - std::string regularization_method = - ctx.Attr("regularization_method"); - if (regularization_method != "" || !regularization_method.empty()) { - PADDLE_ENFORCE_EQ("l2_decay", regularization_method, - platform::errors::InvalidArgument( - "if regularization_method is not null, " - "it should be l2_decay, but received %s", - regularization_method)); + const bool multi_precision = ctx.Attr("multi_precision"); + if (multi_precision) { + InnerCompute(ctx, multi_precision); + } else { + InnerCompute(ctx, multi_precision); } + } - T regularization_coeff = - static_cast(ctx.Attr("regularization_coeff")); + private: + template + void InnerCompute(const framework::ExecutionContext& ctx, + const bool multi_precision) const { + std::string regularization_method = + ctx.Attr("regularization_method"); + MT regularization_coeff = + static_cast(ctx.Attr("regularization_coeff")); RegularizationType regularization_flag{ RegularizationType::kNONE}; // disable regularization if (regularization_method == "l2_decay") { regularization_flag = RegularizationType::kL2DECAY; } - T mu = static_cast(ctx.Attr("mu")); + MT mu = static_cast(ctx.Attr("mu")); + MT rescale_grad = static_cast(ctx.Attr("rescale_grad")); bool use_nesterov = ctx.Attr("use_nesterov"); auto learning_rate = ctx.Input("LearningRate"); auto param = ctx.Input("Param"); auto param_out = ctx.Output("ParamOut"); - auto* velocity = ctx.Input("Velocity"); + auto velocity = ctx.Input("Velocity"); auto velocity_out = ctx.Output("VelocityOut"); + const framework::Tensor* master_param = nullptr; + framework::Tensor* master_param_out = nullptr; + if (multi_precision) { + bool has_master = + ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut"); + PADDLE_ENFORCE_EQ(has_master, true, + platform::errors::InvalidArgument( + "The Input(MasterParam) and Output(MasterParamOut) " + "should not be null when " + "the attr `multi_precision` is true")); + master_param = ctx.Input("MasterParam"); + master_param_out = ctx.Output("MasterParamOut"); + } + param_out->mutable_data(ctx.GetPlace()); - velocity_out->mutable_data(ctx.GetPlace()); + velocity_out->mutable_data(ctx.GetPlace()); + const MT* master_in_data = + multi_precision ? master_param->data() : nullptr; + MT* master_out_data = + multi_precision ? master_param_out->mutable_data(ctx.GetPlace()) + : nullptr; auto* grad_var = ctx.InputVar("Grad"); if (grad_var->IsType()) { auto grad = ctx.Input("Grad"); if (platform::is_cpu_place(ctx.GetPlace())) { - CPUDenseMomentumFunctor functor( - param, grad, velocity, learning_rate, mu, use_nesterov, - regularization_flag, regularization_coeff, param_out, velocity_out); - functor(); + CPUDenseMomentumFunctor functor; + functor(param, grad, velocity, learning_rate, mu, use_nesterov, + regularization_flag, regularization_coeff, param_out, + velocity_out); } else if (platform::is_gpu_place(ctx.GetPlace())) { platform::ForRange for_range( static_cast(ctx.device_context()), param->numel()); if (use_nesterov) { - DenseMomentumFunctor functor( - param->data(), grad->data(), velocity->data(), - learning_rate->data(), mu, param->numel(), regularization_flag, - regularization_coeff, param_out->mutable_data(ctx.GetPlace()), - velocity_out->mutable_data(ctx.GetPlace())); + DenseMomentumFunctor functor( + param->data(), grad->data(), velocity->data(), + learning_rate->data(), master_in_data, mu, rescale_grad, + param->numel(), regularization_flag, regularization_coeff, + param_out->mutable_data(ctx.GetPlace()), + velocity_out->mutable_data(ctx.GetPlace()), master_out_data); for_range(functor); } else { - DenseMomentumFunctor functor( - param->data(), grad->data(), velocity->data(), - learning_rate->data(), mu, param->numel(), regularization_flag, - regularization_coeff, param_out->mutable_data(ctx.GetPlace()), - velocity_out->mutable_data(ctx.GetPlace())); + DenseMomentumFunctor functor( + param->data(), grad->data(), velocity->data(), + learning_rate->data(), master_in_data, mu, rescale_grad, + param->numel(), regularization_flag, regularization_coeff, + param_out->mutable_data(ctx.GetPlace()), + velocity_out->mutable_data(ctx.GetPlace()), master_out_data); for_range(functor); } } @@ -489,23 +572,25 @@ class MomentumOpKernel : public framework::OpKernel { static_cast(ctx.device_context()), param->numel()); if (use_nesterov) { - SparseMomentumFunctor functor( + SparseMomentumFunctor functor( param->data(), merged_grad->value().data(), - velocity->data(), learning_rate->data(), mu, rows, row_numel, + velocity->data(), learning_rate->data(), + master_in_data, mu, rescale_grad, rows, row_numel, static_cast(merged_grad->rows().size()), regularization_flag, regularization_coeff, param_out->mutable_data(ctx.GetPlace()), - velocity_out->mutable_data(ctx.GetPlace())); + velocity_out->mutable_data(ctx.GetPlace()), master_out_data); for_range(functor); } else { - SparseMomentumFunctor functor( + SparseMomentumFunctor functor( param->data(), merged_grad->value().data(), - velocity->data(), learning_rate->data(), mu, rows, row_numel, + velocity->data(), learning_rate->data(), + master_in_data, mu, rescale_grad, rows, row_numel, static_cast(merged_grad->rows().size()), regularization_flag, regularization_coeff, param_out->mutable_data(ctx.GetPlace()), - velocity_out->mutable_data(ctx.GetPlace())); + velocity_out->mutable_data(ctx.GetPlace()), master_out_data); for_range(functor); } } else { diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 0f5ce84155..07218b8f3e 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -54,6 +54,7 @@ std::map> op_ins_map = { {"moving_average_abs_max_scale", {"X", "InAccum", "InState"}}, {"multiclass_nms3", {"BBoxes", "Scores", "RoisNum"}}, {"box_coder", {"PriorBox", "PriorBoxVar", "TargetBox"}}, + {"momentum", {"Param", "Grad", "Velocity", "LearningRate"}}, }; // NOTE(zhiqiu): Like op_ins_map. @@ -82,6 +83,7 @@ std::map> op_outs_map = { {"moving_average_abs_max_scale", {"OutScale", "OutAccum", "OutState"}}, {"multiclass_nms3", {"Out", "NmsRoisNum"}}, {"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}}, + {"momentum", {"ParamOut", "VelocityOut"}}, }; // NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index 1d9f8af102..6987b92a89 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -16,6 +16,12 @@ from __future__ import print_function from ... import core from ... import layers +from ... import global_scope +from ...log_helper import get_logger +import logging +import numpy as np +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') def _rename_arg(op, old_name, new_name): @@ -190,6 +196,127 @@ def _is_in_black_varnames(op, amp_lists): return False +def cast_model_to_fp16(main_program): + """ + Traverse all ops in the whole model and set their inputs and outputs + to the fp16 data type. This function will do some special process for + the batch normalization, which keeps the computational process of + batchnorms in FP32. + Args: + main_program (Program): The main program for training. + """ + valid_types = [ + core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS, + core.VarDesc.VarType.LOD_TENSOR_ARRAY + ] + global_block = main_program.global_block() + + for block in main_program.blocks: + ops = block.ops + for op in ops: + if op.type == 'create_py_reader' or op.type == 'read': + continue + for in_name in op.input_names: + if op.type in { + 'batch_norm', 'fused_bn_add_activation', 'layer_norm' + } and in_name not in {'X', 'Z'}: + continue + for in_var_name in op.input(in_name): + in_var = None + try: + in_var = block.var(in_var_name) + except ValueError as e: + _logger.debug( + "-- {}, try to get it in the global block. --". + format(e)) + in_var = global_block.var(in_var_name) + if in_var is not None: + _logger.debug( + "-- var {} is got in the global block. --". + format(in_var_name)) + + if in_var is None or in_var.type not in valid_types: + continue + + if in_var.dtype == core.VarDesc.VarType.FP32: + in_var.desc.set_dtype(core.VarDesc.VarType.FP16) + + _logger.debug( + "-- op type: {}, in var name: {}, in var dtype: {} --". + format(op.type, in_var_name, in_var.dtype)) + + for out_name in op.output_names: + if op.type in { + 'batch_norm', 'fused_bn_add_activation', 'layer_norm' + } and out_name != 'Y': + continue + for out_var_name in op.output(out_name): + out_var = None + try: + out_var = block.var(out_var_name) + except ValueError as e: + _logger.debug( + "-- {}, try to get it in the global block. --". + format(e)) + out_var = global_block.var(out_var_name) + if out_var is not None: + _logger.debug( + "-- var {} is got in the global block. --". + format(out_var_name)) + + if out_var is None or out_var.type not in valid_types: + continue + + if out_var.dtype == core.VarDesc.VarType.FP32: + out_var.desc.set_dtype(core.VarDesc.VarType.FP16) + + _logger.debug( + "-- op type: {}, out var name: {}, out var dtype: {} --". + format(op.type, out_var_name, out_var.dtype)) + if op.has_attr('in_dtype') and op.attr( + 'in_dtype') == core.VarDesc.VarType.FP32: + op._set_attr('in_dtype', core.VarDesc.VarType.FP16) + if op.has_attr('out_dtype') and op.attr( + 'out_dtype') == core.VarDesc.VarType.FP32: + op._set_attr('out_dtype', core.VarDesc.VarType.FP16) + if op.has_attr('dtype') and op.attr( + 'dtype') == core.VarDesc.VarType.FP32: + op._set_attr('dtype', core.VarDesc.VarType.FP16) + + +def cast_parameters_to_fp16(place, main_program, scope=None): + """ + Traverse all parameters in the whole model and set them to the fp16 data type. + Whereas, this function will keep parameters of batchnorms in FP32. + Args: + place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors. + main_program (Program): The main program for training. + scope(fluid.Scope, optional): scope is used to get the weight tensor values. + Default is None. + """ + all_ops = [] + for block in main_program.blocks: + all_ops.extend(block.ops) + bn_params = set() + for op in all_ops: + if op.type not in { + 'batch_norm', 'fused_bn_add_activation', 'layer_norm' + }: + continue + for in_name in op.input_names: + if in_name not in {'X', 'Z'}: + for in_var_name in op.input(in_name): + bn_params.add(in_var_name) + global_block = main_program.global_block() + all_parameters = global_block.all_parameters() + var_scope = scope if scope is not None else global_scope() + for param in all_parameters: + if param.name not in bn_params: + param_t = var_scope.find_var(param.name).get_tensor() + data = np.array(param_t) + param_t.set(np.float16(data), place) + + def rewrite_program(main_prog, amp_lists): """ Traverse all ops in current block and insert cast op according to diff --git a/python/paddle/fluid/contrib/optimizer.py b/python/paddle/fluid/contrib/optimizer.py index 968bfa92b5..2a22969d52 100644 --- a/python/paddle/fluid/contrib/optimizer.py +++ b/python/paddle/fluid/contrib/optimizer.py @@ -14,11 +14,13 @@ from paddle.fluid.optimizer import Optimizer from paddle.fluid.regularizer import L1DecayRegularizer from paddle.fluid.regularizer import L2DecayRegularizer -from paddle.fluid.regularizer import append_regularization_ops -from paddle.fluid import framework from paddle.fluid import core +from paddle.fluid import framework from paddle.fluid.framework import program_guard -from paddle.fluid.clip import append_gradient_clip_ops +from paddle.fluid import unique_name +from paddle.fluid import layers +from paddle.fluid.layer_helper import LayerHelper +import warnings __all__ = ['Momentum'] @@ -61,6 +63,9 @@ class Momentum(Optimizer): some derived class of ``GradientClipBase`` . There are three cliping strategies ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. + multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false. + rescale_grad (float, optional): Multiply the gradient with `rescale_grad` before updating. \ + Often choose to be ``1.0/batch_size``. name (str, optional): This parameter is used by developers to print debugging information. \ For details, please refer to :ref:`api_guide_Name`. Default is None. @@ -105,6 +110,8 @@ class Momentum(Optimizer): use_nesterov=False, regularization=None, grad_clip=None, + multi_precision=False, + rescale_grad=1.0, name=None): assert learning_rate is not None assert momentum is not None @@ -124,11 +131,68 @@ class Momentum(Optimizer): if (isinstance(regularization, L2DecayRegularizer)): self._regularization_method = "l2_decay" self._regularization_coeff = regularization._regularization_coeff + self._multi_precision = multi_precision + self._rescale_grad = rescale_grad + self._master_weights = {} + + def _create_master_weight(self, param): + assert isinstance(self.helper, LayerHelper) + + var_name = param.name + "_fp32_master" + var_name = unique_name.generate(var_name) + var = layers.create_global_var( + name=var_name, + shape=param.shape, + value=0, + dtype='float32', + persistable=True) + block = self.helper.startup_program.global_block() + block.append_op( + type="cast", + inputs={"X": [param]}, + outputs={"Out": [var]}, + attrs={ + "in_dtype": param.dtype, + "out_dtype": core.VarDesc.VarType.FP32 + }) + self._master_weights[param.name] = var + return var + + def _get_accumulator(self, name, param): + """Utility function to fetch an accumulator for a parameter + + Args: + name: name of the accumulator + param: parameter variable for which accumulator is to be fetched + + Returns: + accumulator variable for the parameter + """ + if self._name is not None: + name = self._name + "_" + name + find_master = self._multi_precision and param.dtype == core.VarDesc.VarType.FP16 + target_param = self._master_weights[ + param.name] if find_master else param + target_name = target_param.name + if (name not in self._accumulators or + target_name not in self._accumulators[name]): + raise Exception("Accumulator {} does not exist for parameter {}". + format(name, target_name)) + return self._accumulators[name][target_name] def _create_accumulators(self, block, parameters): assert isinstance(block, framework.Block) for p in parameters: + if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: + master_p = self._create_master_weight(p) + self._add_accumulator(self._velocity_acc_str, master_p) + continue + if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision: + warnings.warn( + "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." + "Consider using multi_precision=True option of the Momentum optimizer." + ) self._add_accumulator(self._velocity_acc_str, p) def _append_optimize_op(self, block, param_and_grad): @@ -136,6 +200,10 @@ class Momentum(Optimizer): velocity_acc = self._get_accumulator(self._velocity_acc_str, param_and_grad[0]) + find_master = self._multi_precision and param_and_grad[ + 0].dtype == core.VarDesc.VarType.FP16 + master_weight = (self._master_weights[param_and_grad[0].name] + if find_master else None) lr = self._create_param_lr(param_and_grad) if framework.in_dygraph_mode(): @@ -151,7 +219,9 @@ class Momentum(Optimizer): "mu": self._momentum, "use_nesterov": self._use_nesterov, "regularization_method": self._regularization_method, - "regularization_coeff": self._regularization_coeff + "regularization_coeff": self._regularization_coeff, + "multi_precision": find_master, + "rescale_grad": self._rescale_grad } inputs = { "Param": [param_and_grad[0]], @@ -159,11 +229,15 @@ class Momentum(Optimizer): "Velocity": [velocity_acc], "LearningRate": [lr] } - outputs = { "ParamOut": [param_and_grad[0]], "VelocityOut": [velocity_acc] } + + if find_master: + inputs["MasterParam"] = master_weight + outputs["MasterParamOut"] = master_weight + # create the momentum optimize op momentum_op = block.append_op( type=self.type, diff --git a/python/paddle/fluid/contrib/tests/CMakeLists.txt b/python/paddle/fluid/contrib/tests/CMakeLists.txt index ab84257205..a28588bfa5 100644 --- a/python/paddle/fluid/contrib/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/tests/CMakeLists.txt @@ -1,8 +1,14 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") +list(REMOVE_ITEM TEST_OPS test_multi_precision_fp16_train) + foreach(src ${TEST_OPS}) py_test(${src} SRCS ${src}.py) endforeach() + +py_test_modules(test_multi_precision_fp16_train MODULES test_multi_precision_fp16_train ENVS FLAGS_cudnn_deterministic=true FLAGS_cudnn_batchnorm_spatial_persistent=true FLAGS_conv_workspace_size_limit=1000) + set_tests_properties(test_image_classification_fp16 PROPERTIES TIMEOUT 120) set_tests_properties(test_weight_decay_extend PROPERTIES TIMEOUT 120) +set_tests_properties(test_multi_precision_fp16_train PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/contrib/tests/test_multi_precision_fp16_train.py b/python/paddle/fluid/contrib/tests/test_multi_precision_fp16_train.py new file mode 100644 index 0000000000..64ef2e26bb --- /dev/null +++ b/python/paddle/fluid/contrib/tests/test_multi_precision_fp16_train.py @@ -0,0 +1,269 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import contextlib +import unittest +import numpy as np +from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_model_to_fp16 +from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_parameters_to_fp16 + +paddle.enable_static() + + +def resnet_cifar10(input, depth=32): + def conv_bn_layer(input, + ch_out, + filter_size, + stride, + padding, + act='relu', + bias_attr=False): + tmp = fluid.layers.conv2d( + input=input, + filter_size=filter_size, + num_filters=ch_out, + stride=stride, + padding=padding, + act=None, + bias_attr=bias_attr) + return fluid.layers.batch_norm(input=tmp, act=act) + + def shortcut(input, ch_in, ch_out, stride): + if ch_in != ch_out: + return conv_bn_layer(input, ch_out, 1, stride, 0, None) + else: + return input + + def basicblock(input, ch_in, ch_out, stride): + tmp = conv_bn_layer(input, ch_out, 3, stride, 1) + tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None, bias_attr=True) + short = shortcut(input, ch_in, ch_out, stride) + return fluid.layers.elementwise_add(x=tmp, y=short, act='relu') + + def layer_warp(block_func, input, ch_in, ch_out, count, stride): + tmp = block_func(input, ch_in, ch_out, stride) + for i in range(1, count): + tmp = block_func(tmp, ch_out, ch_out, 1) + return tmp + + assert (depth - 2) % 6 == 0 + n = (depth - 2) // 6 + conv1 = conv_bn_layer( + input=input, ch_out=16, filter_size=3, stride=1, padding=1) + res1 = layer_warp(basicblock, conv1, 16, 16, n, 1) + res2 = layer_warp(basicblock, res1, 16, 32, n, 2) + res3 = layer_warp(basicblock, res2, 32, 64, n, 2) + pool = fluid.layers.pool2d( + input=res3, pool_size=8, pool_type='avg', pool_stride=1) + return pool + + +def compile(program, loss_name=None): + build_strategy = paddle.static.BuildStrategy() + exec_strategy = paddle.static.ExecutionStrategy() + + exec_strategy.num_threads = 1 + exec_strategy.num_iteration_per_drop_scope = 10000 + + build_strategy.fuse_bn_act_ops = True + build_strategy.fuse_elewise_add_act_ops = True + build_strategy.fuse_bn_add_act_ops = True + + compiled_program = paddle.static.CompiledProgram( + program).with_data_parallel( + loss_name=loss_name, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + + return compiled_program + + +def train(use_pure_fp16=True, use_nesterov=False): + classdim = 10 + data_shape = [3, 32, 32] + BATCH_SIZE = 128 + PASS_NUM = 1 + + train_program = fluid.Program() + startup_prog = fluid.Program() + train_program.random_seed = 123 + startup_prog.random_seed = 456 + with fluid.program_guard(train_program, startup_prog): + images = fluid.layers.data( + name='pixel', shape=data_shape, dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + net = resnet_cifar10(images, 32) + + logits = fluid.layers.fc(input=net, size=classdim, act="softmax") + if use_pure_fp16: + cast_model_to_fp16(fluid.default_main_program()) + logits_fp32 = fluid.layers.cast(x=logits, dtype="float32") + else: + logits_fp32 = logits + cost = fluid.layers.softmax_with_cross_entropy( + logits_fp32, label, return_softmax=False) + sum_cost = fluid.layers.reduce_sum(cost) + + # Test program + test_program = train_program.clone(for_test=True) + + optimizer = fluid.contrib.optimizer.Momentum( + learning_rate=0.001, + momentum=0.9, + use_nesterov=use_nesterov, + regularization=fluid.regularizer.L2Decay(1e-4), + multi_precision=use_pure_fp16, + rescale_grad=1.0 / BATCH_SIZE) + + optimizer.minimize(sum_cost) + + # no shuffle for unit test + train_reader = paddle.batch( + paddle.dataset.cifar.train10(), batch_size=BATCH_SIZE) + + test_reader = paddle.batch( + paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE) + + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + feeder = fluid.DataFeeder(place=place, feed_list=[images, label]) + + def train_loop(main_program): + exe.run(startup_prog) + if use_pure_fp16: + cast_parameters_to_fp16(place, train_program, fluid.global_scope()) + compiled_program = compile(train_program, sum_cost.name) + loss = 0.0 + for pass_id in range(PASS_NUM): + train_loss_list = [] + for batch_id, data in enumerate(train_reader()): + loss, = exe.run(compiled_program, + feed=feeder.feed(data), + fetch_list=[sum_cost]) + print('PassID {0:1}, Train Batch ID {1:04}, train loss {2:2.4}'. + format(pass_id, batch_id + 1, float(loss))) + train_loss_list.append(float(loss)) + + if batch_id >= 4: # For speeding up CI + test_loss_list = [] + for tid, test_data in enumerate(test_reader()): + loss_t, = exe.run(program=test_program, + feed=feeder.feed(test_data), + fetch_list=[sum_cost]) + test_loss_list.append(float(loss_t)) + print( + 'PassID {0:1}, Test Batch ID {1:04}, test loss {2:2.4}'. + format(pass_id, tid + 1, float(loss_t))) + if tid >= 4: + break # For speeding up CI + return train_loss_list, test_loss_list + + return train_loop(train_program) + + +class TestImageMultiPrecision(unittest.TestCase): + def test_resnet_pure_fp16(self): + if not fluid.core.is_compiled_with_cuda(): + return + + def do_test(use_nesterov=False): + suffix = "with Nesterov" if use_nesterov else "without Nesterov" + with self.scope_prog_guard(): + print("-----------------FP16 Train {}-----------------".format( + suffix)) + train_loss_fp16, test_loss_fp16 = train( + use_pure_fp16=True, use_nesterov=use_nesterov) + with self.scope_prog_guard(): + print("-----------------FP32 Train {}-----------------".format( + suffix)) + train_loss_fp32, test_loss_fp32 = train( + use_pure_fp16=False, use_nesterov=use_nesterov) + + self.assertTrue( + np.allclose( + np.array(train_loss_fp16), + np.array(train_loss_fp32), + rtol=1e-02, + atol=1e-05, + equal_nan=True), + msg='Failed to train in pure FP16.') + self.assertTrue( + np.allclose( + np.array(test_loss_fp16), + np.array(test_loss_fp32), + rtol=1e-02, + atol=1e-05, + equal_nan=True), + msg='Failed to test in pure FP16.') + + do_test(use_nesterov=False) + do_test(use_nesterov=True) + + @contextlib.contextmanager + def scope_prog_guard(self): + prog = fluid.Program() + startup_prog = fluid.Program() + scope = fluid.core.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(prog, startup_prog): + yield + + +class TestAmpWithNonIterableDataLoader(unittest.TestCase): + def decorate_with_data_loader(self): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + with paddle.fluid.unique_name.guard(): + image = fluid.layers.data( + name='image', shape=[3, 224, 224], dtype='float32') + label = fluid.layers.data( + name='label', shape=[1], dtype='int64') + py_reader = fluid.io.DataLoader.from_generator( + feed_list=[image, label], + capacity=4, + iterable=False, + use_double_buffer=False) + zero_var = fluid.layers.fill_constant( + shape=[1], dtype='int64', value=0) + one_var = fluid.layers.fill_constant( + shape=[1], dtype='int64', value=1) + with fluid.layers.control_flow.Switch() as switch: + with switch.case(label != zero_var): + fluid.layers.assign(input=zero_var, output=label) + with switch.default(): + fluid.layers.assign(input=one_var, output=label) + + net = resnet_cifar10(image) + logits = fluid.layers.fc(input=net, size=10, act="softmax") + + block = main_prog.global_block() + for op in block.ops: + if op.type == "mul": + op._set_attr('in_dtype', fluid.core.VarDesc.VarType.FP32) + op._set_attr('out_dtype', fluid.core.VarDesc.VarType.FP32) + op._set_attr('dtype', fluid.core.VarDesc.VarType.FP32) + + cast_model_to_fp16(main_prog) + + def test_non_iterable_dataloader(self): + self.decorate_with_data_loader() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/pserver_pass.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/pserver_pass.py index a60c4e149f..295f02e73c 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/pserver_pass.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/pserver_pass.py @@ -246,7 +246,8 @@ def _append_pserver_ops(optimize_block, opt_op, endpoint, grad_to_block_id, for key in opt_op.input_names: new_shape = None if key in [ - "Param", "Grad", "LearningRate", "Beta1Tensor", "Beta2Tensor" + "Param", "Grad", "LearningRate", "MasterParam", "Beta1Tensor", + "Beta2Tensor" ]: continue var = origin_program.global_block().vars[opt_op.input(key)[0]] diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py index 1bb57409b7..8f629b1522 100644 --- a/python/paddle/fluid/tests/unittests/test_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py @@ -59,7 +59,7 @@ class TestMomentumOp1(OpTest): param = np.random.random((123, 321)).astype(self.dtype) grad = np.random.random((123, 321)).astype(self.dtype) velocity = np.zeros((123, 321)).astype(self.dtype) - learning_rate = np.array([0.001]).astype(self.dtype) + learning_rate = np.array([0.001]).astype(np.float32) mu = 0.0001 use_nesterov = False @@ -217,7 +217,7 @@ class TestSparseMomentumOp(unittest.TestCase): 0.0).astype("float32") velocity_out.set(velocity_out_np_array, place) - # create and initialize LeraningRate Variable + # create and initialize LearningRate Variable lr = scope.var('LearningRate').get_tensor() lr_array = np.full((1), 2.0).astype("float32") lr.set(lr_array, place) @@ -278,6 +278,115 @@ class TestSparseMomentumOp2(TestSparseMomentumOp): self.use_nesterov = True +class TestSparseMomentumOpWithMultiPrecision(unittest.TestCase): + def setUp(self): + self.init_args() + self.regularization_method = "" + self.regularization_coeff = 1.0 + + def check_with_place(self, place): + scope = core.Scope() + # create and initialize Grad Variable + height = 10 + rows = [0, 4, 7] + row_numel = 12 + mu = 1.0 + use_nesterov = self.use_nesterov + regularization_method = self.regularization_method + regularization_coeff = self.regularization_coeff + + # create and initialize Param Variable + param_array = np.full((height, row_numel), 5.0).astype("float32") + param_out_array = np.full((height, row_numel), 0.0).astype("float32") + + param = scope.var('Param').get_tensor() + param.set(param_array.astype("float16"), place) + param_out = scope.var("ParamOut").get_tensor() + param_out.set(param_out_array.astype("float16"), place) + + master_param = scope.var('MasterParam').get_tensor() + master_param.set(param_array, place) + master_param_out = scope.var("MasterParamOut").get_tensor() + master_param_out.set(param_out_array, place) + + grad_selected_rows = scope.var('Grad').get_selected_rows() + grad_selected_rows.set_height(height) + grad_selected_rows.set_rows(rows) + grad_np_array = np.ones((len(rows), row_numel)).astype("float32") + grad_np_array[0, 0] = 2.0 + grad_np_array[2, 8] = 4.0 + grad_tensor = grad_selected_rows.get_tensor() + grad_tensor.set(grad_np_array.astype("float16"), place) + + velocity = scope.var('Velocity').get_tensor() + velocity_np_array = np.ones((height, row_numel)).astype("float32") + velocity.set(velocity_np_array, place) + velocity_out = scope.var('VelocityOut').get_tensor() + velocity_out_np_array = np.full((height, row_numel), + 0.0).astype("float32") + velocity_out.set(velocity_out_np_array, place) + + # create and initialize LearningRate Variable + lr = scope.var('LearningRate').get_tensor() + lr_array = np.full((1), 2.0).astype("float32") + lr.set(lr_array, place) + + # create and run operator + op = Operator( + "momentum", + Param='Param', + Grad='Grad', + Velocity='Velocity', + MasterParam='MasterParam', + ParamOut='ParamOut', + VelocityOut='VelocityOut', + MasterParamOut='MasterParamOut', + LearningRate='LearningRate', + mu=mu, + use_nesterov=use_nesterov, + regularization_method=regularization_method, + regularization_coeff=regularization_coeff, + multi_precision=True, + rescale_grad=1.0) + op.run(scope, place) + + # get and compare result + param_out_np_array = np.array(param_out) + velocity_out_np_array = np.array(velocity_out) + + _grad_np_array = np.full((height, row_numel), 0.0).astype("float32") + for i in range(len(rows)): + _grad_np_array[rows[i]] = grad_np_array[i] + + _param = param_array + + _param_out, _velocity_out = calculate_momentum_by_numpy( + param=_param, + grad=_grad_np_array, + mu=mu, + velocity=velocity_np_array, + use_nesterov=use_nesterov, + learning_rate=lr_array, + regularization_method=regularization_method, + regularization_coeff=regularization_coeff) + + self.assertTrue((_velocity_out == velocity_out_np_array).all()) + self.assertTrue((_param_out == param_out_np_array).all()) + + def init_args(self): + self.use_nesterov = False + + def test_sparse_momentum(self): + if core.is_compiled_with_cuda(): + self.check_with_place(fluid.CUDAPlace(0)) + + +class TestSparseMomentumOpWithMultiPrecision2( + TestSparseMomentumOpWithMultiPrecision): + def init_args(self): + self.use_nesterov = True + + class TestMomentumV2(unittest.TestCase): def test_momentum_dygraph(self): paddle.disable_static() @@ -334,7 +443,7 @@ class TestMomentumOpWithDecay(OpTest): param = np.random.random((123, 321)).astype(self.dtype) grad = np.random.random((123, 321)).astype(self.dtype) velocity = np.zeros((123, 321)).astype(self.dtype) - learning_rate = np.array([0.001]).astype(self.dtype) + learning_rate = np.array([0.001]).astype(np.float32) mu = 0.0001 use_nesterov = self.use_nesterov regularization_method = self.regularization_method -- GitLab