diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index ef89dbb3ffe6e680473713376595a8959f52586f..b13cb45bf988f1116d0557d5202525e0f99e6b5c 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -383,6 +383,7 @@ set(IR_PASS_DEPS fix_op_run_order_pass fuse_gemm_epilogue_pass fused_attention_pass + fuse_adamw_op_pass fused_feedforward_pass delete_dropout_op_pass) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 395da4b0a092861cdfc63422672c38898fea082d..91024a9dbe317491605a0af62b50f0c65169326e 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -189,6 +189,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { #ifdef PADDLE_WITH_CUDA AppendPassWithCheck(strategy_.fused_attention_, "fused_attention_pass"); + AppendPassWithCheck(strategy_.fuse_adamw_, "fuse_adamw_op_pass"); #endif #if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060) @@ -528,6 +529,7 @@ USE_PASS(add_reader_dependency_pass); USE_PASS(delete_dropout_op_x_pass); #ifdef PADDLE_WITH_CUDA USE_PASS(fused_attention_pass); +USE_PASS(fuse_adamw_op_pass); #endif #ifdef PADDLE_WITH_CINN USE_PASS(build_cinn_pass); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 1cd15746ddd3230a926a6f39728cb8765d2e368d..be836c380ed7cf5bac6c767a85f280d86abd7631 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -131,6 +131,8 @@ struct BuildStrategy { bool fuse_gemm_epilogue_{false}; // Fused multi head attention bool fused_attention_{false}; + // Fuse adamw + bool fuse_adamw_{false}; // Fused feed forward bool fused_feedforward_{false}; diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index b619bef90102793bd6152ca8fd50dce4fe49c6c9..be115c4d8e77cbf966e711fa8d6b64521d13dd2b 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -264,6 +264,10 @@ cc_library( fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc DEPS pass graph_pattern_detector) +cc_library( + fuse_adamw_op_pass + SRCS fuse_adamw_op_pass.cc + DEPS pass graph_pattern_detector) cc_library( fused_feedforward_pass SRCS fused_feedforward_pass.cc diff --git a/paddle/fluid/framework/ir/fuse_adamw_op_pass.cc b/paddle/fluid/framework/ir/fuse_adamw_op_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..c26032fadc21d91d5ac2cdda2efcddd7a055e1e8 --- /dev/null +++ b/paddle/fluid/framework/ir/fuse_adamw_op_pass.cc @@ -0,0 +1,317 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 NVIDIA Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/framework/ir/fuse_adamw_op_pass.h" +#include +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +std::vector GetNodeNames(const std::vector &node_vector) { + std::vector out_vector; + for (auto i : node_vector) { + out_vector.emplace_back(i->Name()); + } + return out_vector; +} + +Node *GetInputNode(const Node *op, const std::string &name) { + Node *out = nullptr; + + for (auto &node : op->inputs) { + if (node->Name() == op->Op()->Input(name)[0]) { + out = node; + break; + } + } + + PADDLE_ENFORCE_NOT_NULL( + out, platform::errors::InvalidArgument("Input's name cannot be found.")); + + return out; +} + +Node *GetOutputNode(const Node *op, const std::string &name) { + Node *out = nullptr; + + for (auto &node : op->outputs) { + if (node->Name() == op->Op()->Output(name)[0]) { + out = node; + break; + } + } + + PADDLE_ENFORCE_NOT_NULL( + out, platform::errors::InvalidArgument("Output's name cannot be found.")); + + return out; +} + +void SaveInOutNodes(std::vector> *inout_node_vectors, + const AdamWConfig &config, + const Node *op) { + size_t i = 0; + + for (auto &name : config.inputs_name) { + (*inout_node_vectors)[i].emplace_back(GetInputNode(op, name)); + i++; + } + for (auto &name : config.outputs_name) { + (*inout_node_vectors)[i].emplace_back(GetOutputNode(op, name)); + i++; + } + if (config.multi_precision) { + (*inout_node_vectors)[i++].emplace_back(GetInputNode(op, "MasterParam")); + (*inout_node_vectors)[i].emplace_back(GetOutputNode(op, "MasterParamOut")); + } +} + +void InsertOpToGraph(const std::vector> &inout_node_vectors, + const AdamWConfig &config, + ir::Graph *graph) { + float weight_decay = static_cast(0.0); + bool use_adamw = false; + + if (config.with_decay) { + weight_decay = config.first_coeff; + use_adamw = true; + } + if (inout_node_vectors[0].size() > 0 && config.replace_adamw) { + OpDesc fuse_adamw_op_desc(config.block); + fuse_adamw_op_desc.SetType("fused_adam"); + + size_t i = 0; + + for (auto &name : config.replace_inputs_name) { + fuse_adamw_op_desc.SetInput(name, GetNodeNames(inout_node_vectors[i])); + i++; + } + + fuse_adamw_op_desc.SetInput("LearningRate", {config.first_lr->Name()}); + if (config.use_skip_update) { + fuse_adamw_op_desc.SetInput("SkipUpdate", + {config.first_skip_update->Name()}); + } else { + fuse_adamw_op_desc.SetInput("SkipUpdate", {}); + } + + for (auto &name : config.repalce_outputs_name) { + fuse_adamw_op_desc.SetOutput(name, GetNodeNames(inout_node_vectors[i])); + i++; + } + + if (config.multi_precision) { + fuse_adamw_op_desc.SetInput("MasterParams", + GetNodeNames(inout_node_vectors[i++])); + fuse_adamw_op_desc.SetOutput("MasterParamsOut", + GetNodeNames(inout_node_vectors[i])); + } else { + fuse_adamw_op_desc.SetInput("MasterParams", {}); + } + + fuse_adamw_op_desc.SetAttr("beta1", config.beta1); + fuse_adamw_op_desc.SetAttr("beta2", config.beta2); + fuse_adamw_op_desc.SetAttr("op_role", config.op_role); + fuse_adamw_op_desc.SetAttr("epsilon", config.epsilon); + fuse_adamw_op_desc.SetAttr("chunk_size", 16 * 2048); + fuse_adamw_op_desc.SetAttr("weight_decay", weight_decay); + fuse_adamw_op_desc.SetAttr("use_adamw", use_adamw); + fuse_adamw_op_desc.SetAttr("multi_precision", config.multi_precision); + fuse_adamw_op_desc.SetAttr("use_global_beta_pow", + config.use_global_beta_pow); + + auto fuse_adamw_node = graph->CreateOpNode(&fuse_adamw_op_desc); + + IR_NODE_LINK_TO(config.first_lr, fuse_adamw_node); + if (config.use_skip_update) { + IR_NODE_LINK_TO(config.first_skip_update, fuse_adamw_node); + } + + for (size_t k = 0; k < inout_node_vectors[0].size(); k++) { + size_t j = 0; + + for (; j < config.replace_inputs_name.size(); j++) { + IR_NODE_LINK_TO(inout_node_vectors[j][k], fuse_adamw_node); + } + for (; j < config.replace_inputs_name.size() + + config.repalce_outputs_name.size(); + j++) { + IR_NODE_LINK_TO(fuse_adamw_node, inout_node_vectors[j][k]); + } + if (config.multi_precision) { + IR_NODE_LINK_TO(inout_node_vectors[j][k], fuse_adamw_node); + j++; + IR_NODE_LINK_TO(fuse_adamw_node, inout_node_vectors[j][k]); + } + } + } +} + +bool InitAndCheckAttrs(const size_t &found_adamw_count, + AdamWConfig *config, + const Node *op, + bool *is_continue) { + const Node *adamw_op = op; + Node *skip_update = nullptr; + Node *learning_rate = GetInputNode(adamw_op, "LearningRate"); + auto adamw_op_desc = adamw_op->Op(); + + // Initialize variables + float coeff = 0.0, lr_ratio = 1.0; + bool lazy_mode = false; + int64_t min_row_size_to_use_multithread = 1000; + + // Get skip_update and coeff, these wiil be used to check whether we can + // use fuse_adamw. + for (auto &node : adamw_op->inputs) { + auto in_name = adamw_op_desc->Input("SkipUpdate"); + if (!in_name.empty()) { + if (node->Name() == in_name[0]) { + config->use_skip_update = true; + skip_update = node; + break; + } + } + } + coeff = PADDLE_GET_CONST(float, adamw_op_desc->GetAttr("coeff")); + + // Get attrs and block + if (found_adamw_count == 0) { + // Get blokc + config->block = adamw_op_desc->Block(); + // Get attrs + config->beta1 = PADDLE_GET_CONST(float, adamw_op_desc->GetAttr("beta1")); + config->beta2 = PADDLE_GET_CONST(float, adamw_op_desc->GetAttr("beta2")); + config->op_role = PADDLE_GET_CONST(int, adamw_op_desc->GetAttr("op_role")); + config->epsilon = + PADDLE_GET_CONST(float, adamw_op_desc->GetAttr("epsilon")); + config->use_global_beta_pow = + PADDLE_GET_CONST(bool, adamw_op_desc->GetAttr("use_global_beta_pow")); + + lazy_mode = PADDLE_GET_CONST(bool, adamw_op_desc->GetAttr("lazy_mode")); + min_row_size_to_use_multithread = PADDLE_GET_CONST( + int64_t, adamw_op_desc->GetAttr("min_row_size_to_use_multithread")); + lr_ratio = PADDLE_GET_CONST(float, adamw_op_desc->GetAttr("lr_ratio")); + + config->first_lr = learning_rate; + config->first_coeff = coeff; + if (config->use_skip_update) { + config->first_skip_update = skip_update; + } + + // We do not support these patterns + if (lazy_mode != false || lr_ratio != 1.0 || + min_row_size_to_use_multithread != 1000) { + return false; + } + } + + // Check whether with_decay and multi_precision are matched。 + if (config->with_decay != + PADDLE_GET_CONST(bool, adamw_op_desc->GetAttr("with_decay")) || + config->multi_precision != + PADDLE_GET_CONST(bool, adamw_op_desc->GetAttr("multi_precision"))) { + *is_continue = true; + return true; + } + + // We do not support these patterns + if ((learning_rate->Name() != config->first_lr->Name()) || + (coeff != config->first_coeff) || + (config->use_skip_update && + skip_update->Name() != config->first_skip_update->Name())) { + return false; + } + + return true; +} + +void FuseAdamWPass::ApplyImpl(ir::Graph *graph) const { + graph = FuseAdamWFun(graph, true, true); + graph = FuseAdamWFun(graph, true, false); + graph = FuseAdamWFun(graph, false, true); + graph = FuseAdamWFun(graph, false, false); +} + +ir::Graph *FuseAdamWPass::FuseAdamWFun(ir::Graph *graph, + const bool with_decay, + const bool multi_precision) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + + VLOG(4) << "handle fuse AdadW"; + + const std::string scope_name("fuse_adamw"); + FusePassBase::Init(scope_name, graph); + + size_t found_adamw_count = 0; + + AdamWConfig config; + + config.with_decay = with_decay; + config.multi_precision = multi_precision; + + // Used to store Nodes of input and output for each pattern + std::vector> inout_node_vectors(13); + + std::unordered_set adamw_op_del_set; + + for (auto &node : graph->Nodes()) { + if (node->Name() == "adamw") { + const Node *adamw_op = node; + bool is_continue = false; + + // Initialize attrs and check attrs to determine whether we support this + // pattern. + if (!InitAndCheckAttrs(found_adamw_count, &config, node, &is_continue)) { + config.replace_adamw = false; + return graph; + } + + if (is_continue) { + continue; + } + + adamw_op_del_set.insert(adamw_op); + + // Save input and output Nodes + SaveInOutNodes(&inout_node_vectors, config, adamw_op); + + found_adamw_count++; + } + } + + // Remove old op + if (config.replace_adamw && (inout_node_vectors[0].size() > 0)) { + GraphSafeRemoveNodes(graph, adamw_op_del_set); + } + + // Insert new op to graph + InsertOpToGraph(inout_node_vectors, config, graph); + + VLOG(4) << "replace adamw with fuse_adamw"; + + AddStatis(found_adamw_count); + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(fuse_adamw_op_pass, paddle::framework::ir::FuseAdamWPass); diff --git a/paddle/fluid/framework/ir/fuse_adamw_op_pass.h b/paddle/fluid/framework/ir/fuse_adamw_op_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..0fa8907e687508dfa5b7de988ac010fbbd3c9e09 --- /dev/null +++ b/paddle/fluid/framework/ir/fuse_adamw_op_pass.h @@ -0,0 +1,76 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 NVIDIA Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Graph; +class Node; + +struct AdamWConfig { + Node *first_lr = nullptr; + Node *first_skip_update = nullptr; + paddle::framework::BlockDesc *block = nullptr; + int op_role = 0; + float beta1 = 0.9; + float beta2 = 0.99; + float epsilon = 1e-8; + float first_coeff = 0.0; + bool use_global_beta_pow = false; + bool replace_adamw = true; + bool use_skip_update = false; + bool with_decay = true; + bool multi_precision = true; + + // Initialize the input and output names of adamw op and fused_adamw op + const std::vector inputs_name = { + "Param", "Grad", "Moment1", "Moment2", "Beta1Pow", "Beta2Pow"}; + const std::vector outputs_name = { + "ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}; + const std::vector replace_inputs_name = { + "Params", "Grads", "Moments1", "Moments2", "Beta1Pows", "Beta2Pows"}; + const std::vector repalce_outputs_name = {"ParamsOut", + "Moments1Out", + "Moments2Out", + "Beta1PowsOut", + "Beta2PowsOut"}; +}; + +class FuseAdamWPass : public FusePassBase { + public: + virtual ~FuseAdamWPass() {} + + protected: + void ApplyImpl(ir::Graph *graph) const override; + + ir::Graph *FuseAdamWFun(ir::Graph *graph, + const bool with_decay, + const bool multi_precision) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/pybind/parallel_executor.cc b/paddle/fluid/pybind/parallel_executor.cc index d1938238a99824cc33ca6be44f65bc1a11740429..0a5fd7eb16a764e20a26fd9a1d9ebfa40f65914a 100644 --- a/paddle/fluid/pybind/parallel_executor.cc +++ b/paddle/fluid/pybind/parallel_executor.cc @@ -697,6 +697,28 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT build_strategy = static.BuildStrategy() build_strategy.fuse_gemm_epilogue = True )DOC") + .def_property( + "fuse_adamw", + [](const BuildStrategy &self) { return self.fuse_adamw_; }, + [](BuildStrategy &self, bool b) { + PADDLE_ENFORCE_NE(self.IsFinalized(), + true, + platform::errors::PreconditionNotMet( + "BuildStrategy has been finlaized, cannot be " + "configured again.")); + self.fuse_adamw_ = b; + }, + R"DOC((bool, optional): fuse_adamw indicate whether + to fuse all adamw optimizers with multi_tensor_adam, + it may make the execution faster. Default is False. + Examples: + .. code-block:: python + import paddle + import paddle.static as static + paddle.enable_static() + build_strategy = static.BuildStrategy() + build_strategy.fuse_adamw = True + )DOC") .def_property( "fused_attention", [](const BuildStrategy &self) { return self.fused_attention_; }, diff --git a/paddle/phi/kernels/funcs/multi_tensor_apply.h b/paddle/phi/kernels/funcs/multi_tensor_apply.h index 5be64dcab2ef107d49253ac4953a7480b7115b14..6811793c02dcb2e45f839d1223fff60eb18f8e80 100644 --- a/paddle/phi/kernels/funcs/multi_tensor_apply.h +++ b/paddle/phi/kernels/funcs/multi_tensor_apply.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/place.h" #include "paddle/phi/core/tensor_utils.h" namespace phi { @@ -82,12 +83,14 @@ void LaunchMultiTensorApplyKernel( 0, errors::InvalidArgument( "input_vector[0].size() is not > 0, please cheack params.")); - auto place = input_vector[0][0]->place(); + auto ctx_place = dev_ctx.GetPlace(); PADDLE_ENFORCE_EQ( - place, - GPUPlace(), - errors::InvalidArgument( - "expected input to be on gpu, but input is on cpu now.")); + ctx_place.GetType() == AllocationType::GPU, + true, + errors::PreconditionNotMet( + "Context place error, excepted GPUPlace, but actually %s.", + ctx_place)); + auto place = input_vector[0][0]->place(); for (size_t i = 0; i < input_vector.size(); i++) { PADDLE_ENFORCE_EQ( input_vector[i].size(), diff --git a/python/paddle/distributed/passes/cpp_pass.py b/python/paddle/distributed/passes/cpp_pass.py index c0da3050463b4bbe8432d852157f0c44c4aba959..9d936eb5848fc967b2b04ba0c4061572073eed15 100755 --- a/python/paddle/distributed/passes/cpp_pass.py +++ b/python/paddle/distributed/passes/cpp_pass.py @@ -110,6 +110,19 @@ class FuseGemmEpiloguePass(CPPPassWrapper): return PassType.FUSION_OPT +@register_pass("fuse_adamw") +class FuseAdamWPass(CPPPassWrapper): + def __init__(self): + super().__init__() + + @property + def cpp_name(self): + return "fuse_adamw_op_pass" + + def _type(self): + return PassType.FUSION_OPT + + @register_pass("fuse_optimizer") class FuseOptimizerPass(CPPPassWrapper): def __init__(self): diff --git a/python/paddle/distributed/passes/pass_base.py b/python/paddle/distributed/passes/pass_base.py index fca239b41dcc9626049c39d5c20035480d128e80..beed64da7b11eb5eccd45464d23a927168d6abf3 100755 --- a/python/paddle/distributed/passes/pass_base.py +++ b/python/paddle/distributed/passes/pass_base.py @@ -255,6 +255,7 @@ PassBase._PASS_PROCESS_ORDER_LIST = [ "fused_attention", "fused_feedforward", "fuse_gemm_epilogue", + "fuse_adamw", "fuse_optimizer", ] diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index d64aaf35b4d465be41523eb38f9caa2baa15ce92..7db4d58bd8b917660a627b184a8421144bb61344 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -80,6 +80,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU) set_tests_properties(test_align_tool PROPERTIES TIMEOUT 20) py_test_modules(test_pass_base_list MODULES test_pass_base_list) set_tests_properties(test_pass_base_list PROPERTIES TIMEOUT 20) + py_test_modules(test_fuse_adamw_pass MODULES test_fuse_adamw_pass) + set_tests_properties(test_fuse_adamw_pass PROPERTIES TIMEOUT 20) # End of unittests WITH single card and timeout # NOTE(zyl): unittests WITH single card and WITHOUT timeout diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_fuse_adamw_pass.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_fuse_adamw_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..5860c30121ce210bde2b7870a33965fe0927f1d1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_fuse_adamw_pass.py @@ -0,0 +1,138 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import nn +from paddle.distributed.passes import PassManager, new_pass + + +def apply_passes(main_prog, startup_prog): + pass_manager = PassManager([new_pass("fuse_adamw")]) + pass_manager.apply([main_prog], [startup_prog]) + + +class MLPLayer(nn.Layer): + def __init__(self, input_size, hidden_size, output_size, n): + super().__init__() + self.linear_first = nn.Linear(input_size, hidden_size) + self.decoder_layers = nn.LayerList() + for i in range(n): + self.decoder_layers.append(nn.Linear(hidden_size, hidden_size)) + + self.linear_last = nn.Linear(hidden_size, output_size) + + def forward(self, x): + x = self.linear_first(x) + for layer in self.decoder_layers: + x = layer(x) + x = self.linear_last(x) + return x.mean() + + +class TestFuseAdamWPass(unittest.TestCase): + def setUp(self): + paddle.disable_static() + np.random.seed(10) + self.input_size = 30 + self.hidden_size = 50 + self.output_size = 20 + self.n = 2 + self.range_num = 5 + + def get_input_x(self, use_amp): + x = [] + for _ in range(self.range_num): + if use_amp: + x.append( + np.random.random(size=(10, self.input_size)).astype( + 'float16' + ) + ) + else: + x.append( + np.random.random(size=(10, self.input_size)).astype( + 'float32' + ) + ) + + return x + + def get_loss_data(self, place, x, use_amp=False, use_apply_passes=False): + paddle.enable_static() + paddle.seed(10) + + if place == 'cpu': + use_amp = False + + exe = paddle.static.Executor(place=place) + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + optimizer = paddle.optimizer.AdamW(multi_precision=use_amp) + if use_amp: + optimizer = paddle.static.amp.decorate( + optimizer, + init_loss_scaling=128.0, + use_dynamic_loss_scaling=True, + use_pure_fp16=True, + use_fp16_guard=False, + ) + with paddle.static.program_guard(train_program, startup_program): + if use_amp: + data = paddle.static.data( + shape=[10, self.input_size], name='X', dtype='float16' + ) + else: + data = paddle.static.data( + shape=[10, self.input_size], name='X', dtype='float32' + ) + model = MLPLayer( + self.input_size, self.hidden_size, self.output_size, self.n + ) + out = model(data) + loss = paddle.mean(out) + optimizer.minimize(loss) + + if use_apply_passes: + apply_passes(train_program, startup_program) + + exe.run(startup_program) + if use_amp: + optimizer.amp_init(place=place, scope=paddle.static.global_scope()) + + for i in range(5): + loss_data = exe.run( + train_program, feed={"X": x[i]}, fetch_list=[loss.name] + ) + return loss_data + + def test_fuse_adamw_pass(self): + place = paddle.CUDAPlace(0) + for use_amp in [True, False]: + x = self.get_input_x(use_amp) + loss_without_passes = self.get_loss_data(place, x, use_amp, True) + loss_with_passes = self.get_loss_data(place, x, use_amp, False) + np.testing.assert_allclose( + np.array(loss_without_passes), + np.array(loss_with_passes), + rtol=1e-6, + atol=1e-6, + ) + + +if __name__ == "__main__": + unittest.main()