diff --git a/paddle/fluid/operators/optimizers/multi_tensor_adam_op.cc b/paddle/fluid/operators/fused/fused_adam_op.cc similarity index 84% rename from paddle/fluid/operators/optimizers/multi_tensor_adam_op.cc rename to paddle/fluid/operators/fused/fused_adam_op.cc index 55a84dc0dbc8afd4161dd9aa920b6c4c5da55d66..d786dbd7c2728f5259971a3827954bec2a843000 100644 --- a/paddle/fluid/operators/optimizers/multi_tensor_adam_op.cc +++ b/paddle/fluid/operators/fused/fused_adam_op.cc @@ -22,19 +22,34 @@ namespace operators { using Tensor = phi::DenseTensor; -class MultiTensorAdamOp : public framework::OperatorWithKernel { +class FusedAdamOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { + const framework::ExecutionContext &ctx) const override { auto param_dtype = framework::OperatorWithKernel::IndicateVarDataType(ctx, "Params"); return phi::KernelKey(param_dtype, ctx.GetPlace()); } + + phi::KernelKey GetKernelTypeForVar( + const std::string &var_name, + const phi::DenseTensor &tensor, + const phi::KernelKey &expected_kernel_type) const override { + if (var_name == "Beta1Pows" || var_name == "Beta2Pows" || + var_name == "SkipUpdate") { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); + } else { + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); + } + } }; -class MultiTensorAdamOpMaker : public framework::OpProtoAndCheckerMaker { +class FusedAdamOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("Params", "(Tensor) Input parameters").AsDuplicable(); @@ -144,13 +159,13 @@ $$ } // namespace paddle namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(multi_tensor_adam, - MultiTensorAdamInferShapeFunctor, - PD_INFER_META(phi::MultiTensorAdamInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR(fused_adam, + FusedAdamInferShapeFunctor, + PD_INFER_META(phi::FusedAdamInferMeta)); REGISTER_OPERATOR( - multi_tensor_adam, - ops::MultiTensorAdamOp, - ops::MultiTensorAdamOpMaker, + fused_adam, + ops::FusedAdamOp, + ops::FusedAdamOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker, - MultiTensorAdamInferShapeFunctor); + FusedAdamInferShapeFunctor); diff --git a/paddle/fluid/pybind/eager_generator.h b/paddle/fluid/pybind/eager_generator.h index 1893373527d797d909e54ad82839cd4e4ddf5f9a..8101d506555a9d8fe1c3efbf3282194dbf57fdca 100644 --- a/paddle/fluid/pybind/eager_generator.h +++ b/paddle/fluid/pybind/eager_generator.h @@ -166,7 +166,7 @@ std::map> op_ins_map = { "Beta1Pow", "Beta2Pow", "MasterParam"}}, - {"multi_tensor_adam", + {"fused_adam", {"Params", "Grads", "LearningRate", @@ -332,7 +332,7 @@ std::map> op_outs_map = { "Beta1PowOut", "Beta2PowOut", "MasterParamOut"}}, - {"multi_tensor_adam", + {"fused_adam", {"ParamsOut", "Moments1Out", "Moments2Out", @@ -400,7 +400,7 @@ std::map> op_passing_outs_map = { "Beta1PowOut", "Beta2PowOut", "MasterParamOut"}}, - {"multi_tensor_adam", + {"fused_adam", {"ParamsOut", "Moments1Out", "Moments2Out", diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 11f7449f4507c99420d7496330134305d53bc40d..e6b7124f79c5ecdce4edcac61690bc684aedff8a 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -781,6 +781,17 @@ data_transform : skip_transform : x +- op : fused_adam_ + args : (Tensor[] params, Tensor[] grads, Tensor learning_rate, Tensor[] moments1, Tensor[] moments2, Tensor[] beta1_pows, Tensor[] beta2_pows, Tensor[] master_params, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, int chunk_size, float weight_decay, bool use_adamw, bool multi_precision, bool use_global_beta_pow) + output : Tensor[](params_out){params.size()}, Tensor[](moments1_out){params.size()}, Tensor[](moments2_out){params.size()}, Tensor[](beta1_pows_out){params.size()}, Tensor[](beta2_pows_out){params.size()}, Tensor[](master_params_out){params.size()} + infer_meta : + func : FusedAdamInferMeta + kernel : + func : fused_adam + data_type : params + optional : skip_update, master_params + inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out) + - op : gather args : (Tensor x, Tensor index, Scalar(int) axis=0) output : Tensor(out) @@ -1237,17 +1248,6 @@ optional : master_param inplace : (param -> param_out), (velocity -> velocity_out), (master_param -> master_param_out) -- op : multi_tensor_adam_ - args : (Tensor[] params, Tensor[] grads, Tensor learning_rate, Tensor[] moments1, Tensor[] moments2, Tensor[] beta1_pows, Tensor[] beta2_pows, Tensor[] master_params, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, int chunk_size, float weight_decay, bool use_adamw, bool multi_precision, bool use_global_beta_pow) - output : Tensor[](params_out){params.size()}, Tensor[](moments1_out){params.size()}, Tensor[](moments2_out){params.size()}, Tensor[](beta1_pows_out){params.size()}, Tensor[](beta2_pows_out){params.size()}, Tensor[](master_params_out){params.size()} - infer_meta : - func : MultiTensorAdamInferMeta - kernel : - func : multi_tensor_adam - data_type : params - optional : skip_update, master_params - inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out) - - op : multiclass_nms3 args : (Tensor bboxes, Tensor scores, Tensor rois_num, float score_threshold, int nms_top_k, int keep_top_k, float nms_threshold=0.3, bool normalized=true, float nms_eta=1.0, int background_label=0) output : Tensor(out), Tensor(index), Tensor(nms_rois_num) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 8b168de57507a120042b7440bcb4c60409f45eef..670f2b6cc15173b918b30a79fd88cdef38621754 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2982,7 +2982,7 @@ void YoloLossInferMeta(const MetaTensor& x, gt_match_mask->set_dtype(x.dtype()); } -void MultiTensorAdamInferMeta( +void FusedAdamInferMeta( const std::vector& params, const std::vector& grads, const MetaTensor& learning_rate, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 87ed5529046a638a195cf776d7975a907e72f6d1..d6f997eb3781676fa181d3fb54b3fe32401c9f7d 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -533,7 +533,7 @@ void YoloLossInferMeta(const MetaTensor& x, MetaTensor* objectness_mask, MetaTensor* gt_match_mask); -void MultiTensorAdamInferMeta( +void FusedAdamInferMeta( const std::vector& params, const std::vector& grads, const MetaTensor& learning_rate, diff --git a/paddle/phi/kernels/cpu/multi_tensor_adam_kernel.cc b/paddle/phi/kernels/cpu/fused_adam_kernel.cc similarity index 95% rename from paddle/phi/kernels/cpu/multi_tensor_adam_kernel.cc rename to paddle/phi/kernels/cpu/fused_adam_kernel.cc index 1ca5641a2406a98008cced3f1d2b7cd31ecce660..9d71f2469423e5fecc732cdfefffb740ce3bc268 100644 --- a/paddle/phi/kernels/cpu/multi_tensor_adam_kernel.cc +++ b/paddle/phi/kernels/cpu/fused_adam_kernel.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/multi_tensor_adam_kernel.h" +#include "paddle/phi/kernels/fused_adam_kernel.h" #include #include "paddle/phi/core/kernel_registry.h" @@ -29,7 +29,7 @@ static paddle::optional TensorPtrToOptionalTensor( } template -void MultiTensorAdamKernel( +void FusedAdamKernel( const Context& dev_ctx, const std::vector& params, const std::vector& grads, @@ -157,9 +157,5 @@ void MultiTensorAdamKernel( } // namespace phi -PD_REGISTER_KERNEL(multi_tensor_adam, - CPU, - ALL_LAYOUT, - phi::MultiTensorAdamKernel, - float, - double) {} +PD_REGISTER_KERNEL( + fused_adam, CPU, ALL_LAYOUT, phi::FusedAdamKernel, float, double) {} diff --git a/paddle/phi/kernels/multi_tensor_adam_kernel.h b/paddle/phi/kernels/fused_adam_kernel.h similarity index 98% rename from paddle/phi/kernels/multi_tensor_adam_kernel.h rename to paddle/phi/kernels/fused_adam_kernel.h index 5bc6399a7c10e6eeceb1a9ffe1a372f3e320955a..b44c7250d148ffef89ae5cb46449b9e9e3b9c6cb 100644 --- a/paddle/phi/kernels/multi_tensor_adam_kernel.h +++ b/paddle/phi/kernels/fused_adam_kernel.h @@ -20,7 +20,7 @@ namespace phi { template -void MultiTensorAdamKernel( +void FusedAdamKernel( const Context &dev_ctx, const std::vector ¶ms, const std::vector &grads, diff --git a/paddle/phi/kernels/gpu/multi_tensor_adam_kernel.cu b/paddle/phi/kernels/gpu/fused_adam_kernel.cu similarity index 94% rename from paddle/phi/kernels/gpu/multi_tensor_adam_kernel.cu rename to paddle/phi/kernels/gpu/fused_adam_kernel.cu index 176b453596e3f2677bcd9dc23bf0bdf12cd243f2..644e2085039c5506f448e594535a9f2e6ef7af55 100644 --- a/paddle/phi/kernels/gpu/multi_tensor_adam_kernel.cu +++ b/paddle/phi/kernels/gpu/fused_adam_kernel.cu @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/multi_tensor_adam_kernel.h" +#include "paddle/phi/kernels/fused_adam_kernel.h" #include #include "glog/logging.h" #include "paddle/phi/backends/gpu/gpu_context.h" @@ -31,9 +31,9 @@ namespace phi { // https://github.com/NVIDIA/apex template -struct MultiTensorAdamBetaPowInfo { +struct FusedAdamBetaPowInfo { using MPDType = typename phi::dtype::MPTypeTrait::Type; - MultiTensorAdamBetaPowInfo(const MPDType* beta1pow, const MPDType* beta2pow) { + FusedAdamBetaPowInfo(const MPDType* beta1pow, const MPDType* beta2pow) { beta1pow_ = *beta1pow; beta2pow_ = *beta2pow; } @@ -48,9 +48,9 @@ struct MultiTensorAdamBetaPowInfo { }; template -struct MultiTensorAdamBetaPowInfo { +struct FusedAdamBetaPowInfo { using MPDType = typename phi::dtype::MPTypeTrait::Type; - MultiTensorAdamBetaPowInfo(const MPDType* beta1pow, const MPDType* beta2pow) { + FusedAdamBetaPowInfo(const MPDType* beta1pow, const MPDType* beta2pow) { beta1pow_ = beta1pow; beta2pow_ = beta2pow; } @@ -73,13 +73,13 @@ template -struct MultiTensorAdamFunctor { +struct FusedAdamFunctor { __device__ __forceinline__ void operator()( int chunk_size, const funcs::TensorAndBlockInfo& t_info, MT beta1, MT beta2, - MultiTensorAdamBetaPowInfo beta_pow, + FusedAdamBetaPowInfo beta_pow, MT epsilon, const MT* learning_rate, MT decay) const { @@ -261,7 +261,7 @@ static int GetVecSizeFromTensors(const std::vector& tensors, } template -void MultiTensorAdamKernel( +void FusedAdamKernel( const Context& dev_ctx, const std::vector& params, const std::vector& grads, @@ -365,17 +365,17 @@ void MultiTensorAdamKernel( constexpr int kMaxTensorSize = __multi_precision ? 48 : 60; \ constexpr int kMaxBlockSize = __multi_precision ? 320 : 320; \ constexpr int kBlockSize = 512; \ - MultiTensorAdamBetaPowInfo beta_pow_info( \ + FusedAdamBetaPowInfo beta_pow_info( \ beta1_pow_first->data(), beta2_pow_first->data()); \ - MultiTensorAdamFunctor \ + FusedAdamFunctor \ functor; \ funcs::LaunchMultiTensorApplyKernel in_names = {"Params", "Grads", "LearningRate", "Moments1", "Moments2", - "Beta1Pow", - "Beta2Pow", + "Beta1Pows", + "Beta2Pows", "MasterParams", "SkipUpdate"}; paddle::small_vector out_names = {"ParamsOut", "Moments1Out", "Moments2Out", - "Beta1PowOut", - "Beta2PowOut", + "Beta1PowsOut", + "Beta2PowsOut", "MasterParamsOut"}; paddle::small_vector attr_names = {"beta1", "beta2", @@ -44,7 +43,7 @@ KernelSignature MultiTensorAdamOpArgumentMapping( "multi_precision", "use_global_beta_pow"}; - return KernelSignature("multi_tensor_adam", + return KernelSignature("fused_adam", std::move(in_names), std::move(attr_names), std::move(out_names)); @@ -52,5 +51,4 @@ KernelSignature MultiTensorAdamOpArgumentMapping( } // namespace phi -PD_REGISTER_ARG_MAPPING_FN(multi_tensor_adam, - phi::MultiTensorAdamOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(fused_adam, phi::FusedAdamOpArgumentMapping); diff --git a/paddle/phi/tests/kernels/CMakeLists.txt b/paddle/phi/tests/kernels/CMakeLists.txt index 61be79303b094322c387bc265e774d297c9dfc4c..d3a50f97f22b07198eaa0b1a388fa6745eb80f0a 100644 --- a/paddle/phi/tests/kernels/CMakeLists.txt +++ b/paddle/phi/tests/kernels/CMakeLists.txt @@ -77,8 +77,8 @@ if(WITH_GPU) SRCS test_auto_tune.cu DEPS gtest) cc_test( - test_multi_tensor_adam_kernel - SRCS test_multi_tensor_adam_kernel.cc + test_fused_adam_kernel + SRCS test_fused_adam_kernel.cc DEPS gtest phi phi_api_utils) elseif(WITH_ROCM) hip_test( diff --git a/paddle/phi/tests/kernels/test_multi_tensor_adam_kernel.cc b/paddle/phi/tests/kernels/test_fused_adam_kernel.cc similarity index 84% rename from paddle/phi/tests/kernels/test_multi_tensor_adam_kernel.cc rename to paddle/phi/tests/kernels/test_fused_adam_kernel.cc index 68f4d8e5c14ce3c4d5ee6ebe80b945a830661485..43a29d4ae53d8fa11f96c2ac40d6f8da0957f353 100644 --- a/paddle/phi/tests/kernels/test_multi_tensor_adam_kernel.cc +++ b/paddle/phi/tests/kernels/test_fused_adam_kernel.cc @@ -32,8 +32,8 @@ #include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/elementwise_subtract_kernel.h" #include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/fused_adam_kernel.h" #include "paddle/phi/kernels/gaussian_kernel.h" -#include "paddle/phi/kernels/multi_tensor_adam_kernel.h" #include "paddle/phi/kernels/reduce_max_kernel.h" namespace phi { @@ -179,9 +179,9 @@ struct AdamInfo { GenerateConstantTensorVectors(*ctx, one_shapes, beta2); } - void Update(bool use_multi_tensor, const std::vector &grads) { - if (use_multi_tensor) { - UpdateWithMultiTensorAdam(grads); + void Update(bool use_fused, const std::vector &grads) { + if (use_fused) { + UpdateWithFusedAdam(grads); } else { for (size_t j = 0; j < params.size(); ++j) { if (use_adamw) { @@ -226,7 +226,7 @@ struct AdamInfo { } private: - void UpdateWithMultiTensorAdam(const std::vector &grads) { + void UpdateWithFusedAdam(const std::vector &grads) { auto param_metas = ToMetaTensorVector(params); auto grad_metas = ToMetaTensorVector(grads); auto master_param_metas = ToMetaTensorVector(master_params); @@ -235,34 +235,34 @@ struct AdamInfo { auto beta1_pow_metas = ToMetaTensorVector(beta1_pows); auto beta2_pow_metas = ToMetaTensorVector(beta2_pows); - MultiTensorAdamInferMeta( - ToConstMetaTensorPtrVector(param_metas), - ToConstMetaTensorPtrVector(grad_metas), - learning_rate, - ToConstMetaTensorPtrVector(moment1_metas), - ToConstMetaTensorPtrVector(moment2_metas), - ToConstMetaTensorPtrVector(beta1_pow_metas), - ToConstMetaTensorPtrVector(beta2_pow_metas), - multi_precision ? paddle::make_optional( - ToConstMetaTensorPtrVector(master_param_metas)) - : paddle::none, - MetaTensor(), - beta1, - beta2, - epsilon, - chunk_size, - weight_decay, - use_adamw, - multi_precision, - false, - ToMutableMetaTensorPtrVector(param_metas), - ToMutableMetaTensorPtrVector(moment1_metas), - ToMutableMetaTensorPtrVector(moment2_metas), - ToMutableMetaTensorPtrVector(beta1_pow_metas), - ToMutableMetaTensorPtrVector(beta2_pow_metas), - ToMutableMetaTensorPtrVector(master_param_metas)); - - MultiTensorAdamKernel( + FusedAdamInferMeta(ToConstMetaTensorPtrVector(param_metas), + ToConstMetaTensorPtrVector(grad_metas), + learning_rate, + ToConstMetaTensorPtrVector(moment1_metas), + ToConstMetaTensorPtrVector(moment2_metas), + ToConstMetaTensorPtrVector(beta1_pow_metas), + ToConstMetaTensorPtrVector(beta2_pow_metas), + multi_precision + ? paddle::make_optional( + ToConstMetaTensorPtrVector(master_param_metas)) + : paddle::none, + MetaTensor(), + beta1, + beta2, + epsilon, + chunk_size, + weight_decay, + use_adamw, + multi_precision, + false, + ToMutableMetaTensorPtrVector(param_metas), + ToMutableMetaTensorPtrVector(moment1_metas), + ToMutableMetaTensorPtrVector(moment2_metas), + ToMutableMetaTensorPtrVector(beta1_pow_metas), + ToMutableMetaTensorPtrVector(beta2_pow_metas), + ToMutableMetaTensorPtrVector(master_param_metas)); + + FusedAdamKernel( *ctx, ToConstTensorPtrVector(params), ToConstTensorPtrVector(grads), @@ -395,15 +395,15 @@ auto MaxDiff(const Context &ctx, } template -void TestMultiTensorAdamBase(const std::vector> &shapes, - float atol, - bool use_adamw, - bool multi_precision = false, - float beta1 = 0.9, - float beta2 = 0.99, - float weight_decay = 0.1, - size_t steps = 5, - uint64_t seed = 10) { +void TestFusedAdamBase(const std::vector> &shapes, + float atol, + bool use_adamw, + bool multi_precision = false, + float beta1 = 0.9, + float beta2 = 0.99, + float weight_decay = 0.1, + size_t steps = 5, + uint64_t seed = 10) { const auto &ctx = *paddle::platform::DeviceContextPool::Instance().GetByPlace(PlaceType()); using Context = typename std::remove_const< @@ -448,29 +448,28 @@ static auto GenerateRandomShapes(size_t n, uint64_t low, uint64_t high) { return shapes; } -TEST(multi_tensor_adam, test_fp32_cpu) { +TEST(fused_adam, test_fp32_cpu) { auto shapes = GenerateRandomShapes(30, 10, 20); float atol = 0.0f; for (auto use_adamw : {false, true}) { - TestMultiTensorAdamBase(shapes, atol, use_adamw); + TestFusedAdamBase(shapes, atol, use_adamw); } } #ifdef PADDLE_WITH_CUDA -TEST(multi_tensor_adam, test_fp32_gpu) { +TEST(fused_adam, test_fp32_gpu) { auto shapes = GenerateRandomShapes(40, 0, 2 << 18); float atol = 0.0f; for (auto use_adamw : {false, true}) { - TestMultiTensorAdamBase(shapes, atol, use_adamw); + TestFusedAdamBase(shapes, atol, use_adamw); } } -TEST(multi_tensor_adam, test_fp16_gpu) { +TEST(fused_adam, test_fp16_gpu) { auto shapes = GenerateRandomShapes(40, 0, 2 << 18); float atol = 5e-3f; for (auto use_adamw : {false, true}) { - TestMultiTensorAdamBase( - shapes, atol, use_adamw, true); + TestFusedAdamBase(shapes, atol, use_adamw, true); } } #endif diff --git a/python/paddle/fluid/tests/unittests/test_fused_adam_op.py b/python/paddle/fluid/tests/unittests/test_fused_adam_op.py new file mode 100644 index 0000000000000000000000000000000000000000..9b184e666cd3fabd8185b9ee38c2deadf7b3c0a4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_adam_op.py @@ -0,0 +1,180 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import OpTest + +import paddle + + +def fused_adam_step(inputs, attributes, num): + ''' + Simulate one step of the fused_adam optimizer + :param inputs: dict of inputs + :param attributes: dict of attributes + :return tuple: tuple of output params, moments1, moments2, beta1_pows, beta2_pows + ''' + params = inputs['Params'] + grads = inputs['Grads'] + moments1 = inputs['Moments1'] + moments2 = inputs['Moments2'] + lr = inputs['LearningRate'] + beta1_pows = inputs['Beta1Pows'] + beta2_pows = inputs['Beta2Pows'] + + params_out = [] + moments1_out = [] + moments2_out = [] + beta1_pows_out = [] + beta2_pows_out = [] + + epsilon = attributes['epsilon'] + + if 'beta1' in attributes: + beta1 = attributes['beta1'] + else: + beta1 = inputs['Beta1Tensor'][0][0] + if 'beta2' in attributes: + beta2 = attributes['beta2'] + else: + beta2 = inputs['Beta2Tensor'][0][0] + + for i in range(num): + moments1_out.append(beta1 * moments1[i][1] + (1 - beta1) * grads[i][1]) + moments2_out.append( + beta2 * moments2[i][1] + (1 - beta2) * np.square(grads[i][1]) + ) + lr_t = lr * np.sqrt(1 - beta2_pows[i][1]) / (1 - beta1_pows[i][1]) + params_out.append( + params[i][1] + - lr_t * (moments1_out[i] / (np.sqrt(moments2_out[i]) + epsilon)) + ) + + for i in range(num): + beta1_pows_out.append( + np.array([beta1_pows[i][1]]).astype("float32") * beta1 + ) + beta2_pows_out.append( + np.array([beta2_pows[i][1]]).astype("float32") * beta2 + ) + + return ( + params_out, + moments1_out, + moments2_out, + beta1_pows_out, + beta2_pows_out, + ) + + +class TestFusedAdamOp(OpTest): + def setUp(self): + + paddle.enable_static() + + '''Test FusedAdam Op with supplied attributes''' + self.__class__.op_type = "fused_adam" + + num = 10 + inputs_list = [[0] * num] * 6 + learning_rate = 0.004 + beta1 = 0.78 + beta2 = 0.836 + epsilon = 1e-4 + beta1_pow = beta1**10 + beta2_pow = beta2**10 + + self.attrs = { + 'epsilon': epsilon, + 'beta1': beta1, + 'beta2': beta2, + "chunk_size": 32 * 2048, + } + + for i in range(num): + + inputs_list[0][i] = np.random.uniform(-1, 1, (102, 105)).astype( + "float32" + ) + inputs_list[1][i] = np.random.uniform(-1, 1, (102, 105)).astype( + "float32" + ) + inputs_list[2][i] = np.random.uniform(-1, 1, (102, 105)).astype( + "float32" + ) + inputs_list[3][i] = np.random.random((102, 105)).astype("float32") + inputs_list[4][i] = np.array([beta1_pow]).astype("float32") + inputs_list[5][i] = np.array([beta2_pow]).astype("float32") + + self.inputs = { + 'Params': [ + ("params" + str(i), inputs_list[0][i]) for i in range(num) + ], + 'Grads': [ + ("grads" + str(i), inputs_list[1][i]) for i in range(num) + ], + 'Moments1': [ + ("moments1" + str(i), inputs_list[2][i]) for i in range(num) + ], + 'Moments2': [ + ("moments2" + str(i), inputs_list[3][i]) for i in range(num) + ], + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pows': [ + ("beta1_pows" + str(i), inputs_list[4][i]) for i in range(num) + ], + 'Beta2Pows': [ + ("beta2_pows" + str(i), inputs_list[5][i]) for i in range(num) + ], + } + + ( + params_out, + moments1_out, + moments2_out, + beta1_pows_out, + beta2_pows_out, + ) = fused_adam_step(self.inputs, self.attrs, num) + + self.outputs = { + 'Moments1Out': [ + ("moments1_out" + str(i), moments1_out[i]) for i in range(num) + ], + 'Moments2Out': [ + ("moments2_out" + str(i), moments2_out[i]) for i in range(num) + ], + 'ParamsOut': [ + ("params_out" + str(i), params_out[i]) for i in range(num) + ], + 'Beta1PowsOut': [ + ("beta1_pows_out" + str(i), beta1_pows_out[i]) + for i in range(num) + ], + 'Beta2PowsOut': [ + ("beta2_pows_out" + str(i), beta2_pows_out[i]) + for i in range(num) + ], + } + + def test_check_output(self): + paddle.enable_static() + if paddle.is_compiled_with_cuda(): + self.check_output() + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main()