提交 c7ad26d6 编写于 作者: A Abhinav Arora 提交者: kavyasrinet

[WIP] Move DataType enum inside VarType (#8447)

* Move Pod Types from DataType enum to Type enum

* Fixed data_type.h

* Fix type in TensorDesc

* Add comment to framework.proto

* Fixed type in data_type.h

* Updated format of type in data_type.h

* Fix var_desc.h

* Fix op_kernel_type.h

* Fixed data_type_transform_test.cc

* Fix operator.h

* Fixed data_type_transform.cc

* Fixed op_kernel_type_test.cc

* Fix operator.cc

* Fixed data_layout_transform_test.cc

* Fix var_desc.cc

* Fixed assign_value_op.cc

* Fixed assign_value_op.h

* fixed protobuf.cc

* Fix data_layout_transform_test.cc and op_kernel_type_test.cc

* Fixed rnn_memory_helper_op.cc

* Fix progrma_desc_test.cc

* Fixed fill_constant_batch_size_like_op.cc

* Fix operator_test.cc

* Fixed fill_constant_op.cc

* Fixed gaussian_random_op.cc

* Fixed uniform_random_op.cc

* Fixed edit_distance_op.cc

* Fixed fill_constant_batch_size_like_op.cc

* Fixed rnn_memory_helper_op.cc

* Fixed chunk_eval_op.cc

* Fixed assign_value_op.cc

* Fixed assign_value_op.h

* Fixed cast_op.h

* Fixed cast_op.h

* Fix fill constant op

* Fixed clang for assign_value_op.cc

* Fix one_hot_op.h

* Fix one_hot_op.cc

* Fix fill_op.cc

* Fixed sum_op.cc

* Fixed sum_op clang

* Fix uniform_random_op.cc

* Fix gaussian_random_op.cc

* Fix backward.cc

* Fix protobuf.cc

* Fixed prune_test.cc

* Fixed op_registry_test.cc

* Fix data_device_transform_test.cu

* Fix travis error

* Fixed one_hot_op.cu

* Fixed op_registry_test.cc

* Fixed nccl_op.cc

* Fixing python tests

* Revert "Fixing python tests"

This reverts commit fccaa4c5.

* Fixing Pybind to remove data type

* Fixing tensor.py

* Updated the new files:

* Resolve error in merge conflict of fill_constant_batch_size_like_op.cc
上级 74e0eb72
...@@ -341,7 +341,7 @@ static void CreateGradVarInBlock( ...@@ -341,7 +341,7 @@ static void CreateGradVarInBlock(
auto* param = block_desc->FindVarRecursive(pname); auto* param = block_desc->FindVarRecursive(pname);
auto* grad = block_desc->FindVar(arg); auto* grad = block_desc->FindVar(arg);
if (param == nullptr) { if (param == nullptr) {
grad->SetDataType(proto::DataType::FP32); grad->SetDataType(proto::VarType::FP32);
} else { } else {
grad->SetDataType(param->GetDataType()); grad->SetDataType(param->GetDataType());
} }
......
...@@ -51,10 +51,10 @@ class TestOpWithKernel : public OperatorWithKernel { ...@@ -51,10 +51,10 @@ class TestOpWithKernel : public OperatorWithKernel {
const ExecutionContext& ctx) const override { const ExecutionContext& ctx) const override {
if (Attr<bool>("use_gpu")) { if (Attr<bool>("use_gpu")) {
VLOG(3) << "force use gpu kernel"; VLOG(3) << "force use gpu kernel";
return OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0)); return OpKernelType(proto::VarType::FP32, platform::CUDAPlace(0));
} else { } else {
VLOG(3) << "use default kernel"; VLOG(3) << "use default kernel";
return OpKernelType(proto::DataType::FP32, return OpKernelType(proto::VarType::FP32,
ctx.Input<Tensor>("input")->place()); ctx.Input<Tensor>("input")->place());
} }
} }
......
...@@ -27,9 +27,9 @@ TEST(DataTransform, DataLayoutFunction) { ...@@ -27,9 +27,9 @@ TEST(DataTransform, DataLayoutFunction) {
in.mutable_data<double>(make_ddim({2, 3, 1, 2}), place); in.mutable_data<double>(make_ddim({2, 3, 1, 2}), place);
in.set_layout(DataLayout::kNHWC); in.set_layout(DataLayout::kNHWC);
auto kernel_nhwc = OpKernelType(proto::DataType::FP32, place, auto kernel_nhwc = OpKernelType(proto::VarType::FP32, place,
DataLayout::kNHWC, LibraryType::kPlain); DataLayout::kNHWC, LibraryType::kPlain);
auto kernel_ncwh = OpKernelType(proto::DataType::FP32, place, auto kernel_ncwh = OpKernelType(proto::VarType::FP32, place,
DataLayout::kNCHW, LibraryType::kPlain); DataLayout::kNCHW, LibraryType::kPlain);
TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out); TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out);
......
...@@ -20,35 +20,35 @@ limitations under the License. */ ...@@ -20,35 +20,35 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
inline proto::DataType ToDataType(std::type_index type) { inline proto::VarType::Type ToDataType(std::type_index type) {
using namespace paddle::framework::proto; using namespace paddle::framework::proto;
if (typeid(float).hash_code() == type.hash_code()) { if (typeid(float).hash_code() == type.hash_code()) {
return DataType::FP32; return proto::VarType::FP32;
} else if (typeid(double).hash_code() == type.hash_code()) { } else if (typeid(double).hash_code() == type.hash_code()) {
return DataType::FP64; return proto::VarType::FP64;
} else if (typeid(int).hash_code() == type.hash_code()) { } else if (typeid(int).hash_code() == type.hash_code()) {
return DataType::INT32; return proto::VarType::INT32;
} else if (typeid(int64_t).hash_code() == type.hash_code()) { } else if (typeid(int64_t).hash_code() == type.hash_code()) {
return DataType::INT64; return proto::VarType::INT64;
} else if (typeid(bool).hash_code() == type.hash_code()) { } else if (typeid(bool).hash_code() == type.hash_code()) {
return DataType::BOOL; return proto::VarType::BOOL;
} else { } else {
PADDLE_THROW("Not supported"); PADDLE_THROW("Not supported");
} }
} }
inline std::type_index ToTypeIndex(proto::DataType type) { inline std::type_index ToTypeIndex(proto::VarType::Type type) {
using namespace paddle::framework::proto; using namespace paddle::framework::proto;
switch (type) { switch (type) {
case DataType::FP32: case proto::VarType::FP32:
return typeid(float); return typeid(float);
case DataType::FP64: case proto::VarType::FP64:
return typeid(double); return typeid(double);
case DataType::INT32: case proto::VarType::INT32:
return typeid(int); return typeid(int);
case DataType::INT64: case proto::VarType::INT64:
return typeid(int64_t); return typeid(int64_t);
case DataType::BOOL: case proto::VarType::BOOL:
return typeid(bool); return typeid(bool);
default: default:
PADDLE_THROW("Not support type %d", type); PADDLE_THROW("Not support type %d", type);
...@@ -56,22 +56,22 @@ inline std::type_index ToTypeIndex(proto::DataType type) { ...@@ -56,22 +56,22 @@ inline std::type_index ToTypeIndex(proto::DataType type) {
} }
template <typename Visitor> template <typename Visitor>
inline void VisitDataType(proto::DataType type, Visitor visitor) { inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
using namespace paddle::framework::proto; using namespace paddle::framework::proto;
switch (type) { switch (type) {
case DataType::FP32: case proto::VarType::FP32:
visitor.template operator()<float>(); visitor.template operator()<float>();
break; break;
case DataType::FP64: case proto::VarType::FP64:
visitor.template operator()<double>(); visitor.template operator()<double>();
break; break;
case DataType::INT32: case proto::VarType::INT32:
visitor.template operator()<int>(); visitor.template operator()<int>();
break; break;
case DataType::INT64: case proto::VarType::INT64:
visitor.template operator()<int64_t>(); visitor.template operator()<int64_t>();
break; break;
case DataType::BOOL: case proto::VarType::BOOL:
visitor.template operator()<bool>(); visitor.template operator()<bool>();
break; break;
default: default:
...@@ -79,22 +79,22 @@ inline void VisitDataType(proto::DataType type, Visitor visitor) { ...@@ -79,22 +79,22 @@ inline void VisitDataType(proto::DataType type, Visitor visitor) {
} }
} }
inline std::string DataTypeToString(const proto::DataType type) { inline std::string DataTypeToString(const proto::VarType::Type type) {
using namespace paddle::framework::proto; using namespace paddle::framework::proto;
switch (type) { switch (type) {
case DataType::FP16: case proto::VarType::FP16:
return "float16"; return "float16";
case DataType::FP32: case proto::VarType::FP32:
return "float32"; return "float32";
case DataType::FP64: case proto::VarType::FP64:
return "float64"; return "float64";
case DataType::INT16: case proto::VarType::INT16:
return "int16"; return "int16";
case DataType::INT32: case proto::VarType::INT32:
return "int32"; return "int32";
case DataType::INT64: case proto::VarType::INT64:
return "int64"; return "int64";
case DataType::BOOL: case proto::VarType::BOOL:
return "bool"; return "bool";
default: default:
PADDLE_THROW("Not support type %d", type); PADDLE_THROW("Not support type %d", type);
...@@ -102,7 +102,7 @@ inline std::string DataTypeToString(const proto::DataType type) { ...@@ -102,7 +102,7 @@ inline std::string DataTypeToString(const proto::DataType type) {
} }
inline std::ostream& operator<<(std::ostream& out, inline std::ostream& operator<<(std::ostream& out,
const proto::DataType& type) { const proto::VarType::Type& type) {
out << DataTypeToString(type); out << DataTypeToString(type);
return out; return out;
} }
......
...@@ -65,19 +65,19 @@ void TransDataType(const OpKernelType& kernel_type_for_var, ...@@ -65,19 +65,19 @@ void TransDataType(const OpKernelType& kernel_type_for_var,
auto ctx = pool.Get(in.place()); auto ctx = pool.Get(in.place());
switch (src_type) { switch (src_type) {
case proto::DataType::FP32: case proto::VarType::FP32:
framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx)); framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx));
break; break;
case proto::DataType::FP64: case proto::VarType::FP64:
framework::VisitDataType(dst_type, CastDataType<double>(in, out, ctx)); framework::VisitDataType(dst_type, CastDataType<double>(in, out, ctx));
break; break;
case proto::DataType::INT32: case proto::VarType::INT32:
framework::VisitDataType(dst_type, CastDataType<int>(in, out, ctx)); framework::VisitDataType(dst_type, CastDataType<int>(in, out, ctx));
break; break;
case proto::DataType::INT64: case proto::VarType::INT64:
framework::VisitDataType(dst_type, CastDataType<int64_t>(in, out, ctx)); framework::VisitDataType(dst_type, CastDataType<int64_t>(in, out, ctx));
break; break;
case proto::DataType::BOOL: case proto::VarType::BOOL:
framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx)); framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx));
break; break;
default: default:
......
...@@ -32,11 +32,11 @@ TEST(DataTypeTransform, CPUTransform) { ...@@ -32,11 +32,11 @@ TEST(DataTypeTransform, CPUTransform) {
ptr[i] = i / 3; ptr[i] = i / 3;
} }
auto kernel_fp32 = OpKernelType(proto::DataType::FP32, place, auto kernel_fp32 = OpKernelType(proto::VarType::FP32, place,
DataLayout::kAnyLayout, LibraryType::kPlain); DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_fp64 = OpKernelType(proto::DataType::FP64, place, auto kernel_fp64 = OpKernelType(proto::VarType::FP64, place,
DataLayout::kAnyLayout, LibraryType::kPlain); DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_int32 = OpKernelType(proto::DataType::INT32, place, auto kernel_int32 = OpKernelType(proto::VarType::INT32, place,
DataLayout::kAnyLayout, LibraryType::kPlain); DataLayout::kAnyLayout, LibraryType::kPlain);
TransDataType(kernel_fp32, kernel_fp64, in, &out); TransDataType(kernel_fp32, kernel_fp64, in, &out);
......
...@@ -91,7 +91,9 @@ message OpProto { ...@@ -91,7 +91,9 @@ message OpProto {
required string comment = 5; required string comment = 5;
} }
enum DataType { message VarType {
enum Type {
// Pod Types
BOOL = 0; BOOL = 0;
INT16 = 1; INT16 = 1;
INT32 = 2; INT32 = 2;
...@@ -99,25 +101,24 @@ enum DataType { ...@@ -99,25 +101,24 @@ enum DataType {
FP16 = 4; FP16 = 4;
FP32 = 5; FP32 = 5;
FP64 = 6; FP64 = 6;
}
message VarType { // Other types that may need additional descriptions
enum Type { LOD_TENSOR = 7;
LOD_TENSOR = 1; SELECTED_ROWS = 8;
SELECTED_ROWS = 2; FEED_MINIBATCH = 9;
FEED_MINIBATCH = 3; FETCH_LIST = 10;
FETCH_LIST = 4; STEP_SCOPES = 11;
STEP_SCOPES = 5; LOD_RANK_TABLE = 12;
LOD_RANK_TABLE = 6; LOD_TENSOR_ARRAY = 13;
LOD_TENSOR_ARRAY = 7; PLACE_LIST = 14;
PLACE_LIST = 8; READER = 15;
READER = 9;
} }
required Type type = 1; required Type type = 1;
message TensorDesc { message TensorDesc {
required DataType data_type = 1; // Should only be PODType. Is enforced in C++
required Type data_type = 1;
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480] repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
} }
optional TensorDesc selected_rows = 2; optional TensorDesc selected_rows = 2;
......
...@@ -40,12 +40,12 @@ struct OpKernelType { ...@@ -40,12 +40,12 @@ struct OpKernelType {
// place, data_type, library_type kinds less than 2^8 // place, data_type, library_type kinds less than 2^8
constexpr static int LEFT_SHIFT = 8; constexpr static int LEFT_SHIFT = 8;
proto::DataType data_type_; proto::VarType::Type data_type_;
DataLayout data_layout_; DataLayout data_layout_;
platform::Place place_; platform::Place place_;
LibraryType library_type_; LibraryType library_type_;
OpKernelType(proto::DataType data_type, platform::Place place, OpKernelType(proto::VarType::Type data_type, platform::Place place,
DataLayout data_layout = DataLayout::kAnyLayout, DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain) LibraryType library_type = LibraryType::kPlain)
: data_type_(data_type), : data_type_(data_type),
...@@ -53,7 +53,7 @@ struct OpKernelType { ...@@ -53,7 +53,7 @@ struct OpKernelType {
place_(place), place_(place),
library_type_(library_type) {} library_type_(library_type) {}
OpKernelType(proto::DataType data_type, OpKernelType(proto::VarType::Type data_type,
const platform::DeviceContext& dev_ctx, const platform::DeviceContext& dev_ctx,
DataLayout data_layout = DataLayout::kAnyLayout, DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain) LibraryType library_type = LibraryType::kPlain)
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
TEST(OpKernelType, ToString) { TEST(OpKernelType, ToString) {
using OpKernelType = paddle::framework::OpKernelType; using OpKernelType = paddle::framework::OpKernelType;
using DataType = paddle::framework::proto::DataType; using DataType = paddle::framework::proto::VarType;
using CPUPlace = paddle::platform::CPUPlace; using CPUPlace = paddle::platform::CPUPlace;
using DataLayout = paddle::framework::DataLayout; using DataLayout = paddle::framework::DataLayout;
using LibraryType = paddle::framework::LibraryType; using LibraryType = paddle::framework::LibraryType;
...@@ -33,7 +33,7 @@ TEST(OpKernelType, ToString) { ...@@ -33,7 +33,7 @@ TEST(OpKernelType, ToString) {
TEST(OpKernelType, Hash) { TEST(OpKernelType, Hash) {
using OpKernelType = paddle::framework::OpKernelType; using OpKernelType = paddle::framework::OpKernelType;
using DataType = paddle::framework::proto::DataType; using DataType = paddle::framework::proto::VarType;
using CPUPlace = paddle::platform::CPUPlace; using CPUPlace = paddle::platform::CPUPlace;
using CUDAPlace = paddle::platform::CUDAPlace; using CUDAPlace = paddle::platform::CUDAPlace;
using DataLayout = paddle::framework::DataLayout; using DataLayout = paddle::framework::DataLayout;
......
...@@ -226,7 +226,7 @@ class OpWithKernelTest : public OperatorWithKernel { ...@@ -226,7 +226,7 @@ class OpWithKernelTest : public OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::DataType::FP32, ctx.device_context()); return framework::OpKernelType(proto::VarType::FP32, ctx.device_context());
} }
}; };
...@@ -290,8 +290,8 @@ class OpWithMultiKernelTest : public OperatorWithKernel { ...@@ -290,8 +290,8 @@ class OpWithMultiKernelTest : public OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(proto::VarType::FP32, platform::CUDAPlace(0),
proto::DataType::FP32, platform::CUDAPlace(0), DataLayout::kAnyLayout, DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN); framework::LibraryType::kCUDNN);
} }
}; };
......
...@@ -569,7 +569,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -569,7 +569,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
} }
} }
proto::DataType OperatorWithKernel::IndicateDataType( proto::VarType::Type OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
auto& scope = ctx.scope(); auto& scope = ctx.scope();
int data_type = -1; int data_type = -1;
...@@ -595,7 +595,7 @@ proto::DataType OperatorWithKernel::IndicateDataType( ...@@ -595,7 +595,7 @@ proto::DataType OperatorWithKernel::IndicateDataType(
} }
} }
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input"); PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
return static_cast<proto::DataType>(data_type); return static_cast<proto::VarType::Type>(data_type);
} }
OpKernelType OperatorWithKernel::GetExpectedKernelType( OpKernelType OperatorWithKernel::GetExpectedKernelType(
......
...@@ -394,9 +394,9 @@ class OperatorWithKernel : public OperatorBase { ...@@ -394,9 +394,9 @@ class OperatorWithKernel : public OperatorBase {
const OpKernelType& expected_kernel_type) const; const OpKernelType& expected_kernel_type) const;
private: private:
// indicate kernel DataType by input data. Defaultly all input data must be // indicate kernel DataType by input data. By default all input data must be
// same. // same.
proto::DataType IndicateDataType(const ExecutionContext& ctx) const; proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;
void RunImpl(const Scope& scope, const platform::Place& place) const final; void RunImpl(const Scope& scope, const platform::Place& place) const final;
}; };
......
...@@ -119,7 +119,7 @@ class OpWithKernelTest : public OperatorWithKernel { ...@@ -119,7 +119,7 @@ class OpWithKernelTest : public OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType( OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override { const ExecutionContext& ctx) const override {
return OpKernelType(proto::DataType::FP32, ctx.GetPlace()); return OpKernelType(proto::VarType::FP32, ctx.GetPlace());
} }
}; };
......
...@@ -24,13 +24,13 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -24,13 +24,13 @@ TEST(ProgramDesc, copy_ctor) {
auto* x = global_block->Var("X"); auto* x = global_block->Var("X");
x->SetType(proto::VarType::LOD_TENSOR); x->SetType(proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0); x->SetLoDLevel(0);
x->SetDataType(proto::FP32); x->SetDataType(proto::VarType::FP32);
x->SetShape({1000, 784}); x->SetShape({1000, 784});
auto* y = global_block->Var("Y"); auto* y = global_block->Var("Y");
y->SetType(proto::VarType::LOD_TENSOR); y->SetType(proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0); y->SetLoDLevel(0);
y->SetDataType(proto::FP32); y->SetDataType(proto::VarType::FP32);
y->SetShape({784, 100}); y->SetShape({784, 100});
auto* op = global_block->AppendOp(); auto* op = global_block->AppendOp();
...@@ -86,13 +86,13 @@ TEST(ProgramDescBind, serialize_and_deserialize) { ...@@ -86,13 +86,13 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
auto* x = global_block->Var("X"); auto* x = global_block->Var("X");
x->SetType(proto::VarType::LOD_TENSOR); x->SetType(proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0); x->SetLoDLevel(0);
x->SetDataType(proto::FP32); x->SetDataType(proto::VarType::FP32);
x->SetShape({1000, 784}); x->SetShape({1000, 784});
auto* y = global_block->Var("Y"); auto* y = global_block->Var("Y");
y->SetType(proto::VarType::LOD_TENSOR); y->SetType(proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0); y->SetLoDLevel(0);
y->SetDataType(proto::FP32); y->SetDataType(proto::VarType::FP32);
y->SetShape({784, 100}); y->SetShape({784, 100});
auto* op = global_block->AppendOp(); auto* op = global_block->AppendOp();
......
...@@ -34,7 +34,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs, ...@@ -34,7 +34,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
for (auto kv : outputs) { for (auto kv : outputs) {
for (auto v : kv.second) { for (auto v : kv.second) {
auto var = block->Var(v); auto var = block->Var(v);
var->SetDataType(paddle::framework::proto::DataType::FP32); var->SetDataType(paddle::framework::proto::VarType::FP32);
} }
} }
......
...@@ -87,12 +87,12 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const { ...@@ -87,12 +87,12 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
return res; return res;
} }
void VarDesc::SetDataType(proto::DataType data_type) { void VarDesc::SetDataType(proto::VarType::Type data_type) {
mutable_tensor_desc()->set_data_type(data_type); mutable_tensor_desc()->set_data_type(data_type);
} }
void VarDesc::SetDataTypes( void VarDesc::SetDataTypes(
const std::vector<proto::DataType> &multiple_data_type) { const std::vector<proto::VarType::Type> &multiple_data_type) {
if (multiple_data_type.size() != GetTensorDescNum()) { if (multiple_data_type.size() != GetTensorDescNum()) {
VLOG(3) << "WARNING: The number of given data types(" VLOG(3) << "WARNING: The number of given data types("
<< multiple_data_type.size() << multiple_data_type.size()
...@@ -108,13 +108,13 @@ void VarDesc::SetDataTypes( ...@@ -108,13 +108,13 @@ void VarDesc::SetDataTypes(
} }
} }
proto::DataType VarDesc::GetDataType() const { proto::VarType::Type VarDesc::GetDataType() const {
return tensor_desc().data_type(); return tensor_desc().data_type();
} }
std::vector<proto::DataType> VarDesc::GetDataTypes() const { std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
std::vector<proto::VarType::TensorDesc> descs = tensor_descs(); std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
std::vector<proto::DataType> res; std::vector<proto::VarType::Type> res;
res.reserve(descs.size()); res.reserve(descs.size());
for (const auto &tensor_desc : descs) { for (const auto &tensor_desc : descs) {
res.push_back(tensor_desc.data_type()); res.push_back(tensor_desc.data_type());
......
...@@ -80,13 +80,14 @@ class VarDesc { ...@@ -80,13 +80,14 @@ class VarDesc {
std::vector<std::vector<int64_t>> GetShapes() const; std::vector<std::vector<int64_t>> GetShapes() const;
void SetDataType(proto::DataType data_type); void SetDataType(proto::VarType::Type data_type);
void SetDataTypes(const std::vector<proto::DataType> &multiple_data_type); void SetDataTypes(
const std::vector<proto::VarType::Type> &multiple_data_type);
proto::DataType GetDataType() const; proto::VarType::Type GetDataType() const;
std::vector<proto::DataType> GetDataTypes() const; std::vector<proto::VarType::Type> GetDataTypes() const;
void SetLoDLevel(int32_t lod_level); void SetLoDLevel(int32_t lod_level);
......
...@@ -36,7 +36,8 @@ class AssignValueOp : public framework::OperatorWithKernel { ...@@ -36,7 +36,8 @@ class AssignValueOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::proto::DataType(ctx.Attr<int>("dtype")), ctx.GetPlace()); framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
} }
}; };
...@@ -49,8 +50,8 @@ class AssignValueOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -49,8 +50,8 @@ class AssignValueOpMaker : public framework::OpProtoAndCheckerMaker {
"(vector<int>) " "(vector<int>) "
"Shape of values."); "Shape of values.");
AddAttr<int>("dtype", "data type of values") AddAttr<int>("dtype", "data type of values")
.InEnum({framework::proto::DataType::INT32, .InEnum({framework::proto::VarType::INT32,
framework::proto::DataType::FP32}); framework::proto::VarType::FP32});
AddAttr<std::vector<float>>("fp32_values", "store the float values") AddAttr<std::vector<float>>("fp32_values", "store the float values")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<int>>("int32_values", "store the int values") AddAttr<std::vector<int>>("int32_values", "store the int values")
......
...@@ -30,10 +30,10 @@ class AssignValueKernel : public framework::OpKernel<T> { ...@@ -30,10 +30,10 @@ class AssignValueKernel : public framework::OpKernel<T> {
int dtype = ctx.Attr<int>("dtype"); int dtype = ctx.Attr<int>("dtype");
const char* value_name = nullptr; const char* value_name = nullptr;
switch (dtype) { switch (dtype) {
case framework::proto::DataType::INT32: case framework::proto::VarType::INT32:
value_name = "int32_values"; value_name = "int32_values";
break; break;
case framework::proto::DataType::FP32: case framework::proto::VarType::FP32:
value_name = "fp32_values"; value_name = "fp32_values";
break; break;
default: default:
......
...@@ -55,7 +55,8 @@ class CastOpKernel : public framework::OpKernel<InT> { ...@@ -55,7 +55,8 @@ class CastOpKernel : public framework::OpKernel<InT> {
auto* in = context.Input<framework::Tensor>("X"); auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
framework::VisitDataType( framework::VisitDataType(
static_cast<framework::proto::DataType>(context.Attr<int>("out_dtype")), static_cast<framework::proto::VarType::Type>(
context.Attr<int>("out_dtype")),
CastOpFunctor<DeviceContext, InT>( CastOpFunctor<DeviceContext, InT>(
in, out, context.template device_context<DeviceContext>())); in, out, context.template device_context<DeviceContext>()));
} }
......
...@@ -57,7 +57,7 @@ class ChunkEvalOp : public framework::OperatorWithKernel { ...@@ -57,7 +57,7 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::proto::DataType::FP32, return framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -42,7 +42,7 @@ class EditDistanceOp : public framework::OperatorWithKernel { ...@@ -42,7 +42,7 @@ class EditDistanceOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::proto::DataType::FP32, return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -24,7 +24,7 @@ class FillConstantBatchSizeLikeOp : public BatchSizeLikeOp { ...@@ -24,7 +24,7 @@ class FillConstantBatchSizeLikeOp : public BatchSizeLikeOp {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")), static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.device_context()); ctx.device_context());
} }
}; };
...@@ -36,7 +36,7 @@ class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker { ...@@ -36,7 +36,7 @@ class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5 (FP32)) " "(int, default 5 (FP32)) "
"Output data type") "Output data type")
.SetDefault(framework::proto::DataType::FP32); .SetDefault(framework::proto::VarType::FP32);
AddAttr<float>("value", "(float, default 0) The value to be filled") AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f); .SetDefault(0.0f);
AddComment(R"DOC( AddComment(R"DOC(
......
...@@ -38,7 +38,7 @@ class FillConstantOp : public framework::OperatorBase { ...@@ -38,7 +38,7 @@ class FillConstantOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
auto data_type = auto data_type =
static_cast<framework::proto::DataType>(Attr<int>("dtype")); static_cast<framework::proto::VarType::Type>(Attr<int>("dtype"));
auto value = Attr<float>("value"); auto value = Attr<float>("value");
auto force_cpu = Attr<bool>("force_cpu"); auto force_cpu = Attr<bool>("force_cpu");
auto &out = auto &out =
...@@ -64,7 +64,7 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -64,7 +64,7 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5 (FP32)) " "(int, default 5 (FP32)) "
"Output data type") "Output data type")
.SetDefault(framework::proto::DataType::FP32); .SetDefault(framework::proto::VarType::FP32);
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output"); AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
AddAttr<float>("value", "(float, default 0) The value to be filled") AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f); .SetDefault(0.0f);
......
...@@ -51,7 +51,8 @@ class FillOp : public framework::OperatorBase { ...@@ -51,7 +51,8 @@ class FillOp : public framework::OperatorBase {
"Cannot find variable %s", Output("Out")) "Cannot find variable %s", Output("Out"))
.GetMutable<framework::LoDTensor>()); .GetMutable<framework::LoDTensor>());
out.Resize(framework::make_ddim(Attr<std::vector<int>>("shape"))); out.Resize(framework::make_ddim(Attr<std::vector<int>>("shape")));
auto dtype = static_cast<framework::proto::DataType>(Attr<int>("dtype")); auto dtype =
static_cast<framework::proto::VarType::Type>(Attr<int>("dtype"));
platform::CPUPlace cpu; platform::CPUPlace cpu;
auto force_cpu = Attr<bool>("force_cpu"); auto force_cpu = Attr<bool>("force_cpu");
out.mutable_data(force_cpu ? cpu : place, framework::ToTypeIndex(dtype)); out.mutable_data(force_cpu ? cpu : place, framework::ToTypeIndex(dtype));
...@@ -93,7 +94,7 @@ Fill an tensor with `value` and `shape`. The type of the tensor is specify by ...@@ -93,7 +94,7 @@ Fill an tensor with `value` and `shape`. The type of the tensor is specify by
"value", "The float values of tensor, which are flatten in row major"); "value", "The float values of tensor, which are flatten in row major");
AddAttr<std::vector<int>>("shape", "The shape of output tensor"); AddAttr<std::vector<int>>("shape", "The shape of output tensor");
AddAttr<int>("dtype", "The data type of output tensor, Default is float") AddAttr<int>("dtype", "The data type of output tensor, Default is float")
.SetDefault(framework::proto::DataType::FP32); .SetDefault(framework::proto::VarType::FP32);
AddAttr<bool>("force_cpu", AddAttr<bool>("force_cpu",
"Whether the output tensor must be at CPU memory or not. " "Whether the output tensor must be at CPU memory or not. "
"Default is false.") "Default is false.")
......
...@@ -26,7 +26,7 @@ class GaussianRandomBatchSizeLikeOp : public BatchSizeLikeOp { ...@@ -26,7 +26,7 @@ class GaussianRandomBatchSizeLikeOp : public BatchSizeLikeOp {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")), static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
...@@ -53,7 +53,7 @@ class GaussianRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker { ...@@ -53,7 +53,7 @@ class GaussianRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5(FP32)) " "(int, default 5(FP32)) "
"Output data type.") "Output data type.")
.SetDefault(framework::proto::DataType::FP32); .SetDefault(framework::proto::VarType::FP32);
AddComment(R"DOC( AddComment(R"DOC(
GaussianRandom Operator. GaussianRandom Operator.
......
...@@ -63,7 +63,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -63,7 +63,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")), static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.device_context()); ctx.device_context());
} }
}; };
...@@ -95,7 +95,7 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -95,7 +95,7 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5(FP32)) " "(int, default 5(FP32)) "
"Output data type.") "Output data type.")
.SetDefault(framework::proto::DataType::FP32); .SetDefault(framework::proto::VarType::FP32);
AddComment(R"DOC( AddComment(R"DOC(
GaussianRandom Operator. GaussianRandom Operator.
......
...@@ -55,7 +55,7 @@ class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -55,7 +55,7 @@ class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5 (FP32)) " "(int, default 5 (FP32)) "
"Output data type") "Output data type")
.SetDefault(framework::proto::DataType::FP32); .SetDefault(framework::proto::VarType::FP32);
AddComment(R"DOC( AddComment(R"DOC(
NCCLInit Operator. NCCLInit Operator.
......
...@@ -60,7 +60,7 @@ class OneHotOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -60,7 +60,7 @@ class OneHotOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("dtype", AddAttr<int>("dtype",
"An integer to specify the data type of one-hot " "An integer to specify the data type of one-hot "
"vector. The default value is FP32.") "vector. The default value is FP32.")
.SetDefault(paddle::framework::proto::DataType::FP32); .SetDefault(paddle::framework::proto::VarType::FP32);
AddComment(R"DOC( AddComment(R"DOC(
One Hot Operator. This operator creates the one-hot representations for input One Hot Operator. This operator creates the one-hot representations for input
index values. The following example will help to explain the function of this index values. The following example will help to explain the function of this
......
...@@ -65,7 +65,8 @@ class OneHotCUDAKernel : public framework::OpKernel<T> { ...@@ -65,7 +65,8 @@ class OneHotCUDAKernel : public framework::OpKernel<T> {
int depth = context.Attr<int>("depth"); int depth = context.Attr<int>("depth");
framework::VisitDataType( framework::VisitDataType(
static_cast<framework::proto::DataType>(context.Attr<int>("dtype")), static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype")),
OneHotOpCUDAFunctor<DeviceContext, T>( OneHotOpCUDAFunctor<DeviceContext, T>(
in, out, depth, context.template device_context<DeviceContext>())); in, out, depth, context.template device_context<DeviceContext>()));
} }
......
...@@ -58,7 +58,8 @@ class OneHotKernel : public framework::OpKernel<T> { ...@@ -58,7 +58,8 @@ class OneHotKernel : public framework::OpKernel<T> {
int depth = context.Attr<int>("depth"); int depth = context.Attr<int>("depth");
framework::VisitDataType( framework::VisitDataType(
static_cast<framework::proto::DataType>(context.Attr<int>("dtype")), static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype")),
OneHotOpFunctor<DeviceContext, T>( OneHotOpFunctor<DeviceContext, T>(
in, out, depth, context.template device_context<DeviceContext>())); in, out, depth, context.template device_context<DeviceContext>()));
} }
......
...@@ -66,7 +66,7 @@ class RNNMemoryHelperOpInfoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -66,7 +66,7 @@ class RNNMemoryHelperOpInfoMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5 (FP32)) " "(int, default 5 (FP32)) "
"Output data type") "Output data type")
.SetDefault(framework::proto::DataType::FP32); .SetDefault(framework::proto::VarType::FP32);
AddComment(""); AddComment("");
} }
}; };
...@@ -126,7 +126,7 @@ class RNNMemoryHelperGradOpInfoMaker ...@@ -126,7 +126,7 @@ class RNNMemoryHelperGradOpInfoMaker
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5 (FP32)) " "(int, default 5 (FP32)) "
"Output data type") "Output data type")
.SetDefault(framework::proto::DataType::FP32); .SetDefault(framework::proto::VarType::FP32);
AddComment(""); AddComment("");
} }
}; };
......
...@@ -73,7 +73,8 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -73,7 +73,8 @@ class SumOp : public framework::OperatorWithKernel {
"Sum operator should have at least one tensor"); "Sum operator should have at least one tensor");
return framework::OpKernelType( return framework::OpKernelType(
static_cast<framework::proto::DataType>(dtype), ctx.device_context()); static_cast<framework::proto::VarType::Type>(dtype),
ctx.device_context());
} else if (x_vars[0]->IsType<framework::SelectedRows>()) { } else if (x_vars[0]->IsType<framework::SelectedRows>()) {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( framework::ToDataType(
......
...@@ -26,7 +26,7 @@ class UniformRandomBatchSizeLikeOp : public BatchSizeLikeOp { ...@@ -26,7 +26,7 @@ class UniformRandomBatchSizeLikeOp : public BatchSizeLikeOp {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")), static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
...@@ -58,7 +58,7 @@ This operator initializes a tensor with the same batch_size as the Input tensor ...@@ -58,7 +58,7 @@ This operator initializes a tensor with the same batch_size as the Input tensor
"generate the same random numbers every time.") "generate the same random numbers every time.")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("dtype", "(int, default 5(FP32)) Output tensor data type") AddAttr<int>("dtype", "(int, default 5(FP32)) Output tensor data type")
.SetDefault(framework::proto::DataType::FP32); .SetDefault(framework::proto::VarType::FP32);
} }
}; };
......
...@@ -66,7 +66,7 @@ class UniformRandomOp : public framework::OperatorWithKernel { ...@@ -66,7 +66,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")), static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
...@@ -101,7 +101,7 @@ uniform distribution. ...@@ -101,7 +101,7 @@ uniform distribution.
"generate the same random numbers every time.") "generate the same random numbers every time.")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("dtype", "(int, default 5(FP32)) Output tensor data type") AddAttr<int>("dtype", "(int, default 5(FP32)) Output tensor data type")
.SetDefault(framework::proto::DataType::FP32); .SetDefault(framework::proto::VarType::FP32);
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -195,15 +195,6 @@ void BindBlockDesc(py::module &m) { ...@@ -195,15 +195,6 @@ void BindBlockDesc(py::module &m) {
} }
void BindVarDsec(py::module &m) { void BindVarDsec(py::module &m) {
py::enum_<proto::DataType>(m, "DataType", "")
.value("BOOL", proto::DataType::BOOL)
.value("INT16", proto::DataType::INT16)
.value("INT32", proto::DataType::INT32)
.value("INT64", proto::DataType::INT64)
.value("FP16", proto::DataType::FP16)
.value("FP32", proto::DataType::FP32)
.value("FP64", proto::DataType::FP64);
py::class_<VarDesc> var_desc(m, "VarDesc", ""); py::class_<VarDesc> var_desc(m, "VarDesc", "");
var_desc var_desc
.def("name", .def("name",
...@@ -233,6 +224,13 @@ void BindVarDsec(py::module &m) { ...@@ -233,6 +224,13 @@ void BindVarDsec(py::module &m) {
.def("set_persistable", &VarDesc::SetPersistable); .def("set_persistable", &VarDesc::SetPersistable);
py::enum_<proto::VarType::Type>(var_desc, "VarType", "") py::enum_<proto::VarType::Type>(var_desc, "VarType", "")
.value("BOOL", proto::VarType::BOOL)
.value("INT16", proto::VarType::INT16)
.value("INT32", proto::VarType::INT32)
.value("INT64", proto::VarType::INT64)
.value("FP16", proto::VarType::FP16)
.value("FP32", proto::VarType::FP32)
.value("FP64", proto::VarType::FP64)
.value("LOD_TENSOR", proto::VarType::LOD_TENSOR) .value("LOD_TENSOR", proto::VarType::LOD_TENSOR)
.value("SELECTED_ROWS", proto::VarType::SELECTED_ROWS) .value("SELECTED_ROWS", proto::VarType::SELECTED_ROWS)
.value("FEED_MINIBATCH", proto::VarType::FEED_MINIBATCH) .value("FEED_MINIBATCH", proto::VarType::FEED_MINIBATCH)
......
...@@ -68,7 +68,7 @@ def _infer_var_data_type_(grad_var_name, block): ...@@ -68,7 +68,7 @@ def _infer_var_data_type_(grad_var_name, block):
fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii")) fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii"))
grad_var.set_dtype(fwd_var.dtype()) grad_var.set_dtype(fwd_var.dtype())
else: else:
grad_var.set_dtype(core.DataType.FP32) grad_var.set_dtype(core.VarDesc.VarType.FP32)
def _all_in_set_(cands, s): def _all_in_set_(cands, s):
......
...@@ -27,13 +27,13 @@ class DataToLoDTensorConverter(object): ...@@ -27,13 +27,13 @@ class DataToLoDTensorConverter(object):
self.place = place self.place = place
self.lod_level = lod_level self.lod_level = lod_level
self.shape = shape self.shape = shape
if dtype == core.DataType.FP32: if dtype == core.VarDesc.VarType.FP32:
self.dtype = 'float32' self.dtype = 'float32'
elif dtype == core.DataType.INT64: elif dtype == core.VarDesc.VarType.INT64:
self.dtype = 'int64' self.dtype = 'int64'
elif dtype == core.DataType.FP64: elif dtype == core.VarDesc.VarType.FP64:
self.dtype = 'float64' self.dtype = 'float64'
elif dtype == core.DataType.INT32: elif dtype == core.VarDesc.VarType.INT32:
self.dtype = 'int32' self.dtype = 'int32'
else: else:
raise ValueError("dtype must be any of [int32, float32, int64, " raise ValueError("dtype must be any of [int32, float32, int64, "
......
...@@ -89,7 +89,7 @@ class Evaluator(object): ...@@ -89,7 +89,7 @@ class Evaluator(object):
Args: Args:
suffix(str): the state suffix. suffix(str): the state suffix.
dtype(str|core.DataType): the state data type dtype(str|core.VarDesc.VarType): the state data type
shape(tuple|list): the shape of state shape(tuple|list): the shape of state
Returns: State variable Returns: State variable
......
...@@ -67,24 +67,24 @@ def convert_np_dtype_to_dtype_(np_dtype): ...@@ -67,24 +67,24 @@ def convert_np_dtype_to_dtype_(np_dtype):
Args: Args:
np_dtype(np.dtype): the data type in numpy np_dtype(np.dtype): the data type in numpy
Returns(core.DataType): the data type in Paddle Returns(core.VarDesc.VarType): the data type in Paddle
""" """
dtype = np.dtype(np_dtype) dtype = np.dtype(np_dtype)
if dtype == np.float32: if dtype == np.float32:
return core.DataType.FP32 return core.VarDesc.VarType.FP32
elif dtype == np.float64: elif dtype == np.float64:
return core.DataType.FP64 return core.VarDesc.VarType.FP64
elif dtype == np.float16: elif dtype == np.float16:
return core.DataType.FP16 return core.VarDesc.VarType.FP16
elif dtype == np.int32: elif dtype == np.int32:
return core.DataType.INT32 return core.VarDesc.VarType.INT32
elif dtype == np.int16: elif dtype == np.int16:
return core.DataType.INT16 return core.VarDesc.VarType.INT16
elif dtype == np.int64: elif dtype == np.int64:
return core.DataType.INT64 return core.VarDesc.VarType.INT64
elif dtype == np.bool: elif dtype == np.bool:
return core.DataType.BOOL return core.VarDesc.VarType.BOOL
else: else:
raise ValueError("Not supported numpy dtype " + str(dtype)) raise ValueError("Not supported numpy dtype " + str(dtype))
...@@ -93,16 +93,19 @@ def dtype_is_floating(dtype): ...@@ -93,16 +93,19 @@ def dtype_is_floating(dtype):
""" """
Check the data type is floating or not. Check the data type is floating or not.
Args: Args:
dtype(np.dtype|core.DataType): data type. dtype(np.dtype|core.VarDesc.VarType): data type.
Could be numpy format or Paddle format Could be numpy format or Paddle format
Returns(bool): True if data type is a float value Returns(bool): True if data type is a float value
""" """
if not isinstance(dtype, core.DataType): if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
return dtype in [core.DataType.FP16, core.DataType.FP32, core.DataType.FP64] return dtype in [
core.VarDesc.VarType.FP16, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP64
]
def _debug_string_(proto, throw_on_error=True): def _debug_string_(proto, throw_on_error=True):
...@@ -148,7 +151,7 @@ class Variable(object): ...@@ -148,7 +151,7 @@ class Variable(object):
framework.proto for details. framework.proto for details.
shape(tuple|list|None): The shape of variable. -1 means the batch size. shape(tuple|list|None): The shape of variable. -1 means the batch size.
Some kinds of variable do not contain shape, just set it to None. Some kinds of variable do not contain shape, just set it to None.
dtype(np.dtype|core.DataType|str): The data type of variable. dtype(np.dtype|core.VarDesc.VarType|str): The data type of variable.
lod_level(int): The level of lod tensor. 0 means there is not a time lod_level(int): The level of lod tensor. 0 means there is not a time
series data. series data.
persistable(bool): True if the variable should be saved as check point. persistable(bool): True if the variable should be saved as check point.
...@@ -200,7 +203,7 @@ class Variable(object): ...@@ -200,7 +203,7 @@ class Variable(object):
"shape is {1}; the new shape is {2}. They are not " "shape is {1}; the new shape is {2}. They are not "
"matched.".format(self.name, old_shape, shape)) "matched.".format(self.name, old_shape, shape))
if dtype is not None: if dtype is not None:
if not isinstance(dtype, core.DataType): if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
if is_new_var: if is_new_var:
self.desc.set_dtype(dtype) self.desc.set_dtype(dtype)
......
...@@ -612,7 +612,7 @@ class While(object): ...@@ -612,7 +612,7 @@ class While(object):
if not isinstance(cond, Variable): if not isinstance(cond, Variable):
raise TypeError("condition should be a variable") raise TypeError("condition should be a variable")
assert isinstance(cond, Variable) assert isinstance(cond, Variable)
if cond.dtype != core.DataType.BOOL: if cond.dtype != core.VarDesc.VarType.BOOL:
raise TypeError("condition should be a bool variable") raise TypeError("condition should be a bool variable")
if reduce(lambda a, b: a * b, cond.shape, 1) != 1: if reduce(lambda a, b: a * b, cond.shape, 1) != 1:
raise TypeError("condition should be a bool scalar") raise TypeError("condition should be a bool scalar")
......
...@@ -221,7 +221,7 @@ def embedding(input, ...@@ -221,7 +221,7 @@ def embedding(input,
:math:`padding_idx < 0`, the padding_idx to use in lookup is :math:`padding_idx < 0`, the padding_idx to use in lookup is
:math:`size[0] + dim`. :math:`size[0] + dim`.
param_attr(ParamAttr): Parameters for this layer param_attr(ParamAttr): Parameters for this layer
dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc dtype(np.dtype|core.VarDesc.VarType|str): The type of data : float32, float_16, int etc
Returns: Returns:
Variable: The tensor variable storing the embeddings of the \ Variable: The tensor variable storing the embeddings of the \
......
...@@ -17,7 +17,7 @@ from ..param_attr import ParamAttr ...@@ -17,7 +17,7 @@ from ..param_attr import ParamAttr
from ..framework import convert_np_dtype_to_dtype_ from ..framework import convert_np_dtype_to_dtype_
from ..framework import Variable from ..framework import Variable
from ..initializer import Constant, force_init_on_cpu from ..initializer import Constant, force_init_on_cpu
from ..core import DataType from ..core import VarDesc
import numpy import numpy
__all__ = [ __all__ = [
...@@ -199,10 +199,10 @@ def assign(input, output): ...@@ -199,10 +199,10 @@ def assign(input, output):
attrs={'scale': 1.0}) attrs={'scale': 1.0})
elif isinstance(input, numpy.ndarray): elif isinstance(input, numpy.ndarray):
dtype = convert_np_dtype_to_dtype_(input.dtype) dtype = convert_np_dtype_to_dtype_(input.dtype)
if dtype == DataType.FP32: if dtype == VarDesc.VarType.FP32:
value_name = "fp32_values" value_name = "fp32_values"
values = [float(v) for v in input.flat] values = [float(v) for v in input.flat]
elif dtype == DataType.INT32: elif dtype == VarDesc.VarType.INT32:
value_name = "int32_values" value_name = "int32_values"
values = [int(v) for v in input.flat] values = [int(v) for v in input.flat]
else: else:
...@@ -236,7 +236,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): ...@@ -236,7 +236,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
Args: Args:
shape(tuple|list|None): Shape of the output tensor. shape(tuple|list|None): Shape of the output tensor.
dtype(np.dtype|core.DataType|str): Data type of the output tensor. dtype(np.dtype|core.VarDesc.VarType|str): Data type of the output tensor.
value(float): The constant value used to initialize the output tensor. value(float): The constant value used to initialize the output tensor.
out(Variable): The output tensor. out(Variable): The output tensor.
force_cpu(True|False): data should be on CPU if set true. force_cpu(True|False): data should be on CPU if set true.
...@@ -285,7 +285,7 @@ def fill_constant_batch_size_like(input, ...@@ -285,7 +285,7 @@ def fill_constant_batch_size_like(input,
Args: Args:
input(Variable): Tensor whose dimensions will be used to get batch size input(Variable): Tensor whose dimensions will be used to get batch size
shape(tuple|list|None): Shape of output tensor shape(tuple|list|None): Shape of output tensor
dtype(np.dtype|core.DataType|str): Data type of output tensor dtype(np.dtype|core.VarDesc.VarType|str): Data type of output tensor
value(float): Constant value to initialize the output tensor value(float): Constant value to initialize the output tensor
input_dim_idx(int): Index of input's batch size dimension input_dim_idx(int): Index of input's batch size dimension
output_dim_idx(int): Index of output's batch size dimension output_dim_idx(int): Index of output's batch size dimension
...@@ -327,7 +327,7 @@ def ones(shape, dtype, force_cpu=False): ...@@ -327,7 +327,7 @@ def ones(shape, dtype, force_cpu=False):
Args: Args:
shape(tuple|list|None): Shape of output tensor shape(tuple|list|None): Shape of output tensor
dtype(np.dtype|core.DataType|str): Data type of output tensor dtype(np.dtype|core.VarDesc.VarType|str): Data type of output tensor
Returns: Returns:
Variable: The tensor variable storing the output Variable: The tensor variable storing the output
...@@ -351,7 +351,7 @@ def zeros(shape, dtype, force_cpu=False): ...@@ -351,7 +351,7 @@ def zeros(shape, dtype, force_cpu=False):
Args: Args:
shape(tuple|list|None): Shape of output tensor shape(tuple|list|None): Shape of output tensor
dtype(np.dtype|core.DataType|str): Data type of output tensor dtype(np.dtype|core.VarDesc.VarType|str): Data type of output tensor
Returns: Returns:
Variable: The tensor variable storing the output Variable: The tensor variable storing the output
......
...@@ -20,13 +20,13 @@ from backward import _rename_arg_ ...@@ -20,13 +20,13 @@ from backward import _rename_arg_
from . import core from . import core
dtype_to_size = { dtype_to_size = {
core.DataType.FP16: 2, core.VarDesc.VarType.FP16: 2,
core.DataType.FP32: 4, core.VarDesc.VarType.FP32: 4,
core.DataType.FP64: 8, core.VarDesc.VarType.FP64: 8,
core.DataType.INT16: 2, core.VarDesc.VarType.INT16: 2,
core.DataType.INT32: 4, core.VarDesc.VarType.INT32: 4,
core.DataType.INT64: 8, core.VarDesc.VarType.INT64: 8,
core.DataType.BOOL: 1 core.VarDesc.VarType.BOOL: 1
} }
......
...@@ -22,7 +22,7 @@ block = prog.current_block() ...@@ -22,7 +22,7 @@ block = prog.current_block()
random_reader = block.create_var( random_reader = block.create_var(
type=fluid.core.VarDesc.VarType.READER, name="RandomDataGenerator") type=fluid.core.VarDesc.VarType.READER, name="RandomDataGenerator")
random_reader.desc.set_dtypes( random_reader.desc.set_dtypes(
[fluid.core.DataType.FP32, fluid.core.DataType.FP32]) [fluid.core.VarDesc.VarType.FP32, fluid.core.VarDesc.VarType.FP32])
create_random_data_generator_op = block.append_op( create_random_data_generator_op = block.append_op(
type="create_random_data_generator", type="create_random_data_generator",
......
...@@ -119,9 +119,9 @@ def get_numeric_gradient(place, ...@@ -119,9 +119,9 @@ def get_numeric_gradient(place,
tensor_to_check = scope.find_var(input_to_check).get_tensor() tensor_to_check = scope.find_var(input_to_check).get_tensor()
tensor_size = product(tensor_to_check.get_dims()) tensor_size = product(tensor_to_check.get_dims())
tensor_to_check_dtype = tensor_to_check.dtype() tensor_to_check_dtype = tensor_to_check.dtype()
if tensor_to_check_dtype == core.DataType.FP32: if tensor_to_check_dtype == core.VarDesc.VarType.FP32:
tensor_to_check_dtype = np.float32 tensor_to_check_dtype = np.float32
elif tensor_to_check_dtype == core.DataType.FP64: elif tensor_to_check_dtype == core.VarDesc.VarType.FP64:
tensor_to_check_dtype = np.float64 tensor_to_check_dtype = np.float64
else: else:
raise ValueError("Not supported data type " + str( raise ValueError("Not supported data type " + str(
......
...@@ -140,9 +140,9 @@ def set_output_grad(scope, outputs, place, feed_dict=None): ...@@ -140,9 +140,9 @@ def set_output_grad(scope, outputs, place, feed_dict=None):
grad_tensor = scope.var(grad_var_name(name)).get_tensor() grad_tensor = scope.var(grad_var_name(name)).get_tensor()
out_dtype = out_tensor.dtype() out_dtype = out_tensor.dtype()
if data is None: if data is None:
if out_dtype == core.DataType.FP64: if out_dtype == core.VarDesc.VarType.FP64:
data = np.ones(out_tensor.shape(), dtype=np.float64) data = np.ones(out_tensor.shape(), dtype=np.float64)
elif out_dtype == core.DataType.FP32: elif out_dtype == core.VarDesc.VarType.FP32:
data = np.ones(out_tensor.shape(), dtype=np.float32) data = np.ones(out_tensor.shape(), dtype=np.float32)
else: else:
raise ValueError("Not supported data type " + str(out_dtype)) raise ValueError("Not supported data type " + str(out_dtype))
......
...@@ -24,8 +24,8 @@ class TestCastOp(op_test.OpTest): ...@@ -24,8 +24,8 @@ class TestCastOp(op_test.OpTest):
self.inputs = {'X': ipt.astype('float32')} self.inputs = {'X': ipt.astype('float32')}
self.outputs = {'Out': ipt.astype('float64')} self.outputs = {'Out': ipt.astype('float64')}
self.attrs = { self.attrs = {
'in_dtype': int(core.DataType.FP32), 'in_dtype': int(core.VarDesc.VarType.FP32),
'out_dtype': int(core.DataType.FP64) 'out_dtype': int(core.VarDesc.VarType.FP64)
} }
self.op_type = 'cast' self.op_type = 'cast'
......
...@@ -26,7 +26,7 @@ class TestFillOp(OpTest): ...@@ -26,7 +26,7 @@ class TestFillOp(OpTest):
self.attrs = { self.attrs = {
'value': val.flatten().tolist(), 'value': val.flatten().tolist(),
'shape': [100, 200], 'shape': [100, 200],
'dtype': int(core.DataType.FP64) 'dtype': int(core.VarDesc.VarType.FP64)
} }
self.outputs = {'Out': val.astype('float64')} self.outputs = {'Out': val.astype('float64')}
......
...@@ -97,9 +97,9 @@ def set_output_grad(scope, outputs, place, feed_dict=None): ...@@ -97,9 +97,9 @@ def set_output_grad(scope, outputs, place, feed_dict=None):
grad_tensor = scope.var(grad_var_name(name)).get_tensor() grad_tensor = scope.var(grad_var_name(name)).get_tensor()
out_dtype = out_tensor.dtype() out_dtype = out_tensor.dtype()
if data is None: if data is None:
if out_dtype == core.DataType.FP64: if out_dtype == core.VarDesc.VarType.FP64:
data = np.ones(out_tensor.shape(), dtype=np.float64) data = np.ones(out_tensor.shape(), dtype=np.float64)
elif out_dtype == core.DataType.FP32: elif out_dtype == core.VarDesc.VarType.FP32:
data = np.ones(out_tensor.shape(), dtype=np.float32) data = np.ones(out_tensor.shape(), dtype=np.float32)
else: else:
raise ValueError("Not supported data type " + str(out_dtype)) raise ValueError("Not supported data type " + str(out_dtype))
......
...@@ -38,7 +38,7 @@ class TestOneHotOp(OpTest): ...@@ -38,7 +38,7 @@ class TestOneHotOp(OpTest):
out[i, x[i]] = 1.0 out[i, x[i]] = 1.0
self.inputs = {'X': (x, x_lod)} self.inputs = {'X': (x, x_lod)}
self.attrs = {'depth': depth, 'dtype': int(core.DataType.FP32)} self.attrs = {'depth': depth, 'dtype': int(core.VarDesc.VarType.FP32)}
self.outputs = {'Out': (out, x_lod)} self.outputs = {'Out': (out, x_lod)}
def test_check_output(self): def test_check_output(self):
......
...@@ -36,7 +36,7 @@ class TestParameter(unittest.TestCase): ...@@ -36,7 +36,7 @@ class TestParameter(unittest.TestCase):
self.assertIsNotNone(param) self.assertIsNotNone(param)
self.assertEqual('fc.w', param.name) self.assertEqual('fc.w', param.name)
self.assertEqual((784, 100), param.shape) self.assertEqual((784, 100), param.shape)
self.assertEqual(core.DataType.FP32, param.dtype) self.assertEqual(core.VarDesc.VarType.FP32, param.dtype)
self.assertEqual(0, param.block.idx) self.assertEqual(0, param.block.idx)
exe = Executor(core.CPUPlace()) exe = Executor(core.CPUPlace())
p = exe.run(main_program, fetch_list=[param])[0] p = exe.run(main_program, fetch_list=[param])[0]
......
...@@ -131,8 +131,8 @@ class TestVarDesc(unittest.TestCase): ...@@ -131,8 +131,8 @@ class TestVarDesc(unittest.TestCase):
block = program_desc.block(0) block = program_desc.block(0)
var = block.var('my_var') var = block.var('my_var')
var.set_type(core.VarDesc.VarType.LOD_TENSOR) var.set_type(core.VarDesc.VarType.LOD_TENSOR)
var.set_dtype(core.DataType.INT32) var.set_dtype(core.VarDesc.VarType.INT32)
self.assertEqual(core.DataType.INT32, var.dtype()) self.assertEqual(core.VarDesc.VarType.INT32, var.dtype())
self.assertEqual(core.VarDesc.VarType.LOD_TENSOR, var.type()) self.assertEqual(core.VarDesc.VarType.LOD_TENSOR, var.type())
def test_multiple_dtype(self): def test_multiple_dtype(self):
...@@ -141,7 +141,8 @@ class TestVarDesc(unittest.TestCase): ...@@ -141,7 +141,8 @@ class TestVarDesc(unittest.TestCase):
var = block.var('my_reader') var = block.var('my_reader')
var.set_type(core.VarDesc.VarType.READER) var.set_type(core.VarDesc.VarType.READER)
src_types = [ src_types = [
core.DataType.INT32, core.DataType.FP64, core.DataType.FP32 core.VarDesc.VarType.INT32, core.VarDesc.VarType.FP64,
core.VarDesc.VarType.FP32
] ]
var.set_dtypes(src_types) var.set_dtypes(src_types)
self.assertEqual(src_types, var.dtypes()) self.assertEqual(src_types, var.dtypes())
......
...@@ -20,7 +20,7 @@ import numpy as np ...@@ -20,7 +20,7 @@ import numpy as np
class TestVariable(unittest.TestCase): class TestVariable(unittest.TestCase):
def test_np_dtype_convert(self): def test_np_dtype_convert(self):
DT = core.DataType DT = core.VarDesc.VarType
convert = convert_np_dtype_to_dtype_ convert = convert_np_dtype_to_dtype_
self.assertEqual(DT.FP32, convert(np.float32)) self.assertEqual(DT.FP32, convert(np.float32))
self.assertEqual(DT.FP16, convert("float16")) self.assertEqual(DT.FP16, convert("float16"))
...@@ -36,13 +36,13 @@ class TestVariable(unittest.TestCase): ...@@ -36,13 +36,13 @@ class TestVariable(unittest.TestCase):
w = b.create_var( w = b.create_var(
dtype="float64", shape=[784, 100], lod_level=0, name="fc.w") dtype="float64", shape=[784, 100], lod_level=0, name="fc.w")
self.assertNotEqual(str(w), "") self.assertNotEqual(str(w), "")
self.assertEqual(core.DataType.FP64, w.dtype) self.assertEqual(core.VarDesc.VarType.FP64, w.dtype)
self.assertEqual((784, 100), w.shape) self.assertEqual((784, 100), w.shape)
self.assertEqual("fc.w", w.name) self.assertEqual("fc.w", w.name)
self.assertEqual(0, w.lod_level) self.assertEqual(0, w.lod_level)
w = b.create_var(name='fc.w') w = b.create_var(name='fc.w')
self.assertEqual(core.DataType.FP64, w.dtype) self.assertEqual(core.VarDesc.VarType.FP64, w.dtype)
self.assertEqual((784, 100), w.shape) self.assertEqual((784, 100), w.shape)
self.assertEqual("fc.w", w.name) self.assertEqual("fc.w", w.name)
self.assertEqual(0, w.lod_level) self.assertEqual(0, w.lod_level)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册