未验证 提交 be3777a5 编写于 作者: Z Zhen Wang 提交者: GitHub

Add pure fp16 training with master weights. (#27712)

* add the weight decay func for the momentum op

* Add the multi_precision function in Momentum Optimizer.

* Make sure that the initial value of master weights are same with the fp16 weights.

* add static loss scaling.

* add the rescale_grad function in the pure fp16 training.

* use the original momentum updating method.

* Polish some codes, such as variable names.

* add docstring for apis.

* update the var creation details of _create_master_weight.

* not modify codes about imperative momentum updating.

* Fix the error of test_dist_sparse_tensor_load_momentum UT.

* add unit test for multi precision fp16 training.

* add more unit tests for CI.

* Use lower threshold values for allclose comparing in test_multi_precision_fp16_train UT.

* For CI Coverage Checking.
上级 976961de
......@@ -49,13 +49,17 @@ void MomentumOpMaker::Make() {
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"Input learning rate");
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();
AddOutput("ParamOut",
"(Tensor) This output is updated parameter. "
"It shared memory with Input(Param).");
AddOutput("VelocityOut",
"(Tensor) This output is updated velocity. "
"It shared memory with Input(Velocity).");
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDispensable();
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<bool>("use_nesterov",
......@@ -67,7 +71,17 @@ void MomentumOpMaker::Make() {
"(string) regularization_method, right now only support l2decay or none")
.SetDefault("");
AddAttr<float>("regularization_coeff", "(float) regularization_coeff")
.SetDefault(0);
.SetDefault(0.0f);
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
.SetDefault(false);
AddAttr<float>(
"rescale_grad",
"(float, default 1.0) Multiply the gradient with `rescale_grad`"
"before updating. Often choose to be `1.0/batch_size`.")
.SetDefault(1.0f);
AddComment(R"DOC(
Momentum Optimizer.
......@@ -109,4 +123,12 @@ REGISTER_OP_VERSION(momentum)
"l2decay or none",
std::string(""))
.NewAttr("regularization_coeff", "(float) regularization_coeff",
0.0f));
0.0f)
.NewAttr(
"multi_precision",
"(bool) Whether to use multi-precision during weight updating.",
false)
.NewAttr("rescale_grad",
"(float) Multiply the gradient with `rescale_grad`"
"before updating. Often choose to be `1.0/batch_size`.",
1.0f));
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
......@@ -29,6 +30,44 @@ using framework::SelectedRows;
struct NoNesterov;
struct UseNesterov;
namespace details {
template <typename T>
class MPTypeTrait {
public:
using Type = T;
};
template <>
class MPTypeTrait<platform::float16> {
public:
using Type = float;
};
template <typename T>
struct CPUDenseUpdater {
template <typename G>
void operator()(const Tensor& param, const Tensor& velocity, const T& mu,
const T& lr, const bool use_nesterov, G&& grad,
Tensor* param_out, Tensor* velocity_out) const {
auto param_out_vec = framework::EigenVector<T>::Flatten(*param_out);
auto velocity_out_vec = framework::EigenVector<T>::Flatten(*velocity_out);
auto param_vec = framework::EigenVector<T>::Flatten(param);
auto velocity_vec = framework::EigenVector<T>::Flatten(velocity);
velocity_out_vec = velocity_vec * mu + grad;
if (use_nesterov) {
param_out_vec = param_vec - (grad + velocity_out_vec * mu) * lr;
} else {
param_out_vec = param_vec - lr * velocity_out_vec;
}
}
};
} // namespace details
template <typename T>
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type;
enum class RegularizationType {
kNONE = 0,
kL1DECAY = 1, // do not need support right now
......@@ -118,350 +157,427 @@ class MomentumOp : public framework::OperatorWithKernel {
template <typename T>
class CPUDenseMomentumFunctor {
private:
const Tensor* param_;
const Tensor* grad_;
const Tensor* velocity_;
const Tensor* learning_rate_;
const T mu_;
const T use_nesterov_;
RegularizationType regularization_flag_;
const T regularization_coeff_;
Tensor* param_out_;
Tensor* velocity_out_;
public:
CPUDenseMomentumFunctor(const Tensor* param, const Tensor* grad,
const Tensor* velocity, const Tensor* learning_rate,
const T mu, const bool use_nesterov,
RegularizationType regularization_flag,
const T regularization_coeff, Tensor* param_out,
Tensor* velocity_out)
: param_(param),
grad_(grad),
velocity_(velocity),
learning_rate_(learning_rate),
mu_(mu),
use_nesterov_(use_nesterov),
regularization_flag_(regularization_flag),
regularization_coeff_(regularization_coeff),
param_out_(param_out),
velocity_out_(velocity_out) {}
inline void operator()() {
auto param_out = framework::EigenVector<T>::Flatten(*param_out_);
auto velocity_out = framework::EigenVector<T>::Flatten(*velocity_out_);
auto param = framework::EigenVector<T>::Flatten(*param_);
auto velocity = framework::EigenVector<T>::Flatten(*velocity_);
auto grad = framework::EigenVector<T>::Flatten(*grad_);
auto* lr = learning_rate_->data<T>();
if (regularization_flag_ == RegularizationType::kL2DECAY) {
velocity_out = velocity * mu_ + param * regularization_coeff_ + grad;
if (use_nesterov_) {
param_out =
param -
(param * regularization_coeff_ + grad + velocity_out * mu_) * lr[0];
} else {
param_out = param - lr[0] * velocity_out;
}
void operator()(const Tensor* param, const Tensor* grad,
const Tensor* velocity, const Tensor* learning_rate,
const T mu, const bool use_nesterov,
const RegularizationType regularization_flag,
const T regularization_coeff, Tensor* param_out,
Tensor* velocity_out) {
auto grad_vec = framework::EigenVector<T>::Flatten(*grad);
auto* lr = learning_rate->data<MultiPrecisionType<T>>();
details::CPUDenseUpdater<T> updater;
if (regularization_flag == RegularizationType::kL2DECAY) {
auto param_vec = framework::EigenVector<T>::Flatten(*param);
updater(*param, *velocity, mu, static_cast<T>(lr[0]), use_nesterov,
param_vec * regularization_coeff + grad_vec, param_out,
velocity_out);
} else {
velocity_out = velocity * mu_ + grad;
if (use_nesterov_) {
param_out = param - (grad + velocity_out * mu_) * lr[0];
} else {
param_out = param - lr[0] * velocity_out;
}
updater(*param, *velocity, mu, static_cast<T>(lr[0]), use_nesterov,
grad_vec, param_out, velocity_out);
}
}
};
template <typename T, typename UpdateMethod>
template <typename T, typename MT, typename UpdateMethod>
class DenseMomentumFunctor;
// NOTE(dzh) for performance.
// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two
// functor.
template <typename T>
class DenseMomentumFunctor<T, UseNesterov> {
template <typename T, typename MT>
class DenseMomentumFunctor<T, MT, UseNesterov> {
private:
const T* param_;
const T* grad_;
const T* velocity_;
const T* lr_;
const T mu_;
const MT* velocity_;
const MultiPrecisionType<MT>* lr_;
const MT* master_param_;
const MT mu_;
const MT rescale_grad_;
const int64_t num_;
T* param_out_;
T* velocity_out_;
RegularizationType regularization_flag_;
const T regularization_coeff_;
MT* velocity_out_;
MT* master_param_out_;
const RegularizationType regularization_flag_;
const MT regularization_coeff_;
public:
DenseMomentumFunctor(const T* param, const T* grad, const T* velocity,
const T* learning_rate, const T mu, const int64_t num,
RegularizationType regularization_flag,
const T regularization_coeff, T* param_out,
T* velocity_out)
DenseMomentumFunctor(const T* param, const T* grad, const MT* velocity,
const MultiPrecisionType<MT>* learning_rate,
const MT* master_param, const MT mu,
const MT rescale_grad, const int64_t num,
const RegularizationType regularization_flag,
const MT regularization_coeff, T* param_out,
MT* velocity_out, MT* master_param_out)
: param_(param),
grad_(grad),
velocity_(velocity),
lr_(learning_rate),
master_param_(master_param),
mu_(mu),
rescale_grad_(rescale_grad),
num_(num),
param_out_(param_out),
velocity_out_(velocity_out),
master_param_out_(master_param_out),
regularization_flag_(regularization_flag),
regularization_coeff_(regularization_coeff) {}
inline HOSTDEVICE void operator()(size_t i) const {
// put memory access in register
const T param = param_[i];
T grad = grad_[i];
const T lr = lr_[0];
const T velocity = velocity_[i];
const MT param =
master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
MT grad = static_cast<MT>(grad_[i]) * rescale_grad_;
const MT lr = static_cast<MT>(lr_[0]);
const MT velocity = velocity_[i];
grad = regularization_flag_ == RegularizationType::kL2DECAY
? grad + regularization_coeff_ * param
: grad;
T velocity_out = velocity * mu_ + grad;
T param_out = param - (grad + velocity_out * mu_) * lr;
MT velocity_out = velocity * mu_ + grad;
MT param_out = param - (grad + velocity_out * mu_) * lr;
// write reigster to memory
velocity_out_[i] = velocity_out;
param_out_[i] = param_out;
param_out_[i] = static_cast<T>(param_out);
if (master_param_out_) {
master_param_out_[i] = param_out;
}
}
};
template <typename T>
class DenseMomentumFunctor<T, NoNesterov> {
template <typename T, typename MT>
class DenseMomentumFunctor<T, MT, NoNesterov> {
private:
const T* param_;
const T* grad_;
const T* velocity_;
const T* lr_;
const T mu_;
const MT* velocity_;
const MultiPrecisionType<MT>* lr_;
const MT* master_param_;
const MT mu_;
const MT rescale_grad_;
const int64_t num_;
T* param_out_;
T* velocity_out_;
RegularizationType regularization_flag_;
const T regularization_coeff_;
MT* velocity_out_;
MT* master_param_out_;
const RegularizationType regularization_flag_;
const MT regularization_coeff_;
public:
DenseMomentumFunctor(const T* param, const T* grad, const T* velocity,
const T* learning_rate, const T mu, const int64_t num,
RegularizationType regularization_flag,
const T regularization_coeff, T* param_out,
T* velocity_out)
DenseMomentumFunctor(const T* param, const T* grad, const MT* velocity,
const MultiPrecisionType<MT>* learning_rate,
const MT* master_param, const MT mu,
const MT rescale_grad, const int64_t num,
const RegularizationType regularization_flag,
const MT regularization_coeff, T* param_out,
MT* velocity_out, MT* master_param_out)
: param_(param),
grad_(grad),
velocity_(velocity),
lr_(learning_rate),
master_param_(master_param),
mu_(mu),
rescale_grad_(rescale_grad),
num_(num),
param_out_(param_out),
velocity_out_(velocity_out),
master_param_out_(master_param_out),
regularization_flag_(regularization_flag),
regularization_coeff_(regularization_coeff) {}
inline HOSTDEVICE void operator()(size_t i) const {
// put memory access in register
const T param = param_[i];
T grad = grad_[i];
const T lr = lr_[0];
const T velocity = velocity_[i];
const MT param =
master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
MT grad = static_cast<MT>(grad_[i]) * rescale_grad_;
const MT lr = static_cast<MT>(lr_[0]);
const MT velocity = velocity_[i];
grad = regularization_flag_ == RegularizationType::kL2DECAY
? grad + regularization_coeff_ * param
: grad;
T velocity_out = velocity * mu_ + grad;
T param_out = param - lr * velocity_out;
MT velocity_out = velocity * mu_ + grad;
MT param_out = param - lr * velocity_out;
// write reigster to memory
velocity_out_[i] = velocity_out;
param_out_[i] = param_out;
param_out_[i] = static_cast<T>(param_out);
if (master_param_out_) {
master_param_out_[i] = param_out;
}
}
};
template <typename T, typename UpdateMethod>
template <typename T, typename MT, typename UpdateMethod>
class SparseMomentumFunctor;
template <typename T>
class SparseMomentumFunctor<T, UseNesterov> {
template <typename T, typename MT>
class SparseMomentumFunctor<T, MT, UseNesterov> {
private:
const T* param_;
const T* grad_;
const T* velocity_;
const T* lr_;
const T mu_;
const MT* velocity_;
const MultiPrecisionType<MT>* lr_;
const MT* master_param_;
const MT mu_;
const MT rescale_grad_;
const int64_t* rows_;
const int64_t row_numel_;
const int64_t row_height_;
T* param_out_;
T* velocity_out_;
RegularizationType regularization_flag_;
const T regularization_coeff_;
MT* velocity_out_;
MT* master_param_out_;
const RegularizationType regularization_flag_;
const MT regularization_coeff_;
public:
SparseMomentumFunctor(const T* param, const T* grad, const T* velocity,
const T* lr, const T mu, const int64_t* rows,
SparseMomentumFunctor(const T* param, const T* grad, const MT* velocity,
const MultiPrecisionType<MT>* lr,
const MT* master_param, const MT mu,
const MT rescale_grad, const int64_t* rows,
int64_t row_numel, int64_t row_height,
RegularizationType regularization_flag,
const T regularization_coeff, T* param_out,
T* velocity_out)
const RegularizationType regularization_flag,
const MT regularization_coeff, T* param_out,
MT* velocity_out, MT* master_param_out)
: param_(param),
grad_(grad),
velocity_(velocity),
lr_(lr),
master_param_(master_param),
mu_(mu),
rescale_grad_(rescale_grad),
rows_(rows),
row_numel_(row_numel),
row_height_(row_height),
param_out_(param_out),
velocity_out_(velocity_out),
master_param_out_(master_param_out),
regularization_flag_(regularization_flag),
regularization_coeff_(regularization_coeff) {}
inline HOSTDEVICE void operator()(size_t i) {
auto row_idx =
math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
T grad = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_]
: static_cast<T>(0);
MT grad =
row_idx >= 0
? static_cast<MT>(grad_[row_idx * row_numel_ + i % row_numel_]) *
rescale_grad_
: static_cast<MT>(0);
// put memory access in register
const T param = param_[i];
const T lr = lr_[0];
const T velocity = velocity_[i];
const MT param =
master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
const MT lr = static_cast<MT>(lr_[0]);
const MT velocity = velocity_[i];
grad = regularization_flag_ == RegularizationType::kL2DECAY
? grad + regularization_coeff_ * param
: grad;
T velocity_out = velocity * mu_ + grad;
T param_out = param - (grad + velocity_out * mu_) * lr;
MT velocity_out = velocity * mu_ + grad;
MT param_out = param - (grad + velocity_out * mu_) * lr;
// write reigster to memory
velocity_out_[i] = velocity_out;
param_out_[i] = param_out;
param_out_[i] = static_cast<T>(param_out);
if (master_param_out_) {
master_param_out_[i] = param_out;
}
}
};
template <typename T>
class SparseMomentumFunctor<T, NoNesterov> {
template <typename T, typename MT>
class SparseMomentumFunctor<T, MT, NoNesterov> {
private:
const T* param_;
const T* grad_;
const T* velocity_;
const T* lr_;
const T mu_;
const MT* velocity_;
const MultiPrecisionType<MT>* lr_;
const MT* master_param_;
const MT mu_;
const MT rescale_grad_;
const int64_t* rows_;
const int64_t row_numel_;
const int64_t row_height_;
T* param_out_;
T* velocity_out_;
RegularizationType regularization_flag_;
const T regularization_coeff_;
MT* velocity_out_;
MT* master_param_out_;
const RegularizationType regularization_flag_;
const MT regularization_coeff_;
public:
SparseMomentumFunctor(const T* param, const T* grad, const T* velocity,
const T* lr, const T mu, const int64_t* rows,
SparseMomentumFunctor(const T* param, const T* grad, const MT* velocity,
const MultiPrecisionType<MT>* lr,
const MT* master_param, const MT mu,
const MT rescale_grad, const int64_t* rows,
int64_t row_numel, int64_t row_height,
RegularizationType regularization_flag,
const T regularization_coeff, T* param_out,
T* velocity_out)
const RegularizationType regularization_flag,
const MT regularization_coeff, T* param_out,
MT* velocity_out, MT* master_param_out)
: param_(param),
grad_(grad),
velocity_(velocity),
lr_(lr),
master_param_(master_param),
mu_(mu),
rescale_grad_(rescale_grad),
rows_(rows),
row_numel_(row_numel),
row_height_(row_height),
param_out_(param_out),
velocity_out_(velocity_out),
master_param_out_(master_param_out),
regularization_flag_(regularization_flag),
regularization_coeff_(regularization_coeff) {}
inline HOSTDEVICE void operator()(size_t i) {
auto row_idx =
math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
T grad = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_]
: static_cast<T>(0);
MT grad =
row_idx >= 0
? static_cast<MT>(grad_[row_idx * row_numel_ + i % row_numel_]) *
rescale_grad_
: static_cast<MT>(0);
// put memory access in register
const T param = param_[i];
const T lr = lr_[0];
const T velocity = velocity_[i];
const MT param =
master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
const MT lr = static_cast<MT>(lr_[0]);
const MT velocity = velocity_[i];
grad = regularization_flag_ == RegularizationType::kL2DECAY
? grad + regularization_coeff_ * param
: grad;
T velocity_out = velocity * mu_ + grad;
T param_out = param - velocity_out * lr;
MT velocity_out = velocity * mu_ + grad;
MT param_out = param - velocity_out * lr;
// write reigster to memory
velocity_out_[i] = velocity_out;
param_out_[i] = param_out;
param_out_[i] = static_cast<T>(param_out);
if (master_param_out_) {
master_param_out_[i] = param_out;
}
}
};
template <typename DeviceContext, typename T>
class MomentumOpKernel : public framework::OpKernel<T> {
using MPDType = MultiPrecisionType<T>;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::string regularization_method =
ctx.Attr<std::string>("regularization_method");
if (regularization_method != "" || !regularization_method.empty()) {
PADDLE_ENFORCE_EQ("l2_decay", regularization_method,
platform::errors::InvalidArgument(
"if regularization_method is not null, "
"it should be l2_decay, but received %s",
regularization_method));
const bool multi_precision = ctx.Attr<bool>("multi_precision");
if (multi_precision) {
LOG_FIRST_N(INFO, 1) << R"CODE(
InnerCompute<MPDType>(ctx, multi_precision);
)CODE";
InnerCompute<MPDType>(ctx, multi_precision);
} else {
LOG_FIRST_N(INFO, 1) << R"CODE(
InnerCompute<T>(ctx, multi_precision);
)CODE";
InnerCompute<T>(ctx, multi_precision);
}
}
T regularization_coeff =
static_cast<T>(ctx.Attr<float>("regularization_coeff"));
private:
template <typename MT>
void InnerCompute(const framework::ExecutionContext& ctx,
const bool multi_precision) const {
std::string regularization_method =
ctx.Attr<std::string>("regularization_method");
MT regularization_coeff =
static_cast<MT>(ctx.Attr<float>("regularization_coeff"));
RegularizationType regularization_flag{
RegularizationType::kNONE}; // disable regularization
if (regularization_method == "l2_decay") {
regularization_flag = RegularizationType::kL2DECAY;
}
T mu = static_cast<T>(ctx.Attr<float>("mu"));
MT mu = static_cast<MT>(ctx.Attr<float>("mu"));
MT rescale_grad = static_cast<MT>(ctx.Attr<float>("rescale_grad"));
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto param = ctx.Input<framework::Tensor>("Param");
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
auto* velocity = ctx.Input<framework::Tensor>("Velocity");
auto velocity = ctx.Input<framework::Tensor>("Velocity");
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
const framework::Tensor* master_param = nullptr;
framework::Tensor* master_param_out = nullptr;
if (multi_precision) {
LOG_FIRST_N(INFO, 1) << R"CODE(
bool has_master =
ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
PADDLE_ENFORCE_EQ(has_master, true,
platform::errors::InvalidArgument(
"The Input(MasterParam) and Output(MasterParamOut) "
"should not be null when "
"the attr `multi_precision` is true"));
master_param = ctx.Input<framework::Tensor>("MasterParam");
master_param_out = ctx.Output<framework::Tensor>("MasterParamOut");
)CODE";
bool has_master =
ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
PADDLE_ENFORCE_EQ(has_master, true,
platform::errors::InvalidArgument(
"The Input(MasterParam) and Output(MasterParamOut) "
"should not be null when "
"the attr `multi_precision` is true"));
master_param = ctx.Input<framework::Tensor>("MasterParam");
master_param_out = ctx.Output<framework::Tensor>("MasterParamOut");
}
param_out->mutable_data<T>(ctx.GetPlace());
velocity_out->mutable_data<T>(ctx.GetPlace());
velocity_out->mutable_data<MT>(ctx.GetPlace());
const MT* master_in_data =
multi_precision ? master_param->data<MT>() : nullptr;
MT* master_out_data =
multi_precision ? master_param_out->mutable_data<MT>(ctx.GetPlace())
: nullptr;
auto* grad_var = ctx.InputVar("Grad");
if (grad_var->IsType<framework::LoDTensor>()) {
auto grad = ctx.Input<framework::Tensor>("Grad");
if (platform::is_cpu_place(ctx.GetPlace())) {
CPUDenseMomentumFunctor<T> functor(
param, grad, velocity, learning_rate, mu, use_nesterov,
regularization_flag, regularization_coeff, param_out, velocity_out);
functor();
CPUDenseMomentumFunctor<MT> functor;
functor(param, grad, velocity, learning_rate, mu, use_nesterov,
regularization_flag, regularization_coeff, param_out,
velocity_out);
} else if (platform::is_gpu_place(ctx.GetPlace())) {
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
param->numel());
if (use_nesterov) {
DenseMomentumFunctor<T, UseNesterov> functor(
param->data<T>(), grad->data<T>(), velocity->data<T>(),
learning_rate->data<T>(), mu, param->numel(), regularization_flag,
regularization_coeff, param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<T>(ctx.GetPlace()));
LOG_FIRST_N(INFO, 1) << R"CODE(
DenseMomentumFunctor<T, MT, UseNesterov> functor(
param->data<T>(), grad->data<T>(), velocity->data<MT>(),
learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad,
param->numel(), regularization_flag, regularization_coeff,
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
)CODE";
DenseMomentumFunctor<T, MT, UseNesterov> functor(
param->data<T>(), grad->data<T>(), velocity->data<MT>(),
learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad,
param->numel(), regularization_flag, regularization_coeff,
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
for_range(functor);
} else {
DenseMomentumFunctor<T, NoNesterov> functor(
param->data<T>(), grad->data<T>(), velocity->data<T>(),
learning_rate->data<T>(), mu, param->numel(), regularization_flag,
regularization_coeff, param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<T>(ctx.GetPlace()));
LOG_FIRST_N(INFO, 1) << R"CODE(
DenseMomentumFunctor<T, MT, NoNesterov> functor(
param->data<T>(), grad->data<T>(), velocity->data<MT>(),
learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad,
param->numel(), regularization_flag, regularization_coeff,
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
)CODE";
DenseMomentumFunctor<T, MT, NoNesterov> functor(
param->data<T>(), grad->data<T>(), velocity->data<MT>(),
learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad,
param->numel(), regularization_flag, regularization_coeff,
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
for_range(functor);
}
}
......@@ -489,23 +605,25 @@ class MomentumOpKernel : public framework::OpKernel<T> {
static_cast<const DeviceContext&>(ctx.device_context()),
param->numel());
if (use_nesterov) {
SparseMomentumFunctor<T, UseNesterov> functor(
SparseMomentumFunctor<T, MT, UseNesterov> functor(
param->data<T>(), merged_grad->value().data<T>(),
velocity->data<T>(), learning_rate->data<T>(), mu, rows, row_numel,
velocity->data<MT>(), learning_rate->data<MPDType>(),
master_in_data, mu, rescale_grad, rows, row_numel,
static_cast<int64_t>(merged_grad->rows().size()),
regularization_flag, regularization_coeff,
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<T>(ctx.GetPlace()));
velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
for_range(functor);
} else {
SparseMomentumFunctor<T, NoNesterov> functor(
SparseMomentumFunctor<T, MT, NoNesterov> functor(
param->data<T>(), merged_grad->value().data<T>(),
velocity->data<T>(), learning_rate->data<T>(), mu, rows, row_numel,
velocity->data<MT>(), learning_rate->data<MPDType>(),
master_in_data, mu, rescale_grad, rows, row_numel,
static_cast<int64_t>(merged_grad->rows().size()),
regularization_flag, regularization_coeff,
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<T>(ctx.GetPlace()));
velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
for_range(functor);
}
} else {
......
......@@ -54,6 +54,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"moving_average_abs_max_scale", {"X", "InAccum", "InState"}},
{"multiclass_nms3", {"BBoxes", "Scores", "RoisNum"}},
{"box_coder", {"PriorBox", "PriorBoxVar", "TargetBox"}},
{"momentum", {"Param", "Grad", "Velocity", "LearningRate"}},
};
// NOTE(zhiqiu): Like op_ins_map.
......@@ -82,6 +83,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"moving_average_abs_max_scale", {"OutScale", "OutAccum", "OutState"}},
{"multiclass_nms3", {"Out", "NmsRoisNum"}},
{"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}},
{"momentum", {"ParamOut", "VelocityOut"}},
};
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
......
......@@ -16,6 +16,12 @@ from __future__ import print_function
from ... import core
from ... import layers
from ... import global_scope
from ...log_helper import get_logger
import logging
import numpy as np
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
def _rename_arg(op, old_name, new_name):
......@@ -191,6 +197,127 @@ def _is_in_black_varnames(op, amp_lists):
return False
def cast_model_to_fp16(main_program):
"""
Traverse all ops in the whole model and set their inputs and outputs
to the fp16 data type. This function will do some special process for
the batch normalization, which keeps the computational process of
batchnorms in FP32.
Args:
main_program (Program): The main program for training.
"""
valid_types = [
core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS,
core.VarDesc.VarType.LOD_TENSOR_ARRAY
]
global_block = main_program.global_block()
for block in main_program.blocks:
ops = block.ops
for op in ops:
if op.type == 'create_py_reader' or op.type == 'read':
continue
for in_name in op.input_names:
if op.type in {
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
} and in_name not in {'X', 'Z'}:
continue
for in_var_name in op.input(in_name):
in_var = None
try:
in_var = block.var(in_var_name)
except ValueError as e:
_logger.debug(
"-- {}, try to get it in the global block. --".
format(e))
in_var = global_block.var(in_var_name)
if in_var is not None:
_logger.debug(
"-- var {} is got in the global block. --".
format(in_var_name))
if in_var is None or in_var.type not in valid_types:
continue
if in_var.dtype == core.VarDesc.VarType.FP32:
in_var.desc.set_dtype(core.VarDesc.VarType.FP16)
_logger.debug(
"-- op type: {}, in var name: {}, in var dtype: {} --".
format(op.type, in_var_name, in_var.dtype))
for out_name in op.output_names:
if op.type in {
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
} and out_name != 'Y':
continue
for out_var_name in op.output(out_name):
out_var = None
try:
out_var = block.var(out_var_name)
except ValueError as e:
_logger.debug(
"-- {}, try to get it in the global block. --".
format(e))
out_var = global_block.var(out_var_name)
if out_var is not None:
_logger.debug(
"-- var {} is got in the global block. --".
format(out_var_name))
if out_var is None or out_var.type not in valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.FP16)
_logger.debug(
"-- op type: {}, out var name: {}, out var dtype: {} --".
format(op.type, out_var_name, out_var.dtype))
if op.has_attr('in_dtype') and op.attr(
'in_dtype') == core.VarDesc.VarType.FP32:
op._set_attr('in_dtype', core.VarDesc.VarType.FP16)
if op.has_attr('out_dtype') and op.attr(
'out_dtype') == core.VarDesc.VarType.FP32:
op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
if op.has_attr('dtype') and op.attr(
'dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16)
def cast_parameters_to_fp16(place, main_program, scope=None):
"""
Traverse all parameters in the whole model and set them to the fp16 data type.
Whereas, this function will keep parameters of batchnorms in FP32.
Args:
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors.
main_program (Program): The main program for training.
scope(fluid.Scope, optional): scope is used to get the weight tensor values.
Default is None.
"""
all_ops = []
for block in main_program.blocks:
all_ops.extend(block.ops)
bn_params = set()
for op in all_ops:
if op.type not in {
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
}:
continue
for in_name in op.input_names:
if in_name not in {'X', 'Z'}:
for in_var_name in op.input(in_name):
bn_params.add(in_var_name)
global_block = main_program.global_block()
all_parameters = global_block.all_parameters()
var_scope = scope if scope is not None else global_scope()
for param in all_parameters:
if param.name not in bn_params:
param_t = var_scope.find_var(param.name).get_tensor()
data = np.array(param_t)
param_t.set(np.float16(data), place)
def rewrite_program(main_prog, amp_lists):
"""
Traverse all ops in current block and insert cast op according to
......
......@@ -14,11 +14,13 @@
from paddle.fluid.optimizer import Optimizer
from paddle.fluid.regularizer import L1DecayRegularizer
from paddle.fluid.regularizer import L2DecayRegularizer
from paddle.fluid.regularizer import append_regularization_ops
from paddle.fluid import framework
from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.framework import program_guard
from paddle.fluid.clip import append_gradient_clip_ops
from paddle.fluid import unique_name
from paddle.fluid import layers
from paddle.fluid.layer_helper import LayerHelper
import warnings
__all__ = ['Momentum']
......@@ -61,6 +63,9 @@ class Momentum(Optimizer):
some derived class of ``GradientClipBase`` . There are three cliping strategies
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false.
rescale_grad (float, optional): Multiply the gradient with `rescale_grad` before updating. \
Often choose to be ``1.0/batch_size``.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
......@@ -105,6 +110,8 @@ class Momentum(Optimizer):
use_nesterov=False,
regularization=None,
grad_clip=None,
multi_precision=False,
rescale_grad=1.0,
name=None):
assert learning_rate is not None
assert momentum is not None
......@@ -124,11 +131,68 @@ class Momentum(Optimizer):
if (isinstance(regularization, L2DecayRegularizer)):
self._regularization_method = "l2_decay"
self._regularization_coeff = regularization._regularization_coeff
self._multi_precision = multi_precision
self._rescale_grad = rescale_grad
self._master_weights = {}
def _create_master_weight(self, param):
assert isinstance(self.helper, LayerHelper)
var_name = param.name + "_fp32_master"
var_name = unique_name.generate(var_name)
var = layers.create_global_var(
name=var_name,
shape=param.shape,
value=0,
dtype='float32',
persistable=True)
block = self.helper.startup_program.global_block()
block.append_op(
type="cast",
inputs={"X": [param]},
outputs={"Out": [var]},
attrs={
"in_dtype": param.dtype,
"out_dtype": core.VarDesc.VarType.FP32
})
self._master_weights[param.name] = var
return var
def _get_accumulator(self, name, param):
"""Utility function to fetch an accumulator for a parameter
Args:
name: name of the accumulator
param: parameter variable for which accumulator is to be fetched
Returns:
accumulator variable for the parameter
"""
if self._name is not None:
name = self._name + "_" + name
find_master = self._multi_precision and param.dtype == core.VarDesc.VarType.FP16
target_param = self._master_weights[
param.name] if find_master else param
target_name = target_param.name
if (name not in self._accumulators or
target_name not in self._accumulators[name]):
raise Exception("Accumulator {} does not exist for parameter {}".
format(name, target_name))
return self._accumulators[name][target_name]
def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block)
for p in parameters:
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
master_p = self._create_master_weight(p)
self._add_accumulator(self._velocity_acc_str, master_p)
continue
if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision:
warnings.warn(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Momentum optimizer."
)
self._add_accumulator(self._velocity_acc_str, p)
def _append_optimize_op(self, block, param_and_grad):
......@@ -136,6 +200,10 @@ class Momentum(Optimizer):
velocity_acc = self._get_accumulator(self._velocity_acc_str,
param_and_grad[0])
find_master = self._multi_precision and param_and_grad[
0].dtype == core.VarDesc.VarType.FP16
master_weight = (self._master_weights[param_and_grad[0].name]
if find_master else None)
lr = self._create_param_lr(param_and_grad)
if framework.in_dygraph_mode():
......@@ -151,7 +219,9 @@ class Momentum(Optimizer):
"mu": self._momentum,
"use_nesterov": self._use_nesterov,
"regularization_method": self._regularization_method,
"regularization_coeff": self._regularization_coeff
"regularization_coeff": self._regularization_coeff,
"multi_precision": find_master,
"rescale_grad": self._rescale_grad
}
inputs = {
"Param": [param_and_grad[0]],
......@@ -159,11 +229,15 @@ class Momentum(Optimizer):
"Velocity": [velocity_acc],
"LearningRate": [lr]
}
outputs = {
"ParamOut": [param_and_grad[0]],
"VelocityOut": [velocity_acc]
}
if find_master:
inputs["MasterParam"] = master_weight
outputs["MasterParamOut"] = master_weight
# create the momentum optimize op
momentum_op = block.append_op(
type=self.type,
......
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
list(REMOVE_ITEM TEST_OPS test_multi_precision_fp16_train)
foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py)
endforeach()
py_test_modules(test_multi_precision_fp16_train MODULES test_multi_precision_fp16_train ENVS FLAGS_cudnn_deterministic=true FLAGS_cudnn_batchnorm_spatial_persistent=true FLAGS_conv_workspace_size_limit=1000)
set_tests_properties(test_image_classification_fp16 PROPERTIES TIMEOUT 120)
set_tests_properties(test_weight_decay_extend PROPERTIES TIMEOUT 120)
set_tests_properties(test_multi_precision_fp16_train PROPERTIES TIMEOUT 120)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import contextlib
import unittest
import numpy as np
from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_model_to_fp16
from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_parameters_to_fp16
paddle.enable_static()
def resnet_cifar10(input, depth=32):
def conv_bn_layer(input,
ch_out,
filter_size,
stride,
padding,
act='relu',
bias_attr=False):
tmp = fluid.layers.conv2d(
input=input,
filter_size=filter_size,
num_filters=ch_out,
stride=stride,
padding=padding,
act=None,
bias_attr=bias_attr)
return fluid.layers.batch_norm(input=tmp, act=act)
def shortcut(input, ch_in, ch_out, stride):
if ch_in != ch_out:
return conv_bn_layer(input, ch_out, 1, stride, 0, None)
else:
return input
def basicblock(input, ch_in, ch_out, stride):
tmp = conv_bn_layer(input, ch_out, 3, stride, 1)
tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None, bias_attr=True)
short = shortcut(input, ch_in, ch_out, stride)
return fluid.layers.elementwise_add(x=tmp, y=short, act='relu')
def layer_warp(block_func, input, ch_in, ch_out, count, stride):
tmp = block_func(input, ch_in, ch_out, stride)
for i in range(1, count):
tmp = block_func(tmp, ch_out, ch_out, 1)
return tmp
assert (depth - 2) % 6 == 0
n = (depth - 2) // 6
conv1 = conv_bn_layer(
input=input, ch_out=16, filter_size=3, stride=1, padding=1)
res1 = layer_warp(basicblock, conv1, 16, 16, n, 1)
res2 = layer_warp(basicblock, res1, 16, 32, n, 2)
res3 = layer_warp(basicblock, res2, 32, 64, n, 2)
pool = fluid.layers.pool2d(
input=res3, pool_size=8, pool_type='avg', pool_stride=1)
return pool
def compile(program, loss_name=None):
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = 10000
build_strategy.fuse_bn_act_ops = True
build_strategy.fuse_elewise_add_act_ops = True
build_strategy.fuse_bn_add_act_ops = True
compiled_program = paddle.static.CompiledProgram(
program).with_data_parallel(
loss_name=loss_name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
return compiled_program
def train(use_pure_fp16=True, use_nesterov=False):
classdim = 10
data_shape = [3, 32, 32]
BATCH_SIZE = 128
PASS_NUM = 1
train_program = fluid.Program()
startup_prog = fluid.Program()
train_program.random_seed = 123
startup_prog.random_seed = 456
with fluid.program_guard(train_program, startup_prog):
images = fluid.layers.data(
name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
net = resnet_cifar10(images, 32)
logits = fluid.layers.fc(input=net, size=classdim, act="softmax")
if use_pure_fp16:
cast_model_to_fp16(fluid.default_main_program())
logits_fp32 = fluid.layers.cast(x=logits, dtype="float32")
else:
logits_fp32 = logits
cost = fluid.layers.softmax_with_cross_entropy(
logits_fp32, label, return_softmax=False)
sum_cost = fluid.layers.reduce_sum(cost)
# Test program
test_program = train_program.clone(for_test=True)
optimizer = fluid.contrib.optimizer.Momentum(
learning_rate=0.001,
momentum=0.9,
use_nesterov=use_nesterov,
regularization=fluid.regularizer.L2Decay(1e-4),
multi_precision=use_pure_fp16,
rescale_grad=1.0 / BATCH_SIZE)
optimizer.minimize(sum_cost)
# no shuffle for unit test
train_reader = paddle.batch(
paddle.dataset.cifar.train10(), batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(place=place, feed_list=[images, label])
def train_loop(main_program):
exe.run(startup_prog)
if use_pure_fp16:
cast_parameters_to_fp16(place, train_program, fluid.global_scope())
compiled_program = compile(train_program, sum_cost.name)
loss = 0.0
for pass_id in range(PASS_NUM):
train_loss_list = []
for batch_id, data in enumerate(train_reader()):
loss, = exe.run(compiled_program,
feed=feeder.feed(data),
fetch_list=[sum_cost])
print('PassID {0:1}, Train Batch ID {1:04}, train loss {2:2.4}'.
format(pass_id, batch_id + 1, float(loss)))
train_loss_list.append(float(loss))
if batch_id >= 4: # For speeding up CI
test_loss_list = []
for tid, test_data in enumerate(test_reader()):
loss_t, = exe.run(program=test_program,
feed=feeder.feed(test_data),
fetch_list=[sum_cost])
test_loss_list.append(float(loss_t))
print(
'PassID {0:1}, Test Batch ID {1:04}, test loss {2:2.4}'.
format(pass_id, tid + 1, float(loss_t)))
if tid >= 4:
break # For speeding up CI
return train_loss_list, test_loss_list
return train_loop(train_program)
class TestImageMultiPrecision(unittest.TestCase):
def test_resnet_pure_fp16(self):
if not fluid.core.is_compiled_with_cuda():
return
def do_test(use_nesterov=False):
suffix = "with Nesterov" if use_nesterov else "without Nesterov"
with self.scope_prog_guard():
print("-----------------FP16 Train {}-----------------".format(
suffix))
train_loss_fp16, test_loss_fp16 = train(
use_pure_fp16=True, use_nesterov=use_nesterov)
with self.scope_prog_guard():
print("-----------------FP32 Train {}-----------------".format(
suffix))
train_loss_fp32, test_loss_fp32 = train(
use_pure_fp16=False, use_nesterov=use_nesterov)
self.assertTrue(
np.allclose(
np.array(train_loss_fp16),
np.array(train_loss_fp32),
rtol=1e-02,
atol=1e-05,
equal_nan=True),
msg='Failed to train in pure FP16.')
self.assertTrue(
np.allclose(
np.array(test_loss_fp16),
np.array(test_loss_fp32),
rtol=1e-02,
atol=1e-05,
equal_nan=True),
msg='Failed to test in pure FP16.')
do_test(use_nesterov=False)
do_test(use_nesterov=True)
@contextlib.contextmanager
def scope_prog_guard(self):
prog = fluid.Program()
startup_prog = fluid.Program()
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
yield
class TestAmpWithNonIterableDataLoader(unittest.TestCase):
def decorate_with_data_loader(self):
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
with paddle.fluid.unique_name.guard():
image = fluid.layers.data(
name='image', shape=[3, 224, 224], dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int64')
py_reader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=4,
iterable=False,
use_double_buffer=False)
zero_var = fluid.layers.fill_constant(
shape=[1], dtype='int64', value=0)
one_var = fluid.layers.fill_constant(
shape=[1], dtype='int64', value=1)
with fluid.layers.control_flow.Switch() as switch:
with switch.case(label != zero_var):
fluid.layers.assign(input=zero_var, output=label)
with switch.default():
fluid.layers.assign(input=one_var, output=label)
net = resnet_cifar10(image)
logits = fluid.layers.fc(input=net, size=10, act="softmax")
block = main_prog.global_block()
for op in block.ops:
if op.type == "mul":
op._set_attr('in_dtype', fluid.core.VarDesc.VarType.FP32)
op._set_attr('out_dtype', fluid.core.VarDesc.VarType.FP32)
op._set_attr('dtype', fluid.core.VarDesc.VarType.FP32)
cast_model_to_fp16(main_prog)
def test_non_iterable_dataloader(self):
self.decorate_with_data_loader()
if __name__ == '__main__':
unittest.main()
......@@ -246,7 +246,8 @@ def _append_pserver_ops(optimize_block, opt_op, endpoint, grad_to_block_id,
for key in opt_op.input_names:
new_shape = None
if key in [
"Param", "Grad", "LearningRate", "Beta1Tensor", "Beta2Tensor"
"Param", "Grad", "LearningRate", "MasterParam", "Beta1Tensor",
"Beta2Tensor"
]:
continue
var = origin_program.global_block().vars[opt_op.input(key)[0]]
......
......@@ -59,7 +59,7 @@ class TestMomentumOp1(OpTest):
param = np.random.random((123, 321)).astype(self.dtype)
grad = np.random.random((123, 321)).astype(self.dtype)
velocity = np.zeros((123, 321)).astype(self.dtype)
learning_rate = np.array([0.001]).astype(self.dtype)
learning_rate = np.array([0.001]).astype(np.float32)
mu = 0.0001
use_nesterov = False
......@@ -217,7 +217,7 @@ class TestSparseMomentumOp(unittest.TestCase):
0.0).astype("float32")
velocity_out.set(velocity_out_np_array, place)
# create and initialize LeraningRate Variable
# create and initialize LearningRate Variable
lr = scope.var('LearningRate').get_tensor()
lr_array = np.full((1), 2.0).astype("float32")
lr.set(lr_array, place)
......@@ -278,6 +278,115 @@ class TestSparseMomentumOp2(TestSparseMomentumOp):
self.use_nesterov = True
class TestSparseMomentumOpWithMultiPrecision(unittest.TestCase):
def setUp(self):
self.init_args()
self.regularization_method = ""
self.regularization_coeff = 1.0
def check_with_place(self, place):
scope = core.Scope()
# create and initialize Grad Variable
height = 10
rows = [0, 4, 7]
row_numel = 12
mu = 1.0
use_nesterov = self.use_nesterov
regularization_method = self.regularization_method
regularization_coeff = self.regularization_coeff
# create and initialize Param Variable
param_array = np.full((height, row_numel), 5.0).astype("float32")
param_out_array = np.full((height, row_numel), 0.0).astype("float32")
param = scope.var('Param').get_tensor()
param.set(param_array.astype("float16"), place)
param_out = scope.var("ParamOut").get_tensor()
param_out.set(param_out_array.astype("float16"), place)
master_param = scope.var('MasterParam').get_tensor()
master_param.set(param_array, place)
master_param_out = scope.var("MasterParamOut").get_tensor()
master_param_out.set(param_out_array, place)
grad_selected_rows = scope.var('Grad').get_selected_rows()
grad_selected_rows.set_height(height)
grad_selected_rows.set_rows(rows)
grad_np_array = np.ones((len(rows), row_numel)).astype("float32")
grad_np_array[0, 0] = 2.0
grad_np_array[2, 8] = 4.0
grad_tensor = grad_selected_rows.get_tensor()
grad_tensor.set(grad_np_array.astype("float16"), place)
velocity = scope.var('Velocity').get_tensor()
velocity_np_array = np.ones((height, row_numel)).astype("float32")
velocity.set(velocity_np_array, place)
velocity_out = scope.var('VelocityOut').get_tensor()
velocity_out_np_array = np.full((height, row_numel),
0.0).astype("float32")
velocity_out.set(velocity_out_np_array, place)
# create and initialize LearningRate Variable
lr = scope.var('LearningRate').get_tensor()
lr_array = np.full((1), 2.0).astype("float32")
lr.set(lr_array, place)
# create and run operator
op = Operator(
"momentum",
Param='Param',
Grad='Grad',
Velocity='Velocity',
MasterParam='MasterParam',
ParamOut='ParamOut',
VelocityOut='VelocityOut',
MasterParamOut='MasterParamOut',
LearningRate='LearningRate',
mu=mu,
use_nesterov=use_nesterov,
regularization_method=regularization_method,
regularization_coeff=regularization_coeff,
multi_precision=True,
rescale_grad=1.0)
op.run(scope, place)
# get and compare result
param_out_np_array = np.array(param_out)
velocity_out_np_array = np.array(velocity_out)
_grad_np_array = np.full((height, row_numel), 0.0).astype("float32")
for i in range(len(rows)):
_grad_np_array[rows[i]] = grad_np_array[i]
_param = param_array
_param_out, _velocity_out = calculate_momentum_by_numpy(
param=_param,
grad=_grad_np_array,
mu=mu,
velocity=velocity_np_array,
use_nesterov=use_nesterov,
learning_rate=lr_array,
regularization_method=regularization_method,
regularization_coeff=regularization_coeff)
self.assertTrue((_velocity_out == velocity_out_np_array).all())
self.assertTrue((_param_out == param_out_np_array).all())
def init_args(self):
self.use_nesterov = False
def test_sparse_momentum(self):
if core.is_compiled_with_cuda():
self.check_with_place(fluid.CUDAPlace(0))
class TestSparseMomentumOpWithMultiPrecision2(
TestSparseMomentumOpWithMultiPrecision):
def init_args(self):
self.use_nesterov = True
class TestMomentumV2(unittest.TestCase):
def test_momentum_dygraph(self):
paddle.disable_static()
......@@ -334,7 +443,7 @@ class TestMomentumOpWithDecay(OpTest):
param = np.random.random((123, 321)).astype(self.dtype)
grad = np.random.random((123, 321)).astype(self.dtype)
velocity = np.zeros((123, 321)).astype(self.dtype)
learning_rate = np.array([0.001]).astype(self.dtype)
learning_rate = np.array([0.001]).astype(np.float32)
mu = 0.0001
use_nesterov = self.use_nesterov
regularization_method = self.regularization_method
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册