diff --git a/paddle/fluid/operators/optimizers/adam_op_mlu.cc b/paddle/fluid/operators/optimizers/adam_op_mlu.cc index ecc527d5c72bf0757b71b382febe29e4c594a175..6ee63354fbff41572026db658719152f0633b20e 100644 --- a/paddle/fluid/operators/optimizers/adam_op_mlu.cc +++ b/paddle/fluid/operators/optimizers/adam_op_mlu.cc @@ -333,6 +333,224 @@ class AdamWMLUKernel : public AdamMLUKernel { } }; +template +class MergedAdamMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // Get inputs and outputs + auto params = ctx.MultiInput("Param"); + auto grads = ctx.MultiInput("Grad"); + auto lrs = ctx.MultiInput("LearningRate"); + auto mom1s = ctx.MultiInput("Moment1"); + auto mom2s = ctx.MultiInput("Moment2"); + auto beta1_pows = ctx.MultiInput("Beta1Pow"); + auto beta2_pows = ctx.MultiInput("Beta2Pow"); + auto master_params = ctx.MultiInput("MasterParam"); + auto param_outs = ctx.MultiOutput("ParamOut"); + auto mom1_outs = ctx.MultiOutput("Moment1Out"); + auto mom2_outs = ctx.MultiOutput("Moment2Out"); + auto beta1_pow_outs = ctx.MultiOutput("Beta1PowOut"); + auto beta2_pow_outs = ctx.MultiOutput("Beta2PowOut"); + + // Check validation of inputs and outputs + size_t param_num = params.size(); + PADDLE_ENFORCE_EQ(param_num, + param_outs.size(), + platform::errors::InvalidArgument( + "The size of Output(ParamOut) must be equal to " + "Input(Param), but got the size of Output(ParamOut) " + "is %d, the size of Input(Param) is %d.", + param_outs.size(), + param_num)); + + bool skip_update = false; + if (ctx.HasInput("SkipUpdate")) { + auto* skip_update_tensor = ctx.Input("SkipUpdate"); + PADDLE_ENFORCE_EQ(skip_update_tensor->numel(), + 1, + platform::errors::InvalidArgument( + "Input(SkipUpdate) size must be 1, but get %d", + skip_update_tensor->numel())); + std::vector skip_update_vec; + paddle::framework::TensorToVector( + *skip_update_tensor, ctx.device_context(), &skip_update_vec); + ctx.device_context().Wait(); + skip_update = skip_update_vec[0]; + } + // skip_update=true, just copy input to output, and TensorCopy will call + // mutable_data + + if (skip_update) { + VLOG(4) << "MergedAdam skip update"; + for (size_t i = 0; i < param_num; ++i) { + framework::TensorCopy( + *params[i], + ctx.GetPlace(), + ctx.template device_context(), + param_outs[i]); + framework::TensorCopy( + *mom1s[i], + ctx.GetPlace(), + ctx.template device_context(), + mom1_outs[i]); + framework::TensorCopy( + *mom2s[i], + ctx.GetPlace(), + ctx.template device_context(), + mom2_outs[i]); + framework::TensorCopy( + *beta1_pows[i], + beta1_pows[i]->place(), + ctx.template device_context(), + beta1_pow_outs[i]); + framework::TensorCopy( + *beta2_pows[i], + beta2_pows[i]->place(), + ctx.template device_context(), + beta2_pow_outs[i]); + } + return; + } + + bool use_global_beta_pow = ctx.Attr("use_global_beta_pow"); + VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; + + // Get beta1, beta2 and epsilon from attribute. + const Tensor* beta1_tensor = nullptr; + const Tensor* beta2_tensor = nullptr; + const Tensor* epsilon_tensor = nullptr; + + Tensor beta1_tmp(experimental::DataType::FLOAT32); + Tensor beta2_tmp(experimental::DataType::FLOAT32); + Tensor epsilon_tmp(experimental::DataType::FLOAT32); + + T beta1 = static_cast(ctx.Attr("beta1")); + T beta2 = static_cast(ctx.Attr("beta2")); + T epsilon = static_cast(ctx.Attr("epsilon")); + beta1_tmp.mutable_data({1}, ctx.GetPlace()); + beta2_tmp.mutable_data({1}, ctx.GetPlace()); + epsilon_tmp.mutable_data({1}, ctx.GetPlace()); + MLUCnnlTensorDesc beta1_tmp_desc(beta1_tmp); + MLUCnnlTensorDesc beta2_tmp_desc(beta2_tmp); + MLUCnnlTensorDesc epsilon_tmp_desc(epsilon_tmp); + MLUCnnl::Fill(ctx, + CNNL_POINTER_MODE_HOST, + &beta1, + beta1_tmp_desc.get(), + GetBasePtr(&beta1_tmp)); + MLUCnnl::Fill(ctx, + CNNL_POINTER_MODE_HOST, + &beta2, + beta2_tmp_desc.get(), + GetBasePtr(&beta2_tmp)); + MLUCnnl::Fill(ctx, + CNNL_POINTER_MODE_HOST, + &epsilon, + epsilon_tmp_desc.get(), + GetBasePtr(&epsilon_tmp)); + beta1_tensor = &beta1_tmp; + beta2_tensor = &beta2_tmp; + epsilon_tensor = &epsilon_tmp; + + // Loop to compute + for (size_t i = 0; i < param_num; ++i) { + VLOG(4) << "[MergedAdam] loop: " << i; + param_outs[i]->ShareDataWith(*params[i]); + mom1_outs[i]->ShareDataWith(*mom1s[i]); + mom2_outs[i]->ShareDataWith(*mom2s[i]); + + LoDTensor beta1_pow_tmp; + LoDTensor beta2_pow_tmp; + if (beta1_pows[i]->place() == platform::CPUPlace()) { + T beta1 = *beta1_pows[i]->data(); + beta1_pow_tmp.mutable_data({1}, ctx.GetPlace()); + MLUCnnlTensorDesc beta1_pow_tmp_desc(beta1_pow_tmp); + MLUCnnl::Fill(ctx, + CNNL_POINTER_MODE_HOST, + &beta1, + beta1_pow_tmp_desc.get(), + GetBasePtr(&beta1_pow_tmp)); + beta1_pows[i] = &beta1_pow_tmp; + } + if (beta2_pows[i]->place() == platform::CPUPlace()) { + T beta2 = *beta2_pows[i]->data(); + beta2_pow_tmp.mutable_data({1}, ctx.GetPlace()); + MLUCnnlTensorDesc beta2_pow_tmp_desc(beta2_pow_tmp); + MLUCnnl::Fill(ctx, + CNNL_POINTER_MODE_HOST, + &beta2, + beta2_pow_tmp_desc.get(), + GetBasePtr(&beta2_pow_tmp)); + beta2_pows[i] = &beta2_pow_tmp; + } + + VLOG(3) << "beta1_pow.numel() : " << beta1_pows[i]->numel() + << "beta2_pow.numel() : " << beta2_pows[i]->numel(); + VLOG(3) << "param.numel(): " << params[i]->numel(); + PADDLE_ENFORCE_EQ(beta1_pow_outs[i]->numel(), + 1, + platform::errors::InvalidArgument( + "beta1 pow output size should be 1, but received " + "value is:%d.", + beta1_pow_outs[i]->numel())); + + PADDLE_ENFORCE_EQ(beta2_pow_outs[i]->numel(), + 1, + platform::errors::InvalidArgument( + "beta2 pow output size should be 1, but received " + "value is:%d.", + beta2_pow_outs[i]->numel())); + MLUCnnlTensorDesc param_desc(*params[i]); + MLUCnnlTensorDesc mom1_desc(*mom1s[i]); + MLUCnnlTensorDesc mom2_desc(*mom2s[i]); + MLUCnnlTensorDesc grad_desc(*grads[i]); + MLUCnnl::ApplyAdam(ctx, + param_desc.get(), + GetBasePtr(param_outs[i]), + mom1_desc.get(), + GetBasePtr(mom1_outs[i]), + mom2_desc.get(), + GetBasePtr(mom2_outs[i]), + grad_desc.get(), + GetBasePtr(grads[i]), + GetBasePtr(lrs[i]), + GetBasePtr(beta1_tensor), + GetBasePtr(beta2_tensor), + GetBasePtr(beta1_pows[i]), + GetBasePtr(beta2_pows[i]), + GetBasePtr(epsilon_tensor), + /*use_nesterov*/ false); + if (!use_global_beta_pow) { + beta1_pow_outs[i]->mutable_data(ctx.GetPlace()); + beta2_pow_outs[i]->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc beta1_desc(*beta1_tensor); + MLUCnnlOpTensorDesc mul_op_desc( + CNNL_OP_TENSOR_MUL, ToCnnlDataType(), CNNL_NOT_PROPAGATE_NAN); + + MLUCnnl::OpTensor(ctx, + mul_op_desc.get(), + beta1_desc.get(), + GetBasePtr(beta1_pows[i]), + beta1_desc.get(), + GetBasePtr(beta1_tensor), + beta1_desc.get(), + GetBasePtr(beta1_pow_outs[i]), + ToCnnlDataType()); + + MLUCnnl::OpTensor(ctx, + mul_op_desc.get(), + beta1_desc.get(), + GetBasePtr(beta2_pows[i]), + beta1_desc.get(), + GetBasePtr(beta2_tensor), + beta1_desc.get(), + GetBasePtr(beta2_pow_outs[i]), + ToCnnlDataType()); + } + } + } +}; } // namespace operators } // namespace paddle @@ -346,3 +564,7 @@ REGISTER_OP_MLU_KERNEL(adam, REGISTER_OP_MLU_KERNEL(adamw, ops::AdamWMLUKernel, ops::AdamWMLUKernel); + +REGISTER_OP_MLU_KERNEL(merged_adam, + ops::MergedAdamMLUKernel, + ops::MergedAdamMLUKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_merged_adam_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_merged_adam_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..3aa61e9f982fb8639f49df6af200196184cb41aa --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_merged_adam_op_mlu.py @@ -0,0 +1,163 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +sys.path.append('..') +import unittest +import paddle +import numpy as np +from paddle import _C_ops, _legacy_C_ops +from paddle.fluid.framework import in_dygraph_mode + + +def run_adam_op(params, + grads, + lrs, + moment1s, + moment2s, + beta1_pows, + beta2_pows, + master_params, + epsilon, + beta1, + beta2, + place, + multi_precision=False, + use_merged=False): + assert len(params) == len(grads) + assert len(params) == len(lrs) + assert len(params) == len(moment1s) + assert len(params) == len(moment2s) + assert len(params) == len(beta1_pows) + assert len(params) == len(beta1_pows) + assert len(params) == len(master_params) + paddle.disable_static() + # paddle.set_device(place) + + param_vars = [paddle.fluid.dygraph.to_variable(p) for p in params] + grad_vars = [paddle.fluid.dygraph.to_variable(g) for g in grads] + lr_vars = [paddle.fluid.dygraph.to_variable(l) for l in lrs] + moment1_vars = [paddle.fluid.dygraph.to_variable(m) for m in moment1s] + moment2_vars = [paddle.fluid.dygraph.to_variable(m) for m in moment2s] + beta1_pow_vars = [paddle.fluid.dygraph.to_variable(b) for b in beta1_pows] + beta2_pow_vars = [paddle.fluid.dygraph.to_variable(b) for b in beta2_pows] + master_param_vars = [ + paddle.fluid.dygraph.to_variable(m_p) for m_p in master_params + ] + + if not use_merged: + for i in range(len(param_vars)): + _, _, _, _, _, _ = _legacy_C_ops.adam( + param_vars[i], grad_vars[i], lr_vars[i], moment1_vars[i], + moment2_vars[i], beta1_pow_vars[i], beta2_pow_vars[i], + master_param_vars[i], param_vars[i], moment1_vars[i], + moment2_vars[i], beta1_pow_vars[i], beta2_pow_vars[i], + master_param_vars[i], 'epsilon', epsilon, 'beta1', beta1, + 'beta2', beta2, 'multi_precision', multi_precision) + else: + if in_dygraph_mode(): + _, _, _, _, _, _ = _C_ops.merged_adam_( + param_vars, grad_vars, lr_vars, moment1_vars, moment2_vars, + beta1_pow_vars, beta2_pow_vars, master_param_vars, beta1, beta2, + epsilon, multi_precision, False) + else: + _, _, _, _, _, _ = _legacy_C_ops.merged_adam( + param_vars, grad_vars, lr_vars, moment1_vars, moment2_vars, + beta1_pow_vars, beta2_pow_vars, master_param_vars, param_vars, + moment1_vars, moment2_vars, beta1_pow_vars, beta2_pow_vars, + master_param_vars, 'epsilon', epsilon, 'beta1', beta1, 'beta2', + beta2, 'multi_precision', multi_precision) + + outputs = { + 'ParamOut': param_vars, + 'Moment1Out': moment1_vars, + 'Moment2Out': moment2_vars, + 'Beta1PowOut': beta1_pow_vars, + 'Beta2PowOut': beta2_pow_vars, + 'MasterParamOut': master_param_vars + } + + return outputs + + +class TestMergedAdam(unittest.TestCase): + + def setUp(self): + paddle.disable_static() + self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]] + self.seed = 10 + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + + def gen_rand_data(self, shapes, dtype): + return [np.random.random(s).astype(dtype) for s in shapes] + + def prepare_data(self, shapes, multi_precision, seed, place): + np.random.seed(seed) + mp_dtype = np.float32 + # dtype = np.float16 if multi_precision and place == 'mlu' else np.float32 + dtype = np.float32 + params = self.gen_rand_data(shapes, dtype) + grads = self.gen_rand_data(shapes, dtype) + lrs = self.gen_rand_data([[1], [1], [1], [1]], mp_dtype) + moment1s = self.gen_rand_data(shapes, mp_dtype) + moment2s = self.gen_rand_data(shapes, mp_dtype) + beta1_pows = self.gen_rand_data([[1], [1], [1], [1]], mp_dtype) + beta2_pows = self.gen_rand_data([[1], [1], [1], [1]], mp_dtype) + master_params = [p.astype(mp_dtype) for p in params] + return params, grads, lrs, moment1s, moment2s, beta1_pows, beta2_pows, master_params + + def check_with_place(self, place, multi_precision): + params, grads, lrs, moment1s, moment2s, beta1_pows, beta2_pows, master_params = self.prepare_data( + self.shapes, multi_precision, self.seed, place) + + def run_op(use_merged): + return run_adam_op(params=params, + grads=grads, + lrs=lrs, + moment1s=moment1s, + moment2s=moment2s, + beta1_pows=beta1_pows, + beta2_pows=beta2_pows, + master_params=master_params, + epsilon=0.9, + beta1=0.9, + beta2=0.99, + place=place, + multi_precision=multi_precision, + use_merged=use_merged) + + outs1 = run_op(True) + outs2 = run_op(False) + self.assertEqual(len(outs1), len(outs2)) + + for key in outs1.keys(): + value1 = outs1[key] + value2 = outs2[key] + for i in range(len(value1)): + if place == 'mlu': + np.testing.assert_array_equal(value1[i], value2[i]) + else: + np.testing.assert_allclose(value1[i], + value2[i], + rtol=1e-05, + atol=1e-07) + + def test_main(self): + for multi_precision in [False, True]: + self.check_with_place(self.place, multi_precision) + + +if __name__ == "__main__": + unittest.main()