提交 fcadb452 编写于 作者: A Abhinav Arora 提交者: Yi Wang

Separate VarType from VarDesc in framework.proto and fix all related compiler errors (#8414)

* Refine Type system

* Fixing type inference

* Fixed create_reader_op.cc

* Fix var_desc.h

* Fixed executor.cc

* Fix shape_inference.h

* Fixed create_reader_op.cc

* Fix tensor_util.h

* Fixed var_type_inference_test.cc

* Fix shape_inference.cc

* Fixed sum_op.c

* Fixed read_op.cc

* Fix var_type.h

* Fixed beam_search_decode_op.cc

* sendrecvop_utils.cc

* Fix operator.cc

* Fixed lookup_table_op.cc

* Fixed op_desc.cc

* Fixed get_places_op.cc

* Fixed lod_rank_table_op.cc

* Fixed beam_search_op.cc

* Fix var_desc.cc

* Fixed lod_tensor_to_array_op.cc

* Fixed while_op.cc

* Fix program_desc_test.cc

* tensor_array_read_write_op.cc

* Fix assign_op.cc

* Fix executor.cc

* Fix protobuf.cc

* Fix protobuf.cc
上级 f82fa64a
......@@ -36,24 +36,24 @@ namespace framework {
Executor::Executor(const platform::Place& place) : place_(place) {}
static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
if (var_type == proto::VarDesc::LOD_TENSOR) {
static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>();
} else if (var_type == proto::VarDesc::SELECTED_ROWS) {
} else if (var_type == proto::VarType::SELECTED_ROWS) {
var->GetMutable<SelectedRows>();
} else if (var_type == proto::VarDesc::FEED_MINIBATCH) {
} else if (var_type == proto::VarType::FEED_MINIBATCH) {
var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarDesc::FETCH_LIST) {
} else if (var_type == proto::VarType::FETCH_LIST) {
var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarDesc::STEP_SCOPES) {
} else if (var_type == proto::VarType::STEP_SCOPES) {
var->GetMutable<std::vector<framework::Scope>>();
} else if (var_type == proto::VarDesc::LOD_RANK_TABLE) {
} else if (var_type == proto::VarType::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarDesc::LOD_TENSOR_ARRAY) {
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
var->GetMutable<LoDTensorArray>();
} else if (var_type == proto::VarDesc::PLACE_LIST) {
} else if (var_type == proto::VarType::PLACE_LIST) {
var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarDesc::READER) {
} else if (var_type == proto::VarType::READER) {
var->GetMutable<ReaderHolder>();
} else {
PADDLE_THROW(
......@@ -182,7 +182,7 @@ static bool has_feed_operators(
auto var = block->FindVar(feed_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
feed_holder_name);
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarDesc::FEED_MINIBATCH,
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
"'%s' variable should be 'FEED_MINIBATCH' type",
feed_holder_name);
}
......@@ -222,7 +222,7 @@ static bool has_fetch_operators(
auto var = block->FindVar(fetch_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
fetch_holder_name);
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarDesc::FETCH_LIST,
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
"'%s' variable should be 'FETCH_LIST' type",
fetch_holder_name);
}
......@@ -241,7 +241,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) {
// create feed_holder variable
auto* feed_holder = global_block->Var(feed_holder_name);
feed_holder->SetType(proto::VarDesc::FEED_MINIBATCH);
feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
feed_holder->SetPersistable(true);
int i = 0;
......@@ -274,7 +274,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) {
// create fetch_holder variable
auto* fetch_holder = global_block->Var(fetch_holder_name);
fetch_holder->SetType(proto::VarDesc::FETCH_LIST);
fetch_holder->SetType(proto::VarType::FETCH_LIST);
fetch_holder->SetPersistable(true);
int i = 0;
......
......@@ -101,25 +101,8 @@ enum DataType {
FP64 = 6;
}
message TensorDesc {
required DataType data_type = 1;
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
}
message LoDTensorDesc {
required TensorDesc tensor = 1;
optional int32 lod_level = 2 [ default = 0 ];
}
message LoDTensorArrayDesc {
required TensorDesc tensor = 1;
optional int32 lod_level = 2 [ default = 0 ];
}
message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; }
message VarDesc {
enum VarType {
message VarType {
enum Type {
LOD_TENSOR = 1;
SELECTED_ROWS = 2;
FEED_MINIBATCH = 3;
......@@ -130,13 +113,35 @@ message VarDesc {
PLACE_LIST = 8;
READER = 9;
}
required Type type = 1;
message TensorDesc {
required DataType data_type = 1;
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
}
optional TensorDesc selected_rows = 2;
message LoDTensorDesc {
required TensorDesc tensor = 1;
optional int32 lod_level = 2 [ default = 0 ];
}
optional LoDTensorDesc lod_tensor = 3;
message LoDTensorArrayDesc {
required TensorDesc tensor = 1;
optional int32 lod_level = 2 [ default = 0 ];
}
optional LoDTensorArrayDesc tensor_array = 4;
message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; }
optional ReaderDesc reader = 5;
}
message VarDesc {
required string name = 1;
required VarType type = 2;
optional bool persistable = 3 [ default = false ];
optional LoDTensorDesc lod_tensor = 4;
optional TensorDesc selected_rows = 5;
optional LoDTensorArrayDesc tensor_array = 6;
optional ReaderDesc reader = 7;
}
message BlockDesc {
......
......@@ -53,11 +53,11 @@ class CompileTimeInferShapeContext : public InferShapeContext {
PADDLE_ENFORCE_LT(j, Outputs(out).size());
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
if (in_var->GetType() != proto::VarDesc::LOD_TENSOR) {
if (in_var->GetType() != proto::VarType::LOD_TENSOR) {
VLOG(3) << "input " << in << " is not LodTensor";
return;
}
PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarDesc::LOD_TENSOR,
PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarType::LOD_TENSOR,
"The %d-th output of Output(%s) must be LoDTensor.", j,
out);
out_var->SetLoDLevel(in_var->GetLoDLevel());
......@@ -66,7 +66,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool IsRuntime() const override;
protected:
proto::VarDesc::VarType GetVarType(const std::string &name) const override;
proto::VarType::Type GetVarType(const std::string &name) const override;
DDim GetDim(const std::string &name) const override;
......@@ -388,7 +388,7 @@ void OpDesc::InferVarType(BlockDesc *block) const {
for (auto &out_pair : this->outputs_) {
for (auto &out_var_name : out_pair.second) {
block->FindRecursiveOrCreateVar(out_var_name)
.SetType(proto::VarDesc::LOD_TENSOR);
.SetType(proto::VarType::LOD_TENSOR);
}
}
}
......@@ -507,7 +507,7 @@ void CompileTimeInferShapeContext::SetRepeatedDims(
bool CompileTimeInferShapeContext::IsRuntime() const { return false; }
proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType(
proto::VarType::Type CompileTimeInferShapeContext::GetVarType(
const std::string &name) const {
return block_.FindVarRecursive(name)->GetType();
}
......
......@@ -477,7 +477,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
}
proto::VarDesc::VarType GetVarType(const std::string& name) const override {
proto::VarType::Type GetVarType(const std::string& name) const override {
auto* var = scope_.FindVar(name);
return ToVarType(var->Type());
}
......
......@@ -22,13 +22,13 @@ TEST(ProgramDesc, copy_ctor) {
ProgramDesc program;
auto* global_block = program.MutableBlock(0);
auto* x = global_block->Var("X");
x->SetType(proto::VarDesc_VarType_LOD_TENSOR);
x->SetType(proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0);
x->SetDataType(proto::FP32);
x->SetShape({1000, 784});
auto* y = global_block->Var("Y");
y->SetType(proto::VarDesc_VarType_LOD_TENSOR);
y->SetType(proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0);
y->SetDataType(proto::FP32);
y->SetShape({784, 100});
......@@ -39,7 +39,7 @@ TEST(ProgramDesc, copy_ctor) {
op->SetInput("Y", {y->Name()});
auto* out = global_block->Var("Out");
out->SetType(proto::VarDesc_VarType_LOD_TENSOR);
out->SetType(proto::VarType::LOD_TENSOR);
op->SetOutput("Y", {out->Name()});
ProgramDesc program_copy(program);
......@@ -84,13 +84,13 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
ProgramDesc program_origin;
auto* global_block = program_origin.MutableBlock(0);
auto* x = global_block->Var("X");
x->SetType(proto::VarDesc_VarType_LOD_TENSOR);
x->SetType(proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0);
x->SetDataType(proto::FP32);
x->SetShape({1000, 784});
auto* y = global_block->Var("Y");
y->SetType(proto::VarDesc_VarType_LOD_TENSOR);
y->SetType(proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0);
y->SetDataType(proto::FP32);
y->SetShape({784, 100});
......@@ -101,7 +101,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
op->SetInput("Y", {y->Name()});
auto* out = global_block->Var("Out");
out->SetType(proto::VarDesc_VarType_LOD_TENSOR);
out->SetType(proto::VarType::LOD_TENSOR);
op->SetOutput("Y", {out->Name()});
std::string binary_str;
......
......@@ -116,19 +116,19 @@ void InferShapeContext::SetDims(const std::vector<std::string> &names,
}
}
std::vector<proto::VarDesc::VarType> InferShapeContext::GetInputsVarType(
std::vector<proto::VarType::Type> InferShapeContext::GetInputsVarType(
const std::string &name) const {
return GetVarTypes(Inputs(name));
}
std::vector<proto::VarDesc::VarType> InferShapeContext::GetOutputsVarType(
std::vector<proto::VarType::Type> InferShapeContext::GetOutputsVarType(
const std::string &name) const {
return GetVarTypes(Outputs(name));
}
std::vector<proto::VarDesc::VarType> InferShapeContext::GetVarTypes(
std::vector<proto::VarType::Type> InferShapeContext::GetVarTypes(
const std::vector<std::string> &names) const {
std::vector<proto::VarDesc::VarType> retv;
std::vector<proto::VarType::Type> retv;
retv.resize(names.size());
std::transform(names.begin(), names.end(), retv.begin(),
std::bind(std::mem_fn(&InferShapeContext::GetVarType), this,
......
......@@ -31,9 +31,9 @@ class InferShapeContext {
virtual bool HasInput(const std::string &name) const = 0;
virtual bool HasOutput(const std::string &name) const = 0;
std::vector<proto::VarDesc::VarType> GetInputsVarType(
std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const;
std::vector<proto::VarDesc::VarType> GetOutputsVarType(
std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string &name) const;
virtual bool HasInputs(const std::string &name) const = 0;
......@@ -75,10 +75,10 @@ class InferShapeContext {
std::vector<DDim> GetDims(const std::vector<std::string> &names) const;
std::vector<proto::VarDesc::VarType> GetVarTypes(
std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<std::string> &names) const;
virtual proto::VarDesc::VarType GetVarType(const std::string &name) const = 0;
virtual proto::VarType::Type GetVarType(const std::string &name) const = 0;
virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0;
};
......
......@@ -225,7 +225,7 @@ inline void SerializeToStream(std::ostream& os, const Tensor& tensor,
{ // the 2nd field, tensor description
// int32_t size
// void* protobuf message
proto::TensorDesc desc;
proto::VarType::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type()));
auto dims = framework::vectorize(tensor.dims());
auto* pb_dims = desc.mutable_dims();
......@@ -290,7 +290,7 @@ inline void DeserializeFromStream(std::istream& is, Tensor* tensor,
uint32_t version;
is.read(reinterpret_cast<char*>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
proto::TensorDesc desc;
proto::VarType::TensorDesc desc;
{ // int32_t size
// proto buffer
int32_t size;
......
......@@ -18,18 +18,21 @@ limitations under the License. */
namespace paddle {
namespace framework {
proto::VarDesc::VarType VarDesc::GetType() const { return desc_.type(); }
proto::VarType::Type VarDesc::GetType() const { return desc_.type().type(); }
void VarDesc::SetType(proto::VarDesc::VarType type) { desc_.set_type(type); }
void VarDesc::SetType(proto::VarType::Type type) {
desc_.mutable_type()->set_type(type);
}
void VarDesc::SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
}
void VarDesc::SetTensorDescNum(size_t num) {
switch (desc_.type()) {
case proto::VarDesc::READER: {
auto *lod_tensors_ptr = desc_.mutable_reader()->mutable_lod_tensor();
switch (desc_.type().type()) {
case proto::VarType::READER: {
auto *lod_tensors_ptr =
desc_.mutable_type()->mutable_reader()->mutable_lod_tensor();
lod_tensors_ptr->Clear();
for (size_t i = 0; i < num; ++i) {
lod_tensors_ptr->Add();
......@@ -44,9 +47,9 @@ void VarDesc::SetTensorDescNum(size_t num) {
}
size_t VarDesc::GetTensorDescNum() const {
switch (desc_.type()) {
case proto::VarDesc::READER:
return desc_.reader().lod_tensor_size();
switch (desc_.type().type()) {
case proto::VarType::READER:
return desc_.type().reader().lod_tensor_size();
break;
default:
PADDLE_THROW(
......@@ -64,7 +67,7 @@ void VarDesc::SetShapes(
<< "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_dims.size());
}
std::vector<proto::TensorDesc *> tensors = mutable_tensor_descs();
std::vector<proto::VarType::TensorDesc *> tensors = mutable_tensor_descs();
for (size_t i = 0; i < multiple_dims.size(); ++i) {
VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims());
}
......@@ -75,7 +78,7 @@ std::vector<int64_t> VarDesc::GetShape() const {
}
std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
std::vector<proto::TensorDesc> descs = tensor_descs();
std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
std::vector<std::vector<int64_t>> res;
res.reserve(descs.size());
for (const auto &tensor_desc : descs) {
......@@ -98,7 +101,8 @@ void VarDesc::SetDataTypes(
<< "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_data_type.size());
}
std::vector<proto::TensorDesc *> tensor_descs = mutable_tensor_descs();
std::vector<proto::VarType::TensorDesc *> tensor_descs =
mutable_tensor_descs();
for (size_t i = 0; i < multiple_data_type.size(); ++i) {
tensor_descs[i]->set_data_type(multiple_data_type[i]);
}
......@@ -109,7 +113,7 @@ proto::DataType VarDesc::GetDataType() const {
}
std::vector<proto::DataType> VarDesc::GetDataTypes() const {
std::vector<proto::TensorDesc> descs = tensor_descs();
std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
std::vector<proto::DataType> res;
res.reserve(descs.size());
for (const auto &tensor_desc : descs) {
......@@ -119,12 +123,12 @@ std::vector<proto::DataType> VarDesc::GetDataTypes() const {
}
void VarDesc::SetLoDLevel(int32_t lod_level) {
switch (desc_.type()) {
case proto::VarDesc::LOD_TENSOR:
desc_.mutable_lod_tensor()->set_lod_level(lod_level);
switch (desc_.type().type()) {
case proto::VarType::LOD_TENSOR:
desc_.mutable_type()->mutable_lod_tensor()->set_lod_level(lod_level);
break;
case proto::VarDesc::LOD_TENSOR_ARRAY:
desc_.mutable_tensor_array()->set_lod_level(lod_level);
case proto::VarType::LOD_TENSOR_ARRAY:
desc_.mutable_type()->mutable_tensor_array()->set_lod_level(lod_level);
break;
default:
PADDLE_THROW(
......@@ -142,10 +146,11 @@ void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
<< "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_lod_level.size());
}
switch (desc_.type()) {
case proto::VarDesc::READER: {
switch (desc_.type().type()) {
case proto::VarType::READER: {
size_t i = 0;
for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) {
for (auto &lod_tensor :
*desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) {
lod_tensor.set_lod_level(multiple_lod_level[i++]);
}
} break;
......@@ -157,11 +162,11 @@ void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
}
int32_t VarDesc::GetLoDLevel() const {
switch (desc_.type()) {
case proto::VarDesc::LOD_TENSOR:
return desc_.lod_tensor().lod_level();
case proto::VarDesc::LOD_TENSOR_ARRAY:
return desc_.tensor_array().lod_level();
switch (desc_.type().type()) {
case proto::VarType::LOD_TENSOR:
return desc_.type().lod_tensor().lod_level();
case proto::VarType::LOD_TENSOR_ARRAY:
return desc_.type().tensor_array().lod_level();
default:
PADDLE_THROW(
"Getting 'lod_level' is not supported by the type of var %s.",
......@@ -171,10 +176,10 @@ int32_t VarDesc::GetLoDLevel() const {
std::vector<int32_t> VarDesc::GetLoDLevels() const {
std::vector<int32_t> res;
switch (desc_.type()) {
case proto::VarDesc::READER:
res.reserve(desc_.reader().lod_tensor_size());
for (auto &lod_tensor : desc_.reader().lod_tensor()) {
switch (desc_.type().type()) {
case proto::VarType::READER:
res.reserve(desc_.type().reader().lod_tensor_size());
for (auto &lod_tensor : desc_.type().reader().lod_tensor()) {
res.push_back(lod_tensor.lod_level());
}
return res;
......@@ -186,15 +191,16 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const {
}
}
const proto::TensorDesc &VarDesc::tensor_desc() const {
const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
switch (desc_.type()) {
case proto::VarDesc::SELECTED_ROWS:
return desc_.selected_rows();
case proto::VarDesc::LOD_TENSOR:
return desc_.lod_tensor().tensor();
case proto::VarDesc::LOD_TENSOR_ARRAY:
return desc_.tensor_array().tensor();
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
switch (desc_.type().type()) {
case proto::VarType::SELECTED_ROWS:
return desc_.type().selected_rows();
case proto::VarType::LOD_TENSOR:
return desc_.type().lod_tensor().tensor();
case proto::VarType::LOD_TENSOR_ARRAY:
return desc_.type().tensor_array().tensor();
default:
PADDLE_THROW(
"Getting 'tensor_desc' is not supported by the type of var %s.",
......@@ -202,13 +208,13 @@ const proto::TensorDesc &VarDesc::tensor_desc() const {
}
}
std::vector<proto::TensorDesc> VarDesc::tensor_descs() const {
std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
std::vector<proto::TensorDesc> res;
std::vector<proto::VarType::TensorDesc> res;
res.reserve(GetTensorDescNum());
switch (desc_.type()) {
case proto::VarDesc::READER:
for (const auto &lod_tensor : desc_.reader().lod_tensor()) {
switch (desc_.type().type()) {
case proto::VarType::READER:
for (const auto &lod_tensor : desc_.type().reader().lod_tensor()) {
res.push_back(lod_tensor.tensor());
}
return res;
......@@ -220,15 +226,16 @@ std::vector<proto::TensorDesc> VarDesc::tensor_descs() const {
}
}
proto::TensorDesc *VarDesc::mutable_tensor_desc() {
proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
switch (desc_.type()) {
case proto::VarDesc::SELECTED_ROWS:
return desc_.mutable_selected_rows();
case proto::VarDesc::LOD_TENSOR:
return desc_.mutable_lod_tensor()->mutable_tensor();
case proto::VarDesc::LOD_TENSOR_ARRAY:
return desc_.mutable_tensor_array()->mutable_tensor();
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
switch (desc_.type().type()) {
case proto::VarType::SELECTED_ROWS:
return desc_.mutable_type()->mutable_selected_rows();
case proto::VarType::LOD_TENSOR:
return desc_.mutable_type()->mutable_lod_tensor()->mutable_tensor();
case proto::VarType::LOD_TENSOR_ARRAY:
return desc_.mutable_type()->mutable_tensor_array()->mutable_tensor();
default:
PADDLE_THROW(
"Getting 'mutable_tensor_desc' is not supported by the type of var "
......@@ -237,13 +244,15 @@ proto::TensorDesc *VarDesc::mutable_tensor_desc() {
}
}
std::vector<proto::TensorDesc *> VarDesc::mutable_tensor_descs() {
std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
std::vector<proto::TensorDesc *> res;
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
std::vector<proto::VarType::TensorDesc *> res;
res.reserve(GetTensorDescNum());
switch (desc_.type()) {
case proto::VarDesc::READER:
for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) {
switch (desc_.type().type()) {
case proto::VarType::READER:
for (auto &lod_tensor :
*desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) {
res.push_back(lod_tensor.mutable_tensor());
}
return res;
......
......@@ -57,7 +57,7 @@ class VarDesc {
public:
explicit VarDesc(const std::string &name) {
desc_.set_name(name);
desc_.set_type(proto::VarDesc::LOD_TENSOR);
desc_.mutable_type()->set_type(proto::VarType::LOD_TENSOR);
}
explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {}
......@@ -96,19 +96,19 @@ class VarDesc {
std::vector<int32_t> GetLoDLevels() const;
proto::VarDesc::VarType GetType() const;
proto::VarType::Type GetType() const;
void SetType(proto::VarDesc::VarType type);
void SetType(proto::VarType::Type type);
bool Persistable() const { return desc_.persistable(); }
void SetPersistable(bool persistable) { desc_.set_persistable(persistable); }
private:
const proto::TensorDesc &tensor_desc() const;
std::vector<proto::TensorDesc> tensor_descs() const;
proto::TensorDesc *mutable_tensor_desc();
std::vector<proto::TensorDesc *> mutable_tensor_descs();
const proto::VarType::TensorDesc &tensor_desc() const;
std::vector<proto::VarType::TensorDesc> tensor_descs() const;
proto::VarType::TensorDesc *mutable_tensor_desc();
std::vector<proto::VarType::TensorDesc *> mutable_tensor_descs();
proto::VarDesc desc_;
};
......
......@@ -23,17 +23,17 @@ limitations under the License. */
namespace paddle {
namespace framework {
inline proto::VarDesc::VarType ToVarType(std::type_index type) {
inline proto::VarType::Type ToVarType(std::type_index type) {
if (type.hash_code() == typeid(LoDTensor).hash_code()) {
return proto::VarDesc_VarType_LOD_TENSOR;
return proto::VarType_Type_LOD_TENSOR;
} else if (type.hash_code() == typeid(LoDRankTable).hash_code()) {
return proto::VarDesc_VarType_LOD_RANK_TABLE;
return proto::VarType_Type_LOD_RANK_TABLE;
} else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) {
return proto::VarDesc_VarType_LOD_TENSOR_ARRAY;
return proto::VarType_Type_LOD_TENSOR_ARRAY;
} else if (type.hash_code() == typeid(SelectedRows).hash_code()) {
return proto::VarDesc_VarType_SELECTED_ROWS;
return proto::VarType_Type_SELECTED_ROWS;
} else if (type.hash_code() == typeid(ReaderHolder).hash_code()) {
return proto::VarDesc_VarType_READER;
return proto::VarType_Type_READER;
} else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
}
......@@ -42,19 +42,19 @@ inline proto::VarDesc::VarType ToVarType(std::type_index type) {
template <typename Visitor>
inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
switch (ToVarType(var.Type())) {
case proto::VarDesc_VarType_LOD_TENSOR:
case proto::VarType_Type_LOD_TENSOR:
visitor(var.Get<LoDTensor>());
return;
case proto::VarDesc_VarType_LOD_RANK_TABLE:
case proto::VarType_Type_LOD_RANK_TABLE:
visitor(var.Get<LoDRankTable>());
return;
case proto::VarDesc_VarType_LOD_TENSOR_ARRAY:
case proto::VarType_Type_LOD_TENSOR_ARRAY:
visitor(var.Get<LoDTensorArray>());
return;
case proto::VarDesc_VarType_SELECTED_ROWS:
case proto::VarType_Type_SELECTED_ROWS:
visitor(var.Get<SelectedRows>());
return;
case proto::VarDesc_VarType_READER:
case proto::VarType_Type_READER:
visitor(var.Get<ReaderHolder>());
return;
default:
......
......@@ -35,14 +35,14 @@ class SumOpVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override {
auto &inputs = op_desc.Input("X");
auto default_var_type = proto::VarDesc::SELECTED_ROWS;
auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string &name) {
return block->Var(name)->GetType() == proto::VarDesc::LOD_TENSOR;
return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR;
});
if (any_input_is_lod_tensor) {
default_var_type = proto::VarDesc::LOD_TENSOR;
default_var_type = proto::VarType::LOD_TENSOR;
}
auto out_var_name = op_desc.Output("Out").front();
......@@ -67,19 +67,19 @@ TEST(InferVarType, sum_op) {
op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"});
prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_out");
op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarDesc::SELECTED_ROWS,
ASSERT_EQ(proto::VarType::SELECTED_ROWS,
prog.MutableBlock(0)->Var("test_out")->GetType());
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarDesc::LOD_TENSOR);
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::LOD_TENSOR);
op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarDesc::LOD_TENSOR,
ASSERT_EQ(proto::VarType::LOD_TENSOR,
prog.MutableBlock(0)->Var("test_out")->GetType());
}
......@@ -90,14 +90,14 @@ TEST(InferVarType, sum_op_without_infer_var_type) {
op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
op->SetOutput("Out", {"test2_out"});
prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_out");
op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarDesc_VarType_LOD_TENSOR,
ASSERT_EQ(proto::VarType_Type_LOD_TENSOR,
prog.MutableBlock(0)->Var("test2_out")->GetType());
}
......
......@@ -115,8 +115,8 @@ class AssignInferShape : public framework::InferShapeBase {
void operator()(framework::InferShapeContext *context) const override {
if (context->HasInput("X")) {
auto type = context->GetInputsVarType("X")[0];
if (type == framework::proto::VarDesc_VarType_SELECTED_ROWS ||
type == framework::proto::VarDesc_VarType_LOD_TENSOR) {
if (type == framework::proto::VarType::SELECTED_ROWS ||
type == framework::proto::VarType::LOD_TENSOR) {
context->SetOutputDim("Out", context->GetInputDim("X"));
}
}
......
......@@ -128,10 +128,10 @@ class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
for (auto& o : op_desc.Output("SentenceIds")) {
block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR);
block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
}
for (auto& o : op_desc.Output("SentenceScores")) {
block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR);
block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
}
}
};
......
......@@ -240,10 +240,10 @@ class BeamSearchInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
for (auto &o : op_desc.Output("selected_ids")) {
block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR);
block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
}
for (auto &o : op_desc.Output("selected_scores")) {
block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR);
block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
}
}
};
......
......@@ -84,7 +84,7 @@ class CreateFileReaderInferVarType : public framework::VarTypeInference {
framework::BlockDesc* block) const override {
std::string reader_name = op_desc.Output("Out")[0];
framework::VarDesc* reader = block->FindVarRecursive(reader_name);
reader->SetType(framework::proto::VarDesc::READER);
reader->SetType(framework::proto::VarType::READER);
}
};
......@@ -97,7 +97,7 @@ class CreateDecoratedReaderInferVarType : public framework::VarTypeInference {
framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name);
std::string out_reader_name = op_desc.Output("Out")[0];
framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name);
out_reader->SetType(framework::proto::VarDesc::READER);
out_reader->SetType(framework::proto::VarType::READER);
out_reader->SetDataTypes(in_reader->GetDataTypes());
}
};
......@@ -147,7 +147,7 @@ class CreateRandomDataGeneratorOpMaker
AddComment(R"DOC(
CreateRandomDataGenerator Operator
This Op creates a random reader.
This Op creates a random reader.
The reader generates random data instead of really reading from files.
Generated data follow an uniform distribution between 'min' and 'max'.
)DOC");
......@@ -183,7 +183,7 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
CreateShuffleReader Operator
A shuffle reader takes another reader as its 'underlying reader'
and yields the underlying reader's outputs in a shuffled order.
and yields the underlying reader's outputs in a shuffled order.
)DOC");
}
};
......@@ -218,8 +218,8 @@ class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
CreateBatchReader Operator
A batch reader takes another reader as its 'underlying reader',
gathers the underlying reader's outputs and then yields them in batches.
A batch reader takes another reader as its 'underlying reader',
gathers the underlying reader's outputs and then yields them in batches.
)DOC");
}
};
......
......@@ -24,11 +24,11 @@ void SerializeToMessage(const std::string& name, const framework::Variable* var,
msg->set_varname(name);
std::ostringstream oss;
switch (framework::ToVarType(var->Type())) {
case framework::proto::VarDesc_VarType_LOD_TENSOR:
case framework::proto::VarType_Type_LOD_TENSOR:
msg->set_type(sendrecv::VarType::LOD_TENSOR);
framework::SerializeToStream(oss, var->Get<framework::LoDTensor>(), ctx);
break;
case framework::proto::VarDesc_VarType_SELECTED_ROWS:
case framework::proto::VarType_Type_SELECTED_ROWS:
msg->set_type(sendrecv::VarType::SELECTED_ROWS);
framework::SerializeToStream(oss, var->Get<framework::SelectedRows>(),
ctx);
......
......@@ -98,7 +98,7 @@ class GetPlacesInferVarType : public framework::VarTypeInference {
framework::BlockDesc *block) const override {
for (auto &o_name : op_desc.Output("Out")) {
block->FindRecursiveOrCreateVar(o_name).SetType(
framework::proto::VarDesc::PLACE_LIST);
framework::proto::VarType::PLACE_LIST);
}
}
};
......
......@@ -69,7 +69,7 @@ class LoDRankTableInferVarType : public framework::VarTypeInference {
framework::BlockDesc *block) const override {
for (auto &o : op_desc.Output("Out")) {
block->FindRecursiveOrCreateVar(o).SetType(
framework::proto::VarDesc::LOD_RANK_TABLE);
framework::proto::VarType::LOD_RANK_TABLE);
}
}
};
......
......@@ -138,7 +138,7 @@ class LoDTensorToArrayInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY);
block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR_ARRAY);
}
}
};
......
......@@ -123,11 +123,11 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows";
block->Var(out_var_name)
->SetType(framework::proto::VarDesc::SELECTED_ROWS);
->SetType(framework::proto::VarType::SELECTED_ROWS);
} else {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor";
block->Var(out_var_name)->SetType(framework::proto::VarDesc::LOD_TENSOR);
block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR);
}
}
};
......
......@@ -45,7 +45,7 @@ class ReadInferVarType : public framework::VarTypeInference {
PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size());
for (size_t i = 0; i < dtypes.size(); ++i) {
framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]);
out.SetType(framework::proto::VarDesc::LOD_TENSOR);
out.SetType(framework::proto::VarType::LOD_TENSOR);
out.SetDataType(dtypes[i]);
}
}
......
......@@ -29,7 +29,7 @@ class SumOp : public framework::OperatorWithKernel {
"Output(Out) of SumOp should not be null.");
if (ctx->IsRuntime() &&
ctx->GetOutputsVarType("Out")[0] ==
framework::proto::VarDesc::LOD_TENSOR_ARRAY) {
framework::proto::VarType::LOD_TENSOR_ARRAY) {
return; // skip runtime infershape when is tensor array;
}
......@@ -118,7 +118,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto& inputs = op_desc.Input("X");
auto var_type = framework::proto::VarDesc::SELECTED_ROWS;
auto var_type = framework::proto::VarType::SELECTED_ROWS;
for (auto& name : op_desc.Input("X")) {
VLOG(10) << name << " "
......@@ -128,12 +128,12 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string& name) {
return block->FindRecursiveOrCreateVar(name).GetType() ==
framework::proto::VarDesc::LOD_TENSOR;
framework::proto::VarType::LOD_TENSOR;
});
auto is_tensor_array = [block](const std::string& name) {
return block->FindRecursiveOrCreateVar(name).GetType() ==
framework::proto::VarDesc::LOD_TENSOR_ARRAY;
framework::proto::VarType::LOD_TENSOR_ARRAY;
};
bool any_input_is_tensor_array =
......@@ -151,9 +151,9 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
PADDLE_ENFORCE(all_inputs_are_tensor_array,
"Not all inputs are tensor array:\n%s", os.str());
}
var_type = framework::proto::VarDesc::LOD_TENSOR_ARRAY;
var_type = framework::proto::VarType::LOD_TENSOR_ARRAY;
} else if (any_input_is_lod_tensor) {
var_type = framework::proto::VarDesc::LOD_TENSOR;
var_type = framework::proto::VarType::LOD_TENSOR;
}
auto out_var_name = op_desc.Output("Out").front();
......
......@@ -108,7 +108,7 @@ class WriteToArrayInferVarType : public framework::VarTypeInference {
auto out_name = op_desc.Output("Out")[0];
VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY";
auto &out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY);
out.SetType(framework::proto::VarType::LOD_TENSOR_ARRAY);
auto *x = block->FindVarRecursive(x_name);
if (x != nullptr) {
out.SetDataType(x->GetDataType());
......
......@@ -330,10 +330,10 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
continue;
}
auto dims = ctx->GetInputsElementDim(kX, i);
if (var_types[i] == framework::proto::VarDesc::LOD_TENSOR) {
if (var_types[i] == framework::proto::VarType::LOD_TENSOR) {
names_to_set.push_back(pg_names[i]);
dims_to_set.push_back(dims);
} else if (var_types[i] == framework::proto::VarDesc::LOD_TENSOR_ARRAY) {
} else if (var_types[i] == framework::proto::VarType::LOD_TENSOR_ARRAY) {
// not sure how to set the dim of LOD_TENSOR_ARRAY
names_to_set.push_back(pg_names[i]);
dims_to_set.push_back(dims);
......
......@@ -232,16 +232,16 @@ void BindVarDsec(py::module &m) {
.def("persistable", &VarDesc::Persistable)
.def("set_persistable", &VarDesc::SetPersistable);
py::enum_<proto::VarDesc::VarType>(var_desc, "VarType", "")
.value("LOD_TENSOR", proto::VarDesc::LOD_TENSOR)
.value("SELECTED_ROWS", proto::VarDesc::SELECTED_ROWS)
.value("FEED_MINIBATCH", proto::VarDesc::FEED_MINIBATCH)
.value("FETCH_LIST", proto::VarDesc::FETCH_LIST)
.value("STEP_SCOPES", proto::VarDesc::STEP_SCOPES)
.value("LOD_RANK_TABLE", proto::VarDesc::LOD_RANK_TABLE)
.value("LOD_TENSOR_ARRAY", proto::VarDesc::LOD_TENSOR_ARRAY)
.value("PLACE_LIST", proto::VarDesc::PLACE_LIST)
.value("READER", proto::VarDesc::READER);
py::enum_<proto::VarType::Type>(var_desc, "VarType", "")
.value("LOD_TENSOR", proto::VarType::LOD_TENSOR)
.value("SELECTED_ROWS", proto::VarType::SELECTED_ROWS)
.value("FEED_MINIBATCH", proto::VarType::FEED_MINIBATCH)
.value("FETCH_LIST", proto::VarType::FETCH_LIST)
.value("STEP_SCOPES", proto::VarType::STEP_SCOPES)
.value("LOD_RANK_TABLE", proto::VarType::LOD_RANK_TABLE)
.value("LOD_TENSOR_ARRAY", proto::VarType::LOD_TENSOR_ARRAY)
.value("PLACE_LIST", proto::VarType::PLACE_LIST)
.value("READER", proto::VarType::READER);
}
void BindOpDesc(py::module &m) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册