提交 c5909c77 编写于 作者: M minqiyang

Feature/tensor type

test=release/1.2
上级 847cbdce
...@@ -85,7 +85,7 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var, ...@@ -85,7 +85,7 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
out->mutable_data(expected_kernel_type.place_, in.type()); out->mutable_data(expected_kernel_type.place_, in.type());
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(in.type()), in.type(),
CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out)); CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out));
out->set_layout(expected_kernel_type.data_layout_); out->set_layout(expected_kernel_type.data_layout_);
...@@ -101,7 +101,7 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) { ...@@ -101,7 +101,7 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
case mkldnn::memory::data_type::f32: case mkldnn::memory::data_type::f32:
return platform::to_void_cast(tensor.data<float>()); return platform::to_void_cast(tensor.data<float>());
case mkldnn::memory::data_type::s8: case mkldnn::memory::data_type::s8:
return platform::to_void_cast(tensor.data<char>()); return platform::to_void_cast(tensor.data<int8_t>());
case mkldnn::memory::data_type::u8: case mkldnn::memory::data_type::u8:
return platform::to_void_cast(tensor.data<unsigned char>()); return platform::to_void_cast(tensor.data<unsigned char>());
case mkldnn::memory::data_type::s16: case mkldnn::memory::data_type::s16:
...@@ -144,7 +144,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -144,7 +144,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
memory::data_type in_type = ToMKLDNNDataType(in.type()); memory::data_type in_type = ToMKLDNNDataType(in.type());
PADDLE_ENFORCE(in_type != memory::data_type::data_undef, PADDLE_ENFORCE(in_type != memory::data_type::data_undef,
"Input tensor type is not supported: ", in.type().name()); "Input tensor type is not supported: %s", in.type());
memory::data_type out_type = in_type; memory::data_type out_type = in_type;
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format()); auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
......
...@@ -50,14 +50,14 @@ inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) { ...@@ -50,14 +50,14 @@ inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) {
} }
} }
inline MKLDNNDataType ToMKLDNNDataType(const std::type_index type) { inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
static const std::map<std::type_index, MKLDNNDataType> dict{ static std::unordered_map<int, MKLDNNDataType> dict{
{std::type_index(typeid(float)), MKLDNNDataType::f32}, // NOLINT {DataTypeTrait<float>::DataType, MKLDNNDataType::f32},
{std::type_index(typeid(char)), MKLDNNDataType::s8}, // NOLINT {DataTypeTrait<int8_t>::DataType, MKLDNNDataType::s8},
{std::type_index(typeid(unsigned char)), MKLDNNDataType::u8}, {DataTypeTrait<uint8_t>::DataType, MKLDNNDataType::u8},
{std::type_index(typeid(int16_t)), MKLDNNDataType::s16}, {DataTypeTrait<int16_t>::DataType, MKLDNNDataType::s16},
{std::type_index(typeid(int32_t)), MKLDNNDataType::s32}}; {DataTypeTrait<int32_t>::DataType, MKLDNNDataType::s32}};
auto iter = dict.find(type); auto iter = dict.find(static_cast<int>(type));
if (iter != dict.end()) return iter->second; if (iter != dict.end()) return iter->second;
return MKLDNNDataType::data_undef; return MKLDNNDataType::data_undef;
} }
......
...@@ -26,7 +26,7 @@ struct DataTypeMap { ...@@ -26,7 +26,7 @@ struct DataTypeMap {
std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_; std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_;
std::unordered_map<int, std::type_index> proto_to_cpp_; std::unordered_map<int, std::type_index> proto_to_cpp_;
std::unordered_map<int, std::string> proto_to_str_; std::unordered_map<int, std::string> proto_to_str_;
std::unordered_map<std::type_index, size_t> cpp_to_size_; std::unordered_map<int, size_t> proto_to_size_;
}; };
static DataTypeMap* InitDataTypeMap(); static DataTypeMap* InitDataTypeMap();
...@@ -45,7 +45,7 @@ static inline void RegisterType(DataTypeMap* map, ...@@ -45,7 +45,7 @@ static inline void RegisterType(DataTypeMap* map,
map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T)); map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T));
map->cpp_to_proto_.emplace(typeid(T), proto_type); map->cpp_to_proto_.emplace(typeid(T), proto_type);
map->proto_to_str_.emplace(static_cast<int>(proto_type), name); map->proto_to_str_.emplace(static_cast<int>(proto_type), name);
map->cpp_to_size_.emplace(typeid(T), sizeof(T)); map->proto_to_size_.emplace(static_cast<int>(proto_type), sizeof(T));
} }
static DataTypeMap* InitDataTypeMap() { static DataTypeMap* InitDataTypeMap() {
...@@ -54,17 +54,7 @@ static DataTypeMap* InitDataTypeMap() { ...@@ -54,17 +54,7 @@ static DataTypeMap* InitDataTypeMap() {
#define RegType(cc_type, proto_type) \ #define RegType(cc_type, proto_type) \
RegisterType<cc_type>(retv, proto_type, #cc_type) RegisterType<cc_type>(retv, proto_type, #cc_type)
// NOTE: Add your customize type here. _ForEachDataType_(RegType);
RegType(float16, proto::VarType::FP16);
RegType(float, proto::VarType::FP32);
RegType(double, proto::VarType::FP64);
RegType(int, proto::VarType::INT32);
RegType(int64_t, proto::VarType::INT64);
RegType(bool, proto::VarType::BOOL);
RegType(size_t, proto::VarType::SIZE_T);
RegType(int16_t, proto::VarType::INT16);
RegType(uint8_t, proto::VarType::UINT8);
RegType(int8_t, proto::VarType::INT8);
#undef RegType #undef RegType
return retv; return retv;
...@@ -96,12 +86,12 @@ std::string DataTypeToString(const proto::VarType::Type type) { ...@@ -96,12 +86,12 @@ std::string DataTypeToString(const proto::VarType::Type type) {
static_cast<int>(type)); static_cast<int>(type));
} }
size_t SizeOfType(std::type_index type) { size_t SizeOfType(proto::VarType::Type type) {
auto it = gDataTypeMap().cpp_to_size_.find(type); auto it = gDataTypeMap().proto_to_size_.find(static_cast<int>(type));
if (it != gDataTypeMap().cpp_to_size_.end()) { if (it != gDataTypeMap().proto_to_size_.end()) {
return it->second; return it->second;
} }
PADDLE_THROW("Not support %s as tensor type", type.name()); PADDLE_THROW("Not support %s as tensor type", DataTypeToString(type));
} }
} // namespace framework } // namespace framework
......
...@@ -22,46 +22,59 @@ limitations under the License. */ ...@@ -22,46 +22,59 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T>
struct DataTypeTrait {};
// Stub handle for void
template <>
struct DataTypeTrait<void> {
constexpr static auto DataType = proto::VarType::RAW;
};
#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
callback(cpp_type, ::paddle::framework::proto::VarType::proto_type);
#define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8)
#define DefineDataTypeTrait(cpp_type, proto_type) \
template <> \
struct DataTypeTrait<cpp_type> { \
constexpr static auto DataType = proto_type; \
}
_ForEachDataType_(DefineDataTypeTrait);
#undef DefineDataTypeTrait
extern proto::VarType::Type ToDataType(std::type_index type); extern proto::VarType::Type ToDataType(std::type_index type);
extern std::type_index ToTypeIndex(proto::VarType::Type type); extern std::type_index ToTypeIndex(proto::VarType::Type type);
template <typename Visitor> template <typename Visitor>
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
switch (type) { #define VisitDataTypeCallback(cpp_type, proto_type) \
case proto::VarType::FP16: do { \
visitor.template apply<platform::float16>(); if (type == proto_type) { \
break; visitor.template apply<cpp_type>(); \
case proto::VarType::FP32: return; \
visitor.template apply<float>(); } \
break; } while (0)
case proto::VarType::FP64:
visitor.template apply<double>(); _ForEachDataType_(VisitDataTypeCallback);
break; #undef VisitDataTypeCallback
case proto::VarType::INT32:
visitor.template apply<int>();
break;
case proto::VarType::INT64:
visitor.template apply<int64_t>();
break;
case proto::VarType::BOOL:
visitor.template apply<bool>();
break;
case proto::VarType::UINT8:
visitor.template apply<uint8_t>();
break;
case proto::VarType::INT16:
visitor.template apply<int16_t>();
break;
case proto::VarType::INT8:
visitor.template apply<int8_t>();
break;
default:
PADDLE_THROW("Not supported %d", type); PADDLE_THROW("Not supported %d", type);
}
} }
extern std::string DataTypeToString(const proto::VarType::Type type); extern std::string DataTypeToString(const proto::VarType::Type type);
extern size_t SizeOfType(std::type_index type); extern size_t SizeOfType(proto::VarType::Type type);
inline std::ostream& operator<<(std::ostream& out, inline std::ostream& operator<<(std::ostream& out,
const proto::VarType::Type& type) { const proto::VarType::Type& type) {
out << DataTypeToString(type); out << DataTypeToString(type);
......
...@@ -26,15 +26,15 @@ TEST(DataType, float16) { ...@@ -26,15 +26,15 @@ TEST(DataType, float16) {
Tensor tensor; Tensor tensor;
CPUPlace cpu; CPUPlace cpu;
tensor.mutable_data(cpu, f::ToTypeIndex(dtype)); tensor.mutable_data(cpu, dtype);
// test fp16 tensor // test fp16 tensor
EXPECT_EQ(tensor.type(), std::type_index(typeid(float16))); EXPECT_EQ(tensor.type(), f::ToDataType(typeid(float16)));
// test fp16 size // test fp16 size
EXPECT_EQ(f::SizeOfType(f::ToTypeIndex(dtype)), 2u); EXPECT_EQ(f::SizeOfType(dtype), 2u);
// test debug info // test debug info
std::string type = "float16"; std::string type = "::paddle::platform::float16";
EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str()); EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str());
} }
...@@ -120,7 +120,7 @@ void AllReduceOpHandle::RunImpl() { ...@@ -120,7 +120,7 @@ void AllReduceOpHandle::RunImpl() {
// Reduce All Tensor to trg in CPU // Reduce All Tensor to trg in CPU
ReduceLoDTensor func(lod_tensors, &trg); ReduceLoDTensor func(lod_tensors, &trg);
VisitDataType(ToDataType(lod_tensors[0]->type()), func); VisitDataType(lod_tensors[0]->type(), func);
for (size_t i = 1; i < local_scopes_.size(); ++i) { for (size_t i = 1; i < local_scopes_.size(); ++i) {
auto &scope = auto &scope =
......
...@@ -33,7 +33,7 @@ struct FuseVarsOpHandle : public OpHandleBase { ...@@ -33,7 +33,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
FuseVarsOpHandle(ir::Node *node, Scope *local_scope, FuseVarsOpHandle(ir::Node *node, Scope *local_scope,
const platform::Place &place, const platform::Place &place,
const std::unordered_map<std::string, int64_t> &inputs_numel, const std::unordered_map<std::string, int64_t> &inputs_numel,
const std::type_index &var_type) const proto::VarType::Type var_type)
: OpHandleBase(node), : OpHandleBase(node),
local_scope_(local_scope), local_scope_(local_scope),
place_(place), place_(place),
...@@ -57,7 +57,7 @@ struct FuseVarsOpHandle : public OpHandleBase { ...@@ -57,7 +57,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
Scope *local_scope_; Scope *local_scope_;
const platform::Place place_; const platform::Place place_;
const std::unordered_map<std::string, int64_t> inputs_numel_; const std::unordered_map<std::string, int64_t> inputs_numel_;
const std::type_index type_; const proto::VarType::Type type_;
int64_t total_numel_; int64_t total_numel_;
}; };
} // namespace details } // namespace details
......
...@@ -106,7 +106,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -106,7 +106,7 @@ void ReduceOpHandle::RunImpl() {
if (!FLAGS_cpu_deterministic) { if (!FLAGS_cpu_deterministic) {
ReduceLoDTensor func(lod_tensors, ReduceLoDTensor func(lod_tensors,
out_var->GetMutable<framework::LoDTensor>()); out_var->GetMutable<framework::LoDTensor>());
VisitDataType(ToDataType(lod_tensors[0]->type()), func); VisitDataType(lod_tensors[0]->type(), func);
} else { } else {
// We sum lod_tensors to reduce_sum_trg which is in local_scopes_0 // We sum lod_tensors to reduce_sum_trg which is in local_scopes_0
// here, but it doesn't mean reduce_sum_trg must be in local_scopes_0. // here, but it doesn't mean reduce_sum_trg must be in local_scopes_0.
...@@ -116,7 +116,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -116,7 +116,7 @@ void ReduceOpHandle::RunImpl() {
->FindVar(out_var_handle->name_) ->FindVar(out_var_handle->name_)
->GetMutable<framework::LoDTensor>(); ->GetMutable<framework::LoDTensor>();
ReduceLoDTensor func(lod_tensors, &reduce_sum_trg); ReduceLoDTensor func(lod_tensors, &reduce_sum_trg);
VisitDataType(ToDataType(lod_tensors[0]->type()), func); VisitDataType(lod_tensors[0]->type(), func);
auto trg = out_var->GetMutable<framework::LoDTensor>(); auto trg = out_var->GetMutable<framework::LoDTensor>();
if (reduce_sum_trg.data<void>() != trg->data<void>()) { if (reduce_sum_trg.data<void>() != trg->data<void>()) {
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/dlpack_tensor.h" #include "paddle/fluid/framework/dlpack_tensor.h"
#include "paddle/fluid/framework/data_type.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -36,26 +36,23 @@ static ::DLDataType GetDLDataTypeCode() { ...@@ -36,26 +36,23 @@ static ::DLDataType GetDLDataTypeCode() {
return dtype; return dtype;
} }
static DLDataType GetDLDataTypeFromTypeIndex(const std::type_index &type) { static std::unordered_map<int, ::DLDataType> CreateDLDataTypeMap() {
#define REG_DL_DATA_TYPE(type) \ static std::unordered_map<int, ::DLDataType> result;
{ std::type_index(typeid(type)), GetDLDataTypeCode<type>() }
static const std::unordered_map<std::type_index, ::DLDataType> #define REG_DL_DATA_TYPE(cpp_type, proto_type) \
type_to_dtype_map({ result[static_cast<int>(proto_type)] = GetDLDataTypeCode<cpp_type>()
REG_DL_DATA_TYPE(platform::float16), // NOLINT
REG_DL_DATA_TYPE(float), // NOLINT _ForEachDataType_(REG_DL_DATA_TYPE);
REG_DL_DATA_TYPE(double), // NOLINT #undef REG_DL_DATA_TYPE
REG_DL_DATA_TYPE(int), // NOLINT return result;
REG_DL_DATA_TYPE(int64_t), // NOLINT }
REG_DL_DATA_TYPE(bool), // NOLINT
REG_DL_DATA_TYPE(size_t), // NOLINT static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) {
REG_DL_DATA_TYPE(int16_t), // NOLINT static auto type_to_dtype_map = CreateDLDataTypeMap();
REG_DL_DATA_TYPE(uint8_t), // NOLINT
REG_DL_DATA_TYPE(int8_t) // NOLINT
});
static auto type_to_dtype_map_end_it = type_to_dtype_map.end(); static auto type_to_dtype_map_end_it = type_to_dtype_map.end();
auto it = type_to_dtype_map.find(type); auto it = type_to_dtype_map.find(static_cast<int>(type));
PADDLE_ENFORCE(it != type_to_dtype_map_end_it, "Unsupported data type %s", PADDLE_ENFORCE(it != type_to_dtype_map_end_it, "Unsupported data type %d",
type.name()); type);
return it->second; return it->second;
#undef REG_DL_DATA_TYPE #undef REG_DL_DATA_TYPE
} }
......
...@@ -91,23 +91,11 @@ void TestMainLoop() { ...@@ -91,23 +91,11 @@ void TestMainLoop() {
} }
} }
} }
TEST(dlpack, test_all) {
#define TestCallback(cpp_type, proto_type) TestMainLoop<cpp_type>()
#define PADDLE_DLPACK_TEST(type) \ _ForEachDataType_(TestCallback);
TEST(dlpack, test_##type) { TestMainLoop<type>(); } }
using float16 = platform::float16;
PADDLE_DLPACK_TEST(float16);
PADDLE_DLPACK_TEST(float);
PADDLE_DLPACK_TEST(double);
PADDLE_DLPACK_TEST(int);
PADDLE_DLPACK_TEST(int64_t);
PADDLE_DLPACK_TEST(bool);
PADDLE_DLPACK_TEST(size_t);
PADDLE_DLPACK_TEST(int16_t);
PADDLE_DLPACK_TEST(uint8_t);
PADDLE_DLPACK_TEST(int8_t);
#undef PADDLE_DLPACK_TEST
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -138,39 +138,19 @@ void print_lod_tensor(std::string var_name, const LoDTensor& lod_tensor) { ...@@ -138,39 +138,19 @@ void print_lod_tensor(std::string var_name, const LoDTensor& lod_tensor) {
std::cout << sstream.str() << std::endl; std::cout << sstream.str() << std::endl;
} }
void print_fetch_var(Scope* scope, std::string var_name) { static void print_fetch_var(Scope* scope, const std::string& var_name) {
const LoDTensor& tensor = scope->FindVar(var_name)->Get<LoDTensor>(); auto& tensor = scope->FindVar(var_name)->Get<LoDTensor>();
if (std::type_index(tensor.type()) == #define PrintLoDTensorCallback(cpp_type, proto_type) \
std::type_index(typeid(platform::float16))) { do { \
print_lod_tensor<platform::float16>(var_name, tensor); if (tensor.type() == proto_type) { \
} else if (std::type_index(tensor.type()) == std::type_index(typeid(float))) { print_lod_tensor<cpp_type>(var_name, tensor); \
print_lod_tensor<float>(var_name, tensor); return; \
} else if (std::type_index(tensor.type()) == } \
std::type_index(typeid(double))) { } while (0)
print_lod_tensor<double>(var_name, tensor);
} else if (std::type_index(tensor.type()) == std::type_index(typeid(int))) { _ForEachDataType_(PrintLoDTensorCallback);
print_lod_tensor<int>(var_name, tensor); VLOG(1) << "print_fetch_var: unrecognized data type:" << tensor.type();
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int64_t))) {
print_lod_tensor<int64_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) == std::type_index(typeid(bool))) {
print_lod_tensor<bool>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(uint8_t))) {
print_lod_tensor<uint8_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int16_t))) {
print_lod_tensor<int16_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int8_t))) {
print_lod_tensor<int8_t>(var_name, tensor);
} else {
VLOG(1) << "print_fetch_var: unrecognized data type:"
<< tensor.type().name();
}
return;
} }
void ExecutorThreadWorker::TrainFiles() { void ExecutorThreadWorker::TrainFiles() {
......
...@@ -70,9 +70,9 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) { ...@@ -70,9 +70,9 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
// only print first ten elements // only print first ten elements
int64_t size = t.numel() < 10 ? t.numel() : 10; int64_t size = t.numel() < 10 ? t.numel() : 10;
for (int64_t i = 0; i < size; ++i) { for (int64_t i = 0; i < size; ++i) {
if (IsType<float>(t.type())) { if (t.type() == proto::VarType::FP32) {
os << t.data<float>()[i] << " "; os << t.data<float>()[i] << " ";
} else if (IsType<int64_t>(t.type())) { } else if (t.type() == proto::VarType::INT64) {
os << t.data<int64_t>()[i] << " "; os << t.data<int64_t>()[i] << " ";
} else { } else {
PADDLE_THROW("LoDTensor data type not in [float, int64_t]"); PADDLE_THROW("LoDTensor data type not in [float, int64_t]");
...@@ -387,7 +387,7 @@ void LoDTensor::MergeLoDTensor( ...@@ -387,7 +387,7 @@ void LoDTensor::MergeLoDTensor(
PADDLE_ENFORCE(!lod_tensors.empty()); PADDLE_ENFORCE(!lod_tensors.empty());
framework::DDim new_dim = lod_tensors[0]->dims(); framework::DDim new_dim = lod_tensors[0]->dims();
std::type_index new_type = lod_tensors[0]->type(); auto new_type = lod_tensors[0]->type();
framework::DataLayout new_layout = lod_tensors[0]->layout(); framework::DataLayout new_layout = lod_tensors[0]->layout();
LoD new_lod = lod_tensors[0]->lod(); LoD new_lod = lod_tensors[0]->lod();
for (size_t i = 1; i < lod_tensors.size(); ++i) { for (size_t i = 1; i < lod_tensors.size(); ++i) {
......
...@@ -34,7 +34,8 @@ TEST(OpKernelType, ToString) { ...@@ -34,7 +34,8 @@ TEST(OpKernelType, ToString) {
OpKernelType op_kernel_type2(DataType::FP16, CUDAPlace(0), DataLayout::kNCHW, OpKernelType op_kernel_type2(DataType::FP16, CUDAPlace(0), DataLayout::kNCHW,
LibraryType::kCUDNN); LibraryType::kCUDNN);
ASSERT_EQ(paddle::framework::KernelTypeToString(op_kernel_type2), ASSERT_EQ(paddle::framework::KernelTypeToString(op_kernel_type2),
"data_type[float16]:data_layout[NCHW]:place[CUDAPlace(0)]:library_" "data_type[::paddle::platform::float16]:data_layout[NCHW]:place["
"CUDAPlace(0)]:library_"
"type[CUDNN]"); "type[CUDNN]");
} }
......
...@@ -43,10 +43,9 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = { ...@@ -43,10 +43,9 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
proto::VarType::Type GetDataTypeOfVar(const Variable* var) { proto::VarType::Type GetDataTypeOfVar(const Variable* var) {
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
return framework::ToDataType(var->Get<framework::LoDTensor>().type()); return var->Get<framework::LoDTensor>().type();
} else if (var->IsType<framework::SelectedRows>()) { } else if (var->IsType<framework::SelectedRows>()) {
return framework::ToDataType( return var->Get<framework::SelectedRows>().value().type();
var->Get<framework::SelectedRows>().value().type());
} else { } else {
PADDLE_THROW("Var should be LoDTensor or SelectedRows"); PADDLE_THROW("Var should be LoDTensor or SelectedRows");
} }
...@@ -93,13 +92,13 @@ static std::string GetDtype(const Scope& scope, const std::string& name) { ...@@ -93,13 +92,13 @@ static std::string GetDtype(const Scope& scope, const std::string& name) {
if (UNLIKELY(!tensor.IsInitialized())) { if (UNLIKELY(!tensor.IsInitialized())) {
return ""; return "";
} }
return DataTypeToString(ToDataType(tensor.type())); return DataTypeToString(tensor.type());
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
auto tensor = var->Get<SelectedRows>().value(); auto tensor = var->Get<SelectedRows>().value();
if (UNLIKELY(!tensor.IsInitialized())) { if (UNLIKELY(!tensor.IsInitialized())) {
return "uninited"; return "uninited";
} else { } else {
return DataTypeToString(ToDataType(tensor.type())); return DataTypeToString(tensor.type());
} }
} else { } else {
return ""; return "";
...@@ -686,7 +685,8 @@ static void CheckTensorNANOrInf(const std::string& name, ...@@ -686,7 +685,8 @@ static void CheckTensorNANOrInf(const std::string& name,
if (tensor.memory_size() == 0) { if (tensor.memory_size() == 0) {
return; return;
} }
if (!IsType<float>(tensor.type()) && !IsType<double>(tensor.type())) { if (tensor.type() != proto::VarType::FP32 &&
tensor.type() != proto::VarType::FP64) {
return; return;
} }
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor), PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
...@@ -873,7 +873,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( ...@@ -873,7 +873,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
t = &(var->Get<SelectedRows>().value()); t = &(var->Get<SelectedRows>().value());
} }
if (t != nullptr) { if (t != nullptr) {
int tmp = static_cast<int>(ToDataType(t->type())); int tmp = static_cast<int>(t->type());
PADDLE_ENFORCE( PADDLE_ENFORCE(
tmp == data_type || data_type == -1, tmp == data_type || data_type == -1,
"DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)", "DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)",
......
...@@ -218,11 +218,11 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value, ...@@ -218,11 +218,11 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
if (index < 0) { if (index < 0) {
VLOG(5) << "id " << id << " not in the table, return 0"; VLOG(5) << "id " << id << " not in the table, return 0";
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(value_->type()), value_->type(),
TensorFillVisitor(value, i * value_width, value_width, 0.0)); TensorFillVisitor(value, i * value_width, value_width, 0.0));
} else { } else {
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(value_->type()), value_->type(),
TensorCopyVisitor(value, i * value_width, *value_.get(), TensorCopyVisitor(value, i * value_width, *value_.get(),
index * value_width, value_width)); index * value_width, value_width));
} }
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
extern size_t SizeOfType(std::type_index type); extern size_t SizeOfType(proto::VarType::Type type);
void Tensor::check_memory_size() const { void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor holds no memory. Call Tensor::mutable_data first."); holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
...@@ -31,7 +31,7 @@ size_t Tensor::memory_size() const { ...@@ -31,7 +31,7 @@ size_t Tensor::memory_size() const {
return holder_ == nullptr ? 0UL : holder_->size() - offset_; return holder_ == nullptr ? 0UL : holder_->size() - offset_;
} }
void* Tensor::mutable_data(platform::Place place, std::type_index type, void* Tensor::mutable_data(platform::Place place, proto::VarType::Type type,
memory::Allocator::Attr attr, memory::Allocator::Attr attr,
size_t requested_size) { size_t requested_size) {
type_ = type; type_ = type;
......
...@@ -19,9 +19,9 @@ limitations under the License. */ ...@@ -19,9 +19,9 @@ limitations under the License. */
#include <memory> #include <memory>
#include <typeindex> #include <typeindex>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -67,7 +67,7 @@ class Tensor { ...@@ -67,7 +67,7 @@ class Tensor {
friend struct EigenVector; friend struct EigenVector;
public: public:
Tensor() : type_(typeid(float)), offset_(0) {} Tensor() : type_(proto::VarType::FP32), offset_(0) {}
/*! Return a pointer to mutable memory block. */ /*! Return a pointer to mutable memory block. */
template <typename T> template <typename T>
...@@ -88,7 +88,7 @@ class Tensor { ...@@ -88,7 +88,7 @@ class Tensor {
memory::Allocator::Attr attr = memory::Allocator::kDefault, memory::Allocator::Attr attr = memory::Allocator::kDefault,
size_t requested_size = 0); size_t requested_size = 0);
void* mutable_data(platform::Place place, std::type_index type, void* mutable_data(platform::Place place, proto::VarType::Type type,
memory::Allocator::Attr attr = memory::Allocator::kDefault, memory::Allocator::Attr attr = memory::Allocator::kDefault,
size_t requested_size = 0); size_t requested_size = 0);
...@@ -138,7 +138,7 @@ class Tensor { ...@@ -138,7 +138,7 @@ class Tensor {
return holder_->place(); return holder_->place();
} }
std::type_index type() const { proto::VarType::Type type() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor not initialized yet when Tensor::type() is called."); holder_, "Tensor not initialized yet when Tensor::type() is called.");
return type_; return type_;
...@@ -161,7 +161,7 @@ class Tensor { ...@@ -161,7 +161,7 @@ class Tensor {
private: private:
/*! holds the memory block if allocated. */ /*! holds the memory block if allocated. */
std::shared_ptr<memory::Allocation> holder_; std::shared_ptr<memory::Allocation> holder_;
std::type_index type_; proto::VarType::Type type_;
/** /**
* @brief points to elements dimensions. * @brief points to elements dimensions.
* *
......
...@@ -24,9 +24,8 @@ template <typename T> ...@@ -24,9 +24,8 @@ template <typename T>
inline const T* Tensor::data() const { inline const T* Tensor::data() const {
check_memory_size(); check_memory_size();
bool valid = bool valid =
std::is_same<T, void>::value || type_ == std::type_index(typeid(T)); std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType;
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %d", type_);
type_.name());
return reinterpret_cast<const T*>( return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_); reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
...@@ -38,9 +37,8 @@ template <typename T> ...@@ -38,9 +37,8 @@ template <typename T>
inline T* Tensor::data() { inline T* Tensor::data() {
check_memory_size(); check_memory_size();
bool valid = bool valid =
std::is_same<T, void>::value || type_ == std::type_index(typeid(T)); std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType;
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", type_);
type_.name());
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) + return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_); offset_);
} }
...@@ -60,7 +58,7 @@ inline T* Tensor::mutable_data(platform::Place place, ...@@ -60,7 +58,7 @@ inline T* Tensor::mutable_data(platform::Place place,
size_t requested_size) { size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
return reinterpret_cast<T*>( return reinterpret_cast<T*>(
mutable_data(place, typeid(T), attr, requested_size)); mutable_data(place, DataTypeTrait<T>::DataType, attr, requested_size));
} }
inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
......
...@@ -186,7 +186,7 @@ struct AnyDTypeVisitor { ...@@ -186,7 +186,7 @@ struct AnyDTypeVisitor {
template <typename Predicate, typename DevCtx> template <typename Predicate, typename DevCtx>
inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor, inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor,
const DevCtx& ctx, framework::Tensor* out) { const DevCtx& ctx, framework::Tensor* out) {
VisitDataType(ToDataType(tensor.type()), AnyDTypeVisitor<Predicate, DevCtx>( VisitDataType(tensor.type(), AnyDTypeVisitor<Predicate, DevCtx>(
predicate, tensor, ctx, out)); predicate, tensor, ctx, out));
} }
...@@ -379,7 +379,7 @@ void TensorToStream(std::ostream& os, const Tensor& tensor, ...@@ -379,7 +379,7 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
// int32_t size // int32_t size
// void* protobuf message // void* protobuf message
proto::VarType::TensorDesc desc; proto::VarType::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type())); desc.set_data_type(tensor.type());
auto dims = framework::vectorize(tensor.dims()); auto dims = framework::vectorize(tensor.dims());
auto* pb_dims = desc.mutable_dims(); auto* pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0); pb_dims->Resize(static_cast<int>(dims.size()), 0);
...@@ -461,9 +461,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor, ...@@ -461,9 +461,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
tensor->Resize(framework::make_ddim(dims)); tensor->Resize(framework::make_ddim(dims));
void* buf; void* buf;
auto ctx = platform::CPUDeviceContext(); auto ctx = platform::CPUDeviceContext();
size_t size = size_t size = tensor->numel() * framework::SizeOfType(desc.data_type());
tensor->numel() *
framework::SizeOfType(framework::ToTypeIndex(desc.data_type()));
if (platform::is_gpu_place(dev_ctx.GetPlace())) { if (platform::is_gpu_place(dev_ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
Tensor cpu_tensor; Tensor cpu_tensor;
......
...@@ -289,10 +289,10 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs, ...@@ -289,10 +289,10 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
auto type = fetch.type(); auto type = fetch.type();
auto output = &(outputs->at(i)); auto output = &(outputs->at(i));
output->name = fetchs_[idx]->Input("X")[0]; output->name = fetchs_[idx]->Input("X")[0];
if (type == typeid(float)) { if (type == framework::proto::VarType::FP32) {
GetFetchOne<float>(fetch, output); GetFetchOne<float>(fetch, output);
output->dtype = PaddleDType::FLOAT32; output->dtype = PaddleDType::FLOAT32;
} else if (type == typeid(int64_t)) { } else if (type == framework::proto::VarType::INT64) {
GetFetchOne<int64_t>(fetch, output); GetFetchOne<int64_t>(fetch, output);
output->dtype = PaddleDType::INT64; output->dtype = PaddleDType::INT64;
} else { } else {
......
...@@ -266,10 +266,10 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs, ...@@ -266,10 +266,10 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs,
auto type = fetch.type(); auto type = fetch.type();
auto output = &(outputs->at(i)); auto output = &(outputs->at(i));
output->name = fetchs_[idx]->Input("X")[0]; output->name = fetchs_[idx]->Input("X")[0];
if (type == typeid(float)) { if (type == framework::DataTypeTrait<float>::DataType) {
GetFetchOne<float>(fetch, output); GetFetchOne<float>(fetch, output);
output->dtype = PaddleDType::FLOAT32; output->dtype = PaddleDType::FLOAT32;
} else if (type == typeid(int64_t)) { } else if (type == framework::DataTypeTrait<int64_t>::DataType) {
GetFetchOne<int64_t>(fetch, output); GetFetchOne<int64_t>(fetch, output);
output->dtype = PaddleDType::INT64; output->dtype = PaddleDType::INT64;
} else { } else {
......
...@@ -36,10 +36,10 @@ namespace paddle { ...@@ -36,10 +36,10 @@ namespace paddle {
PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) { PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) {
PaddleTensor pt; PaddleTensor pt;
if (t->type() == typeid(int64_t)) { if (t->type() == framework::proto::VarType::INT64) {
pt.data.Reset(t->data<void>(), t->numel() * sizeof(int64_t)); pt.data.Reset(t->data<void>(), t->numel() * sizeof(int64_t));
pt.dtype = PaddleDType::INT64; pt.dtype = PaddleDType::INT64;
} else if (t->type() == typeid(float)) { } else if (t->type() == framework::proto::VarType::FP32) {
pt.data.Reset(t->data<void>(), t->numel() * sizeof(float)); pt.data.Reset(t->data<void>(), t->numel() * sizeof(float));
pt.dtype = PaddleDType::FLOAT32; pt.dtype = PaddleDType::FLOAT32;
} else { } else {
......
...@@ -361,7 +361,7 @@ static bool CompareTensorData(const framework::LoDTensor &a, ...@@ -361,7 +361,7 @@ static bool CompareTensorData(const framework::LoDTensor &a,
} }
for (size_t i = 0; i < a_size; i++) { for (size_t i = 0; i < a_size; i++) {
if (a.type() == typeid(float)) { if (a.type() == framework::proto::VarType::FP32) {
const auto *a_data = a.data<float>(); const auto *a_data = a.data<float>();
const auto *b_data = b.data<float>(); const auto *b_data = b.data<float>();
if (std::abs(a_data[i] - b_data[i]) > 1e-3) { if (std::abs(a_data[i] - b_data[i]) > 1e-3) {
...@@ -370,7 +370,7 @@ static bool CompareTensorData(const framework::LoDTensor &a, ...@@ -370,7 +370,7 @@ static bool CompareTensorData(const framework::LoDTensor &a,
b_data[i]); b_data[i]);
return false; return false;
} }
} else if (a.type() == typeid(int64_t)) { } else if (a.type() == framework::proto::VarType::INT64) {
const auto *a_data = a.data<int64_t>(); const auto *a_data = a.data<int64_t>();
const auto *b_data = b.data<int64_t>(); const auto *b_data = b.data<int64_t>();
if (std::abs(a_data[i] - b_data[i]) > 1e-3) { if (std::abs(a_data[i] - b_data[i]) > 1e-3) {
......
...@@ -78,7 +78,7 @@ class AffineGridOp : public framework::OperatorWithKernel { ...@@ -78,7 +78,7 @@ class AffineGridOp : public framework::OperatorWithKernel {
library = framework::LibraryType::kCUDNN; library = framework::LibraryType::kCUDNN;
} }
#endif #endif
auto data_type = framework::ToDataType(ctx.Input<Tensor>("Theta")->type()); auto data_type = ctx.Input<Tensor>("Theta")->type();
return framework::OpKernelType(data_type, ctx.GetPlace(), return framework::OpKernelType(data_type, ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library); framework::DataLayout::kAnyLayout, library);
} }
...@@ -188,9 +188,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { ...@@ -188,9 +188,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Theta")->type(),
framework::ToDataType(ctx.Input<Tensor>("Theta")->type()), ctx.GetPlace(),
ctx.GetPlace(), framework::DataLayout::kAnyLayout, library_); framework::DataLayout::kAnyLayout, library_);
} }
}; };
......
...@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL( ...@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL(
int32_t>, int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int16_t>, int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, size_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
uint8_t>); uint8_t>);
...@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL(
int32_t>, int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int16_t>, int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
size_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
uint8_t>); uint8_t>);
...@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL( ...@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL(
int32_t>, int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int16_t>, int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, size_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
uint8_t>); uint8_t>);
...@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL(
int32_t>, int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
int16_t>, int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
size_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
uint8_t>); uint8_t>);
...@@ -58,7 +58,7 @@ struct ArrayToLoDFunctor : public boost::static_visitor<void> { ...@@ -58,7 +58,7 @@ struct ArrayToLoDFunctor : public boost::static_visitor<void> {
ArrayToLoDFunctorImpl<DeviceContext> functor; ArrayToLoDFunctorImpl<DeviceContext> functor;
functor.dev_ctx_ = dev_ctx; functor.dev_ctx_ = dev_ctx;
functor.prev_functor_ = this; functor.prev_functor_ = this;
framework::VisitDataType(framework::ToDataType(out->type()), functor); framework::VisitDataType(out->type(), functor);
} }
}; };
...@@ -91,7 +91,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { ...@@ -91,7 +91,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
PADDLE_ENFORCE(!x.empty(), "There's no element in the input array."); PADDLE_ENFORCE(!x.empty(), "There's no element in the input array.");
int rank = x[0].dims().size(); int rank = x[0].dims().size();
platform::Place place = x[0].place(); platform::Place place = x[0].place();
std::type_index data_type = x[0].type(); auto data_type = x[0].type();
int64_t batch_size = x[0].dims()[0]; int64_t batch_size = x[0].dims()[0];
framework::DDim ins_dims = rank > 1 framework::DDim ins_dims = rank > 1
? framework::slice_ddim(x[0].dims(), 1, rank) ? framework::slice_ddim(x[0].dims(), 1, rank)
......
...@@ -121,8 +121,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -121,8 +121,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType( framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context()); ctx.device_context());
} }
......
...@@ -103,8 +103,7 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel { ...@@ -103,8 +103,7 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("param")->type(),
framework::ToDataType(ctx.Input<Tensor>("param")->type()),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -72,8 +72,7 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -72,8 +72,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type = ctx.Input<Tensor>("X")->type();
framework::ToDataType(ctx.Input<Tensor>("X")->type());
// By default, the type of the scale, bias, mean, // By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor) // and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor). // or double (For double input tensor).
...@@ -81,17 +80,13 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -81,17 +80,13 @@ class BatchNormOp : public framework::OperatorWithKernel {
if (input_data_type == framework::proto::VarType::FP64) { if (input_data_type == framework::proto::VarType::FP64) {
bn_param_type = framework::proto::VarType::FP64; bn_param_type = framework::proto::VarType::FP64;
} }
PADDLE_ENFORCE_EQ(bn_param_type, PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Scale")->type(),
framework::ToDataType(ctx.Input<Tensor>("Scale")->type()),
"Scale input should be of float type"); "Scale input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type, PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Bias")->type(),
framework::ToDataType(ctx.Input<Tensor>("Bias")->type()),
"Bias input should be of float type"); "Bias input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type, PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Mean")->type(),
framework::ToDataType(ctx.Input<Tensor>("Mean")->type()),
"Mean input should be of float type"); "Mean input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType( PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Variance")->type(),
ctx.Input<Tensor>("Variance")->type()),
"Variance input should be of float type"); "Variance input should be of float type");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
...@@ -387,9 +382,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel { ...@@ -387,9 +382,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), ctx.GetPlace(), layout, library);
layout, library);
} }
}; };
......
...@@ -145,7 +145,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase { ...@@ -145,7 +145,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
LoDTensor* sentenceScores = ctx.Output<LoDTensor>("SentenceScores"); LoDTensor* sentenceScores = ctx.Output<LoDTensor>("SentenceScores");
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(scores->at(0).type()), scores->at(0).type(),
BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores, BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores,
beam_size, end_id)); beam_size, end_id));
} }
......
...@@ -282,8 +282,7 @@ class BeamSearchOp : public framework::OperatorWithKernel { ...@@ -282,8 +282,7 @@ class BeamSearchOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = framework::OpKernelType( framework::OpKernelType kt = framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::LoDTensor>("pre_ids")->type(),
ctx.Input<framework::LoDTensor>("pre_ids")->type()),
platform::CPUPlace()); platform::CPUPlace());
return kt; return kt;
} }
......
...@@ -48,13 +48,12 @@ class ConditionalOp : public framework::OperatorBase { ...@@ -48,13 +48,12 @@ class ConditionalOp : public framework::OperatorBase {
if (!(ips.size() == 1UL && ips[0]->IsInitialized())) { if (!(ips.size() == 1UL && ips[0]->IsInitialized())) {
PADDLE_THROW("should have one initialized input as condition"); PADDLE_THROW("should have one initialized input as condition");
} }
if (!(framework::IsType<bool>(ips[0]->type()) && // NOLINT
ips[0]->numel() == 1)) { PADDLE_ENFORCE(ips[0]->type() == framework::proto::VarType::BOOL &&
PADDLE_THROW( ips[0]->numel() == 1,
"condition input's data type should be bool, " "condition input's data type should be bool, "
"numel should be 1, actual numel is %d", "numel should be 1, actual numel is %d",
ips[0]->numel()); ips[0]->numel());
}
bool res = false; bool res = false;
if (platform::is_gpu_place(ips[0]->place())) { if (platform::is_gpu_place(ips[0]->place())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
...@@ -237,7 +237,7 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -237,7 +237,7 @@ class WhileGradOp : public framework::OperatorBase {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
auto &inside_tensor = var->Get<framework::LoDTensor>(); auto &inside_tensor = var->Get<framework::LoDTensor>();
framework::AttributeMap attrs; framework::AttributeMap attrs;
attrs["dtype"] = framework::ToDataType(inside_tensor.type()); attrs["dtype"] = inside_tensor.type();
attrs["shape"] = framework::vectorize2int(inside_tensor.dims()); attrs["shape"] = framework::vectorize2int(inside_tensor.dims());
attrs["value"] = 0.0f; attrs["value"] = 0.0f;
......
...@@ -92,10 +92,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -92,10 +92,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
} }
#endif #endif
auto input_data_type = auto input_data_type = ctx.Input<Tensor>("Input")->type();
framework::ToDataType(ctx.Input<Tensor>("Input")->type()); auto filter_data_type = ctx.Input<Tensor>("Filter")->type();
auto filter_data_type =
framework::ToDataType(ctx.Input<Tensor>("Filter")->type());
PADDLE_ENFORCE_EQ(input_data_type, filter_data_type, PADDLE_ENFORCE_EQ(input_data_type, filter_data_type,
"input and filter data type should be consistent"); "input and filter data type should be consistent");
...@@ -360,9 +358,8 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( ...@@ -360,9 +358,8 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), ctx.GetPlace(), layout_, library_);
layout_, library_);
} }
} // namespace operators } // namespace operators
......
...@@ -95,9 +95,8 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( ...@@ -95,9 +95,8 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), ctx.GetPlace(), layout_, library_);
layout_, library_);
} }
void Conv2DTransposeOpMaker::Make() { void Conv2DTransposeOpMaker::Make() {
...@@ -314,9 +313,8 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( ...@@ -314,9 +313,8 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), ctx.GetPlace(), layout_, library_);
layout_, library_);
} }
} // namespace operators } // namespace operators
......
...@@ -118,8 +118,7 @@ class CRFDecodingOp : public framework::OperatorWithKernel { ...@@ -118,8 +118,7 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<LoDTensor>("Emission")->type(),
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -51,8 +51,7 @@ class CropOp : public framework::OperatorWithKernel { ...@@ -51,8 +51,7 @@ class CropOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
...@@ -174,9 +173,7 @@ class CropOpGrad : public framework::OperatorWithKernel { ...@@ -174,9 +173,7 @@ class CropOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))
->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -57,8 +57,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel { ...@@ -57,8 +57,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
...@@ -111,8 +110,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -111,8 +110,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -36,8 +36,7 @@ class CTCAlignOp : public framework::OperatorWithKernel { ...@@ -36,8 +36,7 @@ class CTCAlignOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -53,8 +53,7 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel { ...@@ -53,8 +53,7 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()), ctx.Input<framework::Tensor>("Input")->type(), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -45,8 +45,7 @@ class BipartiteMatchOp : public framework::OperatorWithKernel { ...@@ -45,8 +45,7 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<LoDTensor>("DistMat")->type(),
framework::ToDataType(ctx.Input<LoDTensor>("DistMat")->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -66,8 +66,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel { ...@@ -66,8 +66,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()), ctx.Input<framework::Tensor>("Input")->type(), ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -66,8 +66,7 @@ class GenerateProposalsOp : public framework::OperatorWithKernel { ...@@ -66,8 +66,7 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Anchors")->type(),
framework::ToDataType(ctx.Input<Tensor>("Anchors")->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -249,8 +249,7 @@ class MineHardExamplesOp : public framework::OperatorWithKernel { ...@@ -249,8 +249,7 @@ class MineHardExamplesOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("ClsLoss")->type()), ctx.Input<framework::Tensor>("ClsLoss")->type(), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
......
...@@ -65,8 +65,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { ...@@ -65,8 +65,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::LoDTensor>("Scores")->type(),
ctx.Input<framework::LoDTensor>("Scores")->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -72,8 +72,7 @@ class PriorBoxOp : public framework::OperatorWithKernel { ...@@ -72,8 +72,7 @@ class PriorBoxOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()), ctx.Input<framework::Tensor>("Input")->type(), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -498,8 +498,7 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel { ...@@ -498,8 +498,7 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
...@@ -519,8 +518,7 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel { ...@@ -519,8 +518,7 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -78,8 +78,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { ...@@ -78,8 +78,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::LoDTensor>("Anchor")->type(),
ctx.Input<framework::LoDTensor>("Anchor")->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -57,8 +57,7 @@ class TargetAssignOp : public framework::OperatorWithKernel { ...@@ -57,8 +57,7 @@ class TargetAssignOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -71,8 +71,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel { ...@@ -71,8 +71,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::Tensor>("DetectRes")->type(),
ctx.Input<framework::Tensor>("DetectRes")->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -122,8 +122,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -122,8 +122,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
if (var->IsType<framework::SelectedRows>()) { if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
ProtoEncodeHelper e2(static_cast<char*>(buf), 128); ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
size_t rows_memory_size = size_t rows_memory_size = slr->rows().size() * sizeof(int64_t);
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size); e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
slices[2] = ::grpc::Slice(e2.size()); slices[2] = ::grpc::Slice(e2.size());
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size()); memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
......
...@@ -61,8 +61,7 @@ TensorPayload GetTensorPayload(framework::Variable* var, ...@@ -61,8 +61,7 @@ TensorPayload GetTensorPayload(framework::Variable* var,
auto tensor = var->Get<framework::LoDTensor>(); auto tensor = var->Get<framework::LoDTensor>();
// FIXME(wuyi): data types in send_recv.proto is copied from // FIXME(wuyi): data types in send_recv.proto is copied from
// framework.proto // framework.proto
request->set_data_type( request->set_data_type(static_cast<VarMsg::Type>(tensor.type()));
static_cast<VarMsg::Type>(framework::ToDataType(tensor.type())));
for (auto& dim : framework::vectorize(tensor.dims())) { for (auto& dim : framework::vectorize(tensor.dims())) {
request->add_dims(dim); request->add_dims(dim);
} }
...@@ -83,8 +82,7 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var, ...@@ -83,8 +82,7 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
VarMsg* request) { VarMsg* request) {
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
request->set_data_type( request->set_data_type(static_cast<VarMsg::Type>(slr->value().type()));
static_cast<VarMsg::Type>(framework::ToDataType(slr->value().type())));
request->set_lod_level(0); request->set_lod_level(0);
request->set_slr_height(slr->height()); request->set_slr_height(slr->height());
......
...@@ -58,18 +58,19 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var, ...@@ -58,18 +58,19 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
VarMsg* request); VarMsg* request);
inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) { inline framework::proto::VarType::Type ToVarType(
sendrecv::VariableMessage::Type type) {
switch (type) { switch (type) {
case sendrecv::VariableMessage::FP32: case sendrecv::VariableMessage::FP32:
return typeid(float); // NOLINT return framework::proto::VarType::FP32; // NOLINT
case sendrecv::VariableMessage::FP64: case sendrecv::VariableMessage::FP64:
return typeid(double); // NOLINT return framework::proto::VarType::FP64; // NOLINT
case sendrecv::VariableMessage::INT32: case sendrecv::VariableMessage::INT32:
return typeid(int); // NOLINT return framework::proto::VarType::INT32; // NOLINT
case sendrecv::VariableMessage::INT64: case sendrecv::VariableMessage::INT64:
return typeid(int64_t); // NOLINT return framework::proto::VarType::INT64; // NOLINT
case sendrecv::VariableMessage::BOOL: case sendrecv::VariableMessage::BOOL:
return typeid(bool); // NOLINT return framework::proto::VarType::BOOL; // NOLINT
default: default:
PADDLE_THROW("Not support type %d", type); PADDLE_THROW("Not support type %d", type);
} }
......
...@@ -114,7 +114,7 @@ bool VariableResponse::CopyLodTensorData( ...@@ -114,7 +114,7 @@ bool VariableResponse::CopyLodTensorData(
tensor->set_lod(lod); tensor->set_lod(lod);
void* tensor_data = void* tensor_data =
tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type())); tensor->mutable_data(ctx.GetPlace(), ToVarType(meta_.data_type()));
VLOG(6) << "Tensor.memory_size = " << tensor->memory_size() VLOG(6) << "Tensor.memory_size = " << tensor->memory_size()
<< ", Buffer Size = " << length; << ", Buffer Size = " << length;
...@@ -139,13 +139,13 @@ bool VariableResponse::CopySelectRowsTensorData( ...@@ -139,13 +139,13 @@ bool VariableResponse::CopySelectRowsTensorData(
slr->set_height(meta_.slr_height()); slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value(); auto* tensor = slr->mutable_value();
tensor->Resize(dims); tensor->Resize(dims);
PADDLE_ENFORCE_EQ(static_cast<size_t>(tensor->numel()), PADDLE_ENFORCE_EQ(
length / framework::SizeOfType( static_cast<size_t>(tensor->numel()),
paddle::operators::distributed::ToTypeIndex( length / framework::SizeOfType(paddle::operators::distributed::ToVarType(
meta_.data_type()))); meta_.data_type())));
void* tensor_data = tensor->mutable_data( void* tensor_data = tensor->mutable_data(
ctx.GetPlace(), ctx.GetPlace(),
paddle::operators::distributed::ToTypeIndex(meta_.data_type())); paddle::operators::distributed::ToVarType(meta_.data_type()));
if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) { if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
return false; return false;
...@@ -159,8 +159,7 @@ bool VariableResponse::CopySelectRowsData( ...@@ -159,8 +159,7 @@ bool VariableResponse::CopySelectRowsData(
const platform::DeviceContext& ctx, int length) { const platform::DeviceContext& ctx, int length) {
auto* slr = GetVar()->GetMutable<framework::SelectedRows>(); auto* slr = GetVar()->GetMutable<framework::SelectedRows>();
slr->mutable_rows()->clear(); slr->mutable_rows()->clear();
slr->mutable_rows()->resize(length / slr->mutable_rows()->resize(length / sizeof(int64_t)); // int64
framework::SizeOfType(typeid(int64_t))); // int64
int64_t* rows_data = slr->mutable_rows()->data(); int64_t* rows_data = slr->mutable_rows()->data();
// copy rows CPU data, GPU data will be copied lazily. // copy rows CPU data, GPU data will be copied lazily.
......
...@@ -108,9 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel { ...@@ -108,9 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.MultiInput<framework::Tensor>("X").front()->type(), ctx.GetPlace());
ctx.MultiInput<framework::Tensor>("X").front()->type()),
ctx.GetPlace());
} }
}; };
......
...@@ -42,9 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel { ...@@ -42,9 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.MultiInput<framework::Tensor>("X")[0]->type(), ctx.GetPlace());
ctx.MultiInput<framework::Tensor>("X")[0]->type()),
ctx.GetPlace());
} }
}; };
......
...@@ -197,8 +197,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -197,8 +197,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = framework::ToDataType( auto input_data_type =
ctx.Input<Tensor>(framework::GradVarName("Out"))->type()); ctx.Input<Tensor>(framework::GradVarName("Out"))->type();
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) { if (platform::CanMKLDNNBeUsed(ctx)) {
......
...@@ -115,8 +115,7 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -115,8 +115,7 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
...@@ -175,8 +174,7 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -175,8 +174,7 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -79,9 +79,8 @@ framework::OpKernelType FCOp::GetExpectedKernelType( ...@@ -79,9 +79,8 @@ framework::OpKernelType FCOp::GetExpectedKernelType(
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), ctx.GetPlace(), layout, library);
layout, library);
} }
void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const { void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const {
...@@ -111,9 +110,8 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType( ...@@ -111,9 +110,8 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), ctx.GetPlace(), layout, library);
layout, library);
} }
void FCOpMaker::Make() { void FCOpMaker::Make() {
......
...@@ -59,9 +59,9 @@ class FillConstantOp : public framework::OperatorBase { ...@@ -59,9 +59,9 @@ class FillConstantOp : public framework::OperatorBase {
if (force_cpu) { if (force_cpu) {
auto cpu = platform::CPUPlace(); auto cpu = platform::CPUPlace();
tensor->mutable_data(cpu, framework::ToTypeIndex(data_type)); tensor->mutable_data(cpu, data_type);
} else { } else {
tensor->mutable_data(dev_place, framework::ToTypeIndex(data_type)); tensor->mutable_data(dev_place, data_type);
} }
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
......
...@@ -55,7 +55,7 @@ class FillOp : public framework::OperatorBase { ...@@ -55,7 +55,7 @@ class FillOp : public framework::OperatorBase {
static_cast<framework::proto::VarType::Type>(Attr<int>("dtype")); static_cast<framework::proto::VarType::Type>(Attr<int>("dtype"));
platform::CPUPlace cpu; platform::CPUPlace cpu;
auto force_cpu = Attr<bool>("force_cpu"); auto force_cpu = Attr<bool>("force_cpu");
out.mutable_data(force_cpu ? cpu : place, framework::ToTypeIndex(dtype)); out.mutable_data(force_cpu ? cpu : place, dtype);
framework::LoDTensor tensor; framework::LoDTensor tensor;
...@@ -64,7 +64,7 @@ class FillOp : public framework::OperatorBase { ...@@ -64,7 +64,7 @@ class FillOp : public framework::OperatorBase {
} else { } else {
// Always make tensor in CPU memory. // Always make tensor in CPU memory.
tensor.Resize(out.dims()); tensor.Resize(out.dims());
tensor.mutable_data(cpu, framework::ToTypeIndex(dtype)); tensor.mutable_data(cpu, dtype);
} }
framework::VisitDataType( framework::VisitDataType(
......
...@@ -135,9 +135,8 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel { ...@@ -135,9 +135,8 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx.Input<framework::Tensor>("X")->type(), PADDLE_ENFORCE_EQ(ctx.Input<framework::Tensor>("X")->type(),
ctx.Input<framework::Tensor>("Y")->type(), ctx.Input<framework::Tensor>("Y")->type(),
"The element's type of input should be the same."); "The element's type of input should be the same.");
auto input_data_type = return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()); ctx.GetPlace());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -324,9 +323,8 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { ...@@ -324,9 +323,8 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type_index = ctx.Input<framework::Tensor>("Y")->type(); return framework::OpKernelType(ctx.Input<framework::Tensor>("Y")->type(),
auto input_data_type = framework::ToDataType(input_data_type_index); ctx.GetPlace());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -115,8 +115,7 @@ void FusedEmbeddingFCLSTMOp::InferShape( ...@@ -115,8 +115,7 @@ void FusedEmbeddingFCLSTMOp::InferShape(
framework::OpKernelType FusedEmbeddingFCLSTMOp::GetExpectedKernelType( framework::OpKernelType FusedEmbeddingFCLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::LoDTensor>("Embeddings")->type(),
ctx.Input<framework::LoDTensor>("Embeddings")->type()),
ctx.device_context()); ctx.device_context());
} }
......
...@@ -93,8 +93,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -93,8 +93,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionGRUOp::GetExpectedKernelType( framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context()); ctx.device_context());
} }
......
...@@ -117,8 +117,7 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -117,8 +117,7 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionLSTMOp::GetExpectedKernelType( framework::OpKernelType FusionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context()); ctx.device_context());
} }
......
...@@ -61,8 +61,7 @@ void FusionSeqConvEltAddReluOp::InferShape( ...@@ -61,8 +61,7 @@ void FusionSeqConvEltAddReluOp::InferShape(
framework::OpKernelType FusionSeqConvEltAddReluOp::GetExpectedKernelType( framework::OpKernelType FusionSeqConvEltAddReluOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context()); ctx.device_context());
} }
......
...@@ -67,8 +67,7 @@ void FusionSeqExpandConcatFCOp::InferShape( ...@@ -67,8 +67,7 @@ void FusionSeqExpandConcatFCOp::InferShape(
framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType( framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(ctx.MultiInput<LoDTensor>("X")[0]->type(),
framework::ToDataType(ctx.MultiInput<LoDTensor>("X")[0]->type()),
ctx.device_context()); ctx.device_context());
} }
......
...@@ -42,8 +42,7 @@ class GatherOp : public framework::OperatorWithKernel { ...@@ -42,8 +42,7 @@ class GatherOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
...@@ -60,8 +59,7 @@ class GatherGradOp : public framework::OperatorWithKernel { ...@@ -60,8 +59,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -63,8 +63,8 @@ class GridSampleOp : public framework::OperatorWithKernel { ...@@ -63,8 +63,8 @@ class GridSampleOp : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_); framework::DataLayout::kAnyLayout, library_);
} }
}; };
...@@ -159,8 +159,8 @@ class GridSampleOpGrad : public framework::OperatorWithKernel { ...@@ -159,8 +159,8 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_); framework::DataLayout::kAnyLayout, library_);
} }
}; };
......
...@@ -141,8 +141,7 @@ class GroupNormGradOp : public framework::OperatorWithKernel { ...@@ -141,8 +141,7 @@ class GroupNormGradOp : public framework::OperatorWithKernel {
if (t == nullptr) { if (t == nullptr) {
PADDLE_THROW("can't find Y@GRAD"); PADDLE_THROW("can't find Y@GRAD");
} }
return framework::OpKernelType(framework::ToDataType(t->type()), return framework::OpKernelType(t->type(), ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -76,8 +76,7 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { ...@@ -76,8 +76,7 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
...@@ -162,8 +161,7 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { ...@@ -162,8 +161,7 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -55,8 +55,8 @@ class InterpolateOp : public framework::OperatorWithKernel { ...@@ -55,8 +55,8 @@ class InterpolateOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace()); ctx.GetPlace());
} }
}; };
...@@ -124,8 +124,8 @@ class InterpolateOpGrad : public framework::OperatorWithKernel { ...@@ -124,8 +124,8 @@ class InterpolateOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -35,8 +35,7 @@ class IsEmptyOp : public framework::OperatorWithKernel { ...@@ -35,8 +35,7 @@ class IsEmptyOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = framework::OpKernelType( framework::OpKernelType kt = framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.Input<framework::LoDTensor>("X")->type(), platform::CPUPlace());
platform::CPUPlace());
return kt; return kt;
} }
}; };
......
...@@ -40,10 +40,9 @@ class OverflowOp : public framework::OperatorWithKernel { ...@@ -40,10 +40,9 @@ class OverflowOp : public framework::OperatorWithKernel {
int dtype = -1; int dtype = -1;
auto *x_var = ctx.InputVar("X"); auto *x_var = ctx.InputVar("X");
if (x_var->IsType<framework::LoDTensor>()) { if (x_var->IsType<framework::LoDTensor>()) {
dtype = framework::ToDataType(x_var->Get<framework::LoDTensor>().type()); dtype = x_var->Get<framework::LoDTensor>().type();
} else if (x_var->IsType<framework::SelectedRows>()) { } else if (x_var->IsType<framework::SelectedRows>()) {
dtype = framework::ToDataType( dtype = x_var->Get<framework::SelectedRows>().value().type();
x_var->Get<framework::SelectedRows>().value().type());
} else { } else {
PADDLE_THROW("Cannot find the input data type by all input data"); PADDLE_THROW("Cannot find the input data type by all input data");
} }
......
...@@ -153,8 +153,7 @@ class LayerNormGradOp : public framework::OperatorWithKernel { ...@@ -153,8 +153,7 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
if (t == nullptr) { if (t == nullptr) {
PADDLE_THROW("can't find Y@GRAD"); PADDLE_THROW("can't find Y@GRAD");
} }
return framework::OpKernelType(framework::ToDataType(t->type()), return framework::OpKernelType(t->type(), ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -184,8 +184,7 @@ class LinearChainCRFOp : public framework::OperatorWithKernel { ...@@ -184,8 +184,7 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
// is determined by its input "Emission". // is determined by its input "Emission".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<LoDTensor>("Emission")->type(),
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
...@@ -244,9 +243,7 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel { ...@@ -244,9 +243,7 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<LoDTensor>(framework::GradVarName("LogLikelihood"))->type(),
ctx.Input<LoDTensor>(framework::GradVarName("LogLikelihood"))
->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -60,7 +60,7 @@ class LoadCombineOp : public framework::OperatorBase { ...@@ -60,7 +60,7 @@ class LoadCombineOp : public framework::OperatorBase {
// Get data from fin to tensor // Get data from fin to tensor
DeserializeFromStream(fin, tensor, dev_ctx); DeserializeFromStream(fin, tensor, dev_ctx);
auto in_dtype = framework::ToDataType(tensor->type()); auto in_dtype = tensor->type();
auto out_dtype = auto out_dtype =
load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
......
...@@ -65,7 +65,7 @@ class LoadOp : public framework::OperatorBase { ...@@ -65,7 +65,7 @@ class LoadOp : public framework::OperatorBase {
DeserializeFromStream(fin, tensor, dev_ctx); DeserializeFromStream(fin, tensor, dev_ctx);
auto load_as_fp16 = Attr<bool>("load_as_fp16"); auto load_as_fp16 = Attr<bool>("load_as_fp16");
auto in_dtype = framework::ToDataType(tensor->type()); auto in_dtype = tensor->type();
auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) { if (in_dtype != out_dtype) {
......
...@@ -39,8 +39,7 @@ class LoDResetOp : public framework::OperatorWithKernel { ...@@ -39,8 +39,7 @@ class LoDResetOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
...@@ -144,8 +143,7 @@ class LoDResetGradOp : public framework::OperatorWithKernel { ...@@ -144,8 +143,7 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -72,7 +72,7 @@ struct LoDTensorToArrayFunctor : public boost::static_visitor<void> { ...@@ -72,7 +72,7 @@ struct LoDTensorToArrayFunctor : public boost::static_visitor<void> {
LoDTensorToArrayFunctorImpl<DeviceContext> func; LoDTensorToArrayFunctorImpl<DeviceContext> func;
func.prev_functor_ = this; func.prev_functor_ = this;
func.dev_ctx_ = dev_ctx; func.dev_ctx_ = dev_ctx;
framework::VisitDataType(framework::ToDataType(input_.type()), func); framework::VisitDataType(input_.type(), func);
} }
}; };
......
...@@ -63,8 +63,7 @@ class LookupSparseTableOp : public framework::OperatorBase { ...@@ -63,8 +63,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
out_shape[0] = ids_t.numel(); out_shape[0] = ids_t.numel();
out_t->Resize(out_shape); out_t->Resize(out_shape);
out_t->mutable_data(cpu, w_t->value().type()); out_t->mutable_data(cpu, w_t->value().type());
PADDLE_ENFORCE_EQ(framework::ToDataType(w_t->value().type()), PADDLE_ENFORCE_EQ(w_t->value().type(), framework::proto::VarType::FP32,
framework::proto::VarType::FP32,
"The sparse table only support FP32"); "The sparse table only support FP32");
w_t->Get(ids_t, out_t, true, is_test); w_t->Get(ids_t, out_t, true, is_test);
out_t->set_lod(ids_t.lod()); out_t->set_lod(ids_t.lod());
......
...@@ -145,8 +145,7 @@ framework::OpKernelType GetExpectedLRNKernel( ...@@ -145,8 +145,7 @@ framework::OpKernelType GetExpectedLRNKernel(
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), ctx.GetPlace(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout_, library_); layout_, library_);
} }
} // namespace } // namespace
......
...@@ -96,8 +96,7 @@ class LSTMOp : public framework::OperatorWithKernel { ...@@ -96,8 +96,7 @@ class LSTMOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()), ctx.Input<framework::LoDTensor>("Input")->type(), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -261,8 +260,7 @@ class LSTMGradOp : public framework::OperatorWithKernel { ...@@ -261,8 +260,7 @@ class LSTMGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()), ctx.Input<framework::LoDTensor>("Input")->type(), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -113,8 +113,7 @@ class LSTMPOp : public framework::OperatorWithKernel { ...@@ -113,8 +113,7 @@ class LSTMPOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()), ctx.Input<framework::LoDTensor>("Input")->type(), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -312,8 +311,7 @@ class LSTMPGradOp : public framework::OperatorWithKernel { ...@@ -312,8 +311,7 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()), ctx.Input<framework::LoDTensor>("Input")->type(), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -77,16 +77,14 @@ template <> ...@@ -77,16 +77,14 @@ template <>
void set_constant_with_place<platform::CPUPlace>( void set_constant_with_place<platform::CPUPlace>(
const platform::DeviceContext& context, framework::Tensor* tensor, const platform::DeviceContext& context, framework::Tensor* tensor,
float value) { float value) {
framework::VisitDataType(framework::ToDataType(tensor->type()), framework::VisitDataType(tensor->type(), TensorSetConstantCPU(tensor, value));
TensorSetConstantCPU(tensor, value));
} }
template <> template <>
void set_constant_with_place<platform::CUDAPinnedPlace>( void set_constant_with_place<platform::CUDAPinnedPlace>(
const platform::DeviceContext& context, framework::Tensor* tensor, const platform::DeviceContext& context, framework::Tensor* tensor,
float value) { float value) {
framework::VisitDataType(framework::ToDataType(tensor->type()), framework::VisitDataType(tensor->type(), TensorSetConstantCPU(tensor, value));
TensorSetConstantCPU(tensor, value));
} }
struct TensorSetConstantWithPlace : public boost::static_visitor<void> { struct TensorSetConstantWithPlace : public boost::static_visitor<void> {
......
...@@ -67,7 +67,7 @@ template <> ...@@ -67,7 +67,7 @@ template <>
void set_constant_with_place<platform::CUDAPlace>( void set_constant_with_place<platform::CUDAPlace>(
const platform::DeviceContext& context, framework::Tensor* tensor, const platform::DeviceContext& context, framework::Tensor* tensor,
float value) { float value) {
framework::VisitDataType(framework::ToDataType(tensor->type()), framework::VisitDataType(tensor->type(),
TensorSetConstantGPU(context, tensor, value)); TensorSetConstantGPU(context, tensor, value));
} }
......
...@@ -44,8 +44,7 @@ class MeanIoUOp : public framework::OperatorWithKernel { ...@@ -44,8 +44,7 @@ class MeanIoUOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Predictions")->type(),
framework::ToDataType(ctx.Input<Tensor>("Predictions")->type()),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -61,9 +61,7 @@ class MeanGradOp : public framework::OperatorWithKernel { ...@@ -61,9 +61,7 @@ class MeanGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = auto input_data_type = ctx.Input<Tensor>("X")->type();
framework::ToDataType(ctx.Input<Tensor>("X")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -63,9 +63,7 @@ class MergeLoDTensorOp : public framework::OperatorBase { ...@@ -63,9 +63,7 @@ class MergeLoDTensorOp : public framework::OperatorBase {
platform::Place place = dev_place; platform::Place place = dev_place;
int64_t batch_size = in_true.dims()[0] + in_false.dims()[0]; int64_t batch_size = in_true.dims()[0] + in_false.dims()[0];
auto data_type = in_true.IsInitialized() ? in_true.type() : in_false.type();
std::type_index data_type =
in_true.IsInitialized() ? in_true.type() : in_false.type();
int rank; int rank;
framework::DDim in_dims; framework::DDim in_dims;
if (in_true.IsInitialized()) { if (in_true.IsInitialized()) {
......
...@@ -55,8 +55,7 @@ class AccuracyOp : public framework::OperatorWithKernel { ...@@ -55,8 +55,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Out")->type(),
framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -51,8 +51,7 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -51,8 +51,7 @@ class AucOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Predict")->type(),
framework::ToDataType(ctx.Input<Tensor>("Predict")->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -82,8 +82,7 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { ...@@ -82,8 +82,7 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("MaxProbs")->type(),
framework::ToDataType(ctx.Input<Tensor>("MaxProbs")->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -53,8 +53,7 @@ class MultiplexOp : public framework::OperatorWithKernel { ...@@ -53,8 +53,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.MultiInput<Tensor>("X")[0]->type(),
framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
...@@ -123,8 +122,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel { ...@@ -123,8 +122,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.MultiInput<Tensor>("X")[0]->type(),
framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -69,8 +69,7 @@ class NCEOp : public framework::OperatorWithKernel { ...@@ -69,8 +69,7 @@ class NCEOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
...@@ -214,8 +213,7 @@ class NCEOpGrad : public framework::OperatorWithKernel { ...@@ -214,8 +213,7 @@ class NCEOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -70,9 +70,8 @@ class AdadeltaOp : public framework::OperatorWithKernel { ...@@ -70,9 +70,8 @@ class AdadeltaOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(),
framework::ToDataType(ctx.Input<Tensor>("Param")->type()); ctx.GetPlace());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -59,9 +59,8 @@ class AdagradOp : public framework::OperatorWithKernel { ...@@ -59,9 +59,8 @@ class AdagradOp : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(),
framework::ToDataType(ctx.Input<Tensor>("Param")->type()); ctx.GetPlace());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -75,8 +75,7 @@ class AdamOp : public framework::OperatorWithKernel { ...@@ -75,8 +75,7 @@ class AdamOp : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type = ctx.Input<Tensor>("Param")->type();
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册