diff --git a/doc/design/python_api.md b/doc/design/python_api.md index 6213da65c8c5931bc16e42574b8628b676424873..c4665e44fca6e75878d76ba0a686f87f10222988 100644 --- a/doc/design/python_api.md +++ b/doc/design/python_api.md @@ -22,7 +22,7 @@ Whenever we create a block, we need to set its parent block to the current block ```python class Program(objects): def __init__(self): - self.proto = core.NewProgram() # a C++ ProgramDesc pointer. + self.desc = core.NewProgram() # a C++ ProgramDesc pointer. self.blocks = vector() self.blocks.append(Block(self, -1)) # the global block self.current_block = 0 # initialized to the global block @@ -57,7 +57,7 @@ A [Block](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.m ```python class Block(objects): def __init__(self, program, parent_idx): - self.proto = core.NewBlock(program.proto) + self.desc = core.NewBlock(program.desc) self.program = program self.vars = map() self.ops = vector() @@ -98,11 +98,11 @@ class Operator(object): outputs,# dict attrs # dict ): - self.proto = core.NewOpDesc(block.proto, type, inputs, outputs, attrs) - core.infer_shape(self.proto, inputs, outputs) + self.desc = core.NewOpDesc(block.desc, type, inputs, outputs, attrs) + core.infer_shape(self.desc, inputs, outputs) def type(self): - return self.proto.type() + return self.desc.type() ``` `Operator` creates the `OpDesc` message in C++ space, so that it can call the `InferShape` function, which is in C++. @@ -124,7 +124,7 @@ class Variable(object): name = unique_name_generator() self.name = name self.block = block - self.proto = core.NewVarDesc(block.proto, name, shape, lod_level) + self.desc = core.NewVarDesc(block.desc, name, shape, lod_level) self.writer = None ``` diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 3e0e0f59038daa33cae1952ffbe5fc0bb1870485..1bf80b3e58df591376b79253c3eaf69355b3397f 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -19,7 +19,7 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope) proto_library(framework_proto SRCS framework.proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto) -cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute) +cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc) diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index c2e796b7c1b6e359765bafd6cd66fa16d69897a1..e7538b4af3429e566a439d5a0db8496efcd94969 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/op_desc.h" +#include +#include #include "paddle/framework/block_desc.h" +#include "paddle/framework/operator.h" namespace paddle { namespace framework { @@ -185,5 +188,38 @@ void OpDescBind::Sync() { need_update_ = false; } } + +using InferShapeFuncMap = + std::unordered_map>; + +static InferShapeFuncMap &InferShapeFuncs() { + static InferShapeFuncMap *g_map = nullptr; + if (g_map == nullptr) { + g_map = new InferShapeFuncMap(); + auto &info_map = OpInfoMap::Instance(); + // all registered kernels + for (auto &pair : OperatorWithKernel::AllOpKernels()) { + auto &info = info_map.Get(pair.first); + // use empty type here to avoid runtime checks. + auto op = + static_cast(info.Creator()("", {}, {}, {})); + g_map->insert( + {pair.first, [op](InferShapeContext *ctx) { op->InferShape(ctx); }}); + } + } + return *g_map; +} + +void OpDescBind::InferShape(const BlockDescBind &block) const { + auto &funcs = InferShapeFuncs(); + auto it = funcs.find(this->Type()); + if (it == funcs.end()) { + PADDLE_THROW("Operator %s has not been registered", this->Type()); + } + CompileTimeInferShapeContext ctx(*this, block); + it->second(&ctx); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index d0c314771c04d2a293f2d9ae0b7fc2be0ccb3add..81c4225041157ac600d1db73ef2363ebcd4abfc0 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -100,6 +100,8 @@ class OpDescBind { return &this->attrs_; } + void InferShape(const BlockDescBind &block) const; + private: template static std::vector MapKeys(const MapType &map) { diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 3219f0a18a6e643eb13c76dfeefdbe4026962676..116c99bd2c1ca59b093392f9e6cc481c089309bc 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -198,7 +198,8 @@ void BindOpDesc(py::module &m) { .def("set_attr", &OpDescBind::SetAttr) .def("attr", &OpDescBind::GetAttr) .def("set_block_attr", &OpDescBind::SetBlockAttr) - .def("get_block_attr", &OpDescBind::GetBlockAttr); + .def("get_block_attr", &OpDescBind::GetBlockAttr) + .def("infer_shape", &OpDescBind::InferShape); } } // namespace pybind diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 356c4986e2e182e904215f7ebb8cac5146364f8b..0f6e3101e26c5ac249664ce8badc10adc939305f 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -231,21 +231,6 @@ All parameter, weight, gradient are variables in Paddle. desc.InitializationErrorString()); return OpRegistry::CreateOp(desc); }) - .def_static("infer_shape", - [](OpDescBind &op_desc, BlockDescBind &block) { - auto op = OpRegistry::CreateOp(*op_desc.Proto()); - auto *op_with_kernel = - dynamic_cast(op.get()); - if (op_with_kernel != nullptr) { - auto ctx = CompileTimeInferShapeContext(op_desc, block); - op_with_kernel->InferShape(&ctx); - } else { - PADDLE_THROW( - "OP(%s) is not type of OperatorWithKernel, " - "should not call this function", - op_desc.Type()); - } - }) .def("backward", [](const OperatorBase &forwardOp, const std::unordered_set &no_grad_vars) { diff --git a/python/paddle/v2/framework/graph.py b/python/paddle/v2/framework/graph.py index a66e7a9d7319ffd000d12589846dd69dab1be025..ba1488546274546c85facd4def0593cec5ca177a 100644 --- a/python/paddle/v2/framework/graph.py +++ b/python/paddle/v2/framework/graph.py @@ -13,15 +13,15 @@ class Variable(object): if name is None: name = Variable._unique_var_name_() try: - self.proto = self.block.proto.var(name) + self.desc = self.block.desc.var(name) is_new_var = False except core.EnforceNotMet: - self.proto = self.block.proto.new_var(name) + self.desc = self.block.desc.new_var(name) is_new_var = True if shape is not None: if is_new_var: - self.proto.set_shape(shape) + self.desc.set_shape(shape) else: old_shape = self.shape shape = tuple(shape) @@ -34,7 +34,7 @@ class Variable(object): if not isinstance(dtype, core.DataType): dtype = Variable._convert_np_dtype_to_dtype_(dtype) if is_new_var: - self.proto.set_data_type(dtype) + self.desc.set_data_type(dtype) else: old_dtype = self.data_type() if dtype != old_shape: @@ -46,7 +46,7 @@ class Variable(object): if lod_level is not None: if is_new_var: - self.proto.set_lod_level(lod_level) + self.desc.set_lod_level(lod_level) else: if lod_level != self.lod_level: raise ValueError("Variable {0} has been created before. " @@ -54,26 +54,25 @@ class Variable(object): "lod_level is {2}. They are not " "matched".format(self.name, self.lod_level, lod_level)) - self.block.vars[name] = self self.op = None @property def name(self): - return self.proto.name() + return self.desc.name() @property def shape(self): # convert to tuple, make it as same as numpy API. - return tuple(self.proto.shape()) + return tuple(self.desc.shape()) @property def data_type(self): - return self.proto.data_type() + return self.desc.data_type() @property def lod_level(self): - return self.proto.lod_level() + return self.desc.lod_level() @staticmethod def _unique_var_name_(): @@ -104,13 +103,13 @@ class Variable(object): class Operator(object): def __init__(self, block, - proto, + desc, type=None, inputs=None, outputs=None, attrs=None): self.block = block - self.proto = proto + self.desc = desc if type is not None: # TODO. pass @@ -129,31 +128,31 @@ class Operator(object): class Block(object): def __init__(self, program, idx): - self.proto = program.proto.block(idx) + self.desc = program.desc.block(idx) self.vars = dict() # var_name --> var self.ops = collections.deque() # operator list self.program = program @property def parent_idx(self): - return self.proto.parent + return self.desc.parent @property def idx(self): - return self.proto.id + return self.desc.id def create_var(self, *args, **kwargs): return Variable(self, *args, **kwargs) def append_op(self, *args, **kwargs): - op_proto = self.proto.append_op() - op = Operator(self, op_proto, *args, **kwargs) + op_desc = self.desc.append_op() + op = Operator(self, op_desc, *args, **kwargs) self.ops.append(op) return op def prepend_op(self, *args, **kwargs): - op_proto = self.proto.prepend_op() - op = Operator(self, op_proto, *args, **kwargs) + op_desc = self.desc.prepend_op() + op = Operator(self, op_desc, *args, **kwargs) self.ops.appendleft(op) return op @@ -170,7 +169,7 @@ class Program(object): def __init__(self): assert not hasattr(self.__class__, '_instance'), 'Do not call constructor directly!' - self.proto = core.ProgramDesc.instance() + self.desc = core.ProgramDesc.instance() self.blocks = [Block(self, 0)] self.current_block_idx = 0 @@ -182,7 +181,7 @@ class Program(object): def create_block(self): new_block_idx = len(self.blocks) - self.proto.append_block(self.current_block().proto) + self.desc.append_block(self.current_block().desc) self.current_block_idx = new_block_idx self.blocks.append(Block(self, self.current_block_idx)) return self.current_block() diff --git a/python/paddle/v2/framework/tests/test_infer_shape.py b/python/paddle/v2/framework/tests/test_infer_shape.py index b38ec9c03740a2e69f1247c094ce56ab43fa8e32..99562890fdd4d8b10f420869f1ba9f694db5969a 100644 --- a/python/paddle/v2/framework/tests/test_infer_shape.py +++ b/python/paddle/v2/framework/tests/test_infer_shape.py @@ -1,6 +1,6 @@ import unittest + import paddle.v2.framework.core as core -from paddle.v2.framework.op import Operator class TestInferShape(unittest.TestCase): @@ -26,7 +26,7 @@ class TestInferShape(unittest.TestCase): sum_op_desc.set_input("X", ["x1", "x2"]) sum_op_desc.set_output("Out", ["out"]) - core.Operator.infer_shape(sum_op_desc, block) + sum_op_desc.infer_shape(block) self.assertEqual(out.shape(), shape) def test_mul_op(self): @@ -55,7 +55,7 @@ class TestInferShape(unittest.TestCase): mul_op_desc.set_attr("x_num_col_dims", 1) mul_op_desc.set_attr("y_num_col_dims", 1) - core.Operator.infer_shape(mul_op_desc, block) + mul_op_desc.infer_shape(block) self.assertEqual(out.shape(), [x_shape[0], y_shape[1]])