diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op_xpu.cc b/paddle/fluid/operators/optimizers/merged_momentum_op_xpu.cc index 3993a46add4805f4fb57d786f4c07bd7997f604e..5ba1f8b98fae8d8af5a5658a12e3e76a73ce497b 100644 --- a/paddle/fluid/operators/optimizers/merged_momentum_op_xpu.cc +++ b/paddle/fluid/operators/optimizers/merged_momentum_op_xpu.cc @@ -41,6 +41,42 @@ class MergedMomentumOpXPUKernel : public framework::OpKernel { ctx.Attr>("regularization_method"); auto regularization_coeff = ctx.Attr>("regularization_coeff"); + PADDLE_ENFORCE_EQ(op_num, + params_out.size(), + platform::errors::InvalidArgument( + "The size of Output(ParamOut) must be equal to " + "Input(Param), but got the size of Output(ParamOut) " + "is %d, the size of Input(Param) is %d.", + params_out.size(), + op_num)); + PADDLE_ENFORCE_EQ(op_num, + velocity.size(), + platform::errors::InvalidArgument( + "The size of Output(Velocity) must be equal to " + "Input(Param), but got the size of Output(Velocity) " + "is %d, the size of Input(Param) is %d.", + velocity.size(), + op_num)); + PADDLE_ENFORCE_EQ( + op_num, + velocity_out.size(), + platform::errors::InvalidArgument( + "The size of Output(VelocityOut) must be equal to " + "Input(Param), but got the size of Output(VelocityOut) " + "is %d, the size of Input(Param) is %d.", + velocity_out.size(), + op_num)); + PADDLE_ENFORCE_EQ( + op_num, + grad.size(), + platform::errors::InvalidArgument( + "The size of Input(Grad) must be equal to Input(Param), but got " + "the size of Input(Grad) is %d, the size of Input(Param) is %d.", + grad.size(), + op_num)); + if (regularization_method.size() == 0) { + regularization_method.resize(op_num); + } std::vector param_list(op_num); std::vector velocity_list(op_num); std::vector grad_list(op_num); @@ -82,39 +118,6 @@ class MergedMomentumOpXPUKernel : public framework::OpKernel { return; } auto& dev_ctx = ctx.template device_context(); - PADDLE_ENFORCE_EQ(op_num, - params_out.size(), - platform::errors::InvalidArgument( - "The size of Output(ParamOut) must be equal to " - "Input(Param), but got the size of Output(ParamOut) " - "is %d, the size of Input(Param) is %d.", - params_out.size(), - op_num)); - PADDLE_ENFORCE_EQ(op_num, - velocity.size(), - platform::errors::InvalidArgument( - "The size of Output(Velocity) must be equal to " - "Input(Param), but got the size of Output(Velocity) " - "is %d, the size of Input(Param) is %d.", - velocity.size(), - op_num)); - PADDLE_ENFORCE_EQ( - op_num, - velocity_out.size(), - platform::errors::InvalidArgument( - "The size of Output(VelocityOut) must be equal to " - "Input(Param), but got the size of Output(VelocityOut) " - "is %d, the size of Input(Param) is %d.", - velocity_out.size(), - op_num)); - PADDLE_ENFORCE_EQ( - op_num, - grad.size(), - platform::errors::InvalidArgument( - "The size of Input(Grad) must be equal to Input(Param), but got " - "the size of Input(Grad) is %d, the size of Input(Param) is %d.", - grad.size(), - op_num)); int r = xpu::merged_momentum(dev_ctx.x_context(), param_list, velocity_list, diff --git a/paddle/fluid/operators/optimizers/momentum_op_xpu.cc b/paddle/fluid/operators/optimizers/momentum_op_xpu.cc index fb19b4b7678a6bfdf2a7907fb94ffa3bffd415ba..bd62c7acaa802bd63c8754d61636987870b45abe 100644 --- a/paddle/fluid/operators/optimizers/momentum_op_xpu.cc +++ b/paddle/fluid/operators/optimizers/momentum_op_xpu.cc @@ -21,6 +21,8 @@ namespace operators { template class MomentumOpXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + public: void Compute(const framework::ExecutionContext& ctx) const override { T mu = static_cast(ctx.Attr("mu")); @@ -33,15 +35,13 @@ class MomentumOpXPUKernel : public framework::OpKernel { auto velocity_out = ctx.Output("VelocityOut"); param_out->mutable_data(ctx.GetPlace()); velocity_out->mutable_data(ctx.GetPlace()); - auto* lr = learning_rate->data(); - + auto* lr = learning_rate->data(); auto regularization_method = ctx.Attr("regularization_method"); auto regularization_coeff = ctx.Attr("regularization_coeff"); if (regularization_method != "l2_decay") { // only support l2_decay regularization_coeff = 0.0f; } - auto* grad_var = ctx.InputVar("Grad"); PADDLE_ENFORCE_EQ(grad_var->IsType(), true, @@ -50,20 +50,18 @@ class MomentumOpXPUKernel : public framework::OpKernel { "MomentumOp-XPU. Excepted " "LodTensor, But received [%s] and [%s]", paddle::framework::ToTypeName(grad_var->Type()))); - auto grad = ctx.Input("Grad"); - auto& dev_ctx = ctx.template device_context(); // int momentum(Context* ctx, const T* param, const T* velocity, const T* // grad, T* param_out, T* velocity_out, int len, const float* lr, int // use_nesterov, float mu, float l2_weight_decay); int r = xpu::momentum(dev_ctx.x_context(), - param->data(), - velocity->data(), - grad->data(), - param_out->data(), - velocity_out->data(), + reinterpret_cast(param->data()), + reinterpret_cast(velocity->data()), + reinterpret_cast(grad->data()), + reinterpret_cast(param_out->data()), + reinterpret_cast(velocity_out->data()), param_out->numel(), lr, use_nesterov, @@ -78,5 +76,7 @@ class MomentumOpXPUKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_XPU_KERNEL( momentum, - ops::MomentumOpXPUKernel); + ops::MomentumOpXPUKernel, + ops::MomentumOpXPUKernel); #endif diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 68bab4980a682cb8c97bf4aa9004944f891a6452..c246924e14b69ace3d824c557cc5df6dcf967b88 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -358,7 +358,9 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::FP16, XPUPlace())})}, {"mish_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"mish", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"momentum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"momentum", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, {"mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, diff --git a/python/paddle/fluid/tests/unittests/xpu/test_merged_momentum_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_merged_momentum_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..28a99ebac88ada4d986fafc480f549e180700285 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_merged_momentum_op_xpu.py @@ -0,0 +1,80 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys + +sys.path.append("..") + +import paddle +import paddle.fluid.core as core + +from op_test import OpTest +from op_test_xpu import XPUOpTest +from test_merged_momentum_op_xpu_base import TestMergedMomentumBase +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper + +paddle.enable_static() + + +class XPUTestMergedMomentumOP(XPUOpTestWrapper): + + def __init__(self): + self.op_name = 'merged_momentum' + self.use_dynamic_create_class = False + + class TestMergedMomentumOp(TestMergedMomentumBase): + + def setUp(self): + super().setUp() + self.set_case() + + def set_case(self): + self.shapes = [[3, 4], [2, 7], [5, 6, 8]] + self.place = paddle.fluid.XPUPlace(0) + self.seed = 1 + + def testalltype(self): + self.check_with_place(self.place, self.in_type) + + class TestMergedMomentum1(TestMergedMomentumOp): + + def set_case(self): + self.shapes = [[3, 4], [2, 7], [5, 6, 8]] + + class TestMergedMomentum2(TestMergedMomentumOp): + + def set_case(self): + self.shapes = [[3, 4], [2, 7]] + + class TestMergedMomentum3(TestMergedMomentumOp): + + def set_case(self): + self.shapes = [[3, 4]] + + class TestMergedMomentum4(TestMergedMomentumOp): + + def set_case(self): + self.shapes = [[3, 4], [2, 7], [5, 6, 7], [9, 9], [10, 12]] + + +support_types = get_xpu_op_support_types('merged_momentum') +for stype in support_types: + create_test_class(globals(), XPUTestMergedMomentumOP, stype) + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_merged_momentum_op_xpu_base.py b/python/paddle/fluid/tests/unittests/xpu/test_merged_momentum_op_xpu_base.py new file mode 100644 index 0000000000000000000000000000000000000000..64d1e5aa9b26cdcba6c6145131ee22ca4e375a24 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_merged_momentum_op_xpu_base.py @@ -0,0 +1,212 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +sys.path.append('..') +import unittest +import paddle +import numpy as np +from paddle.fluid.layer_helper import LayerHelper +from collections import OrderedDict + + +def run_momentum_op(params, + grads, + velocitys, + master_params, + learning_rate, + place, + multi_precision, + mu=0.9, + rescale_grad=0.01, + use_merged=False, + use_nesterov=True): + assert len(params) == len(grads) + assert len(params) == len(velocitys) + if multi_precision: + assert len(params) == len(master_params) + op_type = 'merged_momentum' if use_merged else 'momentum' + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): + helper = LayerHelper(op_type, **locals()) + + param_vars = [ + helper.create_variable(persistable=True, + shape=p.shape, + dtype=p.dtype) for p in params + ] + grad_vars = [ + helper.create_variable(shape=g.shape, dtype=g.dtype) for g in grads + ] + velocity_vars = [ + helper.create_variable(persistable=True, + shape=v.shape, + dtype=v.dtype) for v in velocitys + ] + lr_var = helper.create_variable(persistable=True, + shape=learning_rate.shape, + dtype=learning_rate.dtype) + + feed_dict = OrderedDict() + + feed_dict.update( + OrderedDict([(p_var.name, p_val) + for p_var, p_val in zip(param_vars, params)])) + feed_dict.update( + OrderedDict([(v_var.name, v_val) + for v_var, v_val in zip(velocity_vars, velocitys)])) + fetch_list = list(feed_dict.keys()) + + feed_dict.update( + OrderedDict([(g_var.name, g_val) + for g_var, g_val in zip(grad_vars, grads)])) + feed_dict.update({lr_var.name: learning_rate}) + + if multi_precision: + master_param_vars = [ + helper.create_variable(persistable=True, + shape=p.shape, + dtype=p.dtype) for p in master_params + ] + feed_dict.update( + OrderedDict([ + (mp_var.name, mp_val) + for mp_var, mp_val in zip(master_param_vars, master_params) + ])) + # CPUPlace does not use MasterParam + if isinstance(place, paddle.CUDAPlace): + fetch_list = fetch_list + [ + mp_var.name for mp_var in master_param_vars + ] + else: + master_param_vars = None + + if not use_merged: + for i, (p, g, + v) in enumerate(zip(param_vars, grad_vars, velocity_vars)): + inputs = { + 'Param': p, + 'Grad': g, + 'Velocity': v, + 'LearningRate': lr_var, + } + outputs = {'ParamOut': p, 'VelocityOut': v} + if multi_precision: + inputs['MasterParam'] = master_param_vars[i] + outputs['MasterParamOut'] = master_param_vars[i] + attrs = { + 'mu': mu, + 'multi_precision': multi_precision, + 'rescale_grad': rescale_grad, + 'use_nesterov': use_nesterov, + 'regularization_method': 'l2_decay', + 'regularization_coeff': 2.0, + } + helper.append_op(type=op_type, + inputs=inputs, + outputs=outputs, + attrs=attrs) + else: + inputs = { + 'Param': param_vars, + 'Grad': grad_vars, + 'Velocity': velocity_vars, + 'LearningRate': lr_var, + } + outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars} + if multi_precision: + inputs['MasterParam'] = master_param_vars + outputs['MasterParamOut'] = master_param_vars + attrs = { + 'mu': + mu, + 'multi_precision': + multi_precision, + 'rescale_grad': + rescale_grad, + 'use_nesterov': + use_nesterov, + 'regularization_method': + ['l2_decay' for i in range(len(param_vars))], + 'regularization_coeff': [2.0 for i in range(len(param_vars))], + } + helper.append_op(type=op_type, + inputs=inputs, + outputs=outputs, + attrs=attrs) + + exe = paddle.static.Executor(place) + with paddle.static.scope_guard(paddle.static.Scope()): + exe.run(startup) + return exe.run(main, feed=feed_dict, fetch_list=fetch_list) + + +class TestMergedMomentumBase(unittest.TestCase): + + def setUp(self): + paddle.enable_static() + self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]] + self.seed = 10 + self.place = paddle.fluid.XPUPlace(0) + self.__class__.use_xpu = True + + def gen_rand_data(self, shapes, dtype): + return [np.random.random(s).astype(dtype) for s in shapes] + + def prepare_data(self, shapes, multi_precision, seed, dtype, place): + np.random.seed(seed) + params = self.gen_rand_data(shapes, dtype) + grads = self.gen_rand_data(shapes, dtype) + velocitys = self.gen_rand_data(shapes, dtype) + learning_rate = self.gen_rand_data([[1]], np.float32)[0] + if multi_precision: + master_params = [p.astype(dtype) for p in params] + else: + master_params = None + return params, grads, velocitys, master_params, learning_rate + + def check_with_place(self, place, dtype, multi_precision=False): + params, grads, velocitys, master_params, learning_rate = self.prepare_data( + self.shapes, multi_precision, self.seed, dtype, place) + + def run_op(use_nesterov, use_merged): + # NPU Momentum Op does not support rescale_grad + rescale_grad = 1.0 + return run_momentum_op(params, + grads, + velocitys, + master_params, + learning_rate, + place, + multi_precision, + rescale_grad=rescale_grad, + use_merged=use_merged, + use_nesterov=use_nesterov) + + outs1 = run_op(use_nesterov=True, use_merged=True) + outs2 = run_op(use_nesterov=True, use_merged=False) + self.assertEqual(len(outs1), len(outs2)) + for i, (out1, out2) in enumerate(zip(outs1, outs2)): + np.testing.assert_allclose(out1, out2, atol=1e-7) + outs3 = run_op(use_nesterov=False, use_merged=True) + outs4 = run_op(use_nesterov=False, use_merged=False) + self.assertEqual(len(outs3), len(outs4)) + for j, (out3, out4) in enumerate(zip(outs3, outs4)): + np.testing.assert_allclose(out3, out4, atol=1e-7) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_momentum_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_momentum_op_xpu.py index a33b3e475519668e6f362a156ec195e933b38195..fd840c90f5b91f3bf0e52acd60a6bbb6a89cbb77 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_momentum_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_momentum_op_xpu.py @@ -66,7 +66,6 @@ class XPUTestMomentumOP(XPUOpTestWrapper): def set_case(self): self.op_type = 'momentum' - self.dtype = self.in_type self.init_config() self.param = np.random.uniform(-1, 1, @@ -75,7 +74,6 @@ class XPUTestMomentumOP(XPUOpTestWrapper): self.input_shape).astype(self.dtype) self.velocity = np.random.uniform(-1, 1, self.input_shape).astype( self.dtype) - param_out, velocity_out = calculate_momentum_by_numpy( param=self.param, grad=self.grad, @@ -85,6 +83,8 @@ class XPUTestMomentumOP(XPUOpTestWrapper): learning_rate=self.learning_rate, regularization_method=self.regularization_method, regularization_coeff=self.regularization_coeff) + param_out = param_out.astype(self.dtype) + velocity_out = velocity_out.astype(self.dtype) self.inputs = { 'Param': self.param, 'Grad': self.grad, @@ -101,14 +101,14 @@ class XPUTestMomentumOP(XPUOpTestWrapper): self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out} def init_dtype(self): - self.dtype = np.float32 + self.dtype = self.in_type def test_check_output(self): self.check_output_with_place(self.place) def init_config(self): self.input_shape = [864] - self.learning_rate = np.array([0.001]).astype(self.dtype) + self.learning_rate = np.array([0.001]).astype(float) self.mu = 0.0001 self.use_nesterov = False self.regularization_method = None @@ -118,7 +118,7 @@ class XPUTestMomentumOP(XPUOpTestWrapper): def init_config(self): self.input_shape = [2, 768] - self.learning_rate = np.array([0.002]).astype(self.dtype) + self.learning_rate = np.array([0.002]).astype(float) self.mu = 0.001 self.use_nesterov = False self.regularization_method = None @@ -128,7 +128,7 @@ class XPUTestMomentumOP(XPUOpTestWrapper): def init_config(self): self.input_shape = [3, 8, 4096] - self.learning_rate = np.array([0.005]).astype(self.dtype) + self.learning_rate = np.array([0.005]).astype(float) self.mu = 0.002 self.use_nesterov = True self.regularization_method = None @@ -138,7 +138,7 @@ class XPUTestMomentumOP(XPUOpTestWrapper): def init_config(self): self.input_shape = [1024] - self.learning_rate = np.array([0.01]).astype(self.dtype) + self.learning_rate = np.array([0.01]).astype(float) self.mu = 0.0001 self.use_nesterov = False if self.xpu_version != core.XPUVersion.XPU1: @@ -153,7 +153,7 @@ class XPUTestMomentumOP(XPUOpTestWrapper): def init_config(self): self.input_shape = [2, 2, 255] - self.learning_rate = np.array([0.0005]).astype(self.dtype) + self.learning_rate = np.array([0.0005]).astype(float) self.mu = 0.005 self.use_nesterov = True if self.xpu_version != core.XPUVersion.XPU1: