提交 fcadb452 编写于 作者: A Abhinav Arora 提交者: Yi Wang

Separate VarType from VarDesc in framework.proto and fix all related compiler errors (#8414)

* Refine Type system

* Fixing type inference

* Fixed create_reader_op.cc

* Fix var_desc.h

* Fixed executor.cc

* Fix shape_inference.h

* Fixed create_reader_op.cc

* Fix tensor_util.h

* Fixed var_type_inference_test.cc

* Fix shape_inference.cc

* Fixed sum_op.c

* Fixed read_op.cc

* Fix var_type.h

* Fixed beam_search_decode_op.cc

* sendrecvop_utils.cc

* Fix operator.cc

* Fixed lookup_table_op.cc

* Fixed op_desc.cc

* Fixed get_places_op.cc

* Fixed lod_rank_table_op.cc

* Fixed beam_search_op.cc

* Fix var_desc.cc

* Fixed lod_tensor_to_array_op.cc

* Fixed while_op.cc

* Fix program_desc_test.cc

* tensor_array_read_write_op.cc

* Fix assign_op.cc

* Fix executor.cc

* Fix protobuf.cc

* Fix protobuf.cc
上级 f82fa64a
...@@ -36,24 +36,24 @@ namespace framework { ...@@ -36,24 +36,24 @@ namespace framework {
Executor::Executor(const platform::Place& place) : place_(place) {} Executor::Executor(const platform::Place& place) : place_(place) {}
static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarDesc::LOD_TENSOR) { if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>(); var->GetMutable<LoDTensor>();
} else if (var_type == proto::VarDesc::SELECTED_ROWS) { } else if (var_type == proto::VarType::SELECTED_ROWS) {
var->GetMutable<SelectedRows>(); var->GetMutable<SelectedRows>();
} else if (var_type == proto::VarDesc::FEED_MINIBATCH) { } else if (var_type == proto::VarType::FEED_MINIBATCH) {
var->GetMutable<FeedFetchList>(); var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarDesc::FETCH_LIST) { } else if (var_type == proto::VarType::FETCH_LIST) {
var->GetMutable<FeedFetchList>(); var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarDesc::STEP_SCOPES) { } else if (var_type == proto::VarType::STEP_SCOPES) {
var->GetMutable<std::vector<framework::Scope>>(); var->GetMutable<std::vector<framework::Scope>>();
} else if (var_type == proto::VarDesc::LOD_RANK_TABLE) { } else if (var_type == proto::VarType::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>(); var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarDesc::LOD_TENSOR_ARRAY) { } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
var->GetMutable<LoDTensorArray>(); var->GetMutable<LoDTensorArray>();
} else if (var_type == proto::VarDesc::PLACE_LIST) { } else if (var_type == proto::VarType::PLACE_LIST) {
var->GetMutable<platform::PlaceList>(); var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarDesc::READER) { } else if (var_type == proto::VarType::READER) {
var->GetMutable<ReaderHolder>(); var->GetMutable<ReaderHolder>();
} else { } else {
PADDLE_THROW( PADDLE_THROW(
...@@ -182,7 +182,7 @@ static bool has_feed_operators( ...@@ -182,7 +182,7 @@ static bool has_feed_operators(
auto var = block->FindVar(feed_holder_name); auto var = block->FindVar(feed_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
feed_holder_name); feed_holder_name);
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarDesc::FEED_MINIBATCH, PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
"'%s' variable should be 'FEED_MINIBATCH' type", "'%s' variable should be 'FEED_MINIBATCH' type",
feed_holder_name); feed_holder_name);
} }
...@@ -222,7 +222,7 @@ static bool has_fetch_operators( ...@@ -222,7 +222,7 @@ static bool has_fetch_operators(
auto var = block->FindVar(fetch_holder_name); auto var = block->FindVar(fetch_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
fetch_holder_name); fetch_holder_name);
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarDesc::FETCH_LIST, PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
"'%s' variable should be 'FETCH_LIST' type", "'%s' variable should be 'FETCH_LIST' type",
fetch_holder_name); fetch_holder_name);
} }
...@@ -241,7 +241,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -241,7 +241,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) { if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) {
// create feed_holder variable // create feed_holder variable
auto* feed_holder = global_block->Var(feed_holder_name); auto* feed_holder = global_block->Var(feed_holder_name);
feed_holder->SetType(proto::VarDesc::FEED_MINIBATCH); feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
feed_holder->SetPersistable(true); feed_holder->SetPersistable(true);
int i = 0; int i = 0;
...@@ -274,7 +274,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -274,7 +274,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) { if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) {
// create fetch_holder variable // create fetch_holder variable
auto* fetch_holder = global_block->Var(fetch_holder_name); auto* fetch_holder = global_block->Var(fetch_holder_name);
fetch_holder->SetType(proto::VarDesc::FETCH_LIST); fetch_holder->SetType(proto::VarType::FETCH_LIST);
fetch_holder->SetPersistable(true); fetch_holder->SetPersistable(true);
int i = 0; int i = 0;
......
...@@ -101,42 +101,47 @@ enum DataType { ...@@ -101,42 +101,47 @@ enum DataType {
FP64 = 6; FP64 = 6;
} }
message TensorDesc { message VarType {
enum Type {
LOD_TENSOR = 1;
SELECTED_ROWS = 2;
FEED_MINIBATCH = 3;
FETCH_LIST = 4;
STEP_SCOPES = 5;
LOD_RANK_TABLE = 6;
LOD_TENSOR_ARRAY = 7;
PLACE_LIST = 8;
READER = 9;
}
required Type type = 1;
message TensorDesc {
required DataType data_type = 1; required DataType data_type = 1;
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480] repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
} }
optional TensorDesc selected_rows = 2;
message LoDTensorDesc { message LoDTensorDesc {
required TensorDesc tensor = 1; required TensorDesc tensor = 1;
optional int32 lod_level = 2 [ default = 0 ]; optional int32 lod_level = 2 [ default = 0 ];
} }
optional LoDTensorDesc lod_tensor = 3;
message LoDTensorArrayDesc { message LoDTensorArrayDesc {
required TensorDesc tensor = 1; required TensorDesc tensor = 1;
optional int32 lod_level = 2 [ default = 0 ]; optional int32 lod_level = 2 [ default = 0 ];
} }
optional LoDTensorArrayDesc tensor_array = 4;
message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; } message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; }
optional ReaderDesc reader = 5;
}
message VarDesc { message VarDesc {
enum VarType {
LOD_TENSOR = 1;
SELECTED_ROWS = 2;
FEED_MINIBATCH = 3;
FETCH_LIST = 4;
STEP_SCOPES = 5;
LOD_RANK_TABLE = 6;
LOD_TENSOR_ARRAY = 7;
PLACE_LIST = 8;
READER = 9;
}
required string name = 1; required string name = 1;
required VarType type = 2; required VarType type = 2;
optional bool persistable = 3 [ default = false ]; optional bool persistable = 3 [ default = false ];
optional LoDTensorDesc lod_tensor = 4;
optional TensorDesc selected_rows = 5;
optional LoDTensorArrayDesc tensor_array = 6;
optional ReaderDesc reader = 7;
} }
message BlockDesc { message BlockDesc {
......
...@@ -53,11 +53,11 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -53,11 +53,11 @@ class CompileTimeInferShapeContext : public InferShapeContext {
PADDLE_ENFORCE_LT(j, Outputs(out).size()); PADDLE_ENFORCE_LT(j, Outputs(out).size());
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]); auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]); auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
if (in_var->GetType() != proto::VarDesc::LOD_TENSOR) { if (in_var->GetType() != proto::VarType::LOD_TENSOR) {
VLOG(3) << "input " << in << " is not LodTensor"; VLOG(3) << "input " << in << " is not LodTensor";
return; return;
} }
PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarDesc::LOD_TENSOR, PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarType::LOD_TENSOR,
"The %d-th output of Output(%s) must be LoDTensor.", j, "The %d-th output of Output(%s) must be LoDTensor.", j,
out); out);
out_var->SetLoDLevel(in_var->GetLoDLevel()); out_var->SetLoDLevel(in_var->GetLoDLevel());
...@@ -66,7 +66,7 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -66,7 +66,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool IsRuntime() const override; bool IsRuntime() const override;
protected: protected:
proto::VarDesc::VarType GetVarType(const std::string &name) const override; proto::VarType::Type GetVarType(const std::string &name) const override;
DDim GetDim(const std::string &name) const override; DDim GetDim(const std::string &name) const override;
...@@ -388,7 +388,7 @@ void OpDesc::InferVarType(BlockDesc *block) const { ...@@ -388,7 +388,7 @@ void OpDesc::InferVarType(BlockDesc *block) const {
for (auto &out_pair : this->outputs_) { for (auto &out_pair : this->outputs_) {
for (auto &out_var_name : out_pair.second) { for (auto &out_var_name : out_pair.second) {
block->FindRecursiveOrCreateVar(out_var_name) block->FindRecursiveOrCreateVar(out_var_name)
.SetType(proto::VarDesc::LOD_TENSOR); .SetType(proto::VarType::LOD_TENSOR);
} }
} }
} }
...@@ -507,7 +507,7 @@ void CompileTimeInferShapeContext::SetRepeatedDims( ...@@ -507,7 +507,7 @@ void CompileTimeInferShapeContext::SetRepeatedDims(
bool CompileTimeInferShapeContext::IsRuntime() const { return false; } bool CompileTimeInferShapeContext::IsRuntime() const { return false; }
proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType( proto::VarType::Type CompileTimeInferShapeContext::GetVarType(
const std::string &name) const { const std::string &name) const {
return block_.FindVarRecursive(name)->GetType(); return block_.FindVarRecursive(name)->GetType();
} }
......
...@@ -477,7 +477,7 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -477,7 +477,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
} }
proto::VarDesc::VarType GetVarType(const std::string& name) const override { proto::VarType::Type GetVarType(const std::string& name) const override {
auto* var = scope_.FindVar(name); auto* var = scope_.FindVar(name);
return ToVarType(var->Type()); return ToVarType(var->Type());
} }
......
...@@ -22,13 +22,13 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -22,13 +22,13 @@ TEST(ProgramDesc, copy_ctor) {
ProgramDesc program; ProgramDesc program;
auto* global_block = program.MutableBlock(0); auto* global_block = program.MutableBlock(0);
auto* x = global_block->Var("X"); auto* x = global_block->Var("X");
x->SetType(proto::VarDesc_VarType_LOD_TENSOR); x->SetType(proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0); x->SetLoDLevel(0);
x->SetDataType(proto::FP32); x->SetDataType(proto::FP32);
x->SetShape({1000, 784}); x->SetShape({1000, 784});
auto* y = global_block->Var("Y"); auto* y = global_block->Var("Y");
y->SetType(proto::VarDesc_VarType_LOD_TENSOR); y->SetType(proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0); y->SetLoDLevel(0);
y->SetDataType(proto::FP32); y->SetDataType(proto::FP32);
y->SetShape({784, 100}); y->SetShape({784, 100});
...@@ -39,7 +39,7 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -39,7 +39,7 @@ TEST(ProgramDesc, copy_ctor) {
op->SetInput("Y", {y->Name()}); op->SetInput("Y", {y->Name()});
auto* out = global_block->Var("Out"); auto* out = global_block->Var("Out");
out->SetType(proto::VarDesc_VarType_LOD_TENSOR); out->SetType(proto::VarType::LOD_TENSOR);
op->SetOutput("Y", {out->Name()}); op->SetOutput("Y", {out->Name()});
ProgramDesc program_copy(program); ProgramDesc program_copy(program);
...@@ -84,13 +84,13 @@ TEST(ProgramDescBind, serialize_and_deserialize) { ...@@ -84,13 +84,13 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
ProgramDesc program_origin; ProgramDesc program_origin;
auto* global_block = program_origin.MutableBlock(0); auto* global_block = program_origin.MutableBlock(0);
auto* x = global_block->Var("X"); auto* x = global_block->Var("X");
x->SetType(proto::VarDesc_VarType_LOD_TENSOR); x->SetType(proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0); x->SetLoDLevel(0);
x->SetDataType(proto::FP32); x->SetDataType(proto::FP32);
x->SetShape({1000, 784}); x->SetShape({1000, 784});
auto* y = global_block->Var("Y"); auto* y = global_block->Var("Y");
y->SetType(proto::VarDesc_VarType_LOD_TENSOR); y->SetType(proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0); y->SetLoDLevel(0);
y->SetDataType(proto::FP32); y->SetDataType(proto::FP32);
y->SetShape({784, 100}); y->SetShape({784, 100});
...@@ -101,7 +101,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) { ...@@ -101,7 +101,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
op->SetInput("Y", {y->Name()}); op->SetInput("Y", {y->Name()});
auto* out = global_block->Var("Out"); auto* out = global_block->Var("Out");
out->SetType(proto::VarDesc_VarType_LOD_TENSOR); out->SetType(proto::VarType::LOD_TENSOR);
op->SetOutput("Y", {out->Name()}); op->SetOutput("Y", {out->Name()});
std::string binary_str; std::string binary_str;
......
...@@ -116,19 +116,19 @@ void InferShapeContext::SetDims(const std::vector<std::string> &names, ...@@ -116,19 +116,19 @@ void InferShapeContext::SetDims(const std::vector<std::string> &names,
} }
} }
std::vector<proto::VarDesc::VarType> InferShapeContext::GetInputsVarType( std::vector<proto::VarType::Type> InferShapeContext::GetInputsVarType(
const std::string &name) const { const std::string &name) const {
return GetVarTypes(Inputs(name)); return GetVarTypes(Inputs(name));
} }
std::vector<proto::VarDesc::VarType> InferShapeContext::GetOutputsVarType( std::vector<proto::VarType::Type> InferShapeContext::GetOutputsVarType(
const std::string &name) const { const std::string &name) const {
return GetVarTypes(Outputs(name)); return GetVarTypes(Outputs(name));
} }
std::vector<proto::VarDesc::VarType> InferShapeContext::GetVarTypes( std::vector<proto::VarType::Type> InferShapeContext::GetVarTypes(
const std::vector<std::string> &names) const { const std::vector<std::string> &names) const {
std::vector<proto::VarDesc::VarType> retv; std::vector<proto::VarType::Type> retv;
retv.resize(names.size()); retv.resize(names.size());
std::transform(names.begin(), names.end(), retv.begin(), std::transform(names.begin(), names.end(), retv.begin(),
std::bind(std::mem_fn(&InferShapeContext::GetVarType), this, std::bind(std::mem_fn(&InferShapeContext::GetVarType), this,
......
...@@ -31,9 +31,9 @@ class InferShapeContext { ...@@ -31,9 +31,9 @@ class InferShapeContext {
virtual bool HasInput(const std::string &name) const = 0; virtual bool HasInput(const std::string &name) const = 0;
virtual bool HasOutput(const std::string &name) const = 0; virtual bool HasOutput(const std::string &name) const = 0;
std::vector<proto::VarDesc::VarType> GetInputsVarType( std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const; const std::string &name) const;
std::vector<proto::VarDesc::VarType> GetOutputsVarType( std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string &name) const; const std::string &name) const;
virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasInputs(const std::string &name) const = 0;
...@@ -75,10 +75,10 @@ class InferShapeContext { ...@@ -75,10 +75,10 @@ class InferShapeContext {
std::vector<DDim> GetDims(const std::vector<std::string> &names) const; std::vector<DDim> GetDims(const std::vector<std::string> &names) const;
std::vector<proto::VarDesc::VarType> GetVarTypes( std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<std::string> &names) const; const std::vector<std::string> &names) const;
virtual proto::VarDesc::VarType GetVarType(const std::string &name) const = 0; virtual proto::VarType::Type GetVarType(const std::string &name) const = 0;
virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0; virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0;
}; };
......
...@@ -225,7 +225,7 @@ inline void SerializeToStream(std::ostream& os, const Tensor& tensor, ...@@ -225,7 +225,7 @@ inline void SerializeToStream(std::ostream& os, const Tensor& tensor,
{ // the 2nd field, tensor description { // the 2nd field, tensor description
// int32_t size // int32_t size
// void* protobuf message // void* protobuf message
proto::TensorDesc desc; proto::VarType::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type())); desc.set_data_type(framework::ToDataType(tensor.type()));
auto dims = framework::vectorize(tensor.dims()); auto dims = framework::vectorize(tensor.dims());
auto* pb_dims = desc.mutable_dims(); auto* pb_dims = desc.mutable_dims();
...@@ -290,7 +290,7 @@ inline void DeserializeFromStream(std::istream& is, Tensor* tensor, ...@@ -290,7 +290,7 @@ inline void DeserializeFromStream(std::istream& is, Tensor* tensor,
uint32_t version; uint32_t version;
is.read(reinterpret_cast<char*>(&version), sizeof(version)); is.read(reinterpret_cast<char*>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
proto::TensorDesc desc; proto::VarType::TensorDesc desc;
{ // int32_t size { // int32_t size
// proto buffer // proto buffer
int32_t size; int32_t size;
......
...@@ -18,18 +18,21 @@ limitations under the License. */ ...@@ -18,18 +18,21 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
proto::VarDesc::VarType VarDesc::GetType() const { return desc_.type(); } proto::VarType::Type VarDesc::GetType() const { return desc_.type().type(); }
void VarDesc::SetType(proto::VarDesc::VarType type) { desc_.set_type(type); } void VarDesc::SetType(proto::VarType::Type type) {
desc_.mutable_type()->set_type(type);
}
void VarDesc::SetShape(const std::vector<int64_t> &dims) { void VarDesc::SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims()); VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
} }
void VarDesc::SetTensorDescNum(size_t num) { void VarDesc::SetTensorDescNum(size_t num) {
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::READER: { case proto::VarType::READER: {
auto *lod_tensors_ptr = desc_.mutable_reader()->mutable_lod_tensor(); auto *lod_tensors_ptr =
desc_.mutable_type()->mutable_reader()->mutable_lod_tensor();
lod_tensors_ptr->Clear(); lod_tensors_ptr->Clear();
for (size_t i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
lod_tensors_ptr->Add(); lod_tensors_ptr->Add();
...@@ -44,9 +47,9 @@ void VarDesc::SetTensorDescNum(size_t num) { ...@@ -44,9 +47,9 @@ void VarDesc::SetTensorDescNum(size_t num) {
} }
size_t VarDesc::GetTensorDescNum() const { size_t VarDesc::GetTensorDescNum() const {
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::READER: case proto::VarType::READER:
return desc_.reader().lod_tensor_size(); return desc_.type().reader().lod_tensor_size();
break; break;
default: default:
PADDLE_THROW( PADDLE_THROW(
...@@ -64,7 +67,7 @@ void VarDesc::SetShapes( ...@@ -64,7 +67,7 @@ void VarDesc::SetShapes(
<< "). The Reader is going to be reinitialized."; << "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_dims.size()); SetTensorDescNum(multiple_dims.size());
} }
std::vector<proto::TensorDesc *> tensors = mutable_tensor_descs(); std::vector<proto::VarType::TensorDesc *> tensors = mutable_tensor_descs();
for (size_t i = 0; i < multiple_dims.size(); ++i) { for (size_t i = 0; i < multiple_dims.size(); ++i) {
VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims()); VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims());
} }
...@@ -75,7 +78,7 @@ std::vector<int64_t> VarDesc::GetShape() const { ...@@ -75,7 +78,7 @@ std::vector<int64_t> VarDesc::GetShape() const {
} }
std::vector<std::vector<int64_t>> VarDesc::GetShapes() const { std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
std::vector<proto::TensorDesc> descs = tensor_descs(); std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
std::vector<std::vector<int64_t>> res; std::vector<std::vector<int64_t>> res;
res.reserve(descs.size()); res.reserve(descs.size());
for (const auto &tensor_desc : descs) { for (const auto &tensor_desc : descs) {
...@@ -98,7 +101,8 @@ void VarDesc::SetDataTypes( ...@@ -98,7 +101,8 @@ void VarDesc::SetDataTypes(
<< "). The Reader is going to be reinitialized."; << "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_data_type.size()); SetTensorDescNum(multiple_data_type.size());
} }
std::vector<proto::TensorDesc *> tensor_descs = mutable_tensor_descs(); std::vector<proto::VarType::TensorDesc *> tensor_descs =
mutable_tensor_descs();
for (size_t i = 0; i < multiple_data_type.size(); ++i) { for (size_t i = 0; i < multiple_data_type.size(); ++i) {
tensor_descs[i]->set_data_type(multiple_data_type[i]); tensor_descs[i]->set_data_type(multiple_data_type[i]);
} }
...@@ -109,7 +113,7 @@ proto::DataType VarDesc::GetDataType() const { ...@@ -109,7 +113,7 @@ proto::DataType VarDesc::GetDataType() const {
} }
std::vector<proto::DataType> VarDesc::GetDataTypes() const { std::vector<proto::DataType> VarDesc::GetDataTypes() const {
std::vector<proto::TensorDesc> descs = tensor_descs(); std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
std::vector<proto::DataType> res; std::vector<proto::DataType> res;
res.reserve(descs.size()); res.reserve(descs.size());
for (const auto &tensor_desc : descs) { for (const auto &tensor_desc : descs) {
...@@ -119,12 +123,12 @@ std::vector<proto::DataType> VarDesc::GetDataTypes() const { ...@@ -119,12 +123,12 @@ std::vector<proto::DataType> VarDesc::GetDataTypes() const {
} }
void VarDesc::SetLoDLevel(int32_t lod_level) { void VarDesc::SetLoDLevel(int32_t lod_level) {
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::LOD_TENSOR: case proto::VarType::LOD_TENSOR:
desc_.mutable_lod_tensor()->set_lod_level(lod_level); desc_.mutable_type()->mutable_lod_tensor()->set_lod_level(lod_level);
break; break;
case proto::VarDesc::LOD_TENSOR_ARRAY: case proto::VarType::LOD_TENSOR_ARRAY:
desc_.mutable_tensor_array()->set_lod_level(lod_level); desc_.mutable_type()->mutable_tensor_array()->set_lod_level(lod_level);
break; break;
default: default:
PADDLE_THROW( PADDLE_THROW(
...@@ -142,10 +146,11 @@ void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) { ...@@ -142,10 +146,11 @@ void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
<< "). The Reader is going to be reinitialized."; << "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_lod_level.size()); SetTensorDescNum(multiple_lod_level.size());
} }
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::READER: { case proto::VarType::READER: {
size_t i = 0; size_t i = 0;
for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) { for (auto &lod_tensor :
*desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) {
lod_tensor.set_lod_level(multiple_lod_level[i++]); lod_tensor.set_lod_level(multiple_lod_level[i++]);
} }
} break; } break;
...@@ -157,11 +162,11 @@ void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) { ...@@ -157,11 +162,11 @@ void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
} }
int32_t VarDesc::GetLoDLevel() const { int32_t VarDesc::GetLoDLevel() const {
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::LOD_TENSOR: case proto::VarType::LOD_TENSOR:
return desc_.lod_tensor().lod_level(); return desc_.type().lod_tensor().lod_level();
case proto::VarDesc::LOD_TENSOR_ARRAY: case proto::VarType::LOD_TENSOR_ARRAY:
return desc_.tensor_array().lod_level(); return desc_.type().tensor_array().lod_level();
default: default:
PADDLE_THROW( PADDLE_THROW(
"Getting 'lod_level' is not supported by the type of var %s.", "Getting 'lod_level' is not supported by the type of var %s.",
...@@ -171,10 +176,10 @@ int32_t VarDesc::GetLoDLevel() const { ...@@ -171,10 +176,10 @@ int32_t VarDesc::GetLoDLevel() const {
std::vector<int32_t> VarDesc::GetLoDLevels() const { std::vector<int32_t> VarDesc::GetLoDLevels() const {
std::vector<int32_t> res; std::vector<int32_t> res;
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::READER: case proto::VarType::READER:
res.reserve(desc_.reader().lod_tensor_size()); res.reserve(desc_.type().reader().lod_tensor_size());
for (auto &lod_tensor : desc_.reader().lod_tensor()) { for (auto &lod_tensor : desc_.type().reader().lod_tensor()) {
res.push_back(lod_tensor.lod_level()); res.push_back(lod_tensor.lod_level());
} }
return res; return res;
...@@ -186,15 +191,16 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const { ...@@ -186,15 +191,16 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const {
} }
} }
const proto::TensorDesc &VarDesc::tensor_desc() const { const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set."); PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
switch (desc_.type()) { PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
case proto::VarDesc::SELECTED_ROWS: switch (desc_.type().type()) {
return desc_.selected_rows(); case proto::VarType::SELECTED_ROWS:
case proto::VarDesc::LOD_TENSOR: return desc_.type().selected_rows();
return desc_.lod_tensor().tensor(); case proto::VarType::LOD_TENSOR:
case proto::VarDesc::LOD_TENSOR_ARRAY: return desc_.type().lod_tensor().tensor();
return desc_.tensor_array().tensor(); case proto::VarType::LOD_TENSOR_ARRAY:
return desc_.type().tensor_array().tensor();
default: default:
PADDLE_THROW( PADDLE_THROW(
"Getting 'tensor_desc' is not supported by the type of var %s.", "Getting 'tensor_desc' is not supported by the type of var %s.",
...@@ -202,13 +208,13 @@ const proto::TensorDesc &VarDesc::tensor_desc() const { ...@@ -202,13 +208,13 @@ const proto::TensorDesc &VarDesc::tensor_desc() const {
} }
} }
std::vector<proto::TensorDesc> VarDesc::tensor_descs() const { std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
std::vector<proto::TensorDesc> res; std::vector<proto::VarType::TensorDesc> res;
res.reserve(GetTensorDescNum()); res.reserve(GetTensorDescNum());
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::READER: case proto::VarType::READER:
for (const auto &lod_tensor : desc_.reader().lod_tensor()) { for (const auto &lod_tensor : desc_.type().reader().lod_tensor()) {
res.push_back(lod_tensor.tensor()); res.push_back(lod_tensor.tensor());
} }
return res; return res;
...@@ -220,15 +226,16 @@ std::vector<proto::TensorDesc> VarDesc::tensor_descs() const { ...@@ -220,15 +226,16 @@ std::vector<proto::TensorDesc> VarDesc::tensor_descs() const {
} }
} }
proto::TensorDesc *VarDesc::mutable_tensor_desc() { proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
switch (desc_.type()) { PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
case proto::VarDesc::SELECTED_ROWS: switch (desc_.type().type()) {
return desc_.mutable_selected_rows(); case proto::VarType::SELECTED_ROWS:
case proto::VarDesc::LOD_TENSOR: return desc_.mutable_type()->mutable_selected_rows();
return desc_.mutable_lod_tensor()->mutable_tensor(); case proto::VarType::LOD_TENSOR:
case proto::VarDesc::LOD_TENSOR_ARRAY: return desc_.mutable_type()->mutable_lod_tensor()->mutable_tensor();
return desc_.mutable_tensor_array()->mutable_tensor(); case proto::VarType::LOD_TENSOR_ARRAY:
return desc_.mutable_type()->mutable_tensor_array()->mutable_tensor();
default: default:
PADDLE_THROW( PADDLE_THROW(
"Getting 'mutable_tensor_desc' is not supported by the type of var " "Getting 'mutable_tensor_desc' is not supported by the type of var "
...@@ -237,13 +244,15 @@ proto::TensorDesc *VarDesc::mutable_tensor_desc() { ...@@ -237,13 +244,15 @@ proto::TensorDesc *VarDesc::mutable_tensor_desc() {
} }
} }
std::vector<proto::TensorDesc *> VarDesc::mutable_tensor_descs() { std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
std::vector<proto::TensorDesc *> res; PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
std::vector<proto::VarType::TensorDesc *> res;
res.reserve(GetTensorDescNum()); res.reserve(GetTensorDescNum());
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::READER: case proto::VarType::READER:
for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) { for (auto &lod_tensor :
*desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) {
res.push_back(lod_tensor.mutable_tensor()); res.push_back(lod_tensor.mutable_tensor());
} }
return res; return res;
......
...@@ -57,7 +57,7 @@ class VarDesc { ...@@ -57,7 +57,7 @@ class VarDesc {
public: public:
explicit VarDesc(const std::string &name) { explicit VarDesc(const std::string &name) {
desc_.set_name(name); desc_.set_name(name);
desc_.set_type(proto::VarDesc::LOD_TENSOR); desc_.mutable_type()->set_type(proto::VarType::LOD_TENSOR);
} }
explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {} explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {}
...@@ -96,19 +96,19 @@ class VarDesc { ...@@ -96,19 +96,19 @@ class VarDesc {
std::vector<int32_t> GetLoDLevels() const; std::vector<int32_t> GetLoDLevels() const;
proto::VarDesc::VarType GetType() const; proto::VarType::Type GetType() const;
void SetType(proto::VarDesc::VarType type); void SetType(proto::VarType::Type type);
bool Persistable() const { return desc_.persistable(); } bool Persistable() const { return desc_.persistable(); }
void SetPersistable(bool persistable) { desc_.set_persistable(persistable); } void SetPersistable(bool persistable) { desc_.set_persistable(persistable); }
private: private:
const proto::TensorDesc &tensor_desc() const; const proto::VarType::TensorDesc &tensor_desc() const;
std::vector<proto::TensorDesc> tensor_descs() const; std::vector<proto::VarType::TensorDesc> tensor_descs() const;
proto::TensorDesc *mutable_tensor_desc(); proto::VarType::TensorDesc *mutable_tensor_desc();
std::vector<proto::TensorDesc *> mutable_tensor_descs(); std::vector<proto::VarType::TensorDesc *> mutable_tensor_descs();
proto::VarDesc desc_; proto::VarDesc desc_;
}; };
......
...@@ -23,17 +23,17 @@ limitations under the License. */ ...@@ -23,17 +23,17 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
inline proto::VarDesc::VarType ToVarType(std::type_index type) { inline proto::VarType::Type ToVarType(std::type_index type) {
if (type.hash_code() == typeid(LoDTensor).hash_code()) { if (type.hash_code() == typeid(LoDTensor).hash_code()) {
return proto::VarDesc_VarType_LOD_TENSOR; return proto::VarType_Type_LOD_TENSOR;
} else if (type.hash_code() == typeid(LoDRankTable).hash_code()) { } else if (type.hash_code() == typeid(LoDRankTable).hash_code()) {
return proto::VarDesc_VarType_LOD_RANK_TABLE; return proto::VarType_Type_LOD_RANK_TABLE;
} else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) { } else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) {
return proto::VarDesc_VarType_LOD_TENSOR_ARRAY; return proto::VarType_Type_LOD_TENSOR_ARRAY;
} else if (type.hash_code() == typeid(SelectedRows).hash_code()) { } else if (type.hash_code() == typeid(SelectedRows).hash_code()) {
return proto::VarDesc_VarType_SELECTED_ROWS; return proto::VarType_Type_SELECTED_ROWS;
} else if (type.hash_code() == typeid(ReaderHolder).hash_code()) { } else if (type.hash_code() == typeid(ReaderHolder).hash_code()) {
return proto::VarDesc_VarType_READER; return proto::VarType_Type_READER;
} else { } else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name()); PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
} }
...@@ -42,19 +42,19 @@ inline proto::VarDesc::VarType ToVarType(std::type_index type) { ...@@ -42,19 +42,19 @@ inline proto::VarDesc::VarType ToVarType(std::type_index type) {
template <typename Visitor> template <typename Visitor>
inline void VisitVarType(const framework::Variable& var, Visitor visitor) { inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
switch (ToVarType(var.Type())) { switch (ToVarType(var.Type())) {
case proto::VarDesc_VarType_LOD_TENSOR: case proto::VarType_Type_LOD_TENSOR:
visitor(var.Get<LoDTensor>()); visitor(var.Get<LoDTensor>());
return; return;
case proto::VarDesc_VarType_LOD_RANK_TABLE: case proto::VarType_Type_LOD_RANK_TABLE:
visitor(var.Get<LoDRankTable>()); visitor(var.Get<LoDRankTable>());
return; return;
case proto::VarDesc_VarType_LOD_TENSOR_ARRAY: case proto::VarType_Type_LOD_TENSOR_ARRAY:
visitor(var.Get<LoDTensorArray>()); visitor(var.Get<LoDTensorArray>());
return; return;
case proto::VarDesc_VarType_SELECTED_ROWS: case proto::VarType_Type_SELECTED_ROWS:
visitor(var.Get<SelectedRows>()); visitor(var.Get<SelectedRows>());
return; return;
case proto::VarDesc_VarType_READER: case proto::VarType_Type_READER:
visitor(var.Get<ReaderHolder>()); visitor(var.Get<ReaderHolder>());
return; return;
default: default:
......
...@@ -35,14 +35,14 @@ class SumOpVarTypeInference : public VarTypeInference { ...@@ -35,14 +35,14 @@ class SumOpVarTypeInference : public VarTypeInference {
public: public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override { void operator()(const OpDesc &op_desc, BlockDesc *block) const override {
auto &inputs = op_desc.Input("X"); auto &inputs = op_desc.Input("X");
auto default_var_type = proto::VarDesc::SELECTED_ROWS; auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of( bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string &name) { inputs.begin(), inputs.end(), [block](const std::string &name) {
return block->Var(name)->GetType() == proto::VarDesc::LOD_TENSOR; return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR;
}); });
if (any_input_is_lod_tensor) { if (any_input_is_lod_tensor) {
default_var_type = proto::VarDesc::LOD_TENSOR; default_var_type = proto::VarType::LOD_TENSOR;
} }
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = op_desc.Output("Out").front();
...@@ -67,19 +67,19 @@ TEST(InferVarType, sum_op) { ...@@ -67,19 +67,19 @@ TEST(InferVarType, sum_op) {
op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"}); op->SetOutput("Out", {"test_out"});
prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_out"); prog.MutableBlock(0)->Var("test_out");
op->InferVarType(prog.MutableBlock(0)); op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarDesc::SELECTED_ROWS, ASSERT_EQ(proto::VarType::SELECTED_ROWS,
prog.MutableBlock(0)->Var("test_out")->GetType()); prog.MutableBlock(0)->Var("test_out")->GetType());
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarDesc::LOD_TENSOR); prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::LOD_TENSOR);
op->InferVarType(prog.MutableBlock(0)); op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarDesc::LOD_TENSOR, ASSERT_EQ(proto::VarType::LOD_TENSOR,
prog.MutableBlock(0)->Var("test_out")->GetType()); prog.MutableBlock(0)->Var("test_out")->GetType());
} }
...@@ -90,14 +90,14 @@ TEST(InferVarType, sum_op_without_infer_var_type) { ...@@ -90,14 +90,14 @@ TEST(InferVarType, sum_op_without_infer_var_type) {
op->SetInput("X", {"test2_a", "test2_b", "test2_c"}); op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
op->SetOutput("Out", {"test2_out"}); op->SetOutput("Out", {"test2_out"});
prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_out"); prog.MutableBlock(0)->Var("test2_out");
op->InferVarType(prog.MutableBlock(0)); op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarDesc_VarType_LOD_TENSOR, ASSERT_EQ(proto::VarType_Type_LOD_TENSOR,
prog.MutableBlock(0)->Var("test2_out")->GetType()); prog.MutableBlock(0)->Var("test2_out")->GetType());
} }
......
...@@ -115,8 +115,8 @@ class AssignInferShape : public framework::InferShapeBase { ...@@ -115,8 +115,8 @@ class AssignInferShape : public framework::InferShapeBase {
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext *context) const override {
if (context->HasInput("X")) { if (context->HasInput("X")) {
auto type = context->GetInputsVarType("X")[0]; auto type = context->GetInputsVarType("X")[0];
if (type == framework::proto::VarDesc_VarType_SELECTED_ROWS || if (type == framework::proto::VarType::SELECTED_ROWS ||
type == framework::proto::VarDesc_VarType_LOD_TENSOR) { type == framework::proto::VarType::LOD_TENSOR) {
context->SetOutputDim("Out", context->GetInputDim("X")); context->SetOutputDim("Out", context->GetInputDim("X"));
} }
} }
......
...@@ -128,10 +128,10 @@ class BeamSearchDecodeInferVarType : public framework::VarTypeInference { ...@@ -128,10 +128,10 @@ class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDesc& op_desc, void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override { framework::BlockDesc* block) const override {
for (auto& o : op_desc.Output("SentenceIds")) { for (auto& o : op_desc.Output("SentenceIds")) {
block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR); block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
} }
for (auto& o : op_desc.Output("SentenceScores")) { for (auto& o : op_desc.Output("SentenceScores")) {
block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR); block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
} }
} }
}; };
......
...@@ -240,10 +240,10 @@ class BeamSearchInferVarType : public framework::VarTypeInference { ...@@ -240,10 +240,10 @@ class BeamSearchInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDesc &op_desc, void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override { framework::BlockDesc *block) const override {
for (auto &o : op_desc.Output("selected_ids")) { for (auto &o : op_desc.Output("selected_ids")) {
block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR); block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
} }
for (auto &o : op_desc.Output("selected_scores")) { for (auto &o : op_desc.Output("selected_scores")) {
block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR); block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
} }
} }
}; };
......
...@@ -84,7 +84,7 @@ class CreateFileReaderInferVarType : public framework::VarTypeInference { ...@@ -84,7 +84,7 @@ class CreateFileReaderInferVarType : public framework::VarTypeInference {
framework::BlockDesc* block) const override { framework::BlockDesc* block) const override {
std::string reader_name = op_desc.Output("Out")[0]; std::string reader_name = op_desc.Output("Out")[0];
framework::VarDesc* reader = block->FindVarRecursive(reader_name); framework::VarDesc* reader = block->FindVarRecursive(reader_name);
reader->SetType(framework::proto::VarDesc::READER); reader->SetType(framework::proto::VarType::READER);
} }
}; };
...@@ -97,7 +97,7 @@ class CreateDecoratedReaderInferVarType : public framework::VarTypeInference { ...@@ -97,7 +97,7 @@ class CreateDecoratedReaderInferVarType : public framework::VarTypeInference {
framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name); framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name);
std::string out_reader_name = op_desc.Output("Out")[0]; std::string out_reader_name = op_desc.Output("Out")[0];
framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name); framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name);
out_reader->SetType(framework::proto::VarDesc::READER); out_reader->SetType(framework::proto::VarType::READER);
out_reader->SetDataTypes(in_reader->GetDataTypes()); out_reader->SetDataTypes(in_reader->GetDataTypes());
} }
}; };
......
...@@ -24,11 +24,11 @@ void SerializeToMessage(const std::string& name, const framework::Variable* var, ...@@ -24,11 +24,11 @@ void SerializeToMessage(const std::string& name, const framework::Variable* var,
msg->set_varname(name); msg->set_varname(name);
std::ostringstream oss; std::ostringstream oss;
switch (framework::ToVarType(var->Type())) { switch (framework::ToVarType(var->Type())) {
case framework::proto::VarDesc_VarType_LOD_TENSOR: case framework::proto::VarType_Type_LOD_TENSOR:
msg->set_type(sendrecv::VarType::LOD_TENSOR); msg->set_type(sendrecv::VarType::LOD_TENSOR);
framework::SerializeToStream(oss, var->Get<framework::LoDTensor>(), ctx); framework::SerializeToStream(oss, var->Get<framework::LoDTensor>(), ctx);
break; break;
case framework::proto::VarDesc_VarType_SELECTED_ROWS: case framework::proto::VarType_Type_SELECTED_ROWS:
msg->set_type(sendrecv::VarType::SELECTED_ROWS); msg->set_type(sendrecv::VarType::SELECTED_ROWS);
framework::SerializeToStream(oss, var->Get<framework::SelectedRows>(), framework::SerializeToStream(oss, var->Get<framework::SelectedRows>(),
ctx); ctx);
......
...@@ -98,7 +98,7 @@ class GetPlacesInferVarType : public framework::VarTypeInference { ...@@ -98,7 +98,7 @@ class GetPlacesInferVarType : public framework::VarTypeInference {
framework::BlockDesc *block) const override { framework::BlockDesc *block) const override {
for (auto &o_name : op_desc.Output("Out")) { for (auto &o_name : op_desc.Output("Out")) {
block->FindRecursiveOrCreateVar(o_name).SetType( block->FindRecursiveOrCreateVar(o_name).SetType(
framework::proto::VarDesc::PLACE_LIST); framework::proto::VarType::PLACE_LIST);
} }
} }
}; };
......
...@@ -69,7 +69,7 @@ class LoDRankTableInferVarType : public framework::VarTypeInference { ...@@ -69,7 +69,7 @@ class LoDRankTableInferVarType : public framework::VarTypeInference {
framework::BlockDesc *block) const override { framework::BlockDesc *block) const override {
for (auto &o : op_desc.Output("Out")) { for (auto &o : op_desc.Output("Out")) {
block->FindRecursiveOrCreateVar(o).SetType( block->FindRecursiveOrCreateVar(o).SetType(
framework::proto::VarDesc::LOD_RANK_TABLE); framework::proto::VarType::LOD_RANK_TABLE);
} }
} }
}; };
......
...@@ -138,7 +138,7 @@ class LoDTensorToArrayInferVarType : public framework::VarTypeInference { ...@@ -138,7 +138,7 @@ class LoDTensorToArrayInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDesc &op_desc, void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override { framework::BlockDesc *block) const override {
for (auto &out_var : op_desc.Output("Out")) { for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY); block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR_ARRAY);
} }
} }
}; };
......
...@@ -123,11 +123,11 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference { ...@@ -123,11 +123,11 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows"; << " is set to SelectedRows";
block->Var(out_var_name) block->Var(out_var_name)
->SetType(framework::proto::VarDesc::SELECTED_ROWS); ->SetType(framework::proto::VarType::SELECTED_ROWS);
} else { } else {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor"; << " is set to LoDTensor";
block->Var(out_var_name)->SetType(framework::proto::VarDesc::LOD_TENSOR); block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR);
} }
} }
}; };
......
...@@ -45,7 +45,7 @@ class ReadInferVarType : public framework::VarTypeInference { ...@@ -45,7 +45,7 @@ class ReadInferVarType : public framework::VarTypeInference {
PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size()); PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size());
for (size_t i = 0; i < dtypes.size(); ++i) { for (size_t i = 0; i < dtypes.size(); ++i) {
framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]); framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]);
out.SetType(framework::proto::VarDesc::LOD_TENSOR); out.SetType(framework::proto::VarType::LOD_TENSOR);
out.SetDataType(dtypes[i]); out.SetDataType(dtypes[i]);
} }
} }
......
...@@ -29,7 +29,7 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -29,7 +29,7 @@ class SumOp : public framework::OperatorWithKernel {
"Output(Out) of SumOp should not be null."); "Output(Out) of SumOp should not be null.");
if (ctx->IsRuntime() && if (ctx->IsRuntime() &&
ctx->GetOutputsVarType("Out")[0] == ctx->GetOutputsVarType("Out")[0] ==
framework::proto::VarDesc::LOD_TENSOR_ARRAY) { framework::proto::VarType::LOD_TENSOR_ARRAY) {
return; // skip runtime infershape when is tensor array; return; // skip runtime infershape when is tensor array;
} }
...@@ -118,7 +118,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -118,7 +118,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
void operator()(const framework::OpDesc& op_desc, void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override { framework::BlockDesc* block) const override {
auto& inputs = op_desc.Input("X"); auto& inputs = op_desc.Input("X");
auto var_type = framework::proto::VarDesc::SELECTED_ROWS; auto var_type = framework::proto::VarType::SELECTED_ROWS;
for (auto& name : op_desc.Input("X")) { for (auto& name : op_desc.Input("X")) {
VLOG(10) << name << " " VLOG(10) << name << " "
...@@ -128,12 +128,12 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -128,12 +128,12 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
bool any_input_is_lod_tensor = std::any_of( bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string& name) { inputs.begin(), inputs.end(), [block](const std::string& name) {
return block->FindRecursiveOrCreateVar(name).GetType() == return block->FindRecursiveOrCreateVar(name).GetType() ==
framework::proto::VarDesc::LOD_TENSOR; framework::proto::VarType::LOD_TENSOR;
}); });
auto is_tensor_array = [block](const std::string& name) { auto is_tensor_array = [block](const std::string& name) {
return block->FindRecursiveOrCreateVar(name).GetType() == return block->FindRecursiveOrCreateVar(name).GetType() ==
framework::proto::VarDesc::LOD_TENSOR_ARRAY; framework::proto::VarType::LOD_TENSOR_ARRAY;
}; };
bool any_input_is_tensor_array = bool any_input_is_tensor_array =
...@@ -151,9 +151,9 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -151,9 +151,9 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
PADDLE_ENFORCE(all_inputs_are_tensor_array, PADDLE_ENFORCE(all_inputs_are_tensor_array,
"Not all inputs are tensor array:\n%s", os.str()); "Not all inputs are tensor array:\n%s", os.str());
} }
var_type = framework::proto::VarDesc::LOD_TENSOR_ARRAY; var_type = framework::proto::VarType::LOD_TENSOR_ARRAY;
} else if (any_input_is_lod_tensor) { } else if (any_input_is_lod_tensor) {
var_type = framework::proto::VarDesc::LOD_TENSOR; var_type = framework::proto::VarType::LOD_TENSOR;
} }
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = op_desc.Output("Out").front();
......
...@@ -108,7 +108,7 @@ class WriteToArrayInferVarType : public framework::VarTypeInference { ...@@ -108,7 +108,7 @@ class WriteToArrayInferVarType : public framework::VarTypeInference {
auto out_name = op_desc.Output("Out")[0]; auto out_name = op_desc.Output("Out")[0];
VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY"; VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY";
auto &out = block->FindRecursiveOrCreateVar(out_name); auto &out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY); out.SetType(framework::proto::VarType::LOD_TENSOR_ARRAY);
auto *x = block->FindVarRecursive(x_name); auto *x = block->FindVarRecursive(x_name);
if (x != nullptr) { if (x != nullptr) {
out.SetDataType(x->GetDataType()); out.SetDataType(x->GetDataType());
......
...@@ -330,10 +330,10 @@ class WhileGradOpShapeInference : public framework::InferShapeBase { ...@@ -330,10 +330,10 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
continue; continue;
} }
auto dims = ctx->GetInputsElementDim(kX, i); auto dims = ctx->GetInputsElementDim(kX, i);
if (var_types[i] == framework::proto::VarDesc::LOD_TENSOR) { if (var_types[i] == framework::proto::VarType::LOD_TENSOR) {
names_to_set.push_back(pg_names[i]); names_to_set.push_back(pg_names[i]);
dims_to_set.push_back(dims); dims_to_set.push_back(dims);
} else if (var_types[i] == framework::proto::VarDesc::LOD_TENSOR_ARRAY) { } else if (var_types[i] == framework::proto::VarType::LOD_TENSOR_ARRAY) {
// not sure how to set the dim of LOD_TENSOR_ARRAY // not sure how to set the dim of LOD_TENSOR_ARRAY
names_to_set.push_back(pg_names[i]); names_to_set.push_back(pg_names[i]);
dims_to_set.push_back(dims); dims_to_set.push_back(dims);
......
...@@ -232,16 +232,16 @@ void BindVarDsec(py::module &m) { ...@@ -232,16 +232,16 @@ void BindVarDsec(py::module &m) {
.def("persistable", &VarDesc::Persistable) .def("persistable", &VarDesc::Persistable)
.def("set_persistable", &VarDesc::SetPersistable); .def("set_persistable", &VarDesc::SetPersistable);
py::enum_<proto::VarDesc::VarType>(var_desc, "VarType", "") py::enum_<proto::VarType::Type>(var_desc, "VarType", "")
.value("LOD_TENSOR", proto::VarDesc::LOD_TENSOR) .value("LOD_TENSOR", proto::VarType::LOD_TENSOR)
.value("SELECTED_ROWS", proto::VarDesc::SELECTED_ROWS) .value("SELECTED_ROWS", proto::VarType::SELECTED_ROWS)
.value("FEED_MINIBATCH", proto::VarDesc::FEED_MINIBATCH) .value("FEED_MINIBATCH", proto::VarType::FEED_MINIBATCH)
.value("FETCH_LIST", proto::VarDesc::FETCH_LIST) .value("FETCH_LIST", proto::VarType::FETCH_LIST)
.value("STEP_SCOPES", proto::VarDesc::STEP_SCOPES) .value("STEP_SCOPES", proto::VarType::STEP_SCOPES)
.value("LOD_RANK_TABLE", proto::VarDesc::LOD_RANK_TABLE) .value("LOD_RANK_TABLE", proto::VarType::LOD_RANK_TABLE)
.value("LOD_TENSOR_ARRAY", proto::VarDesc::LOD_TENSOR_ARRAY) .value("LOD_TENSOR_ARRAY", proto::VarType::LOD_TENSOR_ARRAY)
.value("PLACE_LIST", proto::VarDesc::PLACE_LIST) .value("PLACE_LIST", proto::VarType::PLACE_LIST)
.value("READER", proto::VarDesc::READER); .value("READER", proto::VarType::READER);
} }
void BindOpDesc(py::module &m) { void BindOpDesc(py::module &m) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册