提交 d17eb73e 编写于 作者: Y Yu Yang 提交者: GitHub

Update VarDesc from design doc (#4769)

* Update VarDesc from design doc

* Fix GCC compile

* Fix unittest
上级 9a6dffd4
...@@ -46,6 +46,7 @@ void AddOp(const std::string& type, const VariableNameMap& inputs, ...@@ -46,6 +46,7 @@ void AddOp(const std::string& type, const VariableNameMap& inputs,
for (auto kv : outputs) { for (auto kv : outputs) {
for (auto v : kv.second) { for (auto v : kv.second) {
auto var = block->NewVar(v); auto var = block->NewVar(v);
var->SetType(VarDesc::LOD_TENSOR);
var->SetDataType(paddle::framework::DataType::FP32); var->SetDataType(paddle::framework::DataType::FP32);
} }
} }
......
...@@ -97,16 +97,26 @@ enum DataType { ...@@ -97,16 +97,26 @@ enum DataType {
FP64 = 6; FP64 = 6;
} }
message LoDTensorDesc { message TensorDesc {
required DataType data_type = 1; required DataType data_type = 1;
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480] repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
optional int32 lod_level = 3 [ default = 0 ]; }
message LoDTensorDesc {
required TensorDesc tensor = 1;
optional int32 lod_level = 2 [ default = 0 ];
} }
message VarDesc { message VarDesc {
enum VarType {
LOD_TENSOR = 1;
SELECTED_ROWS = 2;
}
required string name = 1; required string name = 1;
optional LoDTensorDesc lod_tensor = 2; required VarType type = 2;
optional bool persistable = 3 [ default = false ]; optional LoDTensorDesc lod_tensor = 3;
optional TensorDesc selected_rows = 4;
optional bool persistable = 5 [ default = false ];
} }
message BlockDesc { message BlockDesc {
......
...@@ -13,32 +13,58 @@ See the License for the specific language governing permissions and ...@@ -13,32 +13,58 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/framework/var_desc.h" #include "paddle/framework/var_desc.h"
#include "paddle/platform/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void VarDescBind::SetShape(const std::vector<int64_t> &dims) { void VarDescBind::SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims()); VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
} }
void VarDescBind::SetDataType(DataType data_type) { void VarDescBind::SetDataType(DataType data_type) {
desc_.mutable_lod_tensor()->set_data_type(data_type); mutable_tensor_desc()->set_data_type(data_type);
} }
std::vector<int64_t> VarDescBind::Shape() const { std::vector<int64_t> VarDescBind::Shape() const {
return RepeatedToVector(desc_.lod_tensor().dims()); return RepeatedToVector(tensor_desc().dims());
} }
DataType VarDescBind::GetDataType() const { DataType VarDescBind::GetDataType() const { return tensor_desc().data_type(); }
return desc_.lod_tensor().data_type();
}
void VarDescBind::SetLoDLevel(int32_t lod_level) { void VarDescBind::SetLoDLevel(int32_t lod_level) {
PADDLE_ENFORCE(desc_.type() == VarDesc::LOD_TENSOR);
desc_.mutable_lod_tensor()->set_lod_level(lod_level); desc_.mutable_lod_tensor()->set_lod_level(lod_level);
} }
int32_t VarDescBind::GetLodLevel() const { int32_t VarDescBind::GetLodLevel() const {
PADDLE_ENFORCE(desc_.type() == VarDesc::LOD_TENSOR);
return desc_.lod_tensor().lod_level(); return desc_.lod_tensor().lod_level();
} }
const TensorDesc &VarDescBind::tensor_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "invoke TensorDesc must after set type");
switch (desc_.type()) {
case VarDesc::SELECTED_ROWS:
return desc_.selected_rows();
case VarDesc::LOD_TENSOR:
return desc_.lod_tensor().tensor();
default:
PADDLE_THROW("Unexpected branch.");
}
}
TensorDesc *VarDescBind::mutable_tensor_desc() {
PADDLE_ENFORCE(desc_.has_type(),
"invoke MutableTensorDesc must after set type");
switch (desc_.type()) {
case VarDesc::SELECTED_ROWS:
return desc_.mutable_selected_rows();
case VarDesc::LOD_TENSOR:
return desc_.mutable_lod_tensor()->mutable_tensor();
default:
PADDLE_THROW("Unexpected branch.");
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -72,7 +72,14 @@ class VarDescBind { ...@@ -72,7 +72,14 @@ class VarDescBind {
int32_t GetLodLevel() const; int32_t GetLodLevel() const;
VarDesc::VarType GetType() const { return desc_.type(); }
void SetType(VarDesc::VarType type) { desc_.set_type(type); }
private: private:
const TensorDesc &tensor_desc() const;
TensorDesc *mutable_tensor_desc();
VarDesc desc_; VarDesc desc_;
}; };
} // namespace framework } // namespace framework
......
...@@ -162,7 +162,8 @@ void BindVarDsec(py::module &m) { ...@@ -162,7 +162,8 @@ void BindVarDsec(py::module &m) {
.value("FP32", DataType::FP32) .value("FP32", DataType::FP32)
.value("FP64", DataType::FP64); .value("FP64", DataType::FP64);
py::class_<VarDescBind>(m, "VarDesc", "") py::class_<VarDescBind> var_desc(m, "VarDesc", "");
var_desc
.def("name", .def("name",
[](const VarDescBind &self) { [](const VarDescBind &self) {
py::bytes name = self.Name(); py::bytes name = self.Name();
...@@ -174,7 +175,13 @@ void BindVarDsec(py::module &m) { ...@@ -174,7 +175,13 @@ void BindVarDsec(py::module &m) {
.def("shape", &VarDescBind::Shape, py::return_value_policy::reference) .def("shape", &VarDescBind::Shape, py::return_value_policy::reference)
.def("data_type", &VarDescBind::GetDataType) .def("data_type", &VarDescBind::GetDataType)
.def("lod_level", &VarDescBind::GetLodLevel) .def("lod_level", &VarDescBind::GetLodLevel)
.def("set_lod_level", &VarDescBind::SetLoDLevel); .def("set_lod_level", &VarDescBind::SetLoDLevel)
.def("type", &VarDescBind::GetType)
.def("set_type", &VarDescBind::SetType);
py::enum_<VarDesc::VarType>(var_desc, "VarType", "")
.value("LOD_TENSOR", VarDesc::LOD_TENSOR)
.value("SELECTED_ROWS", VarDesc::SELECTED_ROWS);
} }
void BindOpDesc(py::module &m) { void BindOpDesc(py::module &m) {
......
...@@ -10,6 +10,7 @@ __all__ = ['Block', 'Variable', 'Program', 'Operator'] ...@@ -10,6 +10,7 @@ __all__ = ['Block', 'Variable', 'Program', 'Operator']
class Variable(object): class Variable(object):
def __init__(self, def __init__(self,
block, block,
type=core.VarDesc.VarType.LOD_TENSOR,
name=None, name=None,
shape=None, shape=None,
dtype=None, dtype=None,
...@@ -26,6 +27,14 @@ class Variable(object): ...@@ -26,6 +27,14 @@ class Variable(object):
self.desc = self.block.desc.new_var(name) self.desc = self.block.desc.new_var(name)
is_new_var = True is_new_var = True
if is_new_var:
self.desc.set_type(type)
elif self.desc.type() != type:
raise ValueError("Variable {0} has been created before. The "
"previous type is {1}; the new type is {2}. They"
" are not matched".format(self.name,
self.desc.type(), type))
if shape is not None: if shape is not None:
if is_new_var: if is_new_var:
self.desc.set_shape(shape) self.desc.set_shape(shape)
......
...@@ -14,11 +14,14 @@ class TestInferShape(unittest.TestCase): ...@@ -14,11 +14,14 @@ class TestInferShape(unittest.TestCase):
# prepare input/output # prepare input/output
x1 = block.new_var("x1") x1 = block.new_var("x1")
x1.set_type(core.VarDesc.VarType.LOD_TENSOR)
x1.set_shape(shape) x1.set_shape(shape)
x2 = block.new_var("x2") x2 = block.new_var("x2")
x2.set_type(core.VarDesc.VarType.LOD_TENSOR)
x2.set_shape(shape) x2.set_shape(shape)
out = block.new_var("out") out = block.new_var("out")
out.set_type(core.VarDesc.VarType.LOD_TENSOR)
# prepare the operator # prepare the operator
sum_op_desc = block.append_op() sum_op_desc = block.append_op()
...@@ -40,11 +43,14 @@ class TestInferShape(unittest.TestCase): ...@@ -40,11 +43,14 @@ class TestInferShape(unittest.TestCase):
# prepare input/output # prepare input/output
x1 = block.new_var("x") x1 = block.new_var("x")
x1.set_type(core.VarDesc.VarType.LOD_TENSOR)
x1.set_shape(x_shape) x1.set_shape(x_shape)
x2 = block.new_var("y") x2 = block.new_var("y")
x2.set_type(core.VarDesc.VarType.LOD_TENSOR)
x2.set_shape(y_shape) x2.set_shape(y_shape)
out = block.new_var("out") out = block.new_var("out")
out.set_type(core.VarDesc.VarType.LOD_TENSOR)
# prepare the operator # prepare the operator
mul_op_desc = block.append_op() mul_op_desc = block.append_op()
......
...@@ -94,17 +94,21 @@ class TestVarDesc(unittest.TestCase): ...@@ -94,17 +94,21 @@ class TestVarDesc(unittest.TestCase):
program_desc = core.ProgramDesc.__create_program_desc__() program_desc = core.ProgramDesc.__create_program_desc__()
block = program_desc.block(0) block = program_desc.block(0)
var = block.new_var('my_var') var = block.new_var('my_var')
var.set_type(core.VarDesc.VarType.SELECTED_ROWS)
src_shape = [3, 2, 10, 8] src_shape = [3, 2, 10, 8]
var.set_shape(src_shape) var.set_shape(src_shape)
res_shape = var.shape() res_shape = var.shape()
self.assertEqual(src_shape, res_shape) self.assertEqual(src_shape, res_shape)
self.assertEqual(core.VarDesc.VarType.SELECTED_ROWS, var.type())
def test_data_type(self): def test_data_type(self):
program_desc = core.ProgramDesc.__create_program_desc__() program_desc = core.ProgramDesc.__create_program_desc__()
block = program_desc.block(0) block = program_desc.block(0)
var = block.new_var('my_var') var = block.new_var('my_var')
var.set_type(core.VarDesc.VarType.LOD_TENSOR)
var.set_data_type(core.DataType.INT32) var.set_data_type(core.DataType.INT32)
self.assertEqual(core.DataType.INT32, var.data_type()) self.assertEqual(core.DataType.INT32, var.data_type())
self.assertEqual(core.VarDesc.VarType.LOD_TENSOR, var.type())
class TestBlockDesc(unittest.TestCase): class TestBlockDesc(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册