diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 179f9194a9dec07489871606f39569b0a67b2c52..ebfd54fdc557148dda81ba5e4936c0cd5f23a887 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -36,24 +36,24 @@ namespace framework { Executor::Executor(const platform::Place& place) : place_(place) {} -static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { - if (var_type == proto::VarDesc::LOD_TENSOR) { +static void CreateTensor(Variable* var, proto::VarType::Type var_type) { + if (var_type == proto::VarType::LOD_TENSOR) { var->GetMutable(); - } else if (var_type == proto::VarDesc::SELECTED_ROWS) { + } else if (var_type == proto::VarType::SELECTED_ROWS) { var->GetMutable(); - } else if (var_type == proto::VarDesc::FEED_MINIBATCH) { + } else if (var_type == proto::VarType::FEED_MINIBATCH) { var->GetMutable(); - } else if (var_type == proto::VarDesc::FETCH_LIST) { + } else if (var_type == proto::VarType::FETCH_LIST) { var->GetMutable(); - } else if (var_type == proto::VarDesc::STEP_SCOPES) { + } else if (var_type == proto::VarType::STEP_SCOPES) { var->GetMutable>(); - } else if (var_type == proto::VarDesc::LOD_RANK_TABLE) { + } else if (var_type == proto::VarType::LOD_RANK_TABLE) { var->GetMutable(); - } else if (var_type == proto::VarDesc::LOD_TENSOR_ARRAY) { + } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) { var->GetMutable(); - } else if (var_type == proto::VarDesc::PLACE_LIST) { + } else if (var_type == proto::VarType::PLACE_LIST) { var->GetMutable(); - } else if (var_type == proto::VarDesc::READER) { + } else if (var_type == proto::VarType::READER) { var->GetMutable(); } else { PADDLE_THROW( @@ -182,7 +182,7 @@ static bool has_feed_operators( auto var = block->FindVar(feed_holder_name); PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", feed_holder_name); - PADDLE_ENFORCE_EQ(var->GetType(), proto::VarDesc::FEED_MINIBATCH, + PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH, "'%s' variable should be 'FEED_MINIBATCH' type", feed_holder_name); } @@ -222,7 +222,7 @@ static bool has_fetch_operators( auto var = block->FindVar(fetch_holder_name); PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", fetch_holder_name); - PADDLE_ENFORCE_EQ(var->GetType(), proto::VarDesc::FETCH_LIST, + PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST, "'%s' variable should be 'FETCH_LIST' type", fetch_holder_name); } @@ -241,7 +241,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) { // create feed_holder variable auto* feed_holder = global_block->Var(feed_holder_name); - feed_holder->SetType(proto::VarDesc::FEED_MINIBATCH); + feed_holder->SetType(proto::VarType::FEED_MINIBATCH); feed_holder->SetPersistable(true); int i = 0; @@ -274,7 +274,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) { // create fetch_holder variable auto* fetch_holder = global_block->Var(fetch_holder_name); - fetch_holder->SetType(proto::VarDesc::FETCH_LIST); + fetch_holder->SetType(proto::VarType::FETCH_LIST); fetch_holder->SetPersistable(true); int i = 0; diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index ad8da21ae0f1df0ede4dba6c9564a506ee8463db..fa7f437851141e99e2c33a854e316e3ed2bedadc 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -101,25 +101,8 @@ enum DataType { FP64 = 6; } -message TensorDesc { - required DataType data_type = 1; - repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480] -} - -message LoDTensorDesc { - required TensorDesc tensor = 1; - optional int32 lod_level = 2 [ default = 0 ]; -} - -message LoDTensorArrayDesc { - required TensorDesc tensor = 1; - optional int32 lod_level = 2 [ default = 0 ]; -} - -message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; } - -message VarDesc { - enum VarType { +message VarType { + enum Type { LOD_TENSOR = 1; SELECTED_ROWS = 2; FEED_MINIBATCH = 3; @@ -130,13 +113,35 @@ message VarDesc { PLACE_LIST = 8; READER = 9; } + + required Type type = 1; + + message TensorDesc { + required DataType data_type = 1; + repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480] + } + optional TensorDesc selected_rows = 2; + + message LoDTensorDesc { + required TensorDesc tensor = 1; + optional int32 lod_level = 2 [ default = 0 ]; + } + optional LoDTensorDesc lod_tensor = 3; + + message LoDTensorArrayDesc { + required TensorDesc tensor = 1; + optional int32 lod_level = 2 [ default = 0 ]; + } + optional LoDTensorArrayDesc tensor_array = 4; + + message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; } + optional ReaderDesc reader = 5; +} + +message VarDesc { required string name = 1; required VarType type = 2; optional bool persistable = 3 [ default = false ]; - optional LoDTensorDesc lod_tensor = 4; - optional TensorDesc selected_rows = 5; - optional LoDTensorArrayDesc tensor_array = 6; - optional ReaderDesc reader = 7; } message BlockDesc { diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index e740010c63ce339527fcf58ebbdff0441fdf1467..eabfdc11a8b314c4af9626ded3edd1bcba212de1 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -53,11 +53,11 @@ class CompileTimeInferShapeContext : public InferShapeContext { PADDLE_ENFORCE_LT(j, Outputs(out).size()); auto *in_var = block_.FindVarRecursive(Inputs(in)[i]); auto *out_var = block_.FindVarRecursive(Outputs(out)[j]); - if (in_var->GetType() != proto::VarDesc::LOD_TENSOR) { + if (in_var->GetType() != proto::VarType::LOD_TENSOR) { VLOG(3) << "input " << in << " is not LodTensor"; return; } - PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarDesc::LOD_TENSOR, + PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarType::LOD_TENSOR, "The %d-th output of Output(%s) must be LoDTensor.", j, out); out_var->SetLoDLevel(in_var->GetLoDLevel()); @@ -66,7 +66,7 @@ class CompileTimeInferShapeContext : public InferShapeContext { bool IsRuntime() const override; protected: - proto::VarDesc::VarType GetVarType(const std::string &name) const override; + proto::VarType::Type GetVarType(const std::string &name) const override; DDim GetDim(const std::string &name) const override; @@ -388,7 +388,7 @@ void OpDesc::InferVarType(BlockDesc *block) const { for (auto &out_pair : this->outputs_) { for (auto &out_var_name : out_pair.second) { block->FindRecursiveOrCreateVar(out_var_name) - .SetType(proto::VarDesc::LOD_TENSOR); + .SetType(proto::VarType::LOD_TENSOR); } } } @@ -507,7 +507,7 @@ void CompileTimeInferShapeContext::SetRepeatedDims( bool CompileTimeInferShapeContext::IsRuntime() const { return false; } -proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType( +proto::VarType::Type CompileTimeInferShapeContext::GetVarType( const std::string &name) const { return block_.FindVarRecursive(name)->GetType(); } diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index bc529b8269a89c791eab1de816cb513e6084d8b8..ff90aba10baf734db4412b7b31f89dda8dbf7667 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -477,7 +477,7 @@ class RuntimeInferShapeContext : public InferShapeContext { } } - proto::VarDesc::VarType GetVarType(const std::string& name) const override { + proto::VarType::Type GetVarType(const std::string& name) const override { auto* var = scope_.FindVar(name); return ToVarType(var->Type()); } diff --git a/paddle/fluid/framework/program_desc_test.cc b/paddle/fluid/framework/program_desc_test.cc index 3a4a87cfa5a0675ecdb84ed6843970cfff0b71b6..d9c4331da1d8a8a1091073b9654d6f785a27c78e 100644 --- a/paddle/fluid/framework/program_desc_test.cc +++ b/paddle/fluid/framework/program_desc_test.cc @@ -22,13 +22,13 @@ TEST(ProgramDesc, copy_ctor) { ProgramDesc program; auto* global_block = program.MutableBlock(0); auto* x = global_block->Var("X"); - x->SetType(proto::VarDesc_VarType_LOD_TENSOR); + x->SetType(proto::VarType::LOD_TENSOR); x->SetLoDLevel(0); x->SetDataType(proto::FP32); x->SetShape({1000, 784}); auto* y = global_block->Var("Y"); - y->SetType(proto::VarDesc_VarType_LOD_TENSOR); + y->SetType(proto::VarType::LOD_TENSOR); y->SetLoDLevel(0); y->SetDataType(proto::FP32); y->SetShape({784, 100}); @@ -39,7 +39,7 @@ TEST(ProgramDesc, copy_ctor) { op->SetInput("Y", {y->Name()}); auto* out = global_block->Var("Out"); - out->SetType(proto::VarDesc_VarType_LOD_TENSOR); + out->SetType(proto::VarType::LOD_TENSOR); op->SetOutput("Y", {out->Name()}); ProgramDesc program_copy(program); @@ -84,13 +84,13 @@ TEST(ProgramDescBind, serialize_and_deserialize) { ProgramDesc program_origin; auto* global_block = program_origin.MutableBlock(0); auto* x = global_block->Var("X"); - x->SetType(proto::VarDesc_VarType_LOD_TENSOR); + x->SetType(proto::VarType::LOD_TENSOR); x->SetLoDLevel(0); x->SetDataType(proto::FP32); x->SetShape({1000, 784}); auto* y = global_block->Var("Y"); - y->SetType(proto::VarDesc_VarType_LOD_TENSOR); + y->SetType(proto::VarType::LOD_TENSOR); y->SetLoDLevel(0); y->SetDataType(proto::FP32); y->SetShape({784, 100}); @@ -101,7 +101,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) { op->SetInput("Y", {y->Name()}); auto* out = global_block->Var("Out"); - out->SetType(proto::VarDesc_VarType_LOD_TENSOR); + out->SetType(proto::VarType::LOD_TENSOR); op->SetOutput("Y", {out->Name()}); std::string binary_str; diff --git a/paddle/fluid/framework/shape_inference.cc b/paddle/fluid/framework/shape_inference.cc index 1b518970acd16384bf7e19c9c8c17874171ad962..dc9a79020f103dadfd9837cffb18ad5946f95f31 100644 --- a/paddle/fluid/framework/shape_inference.cc +++ b/paddle/fluid/framework/shape_inference.cc @@ -116,19 +116,19 @@ void InferShapeContext::SetDims(const std::vector &names, } } -std::vector InferShapeContext::GetInputsVarType( +std::vector InferShapeContext::GetInputsVarType( const std::string &name) const { return GetVarTypes(Inputs(name)); } -std::vector InferShapeContext::GetOutputsVarType( +std::vector InferShapeContext::GetOutputsVarType( const std::string &name) const { return GetVarTypes(Outputs(name)); } -std::vector InferShapeContext::GetVarTypes( +std::vector InferShapeContext::GetVarTypes( const std::vector &names) const { - std::vector retv; + std::vector retv; retv.resize(names.size()); std::transform(names.begin(), names.end(), retv.begin(), std::bind(std::mem_fn(&InferShapeContext::GetVarType), this, diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 3739d640fe981bd174f7955bbdf61d72ffebe738..bc02d700da5186cea5f370b9676e408f62a66a68 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -31,9 +31,9 @@ class InferShapeContext { virtual bool HasInput(const std::string &name) const = 0; virtual bool HasOutput(const std::string &name) const = 0; - std::vector GetInputsVarType( + std::vector GetInputsVarType( const std::string &name) const; - std::vector GetOutputsVarType( + std::vector GetOutputsVarType( const std::string &name) const; virtual bool HasInputs(const std::string &name) const = 0; @@ -75,10 +75,10 @@ class InferShapeContext { std::vector GetDims(const std::vector &names) const; - std::vector GetVarTypes( + std::vector GetVarTypes( const std::vector &names) const; - virtual proto::VarDesc::VarType GetVarType(const std::string &name) const = 0; + virtual proto::VarType::Type GetVarType(const std::string &name) const = 0; virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0; }; diff --git a/paddle/fluid/framework/tensor_util.h b/paddle/fluid/framework/tensor_util.h index 22519013cc463c46a03dd02be3f90e6ea0012689..f0464d480785a5c27ba8386165a8a0093ff0ed38 100644 --- a/paddle/fluid/framework/tensor_util.h +++ b/paddle/fluid/framework/tensor_util.h @@ -225,7 +225,7 @@ inline void SerializeToStream(std::ostream& os, const Tensor& tensor, { // the 2nd field, tensor description // int32_t size // void* protobuf message - proto::TensorDesc desc; + proto::VarType::TensorDesc desc; desc.set_data_type(framework::ToDataType(tensor.type())); auto dims = framework::vectorize(tensor.dims()); auto* pb_dims = desc.mutable_dims(); @@ -290,7 +290,7 @@ inline void DeserializeFromStream(std::istream& is, Tensor* tensor, uint32_t version; is.read(reinterpret_cast(&version), sizeof(version)); PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); - proto::TensorDesc desc; + proto::VarType::TensorDesc desc; { // int32_t size // proto buffer int32_t size; diff --git a/paddle/fluid/framework/var_desc.cc b/paddle/fluid/framework/var_desc.cc index eb881469697e1d5b44479c53091a37f8b18e4ab9..bb2be1ab50a59c23551c6f150beb33a1d2b4a5a7 100644 --- a/paddle/fluid/framework/var_desc.cc +++ b/paddle/fluid/framework/var_desc.cc @@ -18,18 +18,21 @@ limitations under the License. */ namespace paddle { namespace framework { -proto::VarDesc::VarType VarDesc::GetType() const { return desc_.type(); } +proto::VarType::Type VarDesc::GetType() const { return desc_.type().type(); } -void VarDesc::SetType(proto::VarDesc::VarType type) { desc_.set_type(type); } +void VarDesc::SetType(proto::VarType::Type type) { + desc_.mutable_type()->set_type(type); +} void VarDesc::SetShape(const std::vector &dims) { VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims()); } void VarDesc::SetTensorDescNum(size_t num) { - switch (desc_.type()) { - case proto::VarDesc::READER: { - auto *lod_tensors_ptr = desc_.mutable_reader()->mutable_lod_tensor(); + switch (desc_.type().type()) { + case proto::VarType::READER: { + auto *lod_tensors_ptr = + desc_.mutable_type()->mutable_reader()->mutable_lod_tensor(); lod_tensors_ptr->Clear(); for (size_t i = 0; i < num; ++i) { lod_tensors_ptr->Add(); @@ -44,9 +47,9 @@ void VarDesc::SetTensorDescNum(size_t num) { } size_t VarDesc::GetTensorDescNum() const { - switch (desc_.type()) { - case proto::VarDesc::READER: - return desc_.reader().lod_tensor_size(); + switch (desc_.type().type()) { + case proto::VarType::READER: + return desc_.type().reader().lod_tensor_size(); break; default: PADDLE_THROW( @@ -64,7 +67,7 @@ void VarDesc::SetShapes( << "). The Reader is going to be reinitialized."; SetTensorDescNum(multiple_dims.size()); } - std::vector tensors = mutable_tensor_descs(); + std::vector tensors = mutable_tensor_descs(); for (size_t i = 0; i < multiple_dims.size(); ++i) { VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims()); } @@ -75,7 +78,7 @@ std::vector VarDesc::GetShape() const { } std::vector> VarDesc::GetShapes() const { - std::vector descs = tensor_descs(); + std::vector descs = tensor_descs(); std::vector> res; res.reserve(descs.size()); for (const auto &tensor_desc : descs) { @@ -98,7 +101,8 @@ void VarDesc::SetDataTypes( << "). The Reader is going to be reinitialized."; SetTensorDescNum(multiple_data_type.size()); } - std::vector tensor_descs = mutable_tensor_descs(); + std::vector tensor_descs = + mutable_tensor_descs(); for (size_t i = 0; i < multiple_data_type.size(); ++i) { tensor_descs[i]->set_data_type(multiple_data_type[i]); } @@ -109,7 +113,7 @@ proto::DataType VarDesc::GetDataType() const { } std::vector VarDesc::GetDataTypes() const { - std::vector descs = tensor_descs(); + std::vector descs = tensor_descs(); std::vector res; res.reserve(descs.size()); for (const auto &tensor_desc : descs) { @@ -119,12 +123,12 @@ std::vector VarDesc::GetDataTypes() const { } void VarDesc::SetLoDLevel(int32_t lod_level) { - switch (desc_.type()) { - case proto::VarDesc::LOD_TENSOR: - desc_.mutable_lod_tensor()->set_lod_level(lod_level); + switch (desc_.type().type()) { + case proto::VarType::LOD_TENSOR: + desc_.mutable_type()->mutable_lod_tensor()->set_lod_level(lod_level); break; - case proto::VarDesc::LOD_TENSOR_ARRAY: - desc_.mutable_tensor_array()->set_lod_level(lod_level); + case proto::VarType::LOD_TENSOR_ARRAY: + desc_.mutable_type()->mutable_tensor_array()->set_lod_level(lod_level); break; default: PADDLE_THROW( @@ -142,10 +146,11 @@ void VarDesc::SetLoDLevels(const std::vector &multiple_lod_level) { << "). The Reader is going to be reinitialized."; SetTensorDescNum(multiple_lod_level.size()); } - switch (desc_.type()) { - case proto::VarDesc::READER: { + switch (desc_.type().type()) { + case proto::VarType::READER: { size_t i = 0; - for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) { + for (auto &lod_tensor : + *desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) { lod_tensor.set_lod_level(multiple_lod_level[i++]); } } break; @@ -157,11 +162,11 @@ void VarDesc::SetLoDLevels(const std::vector &multiple_lod_level) { } int32_t VarDesc::GetLoDLevel() const { - switch (desc_.type()) { - case proto::VarDesc::LOD_TENSOR: - return desc_.lod_tensor().lod_level(); - case proto::VarDesc::LOD_TENSOR_ARRAY: - return desc_.tensor_array().lod_level(); + switch (desc_.type().type()) { + case proto::VarType::LOD_TENSOR: + return desc_.type().lod_tensor().lod_level(); + case proto::VarType::LOD_TENSOR_ARRAY: + return desc_.type().tensor_array().lod_level(); default: PADDLE_THROW( "Getting 'lod_level' is not supported by the type of var %s.", @@ -171,10 +176,10 @@ int32_t VarDesc::GetLoDLevel() const { std::vector VarDesc::GetLoDLevels() const { std::vector res; - switch (desc_.type()) { - case proto::VarDesc::READER: - res.reserve(desc_.reader().lod_tensor_size()); - for (auto &lod_tensor : desc_.reader().lod_tensor()) { + switch (desc_.type().type()) { + case proto::VarType::READER: + res.reserve(desc_.type().reader().lod_tensor_size()); + for (auto &lod_tensor : desc_.type().reader().lod_tensor()) { res.push_back(lod_tensor.lod_level()); } return res; @@ -186,15 +191,16 @@ std::vector VarDesc::GetLoDLevels() const { } } -const proto::TensorDesc &VarDesc::tensor_desc() const { +const proto::VarType::TensorDesc &VarDesc::tensor_desc() const { PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set."); - switch (desc_.type()) { - case proto::VarDesc::SELECTED_ROWS: - return desc_.selected_rows(); - case proto::VarDesc::LOD_TENSOR: - return desc_.lod_tensor().tensor(); - case proto::VarDesc::LOD_TENSOR_ARRAY: - return desc_.tensor_array().tensor(); + PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set."); + switch (desc_.type().type()) { + case proto::VarType::SELECTED_ROWS: + return desc_.type().selected_rows(); + case proto::VarType::LOD_TENSOR: + return desc_.type().lod_tensor().tensor(); + case proto::VarType::LOD_TENSOR_ARRAY: + return desc_.type().tensor_array().tensor(); default: PADDLE_THROW( "Getting 'tensor_desc' is not supported by the type of var %s.", @@ -202,13 +208,13 @@ const proto::TensorDesc &VarDesc::tensor_desc() const { } } -std::vector VarDesc::tensor_descs() const { +std::vector VarDesc::tensor_descs() const { PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); - std::vector res; + std::vector res; res.reserve(GetTensorDescNum()); - switch (desc_.type()) { - case proto::VarDesc::READER: - for (const auto &lod_tensor : desc_.reader().lod_tensor()) { + switch (desc_.type().type()) { + case proto::VarType::READER: + for (const auto &lod_tensor : desc_.type().reader().lod_tensor()) { res.push_back(lod_tensor.tensor()); } return res; @@ -220,15 +226,16 @@ std::vector VarDesc::tensor_descs() const { } } -proto::TensorDesc *VarDesc::mutable_tensor_desc() { +proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() { PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); - switch (desc_.type()) { - case proto::VarDesc::SELECTED_ROWS: - return desc_.mutable_selected_rows(); - case proto::VarDesc::LOD_TENSOR: - return desc_.mutable_lod_tensor()->mutable_tensor(); - case proto::VarDesc::LOD_TENSOR_ARRAY: - return desc_.mutable_tensor_array()->mutable_tensor(); + PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set."); + switch (desc_.type().type()) { + case proto::VarType::SELECTED_ROWS: + return desc_.mutable_type()->mutable_selected_rows(); + case proto::VarType::LOD_TENSOR: + return desc_.mutable_type()->mutable_lod_tensor()->mutable_tensor(); + case proto::VarType::LOD_TENSOR_ARRAY: + return desc_.mutable_type()->mutable_tensor_array()->mutable_tensor(); default: PADDLE_THROW( "Getting 'mutable_tensor_desc' is not supported by the type of var " @@ -237,13 +244,15 @@ proto::TensorDesc *VarDesc::mutable_tensor_desc() { } } -std::vector VarDesc::mutable_tensor_descs() { +std::vector VarDesc::mutable_tensor_descs() { PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); - std::vector res; + PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set."); + std::vector res; res.reserve(GetTensorDescNum()); - switch (desc_.type()) { - case proto::VarDesc::READER: - for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) { + switch (desc_.type().type()) { + case proto::VarType::READER: + for (auto &lod_tensor : + *desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) { res.push_back(lod_tensor.mutable_tensor()); } return res; diff --git a/paddle/fluid/framework/var_desc.h b/paddle/fluid/framework/var_desc.h index b272e5063e1543880affae6c9bb0a9aa5911d42d..013ba446b9463ce96608bb1f2b15499a3c96fec9 100644 --- a/paddle/fluid/framework/var_desc.h +++ b/paddle/fluid/framework/var_desc.h @@ -57,7 +57,7 @@ class VarDesc { public: explicit VarDesc(const std::string &name) { desc_.set_name(name); - desc_.set_type(proto::VarDesc::LOD_TENSOR); + desc_.mutable_type()->set_type(proto::VarType::LOD_TENSOR); } explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {} @@ -96,19 +96,19 @@ class VarDesc { std::vector GetLoDLevels() const; - proto::VarDesc::VarType GetType() const; + proto::VarType::Type GetType() const; - void SetType(proto::VarDesc::VarType type); + void SetType(proto::VarType::Type type); bool Persistable() const { return desc_.persistable(); } void SetPersistable(bool persistable) { desc_.set_persistable(persistable); } private: - const proto::TensorDesc &tensor_desc() const; - std::vector tensor_descs() const; - proto::TensorDesc *mutable_tensor_desc(); - std::vector mutable_tensor_descs(); + const proto::VarType::TensorDesc &tensor_desc() const; + std::vector tensor_descs() const; + proto::VarType::TensorDesc *mutable_tensor_desc(); + std::vector mutable_tensor_descs(); proto::VarDesc desc_; }; diff --git a/paddle/fluid/framework/var_type.h b/paddle/fluid/framework/var_type.h index b5a6183892aa140b6803a951df395acb756fecfa..960ebff9d7d8a522cf37c6c413e4caa1655ea86e 100644 --- a/paddle/fluid/framework/var_type.h +++ b/paddle/fluid/framework/var_type.h @@ -23,17 +23,17 @@ limitations under the License. */ namespace paddle { namespace framework { -inline proto::VarDesc::VarType ToVarType(std::type_index type) { +inline proto::VarType::Type ToVarType(std::type_index type) { if (type.hash_code() == typeid(LoDTensor).hash_code()) { - return proto::VarDesc_VarType_LOD_TENSOR; + return proto::VarType_Type_LOD_TENSOR; } else if (type.hash_code() == typeid(LoDRankTable).hash_code()) { - return proto::VarDesc_VarType_LOD_RANK_TABLE; + return proto::VarType_Type_LOD_RANK_TABLE; } else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) { - return proto::VarDesc_VarType_LOD_TENSOR_ARRAY; + return proto::VarType_Type_LOD_TENSOR_ARRAY; } else if (type.hash_code() == typeid(SelectedRows).hash_code()) { - return proto::VarDesc_VarType_SELECTED_ROWS; + return proto::VarType_Type_SELECTED_ROWS; } else if (type.hash_code() == typeid(ReaderHolder).hash_code()) { - return proto::VarDesc_VarType_READER; + return proto::VarType_Type_READER; } else { PADDLE_THROW("ToVarType:Unsupported type %s", type.name()); } @@ -42,19 +42,19 @@ inline proto::VarDesc::VarType ToVarType(std::type_index type) { template inline void VisitVarType(const framework::Variable& var, Visitor visitor) { switch (ToVarType(var.Type())) { - case proto::VarDesc_VarType_LOD_TENSOR: + case proto::VarType_Type_LOD_TENSOR: visitor(var.Get()); return; - case proto::VarDesc_VarType_LOD_RANK_TABLE: + case proto::VarType_Type_LOD_RANK_TABLE: visitor(var.Get()); return; - case proto::VarDesc_VarType_LOD_TENSOR_ARRAY: + case proto::VarType_Type_LOD_TENSOR_ARRAY: visitor(var.Get()); return; - case proto::VarDesc_VarType_SELECTED_ROWS: + case proto::VarType_Type_SELECTED_ROWS: visitor(var.Get()); return; - case proto::VarDesc_VarType_READER: + case proto::VarType_Type_READER: visitor(var.Get()); return; default: diff --git a/paddle/fluid/framework/var_type_inference_test.cc b/paddle/fluid/framework/var_type_inference_test.cc index 961f209ee1c9dab9ffeb3a44d7eb4329dbacb85e..1dced845ed7849d9f5a6de16dfe627d52fdb5488 100644 --- a/paddle/fluid/framework/var_type_inference_test.cc +++ b/paddle/fluid/framework/var_type_inference_test.cc @@ -35,14 +35,14 @@ class SumOpVarTypeInference : public VarTypeInference { public: void operator()(const OpDesc &op_desc, BlockDesc *block) const override { auto &inputs = op_desc.Input("X"); - auto default_var_type = proto::VarDesc::SELECTED_ROWS; + auto default_var_type = proto::VarType::SELECTED_ROWS; bool any_input_is_lod_tensor = std::any_of( inputs.begin(), inputs.end(), [block](const std::string &name) { - return block->Var(name)->GetType() == proto::VarDesc::LOD_TENSOR; + return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR; }); if (any_input_is_lod_tensor) { - default_var_type = proto::VarDesc::LOD_TENSOR; + default_var_type = proto::VarType::LOD_TENSOR; } auto out_var_name = op_desc.Output("Out").front(); @@ -67,19 +67,19 @@ TEST(InferVarType, sum_op) { op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetOutput("Out", {"test_out"}); - prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarDesc::SELECTED_ROWS); - prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarDesc::SELECTED_ROWS); - prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarDesc::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_out"); op->InferVarType(prog.MutableBlock(0)); - ASSERT_EQ(proto::VarDesc::SELECTED_ROWS, + ASSERT_EQ(proto::VarType::SELECTED_ROWS, prog.MutableBlock(0)->Var("test_out")->GetType()); - prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarDesc::LOD_TENSOR); + prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::LOD_TENSOR); op->InferVarType(prog.MutableBlock(0)); - ASSERT_EQ(proto::VarDesc::LOD_TENSOR, + ASSERT_EQ(proto::VarType::LOD_TENSOR, prog.MutableBlock(0)->Var("test_out")->GetType()); } @@ -90,14 +90,14 @@ TEST(InferVarType, sum_op_without_infer_var_type) { op->SetInput("X", {"test2_a", "test2_b", "test2_c"}); op->SetOutput("Out", {"test2_out"}); - prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarDesc::SELECTED_ROWS); - prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarDesc::SELECTED_ROWS); - prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarDesc::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarType::SELECTED_ROWS); prog.MutableBlock(0)->Var("test2_out"); op->InferVarType(prog.MutableBlock(0)); - ASSERT_EQ(proto::VarDesc_VarType_LOD_TENSOR, + ASSERT_EQ(proto::VarType_Type_LOD_TENSOR, prog.MutableBlock(0)->Var("test2_out")->GetType()); } diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc index eedf6b8c66f391d2e70f231cfdc6c3cf0980600f..e21dc6d77f3f7481c7059cdc50c8df8f78c3f7e8 100644 --- a/paddle/fluid/operators/assign_op.cc +++ b/paddle/fluid/operators/assign_op.cc @@ -115,8 +115,8 @@ class AssignInferShape : public framework::InferShapeBase { void operator()(framework::InferShapeContext *context) const override { if (context->HasInput("X")) { auto type = context->GetInputsVarType("X")[0]; - if (type == framework::proto::VarDesc_VarType_SELECTED_ROWS || - type == framework::proto::VarDesc_VarType_LOD_TENSOR) { + if (type == framework::proto::VarType::SELECTED_ROWS || + type == framework::proto::VarType::LOD_TENSOR) { context->SetOutputDim("Out", context->GetInputDim("X")); } } diff --git a/paddle/fluid/operators/beam_search_decode_op.cc b/paddle/fluid/operators/beam_search_decode_op.cc index dacb0e2681d8ccee7cd1d034a1655f0af4467999..718f469d38c3c6b7272c1531fae0a1e9ad2e8e3e 100644 --- a/paddle/fluid/operators/beam_search_decode_op.cc +++ b/paddle/fluid/operators/beam_search_decode_op.cc @@ -128,10 +128,10 @@ class BeamSearchDecodeInferVarType : public framework::VarTypeInference { void operator()(const framework::OpDesc& op_desc, framework::BlockDesc* block) const override { for (auto& o : op_desc.Output("SentenceIds")) { - block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR); + block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR); } for (auto& o : op_desc.Output("SentenceScores")) { - block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR); + block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR); } } }; diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index 76985ea9c2696a5bfcbc88eb00bc2cfce96c0c71..e848b1f12cb9f1ce1d37e0e0233bfc361dc35a33 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -240,10 +240,10 @@ class BeamSearchInferVarType : public framework::VarTypeInference { void operator()(const framework::OpDesc &op_desc, framework::BlockDesc *block) const override { for (auto &o : op_desc.Output("selected_ids")) { - block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR); + block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR); } for (auto &o : op_desc.Output("selected_scores")) { - block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR); + block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR); } } }; diff --git a/paddle/fluid/operators/create_reader_op.cc b/paddle/fluid/operators/create_reader_op.cc index 1393f1a66baaf3b53f797aa61fd42ac3cf54f8db..17ed7e24eca9ae923ef71cca7d4c0005eda271f1 100644 --- a/paddle/fluid/operators/create_reader_op.cc +++ b/paddle/fluid/operators/create_reader_op.cc @@ -84,7 +84,7 @@ class CreateFileReaderInferVarType : public framework::VarTypeInference { framework::BlockDesc* block) const override { std::string reader_name = op_desc.Output("Out")[0]; framework::VarDesc* reader = block->FindVarRecursive(reader_name); - reader->SetType(framework::proto::VarDesc::READER); + reader->SetType(framework::proto::VarType::READER); } }; @@ -97,7 +97,7 @@ class CreateDecoratedReaderInferVarType : public framework::VarTypeInference { framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name); std::string out_reader_name = op_desc.Output("Out")[0]; framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name); - out_reader->SetType(framework::proto::VarDesc::READER); + out_reader->SetType(framework::proto::VarType::READER); out_reader->SetDataTypes(in_reader->GetDataTypes()); } }; @@ -147,7 +147,7 @@ class CreateRandomDataGeneratorOpMaker AddComment(R"DOC( CreateRandomDataGenerator Operator - This Op creates a random reader. + This Op creates a random reader. The reader generates random data instead of really reading from files. Generated data follow an uniform distribution between 'min' and 'max'. )DOC"); @@ -183,7 +183,7 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker { CreateShuffleReader Operator A shuffle reader takes another reader as its 'underlying reader' - and yields the underlying reader's outputs in a shuffled order. + and yields the underlying reader's outputs in a shuffled order. )DOC"); } }; @@ -218,8 +218,8 @@ class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( CreateBatchReader Operator - A batch reader takes another reader as its 'underlying reader', - gathers the underlying reader's outputs and then yields them in batches. + A batch reader takes another reader as its 'underlying reader', + gathers the underlying reader's outputs and then yields them in batches. )DOC"); } }; diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index 5403dbc2a0525d75ef2af5312b397ed3daec8cda..169fd40fd950a74e61a4ed06a370f25b533957db 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -24,11 +24,11 @@ void SerializeToMessage(const std::string& name, const framework::Variable* var, msg->set_varname(name); std::ostringstream oss; switch (framework::ToVarType(var->Type())) { - case framework::proto::VarDesc_VarType_LOD_TENSOR: + case framework::proto::VarType_Type_LOD_TENSOR: msg->set_type(sendrecv::VarType::LOD_TENSOR); framework::SerializeToStream(oss, var->Get(), ctx); break; - case framework::proto::VarDesc_VarType_SELECTED_ROWS: + case framework::proto::VarType_Type_SELECTED_ROWS: msg->set_type(sendrecv::VarType::SELECTED_ROWS); framework::SerializeToStream(oss, var->Get(), ctx); diff --git a/paddle/fluid/operators/get_places_op.cc b/paddle/fluid/operators/get_places_op.cc index 8555b0778fa7b072a950a3011354764feab75f5f..9002ce4717c6e75e7204ef62094e4680bba3f88b 100644 --- a/paddle/fluid/operators/get_places_op.cc +++ b/paddle/fluid/operators/get_places_op.cc @@ -98,7 +98,7 @@ class GetPlacesInferVarType : public framework::VarTypeInference { framework::BlockDesc *block) const override { for (auto &o_name : op_desc.Output("Out")) { block->FindRecursiveOrCreateVar(o_name).SetType( - framework::proto::VarDesc::PLACE_LIST); + framework::proto::VarType::PLACE_LIST); } } }; diff --git a/paddle/fluid/operators/lod_rank_table_op.cc b/paddle/fluid/operators/lod_rank_table_op.cc index 2d01ed673763086798276ef780667e75fcd056ef..590b44e14f518c3c60c141c9a0dfe7f2b96f69c6 100644 --- a/paddle/fluid/operators/lod_rank_table_op.cc +++ b/paddle/fluid/operators/lod_rank_table_op.cc @@ -69,7 +69,7 @@ class LoDRankTableInferVarType : public framework::VarTypeInference { framework::BlockDesc *block) const override { for (auto &o : op_desc.Output("Out")) { block->FindRecursiveOrCreateVar(o).SetType( - framework::proto::VarDesc::LOD_RANK_TABLE); + framework::proto::VarType::LOD_RANK_TABLE); } } }; diff --git a/paddle/fluid/operators/lod_tensor_to_array_op.cc b/paddle/fluid/operators/lod_tensor_to_array_op.cc index be47fdfd048cfa9c8903f797f18a9b831f03055c..b5e778a58114633c96a1401df501bb2cb10022c1 100644 --- a/paddle/fluid/operators/lod_tensor_to_array_op.cc +++ b/paddle/fluid/operators/lod_tensor_to_array_op.cc @@ -138,7 +138,7 @@ class LoDTensorToArrayInferVarType : public framework::VarTypeInference { void operator()(const framework::OpDesc &op_desc, framework::BlockDesc *block) const override { for (auto &out_var : op_desc.Output("Out")) { - block->Var(out_var)->SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY); + block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR_ARRAY); } } }; diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index d338553f7c4821431fb9afcd5559bda97fa6181b..3acdca17afc2fea05fb81871e6e03d72691fe91e 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -123,11 +123,11 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference { VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") << " is set to SelectedRows"; block->Var(out_var_name) - ->SetType(framework::proto::VarDesc::SELECTED_ROWS); + ->SetType(framework::proto::VarType::SELECTED_ROWS); } else { VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") << " is set to LoDTensor"; - block->Var(out_var_name)->SetType(framework::proto::VarDesc::LOD_TENSOR); + block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR); } } }; diff --git a/paddle/fluid/operators/read_op.cc b/paddle/fluid/operators/read_op.cc index 127df82ff13b89de42e45113a21d6f5e7c2f20ed..62beab82d4f2b0b795d5d32f50352172de6870cc 100644 --- a/paddle/fluid/operators/read_op.cc +++ b/paddle/fluid/operators/read_op.cc @@ -45,7 +45,7 @@ class ReadInferVarType : public framework::VarTypeInference { PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size()); for (size_t i = 0; i < dtypes.size(); ++i) { framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]); - out.SetType(framework::proto::VarDesc::LOD_TENSOR); + out.SetType(framework::proto::VarType::LOD_TENSOR); out.SetDataType(dtypes[i]); } } diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index bfc5709c4b7f48fc76c2a8a29cb2017d6b449c12..7b88387c3384feabdcb6b169bf2978e70a06c3fd 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -29,7 +29,7 @@ class SumOp : public framework::OperatorWithKernel { "Output(Out) of SumOp should not be null."); if (ctx->IsRuntime() && ctx->GetOutputsVarType("Out")[0] == - framework::proto::VarDesc::LOD_TENSOR_ARRAY) { + framework::proto::VarType::LOD_TENSOR_ARRAY) { return; // skip runtime infershape when is tensor array; } @@ -118,7 +118,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference { void operator()(const framework::OpDesc& op_desc, framework::BlockDesc* block) const override { auto& inputs = op_desc.Input("X"); - auto var_type = framework::proto::VarDesc::SELECTED_ROWS; + auto var_type = framework::proto::VarType::SELECTED_ROWS; for (auto& name : op_desc.Input("X")) { VLOG(10) << name << " " @@ -128,12 +128,12 @@ class SumOpVarTypeInference : public framework::VarTypeInference { bool any_input_is_lod_tensor = std::any_of( inputs.begin(), inputs.end(), [block](const std::string& name) { return block->FindRecursiveOrCreateVar(name).GetType() == - framework::proto::VarDesc::LOD_TENSOR; + framework::proto::VarType::LOD_TENSOR; }); auto is_tensor_array = [block](const std::string& name) { return block->FindRecursiveOrCreateVar(name).GetType() == - framework::proto::VarDesc::LOD_TENSOR_ARRAY; + framework::proto::VarType::LOD_TENSOR_ARRAY; }; bool any_input_is_tensor_array = @@ -151,9 +151,9 @@ class SumOpVarTypeInference : public framework::VarTypeInference { PADDLE_ENFORCE(all_inputs_are_tensor_array, "Not all inputs are tensor array:\n%s", os.str()); } - var_type = framework::proto::VarDesc::LOD_TENSOR_ARRAY; + var_type = framework::proto::VarType::LOD_TENSOR_ARRAY; } else if (any_input_is_lod_tensor) { - var_type = framework::proto::VarDesc::LOD_TENSOR; + var_type = framework::proto::VarType::LOD_TENSOR; } auto out_var_name = op_desc.Output("Out").front(); diff --git a/paddle/fluid/operators/tensor_array_read_write_op.cc b/paddle/fluid/operators/tensor_array_read_write_op.cc index 278b3481176e8dd82dc5a6fed376b1cb59e102c6..9b484cda1216ee2138ed2dc6cbedd5f825fe0be9 100644 --- a/paddle/fluid/operators/tensor_array_read_write_op.cc +++ b/paddle/fluid/operators/tensor_array_read_write_op.cc @@ -108,7 +108,7 @@ class WriteToArrayInferVarType : public framework::VarTypeInference { auto out_name = op_desc.Output("Out")[0]; VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY"; auto &out = block->FindRecursiveOrCreateVar(out_name); - out.SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY); + out.SetType(framework::proto::VarType::LOD_TENSOR_ARRAY); auto *x = block->FindVarRecursive(x_name); if (x != nullptr) { out.SetDataType(x->GetDataType()); diff --git a/paddle/fluid/operators/while_op.cc b/paddle/fluid/operators/while_op.cc index 94a11eaf78249c0ca44411abdbc65e0a74a54c78..3d5cdeda26ada94fbd8e6a7c25995aa7de93fb3d 100644 --- a/paddle/fluid/operators/while_op.cc +++ b/paddle/fluid/operators/while_op.cc @@ -330,10 +330,10 @@ class WhileGradOpShapeInference : public framework::InferShapeBase { continue; } auto dims = ctx->GetInputsElementDim(kX, i); - if (var_types[i] == framework::proto::VarDesc::LOD_TENSOR) { + if (var_types[i] == framework::proto::VarType::LOD_TENSOR) { names_to_set.push_back(pg_names[i]); dims_to_set.push_back(dims); - } else if (var_types[i] == framework::proto::VarDesc::LOD_TENSOR_ARRAY) { + } else if (var_types[i] == framework::proto::VarType::LOD_TENSOR_ARRAY) { // not sure how to set the dim of LOD_TENSOR_ARRAY names_to_set.push_back(pg_names[i]); dims_to_set.push_back(dims); diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 3341edb370f3c2eed2bf34a5fd7411f1c608d9fe..9f97cc5007ec00d0ad28ba755d24896017c9003f 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -232,16 +232,16 @@ void BindVarDsec(py::module &m) { .def("persistable", &VarDesc::Persistable) .def("set_persistable", &VarDesc::SetPersistable); - py::enum_(var_desc, "VarType", "") - .value("LOD_TENSOR", proto::VarDesc::LOD_TENSOR) - .value("SELECTED_ROWS", proto::VarDesc::SELECTED_ROWS) - .value("FEED_MINIBATCH", proto::VarDesc::FEED_MINIBATCH) - .value("FETCH_LIST", proto::VarDesc::FETCH_LIST) - .value("STEP_SCOPES", proto::VarDesc::STEP_SCOPES) - .value("LOD_RANK_TABLE", proto::VarDesc::LOD_RANK_TABLE) - .value("LOD_TENSOR_ARRAY", proto::VarDesc::LOD_TENSOR_ARRAY) - .value("PLACE_LIST", proto::VarDesc::PLACE_LIST) - .value("READER", proto::VarDesc::READER); + py::enum_(var_desc, "VarType", "") + .value("LOD_TENSOR", proto::VarType::LOD_TENSOR) + .value("SELECTED_ROWS", proto::VarType::SELECTED_ROWS) + .value("FEED_MINIBATCH", proto::VarType::FEED_MINIBATCH) + .value("FETCH_LIST", proto::VarType::FETCH_LIST) + .value("STEP_SCOPES", proto::VarType::STEP_SCOPES) + .value("LOD_RANK_TABLE", proto::VarType::LOD_RANK_TABLE) + .value("LOD_TENSOR_ARRAY", proto::VarType::LOD_TENSOR_ARRAY) + .value("PLACE_LIST", proto::VarType::PLACE_LIST) + .value("READER", proto::VarType::READER); } void BindOpDesc(py::module &m) {