未验证 提交 145cdb5a 编写于 作者: Z Zeng Jinle 提交者: GitHub

Add basic functions of Program Pass (#34524)

* add basic APIs

* add attr_types

* follow comments

* change pass attr types

* add set pass attribute codes

* refine PADDLE_THROW
上级 af886995
......@@ -77,10 +77,6 @@ typedef std::vector<std::pair<std::string, std::string>> ParamsAndGrads;
constexpr char kParamsAndDenseGrads[] = "params_and_dense_grads";
constexpr char kParamsAndSparseGrads[] = "params_and_sparse_grads";
typedef std::vector<ProgramDesc> ProgramDescs;
constexpr char kProgramDescs[] = "program_descs";
constexpr char kStartupProgramDescs[] = "startup_program_descs";
typedef std::unordered_set<std::string> PinnedVars;
constexpr char kPinnedVars[] = "pinned_vars";
......
......@@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_helper.h"
#include <queue>
#include <stack>
#include "paddle/fluid/framework/op_proto_maker.h"
DECLARE_bool(convert_all_blocks);
DEFINE_string(print_sub_graph_dir, "",
"FLAGS_print_sub_graph_dir is used "
"to print the nodes of sub_graphs.");
......@@ -431,6 +433,117 @@ std::vector<ir::Node *> TopologySortGraphByDescOrder(const Graph &graph) {
return ret;
}
static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) {
desc->SetType("fill_constant");
desc->SetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName(),
(static_cast<int>(OpRole::kBackward) | static_cast<int>(OpRole::kLoss)));
desc->SetAttr("value", 1.0f);
std::vector<std::string> output_names;
for (auto out : node.outputs) {
output_names.emplace_back(out->Name());
}
desc->SetOutput("Out", output_names);
return desc;
}
static void GetGraphOpDesc(const std::vector<Node *> &nodes,
std::vector<OpDesc> *ops) {
for (Node *n : nodes) {
// if node is not Op, skip
if (!n->IsOp()) continue;
// create fill_constant op
if (n->Name() == "scale_loss_grad") {
ops->emplace_back();
auto &desc = ops->back();
ReplaceScaleLossGradOp(*n, &desc);
} else if (n->Op()) {
ops->emplace_back(*n->Op());
}
// delete no OpDesc op
}
}
static void GraphToBlock(const Graph &graph, proto::BlockDesc *block,
const SortKind *sort_kind) {
// Remove the unneeded variables after memory optimization.
std::unordered_set<std::string> vars2remove;
if (graph.Has(kGraphToProgramVarsToRemove)) {
vars2remove =
graph.Get<std::unordered_set<std::string>>(kGraphToProgramVarsToRemove);
VLOG(2) << "graph (id: " << block->idx() << ") to program remove "
<< vars2remove.size() << " nodes";
}
block->clear_vars();
std::unordered_set<std::string> visited_vars;
for (Node *n : graph.Nodes()) {
if (n->IsVar()) {
if (n->Var() && visited_vars.count(n->Var()->Name()) == 0 &&
!vars2remove.count(n->Var()->Name()) &&
n->GetVarNodeBlockId() == graph.GetBlockId()) {
visited_vars.insert(n->Var()->Name());
block->add_vars()->MergeFrom(*n->Var()->Proto());
}
}
}
block->clear_ops();
std::vector<Node *> nodes;
if (sort_kind != nullptr) {
// Inference Memory Optimize relays on this branch.
nodes = TopologyVarientSort(graph, *sort_kind);
} else {
if (FLAGS_convert_all_blocks) {
nodes = TopologySortGraphByDescOrder(graph);
} else {
nodes = TopologySortOperations(graph);
}
}
std::vector<OpDesc> ops;
GetGraphOpDesc(nodes, &ops);
for (auto &op : ops) {
block->add_ops()->MergeFrom(*op.Proto());
}
}
void GraphToProgram(const Graph &graph, ProgramDesc *program,
const SortKind *sort_kind) {
PADDLE_ENFORCE_EQ(graph.IsMainGraph(), true,
platform::errors::InvalidArgument(
"This graph is a sub_graph, "
"and can't convert to program individually"));
PADDLE_ENFORCE_NOT_NULL(
program,
platform::errors::InvalidArgument(
"program must not be nullptr when converting graph to program"));
proto::ProgramDesc program_pb(*(program->Proto()));
auto block = program_pb.mutable_blocks(kRootBlockIndex);
block->set_idx(kRootBlockIndex);
if (FLAGS_convert_all_blocks) {
GraphToBlock(*graph.GetSubGraph(kRootBlockIndex), block, sort_kind);
VLOG(3) << "Graph to program need convert " << graph.SubGraphsSize()
<< " sub graph";
for (size_t idx = 0; idx < graph.SubGraphsSize(); ++idx) {
// avoid kRootBlockIndex not 0
if (idx == kRootBlockIndex) continue;
block = program_pb.add_blocks();
block->set_idx(idx);
GraphToBlock(*graph.GetSubGraph(idx), block, sort_kind);
}
} else {
GraphToBlock(graph, block, sort_kind);
}
program->CopyFrom(program_pb);
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -27,6 +27,10 @@ namespace paddle {
namespace framework {
namespace ir {
constexpr char kGraphToProgramVarsToRemove[] =
"__graph_to_program_vars_to_remove__";
constexpr char kGraphToProgramSortKind[] = "__graph_to_program_sort_kind__";
// Compare nodes via node id.
class Graph;
......@@ -117,6 +121,9 @@ std::vector<T *> FilterByNodeWrapper(const Graph &graph) {
std::vector<ir::Node *> TopologySortGraphByDescOrder(const Graph &graph);
void GraphToProgram(const Graph &graph, ProgramDesc *p_program,
const SortKind *sort_kind = nullptr);
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -17,11 +17,8 @@ limitations under the License. */
#include <gflags/gflags.h>
#include <algorithm>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
DECLARE_bool(convert_all_blocks);
namespace paddle {
namespace framework {
class ProgramDesc;
......@@ -33,116 +30,12 @@ namespace framework {
namespace ir {
void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_EQ(graph->IsMainGraph(), true,
platform::errors::InvalidArgument(
"This graph is a sub_graph, "
"and can't convert to program individually"));
ProgramDesc& program = Get<ProgramDesc>("program");
std::unique_ptr<proto::ProgramDesc> program_pb(
new proto::ProgramDesc(*program.Proto()));
auto block = program_pb->mutable_blocks(kRootBlockIndex);
block->set_idx(kRootBlockIndex);
if (FLAGS_convert_all_blocks) {
GraphToBlock(graph->GetSubGraph(kRootBlockIndex), block);
VLOG(3) << "Graph to program need convert " << graph->SubGraphsSize()
<< " sub graph";
for (size_t idx = 0; idx < graph->SubGraphsSize(); ++idx) {
// avoid kRootBlockIndex not 0
if (idx == kRootBlockIndex) continue;
block = program_pb->add_blocks();
block->set_idx(idx);
GraphToBlock(graph->GetSubGraph(idx), block);
}
} else {
GraphToBlock(graph, block);
}
program.CopyFrom(*program_pb);
}
OpDesc* ReplaceScaleLossGradOp(ir::Node* node, OpDesc* desc) {
desc->SetType("fill_constant");
desc->SetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName(),
(static_cast<int>(OpRole::kBackward) | static_cast<int>(OpRole::kLoss)));
desc->SetAttr("value", 1.0f);
std::vector<std::string> output_names;
for (auto out : node->outputs) {
output_names.emplace_back(out->Name());
}
desc->SetOutput("Out", output_names);
return desc;
}
std::vector<OpDesc>* GetGraphOpDesc(const std::vector<ir::Node*>& nodes,
std::vector<OpDesc>* ops) {
for (ir::Node* n : nodes) {
// if node is not Op, skip
if (!n->IsOp()) continue;
// create fill_constant op
if (n->Name() == "scale_loss_grad") {
ops->emplace_back();
auto& desc = ops->back();
ReplaceScaleLossGradOp(n, &desc);
} else if (n->Op()) {
ops->emplace_back(*n->Op());
} else {
// delete no OpDesc op
}
}
return ops;
}
void GraphToProgramPass::GraphToBlock(const Graph* graph,
proto::BlockDesc* block) const {
// Remove the unneeded variables after memory optimization.
std::unordered_set<std::string> vars2remove;
if (graph->Has(kGraphToProgramVarsToRemove)) {
vars2remove = graph->Get<std::unordered_set<std::string>>(
kGraphToProgramVarsToRemove);
VLOG(2) << "graph (id: " << block->idx() << ") to program remove "
<< vars2remove.size() << " nodes";
}
block->clear_vars();
std::unordered_set<std::string> visited_vars;
for (ir::Node* n : graph->Nodes()) {
if (n->IsVar()) {
if (n->Var() && visited_vars.count(n->Var()->Name()) == 0 &&
!vars2remove.count(n->Var()->Name()) &&
n->GetVarNodeBlockId() == graph->GetBlockId()) {
visited_vars.insert(n->Var()->Name());
block->add_vars()->MergeFrom(*n->Var()->Proto());
}
}
}
block->clear_ops();
std::vector<ir::Node*> nodes;
auto& program = Get<ProgramDesc>("program");
if (Has(kGraphToProgramSortKind)) {
// Inference Memory Optimize relays on this branch.
int sort_kind = Get<int>(kGraphToProgramSortKind);
nodes = TopologyVarientSort(
*graph, static_cast<framework::ir::SortKind>(sort_kind));
auto sort_kind = static_cast<SortKind>(Get<int>(kGraphToProgramSortKind));
GraphToProgram(*graph, &program, &sort_kind);
} else {
if (FLAGS_convert_all_blocks) {
nodes = TopologySortGraphByDescOrder(*graph);
} else {
nodes = TopologySortOperations(*graph);
}
}
std::vector<OpDesc> ops;
GetGraphOpDesc(nodes, &ops);
for (auto& op : ops) {
block->add_ops()->MergeFrom(*op.Proto());
GraphToProgram(*graph, &program, nullptr);
}
}
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
......@@ -22,16 +23,9 @@ namespace ir {
class Graph;
const char kGraphToProgramVarsToRemove[] =
"__graph_to_program_vars_to_remove__";
const char kGraphToProgramSortKind[] = "__graph_to_program_sort_kind__";
class GraphToProgramPass : public Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
void GraphToBlock(const Graph* graph, proto::BlockDesc* block) const;
};
} // namespace ir
......
......@@ -69,6 +69,26 @@ Graph* Pass::Apply(Graph* graph) const {
return graph;
}
void Pass::Apply(ProgramDesc* main_program,
ProgramDesc* startup_program) const {
PADDLE_ENFORCE_NOT_NULL(main_program, platform::errors::InvalidArgument(
"main program must be provided"));
PADDLE_ENFORCE_NOT_NULL(
startup_program,
platform::errors::InvalidArgument("startup program must be provided"));
Graph graph(*main_program);
Apply(&graph);
// TODO(zjl): support details::kStartupProgramDescs and details::kProgramDescs
ProgramDesc new_main_program;
GraphToProgram(graph, &new_main_program);
main_program->CopyFrom(*new_main_program.Proto());
startup_program->Flush();
main_program->Flush();
}
PassRegistry& PassRegistry::Instance() {
static PassRegistry g_pass_info_map;
return g_pass_info_map;
......
......@@ -29,8 +29,15 @@ limitations under the License. */
namespace paddle {
namespace framework {
namespace details {
using ProgramDescs = std::vector<ProgramDesc>;
constexpr char kProgramDescs[] = "program_descs";
constexpr char kStartupProgramDescs[] = "startup_program_descs";
} // namespace details
namespace ir {
class Graph;
template <typename PassType>
struct PassRegistrar;
......@@ -57,6 +64,8 @@ class Pass {
Graph *Apply(Graph *graph) const;
void Apply(ProgramDesc *main_program, ProgramDesc *startup_program) const;
// Get a reference to the attributed previously set.
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const {
......
......@@ -3,7 +3,7 @@
include_directories(${PADDLE_SOURCE_DIR}/paddle/fluid/platform)
set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper prune
feed_fetch_method pass_builder parallel_executor profiler layer tracer engine scope_pool
feed_fetch_method pass pass_builder parallel_executor profiler layer tracer engine scope_pool
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context
gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator)
......
......@@ -23,7 +23,9 @@
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/python_headers.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_desc.h"
#include "pybind11/stl.h"
......@@ -184,5 +186,150 @@ void BindNode(py::module *m) {
.value("Variable", Node::Type::kVariable)
.export_values();
}
class PYBIND11_HIDDEN PassAttrGetterSetterRegistry {
private:
PassAttrGetterSetterRegistry() = default;
DISABLE_COPY_AND_ASSIGN(PassAttrGetterSetterRegistry);
using Getter = std::function<py::object(const framework::ir::Pass & /*pass*/,
const std::string & /*attr_name*/)>;
using Setter = std::function<void(const std::string & /*attr_name*/,
const py::object & /*attr_value*/,
framework::ir::Pass * /*pass*/)>;
struct GetterSetter {
Getter getter;
Setter setter;
};
public:
static PassAttrGetterSetterRegistry &Instance() {
static PassAttrGetterSetterRegistry instance;
return instance;
}
void Register(const std::string &attr_type, Getter getter, Setter setter) {
PADDLE_ENFORCE_NOT_NULL(
getter, platform::errors::InvalidArgument(
"getter of %s should not be nullptr", attr_type));
PADDLE_ENFORCE_NOT_NULL(
setter, platform::errors::InvalidArgument(
"setter of %s should not be nullptr", attr_type));
GetterSetter getter_setter;
getter_setter.getter = std::move(getter);
getter_setter.setter = std::move(setter);
PADDLE_ENFORCE_EQ(
getter_setter_map_.emplace(attr_type, getter_setter).second, true,
platform::errors::InvalidArgument(
"getter and setter of %s have been set before", attr_type));
}
py::object Get(const framework::ir::Pass &pass, const std::string &attr_name,
const std::string &attr_type) const {
auto iter = getter_setter_map_.find(attr_type);
PADDLE_ENFORCE_EQ(
iter != getter_setter_map_.end(), true,
platform::errors::InvalidArgument("unsupported attribute type %s of %s",
attr_type, attr_name));
const auto &getter = iter->second.getter;
return getter(pass, attr_name);
}
void Set(const std::string &attr_name, const std::string &attr_type,
const py::object &attr_value, framework::ir::Pass *pass) const {
auto iter = getter_setter_map_.find(attr_type);
PADDLE_ENFORCE_EQ(
iter != getter_setter_map_.end(), true,
platform::errors::InvalidArgument("unsupported attribute type %s of %s",
attr_type, attr_name));
const auto &setter = iter->second.setter;
setter(attr_name, attr_value, pass);
}
private:
std::unordered_map<std::string, GetterSetter> getter_setter_map_;
};
#define REGISTER_PASS_ATTR_GETTER_SETTER(attr_type_name, cpp_type) \
do { \
auto getter = [](const framework::ir::Pass &pass, \
const std::string &attr_name) -> py::object { \
auto attr_value = pass.Get<cpp_type>(attr_name); \
return py::cast(attr_value); \
}; \
auto setter = [](const std::string &attr_name, \
const py::object &attr_value, \
framework::ir::Pass *pass) { \
PADDLE_ENFORCE_NOT_NULL( \
pass, platform::errors::InvalidArgument("pass should be provided")); \
try { \
const auto &cpp_attr_value = py::cast<cpp_type>(attr_value); \
pass->Set(attr_name, new cpp_type(cpp_attr_value)); \
} catch (py::cast_error &) { \
PADDLE_THROW(platform::errors::InvalidArgument( \
"type error of attribute %s, expected to be %s", attr_name, \
attr_type_name)); \
} \
}; \
PassAttrGetterSetterRegistry::Instance().Register(attr_type_name, getter, \
setter); \
} while (0)
// NOTE: attr_types may be changed
static void SetAttrsToPass(
const std::unordered_map<std::string, py::object> &attrs,
std::unordered_map<std::string, std::string> *attr_types,
framework::ir::Pass *pass) {
for (const auto &name_and_value : attrs) {
const auto &attr_name = name_and_value.first;
const auto &attr_value = name_and_value.second;
auto &attr_type = (*attr_types)[attr_name];
if (attr_type.empty()) {
attr_type = py::cast<std::string>(attr_value.get_type().attr("__name__"));
}
PassAttrGetterSetterRegistry::Instance().Set(attr_name, attr_type,
attr_value, pass);
}
}
void BindPass(py::module *m) {
// NOTE: pass_attr_types is a dict to indicate the type of each attribute.
// Python has only one integral type "int", but C++ has many integral types.
// If pass_attrs = {"nranks": 1} in Python, we cannot know whether the type
// of "nranks" is size_t or int in C++. Therefore, users can set
// pass_attr_types to indicate the type of "nranks" explicitly,
// i.e. pass_attr_types = {"nranks": "size_t"} means that the type of
// "nranks" is size_t in C++.
REGISTER_PASS_ATTR_GETTER_SETTER("int", int64_t);
REGISTER_PASS_ATTR_GETTER_SETTER("long", int64_t);
REGISTER_PASS_ATTR_GETTER_SETTER("size_t", size_t);
REGISTER_PASS_ATTR_GETTER_SETTER("float32", float);
// Python float is C++ double
REGISTER_PASS_ATTR_GETTER_SETTER("float", double);
REGISTER_PASS_ATTR_GETTER_SETTER("bytes", std::string);
REGISTER_PASS_ATTR_GETTER_SETTER("str", std::string);
m->def(
"apply_pass",
[](framework::ProgramDesc *main_program,
framework::ProgramDesc *startup_program, const std::string &pass_name,
const std::unordered_map<std::string, py::object> &pass_attrs,
std::unordered_map<std::string, std::string> pass_attr_types) {
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
SetAttrsToPass(pass_attrs, &pass_attr_types, pass.get());
pass->Apply(main_program, startup_program);
std::unordered_map<std::string, py::object> result_attrs;
for (const auto &name_and_value : pass_attrs) {
const auto &attr_name = name_and_value.first;
const auto &attr_type = pass_attr_types.at(attr_name);
result_attrs[attr_name] =
PassAttrGetterSetterRegistry::Instance().Get(*pass, attr_name,
attr_type);
}
return result_attrs;
});
}
} // namespace pybind
} // namespace paddle
......@@ -21,5 +21,6 @@ namespace paddle {
namespace pybind {
void BindGraph(pybind11::module *m);
void BindNode(pybind11::module *m);
void BindPass(pybind11::module *m);
} // namespace pybind
} // namespace paddle
......@@ -3105,6 +3105,7 @@ All parameter, weight, gradient are variables in Paddle.
#endif
BindGraph(&m);
BindNode(&m);
BindPass(&m);
BindInferenceApi(&m);
BindCompatible(&m);
BindDataset(&m);
......
......@@ -3232,6 +3232,22 @@ class Block(object):
return ret_var
def _apply_pass(main_program,
startup_program,
pass_name,
pass_attrs={},
pass_attr_types={}):
assert isinstance(pass_attrs, dict), "pass_attrs must be dict"
assert isinstance(pass_attr_types, dict), "pass_attr_types must be dict"
tmp_main_program = core.ProgramDesc(main_program.desc)
tmp_startup_program = core.ProgramDesc(startup_program.desc)
attrs = core.apply_pass(tmp_main_program, tmp_startup_program, pass_name,
pass_attrs, pass_attr_types)
main_program._rebuild_from_desc(tmp_main_program)
startup_program._rebuild_from_desc(tmp_startup_program)
return attrs
class IrNode(object):
"""
Python IrNode. Beneath it is a core.Node, which is used for Ir Pass.
......@@ -4148,6 +4164,91 @@ class Program(object):
# compiled program, i.e. Graph
self._graph = None
def _find_var_class_kwargs(self, new_desc):
old_desc = self.desc
all_new_vars = []
block_num = new_desc.num_blocks()
for idx in range(block_num):
new_block_desc = new_desc.block(idx)
all_new_vars.append([])
block_new_vars = all_new_vars[-1]
for new_var_desc in new_block_desc.all_vars():
if self.blocks[idx].has_var(new_var_desc.name()):
old_var = self.blocks[idx].var(new_var_desc.name())
else:
old_var = None
kwargs = {
'type': new_var_desc.type(),
'name': new_var_desc.name(),
'shape': new_var_desc.shape(),
'dtype': new_var_desc.dtype(),
'lod_level': new_var_desc.lod_level(),
'error_clip': old_var.error_clip
if old_var is not None else None,
'stop_gradient': old_var.stop_gradient
if old_var is not None else False,
'is_data': old_var.is_data
if old_var is not None else False,
'need_check_feed': new_var_desc.need_check_feed(),
'belong_to_optimizer': old_var.belong_to_optimizer
if old_var is not None else False,
}
if isinstance(old_var, Parameter):
kwargs.update({
'trainable': old_var.trainable,
'optimize_attr': old_var.optimize_attr,
'regularizer': old_var.regularizer,
'do_model_average': old_var.do_model_average,
'need_clip': old_var.need_clip,
'is_distributed': old_var.is_distributed,
'is_parameter': old_var.is_parameter,
})
block_new_vars.append({
'class': Parameter,
'kwargs': copy.deepcopy(kwargs),
})
else:
kwargs['persistable'] = new_var_desc.persistable()
block_new_vars.append({
'class': Variable,
'kwargs': copy.deepcopy(kwargs),
})
return all_new_vars
def _rebuild_from_desc(self, desc):
all_new_vars = self._find_var_class_kwargs(desc)
block_num = desc.num_blocks()
assert block_num == len(all_new_vars)
# clear old blocks and desc
self.blocks = []
self.desc = None
# create new blocks and set desc
self.desc = desc
self.blocks = [Block(self, idx) for idx in range(block_num)]
# add new vars first
for idx in range(block_num):
block = self.blocks[idx]
for new_var in all_new_vars[idx]:
clazz = new_var['class']
kwargs = new_var['kwargs']
kwargs['block'] = block
clazz(**kwargs)
# then append op
for idx in range(block_num):
block = self.blocks[idx]
block_desc = self.desc.block(idx)
for op_idx in range(block_desc.op_size()):
op_desc = block_desc.op(op_idx)
op = Operator(block=block, desc=op_desc)
block.ops.append(op)
def global_seed(self, seed=0):
"""
Set global seed for Program
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.vision.models import resnet50
from paddle.nn import CrossEntropyLoss
from paddle.fluid.framework import _apply_pass
import unittest
class TestApplyPassToProgram(unittest.TestCase):
def setUp(self):
paddle.enable_static()
def global_block_contains_op(self, program, op_type):
for op in program.global_block().ops:
if op.type == op_type:
return True
return False
def test_case(self):
image = paddle.static.data(
name="image", shape=[None, 3, 224, 224], dtype="float32")
label = paddle.static.data(name="label", shape=[None, 1], dtype="int64")
model = resnet50()
loss_fn = CrossEntropyLoss()
pred = model(image)
loss = loss_fn(pred, label)
optimizer = paddle.optimizer.SGD(learning_rate=1e-3)
optimizer.minimize(loss)
startup = paddle.static.default_startup_program()
main = paddle.static.default_main_program()
fused_op = "fused_elemwise_add_activation"
self.assertFalse(self.global_block_contains_op(main, fused_op))
attrs = {
"int_attr": -3,
"size_t_attr": 10,
"float_attr": 3.25,
"float32_attr": -4.5,
"str_attr": "any string attr value",
}
attr_types = {
"size_t_attr": "size_t",
"float32_attr": "float32",
}
ret_attrs = _apply_pass(main, startup, "fuse_elewise_add_act_pass",
attrs, attr_types)
self.assertEqual(attrs, ret_attrs)
self.assertTrue(self.global_block_contains_op(main, fused_op))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册