提交 a30239ce 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #4691 from Canpio/dev_opdesc_in_python

Complete `Operator` of Python API
......@@ -91,9 +91,5 @@ BlockDescBind *BlockDescBind::ParentBlock() const {
return prog_->Block(static_cast<size_t>(this->desc_->parent_idx()));
}
void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) {
BlockDesc *desc = block.RawPtr();
this->attrs_[name] = desc;
}
} // namespace framework
} // namespace paddle
......@@ -100,6 +100,12 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) {
need_update_ = true;
}
void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) {
BlockDesc *desc = block.RawPtr();
this->attrs_[name] = desc;
need_update_ = true;
}
void OpDescBind::SetAttrMap(
const std::unordered_map<std::string, Attribute> &attr_map) {
attrs_ = attr_map;
......
......@@ -34,6 +34,7 @@ inline std::vector<T> RepeatedToVector(
template <typename T, typename RepeatedField>
inline void VectorToRepeated(const std::vector<T> &vec,
RepeatedField *repeated_field) {
repeated_field->Clear();
repeated_field->Reserve(vec.size());
for (const auto &elem : vec) {
*repeated_field->Add() = elem;
......@@ -44,6 +45,7 @@ inline void VectorToRepeated(const std::vector<T> &vec,
template <typename RepeatedField>
inline void VectorToRepeated(const std::vector<bool> &vec,
RepeatedField *repeated_field) {
repeated_field->Clear();
repeated_field->Reserve(vec.size());
for (auto elem : vec) {
*repeated_field->Add() = elem;
......
......@@ -54,7 +54,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(
ctx->Attrs().Get<float>("min") < ctx->Attrs().Get<float>("max"),
"uniform_random's min must less then max");
auto dims = Attr<std::vector<int>>("dims");
auto& dims = ctx->Attrs().Get<std::vector<int>>("dims");
std::vector<int64_t> temp;
temp.reserve(dims.size());
for (auto dim : dims) {
......
......@@ -204,7 +204,7 @@ void BindOpDesc(py::module &m) {
.def("set_attr", &OpDescBind::SetAttr)
.def("attr", &OpDescBind::GetAttr)
.def("set_block_attr", &OpDescBind::SetBlockAttr)
.def("get_block_attr", &OpDescBind::GetBlockAttr)
.def("block_attr", &OpDescBind::GetBlockAttr)
.def("check_attrs", &OpDescBind::CheckAttrs)
.def("infer_shape", &OpDescBind::InferShape);
}
......
import paddle.v2.framework.core as core
import paddle.v2.framework.proto.framework_pb2 as framework_pb2
import collections
import numpy as np
import copy
......@@ -106,6 +107,40 @@ class Variable(object):
raise ValueError("Not supported numpy dtype " + str(dtype))
def get_all_op_protos():
"""
Get all registered op proto from PaddlePaddle C++ end.
:return: A list of registered OpProto.
"""
protostrs = core.get_all_op_protos()
ret_values = []
for pbstr in protostrs:
op_proto = framework_pb2.OpProto.FromString(str(pbstr))
ret_values.append(op_proto)
return ret_values
class OpProtoHolder(object):
@classmethod
def instance(cls):
if not hasattr(cls, '_instance'):
cls._instance = cls()
return cls._instance
def __init__(self):
assert not hasattr(
self.__class__,
'_instance'), 'Please use `instance()` to get OpProtoHolder opject!'
op_protos = get_all_op_protos()
self.op_proto_map = {}
for proto in op_protos:
self.op_proto_map[proto.type] = proto
def get_op_proto(self, type):
assert type in self.op_proto_map, "Operator \"%s\" has not been registered." % type
return self.op_proto_map[type]
class Operator(object):
def __init__(self,
block,
......@@ -116,20 +151,89 @@ class Operator(object):
attrs=None):
self.block = block
self.desc = desc
if type is not None:
# TODO.
pass
if len(self.desc.type()) != 0:
return
if type is None:
raise ValueError(
"`type` to initilized an Operator can not be None.")
self.desc.set_type(type)
proto = OpProtoHolder.instance().get_op_proto(type)
if inputs is not None:
# TODO
pass
for in_proto in proto.inputs:
in_argus = inputs[in_proto.name]
if not isinstance(in_argus, list):
in_argus = [in_argus]
if not in_proto.duplicable and len(in_argus) > 1:
raise ValueError(
"Input %s expects only one input, but %d are given." %
(in_proto.name, len(in_argus)))
in_argu_names = []
for argu in in_argus:
in_argu_names.append(argu.name)
self.desc.set_input(in_proto.name, in_argu_names)
if outputs is not None:
# TODO
pass
for out_proto in proto.outputs:
out_argus = outputs[out_proto.name]
if not isinstance(out_argus, list):
out_argus = [out_argus]
if not out_proto.duplicable and len(out_argus) > 1:
raise ValueError(
"Output %s expects only one output, but %d are given." %
(out_proto.name, len(out_argus)))
out_argu_names = []
for argu in out_argus:
out_argu_names.append(argu.name)
argu.op = self
self.desc.set_output(out_proto.name, out_argu_names)
if attrs is not None:
# TODO
pass
for attr in proto.attrs:
attr_name = attr.name
if not attr_name in attrs:
continue
if not isinstance(attrs[attr_name], Block):
self.desc.set_attr(attr_name, attrs[attr_name])
else:
self.desc.set_block_attr(attr_name, attrs[attr_name].desc)
self.desc.check_attrs()
self.desc.infer_shape(self.block.desc)
@property
def type(self):
return self.desc.type()
def input(self, name):
return self.desc.input(name)
@property
def input_names(self):
return self.desc.input_names()
def output(self, name):
return self.desc.output(name)
@property
def output_names(self):
return self.desc.output_names()
def has_attr(self, name):
return self.desc.has_attr(name)
def attr_type(self, name):
return self.desc.attr_type(name)
@property
def attr_names(self):
return self.desc.attr_names()
def attr(self, name):
return self.desc.attr(name)
# TODO: Getters
def block_attr(self, name):
return self.desc.block_attr(name)
class Block(object):
......
import unittest
from paddle.v2.framework.graph import Variable, g_program
import paddle.v2.framework.core as core
class TestOperator(unittest.TestCase):
def test_error_type(self):
block = g_program.create_block()
try:
block.append_op()
self.assertFail()
except ValueError as v_err:
self.assertEqual(
v_err.message,
"`type` to initilized an Operator can not be None.")
try:
block.append_op(type="no_such_op")
self.assertFail()
except AssertionError as a_err:
self.assertEqual(a_err.message,
"Operator \"no_such_op\" has not been registered.")
def test_op_desc_creation(self):
block = g_program.current_block()
mul_x = block.create_var(
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
mul_op = block.append_op(
type="mul",
inputs={"X": [mul_x],
"Y": mul_y},
outputs={"Out": [mul_out]},
attrs={"x_num_col_dims": 1})
self.assertEqual(mul_op.type, "mul")
self.assertEqual(mul_op.input_names, ["X", "Y"])
self.assertEqual(mul_op.input("X"), ["mul.x"])
self.assertEqual(mul_op.input("Y"), ["mul.y"])
self.assertEqual(mul_op.output_names, ["Out"])
self.assertEqual(mul_op.output("Out"), ["mul.out"])
self.assertEqual(
set(mul_op.attr_names), set(["x_num_col_dims", "y_num_col_dims"]))
self.assertEqual(mul_op.has_attr("x_num_col_dims"), True)
self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT)
self.assertEqual(mul_op.attr("x_num_col_dims"), 1)
self.assertEqual(mul_op.has_attr("y_num_col_dims"), True)
self.assertEqual(mul_op.attr_type("y_num_col_dims"), core.AttrType.INT)
self.assertEqual(mul_op.attr("y_num_col_dims"), 1)
self.assertEqual(mul_out.op, mul_op)
def test_mult_input(self):
block = g_program.current_block()
sum_x1 = block.create_var(
dtype="int", shape=[3, 4], lod_level=0, name="sum.x1")
sum_x2 = block.create_var(
dtype="int", shape=[3, 4], lod_level=0, name="sum.x2")
sum_x3 = block.create_var(
dtype="int", shape=[3, 4], lod_level=0, name="sum.x3")
sum_out = block.create_var(
dtype="int", shape=[3, 4], lod_level=0, name="sum.out")
sum_op = block.append_op(
type="sum",
inputs={"X": [sum_x1, sum_x2, sum_x3]},
outputs={"Out": sum_out})
self.assertEqual(sum_op.type, "sum")
self.assertEqual(sum_op.input_names, ["X"])
self.assertEqual(sum_op.input("X"), ["sum.x1", "sum.x2", "sum.x3"])
self.assertEqual(sum_op.output_names, ["Out"])
self.assertEqual(sum_op.output("Out"), ["sum.out"])
self.assertEqual(sum_out.op, sum_op)
if __name__ == '__main__':
unittest.main()
......@@ -53,7 +53,7 @@ class TestOpDesc(unittest.TestCase):
self.assertEqual(8, len(op.attr_names()))
op.set_block_attr("block_attr", prog.block(0))
self.assertEqual(0, op.get_block_attr("block_attr"))
self.assertEqual(0, op.block_attr("block_attr"))
mul_op = block.append_op()
mul_op.set_type("mul")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册