未验证 提交 703a64a3 编写于 作者: Z Zhang Ting 提交者: GitHub

[AMP] support master_grad for adam and momentum (#54240)

* support master_grad for adam and momentum

Co-authored-by: zhangting_2017@163.com <zhangting2020>
上级 cdbf62f8
...@@ -1254,10 +1254,9 @@ static PyObject* eager_api_set_master_grads(PyObject* self, ...@@ -1254,10 +1254,9 @@ static PyObject* eager_api_set_master_grads(PyObject* self,
auto tensor_list = CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 0), 0); auto tensor_list = CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 0), 0);
for (auto& tensor : tensor_list) { for (auto& tensor : tensor_list) {
VLOG(6) << "set master_grad for tensor: " << tensor.name(); VLOG(6) << "set master_grad for tensor: " << tensor.name();
PADDLE_ENFORCE_EQ( if (!egr::egr_utils_api::IsLeafTensor(tensor)) {
egr::egr_utils_api::IsLeafTensor(tensor), continue;
true, }
paddle::platform::errors::Fatal("Only leaf Tensor can be set grad."));
paddle::Tensor* grad = egr::EagerUtils::mutable_grad(tensor); paddle::Tensor* grad = egr::EagerUtils::mutable_grad(tensor);
PADDLE_ENFORCE_NE(grad, PADDLE_ENFORCE_NE(grad,
nullptr, nullptr,
...@@ -1265,13 +1264,13 @@ static PyObject* eager_api_set_master_grads(PyObject* self, ...@@ -1265,13 +1264,13 @@ static PyObject* eager_api_set_master_grads(PyObject* self,
"Detected NULL grad" "Detected NULL grad"
"Please check if you have manually cleared" "Please check if you have manually cleared"
"the grad inside autograd_meta")); "the grad inside autograd_meta"));
auto dtype = (*grad).dtype(); if ((*grad).initialized() && ((*grad).dtype() == phi::DataType::FLOAT16 ||
if ((*grad).initialized() && (*grad).dtype() == phi::DataType::BFLOAT16)) {
(dtype == phi::DataType::FLOAT16 || dtype == phi::DataType::BFLOAT16)) {
auto master_grad = auto master_grad =
paddle::experimental::cast(*grad, phi::DataType::FLOAT32); paddle::experimental::cast(*grad, phi::DataType::FLOAT32);
grad->set_impl(master_grad.impl()); grad->set_impl(master_grad.impl());
} }
VLOG(6) << "finish setting master_grad for tensor: " << tensor.name();
} }
RETURN_PY_NONE RETURN_PY_NONE
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
......
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
namespace phi { namespace phi {
template <typename T, typename MT> template <typename T, typename TG, typename MT>
__global__ void AdamKernelREG(MT beta1, __global__ void AdamKernelREG(MT beta1,
MT beta2, MT beta2,
MT epsilon, MT epsilon,
...@@ -41,7 +41,7 @@ __global__ void AdamKernelREG(MT beta1, ...@@ -41,7 +41,7 @@ __global__ void AdamKernelREG(MT beta1,
const MT* moment2, const MT* moment2,
MT* moment2_out, MT* moment2_out,
const MT* lr_, const MT* lr_,
const T* grad, const TG* grad,
const T* param, const T* param,
T* param_out, T* param_out,
const MT* master_param, const MT* master_param,
...@@ -73,7 +73,7 @@ __global__ void AdamKernelREG(MT beta1, ...@@ -73,7 +73,7 @@ __global__ void AdamKernelREG(MT beta1,
} }
} }
template <typename T, typename MT> template <typename T, typename TG, typename MT>
__global__ void AdamKernelMEM(MT beta1, __global__ void AdamKernelMEM(MT beta1,
MT beta2, MT beta2,
MT epsilon, MT epsilon,
...@@ -84,7 +84,7 @@ __global__ void AdamKernelMEM(MT beta1, ...@@ -84,7 +84,7 @@ __global__ void AdamKernelMEM(MT beta1,
const MT* moment2, const MT* moment2,
MT* moment2_out, MT* moment2_out,
const MT* lr_, const MT* lr_,
const T* grad, const TG* grad,
const T* param, const T* param,
T* param_out, T* param_out,
const MT* master_param, const MT* master_param,
...@@ -152,6 +152,7 @@ void AdamDenseKernel(const Context& dev_ctx, ...@@ -152,6 +152,7 @@ void AdamDenseKernel(const Context& dev_ctx,
DenseTensor* beta2_pow_out, DenseTensor* beta2_pow_out,
DenseTensor* master_param_outs) { DenseTensor* master_param_outs) {
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type; using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
const auto grad_type = grad.dtype();
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
...@@ -212,23 +213,44 @@ void AdamDenseKernel(const Context& dev_ctx, ...@@ -212,23 +213,44 @@ void AdamDenseKernel(const Context& dev_ctx,
if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) { if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) {
// Compute with betapow in REG // Compute with betapow in REG
AdamKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>( if (grad_type == phi::DataType::FLOAT32) {
beta1_, AdamKernelREG<T, float, MPDType>
beta2_, <<<blocks, threads, 0, dev_ctx.stream()>>>(
epsilon_, beta1_,
*beta1_pow.data<MPDType>(), beta2_,
*beta2_pow.data<MPDType>(), epsilon_,
moment1.data<MPDType>(), *beta1_pow.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out), *beta2_pow.data<MPDType>(),
moment2.data<MPDType>(), moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out), dev_ctx.template Alloc<MPDType>(moment1_out),
learning_rate.data<MPDType>(), moment2.data<MPDType>(),
grad.data<T>(), dev_ctx.template Alloc<MPDType>(moment2_out),
param.data<T>(), learning_rate.data<MPDType>(),
dev_ctx.template Alloc<T>(param_out), grad.data<float>(),
master_in_data, param.data<T>(),
master_out_data, dev_ctx.template Alloc<T>(param_out),
param.numel()); master_in_data,
master_out_data,
param.numel());
} else {
AdamKernelREG<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
*beta1_pow.data<MPDType>(),
*beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<T>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
}
if (!use_global_beta_pow) { if (!use_global_beta_pow) {
// Cpu update // Cpu update
dev_ctx.template HostAlloc<MPDType>(beta1_pow_out)[0] = dev_ctx.template HostAlloc<MPDType>(beta1_pow_out)[0] =
...@@ -237,23 +259,44 @@ void AdamDenseKernel(const Context& dev_ctx, ...@@ -237,23 +259,44 @@ void AdamDenseKernel(const Context& dev_ctx,
beta2_ * beta2_pow.data<MPDType>()[0]; beta2_ * beta2_pow.data<MPDType>()[0];
} }
} else { } else {
AdamKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>( if (grad_type == phi::DataType::FLOAT32) {
beta1_, AdamKernelMEM<T, float, MPDType>
beta2_, <<<blocks, threads, 0, dev_ctx.stream()>>>(
epsilon_, beta1_,
beta1_pow.data<MPDType>(), beta2_,
beta2_pow.data<MPDType>(), epsilon_,
moment1.data<MPDType>(), beta1_pow.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out), beta2_pow.data<MPDType>(),
moment2.data<MPDType>(), moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out), dev_ctx.template Alloc<MPDType>(moment1_out),
learning_rate.data<MPDType>(), moment2.data<MPDType>(),
grad.data<T>(), dev_ctx.template Alloc<MPDType>(moment2_out),
param.data<T>(), learning_rate.data<MPDType>(),
dev_ctx.template Alloc<T>(param_out), grad.data<float>(),
master_in_data, param.data<T>(),
master_out_data, dev_ctx.template Alloc<T>(param_out),
param.numel()); master_in_data,
master_out_data,
param.numel());
} else {
AdamKernelMEM<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
beta1_pow.data<MPDType>(),
beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<T>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
}
if (!use_global_beta_pow) { if (!use_global_beta_pow) {
// Update with gpu // Update with gpu
UpdateBetaPow<MPDType><<<1, 1, 0, dev_ctx.stream()>>>( UpdateBetaPow<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
...@@ -308,26 +351,48 @@ void MergedAdamKernel( ...@@ -308,26 +351,48 @@ void MergedAdamKernel(
int threads = 512; int threads = 512;
int blocks = (param[idx]->numel() + threads - 1) / threads; int blocks = (param[idx]->numel() + threads - 1) / threads;
const auto grad_type = grad[idx]->dtype();
if (beta1_pow[idx]->place() == CPUPlace() && if (beta1_pow[idx]->place() == CPUPlace() &&
beta2_pow[idx]->place() == CPUPlace()) { beta2_pow[idx]->place() == CPUPlace()) {
// Compute with betapow in REG // Compute with betapow in REG
AdamKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>( if (grad_type == phi::DataType::FLOAT32) {
beta1_, AdamKernelREG<T, float, MPDType>
beta2_, <<<blocks, threads, 0, dev_ctx.stream()>>>(
epsilon_, beta1_,
*beta1_pow[idx]->data<MPDType>(), beta2_,
*beta2_pow[idx]->data<MPDType>(), epsilon_,
moment1[idx]->data<MPDType>(), *beta1_pow[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out[idx]), *beta2_pow[idx]->data<MPDType>(),
moment2[idx]->data<MPDType>(), moment1[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out[idx]), dev_ctx.template Alloc<MPDType>(moment1_out[idx]),
learning_rate[idx]->data<MPDType>(), moment2[idx]->data<MPDType>(),
grad[idx]->data<T>(), dev_ctx.template Alloc<MPDType>(moment2_out[idx]),
param[idx]->data<T>(), learning_rate[idx]->data<MPDType>(),
dev_ctx.template Alloc<T>(param_out[idx]), grad[idx]->data<float>(),
master_in_data, param[idx]->data<T>(),
master_out_data, dev_ctx.template Alloc<T>(param_out[idx]),
param[idx]->numel()); master_in_data,
master_out_data,
param[idx]->numel());
} else {
AdamKernelREG<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
*beta1_pow[idx]->data<MPDType>(),
*beta2_pow[idx]->data<MPDType>(),
moment1[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out[idx]),
moment2[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out[idx]),
learning_rate[idx]->data<MPDType>(),
grad[idx]->data<T>(),
param[idx]->data<T>(),
dev_ctx.template Alloc<T>(param_out[idx]),
master_in_data,
master_out_data,
param[idx]->numel());
}
if (!use_global_beta_pow) { if (!use_global_beta_pow) {
// Cpu update // Cpu update
dev_ctx.template HostAlloc<MPDType>(beta1_pow_out[idx])[0] = dev_ctx.template HostAlloc<MPDType>(beta1_pow_out[idx])[0] =
...@@ -336,23 +401,44 @@ void MergedAdamKernel( ...@@ -336,23 +401,44 @@ void MergedAdamKernel(
beta2_ * beta2_pow[idx]->data<MPDType>()[0]; beta2_ * beta2_pow[idx]->data<MPDType>()[0];
} }
} else { } else {
AdamKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>( if (grad_type == phi::DataType::FLOAT32) {
beta1_, AdamKernelMEM<T, float, MPDType>
beta2_, <<<blocks, threads, 0, dev_ctx.stream()>>>(
epsilon_, beta1_,
beta1_pow[idx]->data<MPDType>(), beta2_,
beta2_pow[idx]->data<MPDType>(), epsilon_,
moment1[idx]->data<MPDType>(), beta1_pow[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out[idx]), beta2_pow[idx]->data<MPDType>(),
moment2[idx]->data<MPDType>(), moment1[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out[idx]), dev_ctx.template Alloc<MPDType>(moment1_out[idx]),
learning_rate[idx]->data<MPDType>(), moment2[idx]->data<MPDType>(),
grad[idx]->data<T>(), dev_ctx.template Alloc<MPDType>(moment2_out[idx]),
param[idx]->data<T>(), learning_rate[idx]->data<MPDType>(),
dev_ctx.template Alloc<T>(param_out[idx]), grad[idx]->data<float>(),
master_in_data, param[idx]->data<T>(),
master_out_data, dev_ctx.template Alloc<T>(param_out[idx]),
param[idx]->numel()); master_in_data,
master_out_data,
param[idx]->numel());
} else {
AdamKernelMEM<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
beta1_pow[idx]->data<MPDType>(),
beta2_pow[idx]->data<MPDType>(),
moment1[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out[idx]),
moment2[idx]->data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out[idx]),
learning_rate[idx]->data<MPDType>(),
grad[idx]->data<T>(),
param[idx]->data<T>(),
dev_ctx.template Alloc<T>(param_out[idx]),
master_in_data,
master_out_data,
param[idx]->numel());
}
if (!use_global_beta_pow) { if (!use_global_beta_pow) {
// Update with gpu // Update with gpu
UpdateBetaPow<MPDType><<<1, 1, 0, dev_ctx.stream()>>>( UpdateBetaPow<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
......
...@@ -300,21 +300,40 @@ void MergedMomentumInnerCompute( ...@@ -300,21 +300,40 @@ void MergedMomentumInnerCompute(
} else if (ctx.GetPlace().GetType() == phi::AllocationType::GPU) { } else if (ctx.GetPlace().GetType() == phi::AllocationType::GPU) {
phi::funcs::ForRange<Context> for_range( phi::funcs::ForRange<Context> for_range(
static_cast<const Context &>(ctx), params[idx]->numel()); static_cast<const Context &>(ctx), params[idx]->numel());
#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \ const auto grad_type = grads[idx]->dtype();
phi::DenseMomentumFunctor<T, MT, __reg_type, __nesterov> functor( \ #define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \
params[idx]->data<T>(), \ if (grad_type == phi::DataType::FLOAT32) { \
grads[idx]->data<T>(), \ DenseMomentumFunctor<T, float, MT, __reg_type, __nesterov> functor( \
velocitys[idx]->data<MT>(), \ params[idx]->data<T>(), \
lr_temp->data<MPType>(), \ grads[idx]->data<float>(), \
master_in_data, \ velocitys[idx]->data<MT>(), \
static_cast<MT>(mu), \ lr_temp->data<MPType>(), \
static_cast<MT>(rescale_grad), \ master_in_data, \
params[idx]->numel(), \ static_cast<MT>(mu), \
regularization_coeff, \ static_cast<MT>(rescale_grad), \
params_out[idx]->data<T>(), \ params[idx]->numel(), \
velocitys_out[idx]->data<MT>(), \ regularization_coeff, \
master_out_data); \ params_out[idx]->data<T>(), \
for_range(functor); velocitys_out[idx]->data<MT>(), \
master_out_data); \
for_range(functor); \
} else { \
DenseMomentumFunctor<T, T, MT, __reg_type, __nesterov> functor( \
params[idx]->data<T>(), \
grads[idx]->data<T>(), \
velocitys[idx]->data<MT>(), \
lr_temp->data<MPType>(), \
master_in_data, \
static_cast<MT>(mu), \
static_cast<MT>(rescale_grad), \
params[idx]->numel(), \
regularization_coeff, \
params_out[idx]->data<T>(), \
velocitys_out[idx]->data<MT>(), \
master_out_data); \
for_range(functor); \
}
if (use_nesterov) { if (use_nesterov) {
if (regularization_flag == phi::RegularizationType::kL2DECAY) { if (regularization_flag == phi::RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
......
...@@ -104,6 +104,7 @@ class CPUDenseMomentumFunctor { ...@@ -104,6 +104,7 @@ class CPUDenseMomentumFunctor {
}; };
template <typename T, template <typename T,
typename TG,
typename MT, typename MT,
RegularizationType kRegType, RegularizationType kRegType,
typename UpdateMethod> typename UpdateMethod>
...@@ -112,11 +113,11 @@ class DenseMomentumFunctor; ...@@ -112,11 +113,11 @@ class DenseMomentumFunctor;
// NOTE(dzh) for performance. // NOTE(dzh) for performance.
// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two // avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two
// functor. // functor.
template <typename T, typename MT, RegularizationType kRegType> template <typename T, typename TG, typename MT, RegularizationType kRegType>
class DenseMomentumFunctor<T, MT, kRegType, UseNesterov> { class DenseMomentumFunctor<T, TG, MT, kRegType, UseNesterov> {
private: private:
const T* param_; const T* param_;
const T* grad_; const TG* grad_;
const MT* velocity_; const MT* velocity_;
const MultiPrecisionType<MT>* lr_; const MultiPrecisionType<MT>* lr_;
const MT* master_param_; const MT* master_param_;
...@@ -130,7 +131,7 @@ class DenseMomentumFunctor<T, MT, kRegType, UseNesterov> { ...@@ -130,7 +131,7 @@ class DenseMomentumFunctor<T, MT, kRegType, UseNesterov> {
public: public:
DenseMomentumFunctor(const T* param, DenseMomentumFunctor(const T* param,
const T* grad, const TG* grad,
const MT* velocity, const MT* velocity,
const MultiPrecisionType<MT>* learning_rate, const MultiPrecisionType<MT>* learning_rate,
const MT* master_param, const MT* master_param,
...@@ -176,11 +177,11 @@ class DenseMomentumFunctor<T, MT, kRegType, UseNesterov> { ...@@ -176,11 +177,11 @@ class DenseMomentumFunctor<T, MT, kRegType, UseNesterov> {
} }
}; };
template <typename T, typename MT, RegularizationType kRegType> template <typename T, typename TG, typename MT, RegularizationType kRegType>
class DenseMomentumFunctor<T, MT, kRegType, NoNesterov> { class DenseMomentumFunctor<T, TG, MT, kRegType, NoNesterov> {
private: private:
const T* param_; const T* param_;
const T* grad_; const TG* grad_;
const MT* velocity_; const MT* velocity_;
const MultiPrecisionType<MT>* lr_; const MultiPrecisionType<MT>* lr_;
const MT* master_param_; const MT* master_param_;
...@@ -194,7 +195,7 @@ class DenseMomentumFunctor<T, MT, kRegType, NoNesterov> { ...@@ -194,7 +195,7 @@ class DenseMomentumFunctor<T, MT, kRegType, NoNesterov> {
public: public:
DenseMomentumFunctor(const T* param, DenseMomentumFunctor(const T* param,
const T* grad, const TG* grad,
const MT* velocity, const MT* velocity,
const MultiPrecisionType<MT>* learning_rate, const MultiPrecisionType<MT>* learning_rate,
const MT* master_param, const MT* master_param,
...@@ -459,21 +460,39 @@ void MomentumDenseImpl(const Context& ctx, ...@@ -459,21 +460,39 @@ void MomentumDenseImpl(const Context& ctx,
velocity_out); velocity_out);
} else if (ctx.GetPlace().GetType() == phi::AllocationType::GPU) { } else if (ctx.GetPlace().GetType() == phi::AllocationType::GPU) {
funcs::ForRange<Context> for_range(ctx, param.numel()); funcs::ForRange<Context> for_range(ctx, param.numel());
#define PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(__nesterov, __reg_type) \ const auto grad_type = grad.dtype();
DenseMomentumFunctor<T, MT, __reg_type, __nesterov> functor( \ #define PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(__nesterov, __reg_type) \
param.data<T>(), \ if (grad_type == phi::DataType::FLOAT32) { \
grad.data<T>(), \ DenseMomentumFunctor<T, float, MT, __reg_type, __nesterov> functor( \
velocity.data<MT>(), \ param.data<T>(), \
learning_rate.data<MultiPrecisionType<T>>(), \ grad.data<float>(), \
master_in_data, \ velocity.data<MT>(), \
mu, \ learning_rate.data<MultiPrecisionType<T>>(), \
rescale_grad, \ master_in_data, \
param.numel(), \ mu, \
regularization_coeff, \ rescale_grad, \
ctx.template Alloc<T>(param_out), \ param.numel(), \
ctx.template Alloc<MT>(velocity_out), \ regularization_coeff, \
master_out_data); \ ctx.template Alloc<T>(param_out), \
for_range(functor); ctx.template Alloc<MT>(velocity_out), \
master_out_data); \
for_range(functor); \
} else { \
DenseMomentumFunctor<T, T, MT, __reg_type, __nesterov> functor( \
param.data<T>(), \
grad.data<T>(), \
velocity.data<MT>(), \
learning_rate.data<MultiPrecisionType<T>>(), \
master_in_data, \
mu, \
rescale_grad, \
param.numel(), \
regularization_coeff, \
ctx.template Alloc<T>(param_out), \
ctx.template Alloc<MT>(velocity_out), \
master_out_data); \
for_range(functor); \
}
if (use_nesterov) { if (use_nesterov) {
if (regularization_flag == RegularizationType::kL2DECAY) { if (regularization_flag == RegularizationType::kL2DECAY) {
......
...@@ -1265,6 +1265,23 @@ class Optimizer: ...@@ -1265,6 +1265,23 @@ class Optimizer:
): ):
return grad return grad
regularization_term = None regularization_term = None
# when master_grad is true in amp training, grad will be fp32, but param maybe fp16.
# we get master weight when master_grad is true to avoid type mismatch error.
def get_target_param(param, grad):
target_param = param
if param.dtype != grad.dtype:
find_master = (
self._multi_precision
and self._is_dtype_fp16_or_bf16(param.dtype)
)
if find_master and len(self._master_weights) != 0:
target_param = self._master_weights[param.name]
else:
target_param = param.astype(grad.dtype)
return target_param
param = get_target_param(param, grad)
if hasattr(param, 'regularizer') and param.regularizer is not None: if hasattr(param, 'regularizer') and param.regularizer is not None:
# Add variable for regularization term in grad block # Add variable for regularization term in grad block
regularization_term = param.regularizer(param, grad, grad.block) regularization_term = param.regularizer(param, grad, grad.block)
......
...@@ -44,7 +44,7 @@ class TestMasterGrad(unittest.TestCase): ...@@ -44,7 +44,7 @@ class TestMasterGrad(unittest.TestCase):
# fp16 calls # fp16 calls
self.assertEqual(int(op_list['matmul_v2'].split(',')[0]), total_steps) self.assertEqual(int(op_list['matmul_v2'].split(',')[0]), total_steps)
self.assertEqual( self.assertEqual(
int(op_list['adamw_'].split(',')[0]), int(op_list['adam_'].split(',')[0]),
2 * (total_steps / accumulate_batchs_num), 2 * (total_steps / accumulate_batchs_num),
) )
self.assertEqual( self.assertEqual(
...@@ -52,14 +52,11 @@ class TestMasterGrad(unittest.TestCase): ...@@ -52,14 +52,11 @@ class TestMasterGrad(unittest.TestCase):
total_steps + total_steps * 2, total_steps + total_steps * 2,
) )
def run_dygraph(self, total_steps, accumulate_batchs_num): def run_dygraph(self, total_steps, accumulate_batchs_num, model, optimizer):
model = SimpleNet(2, 4)
opt = paddle.optimizer.AdamW(parameters=model.parameters())
model, opt = paddle.amp.decorate( model, opt = paddle.amp.decorate(
model, optimizers=opt, level='O2', master_grad=True model, optimizers=optimizer, level='O2', master_grad=True
) )
scaler = paddle.amp.GradScaler() scaler = paddle.amp.GradScaler()
paddle.amp.debugging.enable_operator_stats_collection() paddle.amp.debugging.enable_operator_stats_collection()
for i in range(total_steps): for i in range(total_steps):
x = np.random.random((2, 2)).astype('float32') x = np.random.random((2, 2)).astype('float32')
...@@ -81,16 +78,32 @@ class TestMasterGrad(unittest.TestCase): ...@@ -81,16 +78,32 @@ class TestMasterGrad(unittest.TestCase):
op_list = paddle.fluid.core.get_low_precision_op_list() op_list = paddle.fluid.core.get_low_precision_op_list()
return fp32_grads, op_list return fp32_grads, op_list
def test_master_grad(self): def test_adam_master_grad(self):
total_steps = 4 total_steps = 4
accumulate_batchs_num = 2 accumulate_batchs_num = 2
model = SimpleNet(2, 4)
opt = paddle.optimizer.Adam(parameters=model.parameters())
fp32_grads, op_list = self.run_dygraph( fp32_grads, op_list = self.run_dygraph(
total_steps, accumulate_batchs_num total_steps, accumulate_batchs_num, model, opt
) )
self.check_results( self.check_results(
fp32_grads, op_list, total_steps, accumulate_batchs_num fp32_grads, op_list, total_steps, accumulate_batchs_num
) )
def test_momentum_master_grad(self):
total_steps = 4
accumulate_batchs_num = 1
model = SimpleNet(2, 4)
L1Decay = paddle.regularizer.L1Decay(0.0001)
opt = paddle.optimizer.Momentum(
parameters=model.parameters(), weight_decay=L1Decay
)
fp32_grads, op_list = self.run_dygraph(
total_steps, accumulate_batchs_num, model, opt
)
for grad in fp32_grads:
self.assertEqual(grad.dtype, paddle.float32)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册