提交 0c37a061 编写于 作者: Y Yu Yang

Merge branch 'feature/change_proto_to_desc' into feature/complete_variable_bind

...@@ -22,7 +22,7 @@ Whenever we create a block, we need to set its parent block to the current block ...@@ -22,7 +22,7 @@ Whenever we create a block, we need to set its parent block to the current block
```python ```python
class Program(objects): class Program(objects):
def __init__(self): def __init__(self):
self.proto = core.NewProgram() # a C++ ProgramDesc pointer. self.desc = core.NewProgram() # a C++ ProgramDesc pointer.
self.blocks = vector<Block>() self.blocks = vector<Block>()
self.blocks.append(Block(self, -1)) # the global block self.blocks.append(Block(self, -1)) # the global block
self.current_block = 0 # initialized to the global block self.current_block = 0 # initialized to the global block
...@@ -57,7 +57,7 @@ A [Block](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.m ...@@ -57,7 +57,7 @@ A [Block](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.m
```python ```python
class Block(objects): class Block(objects):
def __init__(self, program, parent_idx): def __init__(self, program, parent_idx):
self.proto = core.NewBlock(program.proto) self.desc = core.NewBlock(program.desc)
self.program = program self.program = program
self.vars = map<string, Variable>() self.vars = map<string, Variable>()
self.ops = vector<Operator>() self.ops = vector<Operator>()
...@@ -98,11 +98,11 @@ class Operator(object): ...@@ -98,11 +98,11 @@ class Operator(object):
outputs,# dict<stirng, Variable> outputs,# dict<stirng, Variable>
attrs # dict<string, Any> attrs # dict<string, Any>
): ):
self.proto = core.NewOpDesc(block.proto, type, inputs, outputs, attrs) self.desc = core.NewOpDesc(block.desc, type, inputs, outputs, attrs)
core.infer_shape(self.proto, inputs, outputs) core.infer_shape(self.desc, inputs, outputs)
def type(self): def type(self):
return self.proto.type() return self.desc.type()
``` ```
`Operator` creates the `OpDesc` message in C++ space, so that it can call the `InferShape` function, which is in C++. `Operator` creates the `OpDesc` message in C++ space, so that it can call the `InferShape` function, which is in C++.
...@@ -124,7 +124,7 @@ class Variable(object): ...@@ -124,7 +124,7 @@ class Variable(object):
name = unique_name_generator() name = unique_name_generator()
self.name = name self.name = name
self.block = block self.block = block
self.proto = core.NewVarDesc(block.proto, name, shape, lod_level) self.desc = core.NewVarDesc(block.desc, name, shape, lod_level)
self.writer = None self.writer = None
``` ```
......
...@@ -19,7 +19,7 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope) ...@@ -19,7 +19,7 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope)
proto_library(framework_proto SRCS framework.proto) proto_library(framework_proto SRCS framework.proto)
cc_library(attribute SRCS attribute.cc DEPS framework_proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc)
......
...@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/framework/op_desc.h" #include "paddle/framework/op_desc.h"
#include <functional>
#include <unordered_map>
#include "paddle/framework/block_desc.h" #include "paddle/framework/block_desc.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -185,5 +188,38 @@ void OpDescBind::Sync() { ...@@ -185,5 +188,38 @@ void OpDescBind::Sync() {
need_update_ = false; need_update_ = false;
} }
} }
using InferShapeFuncMap =
std::unordered_map<std::string /*op_type*/,
std::function<void(InferShapeContext *)>>;
static InferShapeFuncMap &InferShapeFuncs() {
static InferShapeFuncMap *g_map = nullptr;
if (g_map == nullptr) {
g_map = new InferShapeFuncMap();
auto &info_map = OpInfoMap::Instance();
// all registered kernels
for (auto &pair : OperatorWithKernel::AllOpKernels()) {
auto &info = info_map.Get(pair.first);
// use empty type here to avoid runtime checks.
auto op =
static_cast<OperatorWithKernel *>(info.Creator()("", {}, {}, {}));
g_map->insert(
{pair.first, [op](InferShapeContext *ctx) { op->InferShape(ctx); }});
}
}
return *g_map;
}
void OpDescBind::InferShape(const BlockDescBind &block) const {
auto &funcs = InferShapeFuncs();
auto it = funcs.find(this->Type());
if (it == funcs.end()) {
PADDLE_THROW("Operator %s has not been registered", this->Type());
}
CompileTimeInferShapeContext ctx(*this, block);
it->second(&ctx);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -100,6 +100,8 @@ class OpDescBind { ...@@ -100,6 +100,8 @@ class OpDescBind {
return &this->attrs_; return &this->attrs_;
} }
void InferShape(const BlockDescBind &block) const;
private: private:
template <typename MapType> template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) { static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
......
...@@ -198,7 +198,8 @@ void BindOpDesc(py::module &m) { ...@@ -198,7 +198,8 @@ void BindOpDesc(py::module &m) {
.def("set_attr", &OpDescBind::SetAttr) .def("set_attr", &OpDescBind::SetAttr)
.def("attr", &OpDescBind::GetAttr) .def("attr", &OpDescBind::GetAttr)
.def("set_block_attr", &OpDescBind::SetBlockAttr) .def("set_block_attr", &OpDescBind::SetBlockAttr)
.def("get_block_attr", &OpDescBind::GetBlockAttr); .def("get_block_attr", &OpDescBind::GetBlockAttr)
.def("infer_shape", &OpDescBind::InferShape);
} }
} // namespace pybind } // namespace pybind
......
...@@ -231,21 +231,6 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -231,21 +231,6 @@ All parameter, weight, gradient are variables in Paddle.
desc.InitializationErrorString()); desc.InitializationErrorString());
return OpRegistry::CreateOp(desc); return OpRegistry::CreateOp(desc);
}) })
.def_static("infer_shape",
[](OpDescBind &op_desc, BlockDescBind &block) {
auto op = OpRegistry::CreateOp(*op_desc.Proto());
auto *op_with_kernel =
dynamic_cast<OperatorWithKernel *>(op.get());
if (op_with_kernel != nullptr) {
auto ctx = CompileTimeInferShapeContext(op_desc, block);
op_with_kernel->InferShape(&ctx);
} else {
PADDLE_THROW(
"OP(%s) is not type of OperatorWithKernel, "
"should not call this function",
op_desc.Type());
}
})
.def("backward", .def("backward",
[](const OperatorBase &forwardOp, [](const OperatorBase &forwardOp,
const std::unordered_set<std::string> &no_grad_vars) { const std::unordered_set<std::string> &no_grad_vars) {
......
...@@ -13,15 +13,15 @@ class Variable(object): ...@@ -13,15 +13,15 @@ class Variable(object):
if name is None: if name is None:
name = Variable._unique_var_name_() name = Variable._unique_var_name_()
try: try:
self.proto = self.block.proto.var(name) self.desc = self.block.desc.var(name)
is_new_var = False is_new_var = False
except core.EnforceNotMet: except core.EnforceNotMet:
self.proto = self.block.proto.new_var(name) self.desc = self.block.desc.new_var(name)
is_new_var = True is_new_var = True
if shape is not None: if shape is not None:
if is_new_var: if is_new_var:
self.proto.set_shape(shape) self.desc.set_shape(shape)
else: else:
old_shape = self.shape old_shape = self.shape
shape = tuple(shape) shape = tuple(shape)
...@@ -34,7 +34,7 @@ class Variable(object): ...@@ -34,7 +34,7 @@ class Variable(object):
if not isinstance(dtype, core.DataType): if not isinstance(dtype, core.DataType):
dtype = Variable._convert_np_dtype_to_dtype_(dtype) dtype = Variable._convert_np_dtype_to_dtype_(dtype)
if is_new_var: if is_new_var:
self.proto.set_data_type(dtype) self.desc.set_data_type(dtype)
else: else:
old_dtype = self.data_type() old_dtype = self.data_type()
if dtype != old_shape: if dtype != old_shape:
...@@ -46,7 +46,7 @@ class Variable(object): ...@@ -46,7 +46,7 @@ class Variable(object):
if lod_level is not None: if lod_level is not None:
if is_new_var: if is_new_var:
self.proto.set_lod_level(lod_level) self.desc.set_lod_level(lod_level)
else: else:
if lod_level != self.lod_level: if lod_level != self.lod_level:
raise ValueError("Variable {0} has been created before. " raise ValueError("Variable {0} has been created before. "
...@@ -54,26 +54,25 @@ class Variable(object): ...@@ -54,26 +54,25 @@ class Variable(object):
"lod_level is {2}. They are not " "lod_level is {2}. They are not "
"matched".format(self.name, self.lod_level, "matched".format(self.name, self.lod_level,
lod_level)) lod_level))
self.block.vars[name] = self self.block.vars[name] = self
self.op = None self.op = None
@property @property
def name(self): def name(self):
return self.proto.name() return self.desc.name()
@property @property
def shape(self): def shape(self):
# convert to tuple, make it as same as numpy API. # convert to tuple, make it as same as numpy API.
return tuple(self.proto.shape()) return tuple(self.desc.shape())
@property @property
def data_type(self): def data_type(self):
return self.proto.data_type() return self.desc.data_type()
@property @property
def lod_level(self): def lod_level(self):
return self.proto.lod_level() return self.desc.lod_level()
@staticmethod @staticmethod
def _unique_var_name_(): def _unique_var_name_():
...@@ -104,13 +103,13 @@ class Variable(object): ...@@ -104,13 +103,13 @@ class Variable(object):
class Operator(object): class Operator(object):
def __init__(self, def __init__(self,
block, block,
proto, desc,
type=None, type=None,
inputs=None, inputs=None,
outputs=None, outputs=None,
attrs=None): attrs=None):
self.block = block self.block = block
self.proto = proto self.desc = desc
if type is not None: if type is not None:
# TODO. # TODO.
pass pass
...@@ -129,31 +128,31 @@ class Operator(object): ...@@ -129,31 +128,31 @@ class Operator(object):
class Block(object): class Block(object):
def __init__(self, program, idx): def __init__(self, program, idx):
self.proto = program.proto.block(idx) self.desc = program.desc.block(idx)
self.vars = dict() # var_name --> var self.vars = dict() # var_name --> var
self.ops = collections.deque() # operator list self.ops = collections.deque() # operator list
self.program = program self.program = program
@property @property
def parent_idx(self): def parent_idx(self):
return self.proto.parent return self.desc.parent
@property @property
def idx(self): def idx(self):
return self.proto.id return self.desc.id
def create_var(self, *args, **kwargs): def create_var(self, *args, **kwargs):
return Variable(self, *args, **kwargs) return Variable(self, *args, **kwargs)
def append_op(self, *args, **kwargs): def append_op(self, *args, **kwargs):
op_proto = self.proto.append_op() op_desc = self.desc.append_op()
op = Operator(self, op_proto, *args, **kwargs) op = Operator(self, op_desc, *args, **kwargs)
self.ops.append(op) self.ops.append(op)
return op return op
def prepend_op(self, *args, **kwargs): def prepend_op(self, *args, **kwargs):
op_proto = self.proto.prepend_op() op_desc = self.desc.prepend_op()
op = Operator(self, op_proto, *args, **kwargs) op = Operator(self, op_desc, *args, **kwargs)
self.ops.appendleft(op) self.ops.appendleft(op)
return op return op
...@@ -170,7 +169,7 @@ class Program(object): ...@@ -170,7 +169,7 @@ class Program(object):
def __init__(self): def __init__(self):
assert not hasattr(self.__class__, assert not hasattr(self.__class__,
'_instance'), 'Do not call constructor directly!' '_instance'), 'Do not call constructor directly!'
self.proto = core.ProgramDesc.instance() self.desc = core.ProgramDesc.instance()
self.blocks = [Block(self, 0)] self.blocks = [Block(self, 0)]
self.current_block_idx = 0 self.current_block_idx = 0
...@@ -182,7 +181,7 @@ class Program(object): ...@@ -182,7 +181,7 @@ class Program(object):
def create_block(self): def create_block(self):
new_block_idx = len(self.blocks) new_block_idx = len(self.blocks)
self.proto.append_block(self.current_block().proto) self.desc.append_block(self.current_block().desc)
self.current_block_idx = new_block_idx self.current_block_idx = new_block_idx
self.blocks.append(Block(self, self.current_block_idx)) self.blocks.append(Block(self, self.current_block_idx))
return self.current_block() return self.current_block()
......
import unittest import unittest
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
class TestInferShape(unittest.TestCase): class TestInferShape(unittest.TestCase):
...@@ -26,7 +26,7 @@ class TestInferShape(unittest.TestCase): ...@@ -26,7 +26,7 @@ class TestInferShape(unittest.TestCase):
sum_op_desc.set_input("X", ["x1", "x2"]) sum_op_desc.set_input("X", ["x1", "x2"])
sum_op_desc.set_output("Out", ["out"]) sum_op_desc.set_output("Out", ["out"])
core.Operator.infer_shape(sum_op_desc, block) sum_op_desc.infer_shape(block)
self.assertEqual(out.shape(), shape) self.assertEqual(out.shape(), shape)
def test_mul_op(self): def test_mul_op(self):
...@@ -55,7 +55,7 @@ class TestInferShape(unittest.TestCase): ...@@ -55,7 +55,7 @@ class TestInferShape(unittest.TestCase):
mul_op_desc.set_attr("x_num_col_dims", 1) mul_op_desc.set_attr("x_num_col_dims", 1)
mul_op_desc.set_attr("y_num_col_dims", 1) mul_op_desc.set_attr("y_num_col_dims", 1)
core.Operator.infer_shape(mul_op_desc, block) mul_op_desc.infer_shape(block)
self.assertEqual(out.shape(), [x_shape[0], y_shape[1]]) self.assertEqual(out.shape(), [x_shape[0], y_shape[1]])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册