未验证 提交 e0f1c9f2 编写于 作者: D dongfangshenzhu 提交者: GitHub

[XPU] add merged_momentum unittest and change momentum (#45241)

* add merged_momentum *test=kunlun

* add merged_momentum *test=kunlun

* add fp16 to merged_momentum,*test=kunlun

* change dist_model.cc

* add merged_momentum unittest and  change momentum,test=kunlun

* add merged_momentum unittest and  change momentum,test=kunlun

* add merged_momentum unittest and  change momentum,test=kunlun

* add merged_momentum unittest and  change momentum,test=kunlun
上级 4e2a3c11
......@@ -41,6 +41,42 @@ class MergedMomentumOpXPUKernel : public framework::OpKernel<T> {
ctx.Attr<std::vector<std::string>>("regularization_method");
auto regularization_coeff =
ctx.Attr<std::vector<float>>("regularization_coeff");
PADDLE_ENFORCE_EQ(op_num,
params_out.size(),
platform::errors::InvalidArgument(
"The size of Output(ParamOut) must be equal to "
"Input(Param), but got the size of Output(ParamOut) "
"is %d, the size of Input(Param) is %d.",
params_out.size(),
op_num));
PADDLE_ENFORCE_EQ(op_num,
velocity.size(),
platform::errors::InvalidArgument(
"The size of Output(Velocity) must be equal to "
"Input(Param), but got the size of Output(Velocity) "
"is %d, the size of Input(Param) is %d.",
velocity.size(),
op_num));
PADDLE_ENFORCE_EQ(
op_num,
velocity_out.size(),
platform::errors::InvalidArgument(
"The size of Output(VelocityOut) must be equal to "
"Input(Param), but got the size of Output(VelocityOut) "
"is %d, the size of Input(Param) is %d.",
velocity_out.size(),
op_num));
PADDLE_ENFORCE_EQ(
op_num,
grad.size(),
platform::errors::InvalidArgument(
"The size of Input(Grad) must be equal to Input(Param), but got "
"the size of Input(Grad) is %d, the size of Input(Param) is %d.",
grad.size(),
op_num));
if (regularization_method.size() == 0) {
regularization_method.resize(op_num);
}
std::vector<XPUType*> param_list(op_num);
std::vector<XPUType*> velocity_list(op_num);
std::vector<XPUType*> grad_list(op_num);
......@@ -82,39 +118,6 @@ class MergedMomentumOpXPUKernel : public framework::OpKernel<T> {
return;
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
PADDLE_ENFORCE_EQ(op_num,
params_out.size(),
platform::errors::InvalidArgument(
"The size of Output(ParamOut) must be equal to "
"Input(Param), but got the size of Output(ParamOut) "
"is %d, the size of Input(Param) is %d.",
params_out.size(),
op_num));
PADDLE_ENFORCE_EQ(op_num,
velocity.size(),
platform::errors::InvalidArgument(
"The size of Output(Velocity) must be equal to "
"Input(Param), but got the size of Output(Velocity) "
"is %d, the size of Input(Param) is %d.",
velocity.size(),
op_num));
PADDLE_ENFORCE_EQ(
op_num,
velocity_out.size(),
platform::errors::InvalidArgument(
"The size of Output(VelocityOut) must be equal to "
"Input(Param), but got the size of Output(VelocityOut) "
"is %d, the size of Input(Param) is %d.",
velocity_out.size(),
op_num));
PADDLE_ENFORCE_EQ(
op_num,
grad.size(),
platform::errors::InvalidArgument(
"The size of Input(Grad) must be equal to Input(Param), but got "
"the size of Input(Grad) is %d, the size of Input(Param) is %d.",
grad.size(),
op_num));
int r = xpu::merged_momentum(dev_ctx.x_context(),
param_list,
velocity_list,
......
......@@ -21,6 +21,8 @@ namespace operators {
template <typename DeviceContext, typename T>
class MomentumOpXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
T mu = static_cast<T>(ctx.Attr<float>("mu"));
......@@ -33,15 +35,13 @@ class MomentumOpXPUKernel : public framework::OpKernel<T> {
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
param_out->mutable_data<T>(ctx.GetPlace());
velocity_out->mutable_data<T>(ctx.GetPlace());
auto* lr = learning_rate->data<T>();
auto* lr = learning_rate->data<float>();
auto regularization_method = ctx.Attr<std::string>("regularization_method");
auto regularization_coeff = ctx.Attr<float>("regularization_coeff");
if (regularization_method != "l2_decay") {
// only support l2_decay
regularization_coeff = 0.0f;
}
auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE_EQ(grad_var->IsType<framework::LoDTensor>(),
true,
......@@ -50,20 +50,18 @@ class MomentumOpXPUKernel : public framework::OpKernel<T> {
"MomentumOp-XPU. Excepted "
"LodTensor, But received [%s] and [%s]",
paddle::framework::ToTypeName(grad_var->Type())));
auto grad = ctx.Input<framework::Tensor>("Grad");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
// int momentum(Context* ctx, const T* param, const T* velocity, const T*
// grad, T* param_out, T* velocity_out, int len, const float* lr, int
// use_nesterov, float mu, float l2_weight_decay);
int r = xpu::momentum(dev_ctx.x_context(),
param->data<float>(),
velocity->data<float>(),
grad->data<float>(),
param_out->data<float>(),
velocity_out->data<float>(),
reinterpret_cast<const XPUType*>(param->data<T>()),
reinterpret_cast<const XPUType*>(velocity->data<T>()),
reinterpret_cast<const XPUType*>(grad->data<T>()),
reinterpret_cast<XPUType*>(param_out->data<T>()),
reinterpret_cast<XPUType*>(velocity_out->data<T>()),
param_out->numel(),
lr,
use_nesterov,
......@@ -78,5 +76,7 @@ class MomentumOpXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
momentum,
ops::MomentumOpXPUKernel<paddle::platform::XPUDeviceContext, float>);
ops::MomentumOpXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::MomentumOpXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
#endif
......@@ -358,7 +358,9 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::FP16, XPUPlace())})},
{"mish_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"mish", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"momentum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"momentum",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"mul",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
import paddle
import paddle.fluid.core as core
from op_test import OpTest
from op_test_xpu import XPUOpTest
from test_merged_momentum_op_xpu_base import TestMergedMomentumBase
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
class XPUTestMergedMomentumOP(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'merged_momentum'
self.use_dynamic_create_class = False
class TestMergedMomentumOp(TestMergedMomentumBase):
def setUp(self):
super().setUp()
self.set_case()
def set_case(self):
self.shapes = [[3, 4], [2, 7], [5, 6, 8]]
self.place = paddle.fluid.XPUPlace(0)
self.seed = 1
def testalltype(self):
self.check_with_place(self.place, self.in_type)
class TestMergedMomentum1(TestMergedMomentumOp):
def set_case(self):
self.shapes = [[3, 4], [2, 7], [5, 6, 8]]
class TestMergedMomentum2(TestMergedMomentumOp):
def set_case(self):
self.shapes = [[3, 4], [2, 7]]
class TestMergedMomentum3(TestMergedMomentumOp):
def set_case(self):
self.shapes = [[3, 4]]
class TestMergedMomentum4(TestMergedMomentumOp):
def set_case(self):
self.shapes = [[3, 4], [2, 7], [5, 6, 7], [9, 9], [10, 12]]
support_types = get_xpu_op_support_types('merged_momentum')
for stype in support_types:
create_test_class(globals(), XPUTestMergedMomentumOP, stype)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append('..')
import unittest
import paddle
import numpy as np
from paddle.fluid.layer_helper import LayerHelper
from collections import OrderedDict
def run_momentum_op(params,
grads,
velocitys,
master_params,
learning_rate,
place,
multi_precision,
mu=0.9,
rescale_grad=0.01,
use_merged=False,
use_nesterov=True):
assert len(params) == len(grads)
assert len(params) == len(velocitys)
if multi_precision:
assert len(params) == len(master_params)
op_type = 'merged_momentum' if use_merged else 'momentum'
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
helper = LayerHelper(op_type, **locals())
param_vars = [
helper.create_variable(persistable=True,
shape=p.shape,
dtype=p.dtype) for p in params
]
grad_vars = [
helper.create_variable(shape=g.shape, dtype=g.dtype) for g in grads
]
velocity_vars = [
helper.create_variable(persistable=True,
shape=v.shape,
dtype=v.dtype) for v in velocitys
]
lr_var = helper.create_variable(persistable=True,
shape=learning_rate.shape,
dtype=learning_rate.dtype)
feed_dict = OrderedDict()
feed_dict.update(
OrderedDict([(p_var.name, p_val)
for p_var, p_val in zip(param_vars, params)]))
feed_dict.update(
OrderedDict([(v_var.name, v_val)
for v_var, v_val in zip(velocity_vars, velocitys)]))
fetch_list = list(feed_dict.keys())
feed_dict.update(
OrderedDict([(g_var.name, g_val)
for g_var, g_val in zip(grad_vars, grads)]))
feed_dict.update({lr_var.name: learning_rate})
if multi_precision:
master_param_vars = [
helper.create_variable(persistable=True,
shape=p.shape,
dtype=p.dtype) for p in master_params
]
feed_dict.update(
OrderedDict([
(mp_var.name, mp_val)
for mp_var, mp_val in zip(master_param_vars, master_params)
]))
# CPUPlace does not use MasterParam
if isinstance(place, paddle.CUDAPlace):
fetch_list = fetch_list + [
mp_var.name for mp_var in master_param_vars
]
else:
master_param_vars = None
if not use_merged:
for i, (p, g,
v) in enumerate(zip(param_vars, grad_vars, velocity_vars)):
inputs = {
'Param': p,
'Grad': g,
'Velocity': v,
'LearningRate': lr_var,
}
outputs = {'ParamOut': p, 'VelocityOut': v}
if multi_precision:
inputs['MasterParam'] = master_param_vars[i]
outputs['MasterParamOut'] = master_param_vars[i]
attrs = {
'mu': mu,
'multi_precision': multi_precision,
'rescale_grad': rescale_grad,
'use_nesterov': use_nesterov,
'regularization_method': 'l2_decay',
'regularization_coeff': 2.0,
}
helper.append_op(type=op_type,
inputs=inputs,
outputs=outputs,
attrs=attrs)
else:
inputs = {
'Param': param_vars,
'Grad': grad_vars,
'Velocity': velocity_vars,
'LearningRate': lr_var,
}
outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars}
if multi_precision:
inputs['MasterParam'] = master_param_vars
outputs['MasterParamOut'] = master_param_vars
attrs = {
'mu':
mu,
'multi_precision':
multi_precision,
'rescale_grad':
rescale_grad,
'use_nesterov':
use_nesterov,
'regularization_method':
['l2_decay' for i in range(len(param_vars))],
'regularization_coeff': [2.0 for i in range(len(param_vars))],
}
helper.append_op(type=op_type,
inputs=inputs,
outputs=outputs,
attrs=attrs)
exe = paddle.static.Executor(place)
with paddle.static.scope_guard(paddle.static.Scope()):
exe.run(startup)
return exe.run(main, feed=feed_dict, fetch_list=fetch_list)
class TestMergedMomentumBase(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]]
self.seed = 10
self.place = paddle.fluid.XPUPlace(0)
self.__class__.use_xpu = True
def gen_rand_data(self, shapes, dtype):
return [np.random.random(s).astype(dtype) for s in shapes]
def prepare_data(self, shapes, multi_precision, seed, dtype, place):
np.random.seed(seed)
params = self.gen_rand_data(shapes, dtype)
grads = self.gen_rand_data(shapes, dtype)
velocitys = self.gen_rand_data(shapes, dtype)
learning_rate = self.gen_rand_data([[1]], np.float32)[0]
if multi_precision:
master_params = [p.astype(dtype) for p in params]
else:
master_params = None
return params, grads, velocitys, master_params, learning_rate
def check_with_place(self, place, dtype, multi_precision=False):
params, grads, velocitys, master_params, learning_rate = self.prepare_data(
self.shapes, multi_precision, self.seed, dtype, place)
def run_op(use_nesterov, use_merged):
# NPU Momentum Op does not support rescale_grad
rescale_grad = 1.0
return run_momentum_op(params,
grads,
velocitys,
master_params,
learning_rate,
place,
multi_precision,
rescale_grad=rescale_grad,
use_merged=use_merged,
use_nesterov=use_nesterov)
outs1 = run_op(use_nesterov=True, use_merged=True)
outs2 = run_op(use_nesterov=True, use_merged=False)
self.assertEqual(len(outs1), len(outs2))
for i, (out1, out2) in enumerate(zip(outs1, outs2)):
np.testing.assert_allclose(out1, out2, atol=1e-7)
outs3 = run_op(use_nesterov=False, use_merged=True)
outs4 = run_op(use_nesterov=False, use_merged=False)
self.assertEqual(len(outs3), len(outs4))
for j, (out3, out4) in enumerate(zip(outs3, outs4)):
np.testing.assert_allclose(out3, out4, atol=1e-7)
if __name__ == "__main__":
unittest.main()
......@@ -66,7 +66,6 @@ class XPUTestMomentumOP(XPUOpTestWrapper):
def set_case(self):
self.op_type = 'momentum'
self.dtype = self.in_type
self.init_config()
self.param = np.random.uniform(-1, 1,
......@@ -75,7 +74,6 @@ class XPUTestMomentumOP(XPUOpTestWrapper):
self.input_shape).astype(self.dtype)
self.velocity = np.random.uniform(-1, 1, self.input_shape).astype(
self.dtype)
param_out, velocity_out = calculate_momentum_by_numpy(
param=self.param,
grad=self.grad,
......@@ -85,6 +83,8 @@ class XPUTestMomentumOP(XPUOpTestWrapper):
learning_rate=self.learning_rate,
regularization_method=self.regularization_method,
regularization_coeff=self.regularization_coeff)
param_out = param_out.astype(self.dtype)
velocity_out = velocity_out.astype(self.dtype)
self.inputs = {
'Param': self.param,
'Grad': self.grad,
......@@ -101,14 +101,14 @@ class XPUTestMomentumOP(XPUOpTestWrapper):
self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out}
def init_dtype(self):
self.dtype = np.float32
self.dtype = self.in_type
def test_check_output(self):
self.check_output_with_place(self.place)
def init_config(self):
self.input_shape = [864]
self.learning_rate = np.array([0.001]).astype(self.dtype)
self.learning_rate = np.array([0.001]).astype(float)
self.mu = 0.0001
self.use_nesterov = False
self.regularization_method = None
......@@ -118,7 +118,7 @@ class XPUTestMomentumOP(XPUOpTestWrapper):
def init_config(self):
self.input_shape = [2, 768]
self.learning_rate = np.array([0.002]).astype(self.dtype)
self.learning_rate = np.array([0.002]).astype(float)
self.mu = 0.001
self.use_nesterov = False
self.regularization_method = None
......@@ -128,7 +128,7 @@ class XPUTestMomentumOP(XPUOpTestWrapper):
def init_config(self):
self.input_shape = [3, 8, 4096]
self.learning_rate = np.array([0.005]).astype(self.dtype)
self.learning_rate = np.array([0.005]).astype(float)
self.mu = 0.002
self.use_nesterov = True
self.regularization_method = None
......@@ -138,7 +138,7 @@ class XPUTestMomentumOP(XPUOpTestWrapper):
def init_config(self):
self.input_shape = [1024]
self.learning_rate = np.array([0.01]).astype(self.dtype)
self.learning_rate = np.array([0.01]).astype(float)
self.mu = 0.0001
self.use_nesterov = False
if self.xpu_version != core.XPUVersion.XPU1:
......@@ -153,7 +153,7 @@ class XPUTestMomentumOP(XPUOpTestWrapper):
def init_config(self):
self.input_shape = [2, 2, 255]
self.learning_rate = np.array([0.0005]).astype(self.dtype)
self.learning_rate = np.array([0.0005]).astype(float)
self.mu = 0.005
self.use_nesterov = True
if self.xpu_version != core.XPUVersion.XPU1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册