未验证 提交 66098bff 编写于 作者: Y yuehuayingxueluo 提交者: GitHub

Add Fuse Adamw Pass (#50484)

* add fuse adamw pass

* fix some bugs

* fix CIbug

* change chunk_size

* fix CI bug

* rm test_fused_adam_op.py

* fix CI bugs

* fix fuse_adamw_op_pass.cc

* change code style

* fix CI bug

* fix ut bug and use_adamw_op_pass.cc

* fix test_fuse_adamw_pass.py

* fix CI bug

* remove fluid

* fix ci bug

* fix CI bug
上级 5c76b38b
...@@ -383,6 +383,7 @@ set(IR_PASS_DEPS ...@@ -383,6 +383,7 @@ set(IR_PASS_DEPS
fix_op_run_order_pass fix_op_run_order_pass
fuse_gemm_epilogue_pass fuse_gemm_epilogue_pass
fused_attention_pass fused_attention_pass
fuse_adamw_op_pass
fused_feedforward_pass fused_feedforward_pass
delete_dropout_op_pass) delete_dropout_op_pass)
......
...@@ -189,6 +189,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -189,6 +189,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
AppendPassWithCheck(strategy_.fused_attention_, "fused_attention_pass"); AppendPassWithCheck(strategy_.fused_attention_, "fused_attention_pass");
AppendPassWithCheck(strategy_.fuse_adamw_, "fuse_adamw_op_pass");
#endif #endif
#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060) #if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060)
...@@ -528,6 +529,7 @@ USE_PASS(add_reader_dependency_pass); ...@@ -528,6 +529,7 @@ USE_PASS(add_reader_dependency_pass);
USE_PASS(delete_dropout_op_x_pass); USE_PASS(delete_dropout_op_x_pass);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
USE_PASS(fused_attention_pass); USE_PASS(fused_attention_pass);
USE_PASS(fuse_adamw_op_pass);
#endif #endif
#ifdef PADDLE_WITH_CINN #ifdef PADDLE_WITH_CINN
USE_PASS(build_cinn_pass); USE_PASS(build_cinn_pass);
......
...@@ -131,6 +131,8 @@ struct BuildStrategy { ...@@ -131,6 +131,8 @@ struct BuildStrategy {
bool fuse_gemm_epilogue_{false}; bool fuse_gemm_epilogue_{false};
// Fused multi head attention // Fused multi head attention
bool fused_attention_{false}; bool fused_attention_{false};
// Fuse adamw
bool fuse_adamw_{false};
// Fused feed forward // Fused feed forward
bool fused_feedforward_{false}; bool fused_feedforward_{false};
......
...@@ -264,6 +264,10 @@ cc_library( ...@@ -264,6 +264,10 @@ cc_library(
fuse_relu_depthwise_conv_pass fuse_relu_depthwise_conv_pass
SRCS fuse_relu_depthwise_conv_pass.cc SRCS fuse_relu_depthwise_conv_pass.cc
DEPS pass graph_pattern_detector) DEPS pass graph_pattern_detector)
cc_library(
fuse_adamw_op_pass
SRCS fuse_adamw_op_pass.cc
DEPS pass graph_pattern_detector)
cc_library( cc_library(
fused_feedforward_pass fused_feedforward_pass
SRCS fused_feedforward_pass.cc SRCS fused_feedforward_pass.cc
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 NVIDIA Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_adamw_op_pass.h"
#include <string>
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
std::vector<std::string> GetNodeNames(const std::vector<Node *> &node_vector) {
std::vector<std::string> out_vector;
for (auto i : node_vector) {
out_vector.emplace_back(i->Name());
}
return out_vector;
}
Node *GetInputNode(const Node *op, const std::string &name) {
Node *out = nullptr;
for (auto &node : op->inputs) {
if (node->Name() == op->Op()->Input(name)[0]) {
out = node;
break;
}
}
PADDLE_ENFORCE_NOT_NULL(
out, platform::errors::InvalidArgument("Input's name cannot be found."));
return out;
}
Node *GetOutputNode(const Node *op, const std::string &name) {
Node *out = nullptr;
for (auto &node : op->outputs) {
if (node->Name() == op->Op()->Output(name)[0]) {
out = node;
break;
}
}
PADDLE_ENFORCE_NOT_NULL(
out, platform::errors::InvalidArgument("Output's name cannot be found."));
return out;
}
void SaveInOutNodes(std::vector<std::vector<Node *>> *inout_node_vectors,
const AdamWConfig &config,
const Node *op) {
size_t i = 0;
for (auto &name : config.inputs_name) {
(*inout_node_vectors)[i].emplace_back(GetInputNode(op, name));
i++;
}
for (auto &name : config.outputs_name) {
(*inout_node_vectors)[i].emplace_back(GetOutputNode(op, name));
i++;
}
if (config.multi_precision) {
(*inout_node_vectors)[i++].emplace_back(GetInputNode(op, "MasterParam"));
(*inout_node_vectors)[i].emplace_back(GetOutputNode(op, "MasterParamOut"));
}
}
void InsertOpToGraph(const std::vector<std::vector<Node *>> &inout_node_vectors,
const AdamWConfig &config,
ir::Graph *graph) {
float weight_decay = static_cast<float>(0.0);
bool use_adamw = false;
if (config.with_decay) {
weight_decay = config.first_coeff;
use_adamw = true;
}
if (inout_node_vectors[0].size() > 0 && config.replace_adamw) {
OpDesc fuse_adamw_op_desc(config.block);
fuse_adamw_op_desc.SetType("fused_adam");
size_t i = 0;
for (auto &name : config.replace_inputs_name) {
fuse_adamw_op_desc.SetInput(name, GetNodeNames(inout_node_vectors[i]));
i++;
}
fuse_adamw_op_desc.SetInput("LearningRate", {config.first_lr->Name()});
if (config.use_skip_update) {
fuse_adamw_op_desc.SetInput("SkipUpdate",
{config.first_skip_update->Name()});
} else {
fuse_adamw_op_desc.SetInput("SkipUpdate", {});
}
for (auto &name : config.repalce_outputs_name) {
fuse_adamw_op_desc.SetOutput(name, GetNodeNames(inout_node_vectors[i]));
i++;
}
if (config.multi_precision) {
fuse_adamw_op_desc.SetInput("MasterParams",
GetNodeNames(inout_node_vectors[i++]));
fuse_adamw_op_desc.SetOutput("MasterParamsOut",
GetNodeNames(inout_node_vectors[i]));
} else {
fuse_adamw_op_desc.SetInput("MasterParams", {});
}
fuse_adamw_op_desc.SetAttr("beta1", config.beta1);
fuse_adamw_op_desc.SetAttr("beta2", config.beta2);
fuse_adamw_op_desc.SetAttr("op_role", config.op_role);
fuse_adamw_op_desc.SetAttr("epsilon", config.epsilon);
fuse_adamw_op_desc.SetAttr("chunk_size", 16 * 2048);
fuse_adamw_op_desc.SetAttr("weight_decay", weight_decay);
fuse_adamw_op_desc.SetAttr("use_adamw", use_adamw);
fuse_adamw_op_desc.SetAttr("multi_precision", config.multi_precision);
fuse_adamw_op_desc.SetAttr("use_global_beta_pow",
config.use_global_beta_pow);
auto fuse_adamw_node = graph->CreateOpNode(&fuse_adamw_op_desc);
IR_NODE_LINK_TO(config.first_lr, fuse_adamw_node);
if (config.use_skip_update) {
IR_NODE_LINK_TO(config.first_skip_update, fuse_adamw_node);
}
for (size_t k = 0; k < inout_node_vectors[0].size(); k++) {
size_t j = 0;
for (; j < config.replace_inputs_name.size(); j++) {
IR_NODE_LINK_TO(inout_node_vectors[j][k], fuse_adamw_node);
}
for (; j < config.replace_inputs_name.size() +
config.repalce_outputs_name.size();
j++) {
IR_NODE_LINK_TO(fuse_adamw_node, inout_node_vectors[j][k]);
}
if (config.multi_precision) {
IR_NODE_LINK_TO(inout_node_vectors[j][k], fuse_adamw_node);
j++;
IR_NODE_LINK_TO(fuse_adamw_node, inout_node_vectors[j][k]);
}
}
}
}
bool InitAndCheckAttrs(const size_t &found_adamw_count,
AdamWConfig *config,
const Node *op,
bool *is_continue) {
const Node *adamw_op = op;
Node *skip_update = nullptr;
Node *learning_rate = GetInputNode(adamw_op, "LearningRate");
auto adamw_op_desc = adamw_op->Op();
// Initialize variables
float coeff = 0.0, lr_ratio = 1.0;
bool lazy_mode = false;
int64_t min_row_size_to_use_multithread = 1000;
// Get skip_update and coeff, these wiil be used to check whether we can
// use fuse_adamw.
for (auto &node : adamw_op->inputs) {
auto in_name = adamw_op_desc->Input("SkipUpdate");
if (!in_name.empty()) {
if (node->Name() == in_name[0]) {
config->use_skip_update = true;
skip_update = node;
break;
}
}
}
coeff = PADDLE_GET_CONST(float, adamw_op_desc->GetAttr("coeff"));
// Get attrs and block
if (found_adamw_count == 0) {
// Get blokc
config->block = adamw_op_desc->Block();
// Get attrs
config->beta1 = PADDLE_GET_CONST(float, adamw_op_desc->GetAttr("beta1"));
config->beta2 = PADDLE_GET_CONST(float, adamw_op_desc->GetAttr("beta2"));
config->op_role = PADDLE_GET_CONST(int, adamw_op_desc->GetAttr("op_role"));
config->epsilon =
PADDLE_GET_CONST(float, adamw_op_desc->GetAttr("epsilon"));
config->use_global_beta_pow =
PADDLE_GET_CONST(bool, adamw_op_desc->GetAttr("use_global_beta_pow"));
lazy_mode = PADDLE_GET_CONST(bool, adamw_op_desc->GetAttr("lazy_mode"));
min_row_size_to_use_multithread = PADDLE_GET_CONST(
int64_t, adamw_op_desc->GetAttr("min_row_size_to_use_multithread"));
lr_ratio = PADDLE_GET_CONST(float, adamw_op_desc->GetAttr("lr_ratio"));
config->first_lr = learning_rate;
config->first_coeff = coeff;
if (config->use_skip_update) {
config->first_skip_update = skip_update;
}
// We do not support these patterns
if (lazy_mode != false || lr_ratio != 1.0 ||
min_row_size_to_use_multithread != 1000) {
return false;
}
}
// Check whether with_decay and multi_precision are matched。
if (config->with_decay !=
PADDLE_GET_CONST(bool, adamw_op_desc->GetAttr("with_decay")) ||
config->multi_precision !=
PADDLE_GET_CONST(bool, adamw_op_desc->GetAttr("multi_precision"))) {
*is_continue = true;
return true;
}
// We do not support these patterns
if ((learning_rate->Name() != config->first_lr->Name()) ||
(coeff != config->first_coeff) ||
(config->use_skip_update &&
skip_update->Name() != config->first_skip_update->Name())) {
return false;
}
return true;
}
void FuseAdamWPass::ApplyImpl(ir::Graph *graph) const {
graph = FuseAdamWFun(graph, true, true);
graph = FuseAdamWFun(graph, true, false);
graph = FuseAdamWFun(graph, false, true);
graph = FuseAdamWFun(graph, false, false);
}
ir::Graph *FuseAdamWPass::FuseAdamWFun(ir::Graph *graph,
const bool with_decay,
const bool multi_precision) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
VLOG(4) << "handle fuse AdadW";
const std::string scope_name("fuse_adamw");
FusePassBase::Init(scope_name, graph);
size_t found_adamw_count = 0;
AdamWConfig config;
config.with_decay = with_decay;
config.multi_precision = multi_precision;
// Used to store Nodes of input and output for each pattern
std::vector<std::vector<Node *>> inout_node_vectors(13);
std::unordered_set<const Node *> adamw_op_del_set;
for (auto &node : graph->Nodes()) {
if (node->Name() == "adamw") {
const Node *adamw_op = node;
bool is_continue = false;
// Initialize attrs and check attrs to determine whether we support this
// pattern.
if (!InitAndCheckAttrs(found_adamw_count, &config, node, &is_continue)) {
config.replace_adamw = false;
return graph;
}
if (is_continue) {
continue;
}
adamw_op_del_set.insert(adamw_op);
// Save input and output Nodes
SaveInOutNodes(&inout_node_vectors, config, adamw_op);
found_adamw_count++;
}
}
// Remove old op
if (config.replace_adamw && (inout_node_vectors[0].size() > 0)) {
GraphSafeRemoveNodes(graph, adamw_op_del_set);
}
// Insert new op to graph
InsertOpToGraph(inout_node_vectors, config, graph);
VLOG(4) << "replace adamw with fuse_adamw";
AddStatis(found_adamw_count);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_adamw_op_pass, paddle::framework::ir::FuseAdamWPass);
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 NVIDIA Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mutex>
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
class Node;
struct AdamWConfig {
Node *first_lr = nullptr;
Node *first_skip_update = nullptr;
paddle::framework::BlockDesc *block = nullptr;
int op_role = 0;
float beta1 = 0.9;
float beta2 = 0.99;
float epsilon = 1e-8;
float first_coeff = 0.0;
bool use_global_beta_pow = false;
bool replace_adamw = true;
bool use_skip_update = false;
bool with_decay = true;
bool multi_precision = true;
// Initialize the input and output names of adamw op and fused_adamw op
const std::vector<std::string> inputs_name = {
"Param", "Grad", "Moment1", "Moment2", "Beta1Pow", "Beta2Pow"};
const std::vector<std::string> outputs_name = {
"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"};
const std::vector<std::string> replace_inputs_name = {
"Params", "Grads", "Moments1", "Moments2", "Beta1Pows", "Beta2Pows"};
const std::vector<std::string> repalce_outputs_name = {"ParamsOut",
"Moments1Out",
"Moments2Out",
"Beta1PowsOut",
"Beta2PowsOut"};
};
class FuseAdamWPass : public FusePassBase {
public:
virtual ~FuseAdamWPass() {}
protected:
void ApplyImpl(ir::Graph *graph) const override;
ir::Graph *FuseAdamWFun(ir::Graph *graph,
const bool with_decay,
const bool multi_precision) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -697,6 +697,28 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT ...@@ -697,6 +697,28 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT
build_strategy = static.BuildStrategy() build_strategy = static.BuildStrategy()
build_strategy.fuse_gemm_epilogue = True build_strategy.fuse_gemm_epilogue = True
)DOC") )DOC")
.def_property(
"fuse_adamw",
[](const BuildStrategy &self) { return self.fuse_adamw_; },
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE_NE(self.IsFinalized(),
true,
platform::errors::PreconditionNotMet(
"BuildStrategy has been finlaized, cannot be "
"configured again."));
self.fuse_adamw_ = b;
},
R"DOC((bool, optional): fuse_adamw indicate whether
to fuse all adamw optimizers with multi_tensor_adam,
it may make the execution faster. Default is False.
Examples:
.. code-block:: python
import paddle
import paddle.static as static
paddle.enable_static()
build_strategy = static.BuildStrategy()
build_strategy.fuse_adamw = True
)DOC")
.def_property( .def_property(
"fused_attention", "fused_attention",
[](const BuildStrategy &self) { return self.fused_attention_; }, [](const BuildStrategy &self) { return self.fused_attention_; },
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
namespace phi { namespace phi {
...@@ -82,12 +83,14 @@ void LaunchMultiTensorApplyKernel( ...@@ -82,12 +83,14 @@ void LaunchMultiTensorApplyKernel(
0, 0,
errors::InvalidArgument( errors::InvalidArgument(
"input_vector[0].size() is not > 0, please cheack params.")); "input_vector[0].size() is not > 0, please cheack params."));
auto place = input_vector[0][0]->place(); auto ctx_place = dev_ctx.GetPlace();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
place, ctx_place.GetType() == AllocationType::GPU,
GPUPlace(), true,
errors::InvalidArgument( errors::PreconditionNotMet(
"expected input to be on gpu, but input is on cpu now.")); "Context place error, excepted GPUPlace, but actually %s.",
ctx_place));
auto place = input_vector[0][0]->place();
for (size_t i = 0; i < input_vector.size(); i++) { for (size_t i = 0; i < input_vector.size(); i++) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input_vector[i].size(), input_vector[i].size(),
......
...@@ -110,6 +110,19 @@ class FuseGemmEpiloguePass(CPPPassWrapper): ...@@ -110,6 +110,19 @@ class FuseGemmEpiloguePass(CPPPassWrapper):
return PassType.FUSION_OPT return PassType.FUSION_OPT
@register_pass("fuse_adamw")
class FuseAdamWPass(CPPPassWrapper):
def __init__(self):
super().__init__()
@property
def cpp_name(self):
return "fuse_adamw_op_pass"
def _type(self):
return PassType.FUSION_OPT
@register_pass("fuse_optimizer") @register_pass("fuse_optimizer")
class FuseOptimizerPass(CPPPassWrapper): class FuseOptimizerPass(CPPPassWrapper):
def __init__(self): def __init__(self):
......
...@@ -255,6 +255,7 @@ PassBase._PASS_PROCESS_ORDER_LIST = [ ...@@ -255,6 +255,7 @@ PassBase._PASS_PROCESS_ORDER_LIST = [
"fused_attention", "fused_attention",
"fused_feedforward", "fused_feedforward",
"fuse_gemm_epilogue", "fuse_gemm_epilogue",
"fuse_adamw",
"fuse_optimizer", "fuse_optimizer",
] ]
......
...@@ -80,6 +80,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -80,6 +80,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
set_tests_properties(test_align_tool PROPERTIES TIMEOUT 20) set_tests_properties(test_align_tool PROPERTIES TIMEOUT 20)
py_test_modules(test_pass_base_list MODULES test_pass_base_list) py_test_modules(test_pass_base_list MODULES test_pass_base_list)
set_tests_properties(test_pass_base_list PROPERTIES TIMEOUT 20) set_tests_properties(test_pass_base_list PROPERTIES TIMEOUT 20)
py_test_modules(test_fuse_adamw_pass MODULES test_fuse_adamw_pass)
set_tests_properties(test_fuse_adamw_pass PROPERTIES TIMEOUT 20)
# End of unittests WITH single card and timeout # End of unittests WITH single card and timeout
# NOTE(zyl): unittests WITH single card and WITHOUT timeout # NOTE(zyl): unittests WITH single card and WITHOUT timeout
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle import nn
from paddle.distributed.passes import PassManager, new_pass
def apply_passes(main_prog, startup_prog):
pass_manager = PassManager([new_pass("fuse_adamw")])
pass_manager.apply([main_prog], [startup_prog])
class MLPLayer(nn.Layer):
def __init__(self, input_size, hidden_size, output_size, n):
super().__init__()
self.linear_first = nn.Linear(input_size, hidden_size)
self.decoder_layers = nn.LayerList()
for i in range(n):
self.decoder_layers.append(nn.Linear(hidden_size, hidden_size))
self.linear_last = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = self.linear_first(x)
for layer in self.decoder_layers:
x = layer(x)
x = self.linear_last(x)
return x.mean()
class TestFuseAdamWPass(unittest.TestCase):
def setUp(self):
paddle.disable_static()
np.random.seed(10)
self.input_size = 30
self.hidden_size = 50
self.output_size = 20
self.n = 2
self.range_num = 5
def get_input_x(self, use_amp):
x = []
for _ in range(self.range_num):
if use_amp:
x.append(
np.random.random(size=(10, self.input_size)).astype(
'float16'
)
)
else:
x.append(
np.random.random(size=(10, self.input_size)).astype(
'float32'
)
)
return x
def get_loss_data(self, place, x, use_amp=False, use_apply_passes=False):
paddle.enable_static()
paddle.seed(10)
if place == 'cpu':
use_amp = False
exe = paddle.static.Executor(place=place)
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
optimizer = paddle.optimizer.AdamW(multi_precision=use_amp)
if use_amp:
optimizer = paddle.static.amp.decorate(
optimizer,
init_loss_scaling=128.0,
use_dynamic_loss_scaling=True,
use_pure_fp16=True,
use_fp16_guard=False,
)
with paddle.static.program_guard(train_program, startup_program):
if use_amp:
data = paddle.static.data(
shape=[10, self.input_size], name='X', dtype='float16'
)
else:
data = paddle.static.data(
shape=[10, self.input_size], name='X', dtype='float32'
)
model = MLPLayer(
self.input_size, self.hidden_size, self.output_size, self.n
)
out = model(data)
loss = paddle.mean(out)
optimizer.minimize(loss)
if use_apply_passes:
apply_passes(train_program, startup_program)
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place=place, scope=paddle.static.global_scope())
for i in range(5):
loss_data = exe.run(
train_program, feed={"X": x[i]}, fetch_list=[loss.name]
)
return loss_data
def test_fuse_adamw_pass(self):
place = paddle.CUDAPlace(0)
for use_amp in [True, False]:
x = self.get_input_x(use_amp)
loss_without_passes = self.get_loss_data(place, x, use_amp, True)
loss_with_passes = self.get_loss_data(place, x, use_amp, False)
np.testing.assert_allclose(
np.array(loss_without_passes),
np.array(loss_with_passes),
rtol=1e-6,
atol=1e-6,
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册