提交 fcadb452 编写于 作者: A Abhinav Arora 提交者: Yi Wang

Separate VarType from VarDesc in framework.proto and fix all related compiler errors (#8414)

* Refine Type system

* Fixing type inference

* Fixed create_reader_op.cc

* Fix var_desc.h

* Fixed executor.cc

* Fix shape_inference.h

* Fixed create_reader_op.cc

* Fix tensor_util.h

* Fixed var_type_inference_test.cc

* Fix shape_inference.cc

* Fixed sum_op.c

* Fixed read_op.cc

* Fix var_type.h

* Fixed beam_search_decode_op.cc

* sendrecvop_utils.cc

* Fix operator.cc

* Fixed lookup_table_op.cc

* Fixed op_desc.cc

* Fixed get_places_op.cc

* Fixed lod_rank_table_op.cc

* Fixed beam_search_op.cc

* Fix var_desc.cc

* Fixed lod_tensor_to_array_op.cc

* Fixed while_op.cc

* Fix program_desc_test.cc

* tensor_array_read_write_op.cc

* Fix assign_op.cc

* Fix executor.cc

* Fix protobuf.cc

* Fix protobuf.cc
上级 f82fa64a
...@@ -36,24 +36,24 @@ namespace framework { ...@@ -36,24 +36,24 @@ namespace framework {
Executor::Executor(const platform::Place& place) : place_(place) {} Executor::Executor(const platform::Place& place) : place_(place) {}
static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarDesc::LOD_TENSOR) { if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>(); var->GetMutable<LoDTensor>();
} else if (var_type == proto::VarDesc::SELECTED_ROWS) { } else if (var_type == proto::VarType::SELECTED_ROWS) {
var->GetMutable<SelectedRows>(); var->GetMutable<SelectedRows>();
} else if (var_type == proto::VarDesc::FEED_MINIBATCH) { } else if (var_type == proto::VarType::FEED_MINIBATCH) {
var->GetMutable<FeedFetchList>(); var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarDesc::FETCH_LIST) { } else if (var_type == proto::VarType::FETCH_LIST) {
var->GetMutable<FeedFetchList>(); var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarDesc::STEP_SCOPES) { } else if (var_type == proto::VarType::STEP_SCOPES) {
var->GetMutable<std::vector<framework::Scope>>(); var->GetMutable<std::vector<framework::Scope>>();
} else if (var_type == proto::VarDesc::LOD_RANK_TABLE) { } else if (var_type == proto::VarType::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>(); var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarDesc::LOD_TENSOR_ARRAY) { } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
var->GetMutable<LoDTensorArray>(); var->GetMutable<LoDTensorArray>();
} else if (var_type == proto::VarDesc::PLACE_LIST) { } else if (var_type == proto::VarType::PLACE_LIST) {
var->GetMutable<platform::PlaceList>(); var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarDesc::READER) { } else if (var_type == proto::VarType::READER) {
var->GetMutable<ReaderHolder>(); var->GetMutable<ReaderHolder>();
} else { } else {
PADDLE_THROW( PADDLE_THROW(
...@@ -182,7 +182,7 @@ static bool has_feed_operators( ...@@ -182,7 +182,7 @@ static bool has_feed_operators(
auto var = block->FindVar(feed_holder_name); auto var = block->FindVar(feed_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
feed_holder_name); feed_holder_name);
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarDesc::FEED_MINIBATCH, PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
"'%s' variable should be 'FEED_MINIBATCH' type", "'%s' variable should be 'FEED_MINIBATCH' type",
feed_holder_name); feed_holder_name);
} }
...@@ -222,7 +222,7 @@ static bool has_fetch_operators( ...@@ -222,7 +222,7 @@ static bool has_fetch_operators(
auto var = block->FindVar(fetch_holder_name); auto var = block->FindVar(fetch_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
fetch_holder_name); fetch_holder_name);
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarDesc::FETCH_LIST, PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
"'%s' variable should be 'FETCH_LIST' type", "'%s' variable should be 'FETCH_LIST' type",
fetch_holder_name); fetch_holder_name);
} }
...@@ -241,7 +241,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -241,7 +241,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) { if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) {
// create feed_holder variable // create feed_holder variable
auto* feed_holder = global_block->Var(feed_holder_name); auto* feed_holder = global_block->Var(feed_holder_name);
feed_holder->SetType(proto::VarDesc::FEED_MINIBATCH); feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
feed_holder->SetPersistable(true); feed_holder->SetPersistable(true);
int i = 0; int i = 0;
...@@ -274,7 +274,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -274,7 +274,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) { if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) {
// create fetch_holder variable // create fetch_holder variable
auto* fetch_holder = global_block->Var(fetch_holder_name); auto* fetch_holder = global_block->Var(fetch_holder_name);
fetch_holder->SetType(proto::VarDesc::FETCH_LIST); fetch_holder->SetType(proto::VarType::FETCH_LIST);
fetch_holder->SetPersistable(true); fetch_holder->SetPersistable(true);
int i = 0; int i = 0;
......
...@@ -101,25 +101,8 @@ enum DataType { ...@@ -101,25 +101,8 @@ enum DataType {
FP64 = 6; FP64 = 6;
} }
message TensorDesc { message VarType {
required DataType data_type = 1; enum Type {
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
}
message LoDTensorDesc {
required TensorDesc tensor = 1;
optional int32 lod_level = 2 [ default = 0 ];
}
message LoDTensorArrayDesc {
required TensorDesc tensor = 1;
optional int32 lod_level = 2 [ default = 0 ];
}
message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; }
message VarDesc {
enum VarType {
LOD_TENSOR = 1; LOD_TENSOR = 1;
SELECTED_ROWS = 2; SELECTED_ROWS = 2;
FEED_MINIBATCH = 3; FEED_MINIBATCH = 3;
...@@ -130,13 +113,35 @@ message VarDesc { ...@@ -130,13 +113,35 @@ message VarDesc {
PLACE_LIST = 8; PLACE_LIST = 8;
READER = 9; READER = 9;
} }
required Type type = 1;
message TensorDesc {
required DataType data_type = 1;
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
}
optional TensorDesc selected_rows = 2;
message LoDTensorDesc {
required TensorDesc tensor = 1;
optional int32 lod_level = 2 [ default = 0 ];
}
optional LoDTensorDesc lod_tensor = 3;
message LoDTensorArrayDesc {
required TensorDesc tensor = 1;
optional int32 lod_level = 2 [ default = 0 ];
}
optional LoDTensorArrayDesc tensor_array = 4;
message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; }
optional ReaderDesc reader = 5;
}
message VarDesc {
required string name = 1; required string name = 1;
required VarType type = 2; required VarType type = 2;
optional bool persistable = 3 [ default = false ]; optional bool persistable = 3 [ default = false ];
optional LoDTensorDesc lod_tensor = 4;
optional TensorDesc selected_rows = 5;
optional LoDTensorArrayDesc tensor_array = 6;
optional ReaderDesc reader = 7;
} }
message BlockDesc { message BlockDesc {
......
...@@ -53,11 +53,11 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -53,11 +53,11 @@ class CompileTimeInferShapeContext : public InferShapeContext {
PADDLE_ENFORCE_LT(j, Outputs(out).size()); PADDLE_ENFORCE_LT(j, Outputs(out).size());
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]); auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]); auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
if (in_var->GetType() != proto::VarDesc::LOD_TENSOR) { if (in_var->GetType() != proto::VarType::LOD_TENSOR) {
VLOG(3) << "input " << in << " is not LodTensor"; VLOG(3) << "input " << in << " is not LodTensor";
return; return;
} }
PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarDesc::LOD_TENSOR, PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarType::LOD_TENSOR,
"The %d-th output of Output(%s) must be LoDTensor.", j, "The %d-th output of Output(%s) must be LoDTensor.", j,
out); out);
out_var->SetLoDLevel(in_var->GetLoDLevel()); out_var->SetLoDLevel(in_var->GetLoDLevel());
...@@ -66,7 +66,7 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -66,7 +66,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool IsRuntime() const override; bool IsRuntime() const override;
protected: protected:
proto::VarDesc::VarType GetVarType(const std::string &name) const override; proto::VarType::Type GetVarType(const std::string &name) const override;
DDim GetDim(const std::string &name) const override; DDim GetDim(const std::string &name) const override;
...@@ -388,7 +388,7 @@ void OpDesc::InferVarType(BlockDesc *block) const { ...@@ -388,7 +388,7 @@ void OpDesc::InferVarType(BlockDesc *block) const {
for (auto &out_pair : this->outputs_) { for (auto &out_pair : this->outputs_) {
for (auto &out_var_name : out_pair.second) { for (auto &out_var_name : out_pair.second) {
block->FindRecursiveOrCreateVar(out_var_name) block->FindRecursiveOrCreateVar(out_var_name)
.SetType(proto::VarDesc::LOD_TENSOR); .SetType(proto::VarType::LOD_TENSOR);
} }
} }
} }
...@@ -507,7 +507,7 @@ void CompileTimeInferShapeContext::SetRepeatedDims( ...@@ -507,7 +507,7 @@ void CompileTimeInferShapeContext::SetRepeatedDims(
bool CompileTimeInferShapeContext::IsRuntime() const { return false; } bool CompileTimeInferShapeContext::IsRuntime() const { return false; }
proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType( proto::VarType::Type CompileTimeInferShapeContext::GetVarType(
const std::string &name) const { const std::string &name) const {
return block_.FindVarRecursive(name)->GetType(); return block_.FindVarRecursive(name)->GetType();
} }
......
...@@ -477,7 +477,7 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -477,7 +477,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
} }
proto::VarDesc::VarType GetVarType(const std::string& name) const override { proto::VarType::Type GetVarType(const std::string& name) const override {
auto* var = scope_.FindVar(name); auto* var = scope_.FindVar(name);
return ToVarType(var->Type()); return ToVarType(var->Type());
} }
......
...@@ -22,13 +22,13 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -22,13 +22,13 @@ TEST(ProgramDesc, copy_ctor) {
ProgramDesc program; ProgramDesc program;
auto* global_block = program.MutableBlock(0); auto* global_block = program.MutableBlock(0);
auto* x = global_block->Var("X"); auto* x = global_block->Var("X");
x->SetType(proto::VarDesc_VarType_LOD_TENSOR); x->SetType(proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0); x->SetLoDLevel(0);
x->SetDataType(proto::FP32); x->SetDataType(proto::FP32);
x->SetShape({1000, 784}); x->SetShape({1000, 784});
auto* y = global_block->Var("Y"); auto* y = global_block->Var("Y");
y->SetType(proto::VarDesc_VarType_LOD_TENSOR); y->SetType(proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0); y->SetLoDLevel(0);
y->SetDataType(proto::FP32); y->SetDataType(proto::FP32);
y->SetShape({784, 100}); y->SetShape({784, 100});
...@@ -39,7 +39,7 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -39,7 +39,7 @@ TEST(ProgramDesc, copy_ctor) {
op->SetInput("Y", {y->Name()}); op->SetInput("Y", {y->Name()});
auto* out = global_block->Var("Out"); auto* out = global_block->Var("Out");
out->SetType(proto::VarDesc_VarType_LOD_TENSOR); out->SetType(proto::VarType::LOD_TENSOR);
op->SetOutput("Y", {out->Name()}); op->SetOutput("Y", {out->Name()});
ProgramDesc program_copy(program); ProgramDesc program_copy(program);
...@@ -84,13 +84,13 @@ TEST(ProgramDescBind, serialize_and_deserialize) { ...@@ -84,13 +84,13 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
ProgramDesc program_origin; ProgramDesc program_origin;
auto* global_block = program_origin.MutableBlock(0); auto* global_block = program_origin.MutableBlock(0);
auto* x = global_block->Var("X"); auto* x = global_block->Var("X");
x->SetType(proto::VarDesc_VarType_LOD_TENSOR); x->SetType(proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0); x->SetLoDLevel(0);
x->SetDataType(proto::FP32); x->SetDataType(proto::FP32);
x->SetShape({1000, 784}); x->SetShape({1000, 784});
auto* y = global_block->Var("Y"); auto* y = global_block->Var("Y");
y->SetType(proto::VarDesc_VarType_LOD_TENSOR); y->SetType(proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0); y->SetLoDLevel(0);
y->SetDataType(proto::FP32); y->SetDataType(proto::FP32);
y->SetShape({784, 100}); y->SetShape({784, 100});
...@@ -101,7 +101,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) { ...@@ -101,7 +101,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
op->SetInput("Y", {y->Name()}); op->SetInput("Y", {y->Name()});
auto* out = global_block->Var("Out"); auto* out = global_block->Var("Out");
out->SetType(proto::VarDesc_VarType_LOD_TENSOR); out->SetType(proto::VarType::LOD_TENSOR);
op->SetOutput("Y", {out->Name()}); op->SetOutput("Y", {out->Name()});
std::string binary_str; std::string binary_str;
......
...@@ -116,19 +116,19 @@ void InferShapeContext::SetDims(const std::vector<std::string> &names, ...@@ -116,19 +116,19 @@ void InferShapeContext::SetDims(const std::vector<std::string> &names,
} }
} }
std::vector<proto::VarDesc::VarType> InferShapeContext::GetInputsVarType( std::vector<proto::VarType::Type> InferShapeContext::GetInputsVarType(
const std::string &name) const { const std::string &name) const {
return GetVarTypes(Inputs(name)); return GetVarTypes(Inputs(name));
} }
std::vector<proto::VarDesc::VarType> InferShapeContext::GetOutputsVarType( std::vector<proto::VarType::Type> InferShapeContext::GetOutputsVarType(
const std::string &name) const { const std::string &name) const {
return GetVarTypes(Outputs(name)); return GetVarTypes(Outputs(name));
} }
std::vector<proto::VarDesc::VarType> InferShapeContext::GetVarTypes( std::vector<proto::VarType::Type> InferShapeContext::GetVarTypes(
const std::vector<std::string> &names) const { const std::vector<std::string> &names) const {
std::vector<proto::VarDesc::VarType> retv; std::vector<proto::VarType::Type> retv;
retv.resize(names.size()); retv.resize(names.size());
std::transform(names.begin(), names.end(), retv.begin(), std::transform(names.begin(), names.end(), retv.begin(),
std::bind(std::mem_fn(&InferShapeContext::GetVarType), this, std::bind(std::mem_fn(&InferShapeContext::GetVarType), this,
......
...@@ -31,9 +31,9 @@ class InferShapeContext { ...@@ -31,9 +31,9 @@ class InferShapeContext {
virtual bool HasInput(const std::string &name) const = 0; virtual bool HasInput(const std::string &name) const = 0;
virtual bool HasOutput(const std::string &name) const = 0; virtual bool HasOutput(const std::string &name) const = 0;
std::vector<proto::VarDesc::VarType> GetInputsVarType( std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const; const std::string &name) const;
std::vector<proto::VarDesc::VarType> GetOutputsVarType( std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string &name) const; const std::string &name) const;
virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasInputs(const std::string &name) const = 0;
...@@ -75,10 +75,10 @@ class InferShapeContext { ...@@ -75,10 +75,10 @@ class InferShapeContext {
std::vector<DDim> GetDims(const std::vector<std::string> &names) const; std::vector<DDim> GetDims(const std::vector<std::string> &names) const;
std::vector<proto::VarDesc::VarType> GetVarTypes( std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<std::string> &names) const; const std::vector<std::string> &names) const;
virtual proto::VarDesc::VarType GetVarType(const std::string &name) const = 0; virtual proto::VarType::Type GetVarType(const std::string &name) const = 0;
virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0; virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0;
}; };
......
...@@ -225,7 +225,7 @@ inline void SerializeToStream(std::ostream& os, const Tensor& tensor, ...@@ -225,7 +225,7 @@ inline void SerializeToStream(std::ostream& os, const Tensor& tensor,
{ // the 2nd field, tensor description { // the 2nd field, tensor description
// int32_t size // int32_t size
// void* protobuf message // void* protobuf message
proto::TensorDesc desc; proto::VarType::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type())); desc.set_data_type(framework::ToDataType(tensor.type()));
auto dims = framework::vectorize(tensor.dims()); auto dims = framework::vectorize(tensor.dims());
auto* pb_dims = desc.mutable_dims(); auto* pb_dims = desc.mutable_dims();
...@@ -290,7 +290,7 @@ inline void DeserializeFromStream(std::istream& is, Tensor* tensor, ...@@ -290,7 +290,7 @@ inline void DeserializeFromStream(std::istream& is, Tensor* tensor,
uint32_t version; uint32_t version;
is.read(reinterpret_cast<char*>(&version), sizeof(version)); is.read(reinterpret_cast<char*>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
proto::TensorDesc desc; proto::VarType::TensorDesc desc;
{ // int32_t size { // int32_t size
// proto buffer // proto buffer
int32_t size; int32_t size;
......
...@@ -18,18 +18,21 @@ limitations under the License. */ ...@@ -18,18 +18,21 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
proto::VarDesc::VarType VarDesc::GetType() const { return desc_.type(); } proto::VarType::Type VarDesc::GetType() const { return desc_.type().type(); }
void VarDesc::SetType(proto::VarDesc::VarType type) { desc_.set_type(type); } void VarDesc::SetType(proto::VarType::Type type) {
desc_.mutable_type()->set_type(type);
}
void VarDesc::SetShape(const std::vector<int64_t> &dims) { void VarDesc::SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims()); VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
} }
void VarDesc::SetTensorDescNum(size_t num) { void VarDesc::SetTensorDescNum(size_t num) {
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::READER: { case proto::VarType::READER: {
auto *lod_tensors_ptr = desc_.mutable_reader()->mutable_lod_tensor(); auto *lod_tensors_ptr =
desc_.mutable_type()->mutable_reader()->mutable_lod_tensor();
lod_tensors_ptr->Clear(); lod_tensors_ptr->Clear();
for (size_t i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
lod_tensors_ptr->Add(); lod_tensors_ptr->Add();
...@@ -44,9 +47,9 @@ void VarDesc::SetTensorDescNum(size_t num) { ...@@ -44,9 +47,9 @@ void VarDesc::SetTensorDescNum(size_t num) {
} }
size_t VarDesc::GetTensorDescNum() const { size_t VarDesc::GetTensorDescNum() const {
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::READER: case proto::VarType::READER:
return desc_.reader().lod_tensor_size(); return desc_.type().reader().lod_tensor_size();
break; break;
default: default:
PADDLE_THROW( PADDLE_THROW(
...@@ -64,7 +67,7 @@ void VarDesc::SetShapes( ...@@ -64,7 +67,7 @@ void VarDesc::SetShapes(
<< "). The Reader is going to be reinitialized."; << "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_dims.size()); SetTensorDescNum(multiple_dims.size());
} }
std::vector<proto::TensorDesc *> tensors = mutable_tensor_descs(); std::vector<proto::VarType::TensorDesc *> tensors = mutable_tensor_descs();
for (size_t i = 0; i < multiple_dims.size(); ++i) { for (size_t i = 0; i < multiple_dims.size(); ++i) {
VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims()); VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims());
} }
...@@ -75,7 +78,7 @@ std::vector<int64_t> VarDesc::GetShape() const { ...@@ -75,7 +78,7 @@ std::vector<int64_t> VarDesc::GetShape() const {
} }
std::vector<std::vector<int64_t>> VarDesc::GetShapes() const { std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
std::vector<proto::TensorDesc> descs = tensor_descs(); std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
std::vector<std::vector<int64_t>> res; std::vector<std::vector<int64_t>> res;
res.reserve(descs.size()); res.reserve(descs.size());
for (const auto &tensor_desc : descs) { for (const auto &tensor_desc : descs) {
...@@ -98,7 +101,8 @@ void VarDesc::SetDataTypes( ...@@ -98,7 +101,8 @@ void VarDesc::SetDataTypes(
<< "). The Reader is going to be reinitialized."; << "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_data_type.size()); SetTensorDescNum(multiple_data_type.size());
} }
std::vector<proto::TensorDesc *> tensor_descs = mutable_tensor_descs(); std::vector<proto::VarType::TensorDesc *> tensor_descs =
mutable_tensor_descs();
for (size_t i = 0; i < multiple_data_type.size(); ++i) { for (size_t i = 0; i < multiple_data_type.size(); ++i) {
tensor_descs[i]->set_data_type(multiple_data_type[i]); tensor_descs[i]->set_data_type(multiple_data_type[i]);
} }
...@@ -109,7 +113,7 @@ proto::DataType VarDesc::GetDataType() const { ...@@ -109,7 +113,7 @@ proto::DataType VarDesc::GetDataType() const {
} }
std::vector<proto::DataType> VarDesc::GetDataTypes() const { std::vector<proto::DataType> VarDesc::GetDataTypes() const {
std::vector<proto::TensorDesc> descs = tensor_descs(); std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
std::vector<proto::DataType> res; std::vector<proto::DataType> res;
res.reserve(descs.size()); res.reserve(descs.size());
for (const auto &tensor_desc : descs) { for (const auto &tensor_desc : descs) {
...@@ -119,12 +123,12 @@ std::vector<proto::DataType> VarDesc::GetDataTypes() const { ...@@ -119,12 +123,12 @@ std::vector<proto::DataType> VarDesc::GetDataTypes() const {
} }
void VarDesc::SetLoDLevel(int32_t lod_level) { void VarDesc::SetLoDLevel(int32_t lod_level) {
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::LOD_TENSOR: case proto::VarType::LOD_TENSOR:
desc_.mutable_lod_tensor()->set_lod_level(lod_level); desc_.mutable_type()->mutable_lod_tensor()->set_lod_level(lod_level);
break; break;
case proto::VarDesc::LOD_TENSOR_ARRAY: case proto::VarType::LOD_TENSOR_ARRAY:
desc_.mutable_tensor_array()->set_lod_level(lod_level); desc_.mutable_type()->mutable_tensor_array()->set_lod_level(lod_level);
break; break;
default: default:
PADDLE_THROW( PADDLE_THROW(
...@@ -142,10 +146,11 @@ void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) { ...@@ -142,10 +146,11 @@ void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
<< "). The Reader is going to be reinitialized."; << "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_lod_level.size()); SetTensorDescNum(multiple_lod_level.size());
} }
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::READER: { case proto::VarType::READER: {
size_t i = 0; size_t i = 0;
for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) { for (auto &lod_tensor :
*desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) {
lod_tensor.set_lod_level(multiple_lod_level[i++]); lod_tensor.set_lod_level(multiple_lod_level[i++]);
} }
} break; } break;
...@@ -157,11 +162,11 @@ void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) { ...@@ -157,11 +162,11 @@ void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
} }
int32_t VarDesc::GetLoDLevel() const { int32_t VarDesc::GetLoDLevel() const {
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::LOD_TENSOR: case proto::VarType::LOD_TENSOR:
return desc_.lod_tensor().lod_level(); return desc_.type().lod_tensor().lod_level();
case proto::VarDesc::LOD_TENSOR_ARRAY: case proto::VarType::LOD_TENSOR_ARRAY:
return desc_.tensor_array().lod_level(); return desc_.type().tensor_array().lod_level();
default: default:
PADDLE_THROW( PADDLE_THROW(
"Getting 'lod_level' is not supported by the type of var %s.", "Getting 'lod_level' is not supported by the type of var %s.",
...@@ -171,10 +176,10 @@ int32_t VarDesc::GetLoDLevel() const { ...@@ -171,10 +176,10 @@ int32_t VarDesc::GetLoDLevel() const {
std::vector<int32_t> VarDesc::GetLoDLevels() const { std::vector<int32_t> VarDesc::GetLoDLevels() const {
std::vector<int32_t> res; std::vector<int32_t> res;
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::READER: case proto::VarType::READER:
res.reserve(desc_.reader().lod_tensor_size()); res.reserve(desc_.type().reader().lod_tensor_size());
for (auto &lod_tensor : desc_.reader().lod_tensor()) { for (auto &lod_tensor : desc_.type().reader().lod_tensor()) {
res.push_back(lod_tensor.lod_level()); res.push_back(lod_tensor.lod_level());
} }
return res; return res;
...@@ -186,15 +191,16 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const { ...@@ -186,15 +191,16 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const {
} }
} }
const proto::TensorDesc &VarDesc::tensor_desc() const { const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set."); PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
switch (desc_.type()) { PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
case proto::VarDesc::SELECTED_ROWS: switch (desc_.type().type()) {
return desc_.selected_rows(); case proto::VarType::SELECTED_ROWS:
case proto::VarDesc::LOD_TENSOR: return desc_.type().selected_rows();
return desc_.lod_tensor().tensor(); case proto::VarType::LOD_TENSOR:
case proto::VarDesc::LOD_TENSOR_ARRAY: return desc_.type().lod_tensor().tensor();
return desc_.tensor_array().tensor(); case proto::VarType::LOD_TENSOR_ARRAY:
return desc_.type().tensor_array().tensor();
default: default:
PADDLE_THROW( PADDLE_THROW(
"Getting 'tensor_desc' is not supported by the type of var %s.", "Getting 'tensor_desc' is not supported by the type of var %s.",
...@@ -202,13 +208,13 @@ const proto::TensorDesc &VarDesc::tensor_desc() const { ...@@ -202,13 +208,13 @@ const proto::TensorDesc &VarDesc::tensor_desc() const {
} }
} }
std::vector<proto::TensorDesc> VarDesc::tensor_descs() const { std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
std::vector<proto::TensorDesc> res; std::vector<proto::VarType::TensorDesc> res;
res.reserve(GetTensorDescNum()); res.reserve(GetTensorDescNum());
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::READER: case proto::VarType::READER:
for (const auto &lod_tensor : desc_.reader().lod_tensor()) { for (const auto &lod_tensor : desc_.type().reader().lod_tensor()) {
res.push_back(lod_tensor.tensor()); res.push_back(lod_tensor.tensor());
} }
return res; return res;
...@@ -220,15 +226,16 @@ std::vector<proto::TensorDesc> VarDesc::tensor_descs() const { ...@@ -220,15 +226,16 @@ std::vector<proto::TensorDesc> VarDesc::tensor_descs() const {
} }
} }
proto::TensorDesc *VarDesc::mutable_tensor_desc() { proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
switch (desc_.type()) { PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
case proto::VarDesc::SELECTED_ROWS: switch (desc_.type().type()) {
return desc_.mutable_selected_rows(); case proto::VarType::SELECTED_ROWS:
case proto::VarDesc::LOD_TENSOR: return desc_.mutable_type()->mutable_selected_rows();
return desc_.mutable_lod_tensor()->mutable_tensor(); case proto::VarType::LOD_TENSOR:
case proto::VarDesc::LOD_TENSOR_ARRAY: return desc_.mutable_type()->mutable_lod_tensor()->mutable_tensor();
return desc_.mutable_tensor_array()->mutable_tensor(); case proto::VarType::LOD_TENSOR_ARRAY:
return desc_.mutable_type()->mutable_tensor_array()->mutable_tensor();
default: default:
PADDLE_THROW( PADDLE_THROW(
"Getting 'mutable_tensor_desc' is not supported by the type of var " "Getting 'mutable_tensor_desc' is not supported by the type of var "
...@@ -237,13 +244,15 @@ proto::TensorDesc *VarDesc::mutable_tensor_desc() { ...@@ -237,13 +244,15 @@ proto::TensorDesc *VarDesc::mutable_tensor_desc() {
} }
} }
std::vector<proto::TensorDesc *> VarDesc::mutable_tensor_descs() { std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
std::vector<proto::TensorDesc *> res; PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
std::vector<proto::VarType::TensorDesc *> res;
res.reserve(GetTensorDescNum()); res.reserve(GetTensorDescNum());
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::READER: case proto::VarType::READER:
for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) { for (auto &lod_tensor :
*desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) {
res.push_back(lod_tensor.mutable_tensor()); res.push_back(lod_tensor.mutable_tensor());
} }
return res; return res;
......
...@@ -57,7 +57,7 @@ class VarDesc { ...@@ -57,7 +57,7 @@ class VarDesc {
public: public:
explicit VarDesc(const std::string &name) { explicit VarDesc(const std::string &name) {
desc_.set_name(name); desc_.set_name(name);
desc_.set_type(proto::VarDesc::LOD_TENSOR); desc_.mutable_type()->set_type(proto::VarType::LOD_TENSOR);
} }
explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {} explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {}
...@@ -96,19 +96,19 @@ class VarDesc { ...@@ -96,19 +96,19 @@ class VarDesc {
std::vector<int32_t> GetLoDLevels() const; std::vector<int32_t> GetLoDLevels() const;
proto::VarDesc::VarType GetType() const; proto::VarType::Type GetType() const;
void SetType(proto::VarDesc::VarType type); void SetType(proto::VarType::Type type);
bool Persistable() const { return desc_.persistable(); } bool Persistable() const { return desc_.persistable(); }
void SetPersistable(bool persistable) { desc_.set_persistable(persistable); } void SetPersistable(bool persistable) { desc_.set_persistable(persistable); }
private: private:
const proto::TensorDesc &tensor_desc() const; const proto::VarType::TensorDesc &tensor_desc() const;
std::vector<proto::TensorDesc> tensor_descs() const; std::vector<proto::VarType::TensorDesc> tensor_descs() const;
proto::TensorDesc *mutable_tensor_desc(); proto::VarType::TensorDesc *mutable_tensor_desc();
std::vector<proto::TensorDesc *> mutable_tensor_descs(); std::vector<proto::VarType::TensorDesc *> mutable_tensor_descs();
proto::VarDesc desc_; proto::VarDesc desc_;
}; };
......
...@@ -23,17 +23,17 @@ limitations under the License. */ ...@@ -23,17 +23,17 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
inline proto::VarDesc::VarType ToVarType(std::type_index type) { inline proto::VarType::Type ToVarType(std::type_index type) {
if (type.hash_code() == typeid(LoDTensor).hash_code()) { if (type.hash_code() == typeid(LoDTensor).hash_code()) {
return proto::VarDesc_VarType_LOD_TENSOR; return proto::VarType_Type_LOD_TENSOR;
} else if (type.hash_code() == typeid(LoDRankTable).hash_code()) { } else if (type.hash_code() == typeid(LoDRankTable).hash_code()) {
return proto::VarDesc_VarType_LOD_RANK_TABLE; return proto::VarType_Type_LOD_RANK_TABLE;
} else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) { } else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) {
return proto::VarDesc_VarType_LOD_TENSOR_ARRAY; return proto::VarType_Type_LOD_TENSOR_ARRAY;
} else if (type.hash_code() == typeid(SelectedRows).hash_code()) { } else if (type.hash_code() == typeid(SelectedRows).hash_code()) {
return proto::VarDesc_VarType_SELECTED_ROWS; return proto::VarType_Type_SELECTED_ROWS;
} else if (type.hash_code() == typeid(ReaderHolder).hash_code()) { } else if (type.hash_code() == typeid(ReaderHolder).hash_code()) {
return proto::VarDesc_VarType_READER; return proto::VarType_Type_READER;
} else { } else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name()); PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
} }
...@@ -42,19 +42,19 @@ inline proto::VarDesc::VarType ToVarType(std::type_index type) { ...@@ -42,19 +42,19 @@ inline proto::VarDesc::VarType ToVarType(std::type_index type) {
template <typename Visitor> template <typename Visitor>
inline void VisitVarType(const framework::Variable& var, Visitor visitor) { inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
switch (ToVarType(var.Type())) { switch (ToVarType(var.Type())) {
case proto::VarDesc_VarType_LOD_TENSOR: case proto::VarType_Type_LOD_TENSOR:
visitor(var.Get<LoDTensor>()); visitor(var.Get<LoDTensor>());
return; return;
case proto::VarDesc_VarType_LOD_RANK_TABLE: case proto::VarType_Type_LOD_RANK_TABLE:
visitor(var.Get<LoDRankTable>()); visitor(var.Get<LoDRankTable>());
return; return;
case proto::VarDesc_VarType_LOD_TENSOR_ARRAY: case proto::VarType_Type_LOD_TENSOR_ARRAY:
visitor(var.Get<LoDTensorArray>()); visitor(var.Get<LoDTensorArray>());
return; return;
case proto::VarDesc_VarType_SELECTED_ROWS: case proto::VarType_Type_SELECTED_ROWS:
visitor(var.Get<SelectedRows>()); visitor(var.Get<SelectedRows>());
return; return;
case proto::VarDesc_VarType_READER: case proto::VarType_Type_READER:
visitor(var.Get<ReaderHolder>()); visitor(var.Get<ReaderHolder>());
return; return;
default: default:
......
...@@ -35,14 +35,14 @@ class SumOpVarTypeInference : public VarTypeInference { ...@@ -35,14 +35,14 @@ class SumOpVarTypeInference : public VarTypeInference {
public: public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override { void operator()(const OpDesc &op_desc, BlockDesc *block) const override {
auto &inputs = op_desc.Input("X"); auto &inputs = op_desc.Input("X");
auto default_var_type = proto::VarDesc::SELECTED_ROWS; auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of( bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string &name) { inputs.begin(), inputs.end(), [block](const std::string &name) {
return block->Var(name)->GetType() == proto::VarDesc::LOD_TENSOR; return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR;
}); });
if (any_input_is_lod_tensor) { if (any_input_is_lod_tensor) {
default_var_type = proto::VarDesc::LOD_TENSOR; default_var_type = proto::VarType::LOD_TENSOR;
} }
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = op_desc.Output("Out").front();
...@@ -67,19 +67,19 @@ TEST(InferVarType, sum_op) { ...@@ -67,19 +67,19 @@ TEST(InferVarType, sum_op) {
op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"}); op->SetOutput("Out", {"test_out"});
prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_out"); prog.MutableBlock(0)->Var("test_out");
op->InferVarType(prog.MutableBlock(0)); op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarDesc::SELECTED_ROWS, ASSERT_EQ(proto::VarType::SELECTED_ROWS,
prog.MutableBlock(0)->Var("test_out")->GetType()); prog.MutableBlock(0)->Var("test_out")->GetType());
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarDesc::LOD_TENSOR); prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::LOD_TENSOR);
op->InferVarType(prog.MutableBlock(0)); op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarDesc::LOD_TENSOR, ASSERT_EQ(proto::VarType::LOD_TENSOR,
prog.MutableBlock(0)->Var("test_out")->GetType()); prog.MutableBlock(0)->Var("test_out")->GetType());
} }
...@@ -90,14 +90,14 @@ TEST(InferVarType, sum_op_without_infer_var_type) { ...@@ -90,14 +90,14 @@ TEST(InferVarType, sum_op_without_infer_var_type) {
op->SetInput("X", {"test2_a", "test2_b", "test2_c"}); op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
op->SetOutput("Out", {"test2_out"}); op->SetOutput("Out", {"test2_out"});
prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_out"); prog.MutableBlock(0)->Var("test2_out");
op->InferVarType(prog.MutableBlock(0)); op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarDesc_VarType_LOD_TENSOR, ASSERT_EQ(proto::VarType_Type_LOD_TENSOR,
prog.MutableBlock(0)->Var("test2_out")->GetType()); prog.MutableBlock(0)->Var("test2_out")->GetType());
} }
......
...@@ -115,8 +115,8 @@ class AssignInferShape : public framework::InferShapeBase { ...@@ -115,8 +115,8 @@ class AssignInferShape : public framework::InferShapeBase {
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext *context) const override {
if (context->HasInput("X")) { if (context->HasInput("X")) {
auto type = context->GetInputsVarType("X")[0]; auto type = context->GetInputsVarType("X")[0];
if (type == framework::proto::VarDesc_VarType_SELECTED_ROWS || if (type == framework::proto::VarType::SELECTED_ROWS ||
type == framework::proto::VarDesc_VarType_LOD_TENSOR) { type == framework::proto::VarType::LOD_TENSOR) {
context->SetOutputDim("Out", context->GetInputDim("X")); context->SetOutputDim("Out", context->GetInputDim("X"));
} }
} }
......
...@@ -128,10 +128,10 @@ class BeamSearchDecodeInferVarType : public framework::VarTypeInference { ...@@ -128,10 +128,10 @@ class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDesc& op_desc, void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override { framework::BlockDesc* block) const override {
for (auto& o : op_desc.Output("SentenceIds")) { for (auto& o : op_desc.Output("SentenceIds")) {
block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR); block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
} }
for (auto& o : op_desc.Output("SentenceScores")) { for (auto& o : op_desc.Output("SentenceScores")) {
block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR); block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
} }
} }
}; };
......
...@@ -240,10 +240,10 @@ class BeamSearchInferVarType : public framework::VarTypeInference { ...@@ -240,10 +240,10 @@ class BeamSearchInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDesc &op_desc, void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override { framework::BlockDesc *block) const override {
for (auto &o : op_desc.Output("selected_ids")) { for (auto &o : op_desc.Output("selected_ids")) {
block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR); block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
} }
for (auto &o : op_desc.Output("selected_scores")) { for (auto &o : op_desc.Output("selected_scores")) {
block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR); block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
} }
} }
}; };
......
...@@ -84,7 +84,7 @@ class CreateFileReaderInferVarType : public framework::VarTypeInference { ...@@ -84,7 +84,7 @@ class CreateFileReaderInferVarType : public framework::VarTypeInference {
framework::BlockDesc* block) const override { framework::BlockDesc* block) const override {
std::string reader_name = op_desc.Output("Out")[0]; std::string reader_name = op_desc.Output("Out")[0];
framework::VarDesc* reader = block->FindVarRecursive(reader_name); framework::VarDesc* reader = block->FindVarRecursive(reader_name);
reader->SetType(framework::proto::VarDesc::READER); reader->SetType(framework::proto::VarType::READER);
} }
}; };
...@@ -97,7 +97,7 @@ class CreateDecoratedReaderInferVarType : public framework::VarTypeInference { ...@@ -97,7 +97,7 @@ class CreateDecoratedReaderInferVarType : public framework::VarTypeInference {
framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name); framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name);
std::string out_reader_name = op_desc.Output("Out")[0]; std::string out_reader_name = op_desc.Output("Out")[0];
framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name); framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name);
out_reader->SetType(framework::proto::VarDesc::READER); out_reader->SetType(framework::proto::VarType::READER);
out_reader->SetDataTypes(in_reader->GetDataTypes()); out_reader->SetDataTypes(in_reader->GetDataTypes());
} }
}; };
...@@ -147,7 +147,7 @@ class CreateRandomDataGeneratorOpMaker ...@@ -147,7 +147,7 @@ class CreateRandomDataGeneratorOpMaker
AddComment(R"DOC( AddComment(R"DOC(
CreateRandomDataGenerator Operator CreateRandomDataGenerator Operator
This Op creates a random reader. This Op creates a random reader.
The reader generates random data instead of really reading from files. The reader generates random data instead of really reading from files.
Generated data follow an uniform distribution between 'min' and 'max'. Generated data follow an uniform distribution between 'min' and 'max'.
)DOC"); )DOC");
...@@ -183,7 +183,7 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -183,7 +183,7 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
CreateShuffleReader Operator CreateShuffleReader Operator
A shuffle reader takes another reader as its 'underlying reader' A shuffle reader takes another reader as its 'underlying reader'
and yields the underlying reader's outputs in a shuffled order. and yields the underlying reader's outputs in a shuffled order.
)DOC"); )DOC");
} }
}; };
...@@ -218,8 +218,8 @@ class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -218,8 +218,8 @@ class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
CreateBatchReader Operator CreateBatchReader Operator
A batch reader takes another reader as its 'underlying reader', A batch reader takes another reader as its 'underlying reader',
gathers the underlying reader's outputs and then yields them in batches. gathers the underlying reader's outputs and then yields them in batches.
)DOC"); )DOC");
} }
}; };
......
...@@ -24,11 +24,11 @@ void SerializeToMessage(const std::string& name, const framework::Variable* var, ...@@ -24,11 +24,11 @@ void SerializeToMessage(const std::string& name, const framework::Variable* var,
msg->set_varname(name); msg->set_varname(name);
std::ostringstream oss; std::ostringstream oss;
switch (framework::ToVarType(var->Type())) { switch (framework::ToVarType(var->Type())) {
case framework::proto::VarDesc_VarType_LOD_TENSOR: case framework::proto::VarType_Type_LOD_TENSOR:
msg->set_type(sendrecv::VarType::LOD_TENSOR); msg->set_type(sendrecv::VarType::LOD_TENSOR);
framework::SerializeToStream(oss, var->Get<framework::LoDTensor>(), ctx); framework::SerializeToStream(oss, var->Get<framework::LoDTensor>(), ctx);
break; break;
case framework::proto::VarDesc_VarType_SELECTED_ROWS: case framework::proto::VarType_Type_SELECTED_ROWS:
msg->set_type(sendrecv::VarType::SELECTED_ROWS); msg->set_type(sendrecv::VarType::SELECTED_ROWS);
framework::SerializeToStream(oss, var->Get<framework::SelectedRows>(), framework::SerializeToStream(oss, var->Get<framework::SelectedRows>(),
ctx); ctx);
......
...@@ -98,7 +98,7 @@ class GetPlacesInferVarType : public framework::VarTypeInference { ...@@ -98,7 +98,7 @@ class GetPlacesInferVarType : public framework::VarTypeInference {
framework::BlockDesc *block) const override { framework::BlockDesc *block) const override {
for (auto &o_name : op_desc.Output("Out")) { for (auto &o_name : op_desc.Output("Out")) {
block->FindRecursiveOrCreateVar(o_name).SetType( block->FindRecursiveOrCreateVar(o_name).SetType(
framework::proto::VarDesc::PLACE_LIST); framework::proto::VarType::PLACE_LIST);
} }
} }
}; };
......
...@@ -69,7 +69,7 @@ class LoDRankTableInferVarType : public framework::VarTypeInference { ...@@ -69,7 +69,7 @@ class LoDRankTableInferVarType : public framework::VarTypeInference {
framework::BlockDesc *block) const override { framework::BlockDesc *block) const override {
for (auto &o : op_desc.Output("Out")) { for (auto &o : op_desc.Output("Out")) {
block->FindRecursiveOrCreateVar(o).SetType( block->FindRecursiveOrCreateVar(o).SetType(
framework::proto::VarDesc::LOD_RANK_TABLE); framework::proto::VarType::LOD_RANK_TABLE);
} }
} }
}; };
......
...@@ -138,7 +138,7 @@ class LoDTensorToArrayInferVarType : public framework::VarTypeInference { ...@@ -138,7 +138,7 @@ class LoDTensorToArrayInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDesc &op_desc, void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override { framework::BlockDesc *block) const override {
for (auto &out_var : op_desc.Output("Out")) { for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY); block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR_ARRAY);
} }
} }
}; };
......
...@@ -123,11 +123,11 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference { ...@@ -123,11 +123,11 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows"; << " is set to SelectedRows";
block->Var(out_var_name) block->Var(out_var_name)
->SetType(framework::proto::VarDesc::SELECTED_ROWS); ->SetType(framework::proto::VarType::SELECTED_ROWS);
} else { } else {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor"; << " is set to LoDTensor";
block->Var(out_var_name)->SetType(framework::proto::VarDesc::LOD_TENSOR); block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR);
} }
} }
}; };
......
...@@ -45,7 +45,7 @@ class ReadInferVarType : public framework::VarTypeInference { ...@@ -45,7 +45,7 @@ class ReadInferVarType : public framework::VarTypeInference {
PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size()); PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size());
for (size_t i = 0; i < dtypes.size(); ++i) { for (size_t i = 0; i < dtypes.size(); ++i) {
framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]); framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]);
out.SetType(framework::proto::VarDesc::LOD_TENSOR); out.SetType(framework::proto::VarType::LOD_TENSOR);
out.SetDataType(dtypes[i]); out.SetDataType(dtypes[i]);
} }
} }
......
...@@ -29,7 +29,7 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -29,7 +29,7 @@ class SumOp : public framework::OperatorWithKernel {
"Output(Out) of SumOp should not be null."); "Output(Out) of SumOp should not be null.");
if (ctx->IsRuntime() && if (ctx->IsRuntime() &&
ctx->GetOutputsVarType("Out")[0] == ctx->GetOutputsVarType("Out")[0] ==
framework::proto::VarDesc::LOD_TENSOR_ARRAY) { framework::proto::VarType::LOD_TENSOR_ARRAY) {
return; // skip runtime infershape when is tensor array; return; // skip runtime infershape when is tensor array;
} }
...@@ -118,7 +118,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -118,7 +118,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
void operator()(const framework::OpDesc& op_desc, void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override { framework::BlockDesc* block) const override {
auto& inputs = op_desc.Input("X"); auto& inputs = op_desc.Input("X");
auto var_type = framework::proto::VarDesc::SELECTED_ROWS; auto var_type = framework::proto::VarType::SELECTED_ROWS;
for (auto& name : op_desc.Input("X")) { for (auto& name : op_desc.Input("X")) {
VLOG(10) << name << " " VLOG(10) << name << " "
...@@ -128,12 +128,12 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -128,12 +128,12 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
bool any_input_is_lod_tensor = std::any_of( bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string& name) { inputs.begin(), inputs.end(), [block](const std::string& name) {
return block->FindRecursiveOrCreateVar(name).GetType() == return block->FindRecursiveOrCreateVar(name).GetType() ==
framework::proto::VarDesc::LOD_TENSOR; framework::proto::VarType::LOD_TENSOR;
}); });
auto is_tensor_array = [block](const std::string& name) { auto is_tensor_array = [block](const std::string& name) {
return block->FindRecursiveOrCreateVar(name).GetType() == return block->FindRecursiveOrCreateVar(name).GetType() ==
framework::proto::VarDesc::LOD_TENSOR_ARRAY; framework::proto::VarType::LOD_TENSOR_ARRAY;
}; };
bool any_input_is_tensor_array = bool any_input_is_tensor_array =
...@@ -151,9 +151,9 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -151,9 +151,9 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
PADDLE_ENFORCE(all_inputs_are_tensor_array, PADDLE_ENFORCE(all_inputs_are_tensor_array,
"Not all inputs are tensor array:\n%s", os.str()); "Not all inputs are tensor array:\n%s", os.str());
} }
var_type = framework::proto::VarDesc::LOD_TENSOR_ARRAY; var_type = framework::proto::VarType::LOD_TENSOR_ARRAY;
} else if (any_input_is_lod_tensor) { } else if (any_input_is_lod_tensor) {
var_type = framework::proto::VarDesc::LOD_TENSOR; var_type = framework::proto::VarType::LOD_TENSOR;
} }
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = op_desc.Output("Out").front();
......
...@@ -108,7 +108,7 @@ class WriteToArrayInferVarType : public framework::VarTypeInference { ...@@ -108,7 +108,7 @@ class WriteToArrayInferVarType : public framework::VarTypeInference {
auto out_name = op_desc.Output("Out")[0]; auto out_name = op_desc.Output("Out")[0];
VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY"; VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY";
auto &out = block->FindRecursiveOrCreateVar(out_name); auto &out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY); out.SetType(framework::proto::VarType::LOD_TENSOR_ARRAY);
auto *x = block->FindVarRecursive(x_name); auto *x = block->FindVarRecursive(x_name);
if (x != nullptr) { if (x != nullptr) {
out.SetDataType(x->GetDataType()); out.SetDataType(x->GetDataType());
......
...@@ -330,10 +330,10 @@ class WhileGradOpShapeInference : public framework::InferShapeBase { ...@@ -330,10 +330,10 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
continue; continue;
} }
auto dims = ctx->GetInputsElementDim(kX, i); auto dims = ctx->GetInputsElementDim(kX, i);
if (var_types[i] == framework::proto::VarDesc::LOD_TENSOR) { if (var_types[i] == framework::proto::VarType::LOD_TENSOR) {
names_to_set.push_back(pg_names[i]); names_to_set.push_back(pg_names[i]);
dims_to_set.push_back(dims); dims_to_set.push_back(dims);
} else if (var_types[i] == framework::proto::VarDesc::LOD_TENSOR_ARRAY) { } else if (var_types[i] == framework::proto::VarType::LOD_TENSOR_ARRAY) {
// not sure how to set the dim of LOD_TENSOR_ARRAY // not sure how to set the dim of LOD_TENSOR_ARRAY
names_to_set.push_back(pg_names[i]); names_to_set.push_back(pg_names[i]);
dims_to_set.push_back(dims); dims_to_set.push_back(dims);
......
...@@ -232,16 +232,16 @@ void BindVarDsec(py::module &m) { ...@@ -232,16 +232,16 @@ void BindVarDsec(py::module &m) {
.def("persistable", &VarDesc::Persistable) .def("persistable", &VarDesc::Persistable)
.def("set_persistable", &VarDesc::SetPersistable); .def("set_persistable", &VarDesc::SetPersistable);
py::enum_<proto::VarDesc::VarType>(var_desc, "VarType", "") py::enum_<proto::VarType::Type>(var_desc, "VarType", "")
.value("LOD_TENSOR", proto::VarDesc::LOD_TENSOR) .value("LOD_TENSOR", proto::VarType::LOD_TENSOR)
.value("SELECTED_ROWS", proto::VarDesc::SELECTED_ROWS) .value("SELECTED_ROWS", proto::VarType::SELECTED_ROWS)
.value("FEED_MINIBATCH", proto::VarDesc::FEED_MINIBATCH) .value("FEED_MINIBATCH", proto::VarType::FEED_MINIBATCH)
.value("FETCH_LIST", proto::VarDesc::FETCH_LIST) .value("FETCH_LIST", proto::VarType::FETCH_LIST)
.value("STEP_SCOPES", proto::VarDesc::STEP_SCOPES) .value("STEP_SCOPES", proto::VarType::STEP_SCOPES)
.value("LOD_RANK_TABLE", proto::VarDesc::LOD_RANK_TABLE) .value("LOD_RANK_TABLE", proto::VarType::LOD_RANK_TABLE)
.value("LOD_TENSOR_ARRAY", proto::VarDesc::LOD_TENSOR_ARRAY) .value("LOD_TENSOR_ARRAY", proto::VarType::LOD_TENSOR_ARRAY)
.value("PLACE_LIST", proto::VarDesc::PLACE_LIST) .value("PLACE_LIST", proto::VarType::PLACE_LIST)
.value("READER", proto::VarDesc::READER); .value("READER", proto::VarType::READER);
} }
void BindOpDesc(py::module &m) { void BindOpDesc(py::module &m) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册