diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc index 509aa235d3ee226adef15f08f5785866700499f1..b77d5525d4508056c9d6d487e63e500265e1d700 100644 --- a/paddle/framework/block_desc.cc +++ b/paddle/framework/block_desc.cc @@ -91,9 +91,5 @@ BlockDescBind *BlockDescBind::ParentBlock() const { return prog_->Block(static_cast(this->desc_->parent_idx())); } -void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) { - BlockDesc *desc = block.RawPtr(); - this->attrs_[name] = desc; -} } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index d3c11ad60a0f9319329a59c16bfc4668cd75b7ae..a5d515bbca729220ca6df5fa07d02f1b3f025109 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -100,6 +100,12 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) { need_update_ = true; } +void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) { + BlockDesc *desc = block.RawPtr(); + this->attrs_[name] = desc; + need_update_ = true; +} + void OpDescBind::SetAttrMap( const std::unordered_map &attr_map) { attrs_ = attr_map; diff --git a/paddle/framework/var_desc.h b/paddle/framework/var_desc.h index 464fece85fe5c674690c2034054e551f14db2138..44368795645664a343e2706fb670f104a42c5c9f 100644 --- a/paddle/framework/var_desc.h +++ b/paddle/framework/var_desc.h @@ -34,6 +34,7 @@ inline std::vector RepeatedToVector( template inline void VectorToRepeated(const std::vector &vec, RepeatedField *repeated_field) { + repeated_field->Clear(); repeated_field->Reserve(vec.size()); for (const auto &elem : vec) { *repeated_field->Add() = elem; @@ -44,6 +45,7 @@ inline void VectorToRepeated(const std::vector &vec, template inline void VectorToRepeated(const std::vector &vec, RepeatedField *repeated_field) { + repeated_field->Clear(); repeated_field->Reserve(vec.size()); for (auto elem : vec) { *repeated_field->Add() = elem; diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index e330877fc4283b796dcb5c5d745881884ae491ae..75928f1ec818ab028ea06cfa72273fb99430c3c8 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -54,7 +54,7 @@ class UniformRandomOp : public framework::OperatorWithKernel { PADDLE_ENFORCE( ctx->Attrs().Get("min") < ctx->Attrs().Get("max"), "uniform_random's min must less then max"); - auto dims = Attr>("dims"); + auto& dims = ctx->Attrs().Get>("dims"); std::vector temp; temp.reserve(dims.size()); for (auto dim : dims) { diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 0e4bbe8415fd86ab29c6809e7652dc581b4e6004..7ab4e6a451846199d249ee8c6cf24483802a58da 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -204,7 +204,7 @@ void BindOpDesc(py::module &m) { .def("set_attr", &OpDescBind::SetAttr) .def("attr", &OpDescBind::GetAttr) .def("set_block_attr", &OpDescBind::SetBlockAttr) - .def("get_block_attr", &OpDescBind::GetBlockAttr) + .def("block_attr", &OpDescBind::GetBlockAttr) .def("check_attrs", &OpDescBind::CheckAttrs) .def("infer_shape", &OpDescBind::InferShape); } diff --git a/python/paddle/v2/framework/graph.py b/python/paddle/v2/framework/graph.py index 0f0a2847e58a1ca172bf1ba382abb2ebc1ecb8ed..2afbd0c83158d583dd637cceeb35321ddb68f323 100644 --- a/python/paddle/v2/framework/graph.py +++ b/python/paddle/v2/framework/graph.py @@ -1,4 +1,5 @@ import paddle.v2.framework.core as core +import paddle.v2.framework.proto.framework_pb2 as framework_pb2 import collections import numpy as np import copy @@ -106,6 +107,40 @@ class Variable(object): raise ValueError("Not supported numpy dtype " + str(dtype)) +def get_all_op_protos(): + """ + Get all registered op proto from PaddlePaddle C++ end. + :return: A list of registered OpProto. + """ + protostrs = core.get_all_op_protos() + ret_values = [] + for pbstr in protostrs: + op_proto = framework_pb2.OpProto.FromString(str(pbstr)) + ret_values.append(op_proto) + return ret_values + + +class OpProtoHolder(object): + @classmethod + def instance(cls): + if not hasattr(cls, '_instance'): + cls._instance = cls() + return cls._instance + + def __init__(self): + assert not hasattr( + self.__class__, + '_instance'), 'Please use `instance()` to get OpProtoHolder opject!' + op_protos = get_all_op_protos() + self.op_proto_map = {} + for proto in op_protos: + self.op_proto_map[proto.type] = proto + + def get_op_proto(self, type): + assert type in self.op_proto_map, "Operator \"%s\" has not been registered." % type + return self.op_proto_map[type] + + class Operator(object): def __init__(self, block, @@ -116,20 +151,89 @@ class Operator(object): attrs=None): self.block = block self.desc = desc - if type is not None: - # TODO. - pass + if len(self.desc.type()) != 0: + return + if type is None: + raise ValueError( + "`type` to initilized an Operator can not be None.") + self.desc.set_type(type) + proto = OpProtoHolder.instance().get_op_proto(type) + if inputs is not None: - # TODO - pass + for in_proto in proto.inputs: + in_argus = inputs[in_proto.name] + if not isinstance(in_argus, list): + in_argus = [in_argus] + if not in_proto.duplicable and len(in_argus) > 1: + raise ValueError( + "Input %s expects only one input, but %d are given." % + (in_proto.name, len(in_argus))) + in_argu_names = [] + for argu in in_argus: + in_argu_names.append(argu.name) + self.desc.set_input(in_proto.name, in_argu_names) + if outputs is not None: - # TODO - pass + for out_proto in proto.outputs: + out_argus = outputs[out_proto.name] + if not isinstance(out_argus, list): + out_argus = [out_argus] + if not out_proto.duplicable and len(out_argus) > 1: + raise ValueError( + "Output %s expects only one output, but %d are given." % + (out_proto.name, len(out_argus))) + out_argu_names = [] + for argu in out_argus: + out_argu_names.append(argu.name) + argu.op = self + self.desc.set_output(out_proto.name, out_argu_names) + if attrs is not None: - # TODO - pass + for attr in proto.attrs: + attr_name = attr.name + if not attr_name in attrs: + continue + if not isinstance(attrs[attr_name], Block): + self.desc.set_attr(attr_name, attrs[attr_name]) + else: + self.desc.set_block_attr(attr_name, attrs[attr_name].desc) + + self.desc.check_attrs() + self.desc.infer_shape(self.block.desc) + + @property + def type(self): + return self.desc.type() + + def input(self, name): + return self.desc.input(name) + + @property + def input_names(self): + return self.desc.input_names() + + def output(self, name): + return self.desc.output(name) + + @property + def output_names(self): + return self.desc.output_names() + + def has_attr(self, name): + return self.desc.has_attr(name) + + def attr_type(self, name): + return self.desc.attr_type(name) + + @property + def attr_names(self): + return self.desc.attr_names() + + def attr(self, name): + return self.desc.attr(name) - # TODO: Getters + def block_attr(self, name): + return self.desc.block_attr(name) class Block(object): diff --git a/python/paddle/v2/framework/tests/test_operator_desc.py b/python/paddle/v2/framework/tests/test_operator_desc.py new file mode 100644 index 0000000000000000000000000000000000000000..ec6c6bc1833ac708b9702d9c973e29fc343b4c09 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_operator_desc.py @@ -0,0 +1,76 @@ +import unittest +from paddle.v2.framework.graph import Variable, g_program +import paddle.v2.framework.core as core + + +class TestOperator(unittest.TestCase): + def test_error_type(self): + block = g_program.create_block() + try: + block.append_op() + self.assertFail() + except ValueError as v_err: + self.assertEqual( + v_err.message, + "`type` to initilized an Operator can not be None.") + try: + block.append_op(type="no_such_op") + self.assertFail() + except AssertionError as a_err: + self.assertEqual(a_err.message, + "Operator \"no_such_op\" has not been registered.") + + def test_op_desc_creation(self): + block = g_program.current_block() + mul_x = block.create_var( + dtype="float32", shape=[5, 10], lod_level=0, name="mul.x") + mul_y = block.create_var( + dtype="float32", shape=[10, 8], lod_level=0, name="mul.y") + mul_out = block.create_var( + dtype="float32", shape=[5, 8], lod_level=0, name="mul.out") + mul_op = block.append_op( + type="mul", + inputs={"X": [mul_x], + "Y": mul_y}, + outputs={"Out": [mul_out]}, + attrs={"x_num_col_dims": 1}) + self.assertEqual(mul_op.type, "mul") + self.assertEqual(mul_op.input_names, ["X", "Y"]) + self.assertEqual(mul_op.input("X"), ["mul.x"]) + self.assertEqual(mul_op.input("Y"), ["mul.y"]) + self.assertEqual(mul_op.output_names, ["Out"]) + self.assertEqual(mul_op.output("Out"), ["mul.out"]) + self.assertEqual( + set(mul_op.attr_names), set(["x_num_col_dims", "y_num_col_dims"])) + self.assertEqual(mul_op.has_attr("x_num_col_dims"), True) + self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT) + self.assertEqual(mul_op.attr("x_num_col_dims"), 1) + self.assertEqual(mul_op.has_attr("y_num_col_dims"), True) + self.assertEqual(mul_op.attr_type("y_num_col_dims"), core.AttrType.INT) + self.assertEqual(mul_op.attr("y_num_col_dims"), 1) + self.assertEqual(mul_out.op, mul_op) + + def test_mult_input(self): + block = g_program.current_block() + sum_x1 = block.create_var( + dtype="int", shape=[3, 4], lod_level=0, name="sum.x1") + sum_x2 = block.create_var( + dtype="int", shape=[3, 4], lod_level=0, name="sum.x2") + sum_x3 = block.create_var( + dtype="int", shape=[3, 4], lod_level=0, name="sum.x3") + sum_out = block.create_var( + dtype="int", shape=[3, 4], lod_level=0, name="sum.out") + sum_op = block.append_op( + type="sum", + inputs={"X": [sum_x1, sum_x2, sum_x3]}, + outputs={"Out": sum_out}) + self.assertEqual(sum_op.type, "sum") + self.assertEqual(sum_op.input_names, ["X"]) + self.assertEqual(sum_op.input("X"), ["sum.x1", "sum.x2", "sum.x3"]) + self.assertEqual(sum_op.output_names, ["Out"]) + self.assertEqual(sum_op.output("Out"), ["sum.out"]) + self.assertEqual(sum_out.op, sum_op) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_protobuf_descs.py b/python/paddle/v2/framework/tests/test_protobuf_descs.py index 3db1e79ce43b7f559c7caab8397817b76d56161e..af5ed6801fa7b87e9193df78c7d28cf637eafa42 100644 --- a/python/paddle/v2/framework/tests/test_protobuf_descs.py +++ b/python/paddle/v2/framework/tests/test_protobuf_descs.py @@ -53,7 +53,7 @@ class TestOpDesc(unittest.TestCase): self.assertEqual(8, len(op.attr_names())) op.set_block_attr("block_attr", prog.block(0)) - self.assertEqual(0, op.get_block_attr("block_attr")) + self.assertEqual(0, op.block_attr("block_attr")) mul_op = block.append_op() mul_op.set_type("mul")