未验证 提交 58ce85f7 编写于 作者: Y YuanRisheng 提交者: GitHub

[NewIR]Add program pybind demo for IR (#55102)

* add program pybind

* rename file

* fix compile bugs
上级 cc231cb3
......@@ -125,6 +125,7 @@ set(PYBIND_SRCS
imperative.cc
inference_api.cc
ir.cc
graph.cc
bind_fleet_executor.cc
reader_py.cc
protobuf.cc
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/pybind/graph.h"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/python_headers.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_desc.h"
#include "pybind11/stl.h"
namespace py = pybind11;
using paddle::framework::OpDesc;
using paddle::framework::ProgramDesc;
using paddle::framework::Scope;
using paddle::framework::VarDesc;
using paddle::framework::ir::BuildOperationAdjList;
using paddle::framework::ir::Graph;
using paddle::framework::ir::GraphNum;
using paddle::framework::ir::GraphSafeRemoveNodes;
using paddle::framework::ir::HasCircle;
using paddle::framework::ir::Node;
using paddle::framework::ir::NodeComp;
using paddle::framework::ir::TopologySortOperations;
using pybind11::return_value_policy;
namespace paddle {
namespace pybind {
void BindGraph(py::module *m) {
m->def("graph_safe_remove_nodes",
[](Graph *graph, const std::unordered_set<const Node *> &nodes) {
return GraphSafeRemoveNodes(graph, nodes);
});
m->def("has_circle", HasCircle);
m->def("graph_num", GraphNum);
m->def(
"topology_sort", TopologySortOperations, return_value_policy::reference);
m->def("build_adjacency_list",
BuildOperationAdjList<NodeComp>,
return_value_policy::reference);
py::class_<Graph, std::shared_ptr<Graph>>(
*m,
"Graph",
"The graph is a Directed Acyclic Single Static Assignment Graph, see "
"`paddle::ir::Graph` for details.")
.def(py::init<const ProgramDesc &>())
.def(py::init<const ProgramDesc &, int64_t, int64_t>())
.def("clone", &Graph::Clone)
.def("has", &Graph::Has)
.def("get_bool", &Graph::Get<bool>)
.def("get_int", &Graph::Get<int>)
.def("get_float", &Graph::Get<float>)
.def("get_double", &Graph::Get<double>)
.def("get_string", &Graph::Get<std::string>)
.def("get_marked_nodes",
&Graph::Get<std::unordered_set<const Node *>>,
return_value_policy::reference)
.def("set",
[](Graph &self, const std::string &attr_name, bool attr) {
return self.Set(attr_name, new bool(attr));
})
.def("set",
[](Graph &self, const std::string &attr_name, int attr) {
return self.Set(attr_name, new int(attr));
})
.def("set",
[](Graph &self,
const std::string &attr_name,
const std::string &attr) {
return self.Set(attr_name, new std::string(attr));
})
.def("set",
[](Graph &self, const std::string &attr_name, float attr) {
return self.Set(attr_name, new float(attr));
})
.def("set",
[](Graph &self, const std::string &attr_name, double attr) {
return self.Set(attr_name, new double(attr));
})
.def("set",
[](Graph &self,
const std::string &attr_name,
const std::unordered_set<const Node *> &attr) {
return self.Set(attr_name,
new std::unordered_set<const Node *>(attr));
})
.def("set",
[](Graph &self,
const std::string &attr_name,
const std::unordered_set<std::string> &attr) {
return self.Set(attr_name,
new std::unordered_set<std::string>(attr));
})
.def("set_not_owned",
[](Graph &self, const std::string &attr_name, Scope &attr) {
self.SetNotOwned<Scope>(attr_name, &attr);
})
.def("erase", &Graph::Erase)
.def("nodes", &Graph::Nodes, return_value_policy::reference)
.def(
"create_var_node",
[](Graph &self, VarDesc &var_desc) {
return self.CreateVarNode(&var_desc);
},
return_value_policy::reference)
.def(
"create_op_node",
[](Graph &self, OpDesc &op_desc) {
return self.CreateOpNode(&op_desc);
},
return_value_policy::reference)
.def("create_control_dep_var",
&Graph::CreateControlDepVar,
return_value_policy::reference)
.def("create_empty_node",
&Graph::CreateEmptyNode,
return_value_policy::reference)
.def("release_nodes", &Graph::ReleaseNodes)
.def("remove_node",
[](Graph &self, Node &node) { return self.RemoveNode(&node); })
.def(
"retrieve_node", &Graph::RetrieveNode, return_value_policy::reference)
.def("resolve_hazard", &Graph::ResolveHazard)
.def("origin_program_desc",
&Graph::OriginProgram,
return_value_policy::reference)
.def("sub_graph_size", &Graph::SubGraphsSize)
.def("get_sub_graph", [](Graph &self, int i) {
/* Here we use a lambda function as an empty deleter to avoid the double
free of smart pointer.
Otherwise, this shared pointer will be free both in python and
cpp scope, which will lead a core dumped. */
return std::shared_ptr<Graph>(self.GetSubGraph(i), [](Graph *) {});
});
}
void BindNode(py::module *m) {
py::class_<Node> node(*m, "Node");
node.def("name", &Node::Name)
.def("node_type", &Node::NodeType)
.def("var", &Node::Var, return_value_policy::reference)
.def("op", &Node::Op, return_value_policy::reference)
.def("id", &Node::id)
.def("graph_id", &Node::GraphId)
.def("original_desc_id", &Node::OriginalDescId)
.def("is_op", &Node::IsOp)
.def("is_var", &Node::IsVar)
.def("is_ctrl_var", &Node::IsCtrlVar)
.def("clear_inputs", [](Node &self) { self.inputs.clear(); })
.def("remove_input",
[](Node &self, int node_id) {
auto pos = std::find_if(
self.inputs.begin(),
self.inputs.end(),
[&node_id](const Node *n) { return n->id() == node_id; });
if (pos != self.inputs.end()) {
self.inputs.erase(pos);
}
})
.def("remove_input",
[](Node &self, Node &node) {
auto pos =
std::find(self.inputs.begin(), self.inputs.end(), &node);
if (pos != self.inputs.end()) {
self.inputs.erase(pos);
}
})
.def("append_input",
[](Node &self, Node &node) { self.inputs.push_back(&node); })
.def("clear_outputs", [](Node &self) { self.outputs.clear(); })
.def("remove_output",
[](Node &self, int node_id) {
auto pos = std::find_if(
self.outputs.begin(),
self.outputs.end(),
[&node_id](const Node *n) { return n->id() == node_id; });
if (pos != self.outputs.end()) {
self.outputs.erase(pos);
}
})
.def("remove_output",
[](Node &self, Node &node) {
auto pos =
std::find(self.outputs.begin(), self.outputs.end(), &node);
if (pos != self.outputs.end()) {
self.outputs.erase(pos);
}
})
.def("append_output",
[](Node &self, Node &node) { self.outputs.push_back(&node); })
.def_readwrite("inputs", &Node::inputs)
.def_readwrite("outputs", &Node::outputs);
py::enum_<Node::Type>(node, "Type")
.value("Operation", Node::Type::kOperation)
.value("Variable", Node::Type::kVariable)
.export_values();
py::enum_<Node::Dep>(node, "Dep")
.value("Same", Node::Dep::kSame)
.value("Before", Node::Dep::kBefore)
.value("After", Node::Dep::kAfter)
.value("NoDep", Node::Dep::kNoDep)
.export_values();
}
class PYBIND11_HIDDEN PassAttrGetterSetterRegistry {
private:
PassAttrGetterSetterRegistry() = default;
DISABLE_COPY_AND_ASSIGN(PassAttrGetterSetterRegistry);
using Getter = std::function<py::object(const framework::ir::Pass & /*pass*/,
const std::string & /*attr_name*/)>;
using Setter = std::function<void(const std::string & /*attr_name*/,
const py::object & /*attr_value*/,
framework::ir::Pass * /*pass*/)>;
struct GetterSetter {
Getter getter;
Setter setter;
};
public:
static PassAttrGetterSetterRegistry &Instance() {
static PassAttrGetterSetterRegistry instance;
return instance;
}
void Register(const std::string &attr_type, Getter getter, Setter setter) {
PADDLE_ENFORCE_NOT_NULL(
getter,
platform::errors::InvalidArgument("getter of %s should not be nullptr",
attr_type));
PADDLE_ENFORCE_NOT_NULL(
setter,
platform::errors::InvalidArgument("setter of %s should not be nullptr",
attr_type));
GetterSetter getter_setter;
getter_setter.getter = std::move(getter);
getter_setter.setter = std::move(setter);
PADDLE_ENFORCE_EQ(
getter_setter_map_.emplace(attr_type, getter_setter).second,
true,
platform::errors::InvalidArgument(
"getter and setter of %s have been set before", attr_type));
}
py::object Get(const framework::ir::Pass &pass,
const std::string &attr_name,
const std::string &attr_type) const {
auto iter = getter_setter_map_.find(attr_type);
PADDLE_ENFORCE_EQ(
iter != getter_setter_map_.end(),
true,
platform::errors::InvalidArgument(
"unsupported attribute type %s of %s", attr_type, attr_name));
const auto &getter = iter->second.getter;
return getter(pass, attr_name);
}
void Set(const std::string &attr_name,
const std::string &attr_type,
const py::object &attr_value,
framework::ir::Pass *pass) const {
auto iter = getter_setter_map_.find(attr_type);
PADDLE_ENFORCE_EQ(
iter != getter_setter_map_.end(),
true,
platform::errors::InvalidArgument(
"unsupported attribute type %s of %s", attr_type, attr_name));
const auto &setter = iter->second.setter;
setter(attr_name, attr_value, pass);
}
private:
std::unordered_map<std::string, GetterSetter> getter_setter_map_;
};
#define REGISTER_PASS_ATTR_GETTER_SETTER(attr_type_name, cpp_type) \
do { \
auto getter = [](const framework::ir::Pass &pass, \
const std::string &attr_name) -> py::object { \
auto attr_value = pass.Get<cpp_type>(attr_name); \
return py::cast(attr_value); \
}; \
auto setter = [](const std::string &attr_name, \
const py::object &attr_value, \
framework::ir::Pass *pass) { \
PADDLE_ENFORCE_NOT_NULL( \
pass, platform::errors::InvalidArgument("pass should be provided")); \
try { \
const auto &cpp_attr_value = py::cast<cpp_type>(attr_value); \
pass->Set(attr_name, new cpp_type(cpp_attr_value)); \
} catch (py::cast_error &) { \
PADDLE_THROW(platform::errors::InvalidArgument( \
"type error of attribute %s, expected to be %s", \
attr_name, \
attr_type_name)); \
} \
}; \
PassAttrGetterSetterRegistry::Instance().Register( \
attr_type_name, getter, setter); \
} while (0)
// NOTE: attr_types may be changed
static void SetAttrsToPass(
const std::unordered_map<std::string, py::object> &attrs,
std::unordered_map<std::string, std::string> *attr_types,
framework::ir::Pass *pass) {
for (const auto &name_and_value : attrs) {
const auto &attr_name = name_and_value.first;
const auto &attr_value = name_and_value.second;
auto &attr_type = (*attr_types)[attr_name];
if (attr_type.empty()) {
attr_type = py::cast<std::string>(attr_value.get_type().attr("__name__"));
}
PassAttrGetterSetterRegistry::Instance().Set(
attr_name, attr_type, attr_value, pass);
}
}
static std::vector<std::string> GetPassNames(const py::object &names) {
try {
return {py::cast<std::string>(names)};
} catch (py::cast_error &) {
try {
return py::cast<std::vector<std::string>>(names);
} catch (py::cast_error &) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Pass names must be either str or list[str]"));
}
}
}
void BindPass(py::module *m) {
// NOTE: pass_attr_types is a dict to indicate the type of each attribute.
// Python has only one integral type "int", but C++ has many integral types.
// If pass_attrs = {"nranks": 1} in Python, we cannot know whether the type
// of "nranks" is size_t or int in C++. Therefore, users can set
// pass_attr_types to indicate the type of "nranks" explicitly,
// i.e. pass_attr_types = {"nranks": "size_t"} means that the type of
// "nranks" is size_t in C++.
REGISTER_PASS_ATTR_GETTER_SETTER("bool", bool);
REGISTER_PASS_ATTR_GETTER_SETTER("int", int64_t);
REGISTER_PASS_ATTR_GETTER_SETTER("long", int64_t);
REGISTER_PASS_ATTR_GETTER_SETTER("size_t", size_t);
REGISTER_PASS_ATTR_GETTER_SETTER("float32", float);
// Python float is C++ double
REGISTER_PASS_ATTR_GETTER_SETTER("float", double);
REGISTER_PASS_ATTR_GETTER_SETTER("bytes", std::string);
REGISTER_PASS_ATTR_GETTER_SETTER("str", std::string);
REGISTER_PASS_ATTR_GETTER_SETTER("list[str]", std::vector<std::string>);
m->def("apply_pass",
[](framework::ProgramDesc *main_program,
framework::ProgramDesc *startup_program,
const py::object &py_pass_names,
const std::unordered_map<std::string, py::object> &pass_attrs,
std::unordered_map<std::string, std::string> pass_attr_types) {
auto pass_names = GetPassNames(py_pass_names);
std::vector<std::unique_ptr<framework::ir::Pass>> passes;
std::vector<const framework::ir::Pass *> passes_not_owned;
passes.reserve(pass_names.size());
passes_not_owned.reserve(pass_names.size());
for (const auto &name : pass_names) {
auto pass = framework::ir::PassRegistry::Instance().Get(name);
SetAttrsToPass(pass_attrs, &pass_attr_types, pass.get());
passes.push_back(std::move(pass));
passes_not_owned.push_back(passes.back().get());
}
framework::ir::Pass::ApplyPassesToProgram(
passes_not_owned, main_program, startup_program);
std::unordered_map<std::string, py::object> result_attrs;
for (const auto &pass : passes) {
for (const auto &name_and_value : pass_attrs) {
const auto &attr_name = name_and_value.first;
const auto &attr_type = pass_attr_types.at(attr_name);
result_attrs[attr_name] =
PassAttrGetterSetterRegistry::Instance().Get(
*pass, attr_name, attr_type);
}
}
return result_attrs;
});
}
} // namespace pybind
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <pybind11/pybind11.h>
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace pybind {
void BindGraph(pybind11::module *m);
void BindNode(pybind11::module *m);
void BindPass(pybind11::module *m);
} // namespace pybind
} // namespace paddle
......@@ -21,391 +21,51 @@
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/python_headers.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/program.h"
#include "pybind11/stl.h"
namespace py = pybind11;
using paddle::framework::OpDesc;
using paddle::framework::ProgramDesc;
using paddle::framework::Scope;
using paddle::framework::VarDesc;
using paddle::framework::ir::BuildOperationAdjList;
using paddle::framework::ir::Graph;
using paddle::framework::ir::GraphNum;
using paddle::framework::ir::GraphSafeRemoveNodes;
using paddle::framework::ir::HasCircle;
using paddle::framework::ir::Node;
using paddle::framework::ir::NodeComp;
using paddle::framework::ir::TopologySortOperations;
using ir::Block;
using ir::Operation;
using ir::Program;
using pybind11::return_value_policy;
namespace paddle {
namespace pybind {
void BindGraph(py::module *m) {
m->def("graph_safe_remove_nodes",
[](Graph *graph, const std::unordered_set<const Node *> &nodes) {
return GraphSafeRemoveNodes(graph, nodes);
});
m->def("has_circle", HasCircle);
m->def("graph_num", GraphNum);
m->def(
"topology_sort", TopologySortOperations, return_value_policy::reference);
m->def("build_adjacency_list",
BuildOperationAdjList<NodeComp>,
return_value_policy::reference);
py::class_<Graph, std::shared_ptr<Graph>>(
*m,
"Graph",
"The graph is a Directed Acyclic Single Static Assignment Graph, see "
"`paddle::ir::Graph` for details.")
.def(py::init<const ProgramDesc &>())
.def(py::init<const ProgramDesc &, int64_t, int64_t>())
.def("clone", &Graph::Clone)
.def("has", &Graph::Has)
.def("get_bool", &Graph::Get<bool>)
.def("get_int", &Graph::Get<int>)
.def("get_float", &Graph::Get<float>)
.def("get_double", &Graph::Get<double>)
.def("get_string", &Graph::Get<std::string>)
.def("get_marked_nodes",
&Graph::Get<std::unordered_set<const Node *>>,
return_value_policy::reference)
.def("set",
[](Graph &self, const std::string &attr_name, bool attr) {
return self.Set(attr_name, new bool(attr));
})
.def("set",
[](Graph &self, const std::string &attr_name, int attr) {
return self.Set(attr_name, new int(attr));
})
.def("set",
[](Graph &self,
const std::string &attr_name,
const std::string &attr) {
return self.Set(attr_name, new std::string(attr));
})
.def("set",
[](Graph &self, const std::string &attr_name, float attr) {
return self.Set(attr_name, new float(attr));
})
.def("set",
[](Graph &self, const std::string &attr_name, double attr) {
return self.Set(attr_name, new double(attr));
})
.def("set",
[](Graph &self,
const std::string &attr_name,
const std::unordered_set<const Node *> &attr) {
return self.Set(attr_name,
new std::unordered_set<const Node *>(attr));
})
.def("set",
[](Graph &self,
const std::string &attr_name,
const std::unordered_set<std::string> &attr) {
return self.Set(attr_name,
new std::unordered_set<std::string>(attr));
})
.def("set_not_owned",
[](Graph &self, const std::string &attr_name, Scope &attr) {
self.SetNotOwned<Scope>(attr_name, &attr);
})
.def("erase", &Graph::Erase)
.def("nodes", &Graph::Nodes, return_value_policy::reference)
.def(
"create_var_node",
[](Graph &self, VarDesc &var_desc) {
return self.CreateVarNode(&var_desc);
},
return_value_policy::reference)
.def(
"create_op_node",
[](Graph &self, OpDesc &op_desc) {
return self.CreateOpNode(&op_desc);
},
return_value_policy::reference)
.def("create_control_dep_var",
&Graph::CreateControlDepVar,
return_value_policy::reference)
.def("create_empty_node",
&Graph::CreateEmptyNode,
return_value_policy::reference)
.def("release_nodes", &Graph::ReleaseNodes)
.def("remove_node",
[](Graph &self, Node &node) { return self.RemoveNode(&node); })
.def(
"retrieve_node", &Graph::RetrieveNode, return_value_policy::reference)
.def("resolve_hazard", &Graph::ResolveHazard)
.def("origin_program_desc",
&Graph::OriginProgram,
return_value_policy::reference)
.def("sub_graph_size", &Graph::SubGraphsSize)
.def("get_sub_graph", [](Graph &self, int i) {
/* Here we use a lambda function as an empty deleter to avoid the double
free of smart pointer.
Otherwise, this shared pointer will be free both in python and
cpp scope, which will lead a core dumped. */
return std::shared_ptr<Graph>(self.GetSubGraph(i), [](Graph *) {});
});
}
void BindNode(py::module *m) {
py::class_<Node> node(*m, "Node");
node.def("name", &Node::Name)
.def("node_type", &Node::NodeType)
.def("var", &Node::Var, return_value_policy::reference)
.def("op", &Node::Op, return_value_policy::reference)
.def("id", &Node::id)
.def("graph_id", &Node::GraphId)
.def("original_desc_id", &Node::OriginalDescId)
.def("is_op", &Node::IsOp)
.def("is_var", &Node::IsVar)
.def("is_ctrl_var", &Node::IsCtrlVar)
.def("clear_inputs", [](Node &self) { self.inputs.clear(); })
.def("remove_input",
[](Node &self, int node_id) {
auto pos = std::find_if(
self.inputs.begin(),
self.inputs.end(),
[&node_id](const Node *n) { return n->id() == node_id; });
if (pos != self.inputs.end()) {
self.inputs.erase(pos);
}
})
.def("remove_input",
[](Node &self, Node &node) {
auto pos =
std::find(self.inputs.begin(), self.inputs.end(), &node);
if (pos != self.inputs.end()) {
self.inputs.erase(pos);
}
})
.def("append_input",
[](Node &self, Node &node) { self.inputs.push_back(&node); })
.def("clear_outputs", [](Node &self) { self.outputs.clear(); })
.def("remove_output",
[](Node &self, int node_id) {
auto pos = std::find_if(
self.outputs.begin(),
self.outputs.end(),
[&node_id](const Node *n) { return n->id() == node_id; });
if (pos != self.outputs.end()) {
self.outputs.erase(pos);
}
})
.def("remove_output",
[](Node &self, Node &node) {
auto pos =
std::find(self.outputs.begin(), self.outputs.end(), &node);
if (pos != self.outputs.end()) {
self.outputs.erase(pos);
}
})
.def("append_output",
[](Node &self, Node &node) { self.outputs.push_back(&node); })
.def_readwrite("inputs", &Node::inputs)
.def_readwrite("outputs", &Node::outputs);
py::enum_<Node::Type>(node, "Type")
.value("Operation", Node::Type::kOperation)
.value("Variable", Node::Type::kVariable)
.export_values();
py::enum_<Node::Dep>(node, "Dep")
.value("Same", Node::Dep::kSame)
.value("Before", Node::Dep::kBefore)
.value("After", Node::Dep::kAfter)
.value("NoDep", Node::Dep::kNoDep)
.export_values();
void BindProgram(py::module *m) {
py::class_<Program> program(*m, "Program");
program.def("parameters_num", &Program::parameters_num)
.def("block", &Program::block, return_value_policy::reference)
.def("print", [](Program &self) {
std::ostringstream print_stream;
self.Print(print_stream);
LOG(INFO) << print_stream.str();
});
}
class PYBIND11_HIDDEN PassAttrGetterSetterRegistry {
private:
PassAttrGetterSetterRegistry() = default;
DISABLE_COPY_AND_ASSIGN(PassAttrGetterSetterRegistry);
using Getter = std::function<py::object(const framework::ir::Pass & /*pass*/,
const std::string & /*attr_name*/)>;
using Setter = std::function<void(const std::string & /*attr_name*/,
const py::object & /*attr_value*/,
framework::ir::Pass * /*pass*/)>;
struct GetterSetter {
Getter getter;
Setter setter;
};
public:
static PassAttrGetterSetterRegistry &Instance() {
static PassAttrGetterSetterRegistry instance;
return instance;
}
void Register(const std::string &attr_type, Getter getter, Setter setter) {
PADDLE_ENFORCE_NOT_NULL(
getter,
platform::errors::InvalidArgument("getter of %s should not be nullptr",
attr_type));
PADDLE_ENFORCE_NOT_NULL(
setter,
platform::errors::InvalidArgument("setter of %s should not be nullptr",
attr_type));
GetterSetter getter_setter;
getter_setter.getter = std::move(getter);
getter_setter.setter = std::move(setter);
PADDLE_ENFORCE_EQ(
getter_setter_map_.emplace(attr_type, getter_setter).second,
true,
platform::errors::InvalidArgument(
"getter and setter of %s have been set before", attr_type));
}
py::object Get(const framework::ir::Pass &pass,
const std::string &attr_name,
const std::string &attr_type) const {
auto iter = getter_setter_map_.find(attr_type);
PADDLE_ENFORCE_EQ(
iter != getter_setter_map_.end(),
true,
platform::errors::InvalidArgument(
"unsupported attribute type %s of %s", attr_type, attr_name));
const auto &getter = iter->second.getter;
return getter(pass, attr_name);
}
void Set(const std::string &attr_name,
const std::string &attr_type,
const py::object &attr_value,
framework::ir::Pass *pass) const {
auto iter = getter_setter_map_.find(attr_type);
PADDLE_ENFORCE_EQ(
iter != getter_setter_map_.end(),
true,
platform::errors::InvalidArgument(
"unsupported attribute type %s of %s", attr_type, attr_name));
const auto &setter = iter->second.setter;
setter(attr_name, attr_value, pass);
}
private:
std::unordered_map<std::string, GetterSetter> getter_setter_map_;
};
#define REGISTER_PASS_ATTR_GETTER_SETTER(attr_type_name, cpp_type) \
do { \
auto getter = [](const framework::ir::Pass &pass, \
const std::string &attr_name) -> py::object { \
auto attr_value = pass.Get<cpp_type>(attr_name); \
return py::cast(attr_value); \
}; \
auto setter = [](const std::string &attr_name, \
const py::object &attr_value, \
framework::ir::Pass *pass) { \
PADDLE_ENFORCE_NOT_NULL( \
pass, platform::errors::InvalidArgument("pass should be provided")); \
try { \
const auto &cpp_attr_value = py::cast<cpp_type>(attr_value); \
pass->Set(attr_name, new cpp_type(cpp_attr_value)); \
} catch (py::cast_error &) { \
PADDLE_THROW(platform::errors::InvalidArgument( \
"type error of attribute %s, expected to be %s", \
attr_name, \
attr_type_name)); \
} \
}; \
PassAttrGetterSetterRegistry::Instance().Register( \
attr_type_name, getter, setter); \
} while (0)
// NOTE: attr_types may be changed
static void SetAttrsToPass(
const std::unordered_map<std::string, py::object> &attrs,
std::unordered_map<std::string, std::string> *attr_types,
framework::ir::Pass *pass) {
for (const auto &name_and_value : attrs) {
const auto &attr_name = name_and_value.first;
const auto &attr_value = name_and_value.second;
auto &attr_type = (*attr_types)[attr_name];
if (attr_type.empty()) {
attr_type = py::cast<std::string>(attr_value.get_type().attr("__name__"));
}
PassAttrGetterSetterRegistry::Instance().Set(
attr_name, attr_type, attr_value, pass);
}
void BindBlock(py::module *m) {
py::class_<Block> block(*m, "Block");
block.def("front", &Block::front, return_value_policy::reference)
.def("get_op_list", [](Block &self) -> py::list {
py::list op_list;
for (auto iter = self.begin(); iter != self.end(); iter++) {
op_list.append(*iter);
}
return op_list;
});
}
static std::vector<std::string> GetPassNames(const py::object &names) {
try {
return {py::cast<std::string>(names)};
} catch (py::cast_error &) {
try {
return py::cast<std::vector<std::string>>(names);
} catch (py::cast_error &) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Pass names must be either str or list[str]"));
}
}
void BindOperation(py::module *m) {
py::class_<Operation> op(*m, "Operation");
op.def("name", &Operation::name);
}
void BindPass(py::module *m) {
// NOTE: pass_attr_types is a dict to indicate the type of each attribute.
// Python has only one integral type "int", but C++ has many integral types.
// If pass_attrs = {"nranks": 1} in Python, we cannot know whether the type
// of "nranks" is size_t or int in C++. Therefore, users can set
// pass_attr_types to indicate the type of "nranks" explicitly,
// i.e. pass_attr_types = {"nranks": "size_t"} means that the type of
// "nranks" is size_t in C++.
REGISTER_PASS_ATTR_GETTER_SETTER("bool", bool);
REGISTER_PASS_ATTR_GETTER_SETTER("int", int64_t);
REGISTER_PASS_ATTR_GETTER_SETTER("long", int64_t);
REGISTER_PASS_ATTR_GETTER_SETTER("size_t", size_t);
REGISTER_PASS_ATTR_GETTER_SETTER("float32", float);
// Python float is C++ double
REGISTER_PASS_ATTR_GETTER_SETTER("float", double);
REGISTER_PASS_ATTR_GETTER_SETTER("bytes", std::string);
REGISTER_PASS_ATTR_GETTER_SETTER("str", std::string);
REGISTER_PASS_ATTR_GETTER_SETTER("list[str]", std::vector<std::string>);
m->def("apply_pass",
[](framework::ProgramDesc *main_program,
framework::ProgramDesc *startup_program,
const py::object &py_pass_names,
const std::unordered_map<std::string, py::object> &pass_attrs,
std::unordered_map<std::string, std::string> pass_attr_types) {
auto pass_names = GetPassNames(py_pass_names);
std::vector<std::unique_ptr<framework::ir::Pass>> passes;
std::vector<const framework::ir::Pass *> passes_not_owned;
passes.reserve(pass_names.size());
passes_not_owned.reserve(pass_names.size());
for (const auto &name : pass_names) {
auto pass = framework::ir::PassRegistry::Instance().Get(name);
SetAttrsToPass(pass_attrs, &pass_attr_types, pass.get());
passes.push_back(std::move(pass));
passes_not_owned.push_back(passes.back().get());
}
framework::ir::Pass::ApplyPassesToProgram(
passes_not_owned, main_program, startup_program);
std::unordered_map<std::string, py::object> result_attrs;
for (const auto &pass : passes) {
for (const auto &name_and_value : pass_attrs) {
const auto &attr_name = name_and_value.first;
const auto &attr_type = pass_attr_types.at(attr_name);
result_attrs[attr_name] =
PassAttrGetterSetterRegistry::Instance().Get(
*pass, attr_name, attr_type);
}
}
return result_attrs;
});
void BindNewIR(pybind11::module *m) {
BindProgram(m);
BindBlock(m);
BindOperation(m);
}
} // namespace pybind
......
......@@ -16,12 +16,8 @@
#include <pybind11/pybind11.h>
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace pybind {
void BindGraph(pybind11::module *m);
void BindNode(pybind11::module *m);
void BindPass(pybind11::module *m);
void BindNewIR(pybind11::module *m);
} // namespace pybind
} // namespace paddle
......@@ -103,11 +103,11 @@ limitations under the License. */
#include "paddle/fluid/pybind/global_value_getter_setter.h"
#include "paddle/fluid/pybind/gloo_context_py.h"
#include "paddle/fluid/pybind/gloo_wrapper_py.h"
#include "paddle/fluid/pybind/graph.h"
#include "paddle/fluid/pybind/heter_wrapper_py.h"
#include "paddle/fluid/pybind/imperative.h"
#include "paddle/fluid/pybind/inference_api.h"
#include "paddle/fluid/pybind/io.h"
#include "paddle/fluid/pybind/ir.h"
#include "paddle/fluid/pybind/metrics_py.h"
#include "paddle/fluid/pybind/ps_gpu_wrapper_py.h"
#include "paddle/fluid/pybind/pybind_variant_caster.h"
......
......@@ -103,11 +103,11 @@ limitations under the License. */
#include "paddle/fluid/pybind/global_value_getter_setter.h"
#include "paddle/fluid/pybind/gloo_context_py.h"
#include "paddle/fluid/pybind/gloo_wrapper_py.h"
#include "paddle/fluid/pybind/graph.h"
#include "paddle/fluid/pybind/heter_wrapper_py.h"
#include "paddle/fluid/pybind/imperative.h"
#include "paddle/fluid/pybind/inference_api.h"
#include "paddle/fluid/pybind/io.h"
#include "paddle/fluid/pybind/ir.h"
#include "paddle/fluid/pybind/metrics_py.h"
#include "paddle/fluid/pybind/ps_gpu_wrapper_py.h"
#include "paddle/fluid/pybind/pybind_variant_caster.h"
......
......@@ -115,6 +115,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/global_value_getter_setter.h"
#include "paddle/fluid/pybind/gloo_context_py.h"
#include "paddle/fluid/pybind/gloo_wrapper_py.h"
#include "paddle/fluid/pybind/graph.h"
#include "paddle/fluid/pybind/heter_wrapper_py.h"
#include "paddle/fluid/pybind/imperative.h"
#include "paddle/fluid/pybind/inference_api.h"
......@@ -192,6 +193,7 @@ limitations under the License. */
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/eager/nan_inf_utils.h"
#include "paddle/fluid/imperative/layout_autotune.h"
#include "paddle/fluid/ir_adaptor/translator/translate.h"
#include "paddle/fluid/prim/utils/eager/eager_tensor_operants.h"
#include "paddle/fluid/prim/utils/static/static_tensor_operants.h"
#include "paddle/fluid/pybind/eager_utils.h"
......@@ -2722,7 +2724,7 @@ All parameter, weight, gradient are variables in Paddle.
// Add skipped op list
m.def("set_skipped_op_list",
[](const std::string &op_list) { egr::SetSkipOpList(op_list); });
m.def("translate_newirprogram", &paddle::TranslateLegacyProgramToProgram);
BindFleetWrapper(&m);
BindIO(&m);
BindParallelExecutor(m);
......@@ -2799,6 +2801,8 @@ All parameter, weight, gradient are variables in Paddle.
GetCurrentWorkerInfo(&m);
GetAllWorkerInfos(&m);
#endif
BindNewIR(&m);
}
} // namespace pybind
} // namespace paddle
......@@ -103,11 +103,11 @@ limitations under the License. */
#include "paddle/fluid/pybind/global_value_getter_setter.h"
#include "paddle/fluid/pybind/gloo_context_py.h"
#include "paddle/fluid/pybind/gloo_wrapper_py.h"
#include "paddle/fluid/pybind/graph.h"
#include "paddle/fluid/pybind/heter_wrapper_py.h"
#include "paddle/fluid/pybind/imperative.h"
#include "paddle/fluid/pybind/inference_api.h"
#include "paddle/fluid/pybind/io.h"
#include "paddle/fluid/pybind/ir.h"
#include "paddle/fluid/pybind/metrics_py.h"
#include "paddle/fluid/pybind/ps_gpu_wrapper_py.h"
#include "paddle/fluid/pybind/pybind_variant_caster.h"
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle.fluid import core
paddle.enable_static()
def get_ir_program():
x = paddle.randn([4, 4])
main_program, start_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, start_program):
x_s = paddle.static.data('x', [4, 4], x.dtype)
x_s.stop_gradient = False
y_s = paddle.matmul(x_s, x_s)
y_s = paddle.add(x_s, y_s)
y_s = paddle.tanh(y_s)
newir_program = core.translate_newirprogram(main_program.desc)
return newir_program
class TestPybind(unittest.TestCase):
def test_program(self):
newir_program = get_ir_program()
newir_program.print()
ops = newir_program.block().get_op_list()
self.assertTrue(
len(ops), 4
) # ir program add "builtin.get_parameter" by default, so size is 4
for op in ops:
# check op.name function
if op.name() == 'pd.tanh':
self.assertTrue(True)
return
self.assertTrue(False)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册