未验证 提交 a607b6c8 编写于 作者: Q Qiyang Min 提交者: GitHub

Merge pull request #15216 from velconia/local_release_1_2_x_tensor_type

Feature/tensor type
...@@ -86,6 +86,7 @@ endif(NOT WITH_GOLANG) ...@@ -86,6 +86,7 @@ endif(NOT WITH_GOLANG)
if(WITH_GPU) if(WITH_GPU)
add_definitions(-DPADDLE_WITH_CUDA) add_definitions(-DPADDLE_WITH_CUDA)
add_definitions(-DEIGEN_USE_GPU)
FIND_PACKAGE(CUDA REQUIRED) FIND_PACKAGE(CUDA REQUIRED)
......
...@@ -85,7 +85,7 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var, ...@@ -85,7 +85,7 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
out->mutable_data(expected_kernel_type.place_, in.type()); out->mutable_data(expected_kernel_type.place_, in.type());
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(in.type()), in.type(),
CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out)); CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out));
out->set_layout(expected_kernel_type.data_layout_); out->set_layout(expected_kernel_type.data_layout_);
...@@ -101,7 +101,7 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) { ...@@ -101,7 +101,7 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
case mkldnn::memory::data_type::f32: case mkldnn::memory::data_type::f32:
return platform::to_void_cast(tensor.data<float>()); return platform::to_void_cast(tensor.data<float>());
case mkldnn::memory::data_type::s8: case mkldnn::memory::data_type::s8:
return platform::to_void_cast(tensor.data<char>()); return platform::to_void_cast(tensor.data<int8_t>());
case mkldnn::memory::data_type::u8: case mkldnn::memory::data_type::u8:
return platform::to_void_cast(tensor.data<unsigned char>()); return platform::to_void_cast(tensor.data<unsigned char>());
case mkldnn::memory::data_type::s16: case mkldnn::memory::data_type::s16:
...@@ -144,7 +144,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -144,7 +144,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
memory::data_type in_type = ToMKLDNNDataType(in.type()); memory::data_type in_type = ToMKLDNNDataType(in.type());
PADDLE_ENFORCE(in_type != memory::data_type::data_undef, PADDLE_ENFORCE(in_type != memory::data_type::data_undef,
"Input tensor type is not supported: ", in.type().name()); "Input tensor type is not supported: %s", in.type());
memory::data_type out_type = in_type; memory::data_type out_type = in_type;
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format()); auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
......
...@@ -50,14 +50,14 @@ inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) { ...@@ -50,14 +50,14 @@ inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) {
} }
} }
inline MKLDNNDataType ToMKLDNNDataType(const std::type_index type) { inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
static const std::map<std::type_index, MKLDNNDataType> dict{ static std::unordered_map<int, MKLDNNDataType> dict{
{std::type_index(typeid(float)), MKLDNNDataType::f32}, // NOLINT {DataTypeTrait<float>::DataType, MKLDNNDataType::f32},
{std::type_index(typeid(char)), MKLDNNDataType::s8}, // NOLINT {DataTypeTrait<int8_t>::DataType, MKLDNNDataType::s8},
{std::type_index(typeid(unsigned char)), MKLDNNDataType::u8}, {DataTypeTrait<uint8_t>::DataType, MKLDNNDataType::u8},
{std::type_index(typeid(int16_t)), MKLDNNDataType::s16}, {DataTypeTrait<int16_t>::DataType, MKLDNNDataType::s16},
{std::type_index(typeid(int32_t)), MKLDNNDataType::s32}}; {DataTypeTrait<int32_t>::DataType, MKLDNNDataType::s32}};
auto iter = dict.find(type); auto iter = dict.find(static_cast<int>(type));
if (iter != dict.end()) return iter->second; if (iter != dict.end()) return iter->second;
return MKLDNNDataType::data_undef; return MKLDNNDataType::data_undef;
} }
......
...@@ -26,7 +26,7 @@ struct DataTypeMap { ...@@ -26,7 +26,7 @@ struct DataTypeMap {
std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_; std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_;
std::unordered_map<int, std::type_index> proto_to_cpp_; std::unordered_map<int, std::type_index> proto_to_cpp_;
std::unordered_map<int, std::string> proto_to_str_; std::unordered_map<int, std::string> proto_to_str_;
std::unordered_map<std::type_index, size_t> cpp_to_size_; std::unordered_map<int, size_t> proto_to_size_;
}; };
static DataTypeMap* InitDataTypeMap(); static DataTypeMap* InitDataTypeMap();
...@@ -45,7 +45,7 @@ static inline void RegisterType(DataTypeMap* map, ...@@ -45,7 +45,7 @@ static inline void RegisterType(DataTypeMap* map,
map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T)); map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T));
map->cpp_to_proto_.emplace(typeid(T), proto_type); map->cpp_to_proto_.emplace(typeid(T), proto_type);
map->proto_to_str_.emplace(static_cast<int>(proto_type), name); map->proto_to_str_.emplace(static_cast<int>(proto_type), name);
map->cpp_to_size_.emplace(typeid(T), sizeof(T)); map->proto_to_size_.emplace(static_cast<int>(proto_type), sizeof(T));
} }
static DataTypeMap* InitDataTypeMap() { static DataTypeMap* InitDataTypeMap() {
...@@ -54,17 +54,7 @@ static DataTypeMap* InitDataTypeMap() { ...@@ -54,17 +54,7 @@ static DataTypeMap* InitDataTypeMap() {
#define RegType(cc_type, proto_type) \ #define RegType(cc_type, proto_type) \
RegisterType<cc_type>(retv, proto_type, #cc_type) RegisterType<cc_type>(retv, proto_type, #cc_type)
// NOTE: Add your customize type here. _ForEachDataType_(RegType);
RegType(float16, proto::VarType::FP16);
RegType(float, proto::VarType::FP32);
RegType(double, proto::VarType::FP64);
RegType(int, proto::VarType::INT32);
RegType(int64_t, proto::VarType::INT64);
RegType(bool, proto::VarType::BOOL);
RegType(size_t, proto::VarType::SIZE_T);
RegType(int16_t, proto::VarType::INT16);
RegType(uint8_t, proto::VarType::UINT8);
RegType(int8_t, proto::VarType::INT8);
#undef RegType #undef RegType
return retv; return retv;
...@@ -96,12 +86,12 @@ std::string DataTypeToString(const proto::VarType::Type type) { ...@@ -96,12 +86,12 @@ std::string DataTypeToString(const proto::VarType::Type type) {
static_cast<int>(type)); static_cast<int>(type));
} }
size_t SizeOfType(std::type_index type) { size_t SizeOfType(proto::VarType::Type type) {
auto it = gDataTypeMap().cpp_to_size_.find(type); auto it = gDataTypeMap().proto_to_size_.find(static_cast<int>(type));
if (it != gDataTypeMap().cpp_to_size_.end()) { if (it != gDataTypeMap().proto_to_size_.end()) {
return it->second; return it->second;
} }
PADDLE_THROW("Not support %s as tensor type", type.name()); PADDLE_THROW("Not support %s as tensor type", DataTypeToString(type));
} }
} // namespace framework } // namespace framework
......
...@@ -22,46 +22,59 @@ limitations under the License. */ ...@@ -22,46 +22,59 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T>
struct DataTypeTrait {};
// Stub handle for void
template <>
struct DataTypeTrait<void> {
constexpr static auto DataType = proto::VarType::RAW;
};
#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
callback(cpp_type, ::paddle::framework::proto::VarType::proto_type);
#define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8)
#define DefineDataTypeTrait(cpp_type, proto_type) \
template <> \
struct DataTypeTrait<cpp_type> { \
constexpr static auto DataType = proto_type; \
}
_ForEachDataType_(DefineDataTypeTrait);
#undef DefineDataTypeTrait
extern proto::VarType::Type ToDataType(std::type_index type); extern proto::VarType::Type ToDataType(std::type_index type);
extern std::type_index ToTypeIndex(proto::VarType::Type type); extern std::type_index ToTypeIndex(proto::VarType::Type type);
template <typename Visitor> template <typename Visitor>
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
switch (type) { #define VisitDataTypeCallback(cpp_type, proto_type) \
case proto::VarType::FP16: do { \
visitor.template apply<platform::float16>(); if (type == proto_type) { \
break; visitor.template apply<cpp_type>(); \
case proto::VarType::FP32: return; \
visitor.template apply<float>(); } \
break; } while (0)
case proto::VarType::FP64:
visitor.template apply<double>(); _ForEachDataType_(VisitDataTypeCallback);
break; #undef VisitDataTypeCallback
case proto::VarType::INT32: PADDLE_THROW("Not supported %d", type);
visitor.template apply<int>();
break;
case proto::VarType::INT64:
visitor.template apply<int64_t>();
break;
case proto::VarType::BOOL:
visitor.template apply<bool>();
break;
case proto::VarType::UINT8:
visitor.template apply<uint8_t>();
break;
case proto::VarType::INT16:
visitor.template apply<int16_t>();
break;
case proto::VarType::INT8:
visitor.template apply<int8_t>();
break;
default:
PADDLE_THROW("Not supported %d", type);
}
} }
extern std::string DataTypeToString(const proto::VarType::Type type); extern std::string DataTypeToString(const proto::VarType::Type type);
extern size_t SizeOfType(std::type_index type); extern size_t SizeOfType(proto::VarType::Type type);
inline std::ostream& operator<<(std::ostream& out, inline std::ostream& operator<<(std::ostream& out,
const proto::VarType::Type& type) { const proto::VarType::Type& type) {
out << DataTypeToString(type); out << DataTypeToString(type);
......
...@@ -26,15 +26,15 @@ TEST(DataType, float16) { ...@@ -26,15 +26,15 @@ TEST(DataType, float16) {
Tensor tensor; Tensor tensor;
CPUPlace cpu; CPUPlace cpu;
tensor.mutable_data(cpu, f::ToTypeIndex(dtype)); tensor.mutable_data(cpu, dtype);
// test fp16 tensor // test fp16 tensor
EXPECT_EQ(tensor.type(), std::type_index(typeid(float16))); EXPECT_EQ(tensor.type(), f::ToDataType(typeid(float16)));
// test fp16 size // test fp16 size
EXPECT_EQ(f::SizeOfType(f::ToTypeIndex(dtype)), 2u); EXPECT_EQ(f::SizeOfType(dtype), 2u);
// test debug info // test debug info
std::string type = "float16"; std::string type = "::paddle::platform::float16";
EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str()); EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str());
} }
...@@ -120,7 +120,7 @@ void AllReduceOpHandle::RunImpl() { ...@@ -120,7 +120,7 @@ void AllReduceOpHandle::RunImpl() {
// Reduce All Tensor to trg in CPU // Reduce All Tensor to trg in CPU
ReduceLoDTensor func(lod_tensors, &trg); ReduceLoDTensor func(lod_tensors, &trg);
VisitDataType(ToDataType(lod_tensors[0]->type()), func); VisitDataType(lod_tensors[0]->type(), func);
for (size_t i = 1; i < local_scopes_.size(); ++i) { for (size_t i = 1; i < local_scopes_.size(); ++i) {
auto &scope = auto &scope =
......
...@@ -33,7 +33,7 @@ struct FuseVarsOpHandle : public OpHandleBase { ...@@ -33,7 +33,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
FuseVarsOpHandle(ir::Node *node, Scope *local_scope, FuseVarsOpHandle(ir::Node *node, Scope *local_scope,
const platform::Place &place, const platform::Place &place,
const std::unordered_map<std::string, int64_t> &inputs_numel, const std::unordered_map<std::string, int64_t> &inputs_numel,
const std::type_index &var_type) const proto::VarType::Type var_type)
: OpHandleBase(node), : OpHandleBase(node),
local_scope_(local_scope), local_scope_(local_scope),
place_(place), place_(place),
...@@ -57,7 +57,7 @@ struct FuseVarsOpHandle : public OpHandleBase { ...@@ -57,7 +57,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
Scope *local_scope_; Scope *local_scope_;
const platform::Place place_; const platform::Place place_;
const std::unordered_map<std::string, int64_t> inputs_numel_; const std::unordered_map<std::string, int64_t> inputs_numel_;
const std::type_index type_; const proto::VarType::Type type_;
int64_t total_numel_; int64_t total_numel_;
}; };
} // namespace details } // namespace details
......
...@@ -106,7 +106,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -106,7 +106,7 @@ void ReduceOpHandle::RunImpl() {
if (!FLAGS_cpu_deterministic) { if (!FLAGS_cpu_deterministic) {
ReduceLoDTensor func(lod_tensors, ReduceLoDTensor func(lod_tensors,
out_var->GetMutable<framework::LoDTensor>()); out_var->GetMutable<framework::LoDTensor>());
VisitDataType(ToDataType(lod_tensors[0]->type()), func); VisitDataType(lod_tensors[0]->type(), func);
} else { } else {
// We sum lod_tensors to reduce_sum_trg which is in local_scopes_0 // We sum lod_tensors to reduce_sum_trg which is in local_scopes_0
// here, but it doesn't mean reduce_sum_trg must be in local_scopes_0. // here, but it doesn't mean reduce_sum_trg must be in local_scopes_0.
...@@ -116,7 +116,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -116,7 +116,7 @@ void ReduceOpHandle::RunImpl() {
->FindVar(out_var_handle->name_) ->FindVar(out_var_handle->name_)
->GetMutable<framework::LoDTensor>(); ->GetMutable<framework::LoDTensor>();
ReduceLoDTensor func(lod_tensors, &reduce_sum_trg); ReduceLoDTensor func(lod_tensors, &reduce_sum_trg);
VisitDataType(ToDataType(lod_tensors[0]->type()), func); VisitDataType(lod_tensors[0]->type(), func);
auto trg = out_var->GetMutable<framework::LoDTensor>(); auto trg = out_var->GetMutable<framework::LoDTensor>();
if (reduce_sum_trg.data<void>() != trg->data<void>()) { if (reduce_sum_trg.data<void>() != trg->data<void>()) {
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/dlpack_tensor.h" #include "paddle/fluid/framework/dlpack_tensor.h"
#include "paddle/fluid/framework/data_type.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -36,26 +36,23 @@ static ::DLDataType GetDLDataTypeCode() { ...@@ -36,26 +36,23 @@ static ::DLDataType GetDLDataTypeCode() {
return dtype; return dtype;
} }
static DLDataType GetDLDataTypeFromTypeIndex(const std::type_index &type) { static std::unordered_map<int, ::DLDataType> CreateDLDataTypeMap() {
#define REG_DL_DATA_TYPE(type) \ static std::unordered_map<int, ::DLDataType> result;
{ std::type_index(typeid(type)), GetDLDataTypeCode<type>() }
static const std::unordered_map<std::type_index, ::DLDataType> #define REG_DL_DATA_TYPE(cpp_type, proto_type) \
type_to_dtype_map({ result[static_cast<int>(proto_type)] = GetDLDataTypeCode<cpp_type>()
REG_DL_DATA_TYPE(platform::float16), // NOLINT
REG_DL_DATA_TYPE(float), // NOLINT _ForEachDataType_(REG_DL_DATA_TYPE);
REG_DL_DATA_TYPE(double), // NOLINT #undef REG_DL_DATA_TYPE
REG_DL_DATA_TYPE(int), // NOLINT return result;
REG_DL_DATA_TYPE(int64_t), // NOLINT }
REG_DL_DATA_TYPE(bool), // NOLINT
REG_DL_DATA_TYPE(size_t), // NOLINT static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) {
REG_DL_DATA_TYPE(int16_t), // NOLINT static auto type_to_dtype_map = CreateDLDataTypeMap();
REG_DL_DATA_TYPE(uint8_t), // NOLINT
REG_DL_DATA_TYPE(int8_t) // NOLINT
});
static auto type_to_dtype_map_end_it = type_to_dtype_map.end(); static auto type_to_dtype_map_end_it = type_to_dtype_map.end();
auto it = type_to_dtype_map.find(type); auto it = type_to_dtype_map.find(static_cast<int>(type));
PADDLE_ENFORCE(it != type_to_dtype_map_end_it, "Unsupported data type %s", PADDLE_ENFORCE(it != type_to_dtype_map_end_it, "Unsupported data type %d",
type.name()); type);
return it->second; return it->second;
#undef REG_DL_DATA_TYPE #undef REG_DL_DATA_TYPE
} }
......
...@@ -91,23 +91,11 @@ void TestMainLoop() { ...@@ -91,23 +91,11 @@ void TestMainLoop() {
} }
} }
} }
TEST(dlpack, test_all) {
#define TestCallback(cpp_type, proto_type) TestMainLoop<cpp_type>()
#define PADDLE_DLPACK_TEST(type) \ _ForEachDataType_(TestCallback);
TEST(dlpack, test_##type) { TestMainLoop<type>(); } }
using float16 = platform::float16;
PADDLE_DLPACK_TEST(float16);
PADDLE_DLPACK_TEST(float);
PADDLE_DLPACK_TEST(double);
PADDLE_DLPACK_TEST(int);
PADDLE_DLPACK_TEST(int64_t);
PADDLE_DLPACK_TEST(bool);
PADDLE_DLPACK_TEST(size_t);
PADDLE_DLPACK_TEST(int16_t);
PADDLE_DLPACK_TEST(uint8_t);
PADDLE_DLPACK_TEST(int8_t);
#undef PADDLE_DLPACK_TEST
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -138,39 +138,19 @@ void print_lod_tensor(std::string var_name, const LoDTensor& lod_tensor) { ...@@ -138,39 +138,19 @@ void print_lod_tensor(std::string var_name, const LoDTensor& lod_tensor) {
std::cout << sstream.str() << std::endl; std::cout << sstream.str() << std::endl;
} }
void print_fetch_var(Scope* scope, std::string var_name) { static void print_fetch_var(Scope* scope, const std::string& var_name) {
const LoDTensor& tensor = scope->FindVar(var_name)->Get<LoDTensor>(); auto& tensor = scope->FindVar(var_name)->Get<LoDTensor>();
if (std::type_index(tensor.type()) == #define PrintLoDTensorCallback(cpp_type, proto_type) \
std::type_index(typeid(platform::float16))) { do { \
print_lod_tensor<platform::float16>(var_name, tensor); if (tensor.type() == proto_type) { \
} else if (std::type_index(tensor.type()) == std::type_index(typeid(float))) { print_lod_tensor<cpp_type>(var_name, tensor); \
print_lod_tensor<float>(var_name, tensor); return; \
} else if (std::type_index(tensor.type()) == } \
std::type_index(typeid(double))) { } while (0)
print_lod_tensor<double>(var_name, tensor);
} else if (std::type_index(tensor.type()) == std::type_index(typeid(int))) { _ForEachDataType_(PrintLoDTensorCallback);
print_lod_tensor<int>(var_name, tensor); VLOG(1) << "print_fetch_var: unrecognized data type:" << tensor.type();
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int64_t))) {
print_lod_tensor<int64_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) == std::type_index(typeid(bool))) {
print_lod_tensor<bool>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(uint8_t))) {
print_lod_tensor<uint8_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int16_t))) {
print_lod_tensor<int16_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int8_t))) {
print_lod_tensor<int8_t>(var_name, tensor);
} else {
VLOG(1) << "print_fetch_var: unrecognized data type:"
<< tensor.type().name();
}
return;
} }
void ExecutorThreadWorker::TrainFiles() { void ExecutorThreadWorker::TrainFiles() {
......
...@@ -70,9 +70,9 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) { ...@@ -70,9 +70,9 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
// only print first ten elements // only print first ten elements
int64_t size = t.numel() < 10 ? t.numel() : 10; int64_t size = t.numel() < 10 ? t.numel() : 10;
for (int64_t i = 0; i < size; ++i) { for (int64_t i = 0; i < size; ++i) {
if (IsType<float>(t.type())) { if (t.type() == proto::VarType::FP32) {
os << t.data<float>()[i] << " "; os << t.data<float>()[i] << " ";
} else if (IsType<int64_t>(t.type())) { } else if (t.type() == proto::VarType::INT64) {
os << t.data<int64_t>()[i] << " "; os << t.data<int64_t>()[i] << " ";
} else { } else {
PADDLE_THROW("LoDTensor data type not in [float, int64_t]"); PADDLE_THROW("LoDTensor data type not in [float, int64_t]");
...@@ -387,7 +387,7 @@ void LoDTensor::MergeLoDTensor( ...@@ -387,7 +387,7 @@ void LoDTensor::MergeLoDTensor(
PADDLE_ENFORCE(!lod_tensors.empty()); PADDLE_ENFORCE(!lod_tensors.empty());
framework::DDim new_dim = lod_tensors[0]->dims(); framework::DDim new_dim = lod_tensors[0]->dims();
std::type_index new_type = lod_tensors[0]->type(); auto new_type = lod_tensors[0]->type();
framework::DataLayout new_layout = lod_tensors[0]->layout(); framework::DataLayout new_layout = lod_tensors[0]->layout();
LoD new_lod = lod_tensors[0]->lod(); LoD new_lod = lod_tensors[0]->lod();
for (size_t i = 1; i < lod_tensors.size(); ++i) { for (size_t i = 1; i < lod_tensors.size(); ++i) {
......
...@@ -34,7 +34,8 @@ TEST(OpKernelType, ToString) { ...@@ -34,7 +34,8 @@ TEST(OpKernelType, ToString) {
OpKernelType op_kernel_type2(DataType::FP16, CUDAPlace(0), DataLayout::kNCHW, OpKernelType op_kernel_type2(DataType::FP16, CUDAPlace(0), DataLayout::kNCHW,
LibraryType::kCUDNN); LibraryType::kCUDNN);
ASSERT_EQ(paddle::framework::KernelTypeToString(op_kernel_type2), ASSERT_EQ(paddle::framework::KernelTypeToString(op_kernel_type2),
"data_type[float16]:data_layout[NCHW]:place[CUDAPlace(0)]:library_" "data_type[::paddle::platform::float16]:data_layout[NCHW]:place["
"CUDAPlace(0)]:library_"
"type[CUDNN]"); "type[CUDNN]");
} }
......
...@@ -43,10 +43,9 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = { ...@@ -43,10 +43,9 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
proto::VarType::Type GetDataTypeOfVar(const Variable* var) { proto::VarType::Type GetDataTypeOfVar(const Variable* var) {
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
return framework::ToDataType(var->Get<framework::LoDTensor>().type()); return var->Get<framework::LoDTensor>().type();
} else if (var->IsType<framework::SelectedRows>()) { } else if (var->IsType<framework::SelectedRows>()) {
return framework::ToDataType( return var->Get<framework::SelectedRows>().value().type();
var->Get<framework::SelectedRows>().value().type());
} else { } else {
PADDLE_THROW("Var should be LoDTensor or SelectedRows"); PADDLE_THROW("Var should be LoDTensor or SelectedRows");
} }
...@@ -93,13 +92,13 @@ static std::string GetDtype(const Scope& scope, const std::string& name) { ...@@ -93,13 +92,13 @@ static std::string GetDtype(const Scope& scope, const std::string& name) {
if (UNLIKELY(!tensor.IsInitialized())) { if (UNLIKELY(!tensor.IsInitialized())) {
return ""; return "";
} }
return DataTypeToString(ToDataType(tensor.type())); return DataTypeToString(tensor.type());
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
auto tensor = var->Get<SelectedRows>().value(); auto tensor = var->Get<SelectedRows>().value();
if (UNLIKELY(!tensor.IsInitialized())) { if (UNLIKELY(!tensor.IsInitialized())) {
return "uninited"; return "uninited";
} else { } else {
return DataTypeToString(ToDataType(tensor.type())); return DataTypeToString(tensor.type());
} }
} else { } else {
return ""; return "";
...@@ -686,7 +685,8 @@ static void CheckTensorNANOrInf(const std::string& name, ...@@ -686,7 +685,8 @@ static void CheckTensorNANOrInf(const std::string& name,
if (tensor.memory_size() == 0) { if (tensor.memory_size() == 0) {
return; return;
} }
if (!IsType<float>(tensor.type()) && !IsType<double>(tensor.type())) { if (tensor.type() != proto::VarType::FP32 &&
tensor.type() != proto::VarType::FP64) {
return; return;
} }
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor), PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
...@@ -873,7 +873,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( ...@@ -873,7 +873,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
t = &(var->Get<SelectedRows>().value()); t = &(var->Get<SelectedRows>().value());
} }
if (t != nullptr) { if (t != nullptr) {
int tmp = static_cast<int>(ToDataType(t->type())); int tmp = static_cast<int>(t->type());
PADDLE_ENFORCE( PADDLE_ENFORCE(
tmp == data_type || data_type == -1, tmp == data_type || data_type == -1,
"DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)", "DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)",
......
...@@ -218,11 +218,11 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value, ...@@ -218,11 +218,11 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
if (index < 0) { if (index < 0) {
VLOG(5) << "id " << id << " not in the table, return 0"; VLOG(5) << "id " << id << " not in the table, return 0";
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(value_->type()), value_->type(),
TensorFillVisitor(value, i * value_width, value_width, 0.0)); TensorFillVisitor(value, i * value_width, value_width, 0.0));
} else { } else {
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(value_->type()), value_->type(),
TensorCopyVisitor(value, i * value_width, *value_.get(), TensorCopyVisitor(value, i * value_width, *value_.get(),
index * value_width, value_width)); index * value_width, value_width));
} }
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
extern size_t SizeOfType(std::type_index type); extern size_t SizeOfType(proto::VarType::Type type);
void Tensor::check_memory_size() const { void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor holds no memory. Call Tensor::mutable_data first."); holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
...@@ -31,7 +31,7 @@ size_t Tensor::memory_size() const { ...@@ -31,7 +31,7 @@ size_t Tensor::memory_size() const {
return holder_ == nullptr ? 0UL : holder_->size() - offset_; return holder_ == nullptr ? 0UL : holder_->size() - offset_;
} }
void* Tensor::mutable_data(platform::Place place, std::type_index type, void* Tensor::mutable_data(platform::Place place, proto::VarType::Type type,
memory::Allocator::Attr attr, memory::Allocator::Attr attr,
size_t requested_size) { size_t requested_size) {
type_ = type; type_ = type;
......
...@@ -19,9 +19,9 @@ limitations under the License. */ ...@@ -19,9 +19,9 @@ limitations under the License. */
#include <memory> #include <memory>
#include <typeindex> #include <typeindex>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -67,7 +67,7 @@ class Tensor { ...@@ -67,7 +67,7 @@ class Tensor {
friend struct EigenVector; friend struct EigenVector;
public: public:
Tensor() : type_(typeid(float)), offset_(0) {} Tensor() : type_(proto::VarType::FP32), offset_(0) {}
/*! Return a pointer to mutable memory block. */ /*! Return a pointer to mutable memory block. */
template <typename T> template <typename T>
...@@ -88,7 +88,7 @@ class Tensor { ...@@ -88,7 +88,7 @@ class Tensor {
memory::Allocator::Attr attr = memory::Allocator::kDefault, memory::Allocator::Attr attr = memory::Allocator::kDefault,
size_t requested_size = 0); size_t requested_size = 0);
void* mutable_data(platform::Place place, std::type_index type, void* mutable_data(platform::Place place, proto::VarType::Type type,
memory::Allocator::Attr attr = memory::Allocator::kDefault, memory::Allocator::Attr attr = memory::Allocator::kDefault,
size_t requested_size = 0); size_t requested_size = 0);
...@@ -138,7 +138,7 @@ class Tensor { ...@@ -138,7 +138,7 @@ class Tensor {
return holder_->place(); return holder_->place();
} }
std::type_index type() const { proto::VarType::Type type() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor not initialized yet when Tensor::type() is called."); holder_, "Tensor not initialized yet when Tensor::type() is called.");
return type_; return type_;
...@@ -161,7 +161,7 @@ class Tensor { ...@@ -161,7 +161,7 @@ class Tensor {
private: private:
/*! holds the memory block if allocated. */ /*! holds the memory block if allocated. */
std::shared_ptr<memory::Allocation> holder_; std::shared_ptr<memory::Allocation> holder_;
std::type_index type_; proto::VarType::Type type_;
/** /**
* @brief points to elements dimensions. * @brief points to elements dimensions.
* *
......
...@@ -24,9 +24,8 @@ template <typename T> ...@@ -24,9 +24,8 @@ template <typename T>
inline const T* Tensor::data() const { inline const T* Tensor::data() const {
check_memory_size(); check_memory_size();
bool valid = bool valid =
std::is_same<T, void>::value || type_ == std::type_index(typeid(T)); std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType;
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %d", type_);
type_.name());
return reinterpret_cast<const T*>( return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_); reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
...@@ -38,9 +37,8 @@ template <typename T> ...@@ -38,9 +37,8 @@ template <typename T>
inline T* Tensor::data() { inline T* Tensor::data() {
check_memory_size(); check_memory_size();
bool valid = bool valid =
std::is_same<T, void>::value || type_ == std::type_index(typeid(T)); std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType;
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", type_);
type_.name());
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) + return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_); offset_);
} }
...@@ -60,7 +58,7 @@ inline T* Tensor::mutable_data(platform::Place place, ...@@ -60,7 +58,7 @@ inline T* Tensor::mutable_data(platform::Place place,
size_t requested_size) { size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
return reinterpret_cast<T*>( return reinterpret_cast<T*>(
mutable_data(place, typeid(T), attr, requested_size)); mutable_data(place, DataTypeTrait<T>::DataType, attr, requested_size));
} }
inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
......
...@@ -186,8 +186,8 @@ struct AnyDTypeVisitor { ...@@ -186,8 +186,8 @@ struct AnyDTypeVisitor {
template <typename Predicate, typename DevCtx> template <typename Predicate, typename DevCtx>
inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor, inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor,
const DevCtx& ctx, framework::Tensor* out) { const DevCtx& ctx, framework::Tensor* out) {
VisitDataType(ToDataType(tensor.type()), AnyDTypeVisitor<Predicate, DevCtx>( VisitDataType(tensor.type(), AnyDTypeVisitor<Predicate, DevCtx>(
predicate, tensor, ctx, out)); predicate, tensor, ctx, out));
} }
template <typename Predicate> template <typename Predicate>
...@@ -379,7 +379,7 @@ void TensorToStream(std::ostream& os, const Tensor& tensor, ...@@ -379,7 +379,7 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
// int32_t size // int32_t size
// void* protobuf message // void* protobuf message
proto::VarType::TensorDesc desc; proto::VarType::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type())); desc.set_data_type(tensor.type());
auto dims = framework::vectorize(tensor.dims()); auto dims = framework::vectorize(tensor.dims());
auto* pb_dims = desc.mutable_dims(); auto* pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0); pb_dims->Resize(static_cast<int>(dims.size()), 0);
...@@ -461,9 +461,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor, ...@@ -461,9 +461,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
tensor->Resize(framework::make_ddim(dims)); tensor->Resize(framework::make_ddim(dims));
void* buf; void* buf;
auto ctx = platform::CPUDeviceContext(); auto ctx = platform::CPUDeviceContext();
size_t size = size_t size = tensor->numel() * framework::SizeOfType(desc.data_type());
tensor->numel() *
framework::SizeOfType(framework::ToTypeIndex(desc.data_type()));
if (platform::is_gpu_place(dev_ctx.GetPlace())) { if (platform::is_gpu_place(dev_ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
Tensor cpu_tensor; Tensor cpu_tensor;
......
...@@ -289,10 +289,10 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs, ...@@ -289,10 +289,10 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
auto type = fetch.type(); auto type = fetch.type();
auto output = &(outputs->at(i)); auto output = &(outputs->at(i));
output->name = fetchs_[idx]->Input("X")[0]; output->name = fetchs_[idx]->Input("X")[0];
if (type == typeid(float)) { if (type == framework::proto::VarType::FP32) {
GetFetchOne<float>(fetch, output); GetFetchOne<float>(fetch, output);
output->dtype = PaddleDType::FLOAT32; output->dtype = PaddleDType::FLOAT32;
} else if (type == typeid(int64_t)) { } else if (type == framework::proto::VarType::INT64) {
GetFetchOne<int64_t>(fetch, output); GetFetchOne<int64_t>(fetch, output);
output->dtype = PaddleDType::INT64; output->dtype = PaddleDType::INT64;
} else { } else {
......
...@@ -266,10 +266,10 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs, ...@@ -266,10 +266,10 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs,
auto type = fetch.type(); auto type = fetch.type();
auto output = &(outputs->at(i)); auto output = &(outputs->at(i));
output->name = fetchs_[idx]->Input("X")[0]; output->name = fetchs_[idx]->Input("X")[0];
if (type == typeid(float)) { if (type == framework::DataTypeTrait<float>::DataType) {
GetFetchOne<float>(fetch, output); GetFetchOne<float>(fetch, output);
output->dtype = PaddleDType::FLOAT32; output->dtype = PaddleDType::FLOAT32;
} else if (type == typeid(int64_t)) { } else if (type == framework::DataTypeTrait<int64_t>::DataType) {
GetFetchOne<int64_t>(fetch, output); GetFetchOne<int64_t>(fetch, output);
output->dtype = PaddleDType::INT64; output->dtype = PaddleDType::INT64;
} else { } else {
......
...@@ -36,10 +36,10 @@ namespace paddle { ...@@ -36,10 +36,10 @@ namespace paddle {
PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) { PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) {
PaddleTensor pt; PaddleTensor pt;
if (t->type() == typeid(int64_t)) { if (t->type() == framework::proto::VarType::INT64) {
pt.data.Reset(t->data<void>(), t->numel() * sizeof(int64_t)); pt.data.Reset(t->data<void>(), t->numel() * sizeof(int64_t));
pt.dtype = PaddleDType::INT64; pt.dtype = PaddleDType::INT64;
} else if (t->type() == typeid(float)) { } else if (t->type() == framework::proto::VarType::FP32) {
pt.data.Reset(t->data<void>(), t->numel() * sizeof(float)); pt.data.Reset(t->data<void>(), t->numel() * sizeof(float));
pt.dtype = PaddleDType::FLOAT32; pt.dtype = PaddleDType::FLOAT32;
} else { } else {
......
...@@ -361,7 +361,7 @@ static bool CompareTensorData(const framework::LoDTensor &a, ...@@ -361,7 +361,7 @@ static bool CompareTensorData(const framework::LoDTensor &a,
} }
for (size_t i = 0; i < a_size; i++) { for (size_t i = 0; i < a_size; i++) {
if (a.type() == typeid(float)) { if (a.type() == framework::proto::VarType::FP32) {
const auto *a_data = a.data<float>(); const auto *a_data = a.data<float>();
const auto *b_data = b.data<float>(); const auto *b_data = b.data<float>();
if (std::abs(a_data[i] - b_data[i]) > 1e-3) { if (std::abs(a_data[i] - b_data[i]) > 1e-3) {
...@@ -370,7 +370,7 @@ static bool CompareTensorData(const framework::LoDTensor &a, ...@@ -370,7 +370,7 @@ static bool CompareTensorData(const framework::LoDTensor &a,
b_data[i]); b_data[i]);
return false; return false;
} }
} else if (a.type() == typeid(int64_t)) { } else if (a.type() == framework::proto::VarType::INT64) {
const auto *a_data = a.data<int64_t>(); const auto *a_data = a.data<int64_t>();
const auto *b_data = b.data<int64_t>(); const auto *b_data = b.data<int64_t>();
if (std::abs(a_data[i] - b_data[i]) > 1e-3) { if (std::abs(a_data[i] - b_data[i]) > 1e-3) {
......
...@@ -78,7 +78,7 @@ class AffineGridOp : public framework::OperatorWithKernel { ...@@ -78,7 +78,7 @@ class AffineGridOp : public framework::OperatorWithKernel {
library = framework::LibraryType::kCUDNN; library = framework::LibraryType::kCUDNN;
} }
#endif #endif
auto data_type = framework::ToDataType(ctx.Input<Tensor>("Theta")->type()); auto data_type = ctx.Input<Tensor>("Theta")->type();
return framework::OpKernelType(data_type, ctx.GetPlace(), return framework::OpKernelType(data_type, ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library); framework::DataLayout::kAnyLayout, library);
} }
...@@ -188,9 +188,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { ...@@ -188,9 +188,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Theta")->type(),
framework::ToDataType(ctx.Input<Tensor>("Theta")->type()), ctx.GetPlace(),
ctx.GetPlace(), framework::DataLayout::kAnyLayout, library_); framework::DataLayout::kAnyLayout, library_);
} }
}; };
......
...@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL( ...@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL(
int32_t>, int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int16_t>, int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, size_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
uint8_t>); uint8_t>);
...@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL(
int32_t>, int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int16_t>, int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
size_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
uint8_t>); uint8_t>);
...@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL( ...@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL(
int32_t>, int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int16_t>, int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, size_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
uint8_t>); uint8_t>);
...@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL(
int32_t>, int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
int16_t>, int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
size_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
uint8_t>); uint8_t>);
...@@ -58,7 +58,7 @@ struct ArrayToLoDFunctor : public boost::static_visitor<void> { ...@@ -58,7 +58,7 @@ struct ArrayToLoDFunctor : public boost::static_visitor<void> {
ArrayToLoDFunctorImpl<DeviceContext> functor; ArrayToLoDFunctorImpl<DeviceContext> functor;
functor.dev_ctx_ = dev_ctx; functor.dev_ctx_ = dev_ctx;
functor.prev_functor_ = this; functor.prev_functor_ = this;
framework::VisitDataType(framework::ToDataType(out->type()), functor); framework::VisitDataType(out->type(), functor);
} }
}; };
...@@ -91,7 +91,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { ...@@ -91,7 +91,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
PADDLE_ENFORCE(!x.empty(), "There's no element in the input array."); PADDLE_ENFORCE(!x.empty(), "There's no element in the input array.");
int rank = x[0].dims().size(); int rank = x[0].dims().size();
platform::Place place = x[0].place(); platform::Place place = x[0].place();
std::type_index data_type = x[0].type(); auto data_type = x[0].type();
int64_t batch_size = x[0].dims()[0]; int64_t batch_size = x[0].dims()[0];
framework::DDim ins_dims = rank > 1 framework::DDim ins_dims = rank > 1
? framework::slice_ddim(x[0].dims(), 1, rank) ? framework::slice_ddim(x[0].dims(), 1, rank)
......
...@@ -121,9 +121,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -121,9 +121,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType( framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
void AttentionLSTMOpMaker::Make() { void AttentionLSTMOpMaker::Make() {
......
...@@ -103,9 +103,8 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel { ...@@ -103,9 +103,8 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("param")->type(),
framework::ToDataType(ctx.Input<Tensor>("param")->type()), ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -72,8 +72,7 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -72,8 +72,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type = ctx.Input<Tensor>("X")->type();
framework::ToDataType(ctx.Input<Tensor>("X")->type());
// By default, the type of the scale, bias, mean, // By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor) // and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor). // or double (For double input tensor).
...@@ -81,17 +80,13 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -81,17 +80,13 @@ class BatchNormOp : public framework::OperatorWithKernel {
if (input_data_type == framework::proto::VarType::FP64) { if (input_data_type == framework::proto::VarType::FP64) {
bn_param_type = framework::proto::VarType::FP64; bn_param_type = framework::proto::VarType::FP64;
} }
PADDLE_ENFORCE_EQ(bn_param_type, PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Scale")->type(),
framework::ToDataType(ctx.Input<Tensor>("Scale")->type()),
"Scale input should be of float type"); "Scale input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type, PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Bias")->type(),
framework::ToDataType(ctx.Input<Tensor>("Bias")->type()),
"Bias input should be of float type"); "Bias input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type, PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Mean")->type(),
framework::ToDataType(ctx.Input<Tensor>("Mean")->type()),
"Mean input should be of float type"); "Mean input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType( PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Variance")->type(),
ctx.Input<Tensor>("Variance")->type()),
"Variance input should be of float type"); "Variance input should be of float type");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
...@@ -387,9 +382,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel { ...@@ -387,9 +382,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), ctx.GetPlace(), layout, library);
layout, library);
} }
}; };
......
...@@ -145,7 +145,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase { ...@@ -145,7 +145,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
LoDTensor* sentenceScores = ctx.Output<LoDTensor>("SentenceScores"); LoDTensor* sentenceScores = ctx.Output<LoDTensor>("SentenceScores");
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(scores->at(0).type()), scores->at(0).type(),
BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores, BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores,
beam_size, end_id)); beam_size, end_id));
} }
......
...@@ -282,8 +282,7 @@ class BeamSearchOp : public framework::OperatorWithKernel { ...@@ -282,8 +282,7 @@ class BeamSearchOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = framework::OpKernelType( framework::OpKernelType kt = framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::LoDTensor>("pre_ids")->type(),
ctx.Input<framework::LoDTensor>("pre_ids")->type()),
platform::CPUPlace()); platform::CPUPlace());
return kt; return kt;
} }
......
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/bilinear_tensor_product_op.h" #include "paddle/fluid/operators/bilinear_tensor_product_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -48,13 +48,12 @@ class ConditionalOp : public framework::OperatorBase { ...@@ -48,13 +48,12 @@ class ConditionalOp : public framework::OperatorBase {
if (!(ips.size() == 1UL && ips[0]->IsInitialized())) { if (!(ips.size() == 1UL && ips[0]->IsInitialized())) {
PADDLE_THROW("should have one initialized input as condition"); PADDLE_THROW("should have one initialized input as condition");
} }
if (!(framework::IsType<bool>(ips[0]->type()) && // NOLINT
ips[0]->numel() == 1)) { PADDLE_ENFORCE(ips[0]->type() == framework::proto::VarType::BOOL &&
PADDLE_THROW( ips[0]->numel() == 1,
"condition input's data type should be bool, " "condition input's data type should be bool, "
"numel should be 1, actual numel is %d", "numel should be 1, actual numel is %d",
ips[0]->numel()); ips[0]->numel());
}
bool res = false; bool res = false;
if (platform::is_gpu_place(ips[0]->place())) { if (platform::is_gpu_place(ips[0]->place())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
...@@ -237,7 +237,7 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -237,7 +237,7 @@ class WhileGradOp : public framework::OperatorBase {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
auto &inside_tensor = var->Get<framework::LoDTensor>(); auto &inside_tensor = var->Get<framework::LoDTensor>();
framework::AttributeMap attrs; framework::AttributeMap attrs;
attrs["dtype"] = framework::ToDataType(inside_tensor.type()); attrs["dtype"] = inside_tensor.type();
attrs["shape"] = framework::vectorize2int(inside_tensor.dims()); attrs["shape"] = framework::vectorize2int(inside_tensor.dims());
attrs["value"] = 0.0f; attrs["value"] = 0.0f;
......
...@@ -92,10 +92,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -92,10 +92,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
} }
#endif #endif
auto input_data_type = auto input_data_type = ctx.Input<Tensor>("Input")->type();
framework::ToDataType(ctx.Input<Tensor>("Input")->type()); auto filter_data_type = ctx.Input<Tensor>("Filter")->type();
auto filter_data_type =
framework::ToDataType(ctx.Input<Tensor>("Filter")->type());
PADDLE_ENFORCE_EQ(input_data_type, filter_data_type, PADDLE_ENFORCE_EQ(input_data_type, filter_data_type,
"input and filter data type should be consistent"); "input and filter data type should be consistent");
...@@ -360,9 +358,8 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( ...@@ -360,9 +358,8 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), ctx.GetPlace(), layout_, library_);
layout_, library_);
} }
} // namespace operators } // namespace operators
......
...@@ -95,9 +95,8 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( ...@@ -95,9 +95,8 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), ctx.GetPlace(), layout_, library_);
layout_, library_);
} }
void Conv2DTransposeOpMaker::Make() { void Conv2DTransposeOpMaker::Make() {
...@@ -314,9 +313,8 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( ...@@ -314,9 +313,8 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), ctx.GetPlace(), layout_, library_);
layout_, library_);
} }
} // namespace operators } // namespace operators
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/cos_sim_op.h" #include "paddle/fluid/operators/cos_sim_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -118,9 +118,8 @@ class CRFDecodingOp : public framework::OperatorWithKernel { ...@@ -118,9 +118,8 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<LoDTensor>("Emission")->type(),
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -51,9 +51,8 @@ class CropOp : public framework::OperatorWithKernel { ...@@ -51,9 +51,8 @@ class CropOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -174,9 +173,7 @@ class CropOpGrad : public framework::OperatorWithKernel { ...@@ -174,9 +173,7 @@ class CropOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))
->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/crop_op.h" #include "paddle/fluid/operators/crop_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -57,9 +57,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel { ...@@ -57,9 +57,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -111,9 +110,8 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -111,9 +110,8 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -36,9 +36,8 @@ class CTCAlignOp : public framework::OperatorWithKernel { ...@@ -36,9 +36,8 @@ class CTCAlignOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -53,8 +53,7 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel { ...@@ -53,8 +53,7 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()), ctx.Input<framework::Tensor>("Input")->type(), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -45,9 +45,8 @@ class BipartiteMatchOp : public framework::OperatorWithKernel { ...@@ -45,9 +45,8 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<LoDTensor>("DistMat")->type(),
framework::ToDataType(ctx.Input<LoDTensor>("DistMat")->type()), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
......
...@@ -66,8 +66,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel { ...@@ -66,8 +66,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()), ctx.Input<framework::Tensor>("Input")->type(), ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -66,9 +66,8 @@ class GenerateProposalsOp : public framework::OperatorWithKernel { ...@@ -66,9 +66,8 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Anchors")->type(),
framework::ToDataType(ctx.Input<Tensor>("Anchors")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -249,8 +249,7 @@ class MineHardExamplesOp : public framework::OperatorWithKernel { ...@@ -249,8 +249,7 @@ class MineHardExamplesOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("ClsLoss")->type()), ctx.Input<framework::Tensor>("ClsLoss")->type(), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
......
...@@ -65,8 +65,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { ...@@ -65,8 +65,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::LoDTensor>("Scores")->type(),
ctx.Input<framework::LoDTensor>("Scores")->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -72,8 +72,7 @@ class PriorBoxOp : public framework::OperatorWithKernel { ...@@ -72,8 +72,7 @@ class PriorBoxOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()), ctx.Input<framework::Tensor>("Input")->type(), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -498,9 +498,8 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel { ...@@ -498,9 +498,8 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -519,9 +518,8 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel { ...@@ -519,9 +518,8 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -78,8 +78,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { ...@@ -78,8 +78,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::LoDTensor>("Anchor")->type(),
ctx.Input<framework::LoDTensor>("Anchor")->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -57,9 +57,8 @@ class TargetAssignOp : public framework::OperatorWithKernel { ...@@ -57,9 +57,8 @@ class TargetAssignOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -71,8 +71,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel { ...@@ -71,8 +71,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::Tensor>("DetectRes")->type(),
ctx.Input<framework::Tensor>("DetectRes")->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -122,8 +122,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -122,8 +122,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
if (var->IsType<framework::SelectedRows>()) { if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
ProtoEncodeHelper e2(static_cast<char*>(buf), 128); ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
size_t rows_memory_size = size_t rows_memory_size = slr->rows().size() * sizeof(int64_t);
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size); e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
slices[2] = ::grpc::Slice(e2.size()); slices[2] = ::grpc::Slice(e2.size());
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size()); memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
......
...@@ -61,8 +61,7 @@ TensorPayload GetTensorPayload(framework::Variable* var, ...@@ -61,8 +61,7 @@ TensorPayload GetTensorPayload(framework::Variable* var,
auto tensor = var->Get<framework::LoDTensor>(); auto tensor = var->Get<framework::LoDTensor>();
// FIXME(wuyi): data types in send_recv.proto is copied from // FIXME(wuyi): data types in send_recv.proto is copied from
// framework.proto // framework.proto
request->set_data_type( request->set_data_type(static_cast<VarMsg::Type>(tensor.type()));
static_cast<VarMsg::Type>(framework::ToDataType(tensor.type())));
for (auto& dim : framework::vectorize(tensor.dims())) { for (auto& dim : framework::vectorize(tensor.dims())) {
request->add_dims(dim); request->add_dims(dim);
} }
...@@ -83,8 +82,7 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var, ...@@ -83,8 +82,7 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
VarMsg* request) { VarMsg* request) {
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
request->set_data_type( request->set_data_type(static_cast<VarMsg::Type>(slr->value().type()));
static_cast<VarMsg::Type>(framework::ToDataType(slr->value().type())));
request->set_lod_level(0); request->set_lod_level(0);
request->set_slr_height(slr->height()); request->set_slr_height(slr->height());
......
...@@ -58,18 +58,19 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var, ...@@ -58,18 +58,19 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
VarMsg* request); VarMsg* request);
inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) { inline framework::proto::VarType::Type ToVarType(
sendrecv::VariableMessage::Type type) {
switch (type) { switch (type) {
case sendrecv::VariableMessage::FP32: case sendrecv::VariableMessage::FP32:
return typeid(float); // NOLINT return framework::proto::VarType::FP32; // NOLINT
case sendrecv::VariableMessage::FP64: case sendrecv::VariableMessage::FP64:
return typeid(double); // NOLINT return framework::proto::VarType::FP64; // NOLINT
case sendrecv::VariableMessage::INT32: case sendrecv::VariableMessage::INT32:
return typeid(int); // NOLINT return framework::proto::VarType::INT32; // NOLINT
case sendrecv::VariableMessage::INT64: case sendrecv::VariableMessage::INT64:
return typeid(int64_t); // NOLINT return framework::proto::VarType::INT64; // NOLINT
case sendrecv::VariableMessage::BOOL: case sendrecv::VariableMessage::BOOL:
return typeid(bool); // NOLINT return framework::proto::VarType::BOOL; // NOLINT
default: default:
PADDLE_THROW("Not support type %d", type); PADDLE_THROW("Not support type %d", type);
} }
......
...@@ -114,7 +114,7 @@ bool VariableResponse::CopyLodTensorData( ...@@ -114,7 +114,7 @@ bool VariableResponse::CopyLodTensorData(
tensor->set_lod(lod); tensor->set_lod(lod);
void* tensor_data = void* tensor_data =
tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type())); tensor->mutable_data(ctx.GetPlace(), ToVarType(meta_.data_type()));
VLOG(6) << "Tensor.memory_size = " << tensor->memory_size() VLOG(6) << "Tensor.memory_size = " << tensor->memory_size()
<< ", Buffer Size = " << length; << ", Buffer Size = " << length;
...@@ -139,13 +139,13 @@ bool VariableResponse::CopySelectRowsTensorData( ...@@ -139,13 +139,13 @@ bool VariableResponse::CopySelectRowsTensorData(
slr->set_height(meta_.slr_height()); slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value(); auto* tensor = slr->mutable_value();
tensor->Resize(dims); tensor->Resize(dims);
PADDLE_ENFORCE_EQ(static_cast<size_t>(tensor->numel()), PADDLE_ENFORCE_EQ(
length / framework::SizeOfType( static_cast<size_t>(tensor->numel()),
paddle::operators::distributed::ToTypeIndex( length / framework::SizeOfType(paddle::operators::distributed::ToVarType(
meta_.data_type()))); meta_.data_type())));
void* tensor_data = tensor->mutable_data( void* tensor_data = tensor->mutable_data(
ctx.GetPlace(), ctx.GetPlace(),
paddle::operators::distributed::ToTypeIndex(meta_.data_type())); paddle::operators::distributed::ToVarType(meta_.data_type()));
if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) { if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
return false; return false;
...@@ -159,8 +159,7 @@ bool VariableResponse::CopySelectRowsData( ...@@ -159,8 +159,7 @@ bool VariableResponse::CopySelectRowsData(
const platform::DeviceContext& ctx, int length) { const platform::DeviceContext& ctx, int length) {
auto* slr = GetVar()->GetMutable<framework::SelectedRows>(); auto* slr = GetVar()->GetMutable<framework::SelectedRows>();
slr->mutable_rows()->clear(); slr->mutable_rows()->clear();
slr->mutable_rows()->resize(length / slr->mutable_rows()->resize(length / sizeof(int64_t)); // int64
framework::SizeOfType(typeid(int64_t))); // int64
int64_t* rows_data = slr->mutable_rows()->data(); int64_t* rows_data = slr->mutable_rows()->data();
// copy rows CPU data, GPU data will be copied lazily. // copy rows CPU data, GPU data will be copied lazily.
......
...@@ -108,9 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel { ...@@ -108,9 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.MultiInput<framework::Tensor>("X").front()->type(), ctx.GetPlace());
ctx.MultiInput<framework::Tensor>("X").front()->type()),
ctx.GetPlace());
} }
}; };
......
...@@ -42,9 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel { ...@@ -42,9 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.MultiInput<framework::Tensor>("X")[0]->type(), ctx.GetPlace());
ctx.MultiInput<framework::Tensor>("X")[0]->type()),
ctx.GetPlace());
} }
}; };
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include <thrust/device_ptr.h> #include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h> #include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h> #include <thrust/random.h>
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h" #include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h" #include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise/elementwise_min_op.h" #include "paddle/fluid/operators/elementwise/elementwise_min_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -197,8 +197,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -197,8 +197,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = framework::ToDataType( auto input_data_type =
ctx.Input<Tensor>(framework::GradVarName("Out"))->type()); ctx.Input<Tensor>(framework::GradVarName("Out"))->type();
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) { if (platform::CanMKLDNNBeUsed(ctx)) {
......
...@@ -8,8 +8,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -8,8 +8,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise/elementwise_pow_op.h" #include "paddle/fluid/operators/elementwise/elementwise_pow_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" #include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/expand_op.h" #include "paddle/fluid/operators/expand_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -115,9 +115,8 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -115,9 +115,8 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -175,9 +174,8 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -175,9 +174,8 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -79,9 +79,8 @@ framework::OpKernelType FCOp::GetExpectedKernelType( ...@@ -79,9 +79,8 @@ framework::OpKernelType FCOp::GetExpectedKernelType(
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), ctx.GetPlace(), layout, library);
layout, library);
} }
void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const { void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const {
...@@ -111,9 +110,8 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType( ...@@ -111,9 +110,8 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), ctx.GetPlace(), layout, library);
layout, library);
} }
void FCOpMaker::Make() { void FCOpMaker::Make() {
......
...@@ -59,9 +59,9 @@ class FillConstantOp : public framework::OperatorBase { ...@@ -59,9 +59,9 @@ class FillConstantOp : public framework::OperatorBase {
if (force_cpu) { if (force_cpu) {
auto cpu = platform::CPUPlace(); auto cpu = platform::CPUPlace();
tensor->mutable_data(cpu, framework::ToTypeIndex(data_type)); tensor->mutable_data(cpu, data_type);
} else { } else {
tensor->mutable_data(dev_place, framework::ToTypeIndex(data_type)); tensor->mutable_data(dev_place, data_type);
} }
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
......
...@@ -55,7 +55,7 @@ class FillOp : public framework::OperatorBase { ...@@ -55,7 +55,7 @@ class FillOp : public framework::OperatorBase {
static_cast<framework::proto::VarType::Type>(Attr<int>("dtype")); static_cast<framework::proto::VarType::Type>(Attr<int>("dtype"));
platform::CPUPlace cpu; platform::CPUPlace cpu;
auto force_cpu = Attr<bool>("force_cpu"); auto force_cpu = Attr<bool>("force_cpu");
out.mutable_data(force_cpu ? cpu : place, framework::ToTypeIndex(dtype)); out.mutable_data(force_cpu ? cpu : place, dtype);
framework::LoDTensor tensor; framework::LoDTensor tensor;
...@@ -64,7 +64,7 @@ class FillOp : public framework::OperatorBase { ...@@ -64,7 +64,7 @@ class FillOp : public framework::OperatorBase {
} else { } else {
// Always make tensor in CPU memory. // Always make tensor in CPU memory.
tensor.Resize(out.dims()); tensor.Resize(out.dims());
tensor.mutable_data(cpu, framework::ToTypeIndex(dtype)); tensor.mutable_data(cpu, dtype);
} }
framework::VisitDataType( framework::VisitDataType(
......
...@@ -135,9 +135,8 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel { ...@@ -135,9 +135,8 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx.Input<framework::Tensor>("X")->type(), PADDLE_ENFORCE_EQ(ctx.Input<framework::Tensor>("X")->type(),
ctx.Input<framework::Tensor>("Y")->type(), ctx.Input<framework::Tensor>("Y")->type(),
"The element's type of input should be the same."); "The element's type of input should be the same.");
auto input_data_type = return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()); ctx.GetPlace());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -324,9 +323,8 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { ...@@ -324,9 +323,8 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type_index = ctx.Input<framework::Tensor>("Y")->type(); return framework::OpKernelType(ctx.Input<framework::Tensor>("Y")->type(),
auto input_data_type = framework::ToDataType(input_data_type_index); ctx.GetPlace());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -115,8 +115,7 @@ void FusedEmbeddingFCLSTMOp::InferShape( ...@@ -115,8 +115,7 @@ void FusedEmbeddingFCLSTMOp::InferShape(
framework::OpKernelType FusedEmbeddingFCLSTMOp::GetExpectedKernelType( framework::OpKernelType FusedEmbeddingFCLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::LoDTensor>("Embeddings")->type(),
ctx.Input<framework::LoDTensor>("Embeddings")->type()),
ctx.device_context()); ctx.device_context());
} }
......
...@@ -93,9 +93,8 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -93,9 +93,8 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionGRUOp::GetExpectedKernelType( framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
void FusionGRUOpMaker::Make() { void FusionGRUOpMaker::Make() {
......
...@@ -117,9 +117,8 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -117,9 +117,8 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionLSTMOp::GetExpectedKernelType( framework::OpKernelType FusionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
void FusionLSTMOpMaker::Make() { void FusionLSTMOpMaker::Make() {
......
...@@ -61,9 +61,8 @@ void FusionSeqConvEltAddReluOp::InferShape( ...@@ -61,9 +61,8 @@ void FusionSeqConvEltAddReluOp::InferShape(
framework::OpKernelType FusionSeqConvEltAddReluOp::GetExpectedKernelType( framework::OpKernelType FusionSeqConvEltAddReluOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
void FusionSeqConvEltAddReluOpMaker::Make() { void FusionSeqConvEltAddReluOpMaker::Make() {
......
...@@ -67,9 +67,8 @@ void FusionSeqExpandConcatFCOp::InferShape( ...@@ -67,9 +67,8 @@ void FusionSeqExpandConcatFCOp::InferShape(
framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType( framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(ctx.MultiInput<LoDTensor>("X")[0]->type(),
framework::ToDataType(ctx.MultiInput<LoDTensor>("X")[0]->type()), ctx.device_context());
ctx.device_context());
} }
void FusionSeqExpandConcatFCOpMaker::Make() { void FusionSeqExpandConcatFCOpMaker::Make() {
......
...@@ -42,9 +42,8 @@ class GatherOp : public framework::OperatorWithKernel { ...@@ -42,9 +42,8 @@ class GatherOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -60,9 +59,8 @@ class GatherGradOp : public framework::OperatorWithKernel { ...@@ -60,9 +59,8 @@ class GatherGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -63,9 +63,9 @@ class GridSampleOp : public framework::OperatorWithKernel { ...@@ -63,9 +63,9 @@ class GridSampleOp : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_); framework::DataLayout::kAnyLayout, library_);
} }
}; };
...@@ -159,9 +159,9 @@ class GridSampleOpGrad : public framework::OperatorWithKernel { ...@@ -159,9 +159,9 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_); framework::DataLayout::kAnyLayout, library_);
} }
}; };
......
...@@ -141,8 +141,7 @@ class GroupNormGradOp : public framework::OperatorWithKernel { ...@@ -141,8 +141,7 @@ class GroupNormGradOp : public framework::OperatorWithKernel {
if (t == nullptr) { if (t == nullptr) {
PADDLE_THROW("can't find Y@GRAD"); PADDLE_THROW("can't find Y@GRAD");
} }
return framework::OpKernelType(framework::ToDataType(t->type()), return framework::OpKernelType(t->type(), ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/gru_unit_op.h" #include "paddle/fluid/operators/gru_unit_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -76,9 +76,8 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { ...@@ -76,9 +76,8 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.GetPlace());
ctx.GetPlace());
} }
}; };
...@@ -162,9 +161,8 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { ...@@ -162,9 +161,8 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/hinge_loss_op.h" #include "paddle/fluid/operators/hinge_loss_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/huber_loss_op.h" #include "paddle/fluid/operators/huber_loss_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -11,8 +11,6 @@ ...@@ -11,8 +11,6 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/im2sequence_op.h" #include "paddle/fluid/operators/im2sequence_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -55,8 +55,8 @@ class InterpolateOp : public framework::OperatorWithKernel { ...@@ -55,8 +55,8 @@ class InterpolateOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace()); ctx.GetPlace());
} }
}; };
...@@ -124,8 +124,8 @@ class InterpolateOpGrad : public framework::OperatorWithKernel { ...@@ -124,8 +124,8 @@ class InterpolateOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -35,8 +35,7 @@ class IsEmptyOp : public framework::OperatorWithKernel { ...@@ -35,8 +35,7 @@ class IsEmptyOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = framework::OpKernelType( framework::OpKernelType kt = framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.Input<framework::LoDTensor>("X")->type(), platform::CPUPlace());
platform::CPUPlace());
return kt; return kt;
} }
}; };
......
...@@ -40,10 +40,9 @@ class OverflowOp : public framework::OperatorWithKernel { ...@@ -40,10 +40,9 @@ class OverflowOp : public framework::OperatorWithKernel {
int dtype = -1; int dtype = -1;
auto *x_var = ctx.InputVar("X"); auto *x_var = ctx.InputVar("X");
if (x_var->IsType<framework::LoDTensor>()) { if (x_var->IsType<framework::LoDTensor>()) {
dtype = framework::ToDataType(x_var->Get<framework::LoDTensor>().type()); dtype = x_var->Get<framework::LoDTensor>().type();
} else if (x_var->IsType<framework::SelectedRows>()) { } else if (x_var->IsType<framework::SelectedRows>()) {
dtype = framework::ToDataType( dtype = x_var->Get<framework::SelectedRows>().value().type();
x_var->Get<framework::SelectedRows>().value().type());
} else { } else {
PADDLE_THROW("Cannot find the input data type by all input data"); PADDLE_THROW("Cannot find the input data type by all input data");
} }
......
...@@ -11,8 +11,6 @@ ...@@ -11,8 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/isfinite_op.h" #include "paddle/fluid/operators/isfinite_op.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/l1_norm_op.h" #include "paddle/fluid/operators/l1_norm_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -153,8 +153,7 @@ class LayerNormGradOp : public framework::OperatorWithKernel { ...@@ -153,8 +153,7 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
if (t == nullptr) { if (t == nullptr) {
PADDLE_THROW("can't find Y@GRAD"); PADDLE_THROW("can't find Y@GRAD");
} }
return framework::OpKernelType(framework::ToDataType(t->type()), return framework::OpKernelType(t->type(), ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -184,9 +184,8 @@ class LinearChainCRFOp : public framework::OperatorWithKernel { ...@@ -184,9 +184,8 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
// is determined by its input "Emission". // is determined by its input "Emission".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<LoDTensor>("Emission")->type(),
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
...@@ -244,9 +243,7 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel { ...@@ -244,9 +243,7 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<LoDTensor>(framework::GradVarName("LogLikelihood"))->type(),
ctx.Input<LoDTensor>(framework::GradVarName("LogLikelihood"))
->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -60,7 +60,7 @@ class LoadCombineOp : public framework::OperatorBase { ...@@ -60,7 +60,7 @@ class LoadCombineOp : public framework::OperatorBase {
// Get data from fin to tensor // Get data from fin to tensor
DeserializeFromStream(fin, tensor, dev_ctx); DeserializeFromStream(fin, tensor, dev_ctx);
auto in_dtype = framework::ToDataType(tensor->type()); auto in_dtype = tensor->type();
auto out_dtype = auto out_dtype =
load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
......
...@@ -65,7 +65,7 @@ class LoadOp : public framework::OperatorBase { ...@@ -65,7 +65,7 @@ class LoadOp : public framework::OperatorBase {
DeserializeFromStream(fin, tensor, dev_ctx); DeserializeFromStream(fin, tensor, dev_ctx);
auto load_as_fp16 = Attr<bool>("load_as_fp16"); auto load_as_fp16 = Attr<bool>("load_as_fp16");
auto in_dtype = framework::ToDataType(tensor->type()); auto in_dtype = tensor->type();
auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) { if (in_dtype != out_dtype) {
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册