diff --git a/paddle/fluid/framework/details/multi_devices_helper.h b/paddle/fluid/framework/details/multi_devices_helper.h index 7e2c41dd4f7950c304619243bc5c3db179407236..82ce045fad72368413aeea9fa23939c42a7de6e3 100644 --- a/paddle/fluid/framework/details/multi_devices_helper.h +++ b/paddle/fluid/framework/details/multi_devices_helper.h @@ -77,10 +77,6 @@ typedef std::vector> ParamsAndGrads; constexpr char kParamsAndDenseGrads[] = "params_and_dense_grads"; constexpr char kParamsAndSparseGrads[] = "params_and_sparse_grads"; -typedef std::vector ProgramDescs; -constexpr char kProgramDescs[] = "program_descs"; -constexpr char kStartupProgramDescs[] = "startup_program_descs"; - typedef std::unordered_set PinnedVars; constexpr char kPinnedVars[] = "pinned_vars"; diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 0a856330f8e742b5fc2bb797f1402174dc786889..652ce77d844570e8d3d4292122c8663c7941bfd4 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -15,7 +15,9 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph_helper.h" #include #include +#include "paddle/fluid/framework/op_proto_maker.h" +DECLARE_bool(convert_all_blocks); DEFINE_string(print_sub_graph_dir, "", "FLAGS_print_sub_graph_dir is used " "to print the nodes of sub_graphs."); @@ -431,6 +433,117 @@ std::vector TopologySortGraphByDescOrder(const Graph &graph) { return ret; } +static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) { + desc->SetType("fill_constant"); + desc->SetAttr( + OpProtoAndCheckerMaker::OpRoleAttrName(), + (static_cast(OpRole::kBackward) | static_cast(OpRole::kLoss))); + desc->SetAttr("value", 1.0f); + std::vector output_names; + for (auto out : node.outputs) { + output_names.emplace_back(out->Name()); + } + desc->SetOutput("Out", output_names); + return desc; +} + +static void GetGraphOpDesc(const std::vector &nodes, + std::vector *ops) { + for (Node *n : nodes) { + // if node is not Op, skip + if (!n->IsOp()) continue; + + // create fill_constant op + if (n->Name() == "scale_loss_grad") { + ops->emplace_back(); + auto &desc = ops->back(); + ReplaceScaleLossGradOp(*n, &desc); + } else if (n->Op()) { + ops->emplace_back(*n->Op()); + } + // delete no OpDesc op + } +} + +static void GraphToBlock(const Graph &graph, proto::BlockDesc *block, + const SortKind *sort_kind) { + // Remove the unneeded variables after memory optimization. + std::unordered_set vars2remove; + if (graph.Has(kGraphToProgramVarsToRemove)) { + vars2remove = + graph.Get>(kGraphToProgramVarsToRemove); + VLOG(2) << "graph (id: " << block->idx() << ") to program remove " + << vars2remove.size() << " nodes"; + } + + block->clear_vars(); + std::unordered_set visited_vars; + for (Node *n : graph.Nodes()) { + if (n->IsVar()) { + if (n->Var() && visited_vars.count(n->Var()->Name()) == 0 && + !vars2remove.count(n->Var()->Name()) && + n->GetVarNodeBlockId() == graph.GetBlockId()) { + visited_vars.insert(n->Var()->Name()); + block->add_vars()->MergeFrom(*n->Var()->Proto()); + } + } + } + block->clear_ops(); + + std::vector nodes; + if (sort_kind != nullptr) { + // Inference Memory Optimize relays on this branch. + nodes = TopologyVarientSort(graph, *sort_kind); + } else { + if (FLAGS_convert_all_blocks) { + nodes = TopologySortGraphByDescOrder(graph); + } else { + nodes = TopologySortOperations(graph); + } + } + + std::vector ops; + GetGraphOpDesc(nodes, &ops); + for (auto &op : ops) { + block->add_ops()->MergeFrom(*op.Proto()); + } +} + +void GraphToProgram(const Graph &graph, ProgramDesc *program, + const SortKind *sort_kind) { + PADDLE_ENFORCE_EQ(graph.IsMainGraph(), true, + platform::errors::InvalidArgument( + "This graph is a sub_graph, " + "and can't convert to program individually")); + PADDLE_ENFORCE_NOT_NULL( + program, + platform::errors::InvalidArgument( + "program must not be nullptr when converting graph to program")); + + proto::ProgramDesc program_pb(*(program->Proto())); + auto block = program_pb.mutable_blocks(kRootBlockIndex); + block->set_idx(kRootBlockIndex); + + if (FLAGS_convert_all_blocks) { + GraphToBlock(*graph.GetSubGraph(kRootBlockIndex), block, sort_kind); + + VLOG(3) << "Graph to program need convert " << graph.SubGraphsSize() + << " sub graph"; + for (size_t idx = 0; idx < graph.SubGraphsSize(); ++idx) { + // avoid kRootBlockIndex not 0 + if (idx == kRootBlockIndex) continue; + + block = program_pb.add_blocks(); + block->set_idx(idx); + GraphToBlock(*graph.GetSubGraph(idx), block, sort_kind); + } + } else { + GraphToBlock(graph, block, sort_kind); + } + + program->CopyFrom(program_pb); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_helper.h b/paddle/fluid/framework/ir/graph_helper.h index 3309f600730e8c3fa4e5a3ab5a186e1550a61cf0..f00e3ae37b4da249214162582e4d48aa309190bf 100644 --- a/paddle/fluid/framework/ir/graph_helper.h +++ b/paddle/fluid/framework/ir/graph_helper.h @@ -27,6 +27,10 @@ namespace paddle { namespace framework { namespace ir { +constexpr char kGraphToProgramVarsToRemove[] = + "__graph_to_program_vars_to_remove__"; +constexpr char kGraphToProgramSortKind[] = "__graph_to_program_sort_kind__"; + // Compare nodes via node id. class Graph; @@ -117,6 +121,9 @@ std::vector FilterByNodeWrapper(const Graph &graph) { std::vector TopologySortGraphByDescOrder(const Graph &graph); +void GraphToProgram(const Graph &graph, ProgramDesc *p_program, + const SortKind *sort_kind = nullptr); + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_to_program_pass.cc b/paddle/fluid/framework/ir/graph_to_program_pass.cc index b31ccd48aa98b8a08d5ec19efcb3cb5c80d82d5e..3ad591c6dff04ce6334d1675616f0cf2d5c39182 100644 --- a/paddle/fluid/framework/ir/graph_to_program_pass.cc +++ b/paddle/fluid/framework/ir/graph_to_program_pass.cc @@ -17,11 +17,8 @@ limitations under the License. */ #include #include -#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/op_proto_maker.h" -DECLARE_bool(convert_all_blocks); - namespace paddle { namespace framework { class ProgramDesc; @@ -33,116 +30,12 @@ namespace framework { namespace ir { void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const { - PADDLE_ENFORCE_EQ(graph->IsMainGraph(), true, - platform::errors::InvalidArgument( - "This graph is a sub_graph, " - "and can't convert to program individually")); - - ProgramDesc& program = Get("program"); - - std::unique_ptr program_pb( - new proto::ProgramDesc(*program.Proto())); - - auto block = program_pb->mutable_blocks(kRootBlockIndex); - block->set_idx(kRootBlockIndex); - - if (FLAGS_convert_all_blocks) { - GraphToBlock(graph->GetSubGraph(kRootBlockIndex), block); - - VLOG(3) << "Graph to program need convert " << graph->SubGraphsSize() - << " sub graph"; - for (size_t idx = 0; idx < graph->SubGraphsSize(); ++idx) { - // avoid kRootBlockIndex not 0 - if (idx == kRootBlockIndex) continue; - - block = program_pb->add_blocks(); - block->set_idx(idx); - GraphToBlock(graph->GetSubGraph(idx), block); - } - } else { - GraphToBlock(graph, block); - } - - program.CopyFrom(*program_pb); -} - -OpDesc* ReplaceScaleLossGradOp(ir::Node* node, OpDesc* desc) { - desc->SetType("fill_constant"); - desc->SetAttr( - OpProtoAndCheckerMaker::OpRoleAttrName(), - (static_cast(OpRole::kBackward) | static_cast(OpRole::kLoss))); - desc->SetAttr("value", 1.0f); - std::vector output_names; - for (auto out : node->outputs) { - output_names.emplace_back(out->Name()); - } - desc->SetOutput("Out", output_names); - return desc; -} - -std::vector* GetGraphOpDesc(const std::vector& nodes, - std::vector* ops) { - for (ir::Node* n : nodes) { - // if node is not Op, skip - if (!n->IsOp()) continue; - - // create fill_constant op - if (n->Name() == "scale_loss_grad") { - ops->emplace_back(); - auto& desc = ops->back(); - ReplaceScaleLossGradOp(n, &desc); - } else if (n->Op()) { - ops->emplace_back(*n->Op()); - } else { - // delete no OpDesc op - } - } - return ops; -} - -void GraphToProgramPass::GraphToBlock(const Graph* graph, - proto::BlockDesc* block) const { - // Remove the unneeded variables after memory optimization. - std::unordered_set vars2remove; - if (graph->Has(kGraphToProgramVarsToRemove)) { - vars2remove = graph->Get>( - kGraphToProgramVarsToRemove); - VLOG(2) << "graph (id: " << block->idx() << ") to program remove " - << vars2remove.size() << " nodes"; - } - - block->clear_vars(); - std::unordered_set visited_vars; - for (ir::Node* n : graph->Nodes()) { - if (n->IsVar()) { - if (n->Var() && visited_vars.count(n->Var()->Name()) == 0 && - !vars2remove.count(n->Var()->Name()) && - n->GetVarNodeBlockId() == graph->GetBlockId()) { - visited_vars.insert(n->Var()->Name()); - block->add_vars()->MergeFrom(*n->Var()->Proto()); - } - } - } - block->clear_ops(); - - std::vector nodes; + auto& program = Get("program"); if (Has(kGraphToProgramSortKind)) { - // Inference Memory Optimize relays on this branch. - int sort_kind = Get(kGraphToProgramSortKind); - nodes = TopologyVarientSort( - *graph, static_cast(sort_kind)); + auto sort_kind = static_cast(Get(kGraphToProgramSortKind)); + GraphToProgram(*graph, &program, &sort_kind); } else { - if (FLAGS_convert_all_blocks) { - nodes = TopologySortGraphByDescOrder(*graph); - } else { - nodes = TopologySortOperations(*graph); - } - } - - std::vector ops; - GetGraphOpDesc(nodes, &ops); - for (auto& op : ops) { - block->add_ops()->MergeFrom(*op.Proto()); + GraphToProgram(*graph, &program, nullptr); } } diff --git a/paddle/fluid/framework/ir/graph_to_program_pass.h b/paddle/fluid/framework/ir/graph_to_program_pass.h index 4997c67a92fdc8a09d170952ec82256fae6f148d..3789a0a623df2d231176acaa317ee25a53fa6322 100644 --- a/paddle/fluid/framework/ir/graph_to_program_pass.h +++ b/paddle/fluid/framework/ir/graph_to_program_pass.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/pass.h" namespace paddle { @@ -22,16 +23,9 @@ namespace ir { class Graph; -const char kGraphToProgramVarsToRemove[] = - "__graph_to_program_vars_to_remove__"; -const char kGraphToProgramSortKind[] = "__graph_to_program_sort_kind__"; - class GraphToProgramPass : public Pass { protected: void ApplyImpl(ir::Graph* graph) const override; - - private: - void GraphToBlock(const Graph* graph, proto::BlockDesc* block) const; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 0e5f5867f47b25f3efdcf648c4243cec310ad4ca..42b6244788da0932e8735e039cad2a8bdb35b531 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -69,6 +69,26 @@ Graph* Pass::Apply(Graph* graph) const { return graph; } +void Pass::Apply(ProgramDesc* main_program, + ProgramDesc* startup_program) const { + PADDLE_ENFORCE_NOT_NULL(main_program, platform::errors::InvalidArgument( + "main program must be provided")); + PADDLE_ENFORCE_NOT_NULL( + startup_program, + platform::errors::InvalidArgument("startup program must be provided")); + + Graph graph(*main_program); + Apply(&graph); + + // TODO(zjl): support details::kStartupProgramDescs and details::kProgramDescs + ProgramDesc new_main_program; + GraphToProgram(graph, &new_main_program); + main_program->CopyFrom(*new_main_program.Proto()); + + startup_program->Flush(); + main_program->Flush(); +} + PassRegistry& PassRegistry::Instance() { static PassRegistry g_pass_info_map; return g_pass_info_map; diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 9c306479bf5d6a656950fd6822594c735e55c6e3..8fb96bec9cbd56204725e4528c8801f5aa4308e6 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -29,8 +29,15 @@ limitations under the License. */ namespace paddle { namespace framework { +namespace details { +using ProgramDescs = std::vector; +constexpr char kProgramDescs[] = "program_descs"; +constexpr char kStartupProgramDescs[] = "startup_program_descs"; +} // namespace details + namespace ir { class Graph; + template struct PassRegistrar; @@ -57,6 +64,8 @@ class Pass { Graph *Apply(Graph *graph) const; + void Apply(ProgramDesc *main_program, ProgramDesc *startup_program) const; + // Get a reference to the attributed previously set. template AttrType &Get(const std::string &attr_name) const { diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 58471dd04ac65e251aeb09fb712fe6c87e7fd0b3..f362808a4b9528e603f2649d367543dd7460baf3 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -3,7 +3,7 @@ include_directories(${PADDLE_SOURCE_DIR}/paddle/fluid/platform) set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper prune - feed_fetch_method pass_builder parallel_executor profiler layer tracer engine scope_pool + feed_fetch_method pass pass_builder parallel_executor profiler layer tracer engine scope_pool analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator) diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index fc8d7ac949a0217931a33d58e8c506a86ab6eba4..4a4c34b149e400929be2a279327fc3d70ba5f40b 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -23,7 +23,9 @@ #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/python_headers.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/var_desc.h" #include "pybind11/stl.h" @@ -184,5 +186,150 @@ void BindNode(py::module *m) { .value("Variable", Node::Type::kVariable) .export_values(); } + +class PYBIND11_HIDDEN PassAttrGetterSetterRegistry { + private: + PassAttrGetterSetterRegistry() = default; + DISABLE_COPY_AND_ASSIGN(PassAttrGetterSetterRegistry); + + using Getter = std::function; + using Setter = std::function; + + struct GetterSetter { + Getter getter; + Setter setter; + }; + + public: + static PassAttrGetterSetterRegistry &Instance() { + static PassAttrGetterSetterRegistry instance; + return instance; + } + + void Register(const std::string &attr_type, Getter getter, Setter setter) { + PADDLE_ENFORCE_NOT_NULL( + getter, platform::errors::InvalidArgument( + "getter of %s should not be nullptr", attr_type)); + PADDLE_ENFORCE_NOT_NULL( + setter, platform::errors::InvalidArgument( + "setter of %s should not be nullptr", attr_type)); + GetterSetter getter_setter; + getter_setter.getter = std::move(getter); + getter_setter.setter = std::move(setter); + PADDLE_ENFORCE_EQ( + getter_setter_map_.emplace(attr_type, getter_setter).second, true, + platform::errors::InvalidArgument( + "getter and setter of %s have been set before", attr_type)); + } + + py::object Get(const framework::ir::Pass &pass, const std::string &attr_name, + const std::string &attr_type) const { + auto iter = getter_setter_map_.find(attr_type); + PADDLE_ENFORCE_EQ( + iter != getter_setter_map_.end(), true, + platform::errors::InvalidArgument("unsupported attribute type %s of %s", + attr_type, attr_name)); + const auto &getter = iter->second.getter; + return getter(pass, attr_name); + } + + void Set(const std::string &attr_name, const std::string &attr_type, + const py::object &attr_value, framework::ir::Pass *pass) const { + auto iter = getter_setter_map_.find(attr_type); + PADDLE_ENFORCE_EQ( + iter != getter_setter_map_.end(), true, + platform::errors::InvalidArgument("unsupported attribute type %s of %s", + attr_type, attr_name)); + const auto &setter = iter->second.setter; + setter(attr_name, attr_value, pass); + } + + private: + std::unordered_map getter_setter_map_; +}; + +#define REGISTER_PASS_ATTR_GETTER_SETTER(attr_type_name, cpp_type) \ + do { \ + auto getter = [](const framework::ir::Pass &pass, \ + const std::string &attr_name) -> py::object { \ + auto attr_value = pass.Get(attr_name); \ + return py::cast(attr_value); \ + }; \ + auto setter = [](const std::string &attr_name, \ + const py::object &attr_value, \ + framework::ir::Pass *pass) { \ + PADDLE_ENFORCE_NOT_NULL( \ + pass, platform::errors::InvalidArgument("pass should be provided")); \ + try { \ + const auto &cpp_attr_value = py::cast(attr_value); \ + pass->Set(attr_name, new cpp_type(cpp_attr_value)); \ + } catch (py::cast_error &) { \ + PADDLE_THROW(platform::errors::InvalidArgument( \ + "type error of attribute %s, expected to be %s", attr_name, \ + attr_type_name)); \ + } \ + }; \ + PassAttrGetterSetterRegistry::Instance().Register(attr_type_name, getter, \ + setter); \ + } while (0) + +// NOTE: attr_types may be changed +static void SetAttrsToPass( + const std::unordered_map &attrs, + std::unordered_map *attr_types, + framework::ir::Pass *pass) { + for (const auto &name_and_value : attrs) { + const auto &attr_name = name_and_value.first; + const auto &attr_value = name_and_value.second; + auto &attr_type = (*attr_types)[attr_name]; + if (attr_type.empty()) { + attr_type = py::cast(attr_value.get_type().attr("__name__")); + } + PassAttrGetterSetterRegistry::Instance().Set(attr_name, attr_type, + attr_value, pass); + } +} + +void BindPass(py::module *m) { + // NOTE: pass_attr_types is a dict to indicate the type of each attribute. + // Python has only one integral type "int", but C++ has many integral types. + // If pass_attrs = {"nranks": 1} in Python, we cannot know whether the type + // of "nranks" is size_t or int in C++. Therefore, users can set + // pass_attr_types to indicate the type of "nranks" explicitly, + // i.e. pass_attr_types = {"nranks": "size_t"} means that the type of + // "nranks" is size_t in C++. + REGISTER_PASS_ATTR_GETTER_SETTER("int", int64_t); + REGISTER_PASS_ATTR_GETTER_SETTER("long", int64_t); + REGISTER_PASS_ATTR_GETTER_SETTER("size_t", size_t); + REGISTER_PASS_ATTR_GETTER_SETTER("float32", float); + // Python float is C++ double + REGISTER_PASS_ATTR_GETTER_SETTER("float", double); + REGISTER_PASS_ATTR_GETTER_SETTER("bytes", std::string); + REGISTER_PASS_ATTR_GETTER_SETTER("str", std::string); + + m->def( + "apply_pass", + [](framework::ProgramDesc *main_program, + framework::ProgramDesc *startup_program, const std::string &pass_name, + const std::unordered_map &pass_attrs, + std::unordered_map pass_attr_types) { + auto pass = framework::ir::PassRegistry::Instance().Get(pass_name); + SetAttrsToPass(pass_attrs, &pass_attr_types, pass.get()); + pass->Apply(main_program, startup_program); + std::unordered_map result_attrs; + for (const auto &name_and_value : pass_attrs) { + const auto &attr_name = name_and_value.first; + const auto &attr_type = pass_attr_types.at(attr_name); + result_attrs[attr_name] = + PassAttrGetterSetterRegistry::Instance().Get(*pass, attr_name, + attr_type); + } + return result_attrs; + }); +} + } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/ir.h b/paddle/fluid/pybind/ir.h index 5bee70eba695b6d71c4df03e7ffe5d8d11384172..2cc1459bbe0fe8dc27f292999d01ed34211ed080 100644 --- a/paddle/fluid/pybind/ir.h +++ b/paddle/fluid/pybind/ir.h @@ -21,5 +21,6 @@ namespace paddle { namespace pybind { void BindGraph(pybind11::module *m); void BindNode(pybind11::module *m); +void BindPass(pybind11::module *m); } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 2cda20959178c8fdcd426fbb6d0de1359181c51e..3bbe6d6ef4b1f28bac142e4c9c9088de1f7f1810 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -3105,6 +3105,7 @@ All parameter, weight, gradient are variables in Paddle. #endif BindGraph(&m); BindNode(&m); + BindPass(&m); BindInferenceApi(&m); BindCompatible(&m); BindDataset(&m); diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 5e644fefa3ffb964cbaf74358ae86a3abed28261..2247d49483035ca9b418f6f53bf31a8112603cec 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -3232,6 +3232,22 @@ class Block(object): return ret_var +def _apply_pass(main_program, + startup_program, + pass_name, + pass_attrs={}, + pass_attr_types={}): + assert isinstance(pass_attrs, dict), "pass_attrs must be dict" + assert isinstance(pass_attr_types, dict), "pass_attr_types must be dict" + tmp_main_program = core.ProgramDesc(main_program.desc) + tmp_startup_program = core.ProgramDesc(startup_program.desc) + attrs = core.apply_pass(tmp_main_program, tmp_startup_program, pass_name, + pass_attrs, pass_attr_types) + main_program._rebuild_from_desc(tmp_main_program) + startup_program._rebuild_from_desc(tmp_startup_program) + return attrs + + class IrNode(object): """ Python IrNode. Beneath it is a core.Node, which is used for Ir Pass. @@ -4148,6 +4164,91 @@ class Program(object): # compiled program, i.e. Graph self._graph = None + def _find_var_class_kwargs(self, new_desc): + old_desc = self.desc + all_new_vars = [] + block_num = new_desc.num_blocks() + for idx in range(block_num): + new_block_desc = new_desc.block(idx) + all_new_vars.append([]) + block_new_vars = all_new_vars[-1] + for new_var_desc in new_block_desc.all_vars(): + if self.blocks[idx].has_var(new_var_desc.name()): + old_var = self.blocks[idx].var(new_var_desc.name()) + else: + old_var = None + + kwargs = { + 'type': new_var_desc.type(), + 'name': new_var_desc.name(), + 'shape': new_var_desc.shape(), + 'dtype': new_var_desc.dtype(), + 'lod_level': new_var_desc.lod_level(), + 'error_clip': old_var.error_clip + if old_var is not None else None, + 'stop_gradient': old_var.stop_gradient + if old_var is not None else False, + 'is_data': old_var.is_data + if old_var is not None else False, + 'need_check_feed': new_var_desc.need_check_feed(), + 'belong_to_optimizer': old_var.belong_to_optimizer + if old_var is not None else False, + } + + if isinstance(old_var, Parameter): + kwargs.update({ + 'trainable': old_var.trainable, + 'optimize_attr': old_var.optimize_attr, + 'regularizer': old_var.regularizer, + 'do_model_average': old_var.do_model_average, + 'need_clip': old_var.need_clip, + 'is_distributed': old_var.is_distributed, + 'is_parameter': old_var.is_parameter, + }) + block_new_vars.append({ + 'class': Parameter, + 'kwargs': copy.deepcopy(kwargs), + }) + else: + kwargs['persistable'] = new_var_desc.persistable() + block_new_vars.append({ + 'class': Variable, + 'kwargs': copy.deepcopy(kwargs), + }) + + return all_new_vars + + def _rebuild_from_desc(self, desc): + all_new_vars = self._find_var_class_kwargs(desc) + block_num = desc.num_blocks() + assert block_num == len(all_new_vars) + + # clear old blocks and desc + self.blocks = [] + self.desc = None + + # create new blocks and set desc + self.desc = desc + self.blocks = [Block(self, idx) for idx in range(block_num)] + + # add new vars first + for idx in range(block_num): + block = self.blocks[idx] + for new_var in all_new_vars[idx]: + clazz = new_var['class'] + kwargs = new_var['kwargs'] + kwargs['block'] = block + clazz(**kwargs) + + # then append op + for idx in range(block_num): + block = self.blocks[idx] + block_desc = self.desc.block(idx) + for op_idx in range(block_desc.op_size()): + op_desc = block_desc.op(op_idx) + op = Operator(block=block, desc=op_desc) + block.ops.append(op) + def global_seed(self, seed=0): """ Set global seed for Program diff --git a/python/paddle/fluid/tests/unittests/test_apply_pass_to_program.py b/python/paddle/fluid/tests/unittests/test_apply_pass_to_program.py new file mode 100644 index 0000000000000000000000000000000000000000..b35fc9bae651a8dd766dd1fd2372c311e64aeb52 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_apply_pass_to_program.py @@ -0,0 +1,66 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.vision.models import resnet50 +from paddle.nn import CrossEntropyLoss +from paddle.fluid.framework import _apply_pass +import unittest + + +class TestApplyPassToProgram(unittest.TestCase): + def setUp(self): + paddle.enable_static() + + def global_block_contains_op(self, program, op_type): + for op in program.global_block().ops: + if op.type == op_type: + return True + return False + + def test_case(self): + image = paddle.static.data( + name="image", shape=[None, 3, 224, 224], dtype="float32") + label = paddle.static.data(name="label", shape=[None, 1], dtype="int64") + model = resnet50() + loss_fn = CrossEntropyLoss() + pred = model(image) + loss = loss_fn(pred, label) + optimizer = paddle.optimizer.SGD(learning_rate=1e-3) + optimizer.minimize(loss) + + startup = paddle.static.default_startup_program() + main = paddle.static.default_main_program() + + fused_op = "fused_elemwise_add_activation" + self.assertFalse(self.global_block_contains_op(main, fused_op)) + attrs = { + "int_attr": -3, + "size_t_attr": 10, + "float_attr": 3.25, + "float32_attr": -4.5, + "str_attr": "any string attr value", + } + attr_types = { + "size_t_attr": "size_t", + "float32_attr": "float32", + } + ret_attrs = _apply_pass(main, startup, "fuse_elewise_add_act_pass", + attrs, attr_types) + self.assertEqual(attrs, ret_attrs) + self.assertTrue(self.global_block_contains_op(main, fused_op)) + + +if __name__ == "__main__": + unittest.main()