From f4eda869f3f46d0f5097e4a10af4566a9e15e786 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Thu, 14 Oct 2021 14:41:15 +0800 Subject: [PATCH] Merge momentum ops/kernels (#36380) * merge momentum ops * update * add ut to improve coverage * remove optimizer change * fix error msg * update ut * add __restrict__ for CUDA * update ut * move merged_momentum_op to optimizer dir * fix coverage --- .../optimizers/merged_momentum_op.cc | 95 +++++++++ .../optimizers/merged_momentum_op.cu | 24 +++ .../operators/optimizers/merged_momentum_op.h | 197 ++++++++++++++++++ paddle/fluid/platform/macros.h | 6 + .../unittests/test_merged_momentum_op.py | 194 +++++++++++++++++ 5 files changed, 516 insertions(+) create mode 100644 paddle/fluid/operators/optimizers/merged_momentum_op.cc create mode 100644 paddle/fluid/operators/optimizers/merged_momentum_op.cu create mode 100644 paddle/fluid/operators/optimizers/merged_momentum_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_merged_momentum_op.py diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op.cc b/paddle/fluid/operators/optimizers/merged_momentum_op.cc new file mode 100644 index 00000000000..6c63376b5eb --- /dev/null +++ b/paddle/fluid/operators/optimizers/merged_momentum_op.cc @@ -0,0 +1,95 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/optimizers/merged_momentum_op.h" + +namespace paddle { +namespace operators { + +class MergedMomentumOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override {} + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto param_dtype = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "Param"); + return framework::OpKernelType(param_dtype, ctx.GetPlace()); + } +}; + +class MergedMomentumOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Param", + "(Tensor, default Tensor) " + "Input parameter that has to be updated") + .AsDuplicable(); + AddInput("Grad", + "(Tensor, default Tensor) " + "Input gradient of the parameter") + .AsDuplicable(); + AddInput("Velocity", + "(Tensor, default Tensor) " + "Input velocity (corresponding to the parameter) " + "that has to be updated") + .AsDuplicable(); + AddInput("LearningRate", + "(Tensor, default Tensor) " + "Input learning rate"); + AddInput("MasterParam", "FP32 master weight for AMP.") + .AsDispensable() + .AsDuplicable(); + AddOutput("ParamOut", + "(Tensor) This output is updated parameter. " + "It shared memory with Input(Param).") + .AsDuplicable(); + AddOutput("VelocityOut", + "(Tensor) This output is updated velocity. " + "It shared memory with Input(Velocity).") + .AsDuplicable(); + AddOutput("MasterParamOut", + "The updated FP32 master weight for AMP. " + "It shared memory with Input(MasterParam).") + .AsDispensable() + .AsDuplicable(); + AddAttr("mu", "(float) Momentum coefficient"); + AddAttr("multi_precision", + "(bool, default false) " + "Whether to use multi-precision during weight updating.") + .SetDefault(false); + AddAttr( + "rescale_grad", + "(float, default 1.0) Multiply the gradient with `rescale_grad`" + "before updating. Often choose to be `1.0/batch_size`.") + .SetDefault(1.0f); + AddComment(R"DOC(Merged Momentum Optimizer.)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(merged_momentum, ops::MergedMomentumOp, + ops::MergedMomentumOpMaker); + +REGISTER_OP_CPU_KERNEL( + merged_momentum, ops::MergedMomentumOpKernel, + ops::MergedMomentumOpKernel); diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op.cu b/paddle/fluid/operators/optimizers/merged_momentum_op.cu new file mode 100644 index 00000000000..7e4bbd98079 --- /dev/null +++ b/paddle/fluid/operators/optimizers/merged_momentum_op.cu @@ -0,0 +1,24 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/optimizers/merged_momentum_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + merged_momentum, + ops::MergedMomentumOpKernel, + ops::MergedMomentumOpKernel, + ops::MergedMomentumOpKernel); diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op.h b/paddle/fluid/operators/optimizers/merged_momentum_op.h new file mode 100644 index 00000000000..4dfaa4de3ad --- /dev/null +++ b/paddle/fluid/operators/optimizers/merged_momentum_op.h @@ -0,0 +1,197 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/platform/for_range.h" +#include "paddle/fluid/platform/macros.h" + +namespace paddle { +namespace operators { + +template +struct MergedMomentumMasterParams { + MT *PADDLE_RESTRICT master_params[kParamNum]; + + HOSTDEVICE MT *MasterParam(size_t idx) const { return master_params[idx]; } + HOSTDEVICE void SetMasterParam(size_t idx, MT *p) { master_params[idx] = p; } +}; + +template +struct MergedMomentumMasterParams { + HOSTDEVICE constexpr MT *MasterParam(size_t) const { return nullptr; } + HOSTDEVICE constexpr void SetMasterParam(size_t, MT *) {} +}; + +template +struct MergedMomentumKernelParam + : public MergedMomentumMasterParams { + static constexpr auto N = kParamNum; + size_t sizes[N]; + T *PADDLE_RESTRICT params[N]; + const T *PADDLE_RESTRICT grads[N]; + MT *PADDLE_RESTRICT velocitys[N]; + const MT *PADDLE_RESTRICT lr; + MT mu; + MT rescale_grad; + uint32_t param_num; + + HOSTDEVICE void operator()(size_t i) const { + const auto lr_val = *lr; + for (uint32_t idx = 0; idx < param_num; ++idx) { + auto size = sizes[idx]; + if (i >= size) continue; + + auto param_p = params[idx]; + auto grad_p = grads[idx]; + auto velocity_p = velocitys[idx]; + auto master_param_p = this->MasterParam(idx); + + const MT param = + master_param_p ? master_param_p[i] : static_cast(param_p[i]); + const MT grad = static_cast(grad_p[i]) * rescale_grad; + const MT velocity = velocity_p[i]; + const MT velocity_out = velocity * mu + grad; + const MT param_out = param - lr_val * velocity_out; + velocity_p[i] = velocity_out; + param_p[i] = static_cast(param_out); + if (master_param_p) { + master_param_p[i] = param_out; + } + } + } +}; + +template +class MergedMomentumOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto params = ctx.MultiInput("Param"); + auto params_out = ctx.MultiOutput("ParamOut"); + size_t n = params.size(); + PADDLE_ENFORCE_EQ( + n, params_out.size(), + platform::errors::InvalidArgument( + "Output(ParamOut) number must be equal to Input(Param) number.")); + for (size_t i = 0; i < n; ++i) { + PADDLE_ENFORCE_EQ( + params[i], params_out[i], + platform::errors::InvalidArgument( + "Input(Param) and Output(ParamOut) must be the same Tensors.")); + } + + auto grads = ctx.MultiInput("Grad"); + PADDLE_ENFORCE_EQ( + n, grads.size(), + platform::errors::InvalidArgument( + "Input(Grad) number must be equal to Input(Param) number.")); + + auto velocitys = ctx.MultiInput("Velocity"); + PADDLE_ENFORCE_EQ(n, velocitys.size(), + platform::errors::InvalidArgument( + "Input(Velocity) number and Input(Param) number.")); + + auto velocitys_out = ctx.MultiOutput("VelocityOut"); + PADDLE_ENFORCE_EQ( + n, velocitys_out.size(), + platform::errors::InvalidArgument("Output(VelocityOut) number must be " + "equal to Input(Param) number.")); + for (size_t i = 0; i < n; ++i) { + PADDLE_ENFORCE_EQ(velocitys[i], velocitys_out[i], + platform::errors::InvalidArgument( + "Input(Velocity) and Output(VelocityOut) must be " + "the same Tensors.")); + } + + auto master_params = ctx.MultiInput("MasterParam"); + auto master_params_out = + ctx.MultiOutput("MasterParamOut"); + auto multi_precision = ctx.Attr("multi_precision"); + if (multi_precision) { + PADDLE_ENFORCE_EQ( + n, master_params.size(), + platform::errors::InvalidArgument("Input(MasterParam) number must be " + "equal to Input(Param) number.")); + PADDLE_ENFORCE_EQ(n, master_params_out.size(), + platform::errors::InvalidArgument( + "Output(MasterParamOut) number must be equal to " + "Input(MasterParam) number.")); + for (size_t i = 0; i < n; ++i) { + PADDLE_ENFORCE_EQ(master_params[i], master_params_out[i], + platform::errors::InvalidArgument( + "Input(MasterParam) and Output(MasterParamOut) " + "must be the same Tensors.")); + PADDLE_ENFORCE_NOT_NULL(master_params[i], + platform::errors::InvalidArgument( + "Input(MasterParam) must be provided when " + "multi_precision=True.")); + } + } else { + master_params.clear(); + master_params_out.clear(); + } + + auto lr = ctx.Input("LearningRate"); + auto mu = ctx.Attr("mu"); + auto rescale_grad = ctx.Attr("rescale_grad"); + using MPType = typename operators::details::MPTypeTrait::Type; + + auto &dev_ctx = ctx.template device_context(); + +#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \ + MergedMomentumKernelParam kernel_params; \ + constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \ + size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \ + kernel_params.mu = static_cast(mu); \ + kernel_params.rescale_grad = static_cast(rescale_grad); \ + kernel_params.lr = lr->data(); \ + for (size_t i = 0; i < kernel_num; ++i) { \ + size_t start = i * kMaxMergedNum; \ + size_t end = std::min((i + 1) * kMaxMergedNum, n); \ + kernel_params.param_num = static_cast(end - start); \ + size_t max_size = 0; \ + for (size_t j = 0; j < kernel_params.param_num; ++j) { \ + auto size = static_cast(params_out[j + start]->numel()); \ + max_size = std::max(max_size, size); \ + kernel_params.sizes[j] = size; \ + kernel_params.params[j] = params_out[j + start]->data(); \ + kernel_params.grads[j] = grads[j + start]->data(); \ + kernel_params.velocitys[j] = velocitys_out[j + start]->data(); \ + kernel_params.SetMasterParam( \ + j, kMultiPrecision ? master_params_out[j + start]->data() \ + : nullptr); \ + } \ + platform::ForRange for_range(dev_ctx, max_size); \ + for_range(kernel_params); \ + VLOG(10) << "Launch MergedMomentum kernel " << i << " " \ + << kernel_params.param_num; \ + } + + if (multi_precision) { + PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true); + } else { + PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false); + } + +#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/macros.h b/paddle/fluid/platform/macros.h index fb5cf9fb319..bf089ac117d 100644 --- a/paddle/fluid/platform/macros.h +++ b/paddle/fluid/platform/macros.h @@ -30,3 +30,9 @@ limitations under the License. */ #define FLT_MAX __FLT_MAX__ #endif // __FLT_MAX__ #endif // PADDLE_WITH_MUSL + +#if defined(__NVCC__) || defined(__HIPCC__) +#define PADDLE_RESTRICT __restrict__ +#else +#define PADDLE_RESTRICT +#endif diff --git a/python/paddle/fluid/tests/unittests/test_merged_momentum_op.py b/python/paddle/fluid/tests/unittests/test_merged_momentum_op.py new file mode 100644 index 00000000000..0118a372c3f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_merged_momentum_op.py @@ -0,0 +1,194 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import numpy as np +from paddle.fluid.layer_helper import LayerHelper +from collections import OrderedDict + + +def run_momentum_op(params, + grads, + velocitys, + master_params, + learning_rate, + place, + multi_precision, + mu=0.9, + rescale_grad=0.01, + use_merged=False): + assert len(params) == len(grads) + assert len(params) == len(velocitys) + if multi_precision: + assert len(params) == len(master_params) + op_type = 'merged_momentum' if use_merged else 'momentum' + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): + helper = LayerHelper(op_type, **locals()) + attrs = { + 'mu': mu, + 'multi_precision': multi_precision, + 'rescale_grad': rescale_grad, + } + + param_vars = [ + helper.create_variable( + persistable=True, shape=p.shape, dtype=p.dtype) for p in params + ] + grad_vars = [ + helper.create_variable( + shape=g.shape, dtype=g.dtype) for g in grads + ] + velocity_vars = [ + helper.create_variable( + persistable=True, shape=v.shape, dtype=v.dtype) + for v in velocitys + ] + lr_var = helper.create_variable( + persistable=True, + shape=learning_rate.shape, + dtype=learning_rate.dtype) + + feed_dict = OrderedDict() + + feed_dict.update( + OrderedDict([(p_var.name, p_val) + for p_var, p_val in zip(param_vars, params)])) + feed_dict.update( + OrderedDict([(v_var.name, v_val) + for v_var, v_val in zip(velocity_vars, velocitys)])) + fetch_list = list(feed_dict.keys()) + + feed_dict.update( + OrderedDict([(g_var.name, g_val) + for g_var, g_val in zip(grad_vars, grads)])) + feed_dict.update({lr_var.name: learning_rate}) + + if multi_precision: + master_param_vars = [ + helper.create_variable( + persistable=True, shape=p.shape, dtype=p.dtype) + for p in master_params + ] + feed_dict.update( + OrderedDict([(mp_var.name, mp_val) + for mp_var, mp_val in zip(master_param_vars, + master_params)])) + # CPUPlace does not use MasterParam + if isinstance(place, paddle.CUDAPlace): + fetch_list = fetch_list + [ + mp_var.name for mp_var in master_param_vars + ] + else: + master_param_vars = None + + if not use_merged: + for i, (p, g, + v) in enumerate(zip(param_vars, grad_vars, velocity_vars)): + inputs = { + 'Param': p, + 'Grad': g, + 'Velocity': v, + 'LearningRate': lr_var + } + outputs = {'ParamOut': p, 'VelocityOut': v} + if multi_precision: + inputs['MasterParam'] = master_param_vars[i] + outputs['MasterParamOut'] = master_param_vars[i] + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + else: + inputs = { + 'Param': param_vars, + 'Grad': grad_vars, + 'Velocity': velocity_vars, + 'LearningRate': lr_var + } + outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars} + if multi_precision: + inputs['MasterParam'] = master_param_vars + outputs['MasterParamOut'] = master_param_vars + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + + exe = paddle.static.Executor(place) + with paddle.static.scope_guard(paddle.static.Scope()): + exe.run(startup) + return exe.run(main, feed=feed_dict, fetch_list=fetch_list) + + +class TestMergedMomentum(unittest.TestCase): + def setUp(self): + paddle.enable_static() + self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]] + self.seed = 10 + + def gen_rand_data(self, shapes, dtype): + return [np.random.random(s).astype(dtype) for s in shapes] + + def prepare_data(self, shapes, multi_precision, seed, place): + np.random.seed(seed) + mp_dtype = np.float32 + dtype = np.float16 if multi_precision and isinstance( + place, paddle.CUDAPlace) else np.float32 + params = self.gen_rand_data(shapes, dtype) + grads = self.gen_rand_data(shapes, dtype) + velocitys = self.gen_rand_data(shapes, mp_dtype) + learning_rate = self.gen_rand_data([[1]], mp_dtype)[0] + if multi_precision: + master_params = [p.astype(mp_dtype) for p in params] + else: + master_params = None + return params, grads, velocitys, master_params, learning_rate + + def check_with_place(self, place, multi_precision): + params, grads, velocitys, master_params, learning_rate = self.prepare_data( + self.shapes, multi_precision, self.seed, place) + + def run_op(use_merged): + # FIXME(zengjinle): CPU Momentum Op does not support rescale_grad + rescale_grad = 1.0 if isinstance(place, paddle.CPUPlace) else 0.01 + return run_momentum_op( + params, + grads, + velocitys, + master_params, + learning_rate, + place, + multi_precision, + rescale_grad=rescale_grad, + use_merged=use_merged) + + outs1 = run_op(True) + outs2 = run_op(False) + self.assertEqual(len(outs1), len(outs2)) + for i, (out1, out2) in enumerate(zip(outs1, outs2)): + self.assertTrue(np.allclose(out1, out2, atol=1e-7)) + + def get_places(self): + places = [paddle.CPUPlace()] + if paddle.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + return places + + def test_main(self): + for multi_precision in [False, True]: + for place in self.get_places(): + self.check_with_place(place, multi_precision) + + +if __name__ == "__main__": + unittest.main() -- GitLab