未验证 提交 4dda18a8 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix momentum ops (#36452)

上级 8566cc98
...@@ -173,14 +173,15 @@ class CPUDenseMomentumFunctor { ...@@ -173,14 +173,15 @@ class CPUDenseMomentumFunctor {
} }
}; };
template <typename T, typename MT, typename UpdateMethod> template <typename T, typename MT, RegularizationType kRegType,
typename UpdateMethod>
class DenseMomentumFunctor; class DenseMomentumFunctor;
// NOTE(dzh) for performance. // NOTE(dzh) for performance.
// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two // avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two
// functor. // functor.
template <typename T, typename MT> template <typename T, typename MT, RegularizationType kRegType>
class DenseMomentumFunctor<T, MT, UseNesterov> { class DenseMomentumFunctor<T, MT, kRegType, UseNesterov> {
private: private:
const T* param_; const T* param_;
const T* grad_; const T* grad_;
...@@ -193,7 +194,6 @@ class DenseMomentumFunctor<T, MT, UseNesterov> { ...@@ -193,7 +194,6 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
T* param_out_; T* param_out_;
MT* velocity_out_; MT* velocity_out_;
MT* master_param_out_; MT* master_param_out_;
const RegularizationType regularization_flag_;
const MT regularization_coeff_; const MT regularization_coeff_;
public: public:
...@@ -201,7 +201,6 @@ class DenseMomentumFunctor<T, MT, UseNesterov> { ...@@ -201,7 +201,6 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
const MultiPrecisionType<MT>* learning_rate, const MultiPrecisionType<MT>* learning_rate,
const MT* master_param, const MT mu, const MT* master_param, const MT mu,
const MT rescale_grad, const int64_t num, const MT rescale_grad, const int64_t num,
const RegularizationType regularization_flag,
const MT regularization_coeff, T* param_out, const MT regularization_coeff, T* param_out,
MT* velocity_out, MT* master_param_out) MT* velocity_out, MT* master_param_out)
: param_(param), : param_(param),
...@@ -215,7 +214,6 @@ class DenseMomentumFunctor<T, MT, UseNesterov> { ...@@ -215,7 +214,6 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
param_out_(param_out), param_out_(param_out),
velocity_out_(velocity_out), velocity_out_(velocity_out),
master_param_out_(master_param_out), master_param_out_(master_param_out),
regularization_flag_(regularization_flag),
regularization_coeff_(regularization_coeff) {} regularization_coeff_(regularization_coeff) {}
inline HOSTDEVICE void operator()(size_t i) const { inline HOSTDEVICE void operator()(size_t i) const {
// put memory access in register // put memory access in register
...@@ -225,9 +223,9 @@ class DenseMomentumFunctor<T, MT, UseNesterov> { ...@@ -225,9 +223,9 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
const MT lr = static_cast<MT>(lr_[0]); const MT lr = static_cast<MT>(lr_[0]);
const MT velocity = velocity_[i]; const MT velocity = velocity_[i];
grad = regularization_flag_ == RegularizationType::kL2DECAY if (kRegType == RegularizationType::kL2DECAY) {
? grad + regularization_coeff_ * param grad += regularization_coeff_ * param;
: grad; }
MT velocity_out = velocity * mu_ + grad; MT velocity_out = velocity * mu_ + grad;
MT param_out = param - (grad + velocity_out * mu_) * lr; MT param_out = param - (grad + velocity_out * mu_) * lr;
...@@ -240,8 +238,8 @@ class DenseMomentumFunctor<T, MT, UseNesterov> { ...@@ -240,8 +238,8 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
} }
}; };
template <typename T, typename MT> template <typename T, typename MT, RegularizationType kRegType>
class DenseMomentumFunctor<T, MT, NoNesterov> { class DenseMomentumFunctor<T, MT, kRegType, NoNesterov> {
private: private:
const T* param_; const T* param_;
const T* grad_; const T* grad_;
...@@ -254,7 +252,6 @@ class DenseMomentumFunctor<T, MT, NoNesterov> { ...@@ -254,7 +252,6 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
T* param_out_; T* param_out_;
MT* velocity_out_; MT* velocity_out_;
MT* master_param_out_; MT* master_param_out_;
const RegularizationType regularization_flag_;
const MT regularization_coeff_; const MT regularization_coeff_;
public: public:
...@@ -262,7 +259,6 @@ class DenseMomentumFunctor<T, MT, NoNesterov> { ...@@ -262,7 +259,6 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
const MultiPrecisionType<MT>* learning_rate, const MultiPrecisionType<MT>* learning_rate,
const MT* master_param, const MT mu, const MT* master_param, const MT mu,
const MT rescale_grad, const int64_t num, const MT rescale_grad, const int64_t num,
const RegularizationType regularization_flag,
const MT regularization_coeff, T* param_out, const MT regularization_coeff, T* param_out,
MT* velocity_out, MT* master_param_out) MT* velocity_out, MT* master_param_out)
: param_(param), : param_(param),
...@@ -276,7 +272,6 @@ class DenseMomentumFunctor<T, MT, NoNesterov> { ...@@ -276,7 +272,6 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
param_out_(param_out), param_out_(param_out),
velocity_out_(velocity_out), velocity_out_(velocity_out),
master_param_out_(master_param_out), master_param_out_(master_param_out),
regularization_flag_(regularization_flag),
regularization_coeff_(regularization_coeff) {} regularization_coeff_(regularization_coeff) {}
inline HOSTDEVICE void operator()(size_t i) const { inline HOSTDEVICE void operator()(size_t i) const {
// put memory access in register // put memory access in register
...@@ -286,9 +281,9 @@ class DenseMomentumFunctor<T, MT, NoNesterov> { ...@@ -286,9 +281,9 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
const MT lr = static_cast<MT>(lr_[0]); const MT lr = static_cast<MT>(lr_[0]);
const MT velocity = velocity_[i]; const MT velocity = velocity_[i];
grad = regularization_flag_ == RegularizationType::kL2DECAY if (kRegType == RegularizationType::kL2DECAY) {
? grad + regularization_coeff_ * param grad += regularization_coeff_ * param;
: grad; }
MT velocity_out = velocity * mu_ + grad; MT velocity_out = velocity * mu_ + grad;
MT param_out = param - lr * velocity_out; MT param_out = param - lr * velocity_out;
...@@ -522,23 +517,31 @@ class MomentumOpKernel : public framework::OpKernel<T> { ...@@ -522,23 +517,31 @@ class MomentumOpKernel : public framework::OpKernel<T> {
platform::ForRange<DeviceContext> for_range( platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()), static_cast<const DeviceContext&>(ctx.device_context()),
param->numel()); param->numel());
if (use_nesterov) { #define PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(__nesterov, __reg_type) \
DenseMomentumFunctor<T, MT, UseNesterov> functor( DenseMomentumFunctor<T, MT, __reg_type, __nesterov> functor( \
param->data<T>(), grad->data<T>(), velocity->data<MT>(), param->data<T>(), grad->data<T>(), velocity->data<MT>(), \
learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad, learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad, \
param->numel(), regularization_flag, regularization_coeff, param->numel(), regularization_coeff, \
param_out->mutable_data<T>(ctx.GetPlace()), param_out->mutable_data<T>(ctx.GetPlace()), \
velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data); velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data); \
for_range(functor); for_range(functor);
if (use_nesterov) {
if (regularization_flag == RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(UseNesterov,
RegularizationType::kL2DECAY);
} else {
PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(UseNesterov,
RegularizationType::kNONE);
}
} else { } else {
DenseMomentumFunctor<T, MT, NoNesterov> functor( if (regularization_flag == RegularizationType::kL2DECAY) {
param->data<T>(), grad->data<T>(), velocity->data<MT>(), PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(NoNesterov,
learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad, RegularizationType::kL2DECAY);
param->numel(), regularization_flag, regularization_coeff, } else {
param_out->mutable_data<T>(ctx.GetPlace()), PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(NoNesterov,
velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data); RegularizationType::kNONE);
for_range(functor); }
} }
} }
......
...@@ -102,7 +102,7 @@ def run_momentum_op(params, ...@@ -102,7 +102,7 @@ def run_momentum_op(params,
'Param': p, 'Param': p,
'Grad': g, 'Grad': g,
'Velocity': v, 'Velocity': v,
'LearningRate': lr_var 'LearningRate': lr_var,
} }
outputs = {'ParamOut': p, 'VelocityOut': v} outputs = {'ParamOut': p, 'VelocityOut': v}
if multi_precision: if multi_precision:
...@@ -115,7 +115,7 @@ def run_momentum_op(params, ...@@ -115,7 +115,7 @@ def run_momentum_op(params,
'Param': param_vars, 'Param': param_vars,
'Grad': grad_vars, 'Grad': grad_vars,
'Velocity': velocity_vars, 'Velocity': velocity_vars,
'LearningRate': lr_var 'LearningRate': lr_var,
} }
outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars} outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars}
if multi_precision: if multi_precision:
...@@ -176,7 +176,10 @@ class TestMergedMomentum(unittest.TestCase): ...@@ -176,7 +176,10 @@ class TestMergedMomentum(unittest.TestCase):
outs2 = run_op(False) outs2 = run_op(False)
self.assertEqual(len(outs1), len(outs2)) self.assertEqual(len(outs1), len(outs2))
for i, (out1, out2) in enumerate(zip(outs1, outs2)): for i, (out1, out2) in enumerate(zip(outs1, outs2)):
self.assertTrue(np.allclose(out1, out2, atol=1e-7)) if isinstance(place, paddle.CUDAPlace):
self.assertTrue(np.array_equal(out1, out2))
else:
self.assertTrue(np.allclose(out1, out2, atol=1e-7))
def get_places(self): def get_places(self):
places = [paddle.CPUPlace()] places = [paddle.CPUPlace()]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册