未验证 提交 1f7b2516 编写于 作者: F fwenguang 提交者: GitHub

[MLU] add merged_momentum mlu kernel (#40406)

上级 5720537e
......@@ -11,7 +11,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/controlflow/compare_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
template <typename T>
class MLUMergedMomentumOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto params = ctx.MultiInput<framework::Tensor>("Param");
auto params_out = ctx.MultiOutput<framework::Tensor>("ParamOut");
size_t n = params.size();
PADDLE_ENFORCE_EQ(n, params_out.size(),
platform::errors::InvalidArgument(
"The size of Output(ParamOut) must be equal to "
"Input(Param), but got the size of Output(ParamOut) "
"is %d, the size of Input(Param) is %d.",
params_out.size(), n));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(params[i], params_out[i],
platform::errors::InvalidArgument(
"The size of Input(Param) and Output(ParamOut) "
"must be the same Tensors."));
}
auto grads = ctx.MultiInput<framework::Tensor>("Grad");
PADDLE_ENFORCE_EQ(
n, grads.size(),
platform::errors::InvalidArgument(
"The size of Input(Grad) must be equal to Input(Param), but got "
"the size of Input(Grad) is %d, the size of Input(Param) is %d.",
grads.size(), n));
auto velocitys = ctx.MultiInput<framework::Tensor>("Velocity");
PADDLE_ENFORCE_EQ(n, velocitys.size(),
platform::errors::InvalidArgument(
"The size of Input(Velocity) must be equal to "
"Input(Param), but got the size of Input(Velocity) "
"is %d, the size of Input(Param) is %d.",
velocitys.size(), n));
auto velocitys_out = ctx.MultiOutput<framework::Tensor>("VelocityOut");
PADDLE_ENFORCE_EQ(
n, velocitys_out.size(),
platform::errors::InvalidArgument(
"The size of Output(VelocityOut) must be "
"equal to Input(Param), but got the size of Output(VelocityOut) is "
"%d, the size of Input(Param) is %d.",
velocitys_out.size(), n));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(velocitys[i], velocitys_out[i],
platform::errors::InvalidArgument(
"Input(Velocity) and Output(VelocityOut) must be "
"the same Tensors."));
}
auto mu = ctx.Attr<float>("mu");
auto lrs = ctx.MultiInput<framework::Tensor>("LearningRate");
if (lrs.size() != 1) {
PADDLE_ENFORCE_EQ(
n, lrs.size(),
platform::errors::InvalidArgument(
"If the size of Input(LearningRate) is not 1, the size of "
"Input(LearningRate) must be "
"equal to Input(Param), but got the size of Input(LearningRate) "
"is %d, the size of Input(Param) is %d.",
lrs.size(), n));
}
auto use_nesterov = ctx.Attr<bool>("use_nesterov");
auto regularization_methods =
ctx.Attr<std::vector<std::string>>("regularization_method");
auto regularization_coeffs =
ctx.Attr<std::vector<float>>("regularization_coeff");
if (regularization_methods.size() != 0) {
PADDLE_ENFORCE_EQ(
n, regularization_methods.size(),
platform::errors::InvalidArgument(
"The size of Attr(regularization_method) must be equal "
"to Input(Param), but got the size of "
"Attr(regularization_method) is %d, the size of Input(Param) is "
"%d.",
regularization_methods.size(), n));
PADDLE_ENFORCE_EQ(
n, regularization_coeffs.size(),
platform::errors::InvalidArgument(
"The size of Attr(regularization_coeff) must be equal "
"to Input(Param), but got the size of Attr(regularization_coeff) "
"is %d, the size of Input(Param) is %d.",
regularization_coeffs.size(), n));
}
VLOG(5) << "use_nesterov: " << use_nesterov
<< ", regularization_methods.size(): "
<< regularization_methods.size()
<< ", regularization_coeffs.size(): "
<< regularization_coeffs.size();
auto& dev_ctx = ctx.template device_context<platform::MLUDeviceContext>();
Tensor mu_tensor = ctx.AllocateTmpTensor<T, MLUDeviceContext>({1}, dev_ctx);
MLUCnnlTensorDesc mu_tensor_desc(mu_tensor);
MLUCnnl::Fill(ctx, mu, mu_tensor_desc.get(), GetBasePtr(&mu_tensor));
for (size_t idx = 0; idx < n; ++idx) {
RegularizationType regularization_flag =
regularization_methods.size() > 0 &&
regularization_methods[idx] == "l2_decay"
? RegularizationType::kL2DECAY
: RegularizationType::kNONE;
T regularization_coeff = static_cast<T>(0.0);
if (regularization_coeffs.size() != 0) {
regularization_coeff = static_cast<T>(regularization_coeffs[idx]);
}
auto learning_rate = lrs.size() > 1 ? lrs[idx] : lrs[0];
auto param_out = params_out[idx];
auto velocity_out = velocitys_out[idx];
auto grad = grads[idx];
Tensor regularized_grad;
MLUCnnlTensorDesc param_desc(*param_out);
if (regularization_flag == RegularizationType::kL2DECAY) {
regularized_grad = ctx.AllocateTmpTensor<T, MLUDeviceContext>(
param_out->dims(), dev_ctx);
MLUCnnlOpTensorDesc op_tensor_desc(
CNNL_OP_TENSOR_ADD, ToCnnlDataType<T>(), CNNL_NOT_PROPAGATE_NAN);
MLUCnnl::OpTensor(ctx, op_tensor_desc.get(), param_desc.get(),
GetBasePtr(param_out), param_desc.get(),
GetBasePtr(grad), param_desc.get(),
GetBasePtr(&regularized_grad), ToCnnlDataType<T>(),
regularization_coeff);
} else {
regularized_grad = *grad;
}
MLUCnnl::ApplyMomentum(ctx, param_desc.get(),
GetBasePtr(&regularized_grad), use_nesterov,
GetBasePtr(learning_rate), GetBasePtr(&mu_tensor),
GetBasePtr(param_out), GetBasePtr(velocity_out));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(merged_momentum, ops::MLUMergedMomentumOpKernel<float>,
ops::MLUMergedMomentumOpKernel<plat::float16>);
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append('..')
import unittest
import paddle
import numpy as np
from paddle.fluid.layer_helper import LayerHelper
from collections import OrderedDict
def run_momentum_op(params,
grads,
velocitys,
master_params,
learning_rate,
place,
multi_precision,
mu=0.9,
rescale_grad=0.01,
use_merged=False):
assert len(params) == len(grads)
assert len(params) == len(velocitys)
if multi_precision:
assert len(params) == len(master_params)
op_type = 'merged_momentum' if use_merged else 'momentum'
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
helper = LayerHelper(op_type, **locals())
attrs = {
'mu': mu,
'multi_precision': multi_precision,
'rescale_grad': rescale_grad,
}
param_vars = [
helper.create_variable(
persistable=True, shape=p.shape, dtype=p.dtype) for p in params
]
grad_vars = [
helper.create_variable(
shape=g.shape, dtype=g.dtype) for g in grads
]
velocity_vars = [
helper.create_variable(
persistable=True, shape=v.shape, dtype=v.dtype)
for v in velocitys
]
lr_var = helper.create_variable(
persistable=True,
shape=learning_rate.shape,
dtype=learning_rate.dtype)
feed_dict = OrderedDict()
feed_dict.update(
OrderedDict([(p_var.name, p_val)
for p_var, p_val in zip(param_vars, params)]))
feed_dict.update(
OrderedDict([(v_var.name, v_val)
for v_var, v_val in zip(velocity_vars, velocitys)]))
fetch_list = list(feed_dict.keys())
feed_dict.update(
OrderedDict([(g_var.name, g_val)
for g_var, g_val in zip(grad_vars, grads)]))
feed_dict.update({lr_var.name: learning_rate})
if multi_precision:
master_param_vars = [
helper.create_variable(
persistable=True, shape=p.shape, dtype=p.dtype)
for p in master_params
]
feed_dict.update(
OrderedDict([(mp_var.name, mp_val)
for mp_var, mp_val in zip(master_param_vars,
master_params)]))
# CPUPlace does not use MasterParam
if isinstance(place, paddle.CUDAPlace):
fetch_list = fetch_list + [
mp_var.name for mp_var in master_param_vars
]
else:
master_param_vars = None
if not use_merged:
for i, (p, g,
v) in enumerate(zip(param_vars, grad_vars, velocity_vars)):
inputs = {
'Param': p,
'Grad': g,
'Velocity': v,
'LearningRate': lr_var,
}
outputs = {'ParamOut': p, 'VelocityOut': v}
if multi_precision:
inputs['MasterParam'] = master_param_vars[i]
outputs['MasterParamOut'] = master_param_vars[i]
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
else:
inputs = {
'Param': param_vars,
'Grad': grad_vars,
'Velocity': velocity_vars,
'LearningRate': lr_var,
}
outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars}
if multi_precision:
inputs['MasterParam'] = master_param_vars
outputs['MasterParamOut'] = master_param_vars
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
exe = paddle.static.Executor(place)
with paddle.static.scope_guard(paddle.static.Scope()):
exe.run(startup)
return exe.run(main, feed=feed_dict, fetch_list=fetch_list)
def run_momentum_op2(params,
grads,
velocitys,
master_params,
learning_rate,
place,
multi_precision,
mu=0.9,
rescale_grad=0.01,
use_merged=False,
use_nesterov=True):
assert len(params) == len(grads)
assert len(params) == len(velocitys)
if multi_precision:
assert len(params) == len(master_params)
op_type = 'merged_momentum' if use_merged else 'momentum'
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
helper = LayerHelper(op_type, **locals())
param_vars = [
helper.create_variable(
persistable=True, shape=p.shape, dtype=p.dtype) for p in params
]
grad_vars = [
helper.create_variable(
shape=g.shape, dtype=g.dtype) for g in grads
]
velocity_vars = [
helper.create_variable(
persistable=True, shape=v.shape, dtype=v.dtype)
for v in velocitys
]
lr_var = helper.create_variable(
persistable=True,
shape=learning_rate.shape,
dtype=learning_rate.dtype)
feed_dict = OrderedDict()
feed_dict.update(
OrderedDict([(p_var.name, p_val)
for p_var, p_val in zip(param_vars, params)]))
feed_dict.update(
OrderedDict([(v_var.name, v_val)
for v_var, v_val in zip(velocity_vars, velocitys)]))
fetch_list = list(feed_dict.keys())
feed_dict.update(
OrderedDict([(g_var.name, g_val)
for g_var, g_val in zip(grad_vars, grads)]))
feed_dict.update({lr_var.name: learning_rate})
if multi_precision:
master_param_vars = [
helper.create_variable(
persistable=True, shape=p.shape, dtype=p.dtype)
for p in master_params
]
feed_dict.update(
OrderedDict([(mp_var.name, mp_val)
for mp_var, mp_val in zip(master_param_vars,
master_params)]))
# CPUPlace does not use MasterParam
if isinstance(place, paddle.CUDAPlace):
fetch_list = fetch_list + [
mp_var.name for mp_var in master_param_vars
]
else:
master_param_vars = None
if not use_merged:
for i, (p, g,
v) in enumerate(zip(param_vars, grad_vars, velocity_vars)):
inputs = {
'Param': p,
'Grad': g,
'Velocity': v,
'LearningRate': lr_var,
}
outputs = {'ParamOut': p, 'VelocityOut': v}
if multi_precision:
inputs['MasterParam'] = master_param_vars[i]
outputs['MasterParamOut'] = master_param_vars[i]
attrs = {
'mu': mu,
'multi_precision': multi_precision,
'rescale_grad': rescale_grad,
'use_nesterov': use_nesterov,
'regularization_method': 'l2_decay',
'regularization_coeff': 2.0,
}
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
else:
inputs = {
'Param': param_vars,
'Grad': grad_vars,
'Velocity': velocity_vars,
'LearningRate': lr_var,
}
outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars}
if multi_precision:
inputs['MasterParam'] = master_param_vars
outputs['MasterParamOut'] = master_param_vars
attrs = {
'mu': mu,
'multi_precision': multi_precision,
'rescale_grad': rescale_grad,
'use_nesterov': use_nesterov,
'regularization_method':
['l2_decay' for i in range(len(param_vars))],
'regularization_coeff': [2.0 for i in range(len(param_vars))],
}
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
exe = paddle.static.Executor(place)
with paddle.static.scope_guard(paddle.static.Scope()):
exe.run(startup)
return exe.run(main, feed=feed_dict, fetch_list=fetch_list)
class TestMergedMomentum(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]]
self.seed = 10
self.place = paddle.device.MLUPlace(0)
self.__class__.use_mlu = True
def gen_rand_data(self, shapes, dtype):
return [np.random.random(s).astype(dtype) for s in shapes]
def prepare_data(self, shapes, multi_precision, seed, place):
np.random.seed(seed)
mp_dtype = np.float32
dtype = np.float32
params = self.gen_rand_data(shapes, dtype)
grads = self.gen_rand_data(shapes, dtype)
velocitys = self.gen_rand_data(shapes, mp_dtype)
learning_rate = self.gen_rand_data([[1]], mp_dtype)[0]
if multi_precision:
master_params = [p.astype(mp_dtype) for p in params]
else:
master_params = None
return params, grads, velocitys, master_params, learning_rate
def check_with_place(self, place, multi_precision):
params, grads, velocitys, master_params, learning_rate = self.prepare_data(
self.shapes, multi_precision, self.seed, place)
def run_op(use_merged):
# MLU Momentum Op does not support rescale_grad
rescale_grad = 1.0
return run_momentum_op(
params,
grads,
velocitys,
master_params,
learning_rate,
place,
multi_precision,
rescale_grad=rescale_grad,
use_merged=use_merged)
outs1 = run_op(True)
outs2 = run_op(False)
self.assertEqual(len(outs1), len(outs2))
for i, (out1, out2) in enumerate(zip(outs1, outs2)):
self.assertTrue(np.allclose(out1, out2, atol=1e-7))
def test_main(self):
self.check_with_place(self.place, multi_precision=False)
class TestMergedMomentum2(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]]
self.seed = 10
self.place = paddle.device.MLUPlace(0)
self.__class__.use_mlu = True
def gen_rand_data(self, shapes, dtype):
return [np.random.random(s).astype(dtype) for s in shapes]
def prepare_data(self, shapes, multi_precision, seed, place):
np.random.seed(seed)
mp_dtype = np.float32
dtype = np.float32 # np.float16
params = self.gen_rand_data(shapes, dtype)
grads = self.gen_rand_data(shapes, dtype)
velocitys = self.gen_rand_data(shapes, mp_dtype)
learning_rate = self.gen_rand_data([[1]], mp_dtype)[0]
if multi_precision:
master_params = [p.astype(mp_dtype) for p in params]
else:
master_params = None
return params, grads, velocitys, master_params, learning_rate
def check_with_place(self, place, multi_precision):
params, grads, velocitys, master_params, learning_rate = self.prepare_data(
self.shapes, multi_precision, self.seed, place)
def run_op(use_nesterov, use_merged):
# MLU Momentum Op does not support rescale_grad
rescale_grad = 1.0
return run_momentum_op2(
params,
grads,
velocitys,
master_params,
learning_rate,
place,
multi_precision,
rescale_grad=rescale_grad,
use_merged=use_merged,
use_nesterov=use_nesterov)
outs1 = run_op(use_nesterov=True, use_merged=True)
outs2 = run_op(use_nesterov=True, use_merged=False)
self.assertEqual(len(outs1), len(outs2))
for i, (out1, out2) in enumerate(zip(outs1, outs2)):
self.assertTrue(np.allclose(out1, out2, atol=1e-7))
outs3 = run_op(use_nesterov=False, use_merged=True)
outs4 = run_op(use_nesterov=False, use_merged=False)
self.assertEqual(len(outs3), len(outs4))
for j, (out3, out4) in enumerate(zip(outs3, outs4)):
self.assertTrue(np.allclose(out3, out4, atol=1e-7))
def test_main(self):
self.check_with_place(self.place, multi_precision=False)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册