未验证 提交 a607b6c8 编写于 作者: Q Qiyang Min 提交者: GitHub

Merge pull request #15216 from velconia/local_release_1_2_x_tensor_type

Feature/tensor type
...@@ -86,6 +86,7 @@ endif(NOT WITH_GOLANG) ...@@ -86,6 +86,7 @@ endif(NOT WITH_GOLANG)
if(WITH_GPU) if(WITH_GPU)
add_definitions(-DPADDLE_WITH_CUDA) add_definitions(-DPADDLE_WITH_CUDA)
add_definitions(-DEIGEN_USE_GPU)
FIND_PACKAGE(CUDA REQUIRED) FIND_PACKAGE(CUDA REQUIRED)
......
...@@ -85,7 +85,7 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var, ...@@ -85,7 +85,7 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
out->mutable_data(expected_kernel_type.place_, in.type()); out->mutable_data(expected_kernel_type.place_, in.type());
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(in.type()), in.type(),
CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out)); CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out));
out->set_layout(expected_kernel_type.data_layout_); out->set_layout(expected_kernel_type.data_layout_);
...@@ -101,7 +101,7 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) { ...@@ -101,7 +101,7 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
case mkldnn::memory::data_type::f32: case mkldnn::memory::data_type::f32:
return platform::to_void_cast(tensor.data<float>()); return platform::to_void_cast(tensor.data<float>());
case mkldnn::memory::data_type::s8: case mkldnn::memory::data_type::s8:
return platform::to_void_cast(tensor.data<char>()); return platform::to_void_cast(tensor.data<int8_t>());
case mkldnn::memory::data_type::u8: case mkldnn::memory::data_type::u8:
return platform::to_void_cast(tensor.data<unsigned char>()); return platform::to_void_cast(tensor.data<unsigned char>());
case mkldnn::memory::data_type::s16: case mkldnn::memory::data_type::s16:
...@@ -144,7 +144,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -144,7 +144,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
memory::data_type in_type = ToMKLDNNDataType(in.type()); memory::data_type in_type = ToMKLDNNDataType(in.type());
PADDLE_ENFORCE(in_type != memory::data_type::data_undef, PADDLE_ENFORCE(in_type != memory::data_type::data_undef,
"Input tensor type is not supported: ", in.type().name()); "Input tensor type is not supported: %s", in.type());
memory::data_type out_type = in_type; memory::data_type out_type = in_type;
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format()); auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
......
...@@ -50,14 +50,14 @@ inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) { ...@@ -50,14 +50,14 @@ inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) {
} }
} }
inline MKLDNNDataType ToMKLDNNDataType(const std::type_index type) { inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
static const std::map<std::type_index, MKLDNNDataType> dict{ static std::unordered_map<int, MKLDNNDataType> dict{
{std::type_index(typeid(float)), MKLDNNDataType::f32}, // NOLINT {DataTypeTrait<float>::DataType, MKLDNNDataType::f32},
{std::type_index(typeid(char)), MKLDNNDataType::s8}, // NOLINT {DataTypeTrait<int8_t>::DataType, MKLDNNDataType::s8},
{std::type_index(typeid(unsigned char)), MKLDNNDataType::u8}, {DataTypeTrait<uint8_t>::DataType, MKLDNNDataType::u8},
{std::type_index(typeid(int16_t)), MKLDNNDataType::s16}, {DataTypeTrait<int16_t>::DataType, MKLDNNDataType::s16},
{std::type_index(typeid(int32_t)), MKLDNNDataType::s32}}; {DataTypeTrait<int32_t>::DataType, MKLDNNDataType::s32}};
auto iter = dict.find(type); auto iter = dict.find(static_cast<int>(type));
if (iter != dict.end()) return iter->second; if (iter != dict.end()) return iter->second;
return MKLDNNDataType::data_undef; return MKLDNNDataType::data_undef;
} }
......
...@@ -26,7 +26,7 @@ struct DataTypeMap { ...@@ -26,7 +26,7 @@ struct DataTypeMap {
std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_; std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_;
std::unordered_map<int, std::type_index> proto_to_cpp_; std::unordered_map<int, std::type_index> proto_to_cpp_;
std::unordered_map<int, std::string> proto_to_str_; std::unordered_map<int, std::string> proto_to_str_;
std::unordered_map<std::type_index, size_t> cpp_to_size_; std::unordered_map<int, size_t> proto_to_size_;
}; };
static DataTypeMap* InitDataTypeMap(); static DataTypeMap* InitDataTypeMap();
...@@ -45,7 +45,7 @@ static inline void RegisterType(DataTypeMap* map, ...@@ -45,7 +45,7 @@ static inline void RegisterType(DataTypeMap* map,
map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T)); map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T));
map->cpp_to_proto_.emplace(typeid(T), proto_type); map->cpp_to_proto_.emplace(typeid(T), proto_type);
map->proto_to_str_.emplace(static_cast<int>(proto_type), name); map->proto_to_str_.emplace(static_cast<int>(proto_type), name);
map->cpp_to_size_.emplace(typeid(T), sizeof(T)); map->proto_to_size_.emplace(static_cast<int>(proto_type), sizeof(T));
} }
static DataTypeMap* InitDataTypeMap() { static DataTypeMap* InitDataTypeMap() {
...@@ -54,17 +54,7 @@ static DataTypeMap* InitDataTypeMap() { ...@@ -54,17 +54,7 @@ static DataTypeMap* InitDataTypeMap() {
#define RegType(cc_type, proto_type) \ #define RegType(cc_type, proto_type) \
RegisterType<cc_type>(retv, proto_type, #cc_type) RegisterType<cc_type>(retv, proto_type, #cc_type)
// NOTE: Add your customize type here. _ForEachDataType_(RegType);
RegType(float16, proto::VarType::FP16);
RegType(float, proto::VarType::FP32);
RegType(double, proto::VarType::FP64);
RegType(int, proto::VarType::INT32);
RegType(int64_t, proto::VarType::INT64);
RegType(bool, proto::VarType::BOOL);
RegType(size_t, proto::VarType::SIZE_T);
RegType(int16_t, proto::VarType::INT16);
RegType(uint8_t, proto::VarType::UINT8);
RegType(int8_t, proto::VarType::INT8);
#undef RegType #undef RegType
return retv; return retv;
...@@ -96,12 +86,12 @@ std::string DataTypeToString(const proto::VarType::Type type) { ...@@ -96,12 +86,12 @@ std::string DataTypeToString(const proto::VarType::Type type) {
static_cast<int>(type)); static_cast<int>(type));
} }
size_t SizeOfType(std::type_index type) { size_t SizeOfType(proto::VarType::Type type) {
auto it = gDataTypeMap().cpp_to_size_.find(type); auto it = gDataTypeMap().proto_to_size_.find(static_cast<int>(type));
if (it != gDataTypeMap().cpp_to_size_.end()) { if (it != gDataTypeMap().proto_to_size_.end()) {
return it->second; return it->second;
} }
PADDLE_THROW("Not support %s as tensor type", type.name()); PADDLE_THROW("Not support %s as tensor type", DataTypeToString(type));
} }
} // namespace framework } // namespace framework
......
...@@ -22,46 +22,59 @@ limitations under the License. */ ...@@ -22,46 +22,59 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T>
struct DataTypeTrait {};
// Stub handle for void
template <>
struct DataTypeTrait<void> {
constexpr static auto DataType = proto::VarType::RAW;
};
#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
callback(cpp_type, ::paddle::framework::proto::VarType::proto_type);
#define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8)
#define DefineDataTypeTrait(cpp_type, proto_type) \
template <> \
struct DataTypeTrait<cpp_type> { \
constexpr static auto DataType = proto_type; \
}
_ForEachDataType_(DefineDataTypeTrait);
#undef DefineDataTypeTrait
extern proto::VarType::Type ToDataType(std::type_index type); extern proto::VarType::Type ToDataType(std::type_index type);
extern std::type_index ToTypeIndex(proto::VarType::Type type); extern std::type_index ToTypeIndex(proto::VarType::Type type);
template <typename Visitor> template <typename Visitor>
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
switch (type) { #define VisitDataTypeCallback(cpp_type, proto_type) \
case proto::VarType::FP16: do { \
visitor.template apply<platform::float16>(); if (type == proto_type) { \
break; visitor.template apply<cpp_type>(); \
case proto::VarType::FP32: return; \
visitor.template apply<float>(); } \
break; } while (0)
case proto::VarType::FP64:
visitor.template apply<double>(); _ForEachDataType_(VisitDataTypeCallback);
break; #undef VisitDataTypeCallback
case proto::VarType::INT32: PADDLE_THROW("Not supported %d", type);
visitor.template apply<int>();
break;
case proto::VarType::INT64:
visitor.template apply<int64_t>();
break;
case proto::VarType::BOOL:
visitor.template apply<bool>();
break;
case proto::VarType::UINT8:
visitor.template apply<uint8_t>();
break;
case proto::VarType::INT16:
visitor.template apply<int16_t>();
break;
case proto::VarType::INT8:
visitor.template apply<int8_t>();
break;
default:
PADDLE_THROW("Not supported %d", type);
}
} }
extern std::string DataTypeToString(const proto::VarType::Type type); extern std::string DataTypeToString(const proto::VarType::Type type);
extern size_t SizeOfType(std::type_index type); extern size_t SizeOfType(proto::VarType::Type type);
inline std::ostream& operator<<(std::ostream& out, inline std::ostream& operator<<(std::ostream& out,
const proto::VarType::Type& type) { const proto::VarType::Type& type) {
out << DataTypeToString(type); out << DataTypeToString(type);
......
...@@ -26,15 +26,15 @@ TEST(DataType, float16) { ...@@ -26,15 +26,15 @@ TEST(DataType, float16) {
Tensor tensor; Tensor tensor;
CPUPlace cpu; CPUPlace cpu;
tensor.mutable_data(cpu, f::ToTypeIndex(dtype)); tensor.mutable_data(cpu, dtype);
// test fp16 tensor // test fp16 tensor
EXPECT_EQ(tensor.type(), std::type_index(typeid(float16))); EXPECT_EQ(tensor.type(), f::ToDataType(typeid(float16)));
// test fp16 size // test fp16 size
EXPECT_EQ(f::SizeOfType(f::ToTypeIndex(dtype)), 2u); EXPECT_EQ(f::SizeOfType(dtype), 2u);
// test debug info // test debug info
std::string type = "float16"; std::string type = "::paddle::platform::float16";
EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str()); EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str());
} }
...@@ -120,7 +120,7 @@ void AllReduceOpHandle::RunImpl() { ...@@ -120,7 +120,7 @@ void AllReduceOpHandle::RunImpl() {
// Reduce All Tensor to trg in CPU // Reduce All Tensor to trg in CPU
ReduceLoDTensor func(lod_tensors, &trg); ReduceLoDTensor func(lod_tensors, &trg);
VisitDataType(ToDataType(lod_tensors[0]->type()), func); VisitDataType(lod_tensors[0]->type(), func);
for (size_t i = 1; i < local_scopes_.size(); ++i) { for (size_t i = 1; i < local_scopes_.size(); ++i) {
auto &scope = auto &scope =
......
...@@ -33,7 +33,7 @@ struct FuseVarsOpHandle : public OpHandleBase { ...@@ -33,7 +33,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
FuseVarsOpHandle(ir::Node *node, Scope *local_scope, FuseVarsOpHandle(ir::Node *node, Scope *local_scope,
const platform::Place &place, const platform::Place &place,
const std::unordered_map<std::string, int64_t> &inputs_numel, const std::unordered_map<std::string, int64_t> &inputs_numel,
const std::type_index &var_type) const proto::VarType::Type var_type)
: OpHandleBase(node), : OpHandleBase(node),
local_scope_(local_scope), local_scope_(local_scope),
place_(place), place_(place),
...@@ -57,7 +57,7 @@ struct FuseVarsOpHandle : public OpHandleBase { ...@@ -57,7 +57,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
Scope *local_scope_; Scope *local_scope_;
const platform::Place place_; const platform::Place place_;
const std::unordered_map<std::string, int64_t> inputs_numel_; const std::unordered_map<std::string, int64_t> inputs_numel_;
const std::type_index type_; const proto::VarType::Type type_;
int64_t total_numel_; int64_t total_numel_;
}; };
} // namespace details } // namespace details
......
...@@ -106,7 +106,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -106,7 +106,7 @@ void ReduceOpHandle::RunImpl() {
if (!FLAGS_cpu_deterministic) { if (!FLAGS_cpu_deterministic) {
ReduceLoDTensor func(lod_tensors, ReduceLoDTensor func(lod_tensors,
out_var->GetMutable<framework::LoDTensor>()); out_var->GetMutable<framework::LoDTensor>());
VisitDataType(ToDataType(lod_tensors[0]->type()), func); VisitDataType(lod_tensors[0]->type(), func);
} else { } else {
// We sum lod_tensors to reduce_sum_trg which is in local_scopes_0 // We sum lod_tensors to reduce_sum_trg which is in local_scopes_0
// here, but it doesn't mean reduce_sum_trg must be in local_scopes_0. // here, but it doesn't mean reduce_sum_trg must be in local_scopes_0.
...@@ -116,7 +116,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -116,7 +116,7 @@ void ReduceOpHandle::RunImpl() {
->FindVar(out_var_handle->name_) ->FindVar(out_var_handle->name_)
->GetMutable<framework::LoDTensor>(); ->GetMutable<framework::LoDTensor>();
ReduceLoDTensor func(lod_tensors, &reduce_sum_trg); ReduceLoDTensor func(lod_tensors, &reduce_sum_trg);
VisitDataType(ToDataType(lod_tensors[0]->type()), func); VisitDataType(lod_tensors[0]->type(), func);
auto trg = out_var->GetMutable<framework::LoDTensor>(); auto trg = out_var->GetMutable<framework::LoDTensor>();
if (reduce_sum_trg.data<void>() != trg->data<void>()) { if (reduce_sum_trg.data<void>() != trg->data<void>()) {
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/dlpack_tensor.h" #include "paddle/fluid/framework/dlpack_tensor.h"
#include "paddle/fluid/framework/data_type.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -36,26 +36,23 @@ static ::DLDataType GetDLDataTypeCode() { ...@@ -36,26 +36,23 @@ static ::DLDataType GetDLDataTypeCode() {
return dtype; return dtype;
} }
static DLDataType GetDLDataTypeFromTypeIndex(const std::type_index &type) { static std::unordered_map<int, ::DLDataType> CreateDLDataTypeMap() {
#define REG_DL_DATA_TYPE(type) \ static std::unordered_map<int, ::DLDataType> result;
{ std::type_index(typeid(type)), GetDLDataTypeCode<type>() }
static const std::unordered_map<std::type_index, ::DLDataType> #define REG_DL_DATA_TYPE(cpp_type, proto_type) \
type_to_dtype_map({ result[static_cast<int>(proto_type)] = GetDLDataTypeCode<cpp_type>()
REG_DL_DATA_TYPE(platform::float16), // NOLINT
REG_DL_DATA_TYPE(float), // NOLINT _ForEachDataType_(REG_DL_DATA_TYPE);
REG_DL_DATA_TYPE(double), // NOLINT #undef REG_DL_DATA_TYPE
REG_DL_DATA_TYPE(int), // NOLINT return result;
REG_DL_DATA_TYPE(int64_t), // NOLINT }
REG_DL_DATA_TYPE(bool), // NOLINT
REG_DL_DATA_TYPE(size_t), // NOLINT static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) {
REG_DL_DATA_TYPE(int16_t), // NOLINT static auto type_to_dtype_map = CreateDLDataTypeMap();
REG_DL_DATA_TYPE(uint8_t), // NOLINT
REG_DL_DATA_TYPE(int8_t) // NOLINT
});
static auto type_to_dtype_map_end_it = type_to_dtype_map.end(); static auto type_to_dtype_map_end_it = type_to_dtype_map.end();
auto it = type_to_dtype_map.find(type); auto it = type_to_dtype_map.find(static_cast<int>(type));
PADDLE_ENFORCE(it != type_to_dtype_map_end_it, "Unsupported data type %s", PADDLE_ENFORCE(it != type_to_dtype_map_end_it, "Unsupported data type %d",
type.name()); type);
return it->second; return it->second;
#undef REG_DL_DATA_TYPE #undef REG_DL_DATA_TYPE
} }
......
...@@ -91,23 +91,11 @@ void TestMainLoop() { ...@@ -91,23 +91,11 @@ void TestMainLoop() {
} }
} }
} }
TEST(dlpack, test_all) {
#define TestCallback(cpp_type, proto_type) TestMainLoop<cpp_type>()
#define PADDLE_DLPACK_TEST(type) \ _ForEachDataType_(TestCallback);
TEST(dlpack, test_##type) { TestMainLoop<type>(); } }
using float16 = platform::float16;
PADDLE_DLPACK_TEST(float16);
PADDLE_DLPACK_TEST(float);
PADDLE_DLPACK_TEST(double);
PADDLE_DLPACK_TEST(int);
PADDLE_DLPACK_TEST(int64_t);
PADDLE_DLPACK_TEST(bool);
PADDLE_DLPACK_TEST(size_t);
PADDLE_DLPACK_TEST(int16_t);
PADDLE_DLPACK_TEST(uint8_t);
PADDLE_DLPACK_TEST(int8_t);
#undef PADDLE_DLPACK_TEST
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -138,39 +138,19 @@ void print_lod_tensor(std::string var_name, const LoDTensor& lod_tensor) { ...@@ -138,39 +138,19 @@ void print_lod_tensor(std::string var_name, const LoDTensor& lod_tensor) {
std::cout << sstream.str() << std::endl; std::cout << sstream.str() << std::endl;
} }
void print_fetch_var(Scope* scope, std::string var_name) { static void print_fetch_var(Scope* scope, const std::string& var_name) {
const LoDTensor& tensor = scope->FindVar(var_name)->Get<LoDTensor>(); auto& tensor = scope->FindVar(var_name)->Get<LoDTensor>();
if (std::type_index(tensor.type()) == #define PrintLoDTensorCallback(cpp_type, proto_type) \
std::type_index(typeid(platform::float16))) { do { \
print_lod_tensor<platform::float16>(var_name, tensor); if (tensor.type() == proto_type) { \
} else if (std::type_index(tensor.type()) == std::type_index(typeid(float))) { print_lod_tensor<cpp_type>(var_name, tensor); \
print_lod_tensor<float>(var_name, tensor); return; \
} else if (std::type_index(tensor.type()) == } \
std::type_index(typeid(double))) { } while (0)
print_lod_tensor<double>(var_name, tensor);
} else if (std::type_index(tensor.type()) == std::type_index(typeid(int))) { _ForEachDataType_(PrintLoDTensorCallback);
print_lod_tensor<int>(var_name, tensor); VLOG(1) << "print_fetch_var: unrecognized data type:" << tensor.type();
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int64_t))) {
print_lod_tensor<int64_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) == std::type_index(typeid(bool))) {
print_lod_tensor<bool>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(uint8_t))) {
print_lod_tensor<uint8_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int16_t))) {
print_lod_tensor<int16_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int8_t))) {
print_lod_tensor<int8_t>(var_name, tensor);
} else {
VLOG(1) << "print_fetch_var: unrecognized data type:"
<< tensor.type().name();
}
return;
} }
void ExecutorThreadWorker::TrainFiles() { void ExecutorThreadWorker::TrainFiles() {
......
...@@ -70,9 +70,9 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) { ...@@ -70,9 +70,9 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
// only print first ten elements // only print first ten elements
int64_t size = t.numel() < 10 ? t.numel() : 10; int64_t size = t.numel() < 10 ? t.numel() : 10;
for (int64_t i = 0; i < size; ++i) { for (int64_t i = 0; i < size; ++i) {
if (IsType<float>(t.type())) { if (t.type() == proto::VarType::FP32) {
os << t.data<float>()[i] << " "; os << t.data<float>()[i] << " ";
} else if (IsType<int64_t>(t.type())) { } else if (t.type() == proto::VarType::INT64) {
os << t.data<int64_t>()[i] << " "; os << t.data<int64_t>()[i] << " ";
} else { } else {
PADDLE_THROW("LoDTensor data type not in [float, int64_t]"); PADDLE_THROW("LoDTensor data type not in [float, int64_t]");
...@@ -387,7 +387,7 @@ void LoDTensor::MergeLoDTensor( ...@@ -387,7 +387,7 @@ void LoDTensor::MergeLoDTensor(
PADDLE_ENFORCE(!lod_tensors.empty()); PADDLE_ENFORCE(!lod_tensors.empty());
framework::DDim new_dim = lod_tensors[0]->dims(); framework::DDim new_dim = lod_tensors[0]->dims();
std::type_index new_type = lod_tensors[0]->type(); auto new_type = lod_tensors[0]->type();
framework::DataLayout new_layout = lod_tensors[0]->layout(); framework::DataLayout new_layout = lod_tensors[0]->layout();
LoD new_lod = lod_tensors[0]->lod(); LoD new_lod = lod_tensors[0]->lod();
for (size_t i = 1; i < lod_tensors.size(); ++i) { for (size_t i = 1; i < lod_tensors.size(); ++i) {
......
...@@ -34,7 +34,8 @@ TEST(OpKernelType, ToString) { ...@@ -34,7 +34,8 @@ TEST(OpKernelType, ToString) {
OpKernelType op_kernel_type2(DataType::FP16, CUDAPlace(0), DataLayout::kNCHW, OpKernelType op_kernel_type2(DataType::FP16, CUDAPlace(0), DataLayout::kNCHW,
LibraryType::kCUDNN); LibraryType::kCUDNN);
ASSERT_EQ(paddle::framework::KernelTypeToString(op_kernel_type2), ASSERT_EQ(paddle::framework::KernelTypeToString(op_kernel_type2),
"data_type[float16]:data_layout[NCHW]:place[CUDAPlace(0)]:library_" "data_type[::paddle::platform::float16]:data_layout[NCHW]:place["
"CUDAPlace(0)]:library_"
"type[CUDNN]"); "type[CUDNN]");
} }
......
...@@ -43,10 +43,9 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = { ...@@ -43,10 +43,9 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
proto::VarType::Type GetDataTypeOfVar(const Variable* var) { proto::VarType::Type GetDataTypeOfVar(const Variable* var) {
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
return framework::ToDataType(var->Get<framework::LoDTensor>().type()); return var->Get<framework::LoDTensor>().type();
} else if (var->IsType<framework::SelectedRows>()) { } else if (var->IsType<framework::SelectedRows>()) {
return framework::ToDataType( return var->Get<framework::SelectedRows>().value().type();
var->Get<framework::SelectedRows>().value().type());
} else { } else {
PADDLE_THROW("Var should be LoDTensor or SelectedRows"); PADDLE_THROW("Var should be LoDTensor or SelectedRows");
} }
...@@ -93,13 +92,13 @@ static std::string GetDtype(const Scope& scope, const std::string& name) { ...@@ -93,13 +92,13 @@ static std::string GetDtype(const Scope& scope, const std::string& name) {
if (UNLIKELY(!tensor.IsInitialized())) { if (UNLIKELY(!tensor.IsInitialized())) {
return ""; return "";
} }
return DataTypeToString(ToDataType(tensor.type())); return DataTypeToString(tensor.type());
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
auto tensor = var->Get<SelectedRows>().value(); auto tensor = var->Get<SelectedRows>().value();
if (UNLIKELY(!tensor.IsInitialized())) { if (UNLIKELY(!tensor.IsInitialized())) {
return "uninited"; return "uninited";
} else { } else {
return DataTypeToString(ToDataType(tensor.type())); return DataTypeToString(tensor.type());
} }
} else { } else {
return ""; return "";
...@@ -686,7 +685,8 @@ static void CheckTensorNANOrInf(const std::string& name, ...@@ -686,7 +685,8 @@ static void CheckTensorNANOrInf(const std::string& name,
if (tensor.memory_size() == 0) { if (tensor.memory_size() == 0) {
return; return;
} }
if (!IsType<float>(tensor.type()) && !IsType<double>(tensor.type())) { if (tensor.type() != proto::VarType::FP32 &&
tensor.type() != proto::VarType::FP64) {
return; return;
} }
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor), PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
...@@ -873,7 +873,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( ...@@ -873,7 +873,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
t = &(var->Get<SelectedRows>().value()); t = &(var->Get<SelectedRows>().value());
} }
if (t != nullptr) { if (t != nullptr) {
int tmp = static_cast<int>(ToDataType(t->type())); int tmp = static_cast<int>(t->type());
PADDLE_ENFORCE( PADDLE_ENFORCE(
tmp == data_type || data_type == -1, tmp == data_type || data_type == -1,
"DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)", "DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)",
......
...@@ -218,11 +218,11 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value, ...@@ -218,11 +218,11 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
if (index < 0) { if (index < 0) {
VLOG(5) << "id " << id << " not in the table, return 0"; VLOG(5) << "id " << id << " not in the table, return 0";
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(value_->type()), value_->type(),
TensorFillVisitor(value, i * value_width, value_width, 0.0)); TensorFillVisitor(value, i * value_width, value_width, 0.0));
} else { } else {
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(value_->type()), value_->type(),
TensorCopyVisitor(value, i * value_width, *value_.get(), TensorCopyVisitor(value, i * value_width, *value_.get(),
index * value_width, value_width)); index * value_width, value_width));
} }
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
extern size_t SizeOfType(std::type_index type); extern size_t SizeOfType(proto::VarType::Type type);
void Tensor::check_memory_size() const { void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor holds no memory. Call Tensor::mutable_data first."); holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
...@@ -31,7 +31,7 @@ size_t Tensor::memory_size() const { ...@@ -31,7 +31,7 @@ size_t Tensor::memory_size() const {
return holder_ == nullptr ? 0UL : holder_->size() - offset_; return holder_ == nullptr ? 0UL : holder_->size() - offset_;
} }
void* Tensor::mutable_data(platform::Place place, std::type_index type, void* Tensor::mutable_data(platform::Place place, proto::VarType::Type type,
memory::Allocator::Attr attr, memory::Allocator::Attr attr,
size_t requested_size) { size_t requested_size) {
type_ = type; type_ = type;
......
...@@ -19,9 +19,9 @@ limitations under the License. */ ...@@ -19,9 +19,9 @@ limitations under the License. */
#include <memory> #include <memory>
#include <typeindex> #include <typeindex>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -67,7 +67,7 @@ class Tensor { ...@@ -67,7 +67,7 @@ class Tensor {
friend struct EigenVector; friend struct EigenVector;
public: public:
Tensor() : type_(typeid(float)), offset_(0) {} Tensor() : type_(proto::VarType::FP32), offset_(0) {}
/*! Return a pointer to mutable memory block. */ /*! Return a pointer to mutable memory block. */
template <typename T> template <typename T>
...@@ -88,7 +88,7 @@ class Tensor { ...@@ -88,7 +88,7 @@ class Tensor {
memory::Allocator::Attr attr = memory::Allocator::kDefault, memory::Allocator::Attr attr = memory::Allocator::kDefault,
size_t requested_size = 0); size_t requested_size = 0);
void* mutable_data(platform::Place place, std::type_index type, void* mutable_data(platform::Place place, proto::VarType::Type type,
memory::Allocator::Attr attr = memory::Allocator::kDefault, memory::Allocator::Attr attr = memory::Allocator::kDefault,
size_t requested_size = 0); size_t requested_size = 0);
...@@ -138,7 +138,7 @@ class Tensor { ...@@ -138,7 +138,7 @@ class Tensor {
return holder_->place(); return holder_->place();
} }
std::type_index type() const { proto::VarType::Type type() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor not initialized yet when Tensor::type() is called."); holder_, "Tensor not initialized yet when Tensor::type() is called.");
return type_; return type_;
...@@ -161,7 +161,7 @@ class Tensor { ...@@ -161,7 +161,7 @@ class Tensor {
private: private:
/*! holds the memory block if allocated. */ /*! holds the memory block if allocated. */
std::shared_ptr<memory::Allocation> holder_; std::shared_ptr<memory::Allocation> holder_;
std::type_index type_; proto::VarType::Type type_;
/** /**
* @brief points to elements dimensions. * @brief points to elements dimensions.
* *
......
...@@ -24,9 +24,8 @@ template <typename T> ...@@ -24,9 +24,8 @@ template <typename T>
inline const T* Tensor::data() const { inline const T* Tensor::data() const {
check_memory_size(); check_memory_size();
bool valid = bool valid =
std::is_same<T, void>::value || type_ == std::type_index(typeid(T)); std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType;
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %d", type_);
type_.name());
return reinterpret_cast<const T*>( return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_); reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
...@@ -38,9 +37,8 @@ template <typename T> ...@@ -38,9 +37,8 @@ template <typename T>
inline T* Tensor::data() { inline T* Tensor::data() {
check_memory_size(); check_memory_size();
bool valid = bool valid =
std::is_same<T, void>::value || type_ == std::type_index(typeid(T)); std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType;
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", type_);
type_.name());
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) + return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_); offset_);
} }
...@@ -60,7 +58,7 @@ inline T* Tensor::mutable_data(platform::Place place, ...@@ -60,7 +58,7 @@ inline T* Tensor::mutable_data(platform::Place place,
size_t requested_size) { size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
return reinterpret_cast<T*>( return reinterpret_cast<T*>(
mutable_data(place, typeid(T), attr, requested_size)); mutable_data(place, DataTypeTrait<T>::DataType, attr, requested_size));
} }
inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
......
...@@ -186,8 +186,8 @@ struct AnyDTypeVisitor { ...@@ -186,8 +186,8 @@ struct AnyDTypeVisitor {
template <typename Predicate, typename DevCtx> template <typename Predicate, typename DevCtx>
inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor, inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor,
const DevCtx& ctx, framework::Tensor* out) { const DevCtx& ctx, framework::Tensor* out) {
VisitDataType(ToDataType(tensor.type()), AnyDTypeVisitor<Predicate, DevCtx>( VisitDataType(tensor.type(), AnyDTypeVisitor<Predicate, DevCtx>(
predicate, tensor, ctx, out)); predicate, tensor, ctx, out));
} }
template <typename Predicate> template <typename Predicate>
...@@ -379,7 +379,7 @@ void TensorToStream(std::ostream& os, const Tensor& tensor, ...@@ -379,7 +379,7 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
// int32_t size // int32_t size
// void* protobuf message // void* protobuf message
proto::VarType::TensorDesc desc; proto::VarType::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type())); desc.set_data_type(tensor.type());
auto dims = framework::vectorize(tensor.dims()); auto dims = framework::vectorize(tensor.dims());
auto* pb_dims = desc.mutable_dims(); auto* pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0); pb_dims->Resize(static_cast<int>(dims.size()), 0);
...@@ -461,9 +461,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor, ...@@ -461,9 +461,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
tensor->Resize(framework::make_ddim(dims)); tensor->Resize(framework::make_ddim(dims));
void* buf; void* buf;
auto ctx = platform::CPUDeviceContext(); auto ctx = platform::CPUDeviceContext();
size_t size = size_t size = tensor->numel() * framework::SizeOfType(desc.data_type());
tensor->numel() *
framework::SizeOfType(framework::ToTypeIndex(desc.data_type()));
if (platform::is_gpu_place(dev_ctx.GetPlace())) { if (platform::is_gpu_place(dev_ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
Tensor cpu_tensor; Tensor cpu_tensor;
......
...@@ -289,10 +289,10 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs, ...@@ -289,10 +289,10 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
auto type = fetch.type(); auto type = fetch.type();
auto output = &(outputs->at(i)); auto output = &(outputs->at(i));
output->name = fetchs_[idx]->Input("X")[0]; output->name = fetchs_[idx]->Input("X")[0];
if (type == typeid(float)) { if (type == framework::proto::VarType::FP32) {
GetFetchOne<float>(fetch, output); GetFetchOne<float>(fetch, output);
output->dtype = PaddleDType::FLOAT32; output->dtype = PaddleDType::FLOAT32;
} else if (type == typeid(int64_t)) { } else if (type == framework::proto::VarType::INT64) {
GetFetchOne<int64_t>(fetch, output); GetFetchOne<int64_t>(fetch, output);
output->dtype = PaddleDType::INT64; output->dtype = PaddleDType::INT64;
} else { } else {
......
...@@ -266,10 +266,10 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs, ...@@ -266,10 +266,10 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs,
auto type = fetch.type(); auto type = fetch.type();
auto output = &(outputs->at(i)); auto output = &(outputs->at(i));
output->name = fetchs_[idx]->Input("X")[0]; output->name = fetchs_[idx]->Input("X")[0];
if (type == typeid(float)) { if (type == framework::DataTypeTrait<float>::DataType) {
GetFetchOne<float>(fetch, output); GetFetchOne<float>(fetch, output);
output->dtype = PaddleDType::FLOAT32; output->dtype = PaddleDType::FLOAT32;
} else if (type == typeid(int64_t)) { } else if (type == framework::DataTypeTrait<int64_t>::DataType) {
GetFetchOne<int64_t>(fetch, output); GetFetchOne<int64_t>(fetch, output);
output->dtype = PaddleDType::INT64; output->dtype = PaddleDType::INT64;
} else { } else {
......
...@@ -36,10 +36,10 @@ namespace paddle { ...@@ -36,10 +36,10 @@ namespace paddle {
PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) { PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) {
PaddleTensor pt; PaddleTensor pt;
if (t->type() == typeid(int64_t)) { if (t->type() == framework::proto::VarType::INT64) {
pt.data.Reset(t->data<void>(), t->numel() * sizeof(int64_t)); pt.data.Reset(t->data<void>(), t->numel() * sizeof(int64_t));
pt.dtype = PaddleDType::INT64; pt.dtype = PaddleDType::INT64;
} else if (t->type() == typeid(float)) { } else if (t->type() == framework::proto::VarType::FP32) {
pt.data.Reset(t->data<void>(), t->numel() * sizeof(float)); pt.data.Reset(t->data<void>(), t->numel() * sizeof(float));
pt.dtype = PaddleDType::FLOAT32; pt.dtype = PaddleDType::FLOAT32;
} else { } else {
......
...@@ -361,7 +361,7 @@ static bool CompareTensorData(const framework::LoDTensor &a, ...@@ -361,7 +361,7 @@ static bool CompareTensorData(const framework::LoDTensor &a,
} }
for (size_t i = 0; i < a_size; i++) { for (size_t i = 0; i < a_size; i++) {
if (a.type() == typeid(float)) { if (a.type() == framework::proto::VarType::FP32) {
const auto *a_data = a.data<float>(); const auto *a_data = a.data<float>();
const auto *b_data = b.data<float>(); const auto *b_data = b.data<float>();
if (std::abs(a_data[i] - b_data[i]) > 1e-3) { if (std::abs(a_data[i] - b_data[i]) > 1e-3) {
...@@ -370,7 +370,7 @@ static bool CompareTensorData(const framework::LoDTensor &a, ...@@ -370,7 +370,7 @@ static bool CompareTensorData(const framework::LoDTensor &a,
b_data[i]); b_data[i]);
return false; return false;
} }
} else if (a.type() == typeid(int64_t)) { } else if (a.type() == framework::proto::VarType::INT64) {
const auto *a_data = a.data<int64_t>(); const auto *a_data = a.data<int64_t>();
const auto *b_data = b.data<int64_t>(); const auto *b_data = b.data<int64_t>();
if (std::abs(a_data[i] - b_data[i]) > 1e-3) { if (std::abs(a_data[i] - b_data[i]) > 1e-3) {
......
...@@ -78,7 +78,7 @@ class AffineGridOp : public framework::OperatorWithKernel { ...@@ -78,7 +78,7 @@ class AffineGridOp : public framework::OperatorWithKernel {
library = framework::LibraryType::kCUDNN; library = framework::LibraryType::kCUDNN;
} }
#endif #endif
auto data_type = framework::ToDataType(ctx.Input<Tensor>("Theta")->type()); auto data_type = ctx.Input<Tensor>("Theta")->type();
return framework::OpKernelType(data_type, ctx.GetPlace(), return framework::OpKernelType(data_type, ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library); framework::DataLayout::kAnyLayout, library);
} }
...@@ -188,9 +188,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { ...@@ -188,9 +188,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Theta")->type(),
framework::ToDataType(ctx.Input<Tensor>("Theta")->type()), ctx.GetPlace(),
ctx.GetPlace(), framework::DataLayout::kAnyLayout, library_); framework::DataLayout::kAnyLayout, library_);
} }
}; };
......
...@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL( ...@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL(
int32_t>, int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int16_t>, int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, size_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
uint8_t>); uint8_t>);
...@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL(
int32_t>, int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int16_t>, int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
size_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
uint8_t>); uint8_t>);
...@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL( ...@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL(
int32_t>, int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int16_t>, int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, size_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
uint8_t>); uint8_t>);
...@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL(
int32_t>, int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
int16_t>, int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
size_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
uint8_t>); uint8_t>);
...@@ -58,7 +58,7 @@ struct ArrayToLoDFunctor : public boost::static_visitor<void> { ...@@ -58,7 +58,7 @@ struct ArrayToLoDFunctor : public boost::static_visitor<void> {
ArrayToLoDFunctorImpl<DeviceContext> functor; ArrayToLoDFunctorImpl<DeviceContext> functor;
functor.dev_ctx_ = dev_ctx; functor.dev_ctx_ = dev_ctx;
functor.prev_functor_ = this; functor.prev_functor_ = this;
framework::VisitDataType(framework::ToDataType(out->type()), functor); framework::VisitDataType(out->type(), functor);
} }
}; };
...@@ -91,7 +91,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { ...@@ -91,7 +91,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
PADDLE_ENFORCE(!x.empty(), "There's no element in the input array."); PADDLE_ENFORCE(!x.empty(), "There's no element in the input array.");
int rank = x[0].dims().size(); int rank = x[0].dims().size();
platform::Place place = x[0].place(); platform::Place place = x[0].place();
std::type_index data_type = x[0].type(); auto data_type = x[0].type();
int64_t batch_size = x[0].dims()[0]; int64_t batch_size = x[0].dims()[0];
framework::DDim ins_dims = rank > 1 framework::DDim ins_dims = rank > 1
? framework::slice_ddim(x[0].dims(), 1, rank) ? framework::slice_ddim(x[0].dims(), 1, rank)
......
...@@ -121,9 +121,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -121,9 +121,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType( framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
void AttentionLSTMOpMaker::Make() { void AttentionLSTMOpMaker::Make() {
......
...@@ -103,9 +103,8 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel { ...@@ -103,9 +103,8 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("param")->type(),
framework::ToDataType(ctx.Input<Tensor>("param")->type()), ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -72,8 +72,7 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -72,8 +72,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type = ctx.Input<Tensor>("X")->type();
framework::ToDataType(ctx.Input<Tensor>("X")->type());
// By default, the type of the scale, bias, mean, // By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor) // and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor). // or double (For double input tensor).
...@@ -81,17 +80,13 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -81,17 +80,13 @@ class BatchNormOp : public framework::OperatorWithKernel {
if (input_data_type == framework::proto::VarType::FP64) { if (input_data_type == framework::proto::VarType::FP64) {
bn_param_type = framework::proto::VarType::FP64; bn_param_type = framework::proto::VarType::FP64;
} }
PADDLE_ENFORCE_EQ(bn_param_type, PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Scale")->type(),
framework::ToDataType(ctx.Input<Tensor>("Scale")->type()),
"Scale input should be of float type"); "Scale input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type, PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Bias")->type(),
framework::ToDataType(ctx.Input<Tensor>("Bias")->type()),
"Bias input should be of float type"); "Bias input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type, PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Mean")->type(),
framework::ToDataType(ctx.Input<Tensor>("Mean")->type()),
"Mean input should be of float type"); "Mean input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType( PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Variance")->type(),
ctx.Input<Tensor>("Variance")->type()),
"Variance input should be of float type"); "Variance input should be of float type");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
...@@ -387,9 +382,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel { ...@@ -387,9 +382,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), ctx.GetPlace(), layout, library);
layout, library);
} }
}; };
......
...@@ -145,7 +145,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase { ...@@ -145,7 +145,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
LoDTensor* sentenceScores = ctx.Output<LoDTensor>("SentenceScores"); LoDTensor* sentenceScores = ctx.Output<LoDTensor>("SentenceScores");
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(scores->at(0).type()), scores->at(0).type(),
BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores, BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores,
beam_size, end_id)); beam_size, end_id));
} }
......
...@@ -282,8 +282,7 @@ class BeamSearchOp : public framework::OperatorWithKernel { ...@@ -282,8 +282,7 @@ class BeamSearchOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = framework::OpKernelType( framework::OpKernelType kt = framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::LoDTensor>("pre_ids")->type(),
ctx.Input<framework::LoDTensor>("pre_ids")->type()),
platform::CPUPlace()); platform::CPUPlace());
return kt; return kt;
} }
......
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/bilinear_tensor_product_op.h" #include "paddle/fluid/operators/bilinear_tensor_product_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -48,13 +48,12 @@ class ConditionalOp : public framework::OperatorBase { ...@@ -48,13 +48,12 @@ class ConditionalOp : public framework::OperatorBase {
if (!(ips.size() == 1UL && ips[0]->IsInitialized())) { if (!(ips.size() == 1UL && ips[0]->IsInitialized())) {
PADDLE_THROW("should have one initialized input as condition"); PADDLE_THROW("should have one initialized input as condition");
} }
if (!(framework::IsType<bool>(ips[0]->type()) && // NOLINT
ips[0]->numel() == 1)) { PADDLE_ENFORCE(ips[0]->type() == framework::proto::VarType::BOOL &&
PADDLE_THROW( ips[0]->numel() == 1,
"condition input's data type should be bool, " "condition input's data type should be bool, "
"numel should be 1, actual numel is %d", "numel should be 1, actual numel is %d",
ips[0]->numel()); ips[0]->numel());
}
bool res = false; bool res = false;
if (platform::is_gpu_place(ips[0]->place())) { if (platform::is_gpu_place(ips[0]->place())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
...@@ -237,7 +237,7 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -237,7 +237,7 @@ class WhileGradOp : public framework::OperatorBase {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
auto &inside_tensor = var->Get<framework::LoDTensor>(); auto &inside_tensor = var->Get<framework::LoDTensor>();
framework::AttributeMap attrs; framework::AttributeMap attrs;
attrs["dtype"] = framework::ToDataType(inside_tensor.type()); attrs["dtype"] = inside_tensor.type();
attrs["shape"] = framework::vectorize2int(inside_tensor.dims()); attrs["shape"] = framework::vectorize2int(inside_tensor.dims());
attrs["value"] = 0.0f; attrs["value"] = 0.0f;
......
...@@ -92,10 +92,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -92,10 +92,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
} }
#endif #endif
auto input_data_type = auto input_data_type = ctx.Input<Tensor>("Input")->type();
framework::ToDataType(ctx.Input<Tensor>("Input")->type()); auto filter_data_type = ctx.Input<Tensor>("Filter")->type();
auto filter_data_type =
framework::ToDataType(ctx.Input<Tensor>("Filter")->type());
PADDLE_ENFORCE_EQ(input_data_type, filter_data_type, PADDLE_ENFORCE_EQ(input_data_type, filter_data_type,
"input and filter data type should be consistent"); "input and filter data type should be consistent");
...@@ -360,9 +358,8 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( ...@@ -360,9 +358,8 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), ctx.GetPlace(), layout_, library_);
layout_, library_);
} }
} // namespace operators } // namespace operators
......
...@@ -95,9 +95,8 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( ...@@ -95,9 +95,8 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), ctx.GetPlace(), layout_, library_);
layout_, library_);
} }
void Conv2DTransposeOpMaker::Make() { void Conv2DTransposeOpMaker::Make() {
...@@ -314,9 +313,8 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( ...@@ -314,9 +313,8 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), ctx.GetPlace(), layout_, library_);
layout_, library_);
} }
} // namespace operators } // namespace operators
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/cos_sim_op.h" #include "paddle/fluid/operators/cos_sim_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -118,9 +118,8 @@ class CRFDecodingOp : public framework::OperatorWithKernel { ...@@ -118,9 +118,8 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<LoDTensor>("Emission")->type(),
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -51,9 +51,8 @@ class CropOp : public framework::OperatorWithKernel { ...@@ -51,9 +51,8 @@ class CropOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -174,9 +173,7 @@ class CropOpGrad : public framework::OperatorWithKernel { ...@@ -174,9 +173,7 @@ class CropOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))
->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/crop_op.h" #include "paddle/fluid/operators/crop_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -57,9 +57,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel { ...@@ -57,9 +57,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -111,9 +110,8 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -111,9 +110,8 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -36,9 +36,8 @@ class CTCAlignOp : public framework::OperatorWithKernel { ...@@ -36,9 +36,8 @@ class CTCAlignOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -53,8 +53,7 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel { ...@@ -53,8 +53,7 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()), ctx.Input<framework::Tensor>("Input")->type(), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -45,9 +45,8 @@ class BipartiteMatchOp : public framework::OperatorWithKernel { ...@@ -45,9 +45,8 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<LoDTensor>("DistMat")->type(),
framework::ToDataType(ctx.Input<LoDTensor>("DistMat")->type()), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
......
...@@ -66,8 +66,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel { ...@@ -66,8 +66,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()), ctx.Input<framework::Tensor>("Input")->type(), ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -66,9 +66,8 @@ class GenerateProposalsOp : public framework::OperatorWithKernel { ...@@ -66,9 +66,8 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Anchors")->type(),
framework::ToDataType(ctx.Input<Tensor>("Anchors")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -249,8 +249,7 @@ class MineHardExamplesOp : public framework::OperatorWithKernel { ...@@ -249,8 +249,7 @@ class MineHardExamplesOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("ClsLoss")->type()), ctx.Input<framework::Tensor>("ClsLoss")->type(), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
......
...@@ -65,8 +65,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { ...@@ -65,8 +65,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::LoDTensor>("Scores")->type(),
ctx.Input<framework::LoDTensor>("Scores")->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -72,8 +72,7 @@ class PriorBoxOp : public framework::OperatorWithKernel { ...@@ -72,8 +72,7 @@ class PriorBoxOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()), ctx.Input<framework::Tensor>("Input")->type(), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -498,9 +498,8 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel { ...@@ -498,9 +498,8 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -519,9 +518,8 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel { ...@@ -519,9 +518,8 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -78,8 +78,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { ...@@ -78,8 +78,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::LoDTensor>("Anchor")->type(),
ctx.Input<framework::LoDTensor>("Anchor")->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -57,9 +57,8 @@ class TargetAssignOp : public framework::OperatorWithKernel { ...@@ -57,9 +57,8 @@ class TargetAssignOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -71,8 +71,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel { ...@@ -71,8 +71,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::Tensor>("DetectRes")->type(),
ctx.Input<framework::Tensor>("DetectRes")->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -122,8 +122,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -122,8 +122,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
if (var->IsType<framework::SelectedRows>()) { if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
ProtoEncodeHelper e2(static_cast<char*>(buf), 128); ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
size_t rows_memory_size = size_t rows_memory_size = slr->rows().size() * sizeof(int64_t);
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size); e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
slices[2] = ::grpc::Slice(e2.size()); slices[2] = ::grpc::Slice(e2.size());
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size()); memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
......
...@@ -61,8 +61,7 @@ TensorPayload GetTensorPayload(framework::Variable* var, ...@@ -61,8 +61,7 @@ TensorPayload GetTensorPayload(framework::Variable* var,
auto tensor = var->Get<framework::LoDTensor>(); auto tensor = var->Get<framework::LoDTensor>();
// FIXME(wuyi): data types in send_recv.proto is copied from // FIXME(wuyi): data types in send_recv.proto is copied from
// framework.proto // framework.proto
request->set_data_type( request->set_data_type(static_cast<VarMsg::Type>(tensor.type()));
static_cast<VarMsg::Type>(framework::ToDataType(tensor.type())));
for (auto& dim : framework::vectorize(tensor.dims())) { for (auto& dim : framework::vectorize(tensor.dims())) {
request->add_dims(dim); request->add_dims(dim);
} }
...@@ -83,8 +82,7 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var, ...@@ -83,8 +82,7 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
VarMsg* request) { VarMsg* request) {
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
request->set_data_type( request->set_data_type(static_cast<VarMsg::Type>(slr->value().type()));
static_cast<VarMsg::Type>(framework::ToDataType(slr->value().type())));
request->set_lod_level(0); request->set_lod_level(0);
request->set_slr_height(slr->height()); request->set_slr_height(slr->height());
......
...@@ -58,18 +58,19 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var, ...@@ -58,18 +58,19 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
VarMsg* request); VarMsg* request);
inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) { inline framework::proto::VarType::Type ToVarType(
sendrecv::VariableMessage::Type type) {
switch (type) { switch (type) {
case sendrecv::VariableMessage::FP32: case sendrecv::VariableMessage::FP32:
return typeid(float); // NOLINT return framework::proto::VarType::FP32; // NOLINT
case sendrecv::VariableMessage::FP64: case sendrecv::VariableMessage::FP64:
return typeid(double); // NOLINT return framework::proto::VarType::FP64; // NOLINT
case sendrecv::VariableMessage::INT32: case sendrecv::VariableMessage::INT32:
return typeid(int); // NOLINT return framework::proto::VarType::INT32; // NOLINT
case sendrecv::VariableMessage::INT64: case sendrecv::VariableMessage::INT64:
return typeid(int64_t); // NOLINT return framework::proto::VarType::INT64; // NOLINT
case sendrecv::VariableMessage::BOOL: case sendrecv::VariableMessage::BOOL:
return typeid(bool); // NOLINT return framework::proto::VarType::BOOL; // NOLINT
default: default:
PADDLE_THROW("Not support type %d", type); PADDLE_THROW("Not support type %d", type);
} }
......
...@@ -114,7 +114,7 @@ bool VariableResponse::CopyLodTensorData( ...@@ -114,7 +114,7 @@ bool VariableResponse::CopyLodTensorData(
tensor->set_lod(lod); tensor->set_lod(lod);
void* tensor_data = void* tensor_data =
tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type())); tensor->mutable_data(ctx.GetPlace(), ToVarType(meta_.data_type()));
VLOG(6) << "Tensor.memory_size = " << tensor->memory_size() VLOG(6) << "Tensor.memory_size = " << tensor->memory_size()
<< ", Buffer Size = " << length; << ", Buffer Size = " << length;
...@@ -139,13 +139,13 @@ bool VariableResponse::CopySelectRowsTensorData( ...@@ -139,13 +139,13 @@ bool VariableResponse::CopySelectRowsTensorData(
slr->set_height(meta_.slr_height()); slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value(); auto* tensor = slr->mutable_value();
tensor->Resize(dims); tensor->Resize(dims);
PADDLE_ENFORCE_EQ(static_cast<size_t>(tensor->numel()), PADDLE_ENFORCE_EQ(
length / framework::SizeOfType( static_cast<size_t>(tensor->numel()),
paddle::operators::distributed::ToTypeIndex( length / framework::SizeOfType(paddle::operators::distributed::ToVarType(
meta_.data_type()))); meta_.data_type())));
void* tensor_data = tensor->mutable_data( void* tensor_data = tensor->mutable_data(
ctx.GetPlace(), ctx.GetPlace(),
paddle::operators::distributed::ToTypeIndex(meta_.data_type())); paddle::operators::distributed::ToVarType(meta_.data_type()));
if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) { if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
return false; return false;
...@@ -159,8 +159,7 @@ bool VariableResponse::CopySelectRowsData( ...@@ -159,8 +159,7 @@ bool VariableResponse::CopySelectRowsData(
const platform::DeviceContext& ctx, int length) { const platform::DeviceContext& ctx, int length) {
auto* slr = GetVar()->GetMutable<framework::SelectedRows>(); auto* slr = GetVar()->GetMutable<framework::SelectedRows>();
slr->mutable_rows()->clear(); slr->mutable_rows()->clear();
slr->mutable_rows()->resize(length / slr->mutable_rows()->resize(length / sizeof(int64_t)); // int64
framework::SizeOfType(typeid(int64_t))); // int64
int64_t* rows_data = slr->mutable_rows()->data(); int64_t* rows_data = slr->mutable_rows()->data();
// copy rows CPU data, GPU data will be copied lazily. // copy rows CPU data, GPU data will be copied lazily.
......
...@@ -108,9 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel { ...@@ -108,9 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.MultiInput<framework::Tensor>("X").front()->type(), ctx.GetPlace());
ctx.MultiInput<framework::Tensor>("X").front()->type()),
ctx.GetPlace());
} }
}; };
......
...@@ -42,9 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel { ...@@ -42,9 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.MultiInput<framework::Tensor>("X")[0]->type(), ctx.GetPlace());
ctx.MultiInput<framework::Tensor>("X")[0]->type()),
ctx.GetPlace());
} }
}; };
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include <thrust/device_ptr.h> #include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h> #include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h> #include <thrust/random.h>
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h" #include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h" #include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise/elementwise_min_op.h" #include "paddle/fluid/operators/elementwise/elementwise_min_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -197,8 +197,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -197,8 +197,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = framework::ToDataType( auto input_data_type =
ctx.Input<Tensor>(framework::GradVarName("Out"))->type()); ctx.Input<Tensor>(framework::GradVarName("Out"))->type();
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) { if (platform::CanMKLDNNBeUsed(ctx)) {
......
...@@ -8,8 +8,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -8,8 +8,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise/elementwise_pow_op.h" #include "paddle/fluid/operators/elementwise/elementwise_pow_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" #include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/expand_op.h" #include "paddle/fluid/operators/expand_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -115,9 +115,8 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -115,9 +115,8 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -175,9 +174,8 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -175,9 +174,8 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -79,9 +79,8 @@ framework::OpKernelType FCOp::GetExpectedKernelType( ...@@ -79,9 +79,8 @@ framework::OpKernelType FCOp::GetExpectedKernelType(
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), ctx.GetPlace(), layout, library);
layout, library);
} }
void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const { void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const {
...@@ -111,9 +110,8 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType( ...@@ -111,9 +110,8 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), ctx.GetPlace(), layout, library);
layout, library);
} }
void FCOpMaker::Make() { void FCOpMaker::Make() {
......
...@@ -59,9 +59,9 @@ class FillConstantOp : public framework::OperatorBase { ...@@ -59,9 +59,9 @@ class FillConstantOp : public framework::OperatorBase {
if (force_cpu) { if (force_cpu) {
auto cpu = platform::CPUPlace(); auto cpu = platform::CPUPlace();
tensor->mutable_data(cpu, framework::ToTypeIndex(data_type)); tensor->mutable_data(cpu, data_type);
} else { } else {
tensor->mutable_data(dev_place, framework::ToTypeIndex(data_type)); tensor->mutable_data(dev_place, data_type);
} }
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
......
...@@ -55,7 +55,7 @@ class FillOp : public framework::OperatorBase { ...@@ -55,7 +55,7 @@ class FillOp : public framework::OperatorBase {
static_cast<framework::proto::VarType::Type>(Attr<int>("dtype")); static_cast<framework::proto::VarType::Type>(Attr<int>("dtype"));
platform::CPUPlace cpu; platform::CPUPlace cpu;
auto force_cpu = Attr<bool>("force_cpu"); auto force_cpu = Attr<bool>("force_cpu");
out.mutable_data(force_cpu ? cpu : place, framework::ToTypeIndex(dtype)); out.mutable_data(force_cpu ? cpu : place, dtype);
framework::LoDTensor tensor; framework::LoDTensor tensor;
...@@ -64,7 +64,7 @@ class FillOp : public framework::OperatorBase { ...@@ -64,7 +64,7 @@ class FillOp : public framework::OperatorBase {
} else { } else {
// Always make tensor in CPU memory. // Always make tensor in CPU memory.
tensor.Resize(out.dims()); tensor.Resize(out.dims());
tensor.mutable_data(cpu, framework::ToTypeIndex(dtype)); tensor.mutable_data(cpu, dtype);
} }
framework::VisitDataType( framework::VisitDataType(
......
...@@ -135,9 +135,8 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel { ...@@ -135,9 +135,8 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx.Input<framework::Tensor>("X")->type(), PADDLE_ENFORCE_EQ(ctx.Input<framework::Tensor>("X")->type(),
ctx.Input<framework::Tensor>("Y")->type(), ctx.Input<framework::Tensor>("Y")->type(),
"The element's type of input should be the same."); "The element's type of input should be the same.");
auto input_data_type = return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()); ctx.GetPlace());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -324,9 +323,8 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { ...@@ -324,9 +323,8 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type_index = ctx.Input<framework::Tensor>("Y")->type(); return framework::OpKernelType(ctx.Input<framework::Tensor>("Y")->type(),
auto input_data_type = framework::ToDataType(input_data_type_index); ctx.GetPlace());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -115,8 +115,7 @@ void FusedEmbeddingFCLSTMOp::InferShape( ...@@ -115,8 +115,7 @@ void FusedEmbeddingFCLSTMOp::InferShape(
framework::OpKernelType FusedEmbeddingFCLSTMOp::GetExpectedKernelType( framework::OpKernelType FusedEmbeddingFCLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::LoDTensor>("Embeddings")->type(),
ctx.Input<framework::LoDTensor>("Embeddings")->type()),
ctx.device_context()); ctx.device_context());
} }
......
...@@ -93,9 +93,8 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -93,9 +93,8 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionGRUOp::GetExpectedKernelType( framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
void FusionGRUOpMaker::Make() { void FusionGRUOpMaker::Make() {
......
...@@ -117,9 +117,8 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -117,9 +117,8 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionLSTMOp::GetExpectedKernelType( framework::OpKernelType FusionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
void FusionLSTMOpMaker::Make() { void FusionLSTMOpMaker::Make() {
......
...@@ -61,9 +61,8 @@ void FusionSeqConvEltAddReluOp::InferShape( ...@@ -61,9 +61,8 @@ void FusionSeqConvEltAddReluOp::InferShape(
framework::OpKernelType FusionSeqConvEltAddReluOp::GetExpectedKernelType( framework::OpKernelType FusionSeqConvEltAddReluOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
void FusionSeqConvEltAddReluOpMaker::Make() { void FusionSeqConvEltAddReluOpMaker::Make() {
......
...@@ -67,9 +67,8 @@ void FusionSeqExpandConcatFCOp::InferShape( ...@@ -67,9 +67,8 @@ void FusionSeqExpandConcatFCOp::InferShape(
framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType( framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(ctx.MultiInput<LoDTensor>("X")[0]->type(),
framework::ToDataType(ctx.MultiInput<LoDTensor>("X")[0]->type()), ctx.device_context());
ctx.device_context());
} }
void FusionSeqExpandConcatFCOpMaker::Make() { void FusionSeqExpandConcatFCOpMaker::Make() {
......
...@@ -42,9 +42,8 @@ class GatherOp : public framework::OperatorWithKernel { ...@@ -42,9 +42,8 @@ class GatherOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -60,9 +59,8 @@ class GatherGradOp : public framework::OperatorWithKernel { ...@@ -60,9 +59,8 @@ class GatherGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -63,9 +63,9 @@ class GridSampleOp : public framework::OperatorWithKernel { ...@@ -63,9 +63,9 @@ class GridSampleOp : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_); framework::DataLayout::kAnyLayout, library_);
} }
}; };
...@@ -159,9 +159,9 @@ class GridSampleOpGrad : public framework::OperatorWithKernel { ...@@ -159,9 +159,9 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_); framework::DataLayout::kAnyLayout, library_);
} }
}; };
......
...@@ -141,8 +141,7 @@ class GroupNormGradOp : public framework::OperatorWithKernel { ...@@ -141,8 +141,7 @@ class GroupNormGradOp : public framework::OperatorWithKernel {
if (t == nullptr) { if (t == nullptr) {
PADDLE_THROW("can't find Y@GRAD"); PADDLE_THROW("can't find Y@GRAD");
} }
return framework::OpKernelType(framework::ToDataType(t->type()), return framework::OpKernelType(t->type(), ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/gru_unit_op.h" #include "paddle/fluid/operators/gru_unit_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -76,9 +76,8 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { ...@@ -76,9 +76,8 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.GetPlace());
ctx.GetPlace());
} }
}; };
...@@ -162,9 +161,8 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { ...@@ -162,9 +161,8 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/hinge_loss_op.h" #include "paddle/fluid/operators/hinge_loss_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/huber_loss_op.h" #include "paddle/fluid/operators/huber_loss_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -11,8 +11,6 @@ ...@@ -11,8 +11,6 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/im2sequence_op.h" #include "paddle/fluid/operators/im2sequence_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -55,8 +55,8 @@ class InterpolateOp : public framework::OperatorWithKernel { ...@@ -55,8 +55,8 @@ class InterpolateOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace()); ctx.GetPlace());
} }
}; };
...@@ -124,8 +124,8 @@ class InterpolateOpGrad : public framework::OperatorWithKernel { ...@@ -124,8 +124,8 @@ class InterpolateOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -35,8 +35,7 @@ class IsEmptyOp : public framework::OperatorWithKernel { ...@@ -35,8 +35,7 @@ class IsEmptyOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = framework::OpKernelType( framework::OpKernelType kt = framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.Input<framework::LoDTensor>("X")->type(), platform::CPUPlace());
platform::CPUPlace());
return kt; return kt;
} }
}; };
......
...@@ -40,10 +40,9 @@ class OverflowOp : public framework::OperatorWithKernel { ...@@ -40,10 +40,9 @@ class OverflowOp : public framework::OperatorWithKernel {
int dtype = -1; int dtype = -1;
auto *x_var = ctx.InputVar("X"); auto *x_var = ctx.InputVar("X");
if (x_var->IsType<framework::LoDTensor>()) { if (x_var->IsType<framework::LoDTensor>()) {
dtype = framework::ToDataType(x_var->Get<framework::LoDTensor>().type()); dtype = x_var->Get<framework::LoDTensor>().type();
} else if (x_var->IsType<framework::SelectedRows>()) { } else if (x_var->IsType<framework::SelectedRows>()) {
dtype = framework::ToDataType( dtype = x_var->Get<framework::SelectedRows>().value().type();
x_var->Get<framework::SelectedRows>().value().type());
} else { } else {
PADDLE_THROW("Cannot find the input data type by all input data"); PADDLE_THROW("Cannot find the input data type by all input data");
} }
......
...@@ -11,8 +11,6 @@ ...@@ -11,8 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/isfinite_op.h" #include "paddle/fluid/operators/isfinite_op.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/l1_norm_op.h" #include "paddle/fluid/operators/l1_norm_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -153,8 +153,7 @@ class LayerNormGradOp : public framework::OperatorWithKernel { ...@@ -153,8 +153,7 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
if (t == nullptr) { if (t == nullptr) {
PADDLE_THROW("can't find Y@GRAD"); PADDLE_THROW("can't find Y@GRAD");
} }
return framework::OpKernelType(framework::ToDataType(t->type()), return framework::OpKernelType(t->type(), ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -184,9 +184,8 @@ class LinearChainCRFOp : public framework::OperatorWithKernel { ...@@ -184,9 +184,8 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
// is determined by its input "Emission". // is determined by its input "Emission".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<LoDTensor>("Emission")->type(),
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
...@@ -244,9 +243,7 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel { ...@@ -244,9 +243,7 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<LoDTensor>(framework::GradVarName("LogLikelihood"))->type(),
ctx.Input<LoDTensor>(framework::GradVarName("LogLikelihood"))
->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -60,7 +60,7 @@ class LoadCombineOp : public framework::OperatorBase { ...@@ -60,7 +60,7 @@ class LoadCombineOp : public framework::OperatorBase {
// Get data from fin to tensor // Get data from fin to tensor
DeserializeFromStream(fin, tensor, dev_ctx); DeserializeFromStream(fin, tensor, dev_ctx);
auto in_dtype = framework::ToDataType(tensor->type()); auto in_dtype = tensor->type();
auto out_dtype = auto out_dtype =
load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
......
...@@ -65,7 +65,7 @@ class LoadOp : public framework::OperatorBase { ...@@ -65,7 +65,7 @@ class LoadOp : public framework::OperatorBase {
DeserializeFromStream(fin, tensor, dev_ctx); DeserializeFromStream(fin, tensor, dev_ctx);
auto load_as_fp16 = Attr<bool>("load_as_fp16"); auto load_as_fp16 = Attr<bool>("load_as_fp16");
auto in_dtype = framework::ToDataType(tensor->type()); auto in_dtype = tensor->type();
auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) { if (in_dtype != out_dtype) {
......
...@@ -39,9 +39,8 @@ class LoDResetOp : public framework::OperatorWithKernel { ...@@ -39,9 +39,8 @@ class LoDResetOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -144,9 +143,8 @@ class LoDResetGradOp : public framework::OperatorWithKernel { ...@@ -144,9 +143,8 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -72,7 +72,7 @@ struct LoDTensorToArrayFunctor : public boost::static_visitor<void> { ...@@ -72,7 +72,7 @@ struct LoDTensorToArrayFunctor : public boost::static_visitor<void> {
LoDTensorToArrayFunctorImpl<DeviceContext> func; LoDTensorToArrayFunctorImpl<DeviceContext> func;
func.prev_functor_ = this; func.prev_functor_ = this;
func.dev_ctx_ = dev_ctx; func.dev_ctx_ = dev_ctx;
framework::VisitDataType(framework::ToDataType(input_.type()), func); framework::VisitDataType(input_.type(), func);
} }
}; };
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/log_loss_op.h" #include "paddle/fluid/operators/log_loss_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -63,8 +63,7 @@ class LookupSparseTableOp : public framework::OperatorBase { ...@@ -63,8 +63,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
out_shape[0] = ids_t.numel(); out_shape[0] = ids_t.numel();
out_t->Resize(out_shape); out_t->Resize(out_shape);
out_t->mutable_data(cpu, w_t->value().type()); out_t->mutable_data(cpu, w_t->value().type());
PADDLE_ENFORCE_EQ(framework::ToDataType(w_t->value().type()), PADDLE_ENFORCE_EQ(w_t->value().type(), framework::proto::VarType::FP32,
framework::proto::VarType::FP32,
"The sparse table only support FP32"); "The sparse table only support FP32");
w_t->Get(ids_t, out_t, true, is_test); w_t->Get(ids_t, out_t, true, is_test);
out_t->set_lod(ids_t.lod()); out_t->set_lod(ids_t.lod());
......
...@@ -145,9 +145,8 @@ framework::OpKernelType GetExpectedLRNKernel( ...@@ -145,9 +145,8 @@ framework::OpKernelType GetExpectedLRNKernel(
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), ctx.GetPlace(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), layout_, library_);
layout_, library_);
} }
} // namespace } // namespace
......
...@@ -96,8 +96,7 @@ class LSTMOp : public framework::OperatorWithKernel { ...@@ -96,8 +96,7 @@ class LSTMOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()), ctx.Input<framework::LoDTensor>("Input")->type(), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -261,8 +260,7 @@ class LSTMGradOp : public framework::OperatorWithKernel { ...@@ -261,8 +260,7 @@ class LSTMGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()), ctx.Input<framework::LoDTensor>("Input")->type(), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -113,8 +113,7 @@ class LSTMPOp : public framework::OperatorWithKernel { ...@@ -113,8 +113,7 @@ class LSTMPOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()), ctx.Input<framework::LoDTensor>("Input")->type(), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -312,8 +311,7 @@ class LSTMPGradOp : public framework::OperatorWithKernel { ...@@ -312,8 +311,7 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()), ctx.Input<framework::LoDTensor>("Input")->type(), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/math/context_project.h" #include "paddle/fluid/operators/math/context_project.h"
namespace paddle { namespace paddle {
......
...@@ -77,16 +77,14 @@ template <> ...@@ -77,16 +77,14 @@ template <>
void set_constant_with_place<platform::CPUPlace>( void set_constant_with_place<platform::CPUPlace>(
const platform::DeviceContext& context, framework::Tensor* tensor, const platform::DeviceContext& context, framework::Tensor* tensor,
float value) { float value) {
framework::VisitDataType(framework::ToDataType(tensor->type()), framework::VisitDataType(tensor->type(), TensorSetConstantCPU(tensor, value));
TensorSetConstantCPU(tensor, value));
} }
template <> template <>
void set_constant_with_place<platform::CUDAPinnedPlace>( void set_constant_with_place<platform::CUDAPinnedPlace>(
const platform::DeviceContext& context, framework::Tensor* tensor, const platform::DeviceContext& context, framework::Tensor* tensor,
float value) { float value) {
framework::VisitDataType(framework::ToDataType(tensor->type()), framework::VisitDataType(tensor->type(), TensorSetConstantCPU(tensor, value));
TensorSetConstantCPU(tensor, value));
} }
struct TensorSetConstantWithPlace : public boost::static_visitor<void> { struct TensorSetConstantWithPlace : public boost::static_visitor<void> {
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
...@@ -67,7 +65,7 @@ template <> ...@@ -67,7 +65,7 @@ template <>
void set_constant_with_place<platform::CUDAPlace>( void set_constant_with_place<platform::CUDAPlace>(
const platform::DeviceContext& context, framework::Tensor* tensor, const platform::DeviceContext& context, framework::Tensor* tensor,
float value) { float value) {
framework::VisitDataType(framework::ToDataType(tensor->type()), framework::VisitDataType(tensor->type(),
TensorSetConstantGPU(context, tensor, value)); TensorSetConstantGPU(context, tensor, value));
} }
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
namespace paddle { namespace paddle {
......
...@@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include <vector> #include <vector>
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
......
...@@ -44,9 +44,8 @@ class MeanIoUOp : public framework::OperatorWithKernel { ...@@ -44,9 +44,8 @@ class MeanIoUOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Predictions")->type(),
framework::ToDataType(ctx.Input<Tensor>("Predictions")->type()), ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -61,9 +61,7 @@ class MeanGradOp : public framework::OperatorWithKernel { ...@@ -61,9 +61,7 @@ class MeanGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = auto input_data_type = ctx.Input<Tensor>("X")->type();
framework::ToDataType(ctx.Input<Tensor>("X")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/mean_op.h" #include "paddle/fluid/operators/mean_op.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
......
...@@ -63,9 +63,7 @@ class MergeLoDTensorOp : public framework::OperatorBase { ...@@ -63,9 +63,7 @@ class MergeLoDTensorOp : public framework::OperatorBase {
platform::Place place = dev_place; platform::Place place = dev_place;
int64_t batch_size = in_true.dims()[0] + in_false.dims()[0]; int64_t batch_size = in_true.dims()[0] + in_false.dims()[0];
auto data_type = in_true.IsInitialized() ? in_true.type() : in_false.type();
std::type_index data_type =
in_true.IsInitialized() ? in_true.type() : in_false.type();
int rank; int rank;
framework::DDim in_dims; framework::DDim in_dims;
if (in_true.IsInitialized()) { if (in_true.IsInitialized()) {
......
...@@ -55,9 +55,8 @@ class AccuracyOp : public framework::OperatorWithKernel { ...@@ -55,9 +55,8 @@ class AccuracyOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Out")->type(),
framework::ToDataType(ctx.Input<Tensor>("Out")->type()), ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -51,9 +51,8 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -51,9 +51,8 @@ class AucOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Predict")->type(),
framework::ToDataType(ctx.Input<Tensor>("Predict")->type()), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
......
...@@ -82,9 +82,8 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { ...@@ -82,9 +82,8 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("MaxProbs")->type(),
framework::ToDataType(ctx.Input<Tensor>("MaxProbs")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -53,9 +53,8 @@ class MultiplexOp : public framework::OperatorWithKernel { ...@@ -53,9 +53,8 @@ class MultiplexOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.MultiInput<Tensor>("X")[0]->type(),
framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -123,9 +122,8 @@ class MultiplexGradOp : public framework::OperatorWithKernel { ...@@ -123,9 +122,8 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.MultiInput<Tensor>("X")[0]->type(),
framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -69,9 +69,8 @@ class NCEOp : public framework::OperatorWithKernel { ...@@ -69,9 +69,8 @@ class NCEOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
...@@ -214,9 +213,8 @@ class NCEOpGrad : public framework::OperatorWithKernel { ...@@ -214,9 +213,8 @@ class NCEOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
......
...@@ -70,9 +70,8 @@ class AdadeltaOp : public framework::OperatorWithKernel { ...@@ -70,9 +70,8 @@ class AdadeltaOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(),
framework::ToDataType(ctx.Input<Tensor>("Param")->type()); ctx.GetPlace());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/optimizers/adadelta_op.h" #include "paddle/fluid/operators/optimizers/adadelta_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -59,9 +59,8 @@ class AdagradOp : public framework::OperatorWithKernel { ...@@ -59,9 +59,8 @@ class AdagradOp : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(),
framework::ToDataType(ctx.Input<Tensor>("Param")->type()); ctx.GetPlace());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/operators/optimizers/adagrad_op.h" #include "paddle/fluid/operators/optimizers/adagrad_op.h"
......
...@@ -75,8 +75,7 @@ class AdamOp : public framework::OperatorWithKernel { ...@@ -75,8 +75,7 @@ class AdamOp : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type = ctx.Input<Tensor>("Param")->type();
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/optimizers/adam_op.h" #include "paddle/fluid/operators/optimizers/adam_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -76,9 +76,8 @@ class AdamaxOp : public framework::OperatorWithKernel { ...@@ -76,9 +76,8 @@ class AdamaxOp : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(),
framework::ToDataType(ctx.Input<Tensor>("Param")->type()); ctx.GetPlace());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/optimizers/adamax_op.h" #include "paddle/fluid/operators/optimizers/adamax_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -64,9 +64,8 @@ class DecayedAdagradOp : public framework::OperatorWithKernel { ...@@ -64,9 +64,8 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(),
framework::ToDataType(ctx.Input<Tensor>("Param")->type()); ctx.GetPlace());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/optimizers/decayed_adagrad_op.h" #include "paddle/fluid/operators/optimizers/decayed_adagrad_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -66,8 +66,7 @@ class FTRLOp : public framework::OperatorWithKernel { ...@@ -66,8 +66,7 @@ class FTRLOp : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type = ctx.Input<Tensor>("Param")->type();
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -10,8 +10,6 @@ Unless required by applicable law or agreed to in writing, software distributed ...@@ -10,8 +10,6 @@ Unless required by applicable law or agreed to in writing, software distributed
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
CONDITIONS OF ANY KIND, either express or implied. See the License for the CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. */ specific language governing permissions and limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/optimizers/ftrl_op.h" #include "paddle/fluid/operators/optimizers/ftrl_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -58,9 +58,8 @@ class ProximalAdagradOp : public framework::OperatorWithKernel { ...@@ -58,9 +58,8 @@ class ProximalAdagradOp : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(),
framework::ToDataType(ctx.Input<Tensor>("Param")->type()); ctx.GetPlace());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -10,8 +10,6 @@ Unless required by applicable law or agreed to in writing, software distributed ...@@ -10,8 +10,6 @@ Unless required by applicable law or agreed to in writing, software distributed
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
CONDITIONS OF ANY KIND, either express or implied. See the License for the CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. */ specific language governing permissions and limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/optimizers/proximal_adagrad_op.h" #include "paddle/fluid/operators/optimizers/proximal_adagrad_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -46,9 +46,8 @@ class ProximalGDOp : public framework::OperatorWithKernel { ...@@ -46,9 +46,8 @@ class ProximalGDOp : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(),
framework::ToDataType(ctx.Input<Tensor>("Param")->type()); ctx.GetPlace());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -10,8 +10,6 @@ Unless required by applicable law or agreed to in writing, software distributed ...@@ -10,8 +10,6 @@ Unless required by applicable law or agreed to in writing, software distributed
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
CONDITIONS OF ANY KIND, either express or implied. See the License for the CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. */ specific language governing permissions and limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/optimizers/proximal_gd_op.h" #include "paddle/fluid/operators/optimizers/proximal_gd_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/optimizers/rmsprop_op.h" #include "paddle/fluid/operators/optimizers/rmsprop_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -47,9 +47,8 @@ class PadConstantLikeOp : public framework::OperatorWithKernel { ...@@ -47,9 +47,8 @@ class PadConstantLikeOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Y")->type(),
framework::ToDataType(ctx.Input<Tensor>("Y")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -171,9 +170,8 @@ class PadConstantLikeOpGrad : public framework::OperatorWithKernel { ...@@ -171,9 +170,8 @@ class PadConstantLikeOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Y")->type(),
framework::ToDataType(ctx.Input<Tensor>("Y")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/pad_constant_like_op.h" #include "paddle/fluid/operators/pad_constant_like_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/pad_op.h" #include "paddle/fluid/operators/pad_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -99,9 +99,8 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( ...@@ -99,9 +99,8 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), ctx.GetPlace(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), layout_, library_);
layout_, library_);
} }
void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const { void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const {
...@@ -130,7 +129,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( ...@@ -130,7 +129,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
} }
#endif #endif
auto input_data_type = framework::ToDataType(ctx.Input<Tensor>("X")->type()); auto input_data_type = ctx.Input<Tensor>("X")->type();
if (input_data_type == framework::proto::VarType::FP16) { if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN, PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN,
"float16 can only be used when CUDNN is used"); "float16 can only be used when CUDNN is used");
......
...@@ -71,9 +71,8 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { ...@@ -71,9 +71,8 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -92,9 +91,8 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { ...@@ -92,9 +91,8 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -87,9 +87,8 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel { ...@@ -87,9 +87,8 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Score")->type(),
framework::ToDataType(ctx.Input<Tensor>("Score")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -56,9 +56,8 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -56,9 +56,8 @@ class PReluOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
...@@ -113,9 +112,8 @@ class PReluGradOp : public framework::OperatorWithKernel { ...@@ -113,9 +112,8 @@ class PReluGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
......
...@@ -172,7 +172,7 @@ class TensorPrintOp : public framework::OperatorBase { ...@@ -172,7 +172,7 @@ class TensorPrintOp : public framework::OperatorBase {
formater.name = printed_var_name; formater.name = printed_var_name;
} }
if (Attr<bool>("print_tensor_type")) { if (Attr<bool>("print_tensor_type")) {
formater.dtype = printed_tensor.type(); formater.dtype = framework::ToTypeIndex(printed_tensor.type());
} }
if (Attr<bool>("print_tensor_shape")) { if (Attr<bool>("print_tensor_shape")) {
auto &dims = printed_tensor.dims(); auto &dims = printed_tensor.dims();
......
...@@ -22,9 +22,8 @@ class RandomCropOp : public framework::OperatorWithKernel { ...@@ -22,9 +22,8 @@ class RandomCropOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -99,10 +99,10 @@ void BatchReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) { ...@@ -99,10 +99,10 @@ void BatchReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
out->reserve(out_num); out->reserve(out_num);
for (size_t j = 0; j < out_num; ++j) { for (size_t j = 0; j < out_num; ++j) {
// Merge shape and check date type // Merge shape and check date type
std::type_index batch_type = buffer_[0][j].type(); auto batch_type = buffer_[0][j].type();
framework::DDim batch_shape = buffer_[0][j].dims(); framework::DDim batch_shape = buffer_[0][j].dims();
for (size_t i = 1; i < buffer_.size(); ++i) { for (size_t i = 1; i < buffer_.size(); ++i) {
std::type_index ins_type = buffer_[i][j].type(); auto ins_type = buffer_[i][j].type();
framework::DDim ins_shape = buffer_[i][j].dims(); framework::DDim ins_shape = buffer_[i][j].dims();
PADDLE_ENFORCE_EQ(batch_type, ins_type); PADDLE_ENFORCE_EQ(batch_type, ins_type);
PADDLE_ENFORCE_EQ(slice_ddim(batch_shape, 1, batch_shape.size()), PADDLE_ENFORCE_EQ(slice_ddim(batch_shape, 1, batch_shape.size()),
......
...@@ -414,7 +414,7 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -414,7 +414,7 @@ class RecurrentGradOp : public RecurrentBase {
auto &inside_tensor = cur_scope.FindVar(inside_grad_name) auto &inside_tensor = cur_scope.FindVar(inside_grad_name)
->Get<framework::LoDTensor>(); ->Get<framework::LoDTensor>();
framework::AttributeMap attrs; framework::AttributeMap attrs;
attrs["dtype"] = framework::ToDataType(inside_tensor.type()); attrs["dtype"] = inside_tensor.type();
attrs["shape"] = framework::vectorize2int(inside_tensor.dims()); attrs["shape"] = framework::vectorize2int(inside_tensor.dims());
attrs["value"] = 0.0f; attrs["value"] = 0.0f;
......
...@@ -108,9 +108,8 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -108,9 +108,8 @@ class ReshapeOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -189,9 +188,8 @@ class ReshapeGradOp : public framework::OperatorWithKernel { ...@@ -189,9 +188,8 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -322,9 +320,7 @@ class Reshape2GradOp : public framework::OperatorWithKernel { ...@@ -322,9 +320,7 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))
->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -99,7 +99,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase { ...@@ -99,7 +99,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
auto &in_var_tensor = in_var->Get<framework::LoDTensor>(); auto &in_var_tensor = in_var->Get<framework::LoDTensor>();
framework::AttributeMap attrs; framework::AttributeMap attrs;
attrs["dtype"] = framework::ToDataType(in_var_tensor.type()); attrs["dtype"] = in_var_tensor.type();
attrs["shape"] = framework::vectorize2int(in_var_tensor.dims()); attrs["shape"] = framework::vectorize2int(in_var_tensor.dims());
attrs["value"] = 0.0f; attrs["value"] = 0.0f;
......
...@@ -62,9 +62,8 @@ class ROIAlignOp : public framework::OperatorWithKernel { ...@@ -62,9 +62,8 @@ class ROIAlignOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -83,9 +82,8 @@ class ROIAlignGradOp : public framework::OperatorWithKernel { ...@@ -83,9 +82,8 @@ class ROIAlignGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -69,9 +69,8 @@ class ROIPoolOp : public framework::OperatorWithKernel { ...@@ -69,9 +69,8 @@ class ROIPoolOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -90,9 +89,8 @@ class ROIPoolGradOp : public framework::OperatorWithKernel { ...@@ -90,9 +89,8 @@ class ROIPoolGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -75,7 +75,7 @@ class SaveCombineOp : public framework::OperatorBase { ...@@ -75,7 +75,7 @@ class SaveCombineOp : public framework::OperatorBase {
// Serialize tensors one by one // Serialize tensors one by one
// Check types to see if a fp16 transformation is required // Check types to see if a fp16 transformation is required
auto in_dtype = framework::ToDataType(tensor.type()); auto in_dtype = tensor.type();
auto out_dtype = auto out_dtype =
save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
......
...@@ -85,7 +85,7 @@ class SaveOp : public framework::OperatorBase { ...@@ -85,7 +85,7 @@ class SaveOp : public framework::OperatorBase {
filename); filename);
auto save_as_fp16 = Attr<bool>("save_as_fp16"); auto save_as_fp16 = Attr<bool>("save_as_fp16");
auto in_dtype = framework::ToDataType(tensor.type()); auto in_dtype = tensor.type();
auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) { if (in_dtype != out_dtype) {
......
...@@ -51,9 +51,8 @@ class ScatterOp : public framework::OperatorWithKernel { ...@@ -51,9 +51,8 @@ class ScatterOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -70,9 +69,8 @@ class ScatterGradOp : public framework::OperatorWithKernel { ...@@ -70,9 +69,8 @@ class ScatterGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -114,9 +114,8 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { ...@@ -114,9 +114,8 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/sequence_ops/sequence_pool_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_pool_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -112,9 +112,8 @@ class SequenceScatterOp : public framework::OperatorWithKernel { ...@@ -112,9 +112,8 @@ class SequenceScatterOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
...@@ -131,9 +130,8 @@ class SequenceScatterGradOp : public framework::OperatorWithKernel { ...@@ -131,9 +130,8 @@ class SequenceScatterGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
......
...@@ -50,9 +50,8 @@ class SequenceSliceOp : public framework::OperatorWithKernel { ...@@ -50,9 +50,8 @@ class SequenceSliceOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -71,9 +70,8 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel { ...@@ -71,9 +70,8 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -51,7 +51,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { ...@@ -51,7 +51,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
} }
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), ctx.Input<Tensor>("X")->type(), ctx.GetPlace(),
framework::StringToDataLayout(data_format), library_); framework::StringToDataLayout(data_format), library_);
} }
}; };
...@@ -146,7 +146,7 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { ...@@ -146,7 +146,7 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
} }
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), ctx.Input<Tensor>("X")->type(), ctx.GetPlace(),
framework::StringToDataLayout(data_format), library_); framework::StringToDataLayout(data_format), library_);
} }
}; };
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h" #include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -70,9 +70,8 @@ class SimilarityFocusOp : public framework::OperatorWithKernel { ...@@ -70,9 +70,8 @@ class SimilarityFocusOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
......
...@@ -59,9 +59,8 @@ class SliceOp : public framework::OperatorWithKernel { ...@@ -59,9 +59,8 @@ class SliceOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/smooth_l1_loss_op.h" #include "paddle/fluid/operators/smooth_l1_loss_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -62,8 +62,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { ...@@ -62,8 +62,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
} }
#endif #endif
auto input_data_type = auto input_data_type = ctx.Input<Tensor>("X")->type();
framework::ToDataType(ctx.Input<Tensor>("X")->type());
if (input_data_type == framework::proto::VarType::FP16) { if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"float16 can only be used on GPU place"); "float16 can only be used on GPU place");
...@@ -169,8 +168,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -169,8 +168,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
#endif #endif
auto input_data_type = framework::ToDataType( auto input_data_type =
ctx.Input<Tensor>(framework::GradVarName("Out"))->type()); ctx.Input<Tensor>(framework::GradVarName("Out"))->type();
if (input_data_type == framework::proto::VarType::FP16) { if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"float16 can only be used on GPU place"); "float16 can only be used on GPU place");
......
...@@ -131,9 +131,8 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { ...@@ -131,9 +131,8 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Logits")->type(),
framework::ToDataType(ctx.Input<Tensor>("Logits")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -173,8 +172,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { ...@@ -173,8 +172,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<Tensor>(framework::GradVarName("Loss"))->type(),
ctx.Input<Tensor>(framework::GradVarName("Loss"))->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" #include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
......
...@@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/squared_l2_distance_op.h" #include "paddle/fluid/operators/squared_l2_distance_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/squared_l2_norm_op.h" #include "paddle/fluid/operators/squared_l2_norm_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -91,9 +91,9 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -91,9 +91,9 @@ class SumOp : public framework::OperatorWithKernel {
continue; continue;
} }
if (dtype == -1) { if (dtype == -1) {
dtype = framework::ToDataType(tensor->type()); dtype = tensor->type();
} else { } else {
PADDLE_ENFORCE_EQ(dtype, framework::ToDataType(tensor->type())); PADDLE_ENFORCE_EQ(dtype, tensor->type());
} }
} }
PADDLE_ENFORCE_NE(dtype, -1, PADDLE_ENFORCE_NE(dtype, -1,
...@@ -106,8 +106,8 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -106,8 +106,8 @@ class SumOp : public framework::OperatorWithKernel {
for (auto& var : x_vars) { for (auto& var : x_vars) {
auto& value = var->Get<framework::SelectedRows>().value(); auto& value = var->Get<framework::SelectedRows>().value();
if (value.IsInitialized()) { if (value.IsInitialized()) {
return framework::OpKernelType(framework::ToDataType(value.type()), return framework::OpKernelType(value.type(), ctx.device_context(),
ctx.device_context(), layout, library); layout, library);
} }
} }
// if input sparse vars are not initialized, use an default kernel type. // if input sparse vars are not initialized, use an default kernel type.
...@@ -118,9 +118,8 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -118,9 +118,8 @@ class SumOp : public framework::OperatorWithKernel {
auto& array = x_var->Get<framework::LoDTensorArray>(); auto& array = x_var->Get<framework::LoDTensorArray>();
for (auto& each : array) { for (auto& each : array) {
if (each.numel() != 0) { if (each.numel() != 0) {
return framework::OpKernelType(framework::ToDataType(each.type()), return framework::OpKernelType(each.type(), ctx.device_context(),
ctx.device_context(), layout, layout, library);
library);
} }
} }
} }
......
...@@ -8,8 +8,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -8,8 +8,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/sum_op.h" #include "paddle/fluid/operators/sum_op.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
......
...@@ -144,9 +144,8 @@ class Transpose2Op : public TransposeOp { ...@@ -144,9 +144,8 @@ class Transpose2Op : public TransposeOp {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -194,9 +193,7 @@ class Transpose2OpGrad : public framework::OperatorWithKernel { ...@@ -194,9 +193,7 @@ class Transpose2OpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))
->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -74,9 +74,8 @@ class UnpoolOp : public framework::OperatorWithKernel { ...@@ -74,9 +74,8 @@ class UnpoolOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
public: public:
...@@ -113,9 +112,8 @@ class UnpoolOpGrad : public framework::OperatorWithKernel { ...@@ -113,9 +112,8 @@ class UnpoolOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
public: public:
......
...@@ -56,9 +56,8 @@ class WarpCTCOp : public framework::OperatorWithKernel { ...@@ -56,9 +56,8 @@ class WarpCTCOp : public framework::OperatorWithKernel {
} }
#endif #endif
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Logits")->type(),
framework::ToDataType(ctx.Input<Tensor>("Logits")->type()), ctx.device_context(), layout_, library_);
ctx.device_context(), layout_, library_);
} }
}; };
...@@ -136,9 +135,8 @@ class WarpCTCGradOp : public framework::OperatorWithKernel { ...@@ -136,9 +135,8 @@ class WarpCTCGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Logits")->type(),
framework::ToDataType(ctx.Input<Tensor>("Logits")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -93,7 +93,7 @@ TEST(CudaAtomic, float16) { ...@@ -93,7 +93,7 @@ TEST(CudaAtomic, float16) {
// unalignment of uint8 // unalignment of uint8
void TestUnalign(size_t num, const int shift_bit) { void TestUnalign(size_t num, const int shift_bit) {
PADDLE_ENFORCE(num % 2 == 0, "must be a multiple of 2"); ASSERT_EQ(num % 2, 0);
float16 *in1, *in2, *out; float16 *in1, *in2, *out;
float16 *d_in1, *d_in2; float16 *d_in1, *d_in2;
size_t size = sizeof(uint8_t) * (num + shift_bit); size_t size = sizeof(uint8_t) * (num + shift_bit);
......
...@@ -21,7 +21,6 @@ limitations under the License. */ ...@@ -21,7 +21,6 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
#define EIGEN_USE_GPU
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
......
...@@ -71,9 +71,6 @@ struct float16; ...@@ -71,9 +71,6 @@ struct float16;
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
// NOTE():
// Do not move the eigen.h header, otherwise the eigen_vector<bool> will failed.
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <thread> // NOLINT #include <thread> // NOLINT
#include <typeindex> #include <typeindex>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -28,14 +29,14 @@ ...@@ -28,14 +29,14 @@
namespace paddle { namespace paddle {
namespace platform { namespace platform {
inline ncclDataType_t ToNCCLDataType(std::type_index type) { inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
if (type == typeid(float)) { // NOLINT if (type == framework::proto::VarType::FP32) {
return ncclFloat; return ncclFloat;
} else if (type == typeid(double)) { // NOLINT } else if (type == framework::proto::VarType::FP64) {
return ncclDouble; return ncclDouble;
} else if (type == typeid(int)) { // NOLINT } else if (type == framework::proto::VarType::INT32) {
return ncclInt; return ncclInt;
} else if (type == typeid(int64_t)) { // NOLINT } else if (type == framework::proto::VarType::INT64) {
return ncclInt64; return ncclInt64;
} else { } else {
PADDLE_THROW("Not supported"); PADDLE_THROW("Not supported");
......
...@@ -168,7 +168,7 @@ PYBIND11_MODULE(core, m) { ...@@ -168,7 +168,7 @@ PYBIND11_MODULE(core, m) {
.def("_get_float_element", TensorGetElement<float>) .def("_get_float_element", TensorGetElement<float>)
.def("_set_double_element", TensorSetElement<double>) .def("_set_double_element", TensorSetElement<double>)
.def("_get_double_element", TensorGetElement<double>) .def("_get_double_element", TensorGetElement<double>)
.def("_dtype", [](Tensor &self) { return ToDataType(self.type()); }); .def("_dtype", [](Tensor &self) { return self.type(); });
py::class_<LoDTensor, Tensor>(m, "LoDTensor", R"DOC( py::class_<LoDTensor, Tensor>(m, "LoDTensor", R"DOC(
LoDTensor is a Tensor with optional LoD information. LoDTensor is a Tensor with optional LoD information.
......
...@@ -43,7 +43,7 @@ template <size_t I, typename... ARGS> ...@@ -43,7 +43,7 @@ template <size_t I, typename... ARGS>
struct CastToPyBufferImpl<true, I, ARGS...> { struct CastToPyBufferImpl<true, I, ARGS...> {
using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type; using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
pybind11::buffer_info operator()(const framework::Tensor &tensor) { pybind11::buffer_info operator()(const framework::Tensor &tensor) {
if (std::type_index(typeid(CUR_TYPE)) == tensor.type()) { if (framework::DataTypeTrait<CUR_TYPE>::DataType == tensor.type()) {
auto dim_vec = framework::vectorize(tensor.dims()); auto dim_vec = framework::vectorize(tensor.dims());
std::vector<size_t> dims_outside; std::vector<size_t> dims_outside;
std::vector<size_t> strides; std::vector<size_t> strides;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册