diff --git a/paddle/fluid/framework/var_desc.cc b/paddle/fluid/framework/var_desc.cc index bb2be1ab50a59c23551c6f150beb33a1d2b4a5a7..7e3f002b53351ba5892aaa50482b21a83db94069 100644 --- a/paddle/fluid/framework/var_desc.cc +++ b/paddle/fluid/framework/var_desc.cc @@ -87,12 +87,12 @@ std::vector> VarDesc::GetShapes() const { return res; } -void VarDesc::SetDataType(proto::DataType data_type) { +void VarDesc::SetDataType(proto::VarType::Type data_type) { mutable_tensor_desc()->set_data_type(data_type); } void VarDesc::SetDataTypes( - const std::vector &multiple_data_type) { + const std::vector &multiple_data_type) { if (multiple_data_type.size() != GetTensorDescNum()) { VLOG(3) << "WARNING: The number of given data types(" << multiple_data_type.size() @@ -108,13 +108,13 @@ void VarDesc::SetDataTypes( } } -proto::DataType VarDesc::GetDataType() const { +proto::VarType::Type VarDesc::GetDataType() const { return tensor_desc().data_type(); } -std::vector VarDesc::GetDataTypes() const { +std::vector VarDesc::GetDataTypes() const { std::vector descs = tensor_descs(); - std::vector res; + std::vector res; res.reserve(descs.size()); for (const auto &tensor_desc : descs) { res.push_back(tensor_desc.data_type());