From fcadb452515b6acabf9daf987b4a92dcb8e62d73 Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Mon, 12 Feb 2018 16:05:41 -0800 Subject: [PATCH] Separate VarType from VarDesc in framework.proto and fix all related compiler errors (#8414) * Refine Type system * Fixing type inference * Fixed create_reader_op.cc * Fix var_desc.h * Fixed executor.cc * Fix shape_inference.h * Fixed create_reader_op.cc * Fix tensor_util.h * Fixed var_type_inference_test.cc * Fix shape_inference.cc * Fixed sum_op.c * Fixed read_op.cc * Fix var_type.h * Fixed beam_search_decode_op.cc * sendrecvop_utils.cc * Fix operator.cc * Fixed lookup_table_op.cc * Fixed op_desc.cc * Fixed get_places_op.cc * Fixed lod_rank_table_op.cc * Fixed beam_search_op.cc * Fix var_desc.cc * Fixed lod_tensor_to_array_op.cc * Fixed while_op.cc * Fix program_desc_test.cc * tensor_array_read_write_op.cc * Fix assign_op.cc * Fix executor.cc * Fix protobuf.cc * Fix protobuf.cc --- paddle/fluid/framework/executor.cc | 28 ++--- paddle/fluid/framework/framework.proto | 51 ++++---- paddle/fluid/framework/op_desc.cc | 10 +- paddle/fluid/framework/operator.cc | 2 +- paddle/fluid/framework/program_desc_test.cc | 12 +- paddle/fluid/framework/shape_inference.cc | 8 +- paddle/fluid/framework/shape_inference.h | 8 +- paddle/fluid/framework/tensor_util.h | 4 +- paddle/fluid/framework/var_desc.cc | 119 ++++++++++-------- paddle/fluid/framework/var_desc.h | 14 +-- paddle/fluid/framework/var_type.h | 22 ++-- .../framework/var_type_inference_test.cc | 26 ++-- paddle/fluid/operators/assign_op.cc | 4 +- .../fluid/operators/beam_search_decode_op.cc | 4 +- paddle/fluid/operators/beam_search_op.cc | 4 +- paddle/fluid/operators/create_reader_op.cc | 12 +- .../operators/detail/sendrecvop_utils.cc | 4 +- paddle/fluid/operators/get_places_op.cc | 2 +- paddle/fluid/operators/lod_rank_table_op.cc | 2 +- .../fluid/operators/lod_tensor_to_array_op.cc | 2 +- paddle/fluid/operators/lookup_table_op.cc | 4 +- paddle/fluid/operators/read_op.cc | 2 +- paddle/fluid/operators/sum_op.cc | 12 +- .../operators/tensor_array_read_write_op.cc | 2 +- paddle/fluid/operators/while_op.cc | 4 +- paddle/fluid/pybind/protobuf.cc | 20 +-- 26 files changed, 198 insertions(+), 184 deletions(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 179f9194a9..ebfd54fdc5 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -36,24 +36,24 @@ namespace framework { Executor::Executor(const platform::Place& place) : place_(place) {} -static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { - if (var_type == proto::VarDesc::LOD_TENSOR) { +static void CreateTensor(Variable* var, proto::VarType::Type var_type) { + if (var_type == proto::VarType::LOD_TENSOR) { var->GetMutable(); - } else if (var_type == proto::VarDesc::SELECTED_ROWS) { + } else if (var_type == proto::VarType::SELECTED_ROWS) { var->GetMutable(); - } else if (var_type == proto::VarDesc::FEED_MINIBATCH) { + } else if (var_type == proto::VarType::FEED_MINIBATCH) { var->GetMutable(); - } else if (var_type == proto::VarDesc::FETCH_LIST) { + } else if (var_type == proto::VarType::FETCH_LIST) { var->GetMutable(); - } else if (var_type == proto::VarDesc::STEP_SCOPES) { + } else if (var_type == proto::VarType::STEP_SCOPES) { var->GetMutable>(); - } else if (var_type == proto::VarDesc::LOD_RANK_TABLE) { + } else if (var_type == proto::VarType::LOD_RANK_TABLE) { var->GetMutable(); - } else if (var_type == proto::VarDesc::LOD_TENSOR_ARRAY) { + } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) { var->GetMutable(); - } else if (var_type == proto::VarDesc::PLACE_LIST) { + } else if (var_type == proto::VarType::PLACE_LIST) { var->GetMutable(); - } else if (var_type == proto::VarDesc::READER) { + } else if (var_type == proto::VarType::READER) { var->GetMutable(); } else { PADDLE_THROW( @@ -182,7 +182,7 @@ static bool has_feed_operators( auto var = block->FindVar(feed_holder_name); PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", feed_holder_name); - PADDLE_ENFORCE_EQ(var->GetType(), proto::VarDesc::FEED_MINIBATCH, + PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH, "'%s' variable should be 'FEED_MINIBATCH' type", feed_holder_name); } @@ -222,7 +222,7 @@ static bool has_fetch_operators( auto var = block->FindVar(fetch_holder_name); PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", fetch_holder_name); - PADDLE_ENFORCE_EQ(var->GetType(), proto::VarDesc::FETCH_LIST, + PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST, "'%s' variable should be 'FETCH_LIST' type", fetch_holder_name); } @@ -241,7 +241,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) { // create feed_holder variable auto* feed_holder = global_block->Var(feed_holder_name); - feed_holder->SetType(proto::VarDesc::FEED_MINIBATCH); + feed_holder->SetType(proto::VarType::FEED_MINIBATCH); feed_holder->SetPersistable(true); int i = 0; @@ -274,7 +274,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) { // create fetch_holder variable auto* fetch_holder = global_block->Var(fetch_holder_name); - fetch_holder->SetType(proto::VarDesc::FETCH_LIST); + fetch_holder->SetType(proto::VarType::FETCH_LIST); fetch_holder->SetPersistable(true); int i = 0; diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index ad8da21ae0..fa7f437851 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -101,25 +101,8 @@ enum DataType { FP64 = 6; } -message TensorDesc { - required DataType data_type = 1; - repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480] -} - -message LoDTensorDesc { - required TensorDesc tensor = 1; - optional int32 lod_level = 2 [ default = 0 ]; -} - -message LoDTensorArrayDesc { - required TensorDesc tensor = 1; - optional int32 lod_level = 2 [ default = 0 ]; -} - -message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; } - -message VarDesc { - enum VarType { +message VarType { + enum Type { LOD_TENSOR = 1; SELECTED_ROWS = 2; FEED_MINIBATCH = 3; @@ -130,13 +113,35 @@ message VarDesc { PLACE_LIST = 8; READER = 9; } + + required Type type = 1; + + message TensorDesc { + required DataType data_type = 1; + repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480] + } + optional TensorDesc selected_rows = 2; + + message LoDTensorDesc { + required TensorDesc tensor = 1; + optional int32 lod_level = 2 [ default = 0 ]; + } + optional LoDTensorDesc lod_tensor = 3; + + message LoDTensorArrayDesc { + required TensorDesc tensor = 1; + optional int32 lod_level = 2 [ default = 0 ]; + } + optional LoDTensorArrayDesc tensor_array = 4; + + message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; } + optional ReaderDesc reader = 5; +} + +message VarDesc { required string name = 1; required VarType type = 2; optional bool persistable = 3 [ default = false ]; - optional LoDTensorDesc lod_tensor = 4; - optional TensorDesc selected_rows = 5; - optional LoDTensorArrayDesc tensor_array = 6; - optional ReaderDesc reader = 7; } message BlockDesc { diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index e740010c63..eabfdc11a8 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -53,11 +53,11 @@ class CompileTimeInferShapeContext : public InferShapeContext { PADDLE_ENFORCE_LT(j, Outputs(out).size()); auto *in_var = block_.FindVarRecursive(Inputs(in)[i]); auto *out_var = block_.FindVarRecursive(Outputs(out)[j]); - if (in_var->GetType() != proto::VarDesc::LOD_TENSOR) { + if (in_var->GetType() != proto::VarType::LOD_TENSOR) { VLOG(3) << "input " << in << " is not LodTensor"; return; } - PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarDesc::LOD_TENSOR, + PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarType::LOD_TENSOR, "The %d-th output of Output(%s) must be LoDTensor.", j, out); out_var->SetLoDLevel(in_var->GetLoDLevel()); @@ -66,7 +66,7 @@ class CompileTimeInferShapeContext : public InferShapeContext { bool IsRuntime() const override; protected: - proto::VarDesc::VarType GetVarType(const std::string &name) const override; + proto::VarType::Type GetVarType(const std::string &name) const override; DDim GetDim(const std::string &name) const override; @@ -388,7 +388,7 @@ void OpDesc::InferVarType(BlockDesc *block) const { for (auto &out_pair : this->outputs_) { for (auto &out_var_name : out_pair.second) { block->FindRecursiveOrCreateVar(out_var_name) - .SetType(proto::VarDesc::LOD_TENSOR); + .SetType(proto::VarType::LOD_TENSOR); } } } @@ -507,7 +507,7 @@ void CompileTimeInferShapeContext::SetRepeatedDims( bool CompileTimeInferShapeContext::IsRuntime() const { return false; } -proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType( +proto::VarType::Type CompileTimeInferShapeContext::GetVarType( const std::string &name) const { return block_.FindVarRecursive(name)->GetType(); } diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index bc529b8269..ff90aba10b 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -477,7 +477,7 @@ class RuntimeInferShapeContext : public InferShapeContext { } } - proto::VarDesc::VarType GetVarType(const std::string& name) const override { + proto::VarType::Type GetVarType(const std::string& name) const override { auto* var = scope_.FindVar(name); return ToVarType(var->Type()); } diff --git a/paddle/fluid/framework/program_desc_test.cc b/paddle/fluid/framework/program_desc_test.cc index 3a4a87cfa5..d9c4331da1 100644 --- a/paddle/fluid/framework/program_desc_test.cc +++ b/paddle/fluid/framework/program_desc_test.cc @@ -22,13 +22,13 @@ TEST(ProgramDesc, copy_ctor) { ProgramDesc program; auto* global_block = program.MutableBlock(0); auto* x = global_block->Var("X"); - x->SetType(proto::VarDesc_VarType_LOD_TENSOR); + x->SetType(proto::VarType::LOD_TENSOR); x->SetLoDLevel(0); x->SetDataType(proto::FP32); x->SetShape({1000, 784}); auto* y = global_block->Var("Y"); - y->SetType(proto::VarDesc_VarType_LOD_TENSOR); + y->SetType(proto::VarType::LOD_TENSOR); y->SetLoDLevel(0); y->SetDataType(proto::FP32); y->SetShape({784, 100}); @@ -39,7 +39,7 @@ TEST(ProgramDesc, copy_ctor) { op->SetInput("Y", {y->Name()}); auto* out = global_block->Var("Out"); - out->SetType(proto::VarDesc_VarType_LOD_TENSOR); + out->SetType(proto::VarType::LOD_TENSOR); op->SetOutput("Y", {out->Name()}); ProgramDesc program_copy(program); @@ -84,13 +84,13 @@ TEST(ProgramDescBind, serialize_and_deserialize) { ProgramDesc program_origin; auto* global_block = program_origin.MutableBlock(0); auto* x = global_block->Var("X"); - x->SetType(proto::VarDesc_VarType_LOD_TENSOR); + x->SetType(proto::VarType::LOD_TENSOR); x->SetLoDLevel(0); x->SetDataType(proto::FP32); x->SetShape({1000, 784}); auto* y = global_block->Var("Y"); - y->SetType(proto::VarDesc_VarType_LOD_TENSOR); + y->SetType(proto::VarType::LOD_TENSOR); y->SetLoDLevel(0); y->SetDataType(proto::FP32); y->SetShape({784, 100}); @@ -101,7 +101,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) { op->SetInput("Y", {y->Name()}); auto* out = global_block->Var("Out"); - out->SetType(proto::VarDesc_VarType_LOD_TENSOR); + out->SetType(proto::VarType::LOD_TENSOR); op->SetOutput("Y", {out->Name()}); std::string binary_str; diff --git a/paddle/fluid/framework/shape_inference.cc b/paddle/fluid/framework/shape_inference.cc index 1b518970ac..dc9a79020f 100644 --- a/paddle/fluid/framework/shape_inference.cc +++ b/paddle/fluid/framework/shape_inference.cc @@ -116,19 +116,19 @@ void InferShapeContext::SetDims(const std::vector &names, } } -std::vector InferShapeContext::GetInputsVarType( +std::vector InferShapeContext::GetInputsVarType( const std::string &name) const { return GetVarTypes(Inputs(name)); } -std::vector InferShapeContext::GetOutputsVarType( +std::vector InferShapeContext::GetOutputsVarType( const std::string &name) const { return GetVarTypes(Outputs(name)); } -std::vector InferShapeContext::GetVarTypes( +std::vector InferShapeContext::GetVarTypes( const std::vector &names) const { - std::vector retv; + std::vector retv; retv.resize(names.size()); std::transform(names.begin(), names.end(), retv.begin(), std::bind(std::mem_fn(&InferShapeContext::GetVarType), this, diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 3739d640fe..bc02d700da 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -31,9 +31,9 @@ class InferShapeContext { virtual bool HasInput(const std::string &name) const = 0; virtual bool HasOutput(const std::string &name) const = 0; - std::vector GetInputsVarType( + std::vector GetInputsVarType( const std::string &name) const; - std::vector GetOutputsVarType( + std::vector GetOutputsVarType( const std::string &name) const; virtual bool HasInputs(const std::string &name) const = 0; @@ -75,10 +75,10 @@ class InferShapeContext { std::vector GetDims(const std::vector &names) const; - std::vector GetVarTypes( + std::vector GetVarTypes( const std::vector &names) const; - virtual proto::VarDesc::VarType GetVarType(const std::string &name) const = 0; + virtual proto::VarType::Type GetVarType(const std::string &name) const = 0; virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0; }; diff --git a/paddle/fluid/framework/tensor_util.h b/paddle/fluid/framework/tensor_util.h index 22519013cc..f0464d4807 100644 --- a/paddle/fluid/framework/tensor_util.h +++ b/paddle/fluid/framework/tensor_util.h @@ -225,7 +225,7 @@ inline void SerializeToStream(std::ostream& os, const Tensor& tensor, { // the 2nd field, tensor description // int32_t size // void* protobuf message - proto::TensorDesc desc; + proto::VarType::TensorDesc desc; desc.set_data_type(framework::ToDataType(tensor.type())); auto dims = framework::vectorize(tensor.dims()); auto* pb_dims = desc.mutable_dims(); @@ -290,7 +290,7 @@ inline void DeserializeFromStream(std::istream& is, Tensor* tensor, uint32_t version; is.read(reinterpret_cast(&version), sizeof(version)); PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); - proto::TensorDesc desc; + proto::VarType::TensorDesc desc; { // int32_t size // proto buffer int32_t size; diff --git a/paddle/fluid/framework/var_desc.cc b/paddle/fluid/framework/var_desc.cc index eb88146969..bb2be1ab50 100644 --- a/paddle/fluid/framework/var_desc.cc +++ b/paddle/fluid/framework/var_desc.cc @@ -18,18 +18,21 @@ limitations under the License. */ namespace paddle { namespace framework { -proto::VarDesc::VarType VarDesc::GetType() const { return desc_.type(); } +proto::VarType::Type VarDesc::GetType() const { return desc_.type().type(); } -void VarDesc::SetType(proto::VarDesc::VarType type) { desc_.set_type(type); } +void VarDesc::SetType(proto::VarType::Type type) { + desc_.mutable_type()->set_type(type); +} void VarDesc::SetShape(const std::vector &dims) { VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims()); } void VarDesc::SetTensorDescNum(size_t num) { - switch (desc_.type()) { - case proto::VarDesc::READER: { - auto *lod_tensors_ptr = desc_.mutable_reader()->mutable_lod_tensor(); + switch (desc_.type().type()) { + case proto::VarType::READER: { + auto *lod_tensors_ptr = + desc_.mutable_type()->mutable_reader()->mutable_lod_tensor(); lod_tensors_ptr->Clear(); for (size_t i = 0; i < num; ++i) { lod_tensors_ptr->Add(); @@ -44,9 +47,9 @@ void VarDesc::SetTensorDescNum(size_t num) { } size_t VarDesc::GetTensorDescNum() const { - switch (desc_.type()) { - case proto::VarDesc::READER: - return desc_.reader().lod_tensor_size(); + switch (desc_.type().type()) { + case proto::VarType::READER: + return desc_.type().reader().lod_tensor_size(); break; default: PADDLE_THROW( @@ -64,7 +67,7 @@ void VarDesc::SetShapes( << "). The Reader is going to be reinitialized."; SetTensorDescNum(multiple_dims.size()); } - std::vector tensors = mutable_tensor_descs(); + std::vector tensors = mutable_tensor_descs(); for (size_t i = 0; i < multiple_dims.size(); ++i) { VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims()); } @@ -75,7 +78,7 @@ std::vector VarDesc::GetShape() const { } std::vector> VarDesc::GetShapes() const { - std::vector descs = tensor_descs(); + std::vector descs = tensor_descs(); std::vector> res; res.reserve(descs.size()); for (const auto &tensor_desc : descs) { @@ -98,7 +101,8 @@ void VarDesc::SetDataTypes( << "). The Reader is going to be reinitialized."; SetTensorDescNum(multiple_data_type.size()); } - std::vector tensor_descs = mutable_tensor_descs(); + std::vector tensor_descs = + mutable_tensor_descs(); for (size_t i = 0; i < multiple_data_type.size(); ++i) { tensor_descs[i]->set_data_type(multiple_data_type[i]); } @@ -109,7 +113,7 @@ proto::DataType VarDesc::GetDataType() const { } std::vector VarDesc::GetDataTypes() const { - std::vector descs = tensor_descs(); + std::vector descs = tensor_descs(); std::vector res; res.reserve(descs.size()); for (const auto &tensor_desc : descs) { @@ -119,12 +123,12 @@ std::vector VarDesc::GetDataTypes() const { } void VarDesc::SetLoDLevel(int32_t lod_level) { - switch (desc_.type()) { - case proto::VarDesc::LOD_TENSOR: - desc_.mutable_lod_tensor()->set_lod_level(lod_level); + switch (desc_.type().type()) { + case proto::VarType::LOD_TENSOR: + desc_.mutable_type()->mutable_lod_tensor()->set_lod_level(lod_level); break; - case proto::VarDesc::LOD_TENSOR_ARRAY: - desc_.mutable_tensor_array()->set_lod_level(lod_level); + case proto::VarType::LOD_TENSOR_ARRAY: + desc_.mutable_type()->mutable_tensor_array()->set_lod_level(lod_level); break; default: PADDLE_THROW( @@ -142,10 +146,11 @@ void VarDesc::SetLoDLevels(const std::vector &multiple_lod_level) { << "). The Reader is going to be reinitialized."; SetTensorDescNum(multiple_lod_level.size()); } - switch (desc_.type()) { - case proto::VarDesc::READER: { + switch (desc_.type().type()) { + case proto::VarType::READER: { size_t i = 0; - for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) { + for (auto &lod_tensor : + *desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) { lod_tensor.set_lod_level(multiple_lod_level[i++]); } } break; @@ -157,11 +162,11 @@ void VarDesc::SetLoDLevels(const std::vector &multiple_lod_level) { } int32_t VarDesc::GetLoDLevel() const { - switch (desc_.type()) { - case proto::VarDesc::LOD_TENSOR: - return desc_.lod_tensor().lod_level(); - case proto::VarDesc::LOD_TENSOR_ARRAY: - return desc_.tensor_array().lod_level(); + switch (desc_.type().type()) { + case proto::VarType::LOD_TENSOR: + return desc_.type().lod_tensor().lod_level(); + case proto::VarType::LOD_TENSOR_ARRAY: + return desc_.type().tensor_array().lod_level(); default: PADDLE_THROW( "Getting 'lod_level' is not supported by the type of var %s.", @@ -171,10 +176,10 @@ int32_t VarDesc::GetLoDLevel() const { std::vector VarDesc::GetLoDLevels() const { std::vector res; - switch (desc_.type()) { - case proto::VarDesc::READER: - res.reserve(desc_.reader().lod_tensor_size()); - for (auto &lod_tensor : desc_.reader().lod_tensor()) { + switch (desc_.type().type()) { + case proto::VarType::READER: + res.reserve(desc_.type().reader().lod_tensor_size()); + for (auto &lod_tensor : desc_.type().reader().lod_tensor()) { res.push_back(lod_tensor.lod_level()); } return res; @@ -186,15 +191,16 @@ std::vector VarDesc::GetLoDLevels() const { } } -const proto::TensorDesc &VarDesc::tensor_desc() const { +const proto::VarType::TensorDesc &VarDesc::tensor_desc() const { PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set."); - switch (desc_.type()) { - case proto::VarDesc::SELECTED_ROWS: - return desc_.selected_rows(); - case proto::VarDesc::LOD_TENSOR: - return desc_.lod_tensor().tensor(); - case proto::VarDesc::LOD_TENSOR_ARRAY: - return desc_.tensor_array().tensor(); + PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set."); + switch (desc_.type().type()) { + case proto::VarType::SELECTED_ROWS: + return desc_.type().selected_rows(); + case proto::VarType::LOD_TENSOR: + return desc_.type().lod_tensor().tensor(); + case proto::VarType::LOD_TENSOR_ARRAY: + return desc_.type().tensor_array().tensor(); default: PADDLE_THROW( "Getting 'tensor_desc' is not supported by the type of var %s.", @@ -202,13 +208,13 @@ const proto::TensorDesc &VarDesc::tensor_desc() const { } } -std::vector VarDesc::tensor_descs() const { +std::vector VarDesc::tensor_descs() const { PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); - std::vector res; + std::vector res; res.reserve(GetTensorDescNum()); - switch (desc_.type()) { - case proto::VarDesc::READER: - for (const auto &lod_tensor : desc_.reader().lod_tensor()) { + switch (desc_.type().type()) { + case proto::VarType::READER: + for (const auto &lod_tensor : desc_.type().reader().lod_tensor()) { res.push_back(lod_tensor.tensor()); } return res; @@ -220,15 +226,16 @@ std::vector VarDesc::tensor_descs() const { } } -proto::TensorDesc *VarDesc::mutable_tensor_desc() { +proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() { PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); - switch (desc_.type()) { - case proto::VarDesc::SELECTED_ROWS: - return desc_.mutable_selected_rows(); - case proto::VarDesc::LOD_TENSOR: - return desc_.mutable_lod_tensor()->mutable_tensor(); - case proto::VarDesc::LOD_TENSOR_ARRAY: - return desc_.mutable_tensor_array()->mutable_tensor(); + PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set."); + switch (desc_.type().type()) { + case proto::VarType::SELECTED_ROWS: + return desc_.mutable_type()->mutable_selected_rows(); + case proto::VarType::LOD_TENSOR: + return desc_.mutable_type()->mutable_lod_tensor()->mutable_tensor(); + case proto::VarType::LOD_TENSOR_ARRAY: + return desc_.mutable_type()->mutable_tensor_array()->mutable_tensor(); default: PADDLE_THROW( "Getting 'mutable_tensor_desc' is not supported by the type of var " @@ -237,13 +244,15 @@ proto::TensorDesc *VarDesc::mutable_tensor_desc() { } } -std::vector VarDesc::mutable_tensor_descs() { +std::vector VarDesc::mutable_tensor_descs() { PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); - std::vector res; + PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set."); + std::vector res; res.reserve(GetTensorDescNum()); - switch (desc_.type()) { - case proto::VarDesc::READER: - for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) { + switch (desc_.type().type()) { + case proto::VarType::READER: + for (auto &lod_tensor : + *desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) { res.push_back(lod_tensor.mutable_tensor()); } return res; diff --git a/paddle/fluid/framework/var_desc.h b/paddle/fluid/framework/var_desc.h index b272e5063e..013ba446b9 100644 --- a/paddle/fluid/framework/var_desc.h +++ b/paddle/fluid/framework/var_desc.h @@ -57,7 +57,7 @@ class VarDesc { public: explicit VarDesc(const std::string &name) { desc_.set_name(name); - desc_.set_type(proto::VarDesc::LOD_TENSOR); + desc_.mutable_type()->set_type(proto::VarType::LOD_TENSOR); } explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {} @@ -96,19 +96,19 @@ class VarDesc { std::vector GetLoDLevels() const; - proto::VarDesc::VarType GetType() const; + proto::VarType::Type GetType() const; - void SetType(proto::VarDesc::VarType type); + void SetType(proto::VarType::Type type); bool Persistable() const { return desc_.persistable(); } void SetPersistable(bool persistable) { desc_.set_persistable(persistable); } private: - const proto::TensorDesc &tensor_desc() const; - std::vector tensor_descs() const; - proto::TensorDesc *mutable_tensor_desc(); - std::vector mutable_tensor_descs(); + const proto::VarType::TensorDesc &tensor_desc() const; + std::vector tensor_descs() const; + proto::VarType::TensorDesc *mutable_tensor_desc(); + std::vector mutable_tensor_descs(); proto::VarDesc desc_; }; diff --git a/paddle/fluid/framework/var_type.h b/paddle/fluid/framework/var_type.h index b5a6183892..960ebff9d7 100644 --- a/paddle/fluid/framework/var_type.h +++ b/paddle/fluid/framework/var_type.h @@ -23,17 +23,17 @@ limitations under the License. */ namespace paddle { namespace framework { -inline proto::VarDesc::VarType ToVarType(std::type_index type) { +inline proto::VarType::Type ToVarType(std::type_index type) { if (type.hash_code() == typeid(LoDTensor).hash_code()) { - return proto::VarDesc_VarType_LOD_TENSOR; + return proto::VarType_Type_LOD_TENSOR; } else if (type.hash_code() == typeid(LoDRankTable).hash_code()) { - return proto::VarDesc_VarType_LOD_RANK_TABLE; + return proto::VarType_Type_LOD_RANK_TABLE; } else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) { - return proto::VarDesc_VarType_LOD_TENSOR_ARRAY; + return proto::VarType_Type_LOD_TENSOR_ARRAY; } else if (type.hash_code() == typeid(SelectedRows).hash_code()) { - return proto::VarDesc_VarType_SELECTED_ROWS; + return proto::VarType_Type_SELECTED_ROWS; } else if (type.hash_code() == typeid(ReaderHolder).hash_code()) { - return proto::VarDesc_VarType_READER; + return proto::VarType_Type_READER; } else { PADDLE_THROW("ToVarType:Unsupported type %s", type.name()); } @@ -42,19 +42,19 @@ inline proto::VarDesc::VarType ToVarType(std::type_index type) { template inline void VisitVarType(const framework::Variable& var, Visitor visitor) { switch (ToVarType(var.Type())) { - case proto::VarDesc_VarType_LOD_TENSOR: + case proto::VarType_Type_LOD_TENSOR: visitor(var.Get()); return; - case proto::VarDesc_VarType_LOD_RANK_TABLE: + case proto::VarType_Type_LOD_RANK_TABLE: visitor(var.Get()); return; - case proto::VarDesc_VarType_LOD_TENSOR_ARRAY: + case proto::VarType_Type_LOD_TENSOR_ARRAY: visitor(var.Get()); return; - case proto::VarDesc_VarType_SELECTED_ROWS: + case proto::VarType_Type_SELECTED_ROWS: visitor(var.Get()); return; - case proto::VarDesc_VarType_READER: + case proto::VarType_Type_READER: visitor(var.Get()); return; default: diff --git a/paddle/fluid/framework/var_type_inference_test.cc b/paddle/fluid/framework/var_type_inference_test.cc index 961f209ee1..1dced845ed 100644 --- a/paddle/fluid/framework/var_type_inference_test.cc +++ b/paddle/fluid/framework/var_type_inference_test.cc @@ -35,14 +35,14 @@ class SumOpVarTypeInference : public VarTypeInference { public: void operator()(const OpDesc &op_desc, BlockDesc *block) const override { auto &inputs = op_desc.Input("X"); - auto default_var_type = proto::VarDesc::SELECTED_ROWS; + auto default_var_type = proto::VarType::SELECTED_ROWS; bool any_input_is_lod_tensor = std::any_of( inputs.begin(), inputs.end(), [block](const std::string &name) { - return block->Var(name)->GetType() == proto::VarDesc::LOD_TENSOR; + return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR; }); if (any_input_is_lod_tensor) { - default_var_type = proto::VarDesc::LOD_TENSOR; + default_var_type = proto::VarType::LOD_TENSOR; } auto out_var_name = op_desc.Output("Out").front(); @@ -67,19 +67,19 @@ TEST(InferVarType, sum_op) { op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetOutput("Out", {"test_out"}); - prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarDesc::SELECTED_ROWS); - prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarDesc::SELECTED_ROWS); - prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarDesc::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_out"); op->InferVarType(prog.MutableBlock(0)); - ASSERT_EQ(proto::VarDesc::SELECTED_ROWS, + ASSERT_EQ(proto::VarType::SELECTED_ROWS, prog.MutableBlock(0)->Var("test_out")->GetType()); - prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarDesc::LOD_TENSOR); + prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::LOD_TENSOR); op->InferVarType(prog.MutableBlock(0)); - ASSERT_EQ(proto::VarDesc::LOD_TENSOR, + ASSERT_EQ(proto::VarType::LOD_TENSOR, prog.MutableBlock(0)->Var("test_out")->GetType()); } @@ -90,14 +90,14 @@ TEST(InferVarType, sum_op_without_infer_var_type) { op->SetInput("X", {"test2_a", "test2_b", "test2_c"}); op->SetOutput("Out", {"test2_out"}); - prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarDesc::SELECTED_ROWS); - prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarDesc::SELECTED_ROWS); - prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarDesc::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarType::SELECTED_ROWS); prog.MutableBlock(0)->Var("test2_out"); op->InferVarType(prog.MutableBlock(0)); - ASSERT_EQ(proto::VarDesc_VarType_LOD_TENSOR, + ASSERT_EQ(proto::VarType_Type_LOD_TENSOR, prog.MutableBlock(0)->Var("test2_out")->GetType()); } diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc index eedf6b8c66..e21dc6d77f 100644 --- a/paddle/fluid/operators/assign_op.cc +++ b/paddle/fluid/operators/assign_op.cc @@ -115,8 +115,8 @@ class AssignInferShape : public framework::InferShapeBase { void operator()(framework::InferShapeContext *context) const override { if (context->HasInput("X")) { auto type = context->GetInputsVarType("X")[0]; - if (type == framework::proto::VarDesc_VarType_SELECTED_ROWS || - type == framework::proto::VarDesc_VarType_LOD_TENSOR) { + if (type == framework::proto::VarType::SELECTED_ROWS || + type == framework::proto::VarType::LOD_TENSOR) { context->SetOutputDim("Out", context->GetInputDim("X")); } } diff --git a/paddle/fluid/operators/beam_search_decode_op.cc b/paddle/fluid/operators/beam_search_decode_op.cc index dacb0e2681..718f469d38 100644 --- a/paddle/fluid/operators/beam_search_decode_op.cc +++ b/paddle/fluid/operators/beam_search_decode_op.cc @@ -128,10 +128,10 @@ class BeamSearchDecodeInferVarType : public framework::VarTypeInference { void operator()(const framework::OpDesc& op_desc, framework::BlockDesc* block) const override { for (auto& o : op_desc.Output("SentenceIds")) { - block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR); + block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR); } for (auto& o : op_desc.Output("SentenceScores")) { - block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR); + block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR); } } }; diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index 76985ea9c2..e848b1f12c 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -240,10 +240,10 @@ class BeamSearchInferVarType : public framework::VarTypeInference { void operator()(const framework::OpDesc &op_desc, framework::BlockDesc *block) const override { for (auto &o : op_desc.Output("selected_ids")) { - block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR); + block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR); } for (auto &o : op_desc.Output("selected_scores")) { - block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR); + block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR); } } }; diff --git a/paddle/fluid/operators/create_reader_op.cc b/paddle/fluid/operators/create_reader_op.cc index 1393f1a66b..17ed7e24ec 100644 --- a/paddle/fluid/operators/create_reader_op.cc +++ b/paddle/fluid/operators/create_reader_op.cc @@ -84,7 +84,7 @@ class CreateFileReaderInferVarType : public framework::VarTypeInference { framework::BlockDesc* block) const override { std::string reader_name = op_desc.Output("Out")[0]; framework::VarDesc* reader = block->FindVarRecursive(reader_name); - reader->SetType(framework::proto::VarDesc::READER); + reader->SetType(framework::proto::VarType::READER); } }; @@ -97,7 +97,7 @@ class CreateDecoratedReaderInferVarType : public framework::VarTypeInference { framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name); std::string out_reader_name = op_desc.Output("Out")[0]; framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name); - out_reader->SetType(framework::proto::VarDesc::READER); + out_reader->SetType(framework::proto::VarType::READER); out_reader->SetDataTypes(in_reader->GetDataTypes()); } }; @@ -147,7 +147,7 @@ class CreateRandomDataGeneratorOpMaker AddComment(R"DOC( CreateRandomDataGenerator Operator - This Op creates a random reader. + This Op creates a random reader. The reader generates random data instead of really reading from files. Generated data follow an uniform distribution between 'min' and 'max'. )DOC"); @@ -183,7 +183,7 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker { CreateShuffleReader Operator A shuffle reader takes another reader as its 'underlying reader' - and yields the underlying reader's outputs in a shuffled order. + and yields the underlying reader's outputs in a shuffled order. )DOC"); } }; @@ -218,8 +218,8 @@ class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( CreateBatchReader Operator - A batch reader takes another reader as its 'underlying reader', - gathers the underlying reader's outputs and then yields them in batches. + A batch reader takes another reader as its 'underlying reader', + gathers the underlying reader's outputs and then yields them in batches. )DOC"); } }; diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index 5403dbc2a0..169fd40fd9 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -24,11 +24,11 @@ void SerializeToMessage(const std::string& name, const framework::Variable* var, msg->set_varname(name); std::ostringstream oss; switch (framework::ToVarType(var->Type())) { - case framework::proto::VarDesc_VarType_LOD_TENSOR: + case framework::proto::VarType_Type_LOD_TENSOR: msg->set_type(sendrecv::VarType::LOD_TENSOR); framework::SerializeToStream(oss, var->Get(), ctx); break; - case framework::proto::VarDesc_VarType_SELECTED_ROWS: + case framework::proto::VarType_Type_SELECTED_ROWS: msg->set_type(sendrecv::VarType::SELECTED_ROWS); framework::SerializeToStream(oss, var->Get(), ctx); diff --git a/paddle/fluid/operators/get_places_op.cc b/paddle/fluid/operators/get_places_op.cc index 8555b0778f..9002ce4717 100644 --- a/paddle/fluid/operators/get_places_op.cc +++ b/paddle/fluid/operators/get_places_op.cc @@ -98,7 +98,7 @@ class GetPlacesInferVarType : public framework::VarTypeInference { framework::BlockDesc *block) const override { for (auto &o_name : op_desc.Output("Out")) { block->FindRecursiveOrCreateVar(o_name).SetType( - framework::proto::VarDesc::PLACE_LIST); + framework::proto::VarType::PLACE_LIST); } } }; diff --git a/paddle/fluid/operators/lod_rank_table_op.cc b/paddle/fluid/operators/lod_rank_table_op.cc index 2d01ed6737..590b44e14f 100644 --- a/paddle/fluid/operators/lod_rank_table_op.cc +++ b/paddle/fluid/operators/lod_rank_table_op.cc @@ -69,7 +69,7 @@ class LoDRankTableInferVarType : public framework::VarTypeInference { framework::BlockDesc *block) const override { for (auto &o : op_desc.Output("Out")) { block->FindRecursiveOrCreateVar(o).SetType( - framework::proto::VarDesc::LOD_RANK_TABLE); + framework::proto::VarType::LOD_RANK_TABLE); } } }; diff --git a/paddle/fluid/operators/lod_tensor_to_array_op.cc b/paddle/fluid/operators/lod_tensor_to_array_op.cc index be47fdfd04..b5e778a581 100644 --- a/paddle/fluid/operators/lod_tensor_to_array_op.cc +++ b/paddle/fluid/operators/lod_tensor_to_array_op.cc @@ -138,7 +138,7 @@ class LoDTensorToArrayInferVarType : public framework::VarTypeInference { void operator()(const framework::OpDesc &op_desc, framework::BlockDesc *block) const override { for (auto &out_var : op_desc.Output("Out")) { - block->Var(out_var)->SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY); + block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR_ARRAY); } } }; diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index d338553f7c..3acdca17af 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -123,11 +123,11 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference { VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") << " is set to SelectedRows"; block->Var(out_var_name) - ->SetType(framework::proto::VarDesc::SELECTED_ROWS); + ->SetType(framework::proto::VarType::SELECTED_ROWS); } else { VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") << " is set to LoDTensor"; - block->Var(out_var_name)->SetType(framework::proto::VarDesc::LOD_TENSOR); + block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR); } } }; diff --git a/paddle/fluid/operators/read_op.cc b/paddle/fluid/operators/read_op.cc index 127df82ff1..62beab82d4 100644 --- a/paddle/fluid/operators/read_op.cc +++ b/paddle/fluid/operators/read_op.cc @@ -45,7 +45,7 @@ class ReadInferVarType : public framework::VarTypeInference { PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size()); for (size_t i = 0; i < dtypes.size(); ++i) { framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]); - out.SetType(framework::proto::VarDesc::LOD_TENSOR); + out.SetType(framework::proto::VarType::LOD_TENSOR); out.SetDataType(dtypes[i]); } } diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index bfc5709c4b..7b88387c33 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -29,7 +29,7 @@ class SumOp : public framework::OperatorWithKernel { "Output(Out) of SumOp should not be null."); if (ctx->IsRuntime() && ctx->GetOutputsVarType("Out")[0] == - framework::proto::VarDesc::LOD_TENSOR_ARRAY) { + framework::proto::VarType::LOD_TENSOR_ARRAY) { return; // skip runtime infershape when is tensor array; } @@ -118,7 +118,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference { void operator()(const framework::OpDesc& op_desc, framework::BlockDesc* block) const override { auto& inputs = op_desc.Input("X"); - auto var_type = framework::proto::VarDesc::SELECTED_ROWS; + auto var_type = framework::proto::VarType::SELECTED_ROWS; for (auto& name : op_desc.Input("X")) { VLOG(10) << name << " " @@ -128,12 +128,12 @@ class SumOpVarTypeInference : public framework::VarTypeInference { bool any_input_is_lod_tensor = std::any_of( inputs.begin(), inputs.end(), [block](const std::string& name) { return block->FindRecursiveOrCreateVar(name).GetType() == - framework::proto::VarDesc::LOD_TENSOR; + framework::proto::VarType::LOD_TENSOR; }); auto is_tensor_array = [block](const std::string& name) { return block->FindRecursiveOrCreateVar(name).GetType() == - framework::proto::VarDesc::LOD_TENSOR_ARRAY; + framework::proto::VarType::LOD_TENSOR_ARRAY; }; bool any_input_is_tensor_array = @@ -151,9 +151,9 @@ class SumOpVarTypeInference : public framework::VarTypeInference { PADDLE_ENFORCE(all_inputs_are_tensor_array, "Not all inputs are tensor array:\n%s", os.str()); } - var_type = framework::proto::VarDesc::LOD_TENSOR_ARRAY; + var_type = framework::proto::VarType::LOD_TENSOR_ARRAY; } else if (any_input_is_lod_tensor) { - var_type = framework::proto::VarDesc::LOD_TENSOR; + var_type = framework::proto::VarType::LOD_TENSOR; } auto out_var_name = op_desc.Output("Out").front(); diff --git a/paddle/fluid/operators/tensor_array_read_write_op.cc b/paddle/fluid/operators/tensor_array_read_write_op.cc index 278b348117..9b484cda12 100644 --- a/paddle/fluid/operators/tensor_array_read_write_op.cc +++ b/paddle/fluid/operators/tensor_array_read_write_op.cc @@ -108,7 +108,7 @@ class WriteToArrayInferVarType : public framework::VarTypeInference { auto out_name = op_desc.Output("Out")[0]; VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY"; auto &out = block->FindRecursiveOrCreateVar(out_name); - out.SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY); + out.SetType(framework::proto::VarType::LOD_TENSOR_ARRAY); auto *x = block->FindVarRecursive(x_name); if (x != nullptr) { out.SetDataType(x->GetDataType()); diff --git a/paddle/fluid/operators/while_op.cc b/paddle/fluid/operators/while_op.cc index 94a11eaf78..3d5cdeda26 100644 --- a/paddle/fluid/operators/while_op.cc +++ b/paddle/fluid/operators/while_op.cc @@ -330,10 +330,10 @@ class WhileGradOpShapeInference : public framework::InferShapeBase { continue; } auto dims = ctx->GetInputsElementDim(kX, i); - if (var_types[i] == framework::proto::VarDesc::LOD_TENSOR) { + if (var_types[i] == framework::proto::VarType::LOD_TENSOR) { names_to_set.push_back(pg_names[i]); dims_to_set.push_back(dims); - } else if (var_types[i] == framework::proto::VarDesc::LOD_TENSOR_ARRAY) { + } else if (var_types[i] == framework::proto::VarType::LOD_TENSOR_ARRAY) { // not sure how to set the dim of LOD_TENSOR_ARRAY names_to_set.push_back(pg_names[i]); dims_to_set.push_back(dims); diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 3341edb370..9f97cc5007 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -232,16 +232,16 @@ void BindVarDsec(py::module &m) { .def("persistable", &VarDesc::Persistable) .def("set_persistable", &VarDesc::SetPersistable); - py::enum_(var_desc, "VarType", "") - .value("LOD_TENSOR", proto::VarDesc::LOD_TENSOR) - .value("SELECTED_ROWS", proto::VarDesc::SELECTED_ROWS) - .value("FEED_MINIBATCH", proto::VarDesc::FEED_MINIBATCH) - .value("FETCH_LIST", proto::VarDesc::FETCH_LIST) - .value("STEP_SCOPES", proto::VarDesc::STEP_SCOPES) - .value("LOD_RANK_TABLE", proto::VarDesc::LOD_RANK_TABLE) - .value("LOD_TENSOR_ARRAY", proto::VarDesc::LOD_TENSOR_ARRAY) - .value("PLACE_LIST", proto::VarDesc::PLACE_LIST) - .value("READER", proto::VarDesc::READER); + py::enum_(var_desc, "VarType", "") + .value("LOD_TENSOR", proto::VarType::LOD_TENSOR) + .value("SELECTED_ROWS", proto::VarType::SELECTED_ROWS) + .value("FEED_MINIBATCH", proto::VarType::FEED_MINIBATCH) + .value("FETCH_LIST", proto::VarType::FETCH_LIST) + .value("STEP_SCOPES", proto::VarType::STEP_SCOPES) + .value("LOD_RANK_TABLE", proto::VarType::LOD_RANK_TABLE) + .value("LOD_TENSOR_ARRAY", proto::VarType::LOD_TENSOR_ARRAY) + .value("PLACE_LIST", proto::VarType::PLACE_LIST) + .value("READER", proto::VarType::READER); } void BindOpDesc(py::module &m) { -- GitLab