From ddf2448484cb6d183032e8d616ed51176dea9ded Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 22 Sep 2017 17:46:48 -0700 Subject: [PATCH] Update Input/Output of Op --- paddle/pybind/protobuf.cc | 145 +++++++++++------- .../v2/framework/tests/test_protobuf_descs.py | 19 +++ 2 files changed, 112 insertions(+), 52 deletions(-) diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 5511841c8b..67d6252af8 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/pybind/protobuf.h" #include +#include "paddle/framework/attribute.h" namespace paddle { namespace pybind { @@ -56,10 +57,90 @@ private: class OpDescBind { public: - OpDesc *Proto() { return &op_desc_; } + OpDesc *Proto() { + Sync(); + return &op_desc_; + } + + std::string Type() const { return op_desc_.type(); } + + void SetType(const std::string &type) { op_desc_.set_type(type); } + + const std::vector &Input(const std::string &name) const { + auto it = inputs_.find(name); + PADDLE_ENFORCE( + it != inputs_.end(), "Input %s cannot be found in Op %s", name, Type()); + return it->second; + } + + std::vector InputNames() const { + std::vector retv; + retv.reserve(this->inputs_.size()); + for (auto &ipt : this->inputs_) { + retv.push_back(ipt.first); + } + return retv; + } + + void SetInput(const std::string ¶m_name, + const std::vector &args) { + need_update_ = true; + inputs_[param_name] = args; + } + + const std::vector &Output(const std::string &name) const { + auto it = outputs_.find(name); + PADDLE_ENFORCE(it != outputs_.end(), + "Output %s cannot be found in Op %s", + name, + Type()); + return it->second; + } + + std::vector OutputNames() const { + std::vector retv; + retv.reserve(this->outputs_.size()); + for (auto &ipt : this->outputs_) { + retv.push_back(ipt.first); + } + return retv; + } + + void SetOutput(const std::string ¶m_name, + const std::vector &args) { + need_update_ = true; + this->outputs_[param_name] = args; + } + + std::string DebugString() { return this->Proto()->DebugString(); } + + void Sync() { + if (need_update_) { + this->op_desc_.mutable_inputs()->Clear(); + for (auto &ipt : inputs_) { + auto *input = op_desc_.add_inputs(); + input->set_parameter(ipt.first); + VectorToRepeated(ipt.second, input->mutable_arguments()); + } + + this->op_desc_.mutable_outputs()->Clear(); + for (auto &opt : outputs_) { + auto *output = op_desc_.add_outputs(); + output->set_parameter(opt.first); + VectorToRepeated(opt.second, output->mutable_arguments()); + } + + need_update_ = false; + } + } private: OpDesc op_desc_; + std::unordered_map> inputs_; + std::unordered_map> outputs_; + std::unordered_map attrs_; + + bool need_update_{false}; }; class BlockDescBind { @@ -141,8 +222,6 @@ public: return blocks_.back().get(); } - BlockDescBind *Root() { return blocks_.front().get(); } - BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); } std::string DebugString() { return Proto()->DebugString(); } @@ -196,9 +275,6 @@ void BindProgramDesc(py::module &m) { .def("append_block", &ProgramDescBind::AppendBlock, py::return_value_policy::reference) - .def("root_block", - &ProgramDescBind::Root, - py::return_value_policy::reference) .def("block", &ProgramDescBind::Block, py::return_value_policy::reference) .def("__str__", &ProgramDescBind::DebugString) .def("num_blocks", &ProgramDescBind::Size); @@ -241,52 +317,17 @@ void BindVarDsec(py::module &m) { } void BindOpDesc(py::module &m) { - // auto op_desc_set_var = [](OpDesc::Var *var, - // const std::string ¶meter, - // const std::vector &arguments) { - // var->set_parameter(parameter); - // VectorToRepeated(arguments, var->mutable_arguments()); - // }; - // - // auto op_desc_set_attr = [](OpDesc &desc, const std::string &name) { - // auto attr = desc.add_attrs(); - // attr->set_name(name); - // return attr; - // }; - py::class_(m, "OpDesc", ""); - - // .def("type", [](OpDesc &op) { return op.type(); }) - // .def("set_input", - // [op_desc_set_var](OpDesc &self, - // const std::string ¶meter, - // const std::vector &arguments) { - // auto ipt = self.add_inputs(); - // op_desc_set_var(ipt, parameter, arguments); - // }) - // .def("input_names", - // [](OpDesc &self) { - // std::vector ret_val; - // ret_val.reserve(static_cast(self.inputs().size())); - // std::transform( - // self.inputs().begin(), - // self.inputs().end(), - // std::back_inserter(ret_val), - // [](const OpDesc::Var &var) { return var.parameter(); }); - // return ret_val; - // }) - // .def("__str__", [](OpDesc &self) { return self.DebugString(); }) - // .def("set_output", - // [op_desc_set_var](OpDesc &self, - // const std::string ¶meter, - // const std::vector &arguments) { - // auto opt = self.add_outputs(); - // op_desc_set_var(opt, parameter, arguments); - // }) - // .def("set_attr", - // [op_desc_set_attr](OpDesc &self, const std::string &name, int i) - // { - // op_desc_set_attr(self, name)->set_i(i); - // }); + py::class_(m, "OpDesc", "") + .def("type", &OpDescBind::Type) + .def("set_type", &OpDescBind::SetType) + .def("input", &OpDescBind::Input) + .def("input_names", &OpDescBind::InputNames) + .def("set_input", &OpDescBind::SetInput) + .def("output", &OpDescBind::Output) + .def("output_names", &OpDescBind::OutputNames) + .def("set_output", &OpDescBind::SetOutput) + .def("__str__", &OpDescBind::DebugString) + .def("__repr__", &OpDescBind::DebugString); } } // namespace pybind } // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_protobuf_descs.py b/python/paddle/v2/framework/tests/test_protobuf_descs.py index fbe1f7152b..950a936307 100644 --- a/python/paddle/v2/framework/tests/test_protobuf_descs.py +++ b/python/paddle/v2/framework/tests/test_protobuf_descs.py @@ -2,6 +2,25 @@ import unittest import paddle.v2.framework.core as core +class TestOpDesc(unittest.TestCase): + def test_op_desc(self): + prog = core.ProgramDesc.__create_program_desc__() + self.assertIsNotNone(prog) + block = prog.block(0) + self.assertIsNotNone(block) + op = block.append_op() + self.assertIsNotNone(op) + op.set_type("test") + self.assertEqual("test", op.type()) + op.set_input("X", ["a", "b", "c"]) + self.assertEqual(["a", "b", "c"], op.input("X")) + self.assertEqual(["X"], op.input_names()) + + op.set_output("Out", ["z"]) + self.assertEqual(['z'], op.output("Out")) + self.assertEqual(["Out"], op.output_names()) + + class TestProgramDesc(unittest.TestCase): def test_instance(self): program_desc = core.ProgramDesc.instance() -- GitLab