提交 d17eb73e 编写于 作者: Y Yu Yang 提交者: GitHub

Update VarDesc from design doc (#4769)

* Update VarDesc from design doc

* Fix GCC compile

* Fix unittest
上级 9a6dffd4
......@@ -46,6 +46,7 @@ void AddOp(const std::string& type, const VariableNameMap& inputs,
for (auto kv : outputs) {
for (auto v : kv.second) {
auto var = block->NewVar(v);
var->SetType(VarDesc::LOD_TENSOR);
var->SetDataType(paddle::framework::DataType::FP32);
}
}
......
......@@ -97,16 +97,26 @@ enum DataType {
FP64 = 6;
}
message LoDTensorDesc {
message TensorDesc {
required DataType data_type = 1;
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
optional int32 lod_level = 3 [ default = 0 ];
}
message LoDTensorDesc {
required TensorDesc tensor = 1;
optional int32 lod_level = 2 [ default = 0 ];
}
message VarDesc {
enum VarType {
LOD_TENSOR = 1;
SELECTED_ROWS = 2;
}
required string name = 1;
optional LoDTensorDesc lod_tensor = 2;
optional bool persistable = 3 [ default = false ];
required VarType type = 2;
optional LoDTensorDesc lod_tensor = 3;
optional TensorDesc selected_rows = 4;
optional bool persistable = 5 [ default = false ];
}
message BlockDesc {
......
......@@ -13,32 +13,58 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/var_desc.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace framework {
void VarDescBind::SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims());
VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
}
void VarDescBind::SetDataType(DataType data_type) {
desc_.mutable_lod_tensor()->set_data_type(data_type);
mutable_tensor_desc()->set_data_type(data_type);
}
std::vector<int64_t> VarDescBind::Shape() const {
return RepeatedToVector(desc_.lod_tensor().dims());
return RepeatedToVector(tensor_desc().dims());
}
DataType VarDescBind::GetDataType() const {
return desc_.lod_tensor().data_type();
}
DataType VarDescBind::GetDataType() const { return tensor_desc().data_type(); }
void VarDescBind::SetLoDLevel(int32_t lod_level) {
PADDLE_ENFORCE(desc_.type() == VarDesc::LOD_TENSOR);
desc_.mutable_lod_tensor()->set_lod_level(lod_level);
}
int32_t VarDescBind::GetLodLevel() const {
PADDLE_ENFORCE(desc_.type() == VarDesc::LOD_TENSOR);
return desc_.lod_tensor().lod_level();
}
const TensorDesc &VarDescBind::tensor_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "invoke TensorDesc must after set type");
switch (desc_.type()) {
case VarDesc::SELECTED_ROWS:
return desc_.selected_rows();
case VarDesc::LOD_TENSOR:
return desc_.lod_tensor().tensor();
default:
PADDLE_THROW("Unexpected branch.");
}
}
TensorDesc *VarDescBind::mutable_tensor_desc() {
PADDLE_ENFORCE(desc_.has_type(),
"invoke MutableTensorDesc must after set type");
switch (desc_.type()) {
case VarDesc::SELECTED_ROWS:
return desc_.mutable_selected_rows();
case VarDesc::LOD_TENSOR:
return desc_.mutable_lod_tensor()->mutable_tensor();
default:
PADDLE_THROW("Unexpected branch.");
}
}
} // namespace framework
} // namespace paddle
......@@ -72,7 +72,14 @@ class VarDescBind {
int32_t GetLodLevel() const;
VarDesc::VarType GetType() const { return desc_.type(); }
void SetType(VarDesc::VarType type) { desc_.set_type(type); }
private:
const TensorDesc &tensor_desc() const;
TensorDesc *mutable_tensor_desc();
VarDesc desc_;
};
} // namespace framework
......
......@@ -162,7 +162,8 @@ void BindVarDsec(py::module &m) {
.value("FP32", DataType::FP32)
.value("FP64", DataType::FP64);
py::class_<VarDescBind>(m, "VarDesc", "")
py::class_<VarDescBind> var_desc(m, "VarDesc", "");
var_desc
.def("name",
[](const VarDescBind &self) {
py::bytes name = self.Name();
......@@ -174,7 +175,13 @@ void BindVarDsec(py::module &m) {
.def("shape", &VarDescBind::Shape, py::return_value_policy::reference)
.def("data_type", &VarDescBind::GetDataType)
.def("lod_level", &VarDescBind::GetLodLevel)
.def("set_lod_level", &VarDescBind::SetLoDLevel);
.def("set_lod_level", &VarDescBind::SetLoDLevel)
.def("type", &VarDescBind::GetType)
.def("set_type", &VarDescBind::SetType);
py::enum_<VarDesc::VarType>(var_desc, "VarType", "")
.value("LOD_TENSOR", VarDesc::LOD_TENSOR)
.value("SELECTED_ROWS", VarDesc::SELECTED_ROWS);
}
void BindOpDesc(py::module &m) {
......
......@@ -10,6 +10,7 @@ __all__ = ['Block', 'Variable', 'Program', 'Operator']
class Variable(object):
def __init__(self,
block,
type=core.VarDesc.VarType.LOD_TENSOR,
name=None,
shape=None,
dtype=None,
......@@ -26,6 +27,14 @@ class Variable(object):
self.desc = self.block.desc.new_var(name)
is_new_var = True
if is_new_var:
self.desc.set_type(type)
elif self.desc.type() != type:
raise ValueError("Variable {0} has been created before. The "
"previous type is {1}; the new type is {2}. They"
" are not matched".format(self.name,
self.desc.type(), type))
if shape is not None:
if is_new_var:
self.desc.set_shape(shape)
......
......@@ -14,11 +14,14 @@ class TestInferShape(unittest.TestCase):
# prepare input/output
x1 = block.new_var("x1")
x1.set_type(core.VarDesc.VarType.LOD_TENSOR)
x1.set_shape(shape)
x2 = block.new_var("x2")
x2.set_type(core.VarDesc.VarType.LOD_TENSOR)
x2.set_shape(shape)
out = block.new_var("out")
out.set_type(core.VarDesc.VarType.LOD_TENSOR)
# prepare the operator
sum_op_desc = block.append_op()
......@@ -40,11 +43,14 @@ class TestInferShape(unittest.TestCase):
# prepare input/output
x1 = block.new_var("x")
x1.set_type(core.VarDesc.VarType.LOD_TENSOR)
x1.set_shape(x_shape)
x2 = block.new_var("y")
x2.set_type(core.VarDesc.VarType.LOD_TENSOR)
x2.set_shape(y_shape)
out = block.new_var("out")
out.set_type(core.VarDesc.VarType.LOD_TENSOR)
# prepare the operator
mul_op_desc = block.append_op()
......
......@@ -94,17 +94,21 @@ class TestVarDesc(unittest.TestCase):
program_desc = core.ProgramDesc.__create_program_desc__()
block = program_desc.block(0)
var = block.new_var('my_var')
var.set_type(core.VarDesc.VarType.SELECTED_ROWS)
src_shape = [3, 2, 10, 8]
var.set_shape(src_shape)
res_shape = var.shape()
self.assertEqual(src_shape, res_shape)
self.assertEqual(core.VarDesc.VarType.SELECTED_ROWS, var.type())
def test_data_type(self):
program_desc = core.ProgramDesc.__create_program_desc__()
block = program_desc.block(0)
var = block.new_var('my_var')
var.set_type(core.VarDesc.VarType.LOD_TENSOR)
var.set_data_type(core.DataType.INT32)
self.assertEqual(core.DataType.INT32, var.data_type())
self.assertEqual(core.VarDesc.VarType.LOD_TENSOR, var.type())
class TestBlockDesc(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册