diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index d91993bd4f8c04539cd189a4145350498911c513..75f922d2cca6855a67be7284ae407e549a1a1afb 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -226,7 +226,7 @@ std::unique_ptr MultiDevSSAGraphBuilderBase::ApplyImpl( * Only variables should be the leaves of graph. */ AddOutputToLeafOps(&result); - result.Erase(kGraphOps); + result.Erase(kGraphOps); return graph; } diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 47fcf96a3f92b1f915e5254fff36feb8b2870730..8bb3c27bdd32d07d58913db043569f6a3bf69aeb 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -109,7 +109,6 @@ class Graph { attr_dels_[attr_name] = []() {}; } - template void Erase(const std::string &attr_name) { PADDLE_ENFORCE(attrs_.count(attr_name) != 0, "%s not set in the graph", attr_name); diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 72b0f216d3aafbb95931590935b0bf967a8d5be8..2545f5312f4c1799dc656dc51114421d794ba72e 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -3,7 +3,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune fe if(WITH_PYTHON) list(APPEND PYBIND_DEPS py_func_op) endif() -set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc imperative.cc) +set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc imperative.cc ir.cc) if(WITH_PYTHON) if(WITH_AMD_GPU) diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc new file mode 100644 index 0000000000000000000000000000000000000000..d32fe58f8695a5c14f276ef038416f5c47f3400f --- /dev/null +++ b/paddle/fluid/pybind/ir.cc @@ -0,0 +1,103 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pybind/ir.h" +#include +#include +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/var_desc.h" +#include "pybind11/stl.h" + +namespace py = pybind11; +using paddle::framework::ir::Graph; +using paddle::framework::ir::Node; +using paddle::framework::OpDesc; +using paddle::framework::ProgramDesc; +using paddle::framework::VarDesc; +using pybind11::return_value_policy; + +namespace paddle { +namespace pybind { +void BindGraph(py::module *m) { + py::class_>( + *m, "Graph", + "The graph is a Directed Acyclic Single Static Assignment Graph, see " + "`paddle::ir::Graph` for details.") + .def(py::init()) + .def("has", &Graph::Has) + .def("get_int", &Graph::Get) + .def("get_float", &Graph::Get) + .def("get_double", &Graph::Get) + .def("get_string", &Graph::Get) + .def("set", [](Graph &self, const std::string &attr_name, + int attr) { return self.Set(attr_name, new int(attr)); }) + .def("set", + [](Graph &self, const std::string &attr_name, + const std::string &attr) { + return self.Set(attr_name, new std::string(attr)); + }) + .def("set", + [](Graph &self, const std::string &attr_name, float attr) { + return self.Set(attr_name, new float(attr)); + }) + .def("set", + [](Graph &self, const std::string &attr_name, double attr) { + return self.Set(attr_name, new double(attr)); + }) + .def("erase", &Graph::Erase) + .def("nodes", &Graph::Nodes, return_value_policy::reference) + .def("create_var_node", + [](Graph &self, VarDesc &var_desc) { + return self.CreateVarNode(&var_desc); + }, + return_value_policy::reference) + .def("create_op_node", + [](Graph &self, OpDesc &op_desc) { + return self.CreateOpNode(&op_desc); + }, + return_value_policy::reference) + .def("create_control_dep_var", &Graph::CreateControlDepVar, + return_value_policy::reference) + .def("create_empty_node", &Graph::CreateEmptyNode, + return_value_policy::reference) + .def("release_nodes", &Graph::ReleaseNodes) + .def("remove_node", + [](Graph &self, Node &node) { return self.RemoveNode(&node); }) + .def("retrieve_node", &Graph::RetrieveNode, + return_value_policy::reference) + .def("resolve_hazard", &Graph::ResolveHazard); +} + +void BindNode(py::module *m) { + py::class_ node(*m, "Node"); + node.def("name", &Node::Name) + .def("node_type", &Node::NodeType) + .def("var", &Node::Var) + .def("op", &Node::Op) + .def("id", &Node::id) + .def("is_op", &Node::IsOp) + .def("is_var", &Node::IsVar) + .def("is_ctrl_var", &Node::IsCtrlVar) + .def_readwrite("inputs", &Node::inputs) + .def_readwrite("outputs", &Node::outputs); + + py::enum_(node, "Type") + .value("Operation", Node::Type::kOperation) + .value("Variable", Node::Type::kVariable) + .export_values(); +} +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/ir.h b/paddle/fluid/pybind/ir.h new file mode 100644 index 0000000000000000000000000000000000000000..5bee70eba695b6d71c4df03e7ffe5d8d11384172 --- /dev/null +++ b/paddle/fluid/pybind/ir.h @@ -0,0 +1,25 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace pybind { +void BindGraph(pybind11::module *m); +void BindNode(pybind11::module *m); +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index a540c6fca1f18830b5ac07d6b45c0e586351a21d..1edff3a1f5a089eddc76e35f0233bee57afa6cb7 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -49,6 +49,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/const_value.h" #include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/imperative.h" +#include "paddle/fluid/pybind/ir.h" #include "paddle/fluid/pybind/protobuf.h" #include "paddle/fluid/pybind/pybind.h" // NOLINT #include "paddle/fluid/pybind/recordio.h" @@ -775,7 +776,12 @@ All parameter, weight, gradient are variables in Paddle. }) .def("set_int", [](ir::Pass &self, const std::string &name, int val) { self.Set(name, new int(val)); }) - .def("type", &ir::Pass::Type); + .def("type", &ir::Pass::Type) + .def("apply", [](ir::Pass &self, std::shared_ptr graph) { + std::unique_ptr origin_graph(graph.get()); + auto optim_graph = self.Apply(std::move(origin_graph)); + graph.reset(optim_graph.release()); + }); py::class_> pb( m, "PassBuilder"); @@ -1042,6 +1048,9 @@ All parameter, weight, gradient are variables in Paddle. BindRecordIOWriter(&m); BindAsyncExecutor(&m); + + BindGraph(&m); + BindNode(&m); } } // namespace pybind } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_ir_graph.py b/python/paddle/fluid/tests/unittests/test_ir_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..ba6e4a8b2effade67821f5da9c2bbf7849a8cf79 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_ir_graph.py @@ -0,0 +1,146 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import six +from paddle import fluid + + +class TestIRGraph(unittest.TestCase): + """ + TODO(fc500110): `resolve_hazard` api will be tested when it can be used. + """ + + def test_nodes(self): + graph = build_graph() + self.assertTrue( + {node.name() + for node in graph.nodes()} == {"x1", "x2", "out", "sum"}) + + def test_has_set_get(self): + graph = build_graph() + for attr_name in ["int", "float", "string"]: + self.assertFalse(graph.has(attr_name)) + graph.set("int", 1) + graph.set("float", 0.5) + graph.set("string", "string") + for attr_name in ["int", "float", "string"]: + self.assertTrue(graph.has(attr_name)) + + self.assertTrue(graph.get_int("int") == 1) + self.assertTrue(graph.get_float("float") == 0.5) + self.assertTrue(graph.get_string("string") == "string") + + def test_erase(self): + graph = build_graph() + graph.set("test", 0) + self.assertTrue(graph.has("test")) + graph.erase("test") + self.assertFalse(graph.has("test")) + + def test_create_var_node(self): + prog = fluid.core.ProgramDesc() + block = prog.block(0) + shape = [10, 20] + x1 = block.var(six.b("x1")) + x1.set_type(fluid.core.VarDesc.VarType.LOD_TENSOR) + x1.set_shape(shape) + graph = fluid.core.Graph(prog) + node = graph.create_var_node(x1) + self.assertTrue(node.node_type() == fluid.core.Node.Type.Variable) + + def test_create_op_node(self): + prog = fluid.core.ProgramDesc() + block = prog.block(0) + sum_op_desc = block.append_op() + graph = fluid.core.Graph(prog) + node = graph.create_op_node(sum_op_desc) + self.assertTrue(node.node_type() == fluid.core.Node.Type.Operation) + + def test_create_control_dep_var(self): + graph = build_graph() + name = "__control_var@{}".format(len(graph.nodes())) + node = graph.create_control_dep_var() + self.assertTrue(node.name() == name) + + def test_create_empty_node(self): + prog = fluid.core.ProgramDesc() + graph = fluid.core.Graph(prog) + n1 = graph.create_empty_node('x', fluid.core.Node.Type.Operation) + self.assertTrue(n1.name() == 'x') + n2 = graph.create_empty_node('y', fluid.core.Node.Type.Variable) + self.assertTrue(n2.name() == 'y') + + def test_release_nodes(self): + graph = build_graph() + nodes = graph.release_nodes() + self.assertTrue(len(graph.nodes()) == 0) + self.assertTrue({node.name() + for node in nodes} == {"x1", "x2", "out", "sum"}) + + def test_remove_node(self): + graph = build_graph() + nodes = graph.nodes() + for node in nodes: + if node.name() == "sum": + break + self.assertTrue({node.name() + for node in nodes} == {"x1", "x2", "out", "sum"}) + nodes.remove(node) + self.assertTrue({node.name() for node in nodes} == {"x1", "x2", "out"}) + + def test_retrieve_node(self): + graph = build_graph() + nodes = [] + for i in range(len(graph.nodes())): + nodes.append(graph.retrieve_node(i)) + + for node in nodes: + self.assertTrue(node in graph.nodes()) + + def resolve_hazard(self): + pass + + +def build_graph(): + prog = fluid.core.ProgramDesc() + block = prog.block(0) + + shape = [10, 20] + + # prepare input/output + x1 = block.var(six.b("x1")) + x1.set_type(fluid.core.VarDesc.VarType.LOD_TENSOR) + x1.set_shape(shape) + x2 = block.var(six.b("x2")) + x2.set_type(fluid.core.VarDesc.VarType.LOD_TENSOR) + x2.set_shape(shape) + + out = block.var(six.b("out")) + out.set_type(fluid.core.VarDesc.VarType.LOD_TENSOR) + + sum_op_desc = block.append_op() + sum_op_desc.set_type("sum") + sum_op_desc.set_input("X", ["x1", "x2"]) + sum_op_desc.set_output("Out", ["out"]) + + sum_op_desc.check_attrs() + sum_op_desc.infer_shape(block) + graph = fluid.core.Graph(prog) + return graph + + +if __name__ == "__main__": + unittest.main()