提交 c7ad26d6 编写于 作者: A Abhinav Arora 提交者: kavyasrinet

[WIP] Move DataType enum inside VarType (#8447)

* Move Pod Types from DataType enum to Type enum

* Fixed data_type.h

* Fix type in TensorDesc

* Add comment to framework.proto

* Fixed type in data_type.h

* Updated format of type in data_type.h

* Fix var_desc.h

* Fix op_kernel_type.h

* Fixed data_type_transform_test.cc

* Fix operator.h

* Fixed data_type_transform.cc

* Fixed op_kernel_type_test.cc

* Fix operator.cc

* Fixed data_layout_transform_test.cc

* Fix var_desc.cc

* Fixed assign_value_op.cc

* Fixed assign_value_op.h

* fixed protobuf.cc

* Fix data_layout_transform_test.cc and op_kernel_type_test.cc

* Fixed rnn_memory_helper_op.cc

* Fix progrma_desc_test.cc

* Fixed fill_constant_batch_size_like_op.cc

* Fix operator_test.cc

* Fixed fill_constant_op.cc

* Fixed gaussian_random_op.cc

* Fixed uniform_random_op.cc

* Fixed edit_distance_op.cc

* Fixed fill_constant_batch_size_like_op.cc

* Fixed rnn_memory_helper_op.cc

* Fixed chunk_eval_op.cc

* Fixed assign_value_op.cc

* Fixed assign_value_op.h

* Fixed cast_op.h

* Fixed cast_op.h

* Fix fill constant op

* Fixed clang for assign_value_op.cc

* Fix one_hot_op.h

* Fix one_hot_op.cc

* Fix fill_op.cc

* Fixed sum_op.cc

* Fixed sum_op clang

* Fix uniform_random_op.cc

* Fix gaussian_random_op.cc

* Fix backward.cc

* Fix protobuf.cc

* Fixed prune_test.cc

* Fixed op_registry_test.cc

* Fix data_device_transform_test.cu

* Fix travis error

* Fixed one_hot_op.cu

* Fixed op_registry_test.cc

* Fixed nccl_op.cc

* Fixing python tests

* Revert "Fixing python tests"

This reverts commit fccaa4c5.

* Fixing Pybind to remove data type

* Fixing tensor.py

* Updated the new files:

* Resolve error in merge conflict of fill_constant_batch_size_like_op.cc
上级 74e0eb72
......@@ -341,7 +341,7 @@ static void CreateGradVarInBlock(
auto* param = block_desc->FindVarRecursive(pname);
auto* grad = block_desc->FindVar(arg);
if (param == nullptr) {
grad->SetDataType(proto::DataType::FP32);
grad->SetDataType(proto::VarType::FP32);
} else {
grad->SetDataType(param->GetDataType());
}
......
......@@ -51,10 +51,10 @@ class TestOpWithKernel : public OperatorWithKernel {
const ExecutionContext& ctx) const override {
if (Attr<bool>("use_gpu")) {
VLOG(3) << "force use gpu kernel";
return OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0));
return OpKernelType(proto::VarType::FP32, platform::CUDAPlace(0));
} else {
VLOG(3) << "use default kernel";
return OpKernelType(proto::DataType::FP32,
return OpKernelType(proto::VarType::FP32,
ctx.Input<Tensor>("input")->place());
}
}
......
......@@ -27,9 +27,9 @@ TEST(DataTransform, DataLayoutFunction) {
in.mutable_data<double>(make_ddim({2, 3, 1, 2}), place);
in.set_layout(DataLayout::kNHWC);
auto kernel_nhwc = OpKernelType(proto::DataType::FP32, place,
auto kernel_nhwc = OpKernelType(proto::VarType::FP32, place,
DataLayout::kNHWC, LibraryType::kPlain);
auto kernel_ncwh = OpKernelType(proto::DataType::FP32, place,
auto kernel_ncwh = OpKernelType(proto::VarType::FP32, place,
DataLayout::kNCHW, LibraryType::kPlain);
TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out);
......
......@@ -20,35 +20,35 @@ limitations under the License. */
namespace paddle {
namespace framework {
inline proto::DataType ToDataType(std::type_index type) {
inline proto::VarType::Type ToDataType(std::type_index type) {
using namespace paddle::framework::proto;
if (typeid(float).hash_code() == type.hash_code()) {
return DataType::FP32;
return proto::VarType::FP32;
} else if (typeid(double).hash_code() == type.hash_code()) {
return DataType::FP64;
return proto::VarType::FP64;
} else if (typeid(int).hash_code() == type.hash_code()) {
return DataType::INT32;
return proto::VarType::INT32;
} else if (typeid(int64_t).hash_code() == type.hash_code()) {
return DataType::INT64;
return proto::VarType::INT64;
} else if (typeid(bool).hash_code() == type.hash_code()) {
return DataType::BOOL;
return proto::VarType::BOOL;
} else {
PADDLE_THROW("Not supported");
}
}
inline std::type_index ToTypeIndex(proto::DataType type) {
inline std::type_index ToTypeIndex(proto::VarType::Type type) {
using namespace paddle::framework::proto;
switch (type) {
case DataType::FP32:
case proto::VarType::FP32:
return typeid(float);
case DataType::FP64:
case proto::VarType::FP64:
return typeid(double);
case DataType::INT32:
case proto::VarType::INT32:
return typeid(int);
case DataType::INT64:
case proto::VarType::INT64:
return typeid(int64_t);
case DataType::BOOL:
case proto::VarType::BOOL:
return typeid(bool);
default:
PADDLE_THROW("Not support type %d", type);
......@@ -56,22 +56,22 @@ inline std::type_index ToTypeIndex(proto::DataType type) {
}
template <typename Visitor>
inline void VisitDataType(proto::DataType type, Visitor visitor) {
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
using namespace paddle::framework::proto;
switch (type) {
case DataType::FP32:
case proto::VarType::FP32:
visitor.template operator()<float>();
break;
case DataType::FP64:
case proto::VarType::FP64:
visitor.template operator()<double>();
break;
case DataType::INT32:
case proto::VarType::INT32:
visitor.template operator()<int>();
break;
case DataType::INT64:
case proto::VarType::INT64:
visitor.template operator()<int64_t>();
break;
case DataType::BOOL:
case proto::VarType::BOOL:
visitor.template operator()<bool>();
break;
default:
......@@ -79,22 +79,22 @@ inline void VisitDataType(proto::DataType type, Visitor visitor) {
}
}
inline std::string DataTypeToString(const proto::DataType type) {
inline std::string DataTypeToString(const proto::VarType::Type type) {
using namespace paddle::framework::proto;
switch (type) {
case DataType::FP16:
case proto::VarType::FP16:
return "float16";
case DataType::FP32:
case proto::VarType::FP32:
return "float32";
case DataType::FP64:
case proto::VarType::FP64:
return "float64";
case DataType::INT16:
case proto::VarType::INT16:
return "int16";
case DataType::INT32:
case proto::VarType::INT32:
return "int32";
case DataType::INT64:
case proto::VarType::INT64:
return "int64";
case DataType::BOOL:
case proto::VarType::BOOL:
return "bool";
default:
PADDLE_THROW("Not support type %d", type);
......@@ -102,7 +102,7 @@ inline std::string DataTypeToString(const proto::DataType type) {
}
inline std::ostream& operator<<(std::ostream& out,
const proto::DataType& type) {
const proto::VarType::Type& type) {
out << DataTypeToString(type);
return out;
}
......
......@@ -65,19 +65,19 @@ void TransDataType(const OpKernelType& kernel_type_for_var,
auto ctx = pool.Get(in.place());
switch (src_type) {
case proto::DataType::FP32:
case proto::VarType::FP32:
framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx));
break;
case proto::DataType::FP64:
case proto::VarType::FP64:
framework::VisitDataType(dst_type, CastDataType<double>(in, out, ctx));
break;
case proto::DataType::INT32:
case proto::VarType::INT32:
framework::VisitDataType(dst_type, CastDataType<int>(in, out, ctx));
break;
case proto::DataType::INT64:
case proto::VarType::INT64:
framework::VisitDataType(dst_type, CastDataType<int64_t>(in, out, ctx));
break;
case proto::DataType::BOOL:
case proto::VarType::BOOL:
framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx));
break;
default:
......
......@@ -32,11 +32,11 @@ TEST(DataTypeTransform, CPUTransform) {
ptr[i] = i / 3;
}
auto kernel_fp32 = OpKernelType(proto::DataType::FP32, place,
auto kernel_fp32 = OpKernelType(proto::VarType::FP32, place,
DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_fp64 = OpKernelType(proto::DataType::FP64, place,
auto kernel_fp64 = OpKernelType(proto::VarType::FP64, place,
DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_int32 = OpKernelType(proto::DataType::INT32, place,
auto kernel_int32 = OpKernelType(proto::VarType::INT32, place,
DataLayout::kAnyLayout, LibraryType::kPlain);
TransDataType(kernel_fp32, kernel_fp64, in, &out);
......
......@@ -91,7 +91,9 @@ message OpProto {
required string comment = 5;
}
enum DataType {
message VarType {
enum Type {
// Pod Types
BOOL = 0;
INT16 = 1;
INT32 = 2;
......@@ -99,25 +101,24 @@ enum DataType {
FP16 = 4;
FP32 = 5;
FP64 = 6;
}
message VarType {
enum Type {
LOD_TENSOR = 1;
SELECTED_ROWS = 2;
FEED_MINIBATCH = 3;
FETCH_LIST = 4;
STEP_SCOPES = 5;
LOD_RANK_TABLE = 6;
LOD_TENSOR_ARRAY = 7;
PLACE_LIST = 8;
READER = 9;
// Other types that may need additional descriptions
LOD_TENSOR = 7;
SELECTED_ROWS = 8;
FEED_MINIBATCH = 9;
FETCH_LIST = 10;
STEP_SCOPES = 11;
LOD_RANK_TABLE = 12;
LOD_TENSOR_ARRAY = 13;
PLACE_LIST = 14;
READER = 15;
}
required Type type = 1;
message TensorDesc {
required DataType data_type = 1;
// Should only be PODType. Is enforced in C++
required Type data_type = 1;
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
}
optional TensorDesc selected_rows = 2;
......
......@@ -40,12 +40,12 @@ struct OpKernelType {
// place, data_type, library_type kinds less than 2^8
constexpr static int LEFT_SHIFT = 8;
proto::DataType data_type_;
proto::VarType::Type data_type_;
DataLayout data_layout_;
platform::Place place_;
LibraryType library_type_;
OpKernelType(proto::DataType data_type, platform::Place place,
OpKernelType(proto::VarType::Type data_type, platform::Place place,
DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain)
: data_type_(data_type),
......@@ -53,7 +53,7 @@ struct OpKernelType {
place_(place),
library_type_(library_type) {}
OpKernelType(proto::DataType data_type,
OpKernelType(proto::VarType::Type data_type,
const platform::DeviceContext& dev_ctx,
DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain)
......
......@@ -18,7 +18,7 @@ limitations under the License. */
TEST(OpKernelType, ToString) {
using OpKernelType = paddle::framework::OpKernelType;
using DataType = paddle::framework::proto::DataType;
using DataType = paddle::framework::proto::VarType;
using CPUPlace = paddle::platform::CPUPlace;
using DataLayout = paddle::framework::DataLayout;
using LibraryType = paddle::framework::LibraryType;
......@@ -33,7 +33,7 @@ TEST(OpKernelType, ToString) {
TEST(OpKernelType, Hash) {
using OpKernelType = paddle::framework::OpKernelType;
using DataType = paddle::framework::proto::DataType;
using DataType = paddle::framework::proto::VarType;
using CPUPlace = paddle::platform::CPUPlace;
using CUDAPlace = paddle::platform::CUDAPlace;
using DataLayout = paddle::framework::DataLayout;
......
......@@ -226,7 +226,7 @@ class OpWithKernelTest : public OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::DataType::FP32, ctx.device_context());
return framework::OpKernelType(proto::VarType::FP32, ctx.device_context());
}
};
......@@ -290,8 +290,8 @@ class OpWithMultiKernelTest : public OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
proto::DataType::FP32, platform::CUDAPlace(0), DataLayout::kAnyLayout,
return framework::OpKernelType(proto::VarType::FP32, platform::CUDAPlace(0),
DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
};
......
......@@ -569,7 +569,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
}
}
proto::DataType OperatorWithKernel::IndicateDataType(
proto::VarType::Type OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const {
auto& scope = ctx.scope();
int data_type = -1;
......@@ -595,7 +595,7 @@ proto::DataType OperatorWithKernel::IndicateDataType(
}
}
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
return static_cast<proto::DataType>(data_type);
return static_cast<proto::VarType::Type>(data_type);
}
OpKernelType OperatorWithKernel::GetExpectedKernelType(
......
......@@ -394,9 +394,9 @@ class OperatorWithKernel : public OperatorBase {
const OpKernelType& expected_kernel_type) const;
private:
// indicate kernel DataType by input data. Defaultly all input data must be
// indicate kernel DataType by input data. By default all input data must be
// same.
proto::DataType IndicateDataType(const ExecutionContext& ctx) const;
proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;
void RunImpl(const Scope& scope, const platform::Place& place) const final;
};
......
......@@ -119,7 +119,7 @@ class OpWithKernelTest : public OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override {
return OpKernelType(proto::DataType::FP32, ctx.GetPlace());
return OpKernelType(proto::VarType::FP32, ctx.GetPlace());
}
};
......
......@@ -24,13 +24,13 @@ TEST(ProgramDesc, copy_ctor) {
auto* x = global_block->Var("X");
x->SetType(proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0);
x->SetDataType(proto::FP32);
x->SetDataType(proto::VarType::FP32);
x->SetShape({1000, 784});
auto* y = global_block->Var("Y");
y->SetType(proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0);
y->SetDataType(proto::FP32);
y->SetDataType(proto::VarType::FP32);
y->SetShape({784, 100});
auto* op = global_block->AppendOp();
......@@ -86,13 +86,13 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
auto* x = global_block->Var("X");
x->SetType(proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0);
x->SetDataType(proto::FP32);
x->SetDataType(proto::VarType::FP32);
x->SetShape({1000, 784});
auto* y = global_block->Var("Y");
y->SetType(proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0);
y->SetDataType(proto::FP32);
y->SetDataType(proto::VarType::FP32);
y->SetShape({784, 100});
auto* op = global_block->AppendOp();
......
......@@ -34,7 +34,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
for (auto kv : outputs) {
for (auto v : kv.second) {
auto var = block->Var(v);
var->SetDataType(paddle::framework::proto::DataType::FP32);
var->SetDataType(paddle::framework::proto::VarType::FP32);
}
}
......
......@@ -87,12 +87,12 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
return res;
}
void VarDesc::SetDataType(proto::DataType data_type) {
void VarDesc::SetDataType(proto::VarType::Type data_type) {
mutable_tensor_desc()->set_data_type(data_type);
}
void VarDesc::SetDataTypes(
const std::vector<proto::DataType> &multiple_data_type) {
const std::vector<proto::VarType::Type> &multiple_data_type) {
if (multiple_data_type.size() != GetTensorDescNum()) {
VLOG(3) << "WARNING: The number of given data types("
<< multiple_data_type.size()
......@@ -108,13 +108,13 @@ void VarDesc::SetDataTypes(
}
}
proto::DataType VarDesc::GetDataType() const {
proto::VarType::Type VarDesc::GetDataType() const {
return tensor_desc().data_type();
}
std::vector<proto::DataType> VarDesc::GetDataTypes() const {
std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
std::vector<proto::DataType> res;
std::vector<proto::VarType::Type> res;
res.reserve(descs.size());
for (const auto &tensor_desc : descs) {
res.push_back(tensor_desc.data_type());
......
......@@ -80,13 +80,14 @@ class VarDesc {
std::vector<std::vector<int64_t>> GetShapes() const;
void SetDataType(proto::DataType data_type);
void SetDataType(proto::VarType::Type data_type);
void SetDataTypes(const std::vector<proto::DataType> &multiple_data_type);
void SetDataTypes(
const std::vector<proto::VarType::Type> &multiple_data_type);
proto::DataType GetDataType() const;
proto::VarType::Type GetDataType() const;
std::vector<proto::DataType> GetDataTypes() const;
std::vector<proto::VarType::Type> GetDataTypes() const;
void SetLoDLevel(int32_t lod_level);
......
......@@ -36,7 +36,8 @@ class AssignValueOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::proto::DataType(ctx.Attr<int>("dtype")), ctx.GetPlace());
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
......@@ -49,8 +50,8 @@ class AssignValueOpMaker : public framework::OpProtoAndCheckerMaker {
"(vector<int>) "
"Shape of values.");
AddAttr<int>("dtype", "data type of values")
.InEnum({framework::proto::DataType::INT32,
framework::proto::DataType::FP32});
.InEnum({framework::proto::VarType::INT32,
framework::proto::VarType::FP32});
AddAttr<std::vector<float>>("fp32_values", "store the float values")
.SetDefault({});
AddAttr<std::vector<int>>("int32_values", "store the int values")
......
......@@ -30,10 +30,10 @@ class AssignValueKernel : public framework::OpKernel<T> {
int dtype = ctx.Attr<int>("dtype");
const char* value_name = nullptr;
switch (dtype) {
case framework::proto::DataType::INT32:
case framework::proto::VarType::INT32:
value_name = "int32_values";
break;
case framework::proto::DataType::FP32:
case framework::proto::VarType::FP32:
value_name = "fp32_values";
break;
default:
......
......@@ -55,7 +55,8 @@ class CastOpKernel : public framework::OpKernel<InT> {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
framework::VisitDataType(
static_cast<framework::proto::DataType>(context.Attr<int>("out_dtype")),
static_cast<framework::proto::VarType::Type>(
context.Attr<int>("out_dtype")),
CastOpFunctor<DeviceContext, InT>(
in, out, context.template device_context<DeviceContext>()));
}
......
......@@ -57,7 +57,7 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::proto::DataType::FP32,
return framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace());
}
};
......
......@@ -42,7 +42,7 @@ class EditDistanceOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::proto::DataType::FP32,
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.device_context());
}
};
......
......@@ -24,7 +24,7 @@ class FillConstantBatchSizeLikeOp : public BatchSizeLikeOp {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.device_context());
}
};
......@@ -36,7 +36,7 @@ class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::proto::DataType::FP32);
.SetDefault(framework::proto::VarType::FP32);
AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f);
AddComment(R"DOC(
......
......@@ -38,7 +38,7 @@ class FillConstantOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto data_type =
static_cast<framework::proto::DataType>(Attr<int>("dtype"));
static_cast<framework::proto::VarType::Type>(Attr<int>("dtype"));
auto value = Attr<float>("value");
auto force_cpu = Attr<bool>("force_cpu");
auto &out =
......@@ -64,7 +64,7 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::proto::DataType::FP32);
.SetDefault(framework::proto::VarType::FP32);
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f);
......
......@@ -51,7 +51,8 @@ class FillOp : public framework::OperatorBase {
"Cannot find variable %s", Output("Out"))
.GetMutable<framework::LoDTensor>());
out.Resize(framework::make_ddim(Attr<std::vector<int>>("shape")));
auto dtype = static_cast<framework::proto::DataType>(Attr<int>("dtype"));
auto dtype =
static_cast<framework::proto::VarType::Type>(Attr<int>("dtype"));
platform::CPUPlace cpu;
auto force_cpu = Attr<bool>("force_cpu");
out.mutable_data(force_cpu ? cpu : place, framework::ToTypeIndex(dtype));
......@@ -93,7 +94,7 @@ Fill an tensor with `value` and `shape`. The type of the tensor is specify by
"value", "The float values of tensor, which are flatten in row major");
AddAttr<std::vector<int>>("shape", "The shape of output tensor");
AddAttr<int>("dtype", "The data type of output tensor, Default is float")
.SetDefault(framework::proto::DataType::FP32);
.SetDefault(framework::proto::VarType::FP32);
AddAttr<bool>("force_cpu",
"Whether the output tensor must be at CPU memory or not. "
"Default is false.")
......
......@@ -26,7 +26,7 @@ class GaussianRandomBatchSizeLikeOp : public BatchSizeLikeOp {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
......@@ -53,7 +53,7 @@ class GaussianRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
AddAttr<int>("dtype",
"(int, default 5(FP32)) "
"Output data type.")
.SetDefault(framework::proto::DataType::FP32);
.SetDefault(framework::proto::VarType::FP32);
AddComment(R"DOC(
GaussianRandom Operator.
......
......@@ -63,7 +63,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.device_context());
}
};
......@@ -95,7 +95,7 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("dtype",
"(int, default 5(FP32)) "
"Output data type.")
.SetDefault(framework::proto::DataType::FP32);
.SetDefault(framework::proto::VarType::FP32);
AddComment(R"DOC(
GaussianRandom Operator.
......
......@@ -55,7 +55,7 @@ class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::proto::DataType::FP32);
.SetDefault(framework::proto::VarType::FP32);
AddComment(R"DOC(
NCCLInit Operator.
......
......@@ -60,7 +60,7 @@ class OneHotOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("dtype",
"An integer to specify the data type of one-hot "
"vector. The default value is FP32.")
.SetDefault(paddle::framework::proto::DataType::FP32);
.SetDefault(paddle::framework::proto::VarType::FP32);
AddComment(R"DOC(
One Hot Operator. This operator creates the one-hot representations for input
index values. The following example will help to explain the function of this
......
......@@ -65,7 +65,8 @@ class OneHotCUDAKernel : public framework::OpKernel<T> {
int depth = context.Attr<int>("depth");
framework::VisitDataType(
static_cast<framework::proto::DataType>(context.Attr<int>("dtype")),
static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype")),
OneHotOpCUDAFunctor<DeviceContext, T>(
in, out, depth, context.template device_context<DeviceContext>()));
}
......
......@@ -58,7 +58,8 @@ class OneHotKernel : public framework::OpKernel<T> {
int depth = context.Attr<int>("depth");
framework::VisitDataType(
static_cast<framework::proto::DataType>(context.Attr<int>("dtype")),
static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype")),
OneHotOpFunctor<DeviceContext, T>(
in, out, depth, context.template device_context<DeviceContext>()));
}
......
......@@ -66,7 +66,7 @@ class RNNMemoryHelperOpInfoMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::proto::DataType::FP32);
.SetDefault(framework::proto::VarType::FP32);
AddComment("");
}
};
......@@ -126,7 +126,7 @@ class RNNMemoryHelperGradOpInfoMaker
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::proto::DataType::FP32);
.SetDefault(framework::proto::VarType::FP32);
AddComment("");
}
};
......
......@@ -73,7 +73,8 @@ class SumOp : public framework::OperatorWithKernel {
"Sum operator should have at least one tensor");
return framework::OpKernelType(
static_cast<framework::proto::DataType>(dtype), ctx.device_context());
static_cast<framework::proto::VarType::Type>(dtype),
ctx.device_context());
} else if (x_vars[0]->IsType<framework::SelectedRows>()) {
return framework::OpKernelType(
framework::ToDataType(
......
......@@ -26,7 +26,7 @@ class UniformRandomBatchSizeLikeOp : public BatchSizeLikeOp {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
......@@ -58,7 +58,7 @@ This operator initializes a tensor with the same batch_size as the Input tensor
"generate the same random numbers every time.")
.SetDefault(0);
AddAttr<int>("dtype", "(int, default 5(FP32)) Output tensor data type")
.SetDefault(framework::proto::DataType::FP32);
.SetDefault(framework::proto::VarType::FP32);
}
};
......
......@@ -66,7 +66,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
......@@ -101,7 +101,7 @@ uniform distribution.
"generate the same random numbers every time.")
.SetDefault(0);
AddAttr<int>("dtype", "(int, default 5(FP32)) Output tensor data type")
.SetDefault(framework::proto::DataType::FP32);
.SetDefault(framework::proto::VarType::FP32);
}
};
} // namespace operators
......
......@@ -195,15 +195,6 @@ void BindBlockDesc(py::module &m) {
}
void BindVarDsec(py::module &m) {
py::enum_<proto::DataType>(m, "DataType", "")
.value("BOOL", proto::DataType::BOOL)
.value("INT16", proto::DataType::INT16)
.value("INT32", proto::DataType::INT32)
.value("INT64", proto::DataType::INT64)
.value("FP16", proto::DataType::FP16)
.value("FP32", proto::DataType::FP32)
.value("FP64", proto::DataType::FP64);
py::class_<VarDesc> var_desc(m, "VarDesc", "");
var_desc
.def("name",
......@@ -233,6 +224,13 @@ void BindVarDsec(py::module &m) {
.def("set_persistable", &VarDesc::SetPersistable);
py::enum_<proto::VarType::Type>(var_desc, "VarType", "")
.value("BOOL", proto::VarType::BOOL)
.value("INT16", proto::VarType::INT16)
.value("INT32", proto::VarType::INT32)
.value("INT64", proto::VarType::INT64)
.value("FP16", proto::VarType::FP16)
.value("FP32", proto::VarType::FP32)
.value("FP64", proto::VarType::FP64)
.value("LOD_TENSOR", proto::VarType::LOD_TENSOR)
.value("SELECTED_ROWS", proto::VarType::SELECTED_ROWS)
.value("FEED_MINIBATCH", proto::VarType::FEED_MINIBATCH)
......
......@@ -68,7 +68,7 @@ def _infer_var_data_type_(grad_var_name, block):
fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii"))
grad_var.set_dtype(fwd_var.dtype())
else:
grad_var.set_dtype(core.DataType.FP32)
grad_var.set_dtype(core.VarDesc.VarType.FP32)
def _all_in_set_(cands, s):
......
......@@ -27,13 +27,13 @@ class DataToLoDTensorConverter(object):
self.place = place
self.lod_level = lod_level
self.shape = shape
if dtype == core.DataType.FP32:
if dtype == core.VarDesc.VarType.FP32:
self.dtype = 'float32'
elif dtype == core.DataType.INT64:
elif dtype == core.VarDesc.VarType.INT64:
self.dtype = 'int64'
elif dtype == core.DataType.FP64:
elif dtype == core.VarDesc.VarType.FP64:
self.dtype = 'float64'
elif dtype == core.DataType.INT32:
elif dtype == core.VarDesc.VarType.INT32:
self.dtype = 'int32'
else:
raise ValueError("dtype must be any of [int32, float32, int64, "
......
......@@ -89,7 +89,7 @@ class Evaluator(object):
Args:
suffix(str): the state suffix.
dtype(str|core.DataType): the state data type
dtype(str|core.VarDesc.VarType): the state data type
shape(tuple|list): the shape of state
Returns: State variable
......
......@@ -67,24 +67,24 @@ def convert_np_dtype_to_dtype_(np_dtype):
Args:
np_dtype(np.dtype): the data type in numpy
Returns(core.DataType): the data type in Paddle
Returns(core.VarDesc.VarType): the data type in Paddle
"""
dtype = np.dtype(np_dtype)
if dtype == np.float32:
return core.DataType.FP32
return core.VarDesc.VarType.FP32
elif dtype == np.float64:
return core.DataType.FP64
return core.VarDesc.VarType.FP64
elif dtype == np.float16:
return core.DataType.FP16
return core.VarDesc.VarType.FP16
elif dtype == np.int32:
return core.DataType.INT32
return core.VarDesc.VarType.INT32
elif dtype == np.int16:
return core.DataType.INT16
return core.VarDesc.VarType.INT16
elif dtype == np.int64:
return core.DataType.INT64
return core.VarDesc.VarType.INT64
elif dtype == np.bool:
return core.DataType.BOOL
return core.VarDesc.VarType.BOOL
else:
raise ValueError("Not supported numpy dtype " + str(dtype))
......@@ -93,16 +93,19 @@ def dtype_is_floating(dtype):
"""
Check the data type is floating or not.
Args:
dtype(np.dtype|core.DataType): data type.
dtype(np.dtype|core.VarDesc.VarType): data type.
Could be numpy format or Paddle format
Returns(bool): True if data type is a float value
"""
if not isinstance(dtype, core.DataType):
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
return dtype in [core.DataType.FP16, core.DataType.FP32, core.DataType.FP64]
return dtype in [
core.VarDesc.VarType.FP16, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP64
]
def _debug_string_(proto, throw_on_error=True):
......@@ -148,7 +151,7 @@ class Variable(object):
framework.proto for details.
shape(tuple|list|None): The shape of variable. -1 means the batch size.
Some kinds of variable do not contain shape, just set it to None.
dtype(np.dtype|core.DataType|str): The data type of variable.
dtype(np.dtype|core.VarDesc.VarType|str): The data type of variable.
lod_level(int): The level of lod tensor. 0 means there is not a time
series data.
persistable(bool): True if the variable should be saved as check point.
......@@ -200,7 +203,7 @@ class Variable(object):
"shape is {1}; the new shape is {2}. They are not "
"matched.".format(self.name, old_shape, shape))
if dtype is not None:
if not isinstance(dtype, core.DataType):
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
if is_new_var:
self.desc.set_dtype(dtype)
......
......@@ -612,7 +612,7 @@ class While(object):
if not isinstance(cond, Variable):
raise TypeError("condition should be a variable")
assert isinstance(cond, Variable)
if cond.dtype != core.DataType.BOOL:
if cond.dtype != core.VarDesc.VarType.BOOL:
raise TypeError("condition should be a bool variable")
if reduce(lambda a, b: a * b, cond.shape, 1) != 1:
raise TypeError("condition should be a bool scalar")
......
......@@ -221,7 +221,7 @@ def embedding(input,
:math:`padding_idx < 0`, the padding_idx to use in lookup is
:math:`size[0] + dim`.
param_attr(ParamAttr): Parameters for this layer
dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc
dtype(np.dtype|core.VarDesc.VarType|str): The type of data : float32, float_16, int etc
Returns:
Variable: The tensor variable storing the embeddings of the \
......
......@@ -17,7 +17,7 @@ from ..param_attr import ParamAttr
from ..framework import convert_np_dtype_to_dtype_
from ..framework import Variable
from ..initializer import Constant, force_init_on_cpu
from ..core import DataType
from ..core import VarDesc
import numpy
__all__ = [
......@@ -199,10 +199,10 @@ def assign(input, output):
attrs={'scale': 1.0})
elif isinstance(input, numpy.ndarray):
dtype = convert_np_dtype_to_dtype_(input.dtype)
if dtype == DataType.FP32:
if dtype == VarDesc.VarType.FP32:
value_name = "fp32_values"
values = [float(v) for v in input.flat]
elif dtype == DataType.INT32:
elif dtype == VarDesc.VarType.INT32:
value_name = "int32_values"
values = [int(v) for v in input.flat]
else:
......@@ -236,7 +236,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
Args:
shape(tuple|list|None): Shape of the output tensor.
dtype(np.dtype|core.DataType|str): Data type of the output tensor.
dtype(np.dtype|core.VarDesc.VarType|str): Data type of the output tensor.
value(float): The constant value used to initialize the output tensor.
out(Variable): The output tensor.
force_cpu(True|False): data should be on CPU if set true.
......@@ -285,7 +285,7 @@ def fill_constant_batch_size_like(input,
Args:
input(Variable): Tensor whose dimensions will be used to get batch size
shape(tuple|list|None): Shape of output tensor
dtype(np.dtype|core.DataType|str): Data type of output tensor
dtype(np.dtype|core.VarDesc.VarType|str): Data type of output tensor
value(float): Constant value to initialize the output tensor
input_dim_idx(int): Index of input's batch size dimension
output_dim_idx(int): Index of output's batch size dimension
......@@ -327,7 +327,7 @@ def ones(shape, dtype, force_cpu=False):
Args:
shape(tuple|list|None): Shape of output tensor
dtype(np.dtype|core.DataType|str): Data type of output tensor
dtype(np.dtype|core.VarDesc.VarType|str): Data type of output tensor
Returns:
Variable: The tensor variable storing the output
......@@ -351,7 +351,7 @@ def zeros(shape, dtype, force_cpu=False):
Args:
shape(tuple|list|None): Shape of output tensor
dtype(np.dtype|core.DataType|str): Data type of output tensor
dtype(np.dtype|core.VarDesc.VarType|str): Data type of output tensor
Returns:
Variable: The tensor variable storing the output
......
......@@ -20,13 +20,13 @@ from backward import _rename_arg_
from . import core
dtype_to_size = {
core.DataType.FP16: 2,
core.DataType.FP32: 4,
core.DataType.FP64: 8,
core.DataType.INT16: 2,
core.DataType.INT32: 4,
core.DataType.INT64: 8,
core.DataType.BOOL: 1
core.VarDesc.VarType.FP16: 2,
core.VarDesc.VarType.FP32: 4,
core.VarDesc.VarType.FP64: 8,
core.VarDesc.VarType.INT16: 2,
core.VarDesc.VarType.INT32: 4,
core.VarDesc.VarType.INT64: 8,
core.VarDesc.VarType.BOOL: 1
}
......
......@@ -22,7 +22,7 @@ block = prog.current_block()
random_reader = block.create_var(
type=fluid.core.VarDesc.VarType.READER, name="RandomDataGenerator")
random_reader.desc.set_dtypes(
[fluid.core.DataType.FP32, fluid.core.DataType.FP32])
[fluid.core.VarDesc.VarType.FP32, fluid.core.VarDesc.VarType.FP32])
create_random_data_generator_op = block.append_op(
type="create_random_data_generator",
......
......@@ -119,9 +119,9 @@ def get_numeric_gradient(place,
tensor_to_check = scope.find_var(input_to_check).get_tensor()
tensor_size = product(tensor_to_check.get_dims())
tensor_to_check_dtype = tensor_to_check.dtype()
if tensor_to_check_dtype == core.DataType.FP32:
if tensor_to_check_dtype == core.VarDesc.VarType.FP32:
tensor_to_check_dtype = np.float32
elif tensor_to_check_dtype == core.DataType.FP64:
elif tensor_to_check_dtype == core.VarDesc.VarType.FP64:
tensor_to_check_dtype = np.float64
else:
raise ValueError("Not supported data type " + str(
......
......@@ -140,9 +140,9 @@ def set_output_grad(scope, outputs, place, feed_dict=None):
grad_tensor = scope.var(grad_var_name(name)).get_tensor()
out_dtype = out_tensor.dtype()
if data is None:
if out_dtype == core.DataType.FP64:
if out_dtype == core.VarDesc.VarType.FP64:
data = np.ones(out_tensor.shape(), dtype=np.float64)
elif out_dtype == core.DataType.FP32:
elif out_dtype == core.VarDesc.VarType.FP32:
data = np.ones(out_tensor.shape(), dtype=np.float32)
else:
raise ValueError("Not supported data type " + str(out_dtype))
......
......@@ -24,8 +24,8 @@ class TestCastOp(op_test.OpTest):
self.inputs = {'X': ipt.astype('float32')}
self.outputs = {'Out': ipt.astype('float64')}
self.attrs = {
'in_dtype': int(core.DataType.FP32),
'out_dtype': int(core.DataType.FP64)
'in_dtype': int(core.VarDesc.VarType.FP32),
'out_dtype': int(core.VarDesc.VarType.FP64)
}
self.op_type = 'cast'
......
......@@ -26,7 +26,7 @@ class TestFillOp(OpTest):
self.attrs = {
'value': val.flatten().tolist(),
'shape': [100, 200],
'dtype': int(core.DataType.FP64)
'dtype': int(core.VarDesc.VarType.FP64)
}
self.outputs = {'Out': val.astype('float64')}
......
......@@ -97,9 +97,9 @@ def set_output_grad(scope, outputs, place, feed_dict=None):
grad_tensor = scope.var(grad_var_name(name)).get_tensor()
out_dtype = out_tensor.dtype()
if data is None:
if out_dtype == core.DataType.FP64:
if out_dtype == core.VarDesc.VarType.FP64:
data = np.ones(out_tensor.shape(), dtype=np.float64)
elif out_dtype == core.DataType.FP32:
elif out_dtype == core.VarDesc.VarType.FP32:
data = np.ones(out_tensor.shape(), dtype=np.float32)
else:
raise ValueError("Not supported data type " + str(out_dtype))
......
......@@ -38,7 +38,7 @@ class TestOneHotOp(OpTest):
out[i, x[i]] = 1.0
self.inputs = {'X': (x, x_lod)}
self.attrs = {'depth': depth, 'dtype': int(core.DataType.FP32)}
self.attrs = {'depth': depth, 'dtype': int(core.VarDesc.VarType.FP32)}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
......
......@@ -36,7 +36,7 @@ class TestParameter(unittest.TestCase):
self.assertIsNotNone(param)
self.assertEqual('fc.w', param.name)
self.assertEqual((784, 100), param.shape)
self.assertEqual(core.DataType.FP32, param.dtype)
self.assertEqual(core.VarDesc.VarType.FP32, param.dtype)
self.assertEqual(0, param.block.idx)
exe = Executor(core.CPUPlace())
p = exe.run(main_program, fetch_list=[param])[0]
......
......@@ -131,8 +131,8 @@ class TestVarDesc(unittest.TestCase):
block = program_desc.block(0)
var = block.var('my_var')
var.set_type(core.VarDesc.VarType.LOD_TENSOR)
var.set_dtype(core.DataType.INT32)
self.assertEqual(core.DataType.INT32, var.dtype())
var.set_dtype(core.VarDesc.VarType.INT32)
self.assertEqual(core.VarDesc.VarType.INT32, var.dtype())
self.assertEqual(core.VarDesc.VarType.LOD_TENSOR, var.type())
def test_multiple_dtype(self):
......@@ -141,7 +141,8 @@ class TestVarDesc(unittest.TestCase):
var = block.var('my_reader')
var.set_type(core.VarDesc.VarType.READER)
src_types = [
core.DataType.INT32, core.DataType.FP64, core.DataType.FP32
core.VarDesc.VarType.INT32, core.VarDesc.VarType.FP64,
core.VarDesc.VarType.FP32
]
var.set_dtypes(src_types)
self.assertEqual(src_types, var.dtypes())
......
......@@ -20,7 +20,7 @@ import numpy as np
class TestVariable(unittest.TestCase):
def test_np_dtype_convert(self):
DT = core.DataType
DT = core.VarDesc.VarType
convert = convert_np_dtype_to_dtype_
self.assertEqual(DT.FP32, convert(np.float32))
self.assertEqual(DT.FP16, convert("float16"))
......@@ -36,13 +36,13 @@ class TestVariable(unittest.TestCase):
w = b.create_var(
dtype="float64", shape=[784, 100], lod_level=0, name="fc.w")
self.assertNotEqual(str(w), "")
self.assertEqual(core.DataType.FP64, w.dtype)
self.assertEqual(core.VarDesc.VarType.FP64, w.dtype)
self.assertEqual((784, 100), w.shape)
self.assertEqual("fc.w", w.name)
self.assertEqual(0, w.lod_level)
w = b.create_var(name='fc.w')
self.assertEqual(core.DataType.FP64, w.dtype)
self.assertEqual(core.VarDesc.VarType.FP64, w.dtype)
self.assertEqual((784, 100), w.shape)
self.assertEqual("fc.w", w.name)
self.assertEqual(0, w.lod_level)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册