未验证 提交 1f7b2516 编写于 作者: F fwenguang 提交者: GitHub

[MLU] add merged_momentum mlu kernel (#40406)

上级 5720537e
...@@ -11,7 +11,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -11,7 +11,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/controlflow/compare_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle { namespace paddle {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
template <typename T>
class MLUMergedMomentumOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto params = ctx.MultiInput<framework::Tensor>("Param");
auto params_out = ctx.MultiOutput<framework::Tensor>("ParamOut");
size_t n = params.size();
PADDLE_ENFORCE_EQ(n, params_out.size(),
platform::errors::InvalidArgument(
"The size of Output(ParamOut) must be equal to "
"Input(Param), but got the size of Output(ParamOut) "
"is %d, the size of Input(Param) is %d.",
params_out.size(), n));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(params[i], params_out[i],
platform::errors::InvalidArgument(
"The size of Input(Param) and Output(ParamOut) "
"must be the same Tensors."));
}
auto grads = ctx.MultiInput<framework::Tensor>("Grad");
PADDLE_ENFORCE_EQ(
n, grads.size(),
platform::errors::InvalidArgument(
"The size of Input(Grad) must be equal to Input(Param), but got "
"the size of Input(Grad) is %d, the size of Input(Param) is %d.",
grads.size(), n));
auto velocitys = ctx.MultiInput<framework::Tensor>("Velocity");
PADDLE_ENFORCE_EQ(n, velocitys.size(),
platform::errors::InvalidArgument(
"The size of Input(Velocity) must be equal to "
"Input(Param), but got the size of Input(Velocity) "
"is %d, the size of Input(Param) is %d.",
velocitys.size(), n));
auto velocitys_out = ctx.MultiOutput<framework::Tensor>("VelocityOut");
PADDLE_ENFORCE_EQ(
n, velocitys_out.size(),
platform::errors::InvalidArgument(
"The size of Output(VelocityOut) must be "
"equal to Input(Param), but got the size of Output(VelocityOut) is "
"%d, the size of Input(Param) is %d.",
velocitys_out.size(), n));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(velocitys[i], velocitys_out[i],
platform::errors::InvalidArgument(
"Input(Velocity) and Output(VelocityOut) must be "
"the same Tensors."));
}
auto mu = ctx.Attr<float>("mu");
auto lrs = ctx.MultiInput<framework::Tensor>("LearningRate");
if (lrs.size() != 1) {
PADDLE_ENFORCE_EQ(
n, lrs.size(),
platform::errors::InvalidArgument(
"If the size of Input(LearningRate) is not 1, the size of "
"Input(LearningRate) must be "
"equal to Input(Param), but got the size of Input(LearningRate) "
"is %d, the size of Input(Param) is %d.",
lrs.size(), n));
}
auto use_nesterov = ctx.Attr<bool>("use_nesterov");
auto regularization_methods =
ctx.Attr<std::vector<std::string>>("regularization_method");
auto regularization_coeffs =
ctx.Attr<std::vector<float>>("regularization_coeff");
if (regularization_methods.size() != 0) {
PADDLE_ENFORCE_EQ(
n, regularization_methods.size(),
platform::errors::InvalidArgument(
"The size of Attr(regularization_method) must be equal "
"to Input(Param), but got the size of "
"Attr(regularization_method) is %d, the size of Input(Param) is "
"%d.",
regularization_methods.size(), n));
PADDLE_ENFORCE_EQ(
n, regularization_coeffs.size(),
platform::errors::InvalidArgument(
"The size of Attr(regularization_coeff) must be equal "
"to Input(Param), but got the size of Attr(regularization_coeff) "
"is %d, the size of Input(Param) is %d.",
regularization_coeffs.size(), n));
}
VLOG(5) << "use_nesterov: " << use_nesterov
<< ", regularization_methods.size(): "
<< regularization_methods.size()
<< ", regularization_coeffs.size(): "
<< regularization_coeffs.size();
auto& dev_ctx = ctx.template device_context<platform::MLUDeviceContext>();
Tensor mu_tensor = ctx.AllocateTmpTensor<T, MLUDeviceContext>({1}, dev_ctx);
MLUCnnlTensorDesc mu_tensor_desc(mu_tensor);
MLUCnnl::Fill(ctx, mu, mu_tensor_desc.get(), GetBasePtr(&mu_tensor));
for (size_t idx = 0; idx < n; ++idx) {
RegularizationType regularization_flag =
regularization_methods.size() > 0 &&
regularization_methods[idx] == "l2_decay"
? RegularizationType::kL2DECAY
: RegularizationType::kNONE;
T regularization_coeff = static_cast<T>(0.0);
if (regularization_coeffs.size() != 0) {
regularization_coeff = static_cast<T>(regularization_coeffs[idx]);
}
auto learning_rate = lrs.size() > 1 ? lrs[idx] : lrs[0];
auto param_out = params_out[idx];
auto velocity_out = velocitys_out[idx];
auto grad = grads[idx];
Tensor regularized_grad;
MLUCnnlTensorDesc param_desc(*param_out);
if (regularization_flag == RegularizationType::kL2DECAY) {
regularized_grad = ctx.AllocateTmpTensor<T, MLUDeviceContext>(
param_out->dims(), dev_ctx);
MLUCnnlOpTensorDesc op_tensor_desc(
CNNL_OP_TENSOR_ADD, ToCnnlDataType<T>(), CNNL_NOT_PROPAGATE_NAN);
MLUCnnl::OpTensor(ctx, op_tensor_desc.get(), param_desc.get(),
GetBasePtr(param_out), param_desc.get(),
GetBasePtr(grad), param_desc.get(),
GetBasePtr(&regularized_grad), ToCnnlDataType<T>(),
regularization_coeff);
} else {
regularized_grad = *grad;
}
MLUCnnl::ApplyMomentum(ctx, param_desc.get(),
GetBasePtr(&regularized_grad), use_nesterov,
GetBasePtr(learning_rate), GetBasePtr(&mu_tensor),
GetBasePtr(param_out), GetBasePtr(velocity_out));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(merged_momentum, ops::MLUMergedMomentumOpKernel<float>,
ops::MLUMergedMomentumOpKernel<plat::float16>);
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append('..')
import unittest
import paddle
import numpy as np
from paddle.fluid.layer_helper import LayerHelper
from collections import OrderedDict
def run_momentum_op(params,
grads,
velocitys,
master_params,
learning_rate,
place,
multi_precision,
mu=0.9,
rescale_grad=0.01,
use_merged=False):
assert len(params) == len(grads)
assert len(params) == len(velocitys)
if multi_precision:
assert len(params) == len(master_params)
op_type = 'merged_momentum' if use_merged else 'momentum'
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
helper = LayerHelper(op_type, **locals())
attrs = {
'mu': mu,
'multi_precision': multi_precision,
'rescale_grad': rescale_grad,
}
param_vars = [
helper.create_variable(
persistable=True, shape=p.shape, dtype=p.dtype) for p in params
]
grad_vars = [
helper.create_variable(
shape=g.shape, dtype=g.dtype) for g in grads
]
velocity_vars = [
helper.create_variable(
persistable=True, shape=v.shape, dtype=v.dtype)
for v in velocitys
]
lr_var = helper.create_variable(
persistable=True,
shape=learning_rate.shape,
dtype=learning_rate.dtype)
feed_dict = OrderedDict()
feed_dict.update(
OrderedDict([(p_var.name, p_val)
for p_var, p_val in zip(param_vars, params)]))
feed_dict.update(
OrderedDict([(v_var.name, v_val)
for v_var, v_val in zip(velocity_vars, velocitys)]))
fetch_list = list(feed_dict.keys())
feed_dict.update(
OrderedDict([(g_var.name, g_val)
for g_var, g_val in zip(grad_vars, grads)]))
feed_dict.update({lr_var.name: learning_rate})
if multi_precision:
master_param_vars = [
helper.create_variable(
persistable=True, shape=p.shape, dtype=p.dtype)
for p in master_params
]
feed_dict.update(
OrderedDict([(mp_var.name, mp_val)
for mp_var, mp_val in zip(master_param_vars,
master_params)]))
# CPUPlace does not use MasterParam
if isinstance(place, paddle.CUDAPlace):
fetch_list = fetch_list + [
mp_var.name for mp_var in master_param_vars
]
else:
master_param_vars = None
if not use_merged:
for i, (p, g,
v) in enumerate(zip(param_vars, grad_vars, velocity_vars)):
inputs = {
'Param': p,
'Grad': g,
'Velocity': v,
'LearningRate': lr_var,
}
outputs = {'ParamOut': p, 'VelocityOut': v}
if multi_precision:
inputs['MasterParam'] = master_param_vars[i]
outputs['MasterParamOut'] = master_param_vars[i]
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
else:
inputs = {
'Param': param_vars,
'Grad': grad_vars,
'Velocity': velocity_vars,
'LearningRate': lr_var,
}
outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars}
if multi_precision:
inputs['MasterParam'] = master_param_vars
outputs['MasterParamOut'] = master_param_vars
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
exe = paddle.static.Executor(place)
with paddle.static.scope_guard(paddle.static.Scope()):
exe.run(startup)
return exe.run(main, feed=feed_dict, fetch_list=fetch_list)
def run_momentum_op2(params,
grads,
velocitys,
master_params,
learning_rate,
place,
multi_precision,
mu=0.9,
rescale_grad=0.01,
use_merged=False,
use_nesterov=True):
assert len(params) == len(grads)
assert len(params) == len(velocitys)
if multi_precision:
assert len(params) == len(master_params)
op_type = 'merged_momentum' if use_merged else 'momentum'
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
helper = LayerHelper(op_type, **locals())
param_vars = [
helper.create_variable(
persistable=True, shape=p.shape, dtype=p.dtype) for p in params
]
grad_vars = [
helper.create_variable(
shape=g.shape, dtype=g.dtype) for g in grads
]
velocity_vars = [
helper.create_variable(
persistable=True, shape=v.shape, dtype=v.dtype)
for v in velocitys
]
lr_var = helper.create_variable(
persistable=True,
shape=learning_rate.shape,
dtype=learning_rate.dtype)
feed_dict = OrderedDict()
feed_dict.update(
OrderedDict([(p_var.name, p_val)
for p_var, p_val in zip(param_vars, params)]))
feed_dict.update(
OrderedDict([(v_var.name, v_val)
for v_var, v_val in zip(velocity_vars, velocitys)]))
fetch_list = list(feed_dict.keys())
feed_dict.update(
OrderedDict([(g_var.name, g_val)
for g_var, g_val in zip(grad_vars, grads)]))
feed_dict.update({lr_var.name: learning_rate})
if multi_precision:
master_param_vars = [
helper.create_variable(
persistable=True, shape=p.shape, dtype=p.dtype)
for p in master_params
]
feed_dict.update(
OrderedDict([(mp_var.name, mp_val)
for mp_var, mp_val in zip(master_param_vars,
master_params)]))
# CPUPlace does not use MasterParam
if isinstance(place, paddle.CUDAPlace):
fetch_list = fetch_list + [
mp_var.name for mp_var in master_param_vars
]
else:
master_param_vars = None
if not use_merged:
for i, (p, g,
v) in enumerate(zip(param_vars, grad_vars, velocity_vars)):
inputs = {
'Param': p,
'Grad': g,
'Velocity': v,
'LearningRate': lr_var,
}
outputs = {'ParamOut': p, 'VelocityOut': v}
if multi_precision:
inputs['MasterParam'] = master_param_vars[i]
outputs['MasterParamOut'] = master_param_vars[i]
attrs = {
'mu': mu,
'multi_precision': multi_precision,
'rescale_grad': rescale_grad,
'use_nesterov': use_nesterov,
'regularization_method': 'l2_decay',
'regularization_coeff': 2.0,
}
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
else:
inputs = {
'Param': param_vars,
'Grad': grad_vars,
'Velocity': velocity_vars,
'LearningRate': lr_var,
}
outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars}
if multi_precision:
inputs['MasterParam'] = master_param_vars
outputs['MasterParamOut'] = master_param_vars
attrs = {
'mu': mu,
'multi_precision': multi_precision,
'rescale_grad': rescale_grad,
'use_nesterov': use_nesterov,
'regularization_method':
['l2_decay' for i in range(len(param_vars))],
'regularization_coeff': [2.0 for i in range(len(param_vars))],
}
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
exe = paddle.static.Executor(place)
with paddle.static.scope_guard(paddle.static.Scope()):
exe.run(startup)
return exe.run(main, feed=feed_dict, fetch_list=fetch_list)
class TestMergedMomentum(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]]
self.seed = 10
self.place = paddle.device.MLUPlace(0)
self.__class__.use_mlu = True
def gen_rand_data(self, shapes, dtype):
return [np.random.random(s).astype(dtype) for s in shapes]
def prepare_data(self, shapes, multi_precision, seed, place):
np.random.seed(seed)
mp_dtype = np.float32
dtype = np.float32
params = self.gen_rand_data(shapes, dtype)
grads = self.gen_rand_data(shapes, dtype)
velocitys = self.gen_rand_data(shapes, mp_dtype)
learning_rate = self.gen_rand_data([[1]], mp_dtype)[0]
if multi_precision:
master_params = [p.astype(mp_dtype) for p in params]
else:
master_params = None
return params, grads, velocitys, master_params, learning_rate
def check_with_place(self, place, multi_precision):
params, grads, velocitys, master_params, learning_rate = self.prepare_data(
self.shapes, multi_precision, self.seed, place)
def run_op(use_merged):
# MLU Momentum Op does not support rescale_grad
rescale_grad = 1.0
return run_momentum_op(
params,
grads,
velocitys,
master_params,
learning_rate,
place,
multi_precision,
rescale_grad=rescale_grad,
use_merged=use_merged)
outs1 = run_op(True)
outs2 = run_op(False)
self.assertEqual(len(outs1), len(outs2))
for i, (out1, out2) in enumerate(zip(outs1, outs2)):
self.assertTrue(np.allclose(out1, out2, atol=1e-7))
def test_main(self):
self.check_with_place(self.place, multi_precision=False)
class TestMergedMomentum2(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]]
self.seed = 10
self.place = paddle.device.MLUPlace(0)
self.__class__.use_mlu = True
def gen_rand_data(self, shapes, dtype):
return [np.random.random(s).astype(dtype) for s in shapes]
def prepare_data(self, shapes, multi_precision, seed, place):
np.random.seed(seed)
mp_dtype = np.float32
dtype = np.float32 # np.float16
params = self.gen_rand_data(shapes, dtype)
grads = self.gen_rand_data(shapes, dtype)
velocitys = self.gen_rand_data(shapes, mp_dtype)
learning_rate = self.gen_rand_data([[1]], mp_dtype)[0]
if multi_precision:
master_params = [p.astype(mp_dtype) for p in params]
else:
master_params = None
return params, grads, velocitys, master_params, learning_rate
def check_with_place(self, place, multi_precision):
params, grads, velocitys, master_params, learning_rate = self.prepare_data(
self.shapes, multi_precision, self.seed, place)
def run_op(use_nesterov, use_merged):
# MLU Momentum Op does not support rescale_grad
rescale_grad = 1.0
return run_momentum_op2(
params,
grads,
velocitys,
master_params,
learning_rate,
place,
multi_precision,
rescale_grad=rescale_grad,
use_merged=use_merged,
use_nesterov=use_nesterov)
outs1 = run_op(use_nesterov=True, use_merged=True)
outs2 = run_op(use_nesterov=True, use_merged=False)
self.assertEqual(len(outs1), len(outs2))
for i, (out1, out2) in enumerate(zip(outs1, outs2)):
self.assertTrue(np.allclose(out1, out2, atol=1e-7))
outs3 = run_op(use_nesterov=False, use_merged=True)
outs4 = run_op(use_nesterov=False, use_merged=False)
self.assertEqual(len(outs3), len(outs4))
for j, (out3, out4) in enumerate(zip(outs3, outs4)):
self.assertTrue(np.allclose(out3, out4, atol=1e-7))
def test_main(self):
self.check_with_place(self.place, multi_precision=False)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册