提交 ddf24484 编写于 作者: Y Yu Yang

Update Input/Output of Op

上级 e05e27a7
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/pybind/protobuf.h"
#include <deque>
#include "paddle/framework/attribute.h"
namespace paddle {
namespace pybind {
......@@ -56,10 +57,90 @@ private:
class OpDescBind {
public:
OpDesc *Proto() { return &op_desc_; }
OpDesc *Proto() {
Sync();
return &op_desc_;
}
std::string Type() const { return op_desc_.type(); }
void SetType(const std::string &type) { op_desc_.set_type(type); }
const std::vector<std::string> &Input(const std::string &name) const {
auto it = inputs_.find(name);
PADDLE_ENFORCE(
it != inputs_.end(), "Input %s cannot be found in Op %s", name, Type());
return it->second;
}
std::vector<std::string> InputNames() const {
std::vector<std::string> retv;
retv.reserve(this->inputs_.size());
for (auto &ipt : this->inputs_) {
retv.push_back(ipt.first);
}
return retv;
}
void SetInput(const std::string &param_name,
const std::vector<std::string> &args) {
need_update_ = true;
inputs_[param_name] = args;
}
const std::vector<std::string> &Output(const std::string &name) const {
auto it = outputs_.find(name);
PADDLE_ENFORCE(it != outputs_.end(),
"Output %s cannot be found in Op %s",
name,
Type());
return it->second;
}
std::vector<std::string> OutputNames() const {
std::vector<std::string> retv;
retv.reserve(this->outputs_.size());
for (auto &ipt : this->outputs_) {
retv.push_back(ipt.first);
}
return retv;
}
void SetOutput(const std::string &param_name,
const std::vector<std::string> &args) {
need_update_ = true;
this->outputs_[param_name] = args;
}
std::string DebugString() { return this->Proto()->DebugString(); }
void Sync() {
if (need_update_) {
this->op_desc_.mutable_inputs()->Clear();
for (auto &ipt : inputs_) {
auto *input = op_desc_.add_inputs();
input->set_parameter(ipt.first);
VectorToRepeated(ipt.second, input->mutable_arguments());
}
this->op_desc_.mutable_outputs()->Clear();
for (auto &opt : outputs_) {
auto *output = op_desc_.add_outputs();
output->set_parameter(opt.first);
VectorToRepeated(opt.second, output->mutable_arguments());
}
need_update_ = false;
}
}
private:
OpDesc op_desc_;
std::unordered_map<std::string, std::vector<std::string>> inputs_;
std::unordered_map<std::string, std::vector<std::string>> outputs_;
std::unordered_map<std::string, Attribute> attrs_;
bool need_update_{false};
};
class BlockDescBind {
......@@ -141,8 +222,6 @@ public:
return blocks_.back().get();
}
BlockDescBind *Root() { return blocks_.front().get(); }
BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); }
std::string DebugString() { return Proto()->DebugString(); }
......@@ -196,9 +275,6 @@ void BindProgramDesc(py::module &m) {
.def("append_block",
&ProgramDescBind::AppendBlock,
py::return_value_policy::reference)
.def("root_block",
&ProgramDescBind::Root,
py::return_value_policy::reference)
.def("block", &ProgramDescBind::Block, py::return_value_policy::reference)
.def("__str__", &ProgramDescBind::DebugString)
.def("num_blocks", &ProgramDescBind::Size);
......@@ -241,52 +317,17 @@ void BindVarDsec(py::module &m) {
}
void BindOpDesc(py::module &m) {
// auto op_desc_set_var = [](OpDesc::Var *var,
// const std::string &parameter,
// const std::vector<std::string> &arguments) {
// var->set_parameter(parameter);
// VectorToRepeated(arguments, var->mutable_arguments());
// };
//
// auto op_desc_set_attr = [](OpDesc &desc, const std::string &name) {
// auto attr = desc.add_attrs();
// attr->set_name(name);
// return attr;
// };
py::class_<OpDescBind>(m, "OpDesc", "");
// .def("type", [](OpDesc &op) { return op.type(); })
// .def("set_input",
// [op_desc_set_var](OpDesc &self,
// const std::string &parameter,
// const std::vector<std::string> &arguments) {
// auto ipt = self.add_inputs();
// op_desc_set_var(ipt, parameter, arguments);
// })
// .def("input_names",
// [](OpDesc &self) {
// std::vector<std::string> ret_val;
// ret_val.reserve(static_cast<size_t>(self.inputs().size()));
// std::transform(
// self.inputs().begin(),
// self.inputs().end(),
// std::back_inserter(ret_val),
// [](const OpDesc::Var &var) { return var.parameter(); });
// return ret_val;
// })
// .def("__str__", [](OpDesc &self) { return self.DebugString(); })
// .def("set_output",
// [op_desc_set_var](OpDesc &self,
// const std::string &parameter,
// const std::vector<std::string> &arguments) {
// auto opt = self.add_outputs();
// op_desc_set_var(opt, parameter, arguments);
// })
// .def("set_attr",
// [op_desc_set_attr](OpDesc &self, const std::string &name, int i)
// {
// op_desc_set_attr(self, name)->set_i(i);
// });
py::class_<OpDescBind>(m, "OpDesc", "")
.def("type", &OpDescBind::Type)
.def("set_type", &OpDescBind::SetType)
.def("input", &OpDescBind::Input)
.def("input_names", &OpDescBind::InputNames)
.def("set_input", &OpDescBind::SetInput)
.def("output", &OpDescBind::Output)
.def("output_names", &OpDescBind::OutputNames)
.def("set_output", &OpDescBind::SetOutput)
.def("__str__", &OpDescBind::DebugString)
.def("__repr__", &OpDescBind::DebugString);
}
} // namespace pybind
} // namespace paddle
......@@ -2,6 +2,25 @@ import unittest
import paddle.v2.framework.core as core
class TestOpDesc(unittest.TestCase):
def test_op_desc(self):
prog = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
op = block.append_op()
self.assertIsNotNone(op)
op.set_type("test")
self.assertEqual("test", op.type())
op.set_input("X", ["a", "b", "c"])
self.assertEqual(["a", "b", "c"], op.input("X"))
self.assertEqual(["X"], op.input_names())
op.set_output("Out", ["z"])
self.assertEqual(['z'], op.output("Out"))
self.assertEqual(["Out"], op.output_names())
class TestProgramDesc(unittest.TestCase):
def test_instance(self):
program_desc = core.ProgramDesc.instance()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册