From 2b74b7391d4fdc8c4d8e8d4f7a147e95909c6981 Mon Sep 17 00:00:00 2001 From: Aganlengzi Date: Fri, 25 Mar 2022 10:43:35 +0800 Subject: [PATCH] [NPU] add merged_momentum (#40875) * [NPU] add merged_momentum * fix * fix device --- .../optimizers/merged_momentum_op_npu.cc | 167 ++++++++ .../npu/test_merged_momentum_op_npu.py | 373 ++++++++++++++++++ 2 files changed, 540 insertions(+) create mode 100644 paddle/fluid/operators/optimizers/merged_momentum_op_npu.cc create mode 100644 python/paddle/fluid/tests/unittests/npu/test_merged_momentum_op_npu.py diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op_npu.cc b/paddle/fluid/operators/optimizers/merged_momentum_op_npu.cc new file mode 100644 index 00000000000..f29a42be9d9 --- /dev/null +++ b/paddle/fluid/operators/optimizers/merged_momentum_op_npu.cc @@ -0,0 +1,167 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/optimizers/merged_momentum_op.h" + +#include "paddle/fluid/platform/device/npu/npu_op_runner.h" + +namespace paddle { +namespace operators { + +template +class NPUMergedMomentumOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto params = ctx.MultiInput("Param"); + auto params_out = ctx.MultiOutput("ParamOut"); + size_t n = params.size(); + PADDLE_ENFORCE_EQ(n, params_out.size(), + platform::errors::InvalidArgument( + "The size of Output(ParamOut) must be equal to " + "Input(Param), but got the size of Output(ParamOut) " + "is %d, the size of Input(Param) is %d.", + params_out.size(), n)); + for (size_t i = 0; i < n; ++i) { + PADDLE_ENFORCE_EQ(params[i], params_out[i], + platform::errors::InvalidArgument( + "The size of Input(Param) and Output(ParamOut) " + "must be the same Tensors.")); + } + + auto grads = ctx.MultiInput("Grad"); + PADDLE_ENFORCE_EQ( + n, grads.size(), + platform::errors::InvalidArgument( + "The size of Input(Grad) must be equal to Input(Param), but got " + "the size of Input(Grad) is %d, the size of Input(Param) is %d.", + grads.size(), n)); + + auto velocitys = ctx.MultiInput("Velocity"); + PADDLE_ENFORCE_EQ(n, velocitys.size(), + platform::errors::InvalidArgument( + "The size of Input(Velocity) must be equal to " + "Input(Param), but got the size of Input(Velocity) " + "is %d, the size of Input(Param) is %d.", + velocitys.size(), n)); + + auto velocitys_out = ctx.MultiOutput("VelocityOut"); + PADDLE_ENFORCE_EQ( + n, velocitys_out.size(), + platform::errors::InvalidArgument( + "The size of Output(VelocityOut) must be " + "equal to Input(Param), but got the size of Output(VelocityOut) is " + "%d, the size of Input(Param) is %d.", + velocitys_out.size(), n)); + for (size_t i = 0; i < n; ++i) { + PADDLE_ENFORCE_EQ(velocitys[i], velocitys_out[i], + platform::errors::InvalidArgument( + "Input(Velocity) and Output(VelocityOut) must be " + "the same Tensors.")); + } + + T mu = static_cast(ctx.Attr("mu")); + auto lrs = ctx.MultiInput("LearningRate"); + if (lrs.size() != 1) { + PADDLE_ENFORCE_EQ( + n, lrs.size(), + platform::errors::InvalidArgument( + "If the size of Input(LearningRate) is not 1, the size of " + "Input(LearningRate) must be " + "equal to Input(Param), but got the size of Input(LearningRate) " + "is %d, the size of Input(Param) is %d.", + lrs.size(), n)); + } + auto use_nesterov = ctx.Attr("use_nesterov"); + auto regularization_methods = + ctx.Attr>("regularization_method"); + auto regularization_coeffs = + ctx.Attr>("regularization_coeff"); + if (regularization_methods.size() != 0) { + PADDLE_ENFORCE_EQ( + n, regularization_methods.size(), + platform::errors::InvalidArgument( + "The size of Attr(regularization_method) must be equal " + "to Input(Param), but got the size of " + "Attr(regularization_method) is %d, the size of Input(Param) is " + "%d.", + regularization_methods.size(), n)); + PADDLE_ENFORCE_EQ( + n, regularization_coeffs.size(), + platform::errors::InvalidArgument( + "The size of Attr(regularization_coeff) must be equal " + "to Input(Param), but got the size of Attr(regularization_coeff) " + "is %d, the size of Input(Param) is %d.", + regularization_coeffs.size(), n)); + } + + VLOG(5) << "use_nesterov: " << use_nesterov + << ", regularization_methods.size(): " + << regularization_methods.size() + << ", regularization_coeffs.size(): " + << regularization_coeffs.size(); + + auto& dev_ctx = ctx.template device_context(); + + Tensor mu_tensor; + mu_tensor.mutable_data(phi::make_ddim({1}), ctx.GetPlace()); + FillNpuTensorWithConstant(&mu_tensor, mu); + + for (size_t idx = 0; idx < n; ++idx) { + RegularizationType regularization_flag = + regularization_methods.size() > 0 && + regularization_methods[idx] == "l2_decay" + ? RegularizationType::kL2DECAY + : RegularizationType::kNONE; + float regularization_coeff = 0.0; + if (regularization_coeffs.size() != 0) { + regularization_coeff = regularization_coeffs[idx]; + } + + auto learning_rate = lrs.size() > 1 ? lrs[idx] : lrs[0]; + auto param = params[idx]; + auto param_out = params_out[idx]; + auto velocity = velocitys[idx]; + auto velocity_out = velocitys_out[idx]; + + auto grad = grads[idx]; + Tensor regularized_grad; + if (regularization_flag == RegularizationType::kL2DECAY) { + regularized_grad.mutable_data(grad->dims(), ctx.GetPlace()); + const auto& runner1 = NpuOpRunner("Muls", {*param}, {regularized_grad}, + {{"value", regularization_coeff}}); + runner1.Run(dev_ctx.stream()); + const auto& runner2 = NpuOpRunner("Add", {regularized_grad, *grad}, + {regularized_grad}, {}); + runner2.Run(dev_ctx.stream()); + } else { + regularized_grad.ShareDataWith(*grad); + } + framework::TensorCopy(*param, ctx.GetPlace(), dev_ctx, param_out); + framework::TensorCopy(*velocity, ctx.GetPlace(), dev_ctx, velocity_out); + // NOTE: ApplyMomentum will change the input + const auto& runner = NpuOpRunner( + "ApplyMomentum", {*param_out, *velocity_out, *learning_rate, + regularized_grad, mu_tensor}, + {*param_out}, {{"use_nesterov", use_nesterov}}); + runner.Run(dev_ctx.stream()); + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_NPU_KERNEL(merged_momentum, ops::NPUMergedMomentumOpKernel, + ops::NPUMergedMomentumOpKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_merged_momentum_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_merged_momentum_op_npu.py new file mode 100644 index 00000000000..96a15fc1caa --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_merged_momentum_op_npu.py @@ -0,0 +1,373 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +sys.path.append('..') +import unittest +import paddle +import numpy as np +from paddle.fluid.layer_helper import LayerHelper +from collections import OrderedDict + + +def run_momentum_op(params, + grads, + velocitys, + master_params, + learning_rate, + place, + multi_precision, + mu=0.9, + rescale_grad=0.01, + use_merged=False): + assert len(params) == len(grads) + assert len(params) == len(velocitys) + if multi_precision: + assert len(params) == len(master_params) + op_type = 'merged_momentum' if use_merged else 'momentum' + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): + helper = LayerHelper(op_type, **locals()) + attrs = { + 'mu': mu, + 'multi_precision': multi_precision, + 'rescale_grad': rescale_grad, + } + + param_vars = [ + helper.create_variable( + persistable=True, shape=p.shape, dtype=p.dtype) for p in params + ] + grad_vars = [ + helper.create_variable( + shape=g.shape, dtype=g.dtype) for g in grads + ] + velocity_vars = [ + helper.create_variable( + persistable=True, shape=v.shape, dtype=v.dtype) + for v in velocitys + ] + lr_var = helper.create_variable( + persistable=True, + shape=learning_rate.shape, + dtype=learning_rate.dtype) + + feed_dict = OrderedDict() + + feed_dict.update( + OrderedDict([(p_var.name, p_val) + for p_var, p_val in zip(param_vars, params)])) + feed_dict.update( + OrderedDict([(v_var.name, v_val) + for v_var, v_val in zip(velocity_vars, velocitys)])) + fetch_list = list(feed_dict.keys()) + + feed_dict.update( + OrderedDict([(g_var.name, g_val) + for g_var, g_val in zip(grad_vars, grads)])) + feed_dict.update({lr_var.name: learning_rate}) + + if multi_precision: + master_param_vars = [ + helper.create_variable( + persistable=True, shape=p.shape, dtype=p.dtype) + for p in master_params + ] + feed_dict.update( + OrderedDict([(mp_var.name, mp_val) + for mp_var, mp_val in zip(master_param_vars, + master_params)])) + # CPUPlace does not use MasterParam + if isinstance(place, paddle.CUDAPlace): + fetch_list = fetch_list + [ + mp_var.name for mp_var in master_param_vars + ] + else: + master_param_vars = None + + if not use_merged: + for i, (p, g, + v) in enumerate(zip(param_vars, grad_vars, velocity_vars)): + inputs = { + 'Param': p, + 'Grad': g, + 'Velocity': v, + 'LearningRate': lr_var, + } + outputs = {'ParamOut': p, 'VelocityOut': v} + if multi_precision: + inputs['MasterParam'] = master_param_vars[i] + outputs['MasterParamOut'] = master_param_vars[i] + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + else: + inputs = { + 'Param': param_vars, + 'Grad': grad_vars, + 'Velocity': velocity_vars, + 'LearningRate': lr_var, + } + outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars} + if multi_precision: + inputs['MasterParam'] = master_param_vars + outputs['MasterParamOut'] = master_param_vars + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + + exe = paddle.static.Executor(place) + with paddle.static.scope_guard(paddle.static.Scope()): + exe.run(startup) + return exe.run(main, feed=feed_dict, fetch_list=fetch_list) + + +def run_momentum_op2(params, + grads, + velocitys, + master_params, + learning_rate, + place, + multi_precision, + mu=0.9, + rescale_grad=0.01, + use_merged=False, + use_nesterov=True): + assert len(params) == len(grads) + assert len(params) == len(velocitys) + if multi_precision: + assert len(params) == len(master_params) + op_type = 'merged_momentum' if use_merged else 'momentum' + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): + helper = LayerHelper(op_type, **locals()) + + param_vars = [ + helper.create_variable( + persistable=True, shape=p.shape, dtype=p.dtype) for p in params + ] + grad_vars = [ + helper.create_variable( + shape=g.shape, dtype=g.dtype) for g in grads + ] + velocity_vars = [ + helper.create_variable( + persistable=True, shape=v.shape, dtype=v.dtype) + for v in velocitys + ] + lr_var = helper.create_variable( + persistable=True, + shape=learning_rate.shape, + dtype=learning_rate.dtype) + + feed_dict = OrderedDict() + + feed_dict.update( + OrderedDict([(p_var.name, p_val) + for p_var, p_val in zip(param_vars, params)])) + feed_dict.update( + OrderedDict([(v_var.name, v_val) + for v_var, v_val in zip(velocity_vars, velocitys)])) + fetch_list = list(feed_dict.keys()) + + feed_dict.update( + OrderedDict([(g_var.name, g_val) + for g_var, g_val in zip(grad_vars, grads)])) + feed_dict.update({lr_var.name: learning_rate}) + + if multi_precision: + master_param_vars = [ + helper.create_variable( + persistable=True, shape=p.shape, dtype=p.dtype) + for p in master_params + ] + feed_dict.update( + OrderedDict([(mp_var.name, mp_val) + for mp_var, mp_val in zip(master_param_vars, + master_params)])) + # CPUPlace does not use MasterParam + if isinstance(place, paddle.CUDAPlace): + fetch_list = fetch_list + [ + mp_var.name for mp_var in master_param_vars + ] + else: + master_param_vars = None + + if not use_merged: + for i, (p, g, + v) in enumerate(zip(param_vars, grad_vars, velocity_vars)): + inputs = { + 'Param': p, + 'Grad': g, + 'Velocity': v, + 'LearningRate': lr_var, + } + outputs = {'ParamOut': p, 'VelocityOut': v} + if multi_precision: + inputs['MasterParam'] = master_param_vars[i] + outputs['MasterParamOut'] = master_param_vars[i] + attrs = { + 'mu': mu, + 'multi_precision': multi_precision, + 'rescale_grad': rescale_grad, + 'use_nesterov': use_nesterov, + 'regularization_method': 'l2_decay', + 'regularization_coeff': 2.0, + } + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + else: + inputs = { + 'Param': param_vars, + 'Grad': grad_vars, + 'Velocity': velocity_vars, + 'LearningRate': lr_var, + } + outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars} + if multi_precision: + inputs['MasterParam'] = master_param_vars + outputs['MasterParamOut'] = master_param_vars + attrs = { + 'mu': mu, + 'multi_precision': multi_precision, + 'rescale_grad': rescale_grad, + 'use_nesterov': use_nesterov, + 'regularization_method': + ['l2_decay' for i in range(len(param_vars))], + 'regularization_coeff': [2.0 for i in range(len(param_vars))], + } + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + + exe = paddle.static.Executor(place) + with paddle.static.scope_guard(paddle.static.Scope()): + exe.run(startup) + return exe.run(main, feed=feed_dict, fetch_list=fetch_list) + + +class TestMergedMomentum(unittest.TestCase): + def setUp(self): + paddle.enable_static() + self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]] + self.seed = 10 + self.place = paddle.fluid.NPUPlace(0) + self.__class__.use_npu = True + + def gen_rand_data(self, shapes, dtype): + return [np.random.random(s).astype(dtype) for s in shapes] + + def prepare_data(self, shapes, multi_precision, seed, place): + np.random.seed(seed) + mp_dtype = np.float32 + dtype = np.float32 + params = self.gen_rand_data(shapes, dtype) + grads = self.gen_rand_data(shapes, dtype) + velocitys = self.gen_rand_data(shapes, mp_dtype) + learning_rate = self.gen_rand_data([[1]], mp_dtype)[0] + if multi_precision: + master_params = [p.astype(mp_dtype) for p in params] + else: + master_params = None + return params, grads, velocitys, master_params, learning_rate + + def check_with_place(self, place, multi_precision): + params, grads, velocitys, master_params, learning_rate = self.prepare_data( + self.shapes, multi_precision, self.seed, place) + + def run_op(use_merged): + # NPU Momentum Op does not support rescale_grad + rescale_grad = 1.0 + return run_momentum_op( + params, + grads, + velocitys, + master_params, + learning_rate, + place, + multi_precision, + rescale_grad=rescale_grad, + use_merged=use_merged) + + outs1 = run_op(True) + outs2 = run_op(False) + self.assertEqual(len(outs1), len(outs2)) + for i, (out1, out2) in enumerate(zip(outs1, outs2)): + self.assertTrue(np.allclose(out1, out2, atol=1e-7)) + + def test_main(self): + self.check_with_place(self.place, multi_precision=False) + + +class TestMergedMomentum2(unittest.TestCase): + def setUp(self): + paddle.enable_static() + self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]] + self.seed = 10 + self.place = paddle.fluid.NPUPlace(0) + self.__class__.use_npu = True + + def gen_rand_data(self, shapes, dtype): + return [np.random.random(s).astype(dtype) for s in shapes] + + def prepare_data(self, shapes, multi_precision, seed, place): + np.random.seed(seed) + mp_dtype = np.float32 + dtype = np.float32 # np.float16 + params = self.gen_rand_data(shapes, dtype) + grads = self.gen_rand_data(shapes, dtype) + velocitys = self.gen_rand_data(shapes, mp_dtype) + learning_rate = self.gen_rand_data([[1]], mp_dtype)[0] + if multi_precision: + master_params = [p.astype(mp_dtype) for p in params] + else: + master_params = None + return params, grads, velocitys, master_params, learning_rate + + def check_with_place(self, place, multi_precision): + params, grads, velocitys, master_params, learning_rate = self.prepare_data( + self.shapes, multi_precision, self.seed, place) + + def run_op(use_nesterov, use_merged): + # NPU Momentum Op does not support rescale_grad + rescale_grad = 1.0 + return run_momentum_op2( + params, + grads, + velocitys, + master_params, + learning_rate, + place, + multi_precision, + rescale_grad=rescale_grad, + use_merged=use_merged, + use_nesterov=use_nesterov) + + outs1 = run_op(use_nesterov=True, use_merged=True) + outs2 = run_op(use_nesterov=True, use_merged=False) + self.assertEqual(len(outs1), len(outs2)) + for i, (out1, out2) in enumerate(zip(outs1, outs2)): + self.assertTrue(np.allclose(out1, out2, atol=1e-7)) + + outs3 = run_op(use_nesterov=False, use_merged=True) + outs4 = run_op(use_nesterov=False, use_merged=False) + self.assertEqual(len(outs3), len(outs4)) + for j, (out3, out4) in enumerate(zip(outs3, outs4)): + self.assertTrue(np.allclose(out3, out4, atol=1e-7)) + + def test_main(self): + self.check_with_place(self.place, multi_precision=False) + + +if __name__ == "__main__": + unittest.main() -- GitLab