提交 0c37a061 编写于 作者: Y Yu Yang

Merge branch 'feature/change_proto_to_desc' into feature/complete_variable_bind

......@@ -22,7 +22,7 @@ Whenever we create a block, we need to set its parent block to the current block
```python
class Program(objects):
def __init__(self):
self.proto = core.NewProgram() # a C++ ProgramDesc pointer.
self.desc = core.NewProgram() # a C++ ProgramDesc pointer.
self.blocks = vector<Block>()
self.blocks.append(Block(self, -1)) # the global block
self.current_block = 0 # initialized to the global block
......@@ -57,7 +57,7 @@ A [Block](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.m
```python
class Block(objects):
def __init__(self, program, parent_idx):
self.proto = core.NewBlock(program.proto)
self.desc = core.NewBlock(program.desc)
self.program = program
self.vars = map<string, Variable>()
self.ops = vector<Operator>()
......@@ -98,11 +98,11 @@ class Operator(object):
outputs,# dict<stirng, Variable>
attrs # dict<string, Any>
):
self.proto = core.NewOpDesc(block.proto, type, inputs, outputs, attrs)
core.infer_shape(self.proto, inputs, outputs)
self.desc = core.NewOpDesc(block.desc, type, inputs, outputs, attrs)
core.infer_shape(self.desc, inputs, outputs)
def type(self):
return self.proto.type()
return self.desc.type()
```
`Operator` creates the `OpDesc` message in C++ space, so that it can call the `InferShape` function, which is in C++.
......@@ -124,7 +124,7 @@ class Variable(object):
name = unique_name_generator()
self.name = name
self.block = block
self.proto = core.NewVarDesc(block.proto, name, shape, lod_level)
self.desc = core.NewVarDesc(block.desc, name, shape, lod_level)
self.writer = None
```
......
......@@ -19,7 +19,7 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope)
proto_library(framework_proto SRCS framework.proto)
cc_library(attribute SRCS attribute.cc DEPS framework_proto)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc)
......
......@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_desc.h"
#include <functional>
#include <unordered_map>
#include "paddle/framework/block_desc.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace framework {
......@@ -185,5 +188,38 @@ void OpDescBind::Sync() {
need_update_ = false;
}
}
using InferShapeFuncMap =
std::unordered_map<std::string /*op_type*/,
std::function<void(InferShapeContext *)>>;
static InferShapeFuncMap &InferShapeFuncs() {
static InferShapeFuncMap *g_map = nullptr;
if (g_map == nullptr) {
g_map = new InferShapeFuncMap();
auto &info_map = OpInfoMap::Instance();
// all registered kernels
for (auto &pair : OperatorWithKernel::AllOpKernels()) {
auto &info = info_map.Get(pair.first);
// use empty type here to avoid runtime checks.
auto op =
static_cast<OperatorWithKernel *>(info.Creator()("", {}, {}, {}));
g_map->insert(
{pair.first, [op](InferShapeContext *ctx) { op->InferShape(ctx); }});
}
}
return *g_map;
}
void OpDescBind::InferShape(const BlockDescBind &block) const {
auto &funcs = InferShapeFuncs();
auto it = funcs.find(this->Type());
if (it == funcs.end()) {
PADDLE_THROW("Operator %s has not been registered", this->Type());
}
CompileTimeInferShapeContext ctx(*this, block);
it->second(&ctx);
}
} // namespace framework
} // namespace paddle
......@@ -100,6 +100,8 @@ class OpDescBind {
return &this->attrs_;
}
void InferShape(const BlockDescBind &block) const;
private:
template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
......
......@@ -198,7 +198,8 @@ void BindOpDesc(py::module &m) {
.def("set_attr", &OpDescBind::SetAttr)
.def("attr", &OpDescBind::GetAttr)
.def("set_block_attr", &OpDescBind::SetBlockAttr)
.def("get_block_attr", &OpDescBind::GetBlockAttr);
.def("get_block_attr", &OpDescBind::GetBlockAttr)
.def("infer_shape", &OpDescBind::InferShape);
}
} // namespace pybind
......
......@@ -231,21 +231,6 @@ All parameter, weight, gradient are variables in Paddle.
desc.InitializationErrorString());
return OpRegistry::CreateOp(desc);
})
.def_static("infer_shape",
[](OpDescBind &op_desc, BlockDescBind &block) {
auto op = OpRegistry::CreateOp(*op_desc.Proto());
auto *op_with_kernel =
dynamic_cast<OperatorWithKernel *>(op.get());
if (op_with_kernel != nullptr) {
auto ctx = CompileTimeInferShapeContext(op_desc, block);
op_with_kernel->InferShape(&ctx);
} else {
PADDLE_THROW(
"OP(%s) is not type of OperatorWithKernel, "
"should not call this function",
op_desc.Type());
}
})
.def("backward",
[](const OperatorBase &forwardOp,
const std::unordered_set<std::string> &no_grad_vars) {
......
......@@ -13,15 +13,15 @@ class Variable(object):
if name is None:
name = Variable._unique_var_name_()
try:
self.proto = self.block.proto.var(name)
self.desc = self.block.desc.var(name)
is_new_var = False
except core.EnforceNotMet:
self.proto = self.block.proto.new_var(name)
self.desc = self.block.desc.new_var(name)
is_new_var = True
if shape is not None:
if is_new_var:
self.proto.set_shape(shape)
self.desc.set_shape(shape)
else:
old_shape = self.shape
shape = tuple(shape)
......@@ -34,7 +34,7 @@ class Variable(object):
if not isinstance(dtype, core.DataType):
dtype = Variable._convert_np_dtype_to_dtype_(dtype)
if is_new_var:
self.proto.set_data_type(dtype)
self.desc.set_data_type(dtype)
else:
old_dtype = self.data_type()
if dtype != old_shape:
......@@ -46,7 +46,7 @@ class Variable(object):
if lod_level is not None:
if is_new_var:
self.proto.set_lod_level(lod_level)
self.desc.set_lod_level(lod_level)
else:
if lod_level != self.lod_level:
raise ValueError("Variable {0} has been created before. "
......@@ -54,26 +54,25 @@ class Variable(object):
"lod_level is {2}. They are not "
"matched".format(self.name, self.lod_level,
lod_level))
self.block.vars[name] = self
self.op = None
@property
def name(self):
return self.proto.name()
return self.desc.name()
@property
def shape(self):
# convert to tuple, make it as same as numpy API.
return tuple(self.proto.shape())
return tuple(self.desc.shape())
@property
def data_type(self):
return self.proto.data_type()
return self.desc.data_type()
@property
def lod_level(self):
return self.proto.lod_level()
return self.desc.lod_level()
@staticmethod
def _unique_var_name_():
......@@ -104,13 +103,13 @@ class Variable(object):
class Operator(object):
def __init__(self,
block,
proto,
desc,
type=None,
inputs=None,
outputs=None,
attrs=None):
self.block = block
self.proto = proto
self.desc = desc
if type is not None:
# TODO.
pass
......@@ -129,31 +128,31 @@ class Operator(object):
class Block(object):
def __init__(self, program, idx):
self.proto = program.proto.block(idx)
self.desc = program.desc.block(idx)
self.vars = dict() # var_name --> var
self.ops = collections.deque() # operator list
self.program = program
@property
def parent_idx(self):
return self.proto.parent
return self.desc.parent
@property
def idx(self):
return self.proto.id
return self.desc.id
def create_var(self, *args, **kwargs):
return Variable(self, *args, **kwargs)
def append_op(self, *args, **kwargs):
op_proto = self.proto.append_op()
op = Operator(self, op_proto, *args, **kwargs)
op_desc = self.desc.append_op()
op = Operator(self, op_desc, *args, **kwargs)
self.ops.append(op)
return op
def prepend_op(self, *args, **kwargs):
op_proto = self.proto.prepend_op()
op = Operator(self, op_proto, *args, **kwargs)
op_desc = self.desc.prepend_op()
op = Operator(self, op_desc, *args, **kwargs)
self.ops.appendleft(op)
return op
......@@ -170,7 +169,7 @@ class Program(object):
def __init__(self):
assert not hasattr(self.__class__,
'_instance'), 'Do not call constructor directly!'
self.proto = core.ProgramDesc.instance()
self.desc = core.ProgramDesc.instance()
self.blocks = [Block(self, 0)]
self.current_block_idx = 0
......@@ -182,7 +181,7 @@ class Program(object):
def create_block(self):
new_block_idx = len(self.blocks)
self.proto.append_block(self.current_block().proto)
self.desc.append_block(self.current_block().desc)
self.current_block_idx = new_block_idx
self.blocks.append(Block(self, self.current_block_idx))
return self.current_block()
......
import unittest
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
class TestInferShape(unittest.TestCase):
......@@ -26,7 +26,7 @@ class TestInferShape(unittest.TestCase):
sum_op_desc.set_input("X", ["x1", "x2"])
sum_op_desc.set_output("Out", ["out"])
core.Operator.infer_shape(sum_op_desc, block)
sum_op_desc.infer_shape(block)
self.assertEqual(out.shape(), shape)
def test_mul_op(self):
......@@ -55,7 +55,7 @@ class TestInferShape(unittest.TestCase):
mul_op_desc.set_attr("x_num_col_dims", 1)
mul_op_desc.set_attr("y_num_col_dims", 1)
core.Operator.infer_shape(mul_op_desc, block)
mul_op_desc.infer_shape(block)
self.assertEqual(out.shape(), [x_shape[0], y_shape[1]])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册