未验证 提交 4dda18a8 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix momentum ops (#36452)

上级 8566cc98
......@@ -173,14 +173,15 @@ class CPUDenseMomentumFunctor {
}
};
template <typename T, typename MT, typename UpdateMethod>
template <typename T, typename MT, RegularizationType kRegType,
typename UpdateMethod>
class DenseMomentumFunctor;
// NOTE(dzh) for performance.
// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two
// functor.
template <typename T, typename MT>
class DenseMomentumFunctor<T, MT, UseNesterov> {
template <typename T, typename MT, RegularizationType kRegType>
class DenseMomentumFunctor<T, MT, kRegType, UseNesterov> {
private:
const T* param_;
const T* grad_;
......@@ -193,7 +194,6 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
T* param_out_;
MT* velocity_out_;
MT* master_param_out_;
const RegularizationType regularization_flag_;
const MT regularization_coeff_;
public:
......@@ -201,7 +201,6 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
const MultiPrecisionType<MT>* learning_rate,
const MT* master_param, const MT mu,
const MT rescale_grad, const int64_t num,
const RegularizationType regularization_flag,
const MT regularization_coeff, T* param_out,
MT* velocity_out, MT* master_param_out)
: param_(param),
......@@ -215,7 +214,6 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
param_out_(param_out),
velocity_out_(velocity_out),
master_param_out_(master_param_out),
regularization_flag_(regularization_flag),
regularization_coeff_(regularization_coeff) {}
inline HOSTDEVICE void operator()(size_t i) const {
// put memory access in register
......@@ -225,9 +223,9 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
const MT lr = static_cast<MT>(lr_[0]);
const MT velocity = velocity_[i];
grad = regularization_flag_ == RegularizationType::kL2DECAY
? grad + regularization_coeff_ * param
: grad;
if (kRegType == RegularizationType::kL2DECAY) {
grad += regularization_coeff_ * param;
}
MT velocity_out = velocity * mu_ + grad;
MT param_out = param - (grad + velocity_out * mu_) * lr;
......@@ -240,8 +238,8 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
}
};
template <typename T, typename MT>
class DenseMomentumFunctor<T, MT, NoNesterov> {
template <typename T, typename MT, RegularizationType kRegType>
class DenseMomentumFunctor<T, MT, kRegType, NoNesterov> {
private:
const T* param_;
const T* grad_;
......@@ -254,7 +252,6 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
T* param_out_;
MT* velocity_out_;
MT* master_param_out_;
const RegularizationType regularization_flag_;
const MT regularization_coeff_;
public:
......@@ -262,7 +259,6 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
const MultiPrecisionType<MT>* learning_rate,
const MT* master_param, const MT mu,
const MT rescale_grad, const int64_t num,
const RegularizationType regularization_flag,
const MT regularization_coeff, T* param_out,
MT* velocity_out, MT* master_param_out)
: param_(param),
......@@ -276,7 +272,6 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
param_out_(param_out),
velocity_out_(velocity_out),
master_param_out_(master_param_out),
regularization_flag_(regularization_flag),
regularization_coeff_(regularization_coeff) {}
inline HOSTDEVICE void operator()(size_t i) const {
// put memory access in register
......@@ -286,9 +281,9 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
const MT lr = static_cast<MT>(lr_[0]);
const MT velocity = velocity_[i];
grad = regularization_flag_ == RegularizationType::kL2DECAY
? grad + regularization_coeff_ * param
: grad;
if (kRegType == RegularizationType::kL2DECAY) {
grad += regularization_coeff_ * param;
}
MT velocity_out = velocity * mu_ + grad;
MT param_out = param - lr * velocity_out;
......@@ -522,23 +517,31 @@ class MomentumOpKernel : public framework::OpKernel<T> {
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
param->numel());
if (use_nesterov) {
DenseMomentumFunctor<T, MT, UseNesterov> functor(
param->data<T>(), grad->data<T>(), velocity->data<MT>(),
learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad,
param->numel(), regularization_flag, regularization_coeff,
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
for_range(functor);
#define PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(__nesterov, __reg_type) \
DenseMomentumFunctor<T, MT, __reg_type, __nesterov> functor( \
param->data<T>(), grad->data<T>(), velocity->data<MT>(), \
learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad, \
param->numel(), regularization_coeff, \
param_out->mutable_data<T>(ctx.GetPlace()), \
velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data); \
for_range(functor);
if (use_nesterov) {
if (regularization_flag == RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(UseNesterov,
RegularizationType::kL2DECAY);
} else {
PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(UseNesterov,
RegularizationType::kNONE);
}
} else {
DenseMomentumFunctor<T, MT, NoNesterov> functor(
param->data<T>(), grad->data<T>(), velocity->data<MT>(),
learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad,
param->numel(), regularization_flag, regularization_coeff,
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
for_range(functor);
if (regularization_flag == RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(NoNesterov,
RegularizationType::kL2DECAY);
} else {
PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(NoNesterov,
RegularizationType::kNONE);
}
}
}
......
......@@ -102,7 +102,7 @@ def run_momentum_op(params,
'Param': p,
'Grad': g,
'Velocity': v,
'LearningRate': lr_var
'LearningRate': lr_var,
}
outputs = {'ParamOut': p, 'VelocityOut': v}
if multi_precision:
......@@ -115,7 +115,7 @@ def run_momentum_op(params,
'Param': param_vars,
'Grad': grad_vars,
'Velocity': velocity_vars,
'LearningRate': lr_var
'LearningRate': lr_var,
}
outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars}
if multi_precision:
......@@ -176,7 +176,10 @@ class TestMergedMomentum(unittest.TestCase):
outs2 = run_op(False)
self.assertEqual(len(outs1), len(outs2))
for i, (out1, out2) in enumerate(zip(outs1, outs2)):
self.assertTrue(np.allclose(out1, out2, atol=1e-7))
if isinstance(place, paddle.CUDAPlace):
self.assertTrue(np.array_equal(out1, out2))
else:
self.assertTrue(np.allclose(out1, out2, atol=1e-7))
def get_places(self):
places = [paddle.CPUPlace()]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册