未验证 提交 c8ffdecb 编写于 作者: Z zhangbo9674 提交者: GitHub

[opt] Add regularation and Nesterov for mergerd_momentum op (#37527)

* add regularation and Nesterov for mergerd_momentum

* refine unittest for use_nesterov attr

* refine op check

* refine code

* fix bug

* refine code of regularization_flag

* delete useless code
上级 5440d2f9
......@@ -50,7 +50,8 @@ class MergedMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDuplicable();
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"Input learning rate");
"Input learning rate")
.AsDuplicable();
AddInput("MasterParam", "FP32 master weight for AMP.")
.AsDispensable()
.AsDuplicable();
......@@ -68,6 +69,18 @@ class MergedMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDispensable()
.AsDuplicable();
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<bool>("use_nesterov",
"(bool, default false) "
"Use Nesterov Momentum or not.")
.SetDefault(false);
AddAttr<std::vector<std::string>>(
"regularization_method",
"(string) regularization_method, right now only "
"support l2decay or none")
.SetDefault({});
AddAttr<std::vector<float>>("regularization_coeff",
"(float) regularization_coeff")
.SetDefault({});
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
......
......@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/momentum_op.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"
......@@ -85,33 +86,43 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
auto params = ctx.MultiInput<framework::Tensor>("Param");
auto params_out = ctx.MultiOutput<framework::Tensor>("ParamOut");
size_t n = params.size();
PADDLE_ENFORCE_EQ(
n, params_out.size(),
platform::errors::InvalidArgument(
"Output(ParamOut) number must be equal to Input(Param) number."));
PADDLE_ENFORCE_EQ(n, params_out.size(),
platform::errors::InvalidArgument(
"The size of Output(ParamOut) must be equal to "
"Input(Param), but got the size of Output(ParamOut) "
"is %d, the size of Input(Param) is %d.",
params_out.size(), n));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(
params[i], params_out[i],
platform::errors::InvalidArgument(
"Input(Param) and Output(ParamOut) must be the same Tensors."));
PADDLE_ENFORCE_EQ(params[i], params_out[i],
platform::errors::InvalidArgument(
"The size of Input(Param) and Output(ParamOut) "
"must be the same Tensors."));
}
auto grads = ctx.MultiInput<framework::Tensor>("Grad");
PADDLE_ENFORCE_EQ(
n, grads.size(),
platform::errors::InvalidArgument(
"Input(Grad) number must be equal to Input(Param) number."));
"The size of Input(Grad) must be equal to Input(Param), but got "
"the size of Input(Grad) is %d, the size of Input(Param) is %d.",
grads.size(), n));
auto velocitys = ctx.MultiInput<framework::Tensor>("Velocity");
PADDLE_ENFORCE_EQ(n, velocitys.size(),
platform::errors::InvalidArgument(
"Input(Velocity) number and Input(Param) number."));
"The size of Input(Velocity) must be equal to "
"Input(Param), but got the size of Input(Velocity) "
"is %d, the size of Input(Param) is %d.",
velocitys.size(), n));
auto velocitys_out = ctx.MultiOutput<framework::Tensor>("VelocityOut");
PADDLE_ENFORCE_EQ(
n, velocitys_out.size(),
platform::errors::InvalidArgument("Output(VelocityOut) number must be "
"equal to Input(Param) number."));
platform::errors::InvalidArgument(
"The size of Output(VelocityOut) must be "
"equal to Input(Param), but got the size of Output(VelocityOut) is "
"%d, the size of Input(Param) is %d.",
velocitys_out.size(), n));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(velocitys[i], velocitys_out[i],
platform::errors::InvalidArgument(
......@@ -126,12 +137,18 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
if (multi_precision) {
PADDLE_ENFORCE_EQ(
n, master_params.size(),
platform::errors::InvalidArgument("Input(MasterParam) number must be "
"equal to Input(Param) number."));
PADDLE_ENFORCE_EQ(n, master_params_out.size(),
platform::errors::InvalidArgument(
"Output(MasterParamOut) number must be equal to "
"Input(MasterParam) number."));
platform::errors::InvalidArgument(
"The size of Input(MasterParam) must be "
"equal to Input(Param), but got the size of Input(MasterParam) "
"is %d, the size of Input(Param) is %d.",
master_params.size(), n));
PADDLE_ENFORCE_EQ(
n, master_params_out.size(),
platform::errors::InvalidArgument(
"The size of Output(MasterParamOut) must be equal to "
"Input(MasterParam), but got the size of Output(MasterParamOut) "
"is %d, the size of Input(Param) is %d.",
master_params_out.size(), n));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(master_params[i], master_params_out[i],
platform::errors::InvalidArgument(
......@@ -147,20 +164,61 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
master_params_out.clear();
}
auto lr = ctx.Input<framework::Tensor>("LearningRate");
auto mu = ctx.Attr<float>("mu");
auto rescale_grad = ctx.Attr<float>("rescale_grad");
auto lrs = ctx.MultiInput<framework::Tensor>("LearningRate");
if (lrs.size() != 1) {
PADDLE_ENFORCE_EQ(
n, lrs.size(),
platform::errors::InvalidArgument(
"If the size of Input(LearningRate) is not 1, the size of "
"Input(LearningRate) must be "
"equal to Input(Param), but got the size of Input(LearningRate) "
"is %d, the size of Input(Param) is %d.",
lrs.size(), n));
}
auto use_nesterov = ctx.Attr<bool>("use_nesterov");
auto regularization_methods =
ctx.Attr<std::vector<std::string>>("regularization_method");
auto regularization_coeffs =
ctx.Attr<std::vector<float>>("regularization_coeff");
if (regularization_methods.size() != 0) {
PADDLE_ENFORCE_EQ(
n, regularization_methods.size(),
platform::errors::InvalidArgument(
"The size of Attr(regularization_method) must be equal "
"to Input(Param), but got the size of "
"Attr(regularization_method) is %d, the size of Input(Param) is "
"%d.",
regularization_methods.size(), n));
PADDLE_ENFORCE_EQ(
n, regularization_coeffs.size(),
platform::errors::InvalidArgument(
"The size of Attr(regularization_coeff) must be equal "
"to Input(Param), but got the size of Attr(regularization_coeff) "
"is %d, the size of Input(Param) is %d.",
regularization_coeffs.size(), n));
}
VLOG(5) << "use_nesterov: " << use_nesterov
<< ", regularization_methods.size(): "
<< regularization_methods.size()
<< ", regularization_coeffs.size(): "
<< regularization_coeffs.size();
using MPType = typename operators::details::MPTypeTrait<T>::Type;
auto &dev_ctx = ctx.template device_context<DeviceContext>();
if (lrs.size() == 1 && use_nesterov == false &&
regularization_methods.size() == 0) {
#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \
MergedMomentumKernelParam<T, MPType, kMultiPrecision> kernel_params; \
constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \
size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \
kernel_params.mu = static_cast<MPType>(mu); \
kernel_params.rescale_grad = static_cast<MPType>(rescale_grad); \
kernel_params.lr = lr->data<MPType>(); \
kernel_params.lr = lrs[0]->data<MPType>(); \
for (size_t i = 0; i < kernel_num; ++i) { \
size_t start = i * kMaxMergedNum; \
size_t end = std::min((i + 1) * kMaxMergedNum, n); \
......@@ -182,14 +240,78 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
VLOG(10) << "Launch MergedMomentum kernel " << i << " " \
<< kernel_params.param_num; \
}
if (multi_precision) {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true);
if (multi_precision) {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true);
} else {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false);
}
#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL
} else {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false);
}
for (size_t idx = 0; idx < n; idx++) {
RegularizationType regularization_flag =
regularization_methods.size() > 0 &&
regularization_methods[idx] == "l2_decay"
? RegularizationType::kL2DECAY
: RegularizationType::kNONE;
#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL
MPType regularization_coeff = static_cast<MPType>(0.0);
if (regularization_coeffs.size() != 0) {
regularization_coeff =
static_cast<MPType>(regularization_coeffs[idx]);
}
auto lr_temp = lrs.size() > 1 ? lrs[idx] : lrs[0];
const MPType *master_in_data =
multi_precision ? master_params[idx]->data<MPType>() : nullptr;
MPType *master_out_data =
multi_precision ? master_params_out[idx]->data<MPType>() : nullptr;
if (platform::is_cpu_place(ctx.GetPlace())) {
CPUDenseMomentumFunctor<MPType> functor;
functor(params[idx], grads[idx], velocitys[idx], lr_temp, mu,
use_nesterov, regularization_flag, regularization_coeff,
params_out[idx], velocitys_out[idx]);
VLOG(10) << "Launch MergedMomentum cpu kernel.";
} else if (platform::is_gpu_place(ctx.GetPlace())) {
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext &>(ctx.device_context()),
params[idx]->numel());
#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \
DenseMomentumFunctor<T, MPType, __reg_type, __nesterov> functor( \
params[idx]->data<T>(), grads[idx]->data<T>(), \
velocitys[idx]->data<MPType>(), lr_temp->data<MPType>(), master_in_data, \
mu, rescale_grad, params[idx]->numel(), regularization_coeff, \
params_out[idx]->data<T>(), velocitys_out[idx]->data<MPType>(), \
master_out_data); \
for_range(functor);
if (use_nesterov) {
if (regularization_flag == RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
UseNesterov, RegularizationType::kL2DECAY);
VLOG(10)
<< "Launch MergedMomentum gpu kernel use_nesterov kL2DECAY.";
} else {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(UseNesterov,
RegularizationType::kNONE);
VLOG(10)
<< "Launch MergedMomentum gpu kernel use_nesterov kNONE.";
}
} else {
if (regularization_flag == RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
NoNesterov, RegularizationType::kL2DECAY);
VLOG(10)
<< "Launch MergedMomentum gpu kernel no_nesterov kL2DECAY.";
} else {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(NoNesterov,
RegularizationType::kNONE);
VLOG(10) << "Launch MergedMomentum gpu kernel no_nesterov kNONE.";
}
}
}
}
VLOG(10)
<< "Launch MergedMomentum kernel with multi_lr and regularization.";
}
}
};
......
......@@ -827,7 +827,7 @@ GetVarBaseListFromArgs(const std::string& op_type, const std::string& arg_name,
bool dispensable = false) {
PyObject* list = PyTuple_GET_ITEM(args, arg_idx);
if (list == nullptr) {
if (list == nullptr || list == Py_None) {
if (!dispensable) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensor, but got "
......
......@@ -130,6 +130,130 @@ def run_momentum_op(params,
return exe.run(main, feed=feed_dict, fetch_list=fetch_list)
def run_momentum_op2(params,
grads,
velocitys,
master_params,
learning_rate,
place,
multi_precision,
mu=0.9,
rescale_grad=0.01,
use_merged=False,
use_nesterov=True):
assert len(params) == len(grads)
assert len(params) == len(velocitys)
if multi_precision:
assert len(params) == len(master_params)
op_type = 'merged_momentum' if use_merged else 'momentum'
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
helper = LayerHelper(op_type, **locals())
param_vars = [
helper.create_variable(
persistable=True, shape=p.shape, dtype=p.dtype) for p in params
]
grad_vars = [
helper.create_variable(
shape=g.shape, dtype=g.dtype) for g in grads
]
velocity_vars = [
helper.create_variable(
persistable=True, shape=v.shape, dtype=v.dtype)
for v in velocitys
]
lr_var = helper.create_variable(
persistable=True,
shape=learning_rate.shape,
dtype=learning_rate.dtype)
feed_dict = OrderedDict()
feed_dict.update(
OrderedDict([(p_var.name, p_val)
for p_var, p_val in zip(param_vars, params)]))
feed_dict.update(
OrderedDict([(v_var.name, v_val)
for v_var, v_val in zip(velocity_vars, velocitys)]))
fetch_list = list(feed_dict.keys())
feed_dict.update(
OrderedDict([(g_var.name, g_val)
for g_var, g_val in zip(grad_vars, grads)]))
feed_dict.update({lr_var.name: learning_rate})
if multi_precision:
master_param_vars = [
helper.create_variable(
persistable=True, shape=p.shape, dtype=p.dtype)
for p in master_params
]
feed_dict.update(
OrderedDict([(mp_var.name, mp_val)
for mp_var, mp_val in zip(master_param_vars,
master_params)]))
# CPUPlace does not use MasterParam
if isinstance(place, paddle.CUDAPlace):
fetch_list = fetch_list + [
mp_var.name for mp_var in master_param_vars
]
else:
master_param_vars = None
if not use_merged:
for i, (p, g,
v) in enumerate(zip(param_vars, grad_vars, velocity_vars)):
inputs = {
'Param': p,
'Grad': g,
'Velocity': v,
'LearningRate': lr_var,
}
outputs = {'ParamOut': p, 'VelocityOut': v}
if multi_precision:
inputs['MasterParam'] = master_param_vars[i]
outputs['MasterParamOut'] = master_param_vars[i]
attrs = {
'mu': mu,
'multi_precision': multi_precision,
'rescale_grad': rescale_grad,
'use_nesterov': use_nesterov,
'regularization_method': 'l2_decay',
'regularization_coeff': 2.0,
}
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
else:
inputs = {
'Param': param_vars,
'Grad': grad_vars,
'Velocity': velocity_vars,
'LearningRate': lr_var,
}
outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars}
if multi_precision:
inputs['MasterParam'] = master_param_vars
outputs['MasterParamOut'] = master_param_vars
attrs = {
'mu': mu,
'multi_precision': multi_precision,
'rescale_grad': rescale_grad,
'use_nesterov': use_nesterov,
'regularization_method':
['l2_decay' for i in range(len(param_vars))],
'regularization_coeff': [2.0 for i in range(len(param_vars))],
}
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
exe = paddle.static.Executor(place)
with paddle.static.scope_guard(paddle.static.Scope()):
exe.run(startup)
return exe.run(main, feed=feed_dict, fetch_list=fetch_list)
class TestMergedMomentum(unittest.TestCase):
def setUp(self):
paddle.enable_static()
......@@ -193,5 +317,78 @@ class TestMergedMomentum(unittest.TestCase):
self.check_with_place(place, multi_precision)
class TestMergedMomentum2(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]]
self.seed = 10
def gen_rand_data(self, shapes, dtype):
return [np.random.random(s).astype(dtype) for s in shapes]
def prepare_data(self, shapes, multi_precision, seed, place):
np.random.seed(seed)
mp_dtype = np.float32
dtype = np.float16 if multi_precision and isinstance(
place, paddle.CUDAPlace) else np.float32
params = self.gen_rand_data(shapes, dtype)
grads = self.gen_rand_data(shapes, dtype)
velocitys = self.gen_rand_data(shapes, mp_dtype)
learning_rate = self.gen_rand_data([[1]], mp_dtype)[0]
if multi_precision:
master_params = [p.astype(mp_dtype) for p in params]
else:
master_params = None
return params, grads, velocitys, master_params, learning_rate
def check_with_place(self, place, multi_precision):
params, grads, velocitys, master_params, learning_rate = self.prepare_data(
self.shapes, multi_precision, self.seed, place)
def run_op(use_nesterov, use_merged):
# FIXME(zengjinle): CPU Momentum Op does not support rescale_grad
rescale_grad = 1.0 if isinstance(place, paddle.CPUPlace) else 0.01
return run_momentum_op2(
params,
grads,
velocitys,
master_params,
learning_rate,
place,
multi_precision,
rescale_grad=rescale_grad,
use_merged=use_merged,
use_nesterov=use_nesterov)
outs1 = run_op(use_nesterov=True, use_merged=True)
outs2 = run_op(use_nesterov=True, use_merged=False)
self.assertEqual(len(outs1), len(outs2))
for i, (out1, out2) in enumerate(zip(outs1, outs2)):
if isinstance(place, paddle.CUDAPlace):
self.assertTrue(np.array_equal(out1, out2))
else:
self.assertTrue(np.allclose(out1, out2, atol=1e-7))
outs3 = run_op(use_nesterov=False, use_merged=True)
outs4 = run_op(use_nesterov=False, use_merged=False)
self.assertEqual(len(outs3), len(outs4))
for j, (out3, out4) in enumerate(zip(outs3, outs4)):
if isinstance(place, paddle.CUDAPlace):
self.assertTrue(np.array_equal(out3, out4))
else:
self.assertTrue(np.allclose(out3, out4, atol=1e-7))
def get_places(self):
places = [paddle.CPUPlace()]
if paddle.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
return places
def test_main(self):
for multi_precision in [False, True]:
for place in self.get_places():
self.check_with_place(place, multi_precision)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册