提交 7129fa3c 编写于 作者: Y Yang Yang

merge develop

...@@ -60,6 +60,7 @@ option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF) ...@@ -60,6 +60,7 @@ option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF)
option(WITH_DISTRIBUTE "Compile with grpc distributed support" OFF) option(WITH_DISTRIBUTE "Compile with grpc distributed support" OFF)
option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF) option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF)
option(WITH_ARM_FP16 "Use half precision support on armv8.2-a cpu" OFF) option(WITH_ARM_FP16 "Use half precision support on armv8.2-a cpu" OFF)
option(WITH_FAST_BUNDLE_TEST "Bundle tests that can be run in a single process together to reduce launch overhead" ON)
# CMAKE_BUILD_TYPE # CMAKE_BUILD_TYPE
if(NOT CMAKE_BUILD_TYPE) if(NOT CMAKE_BUILD_TYPE)
......
...@@ -36,24 +36,24 @@ namespace framework { ...@@ -36,24 +36,24 @@ namespace framework {
Executor::Executor(const platform::Place& place) : place_(place) {} Executor::Executor(const platform::Place& place) : place_(place) {}
static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarDesc::LOD_TENSOR) { if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>(); var->GetMutable<LoDTensor>();
} else if (var_type == proto::VarDesc::SELECTED_ROWS) { } else if (var_type == proto::VarType::SELECTED_ROWS) {
var->GetMutable<SelectedRows>(); var->GetMutable<SelectedRows>();
} else if (var_type == proto::VarDesc::FEED_MINIBATCH) { } else if (var_type == proto::VarType::FEED_MINIBATCH) {
var->GetMutable<FeedFetchList>(); var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarDesc::FETCH_LIST) { } else if (var_type == proto::VarType::FETCH_LIST) {
var->GetMutable<FeedFetchList>(); var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarDesc::STEP_SCOPES) { } else if (var_type == proto::VarType::STEP_SCOPES) {
var->GetMutable<std::vector<framework::Scope>>(); var->GetMutable<std::vector<framework::Scope>>();
} else if (var_type == proto::VarDesc::LOD_RANK_TABLE) { } else if (var_type == proto::VarType::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>(); var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarDesc::LOD_TENSOR_ARRAY) { } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
var->GetMutable<LoDTensorArray>(); var->GetMutable<LoDTensorArray>();
} else if (var_type == proto::VarDesc::PLACE_LIST) { } else if (var_type == proto::VarType::PLACE_LIST) {
var->GetMutable<platform::PlaceList>(); var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarDesc::READER) { } else if (var_type == proto::VarType::READER) {
var->GetMutable<ReaderHolder>(); var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarDesc::NCCL_COM) { } else if (var_type == proto::VarDesc::NCCL_COM) {
// GetMutable will be called in ncclInit // GetMutable will be called in ncclInit
...@@ -184,7 +184,7 @@ static bool has_feed_operators( ...@@ -184,7 +184,7 @@ static bool has_feed_operators(
auto var = block->FindVar(feed_holder_name); auto var = block->FindVar(feed_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
feed_holder_name); feed_holder_name);
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarDesc::FEED_MINIBATCH, PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
"'%s' variable should be 'FEED_MINIBATCH' type", "'%s' variable should be 'FEED_MINIBATCH' type",
feed_holder_name); feed_holder_name);
} }
...@@ -224,7 +224,7 @@ static bool has_fetch_operators( ...@@ -224,7 +224,7 @@ static bool has_fetch_operators(
auto var = block->FindVar(fetch_holder_name); auto var = block->FindVar(fetch_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
fetch_holder_name); fetch_holder_name);
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarDesc::FETCH_LIST, PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
"'%s' variable should be 'FETCH_LIST' type", "'%s' variable should be 'FETCH_LIST' type",
fetch_holder_name); fetch_holder_name);
} }
...@@ -243,7 +243,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -243,7 +243,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) { if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) {
// create feed_holder variable // create feed_holder variable
auto* feed_holder = global_block->Var(feed_holder_name); auto* feed_holder = global_block->Var(feed_holder_name);
feed_holder->SetType(proto::VarDesc::FEED_MINIBATCH); feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
feed_holder->SetPersistable(true); feed_holder->SetPersistable(true);
int i = 0; int i = 0;
...@@ -276,7 +276,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -276,7 +276,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) { if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) {
// create fetch_holder variable // create fetch_holder variable
auto* fetch_holder = global_block->Var(fetch_holder_name); auto* fetch_holder = global_block->Var(fetch_holder_name);
fetch_holder->SetType(proto::VarDesc::FETCH_LIST); fetch_holder->SetType(proto::VarType::FETCH_LIST);
fetch_holder->SetPersistable(true); fetch_holder->SetPersistable(true);
int i = 0; int i = 0;
......
...@@ -101,25 +101,8 @@ enum DataType { ...@@ -101,25 +101,8 @@ enum DataType {
FP64 = 6; FP64 = 6;
} }
message TensorDesc { message VarType {
required DataType data_type = 1; enum Type {
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
}
message LoDTensorDesc {
required TensorDesc tensor = 1;
optional int32 lod_level = 2 [ default = 0 ];
}
message LoDTensorArrayDesc {
required TensorDesc tensor = 1;
optional int32 lod_level = 2 [ default = 0 ];
}
message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; }
message VarDesc {
enum VarType {
LOD_TENSOR = 1; LOD_TENSOR = 1;
SELECTED_ROWS = 2; SELECTED_ROWS = 2;
FEED_MINIBATCH = 3; FEED_MINIBATCH = 3;
...@@ -131,13 +114,35 @@ message VarDesc { ...@@ -131,13 +114,35 @@ message VarDesc {
READER = 9; READER = 9;
NCCL_COM = 10; NCCL_COM = 10;
} }
required Type type = 1;
message TensorDesc {
required DataType data_type = 1;
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
}
optional TensorDesc selected_rows = 2;
message LoDTensorDesc {
required TensorDesc tensor = 1;
optional int32 lod_level = 2 [ default = 0 ];
}
optional LoDTensorDesc lod_tensor = 3;
message LoDTensorArrayDesc {
required TensorDesc tensor = 1;
optional int32 lod_level = 2 [ default = 0 ];
}
optional LoDTensorArrayDesc tensor_array = 4;
message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; }
optional ReaderDesc reader = 5;
}
message VarDesc {
required string name = 1; required string name = 1;
required VarType type = 2; required VarType type = 2;
optional bool persistable = 3 [ default = false ]; optional bool persistable = 3 [ default = false ];
optional LoDTensorDesc lod_tensor = 4;
optional TensorDesc selected_rows = 5;
optional LoDTensorArrayDesc tensor_array = 6;
optional ReaderDesc reader = 7;
} }
message BlockDesc { message BlockDesc {
......
...@@ -53,11 +53,11 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -53,11 +53,11 @@ class CompileTimeInferShapeContext : public InferShapeContext {
PADDLE_ENFORCE_LT(j, Outputs(out).size()); PADDLE_ENFORCE_LT(j, Outputs(out).size());
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]); auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]); auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
if (in_var->GetType() != proto::VarDesc::LOD_TENSOR) { if (in_var->GetType() != proto::VarType::LOD_TENSOR) {
VLOG(3) << "input " << in << " is not LodTensor"; VLOG(3) << "input " << in << " is not LodTensor";
return; return;
} }
PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarDesc::LOD_TENSOR, PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarType::LOD_TENSOR,
"The %d-th output of Output(%s) must be LoDTensor.", j, "The %d-th output of Output(%s) must be LoDTensor.", j,
out); out);
out_var->SetLoDLevel(in_var->GetLoDLevel()); out_var->SetLoDLevel(in_var->GetLoDLevel());
...@@ -66,7 +66,7 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -66,7 +66,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool IsRuntime() const override; bool IsRuntime() const override;
protected: protected:
proto::VarDesc::VarType GetVarType(const std::string &name) const override; proto::VarType::Type GetVarType(const std::string &name) const override;
DDim GetDim(const std::string &name) const override; DDim GetDim(const std::string &name) const override;
...@@ -388,7 +388,7 @@ void OpDesc::InferVarType(BlockDesc *block) const { ...@@ -388,7 +388,7 @@ void OpDesc::InferVarType(BlockDesc *block) const {
for (auto &out_pair : this->outputs_) { for (auto &out_pair : this->outputs_) {
for (auto &out_var_name : out_pair.second) { for (auto &out_var_name : out_pair.second) {
block->FindRecursiveOrCreateVar(out_var_name) block->FindRecursiveOrCreateVar(out_var_name)
.SetType(proto::VarDesc::LOD_TENSOR); .SetType(proto::VarType::LOD_TENSOR);
} }
} }
} }
...@@ -507,7 +507,7 @@ void CompileTimeInferShapeContext::SetRepeatedDims( ...@@ -507,7 +507,7 @@ void CompileTimeInferShapeContext::SetRepeatedDims(
bool CompileTimeInferShapeContext::IsRuntime() const { return false; } bool CompileTimeInferShapeContext::IsRuntime() const { return false; }
proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType( proto::VarType::Type CompileTimeInferShapeContext::GetVarType(
const std::string &name) const { const std::string &name) const {
return block_.FindVarRecursive(name)->GetType(); return block_.FindVarRecursive(name)->GetType();
} }
......
...@@ -477,7 +477,7 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -477,7 +477,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
} }
proto::VarDesc::VarType GetVarType(const std::string& name) const override { proto::VarType::Type GetVarType(const std::string& name) const override {
auto* var = scope_.FindVar(name); auto* var = scope_.FindVar(name);
return ToVarType(var->Type()); return ToVarType(var->Type());
} }
......
...@@ -22,13 +22,13 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -22,13 +22,13 @@ TEST(ProgramDesc, copy_ctor) {
ProgramDesc program; ProgramDesc program;
auto* global_block = program.MutableBlock(0); auto* global_block = program.MutableBlock(0);
auto* x = global_block->Var("X"); auto* x = global_block->Var("X");
x->SetType(proto::VarDesc_VarType_LOD_TENSOR); x->SetType(proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0); x->SetLoDLevel(0);
x->SetDataType(proto::FP32); x->SetDataType(proto::FP32);
x->SetShape({1000, 784}); x->SetShape({1000, 784});
auto* y = global_block->Var("Y"); auto* y = global_block->Var("Y");
y->SetType(proto::VarDesc_VarType_LOD_TENSOR); y->SetType(proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0); y->SetLoDLevel(0);
y->SetDataType(proto::FP32); y->SetDataType(proto::FP32);
y->SetShape({784, 100}); y->SetShape({784, 100});
...@@ -39,7 +39,7 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -39,7 +39,7 @@ TEST(ProgramDesc, copy_ctor) {
op->SetInput("Y", {y->Name()}); op->SetInput("Y", {y->Name()});
auto* out = global_block->Var("Out"); auto* out = global_block->Var("Out");
out->SetType(proto::VarDesc_VarType_LOD_TENSOR); out->SetType(proto::VarType::LOD_TENSOR);
op->SetOutput("Y", {out->Name()}); op->SetOutput("Y", {out->Name()});
ProgramDesc program_copy(program); ProgramDesc program_copy(program);
...@@ -84,13 +84,13 @@ TEST(ProgramDescBind, serialize_and_deserialize) { ...@@ -84,13 +84,13 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
ProgramDesc program_origin; ProgramDesc program_origin;
auto* global_block = program_origin.MutableBlock(0); auto* global_block = program_origin.MutableBlock(0);
auto* x = global_block->Var("X"); auto* x = global_block->Var("X");
x->SetType(proto::VarDesc_VarType_LOD_TENSOR); x->SetType(proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0); x->SetLoDLevel(0);
x->SetDataType(proto::FP32); x->SetDataType(proto::FP32);
x->SetShape({1000, 784}); x->SetShape({1000, 784});
auto* y = global_block->Var("Y"); auto* y = global_block->Var("Y");
y->SetType(proto::VarDesc_VarType_LOD_TENSOR); y->SetType(proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0); y->SetLoDLevel(0);
y->SetDataType(proto::FP32); y->SetDataType(proto::FP32);
y->SetShape({784, 100}); y->SetShape({784, 100});
...@@ -101,7 +101,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) { ...@@ -101,7 +101,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
op->SetInput("Y", {y->Name()}); op->SetInput("Y", {y->Name()});
auto* out = global_block->Var("Out"); auto* out = global_block->Var("Out");
out->SetType(proto::VarDesc_VarType_LOD_TENSOR); out->SetType(proto::VarType::LOD_TENSOR);
op->SetOutput("Y", {out->Name()}); op->SetOutput("Y", {out->Name()});
std::string binary_str; std::string binary_str;
......
...@@ -116,19 +116,19 @@ void InferShapeContext::SetDims(const std::vector<std::string> &names, ...@@ -116,19 +116,19 @@ void InferShapeContext::SetDims(const std::vector<std::string> &names,
} }
} }
std::vector<proto::VarDesc::VarType> InferShapeContext::GetInputsVarType( std::vector<proto::VarType::Type> InferShapeContext::GetInputsVarType(
const std::string &name) const { const std::string &name) const {
return GetVarTypes(Inputs(name)); return GetVarTypes(Inputs(name));
} }
std::vector<proto::VarDesc::VarType> InferShapeContext::GetOutputsVarType( std::vector<proto::VarType::Type> InferShapeContext::GetOutputsVarType(
const std::string &name) const { const std::string &name) const {
return GetVarTypes(Outputs(name)); return GetVarTypes(Outputs(name));
} }
std::vector<proto::VarDesc::VarType> InferShapeContext::GetVarTypes( std::vector<proto::VarType::Type> InferShapeContext::GetVarTypes(
const std::vector<std::string> &names) const { const std::vector<std::string> &names) const {
std::vector<proto::VarDesc::VarType> retv; std::vector<proto::VarType::Type> retv;
retv.resize(names.size()); retv.resize(names.size());
std::transform(names.begin(), names.end(), retv.begin(), std::transform(names.begin(), names.end(), retv.begin(),
std::bind(std::mem_fn(&InferShapeContext::GetVarType), this, std::bind(std::mem_fn(&InferShapeContext::GetVarType), this,
......
...@@ -31,9 +31,9 @@ class InferShapeContext { ...@@ -31,9 +31,9 @@ class InferShapeContext {
virtual bool HasInput(const std::string &name) const = 0; virtual bool HasInput(const std::string &name) const = 0;
virtual bool HasOutput(const std::string &name) const = 0; virtual bool HasOutput(const std::string &name) const = 0;
std::vector<proto::VarDesc::VarType> GetInputsVarType( std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const; const std::string &name) const;
std::vector<proto::VarDesc::VarType> GetOutputsVarType( std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string &name) const; const std::string &name) const;
virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasInputs(const std::string &name) const = 0;
...@@ -75,10 +75,10 @@ class InferShapeContext { ...@@ -75,10 +75,10 @@ class InferShapeContext {
std::vector<DDim> GetDims(const std::vector<std::string> &names) const; std::vector<DDim> GetDims(const std::vector<std::string> &names) const;
std::vector<proto::VarDesc::VarType> GetVarTypes( std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<std::string> &names) const; const std::vector<std::string> &names) const;
virtual proto::VarDesc::VarType GetVarType(const std::string &name) const = 0; virtual proto::VarType::Type GetVarType(const std::string &name) const = 0;
virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0; virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0;
}; };
......
...@@ -225,7 +225,7 @@ inline void SerializeToStream(std::ostream& os, const Tensor& tensor, ...@@ -225,7 +225,7 @@ inline void SerializeToStream(std::ostream& os, const Tensor& tensor,
{ // the 2nd field, tensor description { // the 2nd field, tensor description
// int32_t size // int32_t size
// void* protobuf message // void* protobuf message
proto::TensorDesc desc; proto::VarType::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type())); desc.set_data_type(framework::ToDataType(tensor.type()));
auto dims = framework::vectorize(tensor.dims()); auto dims = framework::vectorize(tensor.dims());
auto* pb_dims = desc.mutable_dims(); auto* pb_dims = desc.mutable_dims();
...@@ -290,7 +290,7 @@ inline void DeserializeFromStream(std::istream& is, Tensor* tensor, ...@@ -290,7 +290,7 @@ inline void DeserializeFromStream(std::istream& is, Tensor* tensor,
uint32_t version; uint32_t version;
is.read(reinterpret_cast<char*>(&version), sizeof(version)); is.read(reinterpret_cast<char*>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
proto::TensorDesc desc; proto::VarType::TensorDesc desc;
{ // int32_t size { // int32_t size
// proto buffer // proto buffer
int32_t size; int32_t size;
......
...@@ -18,18 +18,21 @@ limitations under the License. */ ...@@ -18,18 +18,21 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
proto::VarDesc::VarType VarDesc::GetType() const { return desc_.type(); } proto::VarType::Type VarDesc::GetType() const { return desc_.type().type(); }
void VarDesc::SetType(proto::VarDesc::VarType type) { desc_.set_type(type); } void VarDesc::SetType(proto::VarType::Type type) {
desc_.mutable_type()->set_type(type);
}
void VarDesc::SetShape(const std::vector<int64_t> &dims) { void VarDesc::SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims()); VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
} }
void VarDesc::SetTensorDescNum(size_t num) { void VarDesc::SetTensorDescNum(size_t num) {
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::READER: { case proto::VarType::READER: {
auto *lod_tensors_ptr = desc_.mutable_reader()->mutable_lod_tensor(); auto *lod_tensors_ptr =
desc_.mutable_type()->mutable_reader()->mutable_lod_tensor();
lod_tensors_ptr->Clear(); lod_tensors_ptr->Clear();
for (size_t i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
lod_tensors_ptr->Add(); lod_tensors_ptr->Add();
...@@ -44,9 +47,9 @@ void VarDesc::SetTensorDescNum(size_t num) { ...@@ -44,9 +47,9 @@ void VarDesc::SetTensorDescNum(size_t num) {
} }
size_t VarDesc::GetTensorDescNum() const { size_t VarDesc::GetTensorDescNum() const {
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::READER: case proto::VarType::READER:
return desc_.reader().lod_tensor_size(); return desc_.type().reader().lod_tensor_size();
break; break;
default: default:
PADDLE_THROW( PADDLE_THROW(
...@@ -64,7 +67,7 @@ void VarDesc::SetShapes( ...@@ -64,7 +67,7 @@ void VarDesc::SetShapes(
<< "). The Reader is going to be reinitialized."; << "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_dims.size()); SetTensorDescNum(multiple_dims.size());
} }
std::vector<proto::TensorDesc *> tensors = mutable_tensor_descs(); std::vector<proto::VarType::TensorDesc *> tensors = mutable_tensor_descs();
for (size_t i = 0; i < multiple_dims.size(); ++i) { for (size_t i = 0; i < multiple_dims.size(); ++i) {
VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims()); VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims());
} }
...@@ -75,7 +78,7 @@ std::vector<int64_t> VarDesc::GetShape() const { ...@@ -75,7 +78,7 @@ std::vector<int64_t> VarDesc::GetShape() const {
} }
std::vector<std::vector<int64_t>> VarDesc::GetShapes() const { std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
std::vector<proto::TensorDesc> descs = tensor_descs(); std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
std::vector<std::vector<int64_t>> res; std::vector<std::vector<int64_t>> res;
res.reserve(descs.size()); res.reserve(descs.size());
for (const auto &tensor_desc : descs) { for (const auto &tensor_desc : descs) {
...@@ -98,7 +101,8 @@ void VarDesc::SetDataTypes( ...@@ -98,7 +101,8 @@ void VarDesc::SetDataTypes(
<< "). The Reader is going to be reinitialized."; << "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_data_type.size()); SetTensorDescNum(multiple_data_type.size());
} }
std::vector<proto::TensorDesc *> tensor_descs = mutable_tensor_descs(); std::vector<proto::VarType::TensorDesc *> tensor_descs =
mutable_tensor_descs();
for (size_t i = 0; i < multiple_data_type.size(); ++i) { for (size_t i = 0; i < multiple_data_type.size(); ++i) {
tensor_descs[i]->set_data_type(multiple_data_type[i]); tensor_descs[i]->set_data_type(multiple_data_type[i]);
} }
...@@ -109,7 +113,7 @@ proto::DataType VarDesc::GetDataType() const { ...@@ -109,7 +113,7 @@ proto::DataType VarDesc::GetDataType() const {
} }
std::vector<proto::DataType> VarDesc::GetDataTypes() const { std::vector<proto::DataType> VarDesc::GetDataTypes() const {
std::vector<proto::TensorDesc> descs = tensor_descs(); std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
std::vector<proto::DataType> res; std::vector<proto::DataType> res;
res.reserve(descs.size()); res.reserve(descs.size());
for (const auto &tensor_desc : descs) { for (const auto &tensor_desc : descs) {
...@@ -119,12 +123,12 @@ std::vector<proto::DataType> VarDesc::GetDataTypes() const { ...@@ -119,12 +123,12 @@ std::vector<proto::DataType> VarDesc::GetDataTypes() const {
} }
void VarDesc::SetLoDLevel(int32_t lod_level) { void VarDesc::SetLoDLevel(int32_t lod_level) {
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::LOD_TENSOR: case proto::VarType::LOD_TENSOR:
desc_.mutable_lod_tensor()->set_lod_level(lod_level); desc_.mutable_type()->mutable_lod_tensor()->set_lod_level(lod_level);
break; break;
case proto::VarDesc::LOD_TENSOR_ARRAY: case proto::VarType::LOD_TENSOR_ARRAY:
desc_.mutable_tensor_array()->set_lod_level(lod_level); desc_.mutable_type()->mutable_tensor_array()->set_lod_level(lod_level);
break; break;
default: default:
PADDLE_THROW( PADDLE_THROW(
...@@ -142,10 +146,11 @@ void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) { ...@@ -142,10 +146,11 @@ void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
<< "). The Reader is going to be reinitialized."; << "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_lod_level.size()); SetTensorDescNum(multiple_lod_level.size());
} }
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::READER: { case proto::VarType::READER: {
size_t i = 0; size_t i = 0;
for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) { for (auto &lod_tensor :
*desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) {
lod_tensor.set_lod_level(multiple_lod_level[i++]); lod_tensor.set_lod_level(multiple_lod_level[i++]);
} }
} break; } break;
...@@ -157,11 +162,11 @@ void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) { ...@@ -157,11 +162,11 @@ void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
} }
int32_t VarDesc::GetLoDLevel() const { int32_t VarDesc::GetLoDLevel() const {
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::LOD_TENSOR: case proto::VarType::LOD_TENSOR:
return desc_.lod_tensor().lod_level(); return desc_.type().lod_tensor().lod_level();
case proto::VarDesc::LOD_TENSOR_ARRAY: case proto::VarType::LOD_TENSOR_ARRAY:
return desc_.tensor_array().lod_level(); return desc_.type().tensor_array().lod_level();
default: default:
PADDLE_THROW( PADDLE_THROW(
"Getting 'lod_level' is not supported by the type of var %s.", "Getting 'lod_level' is not supported by the type of var %s.",
...@@ -171,10 +176,10 @@ int32_t VarDesc::GetLoDLevel() const { ...@@ -171,10 +176,10 @@ int32_t VarDesc::GetLoDLevel() const {
std::vector<int32_t> VarDesc::GetLoDLevels() const { std::vector<int32_t> VarDesc::GetLoDLevels() const {
std::vector<int32_t> res; std::vector<int32_t> res;
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::READER: case proto::VarType::READER:
res.reserve(desc_.reader().lod_tensor_size()); res.reserve(desc_.type().reader().lod_tensor_size());
for (auto &lod_tensor : desc_.reader().lod_tensor()) { for (auto &lod_tensor : desc_.type().reader().lod_tensor()) {
res.push_back(lod_tensor.lod_level()); res.push_back(lod_tensor.lod_level());
} }
return res; return res;
...@@ -186,15 +191,16 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const { ...@@ -186,15 +191,16 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const {
} }
} }
const proto::TensorDesc &VarDesc::tensor_desc() const { const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set."); PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
switch (desc_.type()) { PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
case proto::VarDesc::SELECTED_ROWS: switch (desc_.type().type()) {
return desc_.selected_rows(); case proto::VarType::SELECTED_ROWS:
case proto::VarDesc::LOD_TENSOR: return desc_.type().selected_rows();
return desc_.lod_tensor().tensor(); case proto::VarType::LOD_TENSOR:
case proto::VarDesc::LOD_TENSOR_ARRAY: return desc_.type().lod_tensor().tensor();
return desc_.tensor_array().tensor(); case proto::VarType::LOD_TENSOR_ARRAY:
return desc_.type().tensor_array().tensor();
default: default:
PADDLE_THROW( PADDLE_THROW(
"Getting 'tensor_desc' is not supported by the type of var %s.", "Getting 'tensor_desc' is not supported by the type of var %s.",
...@@ -202,13 +208,13 @@ const proto::TensorDesc &VarDesc::tensor_desc() const { ...@@ -202,13 +208,13 @@ const proto::TensorDesc &VarDesc::tensor_desc() const {
} }
} }
std::vector<proto::TensorDesc> VarDesc::tensor_descs() const { std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
std::vector<proto::TensorDesc> res; std::vector<proto::VarType::TensorDesc> res;
res.reserve(GetTensorDescNum()); res.reserve(GetTensorDescNum());
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::READER: case proto::VarType::READER:
for (const auto &lod_tensor : desc_.reader().lod_tensor()) { for (const auto &lod_tensor : desc_.type().reader().lod_tensor()) {
res.push_back(lod_tensor.tensor()); res.push_back(lod_tensor.tensor());
} }
return res; return res;
...@@ -220,15 +226,16 @@ std::vector<proto::TensorDesc> VarDesc::tensor_descs() const { ...@@ -220,15 +226,16 @@ std::vector<proto::TensorDesc> VarDesc::tensor_descs() const {
} }
} }
proto::TensorDesc *VarDesc::mutable_tensor_desc() { proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
switch (desc_.type()) { PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
case proto::VarDesc::SELECTED_ROWS: switch (desc_.type().type()) {
return desc_.mutable_selected_rows(); case proto::VarType::SELECTED_ROWS:
case proto::VarDesc::LOD_TENSOR: return desc_.mutable_type()->mutable_selected_rows();
return desc_.mutable_lod_tensor()->mutable_tensor(); case proto::VarType::LOD_TENSOR:
case proto::VarDesc::LOD_TENSOR_ARRAY: return desc_.mutable_type()->mutable_lod_tensor()->mutable_tensor();
return desc_.mutable_tensor_array()->mutable_tensor(); case proto::VarType::LOD_TENSOR_ARRAY:
return desc_.mutable_type()->mutable_tensor_array()->mutable_tensor();
default: default:
PADDLE_THROW( PADDLE_THROW(
"Getting 'mutable_tensor_desc' is not supported by the type of var " "Getting 'mutable_tensor_desc' is not supported by the type of var "
...@@ -237,13 +244,15 @@ proto::TensorDesc *VarDesc::mutable_tensor_desc() { ...@@ -237,13 +244,15 @@ proto::TensorDesc *VarDesc::mutable_tensor_desc() {
} }
} }
std::vector<proto::TensorDesc *> VarDesc::mutable_tensor_descs() { std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
std::vector<proto::TensorDesc *> res; PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
std::vector<proto::VarType::TensorDesc *> res;
res.reserve(GetTensorDescNum()); res.reserve(GetTensorDescNum());
switch (desc_.type()) { switch (desc_.type().type()) {
case proto::VarDesc::READER: case proto::VarType::READER:
for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) { for (auto &lod_tensor :
*desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) {
res.push_back(lod_tensor.mutable_tensor()); res.push_back(lod_tensor.mutable_tensor());
} }
return res; return res;
......
...@@ -57,7 +57,7 @@ class VarDesc { ...@@ -57,7 +57,7 @@ class VarDesc {
public: public:
explicit VarDesc(const std::string &name) { explicit VarDesc(const std::string &name) {
desc_.set_name(name); desc_.set_name(name);
desc_.set_type(proto::VarDesc::LOD_TENSOR); desc_.mutable_type()->set_type(proto::VarType::LOD_TENSOR);
} }
explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {} explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {}
...@@ -96,19 +96,19 @@ class VarDesc { ...@@ -96,19 +96,19 @@ class VarDesc {
std::vector<int32_t> GetLoDLevels() const; std::vector<int32_t> GetLoDLevels() const;
proto::VarDesc::VarType GetType() const; proto::VarType::Type GetType() const;
void SetType(proto::VarDesc::VarType type); void SetType(proto::VarType::Type type);
bool Persistable() const { return desc_.persistable(); } bool Persistable() const { return desc_.persistable(); }
void SetPersistable(bool persistable) { desc_.set_persistable(persistable); } void SetPersistable(bool persistable) { desc_.set_persistable(persistable); }
private: private:
const proto::TensorDesc &tensor_desc() const; const proto::VarType::TensorDesc &tensor_desc() const;
std::vector<proto::TensorDesc> tensor_descs() const; std::vector<proto::VarType::TensorDesc> tensor_descs() const;
proto::TensorDesc *mutable_tensor_desc(); proto::VarType::TensorDesc *mutable_tensor_desc();
std::vector<proto::TensorDesc *> mutable_tensor_descs(); std::vector<proto::VarType::TensorDesc *> mutable_tensor_descs();
proto::VarDesc desc_; proto::VarDesc desc_;
}; };
......
...@@ -23,17 +23,17 @@ limitations under the License. */ ...@@ -23,17 +23,17 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
inline proto::VarDesc::VarType ToVarType(std::type_index type) { inline proto::VarType::Type ToVarType(std::type_index type) {
if (type.hash_code() == typeid(LoDTensor).hash_code()) { if (type.hash_code() == typeid(LoDTensor).hash_code()) {
return proto::VarDesc_VarType_LOD_TENSOR; return proto::VarType_Type_LOD_TENSOR;
} else if (type.hash_code() == typeid(LoDRankTable).hash_code()) { } else if (type.hash_code() == typeid(LoDRankTable).hash_code()) {
return proto::VarDesc_VarType_LOD_RANK_TABLE; return proto::VarType_Type_LOD_RANK_TABLE;
} else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) { } else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) {
return proto::VarDesc_VarType_LOD_TENSOR_ARRAY; return proto::VarType_Type_LOD_TENSOR_ARRAY;
} else if (type.hash_code() == typeid(SelectedRows).hash_code()) { } else if (type.hash_code() == typeid(SelectedRows).hash_code()) {
return proto::VarDesc_VarType_SELECTED_ROWS; return proto::VarType_Type_SELECTED_ROWS;
} else if (type.hash_code() == typeid(ReaderHolder).hash_code()) { } else if (type.hash_code() == typeid(ReaderHolder).hash_code()) {
return proto::VarDesc_VarType_READER; return proto::VarType_Type_READER;
} else { } else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name()); PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
} }
...@@ -42,19 +42,19 @@ inline proto::VarDesc::VarType ToVarType(std::type_index type) { ...@@ -42,19 +42,19 @@ inline proto::VarDesc::VarType ToVarType(std::type_index type) {
template <typename Visitor> template <typename Visitor>
inline void VisitVarType(const framework::Variable& var, Visitor visitor) { inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
switch (ToVarType(var.Type())) { switch (ToVarType(var.Type())) {
case proto::VarDesc_VarType_LOD_TENSOR: case proto::VarType_Type_LOD_TENSOR:
visitor(var.Get<LoDTensor>()); visitor(var.Get<LoDTensor>());
return; return;
case proto::VarDesc_VarType_LOD_RANK_TABLE: case proto::VarType_Type_LOD_RANK_TABLE:
visitor(var.Get<LoDRankTable>()); visitor(var.Get<LoDRankTable>());
return; return;
case proto::VarDesc_VarType_LOD_TENSOR_ARRAY: case proto::VarType_Type_LOD_TENSOR_ARRAY:
visitor(var.Get<LoDTensorArray>()); visitor(var.Get<LoDTensorArray>());
return; return;
case proto::VarDesc_VarType_SELECTED_ROWS: case proto::VarType_Type_SELECTED_ROWS:
visitor(var.Get<SelectedRows>()); visitor(var.Get<SelectedRows>());
return; return;
case proto::VarDesc_VarType_READER: case proto::VarType_Type_READER:
visitor(var.Get<ReaderHolder>()); visitor(var.Get<ReaderHolder>());
return; return;
default: default:
......
...@@ -35,14 +35,14 @@ class SumOpVarTypeInference : public VarTypeInference { ...@@ -35,14 +35,14 @@ class SumOpVarTypeInference : public VarTypeInference {
public: public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override { void operator()(const OpDesc &op_desc, BlockDesc *block) const override {
auto &inputs = op_desc.Input("X"); auto &inputs = op_desc.Input("X");
auto default_var_type = proto::VarDesc::SELECTED_ROWS; auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of( bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string &name) { inputs.begin(), inputs.end(), [block](const std::string &name) {
return block->Var(name)->GetType() == proto::VarDesc::LOD_TENSOR; return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR;
}); });
if (any_input_is_lod_tensor) { if (any_input_is_lod_tensor) {
default_var_type = proto::VarDesc::LOD_TENSOR; default_var_type = proto::VarType::LOD_TENSOR;
} }
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = op_desc.Output("Out").front();
...@@ -67,19 +67,19 @@ TEST(InferVarType, sum_op) { ...@@ -67,19 +67,19 @@ TEST(InferVarType, sum_op) {
op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"}); op->SetOutput("Out", {"test_out"});
prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_out"); prog.MutableBlock(0)->Var("test_out");
op->InferVarType(prog.MutableBlock(0)); op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarDesc::SELECTED_ROWS, ASSERT_EQ(proto::VarType::SELECTED_ROWS,
prog.MutableBlock(0)->Var("test_out")->GetType()); prog.MutableBlock(0)->Var("test_out")->GetType());
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarDesc::LOD_TENSOR); prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::LOD_TENSOR);
op->InferVarType(prog.MutableBlock(0)); op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarDesc::LOD_TENSOR, ASSERT_EQ(proto::VarType::LOD_TENSOR,
prog.MutableBlock(0)->Var("test_out")->GetType()); prog.MutableBlock(0)->Var("test_out")->GetType());
} }
...@@ -90,14 +90,14 @@ TEST(InferVarType, sum_op_without_infer_var_type) { ...@@ -90,14 +90,14 @@ TEST(InferVarType, sum_op_without_infer_var_type) {
op->SetInput("X", {"test2_a", "test2_b", "test2_c"}); op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
op->SetOutput("Out", {"test2_out"}); op->SetOutput("Out", {"test2_out"});
prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarType::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_out"); prog.MutableBlock(0)->Var("test2_out");
op->InferVarType(prog.MutableBlock(0)); op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarDesc_VarType_LOD_TENSOR, ASSERT_EQ(proto::VarType_Type_LOD_TENSOR,
prog.MutableBlock(0)->Var("test2_out")->GetType()); prog.MutableBlock(0)->Var("test2_out")->GetType());
} }
......
...@@ -115,8 +115,8 @@ class AssignInferShape : public framework::InferShapeBase { ...@@ -115,8 +115,8 @@ class AssignInferShape : public framework::InferShapeBase {
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext *context) const override {
if (context->HasInput("X")) { if (context->HasInput("X")) {
auto type = context->GetInputsVarType("X")[0]; auto type = context->GetInputsVarType("X")[0];
if (type == framework::proto::VarDesc_VarType_SELECTED_ROWS || if (type == framework::proto::VarType::SELECTED_ROWS ||
type == framework::proto::VarDesc_VarType_LOD_TENSOR) { type == framework::proto::VarType::LOD_TENSOR) {
context->SetOutputDim("Out", context->GetInputDim("X")); context->SetOutputDim("Out", context->GetInputDim("X"));
} }
} }
......
...@@ -128,10 +128,10 @@ class BeamSearchDecodeInferVarType : public framework::VarTypeInference { ...@@ -128,10 +128,10 @@ class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDesc& op_desc, void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override { framework::BlockDesc* block) const override {
for (auto& o : op_desc.Output("SentenceIds")) { for (auto& o : op_desc.Output("SentenceIds")) {
block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR); block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
} }
for (auto& o : op_desc.Output("SentenceScores")) { for (auto& o : op_desc.Output("SentenceScores")) {
block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR); block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
} }
} }
}; };
......
...@@ -240,10 +240,10 @@ class BeamSearchInferVarType : public framework::VarTypeInference { ...@@ -240,10 +240,10 @@ class BeamSearchInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDesc &op_desc, void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override { framework::BlockDesc *block) const override {
for (auto &o : op_desc.Output("selected_ids")) { for (auto &o : op_desc.Output("selected_ids")) {
block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR); block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
} }
for (auto &o : op_desc.Output("selected_scores")) { for (auto &o : op_desc.Output("selected_scores")) {
block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR); block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
} }
} }
}; };
......
...@@ -84,7 +84,7 @@ class CreateFileReaderInferVarType : public framework::VarTypeInference { ...@@ -84,7 +84,7 @@ class CreateFileReaderInferVarType : public framework::VarTypeInference {
framework::BlockDesc* block) const override { framework::BlockDesc* block) const override {
std::string reader_name = op_desc.Output("Out")[0]; std::string reader_name = op_desc.Output("Out")[0];
framework::VarDesc* reader = block->FindVarRecursive(reader_name); framework::VarDesc* reader = block->FindVarRecursive(reader_name);
reader->SetType(framework::proto::VarDesc::READER); reader->SetType(framework::proto::VarType::READER);
} }
}; };
...@@ -97,7 +97,7 @@ class CreateDecoratedReaderInferVarType : public framework::VarTypeInference { ...@@ -97,7 +97,7 @@ class CreateDecoratedReaderInferVarType : public framework::VarTypeInference {
framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name); framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name);
std::string out_reader_name = op_desc.Output("Out")[0]; std::string out_reader_name = op_desc.Output("Out")[0];
framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name); framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name);
out_reader->SetType(framework::proto::VarDesc::READER); out_reader->SetType(framework::proto::VarType::READER);
out_reader->SetDataTypes(in_reader->GetDataTypes()); out_reader->SetDataTypes(in_reader->GetDataTypes());
} }
}; };
...@@ -147,7 +147,7 @@ class CreateRandomDataGeneratorOpMaker ...@@ -147,7 +147,7 @@ class CreateRandomDataGeneratorOpMaker
AddComment(R"DOC( AddComment(R"DOC(
CreateRandomDataGenerator Operator CreateRandomDataGenerator Operator
This Op creates a random reader. This Op creates a random reader.
The reader generates random data instead of really reading from files. The reader generates random data instead of really reading from files.
Generated data follow an uniform distribution between 'min' and 'max'. Generated data follow an uniform distribution between 'min' and 'max'.
)DOC"); )DOC");
...@@ -183,7 +183,7 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -183,7 +183,7 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
CreateShuffleReader Operator CreateShuffleReader Operator
A shuffle reader takes another reader as its 'underlying reader' A shuffle reader takes another reader as its 'underlying reader'
and yields the underlying reader's outputs in a shuffled order. and yields the underlying reader's outputs in a shuffled order.
)DOC"); )DOC");
} }
}; };
...@@ -218,8 +218,8 @@ class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -218,8 +218,8 @@ class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
CreateBatchReader Operator CreateBatchReader Operator
A batch reader takes another reader as its 'underlying reader', A batch reader takes another reader as its 'underlying reader',
gathers the underlying reader's outputs and then yields them in batches. gathers the underlying reader's outputs and then yields them in batches.
)DOC"); )DOC");
} }
}; };
......
...@@ -24,11 +24,11 @@ void SerializeToMessage(const std::string& name, const framework::Variable* var, ...@@ -24,11 +24,11 @@ void SerializeToMessage(const std::string& name, const framework::Variable* var,
msg->set_varname(name); msg->set_varname(name);
std::ostringstream oss; std::ostringstream oss;
switch (framework::ToVarType(var->Type())) { switch (framework::ToVarType(var->Type())) {
case framework::proto::VarDesc_VarType_LOD_TENSOR: case framework::proto::VarType_Type_LOD_TENSOR:
msg->set_type(sendrecv::VarType::LOD_TENSOR); msg->set_type(sendrecv::VarType::LOD_TENSOR);
framework::SerializeToStream(oss, var->Get<framework::LoDTensor>(), ctx); framework::SerializeToStream(oss, var->Get<framework::LoDTensor>(), ctx);
break; break;
case framework::proto::VarDesc_VarType_SELECTED_ROWS: case framework::proto::VarType_Type_SELECTED_ROWS:
msg->set_type(sendrecv::VarType::SELECTED_ROWS); msg->set_type(sendrecv::VarType::SELECTED_ROWS);
framework::SerializeToStream(oss, var->Get<framework::SelectedRows>(), framework::SerializeToStream(oss, var->Get<framework::SelectedRows>(),
ctx); ctx);
......
...@@ -314,7 +314,6 @@ EIGEN_FUNCTOR(Div, EIGEN_DIV); ...@@ -314,7 +314,6 @@ EIGEN_FUNCTOR(Div, EIGEN_DIV);
template <typename DeviceContext, typename T, typename functor, template <typename DeviceContext, typename T, typename functor,
typename broadcastfunctor, typename broadcast2functor> typename broadcastfunctor, typename broadcast2functor>
void ElementwiseGradCompute(const framework::ExecutionContext& ctx, void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* x,
const framework::Tensor* y, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* out,
......
...@@ -98,7 +98,7 @@ class GetPlacesInferVarType : public framework::VarTypeInference { ...@@ -98,7 +98,7 @@ class GetPlacesInferVarType : public framework::VarTypeInference {
framework::BlockDesc *block) const override { framework::BlockDesc *block) const override {
for (auto &o_name : op_desc.Output("Out")) { for (auto &o_name : op_desc.Output("Out")) {
block->FindRecursiveOrCreateVar(o_name).SetType( block->FindRecursiveOrCreateVar(o_name).SetType(
framework::proto::VarDesc::PLACE_LIST); framework::proto::VarType::PLACE_LIST);
} }
} }
}; };
......
...@@ -69,7 +69,7 @@ class LoDRankTableInferVarType : public framework::VarTypeInference { ...@@ -69,7 +69,7 @@ class LoDRankTableInferVarType : public framework::VarTypeInference {
framework::BlockDesc *block) const override { framework::BlockDesc *block) const override {
for (auto &o : op_desc.Output("Out")) { for (auto &o : op_desc.Output("Out")) {
block->FindRecursiveOrCreateVar(o).SetType( block->FindRecursiveOrCreateVar(o).SetType(
framework::proto::VarDesc::LOD_RANK_TABLE); framework::proto::VarType::LOD_RANK_TABLE);
} }
} }
}; };
......
...@@ -138,7 +138,7 @@ class LoDTensorToArrayInferVarType : public framework::VarTypeInference { ...@@ -138,7 +138,7 @@ class LoDTensorToArrayInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDesc &op_desc, void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override { framework::BlockDesc *block) const override {
for (auto &out_var : op_desc.Output("Out")) { for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY); block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR_ARRAY);
} }
} }
}; };
......
...@@ -123,11 +123,11 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference { ...@@ -123,11 +123,11 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows"; << " is set to SelectedRows";
block->Var(out_var_name) block->Var(out_var_name)
->SetType(framework::proto::VarDesc::SELECTED_ROWS); ->SetType(framework::proto::VarType::SELECTED_ROWS);
} else { } else {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor"; << " is set to LoDTensor";
block->Var(out_var_name)->SetType(framework::proto::VarDesc::LOD_TENSOR); block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR);
} }
} }
}; };
......
...@@ -46,7 +46,7 @@ struct Formater { ...@@ -46,7 +46,7 @@ struct Formater {
} }
private: private:
void PrintMessage() { CLOG << std::time(nullptr) << "\t" << message; } void PrintMessage() { CLOG << std::time(nullptr) << "\t" << message << "\t"; }
void PrintName() { void PrintName() {
if (!name.empty()) { if (!name.empty()) {
CLOG << "Tensor[" << name << "]" << std::endl; CLOG << "Tensor[" << name << "]" << std::endl;
...@@ -85,15 +85,16 @@ struct Formater { ...@@ -85,15 +85,16 @@ struct Formater {
// print float // print float
if (dtype.hash_code() == typeid(float).hash_code()) { if (dtype.hash_code() == typeid(float).hash_code()) {
Display<float>(size); Display<float>(size);
} } else if (dtype.hash_code() == typeid(double).hash_code()) {
if (dtype.hash_code() == typeid(double).hash_code()) {
Display<double>(size); Display<double>(size);
} } else if (dtype.hash_code() == typeid(int).hash_code()) {
if (dtype.hash_code() == typeid(int).hash_code()) {
Display<int>(size); Display<int>(size);
} } else if (dtype.hash_code() == typeid(int64_t).hash_code()) {
if (dtype.hash_code() == typeid(int64_t).hash_code()) {
Display<int64_t>(size); Display<int64_t>(size);
} else if (dtype.hash_code() == typeid(bool).hash_code()) {
Display<bool>(size);
} else {
CLOG << "\tdata: unprintable type: " << dtype.name() << std::endl;
} }
} }
...@@ -182,6 +183,7 @@ class TensorPrintOp : public framework::OperatorBase { ...@@ -182,6 +183,7 @@ class TensorPrintOp : public framework::OperatorBase {
} }
Formater formater; Formater formater;
formater.message = Attr<std::string>("message");
if (Attr<bool>("print_tensor_name")) { if (Attr<bool>("print_tensor_name")) {
formater.name = printed_var_name; formater.name = printed_var_name;
} }
......
...@@ -45,7 +45,7 @@ class ReadInferVarType : public framework::VarTypeInference { ...@@ -45,7 +45,7 @@ class ReadInferVarType : public framework::VarTypeInference {
PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size()); PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size());
for (size_t i = 0; i < dtypes.size(); ++i) { for (size_t i = 0; i < dtypes.size(); ++i) {
framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]); framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]);
out.SetType(framework::proto::VarDesc::LOD_TENSOR); out.SetType(framework::proto::VarType::LOD_TENSOR);
out.SetDataType(dtypes[i]); out.SetDataType(dtypes[i]);
} }
} }
......
...@@ -29,7 +29,7 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -29,7 +29,7 @@ class SumOp : public framework::OperatorWithKernel {
"Output(Out) of SumOp should not be null."); "Output(Out) of SumOp should not be null.");
if (ctx->IsRuntime() && if (ctx->IsRuntime() &&
ctx->GetOutputsVarType("Out")[0] == ctx->GetOutputsVarType("Out")[0] ==
framework::proto::VarDesc::LOD_TENSOR_ARRAY) { framework::proto::VarType::LOD_TENSOR_ARRAY) {
return; // skip runtime infershape when is tensor array; return; // skip runtime infershape when is tensor array;
} }
...@@ -118,7 +118,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -118,7 +118,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
void operator()(const framework::OpDesc& op_desc, void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override { framework::BlockDesc* block) const override {
auto& inputs = op_desc.Input("X"); auto& inputs = op_desc.Input("X");
auto var_type = framework::proto::VarDesc::SELECTED_ROWS; auto var_type = framework::proto::VarType::SELECTED_ROWS;
for (auto& name : op_desc.Input("X")) { for (auto& name : op_desc.Input("X")) {
VLOG(10) << name << " " VLOG(10) << name << " "
...@@ -128,12 +128,12 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -128,12 +128,12 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
bool any_input_is_lod_tensor = std::any_of( bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string& name) { inputs.begin(), inputs.end(), [block](const std::string& name) {
return block->FindRecursiveOrCreateVar(name).GetType() == return block->FindRecursiveOrCreateVar(name).GetType() ==
framework::proto::VarDesc::LOD_TENSOR; framework::proto::VarType::LOD_TENSOR;
}); });
auto is_tensor_array = [block](const std::string& name) { auto is_tensor_array = [block](const std::string& name) {
return block->FindRecursiveOrCreateVar(name).GetType() == return block->FindRecursiveOrCreateVar(name).GetType() ==
framework::proto::VarDesc::LOD_TENSOR_ARRAY; framework::proto::VarType::LOD_TENSOR_ARRAY;
}; };
bool any_input_is_tensor_array = bool any_input_is_tensor_array =
...@@ -151,9 +151,9 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -151,9 +151,9 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
PADDLE_ENFORCE(all_inputs_are_tensor_array, PADDLE_ENFORCE(all_inputs_are_tensor_array,
"Not all inputs are tensor array:\n%s", os.str()); "Not all inputs are tensor array:\n%s", os.str());
} }
var_type = framework::proto::VarDesc::LOD_TENSOR_ARRAY; var_type = framework::proto::VarType::LOD_TENSOR_ARRAY;
} else if (any_input_is_lod_tensor) { } else if (any_input_is_lod_tensor) {
var_type = framework::proto::VarDesc::LOD_TENSOR; var_type = framework::proto::VarType::LOD_TENSOR;
} }
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = op_desc.Output("Out").front();
......
...@@ -108,7 +108,7 @@ class WriteToArrayInferVarType : public framework::VarTypeInference { ...@@ -108,7 +108,7 @@ class WriteToArrayInferVarType : public framework::VarTypeInference {
auto out_name = op_desc.Output("Out")[0]; auto out_name = op_desc.Output("Out")[0];
VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY"; VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY";
auto &out = block->FindRecursiveOrCreateVar(out_name); auto &out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY); out.SetType(framework::proto::VarType::LOD_TENSOR_ARRAY);
auto *x = block->FindVarRecursive(x_name); auto *x = block->FindVarRecursive(x_name);
if (x != nullptr) { if (x != nullptr) {
out.SetDataType(x->GetDataType()); out.SetDataType(x->GetDataType());
......
...@@ -330,10 +330,10 @@ class WhileGradOpShapeInference : public framework::InferShapeBase { ...@@ -330,10 +330,10 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
continue; continue;
} }
auto dims = ctx->GetInputsElementDim(kX, i); auto dims = ctx->GetInputsElementDim(kX, i);
if (var_types[i] == framework::proto::VarDesc::LOD_TENSOR) { if (var_types[i] == framework::proto::VarType::LOD_TENSOR) {
names_to_set.push_back(pg_names[i]); names_to_set.push_back(pg_names[i]);
dims_to_set.push_back(dims); dims_to_set.push_back(dims);
} else if (var_types[i] == framework::proto::VarDesc::LOD_TENSOR_ARRAY) { } else if (var_types[i] == framework::proto::VarType::LOD_TENSOR_ARRAY) {
// not sure how to set the dim of LOD_TENSOR_ARRAY // not sure how to set the dim of LOD_TENSOR_ARRAY
names_to_set.push_back(pg_names[i]); names_to_set.push_back(pg_names[i]);
dims_to_set.push_back(dims); dims_to_set.push_back(dims);
......
...@@ -232,17 +232,17 @@ void BindVarDsec(py::module &m) { ...@@ -232,17 +232,17 @@ void BindVarDsec(py::module &m) {
.def("persistable", &VarDesc::Persistable) .def("persistable", &VarDesc::Persistable)
.def("set_persistable", &VarDesc::SetPersistable); .def("set_persistable", &VarDesc::SetPersistable);
py::enum_<proto::VarDesc::VarType>(var_desc, "VarType", "") py::enum_<proto::VarType::Type>(var_desc, "VarType", "")
.value("LOD_TENSOR", proto::VarDesc::LOD_TENSOR) .value("LOD_TENSOR", proto::VarType::LOD_TENSOR)
.value("SELECTED_ROWS", proto::VarDesc::SELECTED_ROWS) .value("SELECTED_ROWS", proto::VarType::SELECTED_ROWS)
.value("FEED_MINIBATCH", proto::VarDesc::FEED_MINIBATCH) .value("FEED_MINIBATCH", proto::VarType::FEED_MINIBATCH)
.value("FETCH_LIST", proto::VarDesc::FETCH_LIST) .value("FETCH_LIST", proto::VarType::FETCH_LIST)
.value("STEP_SCOPES", proto::VarDesc::STEP_SCOPES) .value("STEP_SCOPES", proto::VarType::STEP_SCOPES)
.value("LOD_RANK_TABLE", proto::VarDesc::LOD_RANK_TABLE) .value("LOD_RANK_TABLE", proto::VarType::LOD_RANK_TABLE)
.value("LOD_TENSOR_ARRAY", proto::VarDesc::LOD_TENSOR_ARRAY) .value("LOD_TENSOR_ARRAY", proto::VarType::LOD_TENSOR_ARRAY)
.value("PLACE_LIST", proto::VarDesc::PLACE_LIST) .value("PLACE_LIST", proto::VarType::PLACE_LIST)
.value("READER", proto::VarDesc::READER) .value("READER", proto::VarType::READER)
.value("NCCL_COM", proto::VarDesc::NCCL_COM); .value("NCCL_COM", proto::VarType::NCCL_COM);
} }
void BindOpDesc(py::module &m) { void BindOpDesc(py::module &m) {
......
...@@ -174,7 +174,7 @@ def Print(input, ...@@ -174,7 +174,7 @@ def Print(input,
print_tensor_type (bool): Print the tensor type. print_tensor_type (bool): Print the tensor type.
print_tensor_shape (bool): Print the tensor shape. print_tensor_shape (bool): Print the tensor shape.
print_tensor_lod (bool): Print the tensor lod. print_tensor_lod (bool): Print the tensor lod.
print_phase (bool): Which phase to displace, including 'forward', print_phase (str): Which phase to displace, including 'forward',
'backward' and 'both'. If set to 'backward' or 'both', will 'backward' and 'both'. If set to 'backward' or 'both', will
print the gradients of input tensor. print the gradients of input tensor.
......
...@@ -1579,7 +1579,7 @@ def layer_norm(input, ...@@ -1579,7 +1579,7 @@ def layer_norm(input,
""" """
**Layer Normalization** **Layer Normalization**
Assume feature vectors exist on dimensions Assume feature vectors exist on dimensions
:attr:`begin_norm_axis ... rank(input)` and calculate the moment statistics :attr:`begin_norm_axis ... rank(input)` and calculate the moment statistics
along these dimensions for each feature vector :math:`a` with size along these dimensions for each feature vector :math:`a` with size
:math:`H`, then normalize each feature vector using the corresponding :math:`H`, then normalize each feature vector using the corresponding
...@@ -1600,13 +1600,13 @@ def layer_norm(input, ...@@ -1600,13 +1600,13 @@ def layer_norm(input,
Args: Args:
input(Variable): The input tensor variable. input(Variable): The input tensor variable.
scale(bool): Whether to learn the adaptive gain :math:`g` after scale(bool): Whether to learn the adaptive gain :math:`g` after
normalization. normalization.
shift(bool): Whether to learn the adaptive bias :math:`b` after shift(bool): Whether to learn the adaptive bias :math:`b` after
normalization. normalization.
begin_norm_axis(bool): The normalization will be performed along begin_norm_axis(bool): The normalization will be performed along
dimensions from :attr:`begin_norm_axis` to :attr:`rank(input)`. dimensions from :attr:`begin_norm_axis` to :attr:`rank(input)`.
epsilon(float): The small value added to the variance to prevent epsilon(float): The small value added to the variance to prevent
division by zero. division by zero.
param_attr(ParamAttr|None): The parameter attribute for the learnable param_attr(ParamAttr|None): The parameter attribute for the learnable
gain :math:`g`. gain :math:`g`.
...@@ -2070,7 +2070,7 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None): ...@@ -2070,7 +2070,7 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None):
Tensor variable with a single element, otherwise must be in the Tensor variable with a single element, otherwise must be in the
range :math:`[-rank(input), rank(input))`. If :math:`dim < 0`, range :math:`[-rank(input), rank(input))`. If :math:`dim < 0`,
the dimension to reduce is :math:`rank + dim`. the dimension to reduce is :math:`rank + dim`.
keep_dim (bool): Whether to reserve the reduced dimension in the keep_dim (bool|False): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have one fewer dimension output Tensor. The result tensor will have one fewer dimension
than the :attr:`input` unless :attr:`keep_dim` is true. than the :attr:`input` unless :attr:`keep_dim` is true.
name(str|None): A name for this layer(optional). If set None, the layer name(str|None): A name for this layer(optional). If set None, the layer
...@@ -3098,33 +3098,33 @@ def multiplex(inputs, index): ...@@ -3098,33 +3098,33 @@ def multiplex(inputs, index):
def softmax_with_cross_entropy(logits, label, soft_label=False): def softmax_with_cross_entropy(logits, label, soft_label=False):
""" """
**Softmax With Cross Entropy Operator.** **Softmax With Cross Entropy Operator.**
Cross entropy loss with softmax is used as the output layer extensively. This Cross entropy loss with softmax is used as the output layer extensively. This
operator computes the softmax normalized values for each row of the input operator computes the softmax normalized values for each row of the input
tensor, after which cross-entropy loss is computed. This provides a more tensor, after which cross-entropy loss is computed. This provides a more
numerically stable gradient. numerically stable gradient.
Because this operator performs a softmax on logits internally, it expects Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results. softmax operator since that would produce incorrect results.
When the attribute soft_label is set false, this operators expects mutually When the attribute soft_label is set false, this operators expects mutually
exclusive hard labels, each sample in a batch is in exactly one class with a exclusive hard labels, each sample in a batch is in exactly one class with a
probability of 1.0. Each sample in the batch will have a single label. probability of 1.0. Each sample in the batch will have a single label.
The equation is as follows: The equation is as follows:
1) Hard label (one-hot label, so every sample has exactly one class) 1) Hard label (one-hot label, so every sample has exactly one class)
.. math:: .. math::
loss_j = -\\text{logit}_{label_j} + loss_j = -\\text{logit}_{label_j} +
\\log\\left(\\sum_{i=0}^{K}\\exp(\\text{logit}_i)\\right), j = 1,..., K \\log\\left(\\sum_{i=0}^{K}\\exp(\\text{logit}_i)\\right), j = 1,..., K
2) Soft label (each sample can have a distribution over all classes) 2) Soft label (each sample can have a distribution over all classes)
.. math:: .. math::
loss_j = -\\sum_{i=0}^{K}\\text{label}_i loss_j = -\\sum_{i=0}^{K}\\text{label}_i
\\left(\\text{logit}_i - \\log\\left(\\sum_{i=0}^{K} \\left(\\text{logit}_i - \\log\\left(\\sum_{i=0}^{K}
\\exp(\\text{logit}_i)\\right)\\right), j = 1,...,K \\exp(\\text{logit}_i)\\right)\\right), j = 1,...,K
...@@ -3169,7 +3169,7 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None): ...@@ -3169,7 +3169,7 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None):
The operator takes the first dimension of X and Y as batch size. The operator takes the first dimension of X and Y as batch size.
For each instance, it computes the smooth l1 loss element by element first For each instance, it computes the smooth l1 loss element by element first
and then sums all the losses. So the shape of Out is [batch_size, 1]. and then sums all the losses. So the shape of Out is [batch_size, 1].
Args: Args:
x (Variable): A tensor with rank at least 2. The input value of smooth x (Variable): A tensor with rank at least 2. The input value of smooth
l1 loss op with shape [batch_size, dim1, ..., dimN]. l1 loss op with shape [batch_size, dim1, ..., dimN].
......
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
if(NOT WITH_DISTRIBUTE)
list(REMOVE_ITEM TEST_OPS test_recv_op)
endif(NOT WITH_DISTRIBUTE)
list(REMOVE_ITEM TEST_OPS test_warpctc_op)
foreach(src ${TEST_OPS}) foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py) py_test(${src} SRCS ${src}.py)
endforeach() endforeach()
py_test(test_warpctc_op SRCS test_warpctc_op.py ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR})
add_subdirectory(unittests)
add_subdirectory(book) add_subdirectory(book)
add_subdirectory(book_distribute) add_subdirectory(book_distribute)
add_subdirectory(book_memory_optimization) add_subdirectory(book_memory_optimization)
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
if(NOT WITH_DISTRIBUTE)
list(REMOVE_ITEM TEST_OPS test_recv_op)
endif(NOT WITH_DISTRIBUTE)
list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290
list(REMOVE_ITEM TEST_OPS test_modified_huber_loss_op) # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5184
list(REMOVE_ITEM TEST_OPS test_lstm_unit_op) # # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5185
list(REMOVE_ITEM TEST_OPS test_nce) # IXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/7778
list(REMOVE_ITEM TEST_OPS test_recurrent_op) # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/6152
list(REMOVE_ITEM TEST_OPS test_cond_op) # FIXME(qijun): https://github.com/PaddlePaddle/Paddle/issues/5101#issuecomment-339814957
list(REMOVE_ITEM TEST_OPS test_detection_output_op) # FIXME: detection_output_op will be rewritten. This unittest should be
list(REMOVE_ITEM TEST_OPS op_test) # op_test is a helper python file, not a test
list(REMOVE_ITEM TEST_OPS decorators) # decorators is a helper python file, not a test
function(py_test_modules TARGET_NAME)
if(WITH_TESTING)
set(options "")
set(oneValueArgs "")
set(multiValueArgs MODULES DEPS ARGS ENVS)
cmake_parse_arguments(py_test_modules "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME}
COMMAND env PYTHONPATH=${PADDLE_PYTHON_BUILD_DIR}/lib-python ${py_test_modules_ENVS}
${PYTHON_EXECUTABLE} -u -m unittest --verbose ${py_test_modules_MODULES} ${py_test_modules_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif()
endfunction()
# test time consuming OPs in a separate process for expliot parallism
list(REMOVE_ITEM TEST_OPS test_warpctc_op)
list(REMOVE_ITEM TEST_OPS test_dyn_rnn)
list(REMOVE_ITEM TEST_OPS test_mul_op)
# tests that need to be run in separate process.
list(REMOVE_ITEM TEST_OPS test_multihead_attention)
list(REMOVE_ITEM TEST_OPS test_calc_gradient)
list(REMOVE_ITEM TEST_OPS test_while_op)
list(REMOVE_ITEM TEST_OPS test_lod_array_length_op)
list(REMOVE_ITEM TEST_OPS test_reorder_lod_tensor)
list(REMOVE_ITEM TEST_OPS test_profiler)
list(REMOVE_ITEM TEST_OPS test_normalization_wrapper)
list(REMOVE_ITEM TEST_OPS test_executor_and_mul)
list(REMOVE_ITEM TEST_OPS test_assign_value_op)
list(REMOVE_ITEM TEST_OPS test_array_read_write_op)
list(REMOVE_ITEM TEST_OPS test_lod_rank_table)
list(REMOVE_ITEM TEST_OPS test_weight_normalization)
list(REMOVE_ITEM TEST_OPS test_conditional_block)
list(REMOVE_ITEM TEST_OPS test_parameter)
list(REMOVE_ITEM TEST_OPS test_registry)
list(REMOVE_ITEM TEST_OPS test_fetch_var)
list(REMOVE_ITEM TEST_OPS test_parallel_op)
list(REMOVE_ITEM TEST_OPS test_dynrnn_static_input)
# tests that can be bundled together in one python process for speed.
if(WITH_FAST_BUNDLE_TEST)
py_test_modules("test_all_ops" MODULES ${TEST_OPS})
else()
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP)
endif(WITH_FAST_BUNDLE_TEST)
# tests with high overhead
py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR})
py_test_modules(test_train_dyn_rnn MODULES test_dyn_rnn)
py_test_modules(test_mul_op MODULES test_mul_op)
# tests that need to be run in separate process.
py_test_modules(test_multihead_attention MODULES test_multihead_attention)
py_test_modules(test_calc_gradient MODULES test_calc_gradient)
py_test_modules(test_while_op MODULES test_while_op)
py_test_modules(test_lod_array_length_op MODULES test_lod_array_length_op)
py_test_modules(test_reorder_lod_tensor MODULES test_reorder_lod_tensor)
py_test_modules(test_profiler MODULES test_profiler)
py_test_modules(test_normalization_wrapper MODULES test_normalization_wrapper)
py_test_modules(test_executor_and_mul MODULES test_executor_and_mul)
py_test_modules(test_assign_value_op MODULES test_assign_value_op)
py_test_modules(test_array_read_write_op MODULES test_array_read_write_op)
py_test_modules(test_lod_rank_table MODULES test_lod_rank_table)
py_test_modules(test_weight_normalization MODULES test_weight_normalization)
py_test_modules(test_conditional_block MODULES test_conditional_block)
py_test_modules(test_parameter MODULES test_parameter)
py_test_modules(test_registry MODULES test_registry)
py_test_modules(test_fetch_var MODULES test_fetch_var)
py_test_modules(test_dynrnn_static_input MODULES test_dynrnn_static_input)
py_test_modules(test_parallel_op MODULES test_parallel_op)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -125,7 +125,4 @@ class TestCondOp(unittest.TestCase): ...@@ -125,7 +125,4 @@ class TestCondOp(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
exit(
0
) # FIXME(qijun): https://github.com/PaddlePaddle/Paddle/issues/5101#issuecomment-339814957
unittest.main() unittest.main()
...@@ -68,6 +68,4 @@ class TestUnpoolOp(OpTest): ...@@ -68,6 +68,4 @@ class TestUnpoolOp(OpTest):
if __name__ == '__main__': if __name__ == '__main__':
# FIXME: detection_output_op will be rewritten. This unittest should be unittest.main()
# enabled after rewriting.
exit(0) # temporary disable this unittest
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册