未验证 提交 a607b6c8 编写于 作者: Q Qiyang Min 提交者: GitHub

Merge pull request #15216 from velconia/local_release_1_2_x_tensor_type

Feature/tensor type
......@@ -86,6 +86,7 @@ endif(NOT WITH_GOLANG)
if(WITH_GPU)
add_definitions(-DPADDLE_WITH_CUDA)
add_definitions(-DEIGEN_USE_GPU)
FIND_PACKAGE(CUDA REQUIRED)
......
......@@ -85,7 +85,7 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
out->mutable_data(expected_kernel_type.place_, in.type());
framework::VisitDataType(
framework::ToDataType(in.type()),
in.type(),
CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out));
out->set_layout(expected_kernel_type.data_layout_);
......@@ -101,7 +101,7 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
case mkldnn::memory::data_type::f32:
return platform::to_void_cast(tensor.data<float>());
case mkldnn::memory::data_type::s8:
return platform::to_void_cast(tensor.data<char>());
return platform::to_void_cast(tensor.data<int8_t>());
case mkldnn::memory::data_type::u8:
return platform::to_void_cast(tensor.data<unsigned char>());
case mkldnn::memory::data_type::s16:
......@@ -144,7 +144,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
memory::data_type in_type = ToMKLDNNDataType(in.type());
PADDLE_ENFORCE(in_type != memory::data_type::data_undef,
"Input tensor type is not supported: ", in.type().name());
"Input tensor type is not supported: %s", in.type());
memory::data_type out_type = in_type;
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
......
......@@ -50,14 +50,14 @@ inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) {
}
}
inline MKLDNNDataType ToMKLDNNDataType(const std::type_index type) {
static const std::map<std::type_index, MKLDNNDataType> dict{
{std::type_index(typeid(float)), MKLDNNDataType::f32}, // NOLINT
{std::type_index(typeid(char)), MKLDNNDataType::s8}, // NOLINT
{std::type_index(typeid(unsigned char)), MKLDNNDataType::u8},
{std::type_index(typeid(int16_t)), MKLDNNDataType::s16},
{std::type_index(typeid(int32_t)), MKLDNNDataType::s32}};
auto iter = dict.find(type);
inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
static std::unordered_map<int, MKLDNNDataType> dict{
{DataTypeTrait<float>::DataType, MKLDNNDataType::f32},
{DataTypeTrait<int8_t>::DataType, MKLDNNDataType::s8},
{DataTypeTrait<uint8_t>::DataType, MKLDNNDataType::u8},
{DataTypeTrait<int16_t>::DataType, MKLDNNDataType::s16},
{DataTypeTrait<int32_t>::DataType, MKLDNNDataType::s32}};
auto iter = dict.find(static_cast<int>(type));
if (iter != dict.end()) return iter->second;
return MKLDNNDataType::data_undef;
}
......
......@@ -26,7 +26,7 @@ struct DataTypeMap {
std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_;
std::unordered_map<int, std::type_index> proto_to_cpp_;
std::unordered_map<int, std::string> proto_to_str_;
std::unordered_map<std::type_index, size_t> cpp_to_size_;
std::unordered_map<int, size_t> proto_to_size_;
};
static DataTypeMap* InitDataTypeMap();
......@@ -45,7 +45,7 @@ static inline void RegisterType(DataTypeMap* map,
map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T));
map->cpp_to_proto_.emplace(typeid(T), proto_type);
map->proto_to_str_.emplace(static_cast<int>(proto_type), name);
map->cpp_to_size_.emplace(typeid(T), sizeof(T));
map->proto_to_size_.emplace(static_cast<int>(proto_type), sizeof(T));
}
static DataTypeMap* InitDataTypeMap() {
......@@ -54,17 +54,7 @@ static DataTypeMap* InitDataTypeMap() {
#define RegType(cc_type, proto_type) \
RegisterType<cc_type>(retv, proto_type, #cc_type)
// NOTE: Add your customize type here.
RegType(float16, proto::VarType::FP16);
RegType(float, proto::VarType::FP32);
RegType(double, proto::VarType::FP64);
RegType(int, proto::VarType::INT32);
RegType(int64_t, proto::VarType::INT64);
RegType(bool, proto::VarType::BOOL);
RegType(size_t, proto::VarType::SIZE_T);
RegType(int16_t, proto::VarType::INT16);
RegType(uint8_t, proto::VarType::UINT8);
RegType(int8_t, proto::VarType::INT8);
_ForEachDataType_(RegType);
#undef RegType
return retv;
......@@ -96,12 +86,12 @@ std::string DataTypeToString(const proto::VarType::Type type) {
static_cast<int>(type));
}
size_t SizeOfType(std::type_index type) {
auto it = gDataTypeMap().cpp_to_size_.find(type);
if (it != gDataTypeMap().cpp_to_size_.end()) {
size_t SizeOfType(proto::VarType::Type type) {
auto it = gDataTypeMap().proto_to_size_.find(static_cast<int>(type));
if (it != gDataTypeMap().proto_to_size_.end()) {
return it->second;
}
PADDLE_THROW("Not support %s as tensor type", type.name());
PADDLE_THROW("Not support %s as tensor type", DataTypeToString(type));
}
} // namespace framework
......
......@@ -22,46 +22,59 @@ limitations under the License. */
namespace paddle {
namespace framework {
template <typename T>
struct DataTypeTrait {};
// Stub handle for void
template <>
struct DataTypeTrait<void> {
constexpr static auto DataType = proto::VarType::RAW;
};
#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
callback(cpp_type, ::paddle::framework::proto::VarType::proto_type);
#define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8)
#define DefineDataTypeTrait(cpp_type, proto_type) \
template <> \
struct DataTypeTrait<cpp_type> { \
constexpr static auto DataType = proto_type; \
}
_ForEachDataType_(DefineDataTypeTrait);
#undef DefineDataTypeTrait
extern proto::VarType::Type ToDataType(std::type_index type);
extern std::type_index ToTypeIndex(proto::VarType::Type type);
template <typename Visitor>
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
switch (type) {
case proto::VarType::FP16:
visitor.template apply<platform::float16>();
break;
case proto::VarType::FP32:
visitor.template apply<float>();
break;
case proto::VarType::FP64:
visitor.template apply<double>();
break;
case proto::VarType::INT32:
visitor.template apply<int>();
break;
case proto::VarType::INT64:
visitor.template apply<int64_t>();
break;
case proto::VarType::BOOL:
visitor.template apply<bool>();
break;
case proto::VarType::UINT8:
visitor.template apply<uint8_t>();
break;
case proto::VarType::INT16:
visitor.template apply<int16_t>();
break;
case proto::VarType::INT8:
visitor.template apply<int8_t>();
break;
default:
#define VisitDataTypeCallback(cpp_type, proto_type) \
do { \
if (type == proto_type) { \
visitor.template apply<cpp_type>(); \
return; \
} \
} while (0)
_ForEachDataType_(VisitDataTypeCallback);
#undef VisitDataTypeCallback
PADDLE_THROW("Not supported %d", type);
}
}
extern std::string DataTypeToString(const proto::VarType::Type type);
extern size_t SizeOfType(std::type_index type);
extern size_t SizeOfType(proto::VarType::Type type);
inline std::ostream& operator<<(std::ostream& out,
const proto::VarType::Type& type) {
out << DataTypeToString(type);
......
......@@ -26,15 +26,15 @@ TEST(DataType, float16) {
Tensor tensor;
CPUPlace cpu;
tensor.mutable_data(cpu, f::ToTypeIndex(dtype));
tensor.mutable_data(cpu, dtype);
// test fp16 tensor
EXPECT_EQ(tensor.type(), std::type_index(typeid(float16)));
EXPECT_EQ(tensor.type(), f::ToDataType(typeid(float16)));
// test fp16 size
EXPECT_EQ(f::SizeOfType(f::ToTypeIndex(dtype)), 2u);
EXPECT_EQ(f::SizeOfType(dtype), 2u);
// test debug info
std::string type = "float16";
std::string type = "::paddle::platform::float16";
EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str());
}
......@@ -120,7 +120,7 @@ void AllReduceOpHandle::RunImpl() {
// Reduce All Tensor to trg in CPU
ReduceLoDTensor func(lod_tensors, &trg);
VisitDataType(ToDataType(lod_tensors[0]->type()), func);
VisitDataType(lod_tensors[0]->type(), func);
for (size_t i = 1; i < local_scopes_.size(); ++i) {
auto &scope =
......
......@@ -33,7 +33,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
FuseVarsOpHandle(ir::Node *node, Scope *local_scope,
const platform::Place &place,
const std::unordered_map<std::string, int64_t> &inputs_numel,
const std::type_index &var_type)
const proto::VarType::Type var_type)
: OpHandleBase(node),
local_scope_(local_scope),
place_(place),
......@@ -57,7 +57,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
Scope *local_scope_;
const platform::Place place_;
const std::unordered_map<std::string, int64_t> inputs_numel_;
const std::type_index type_;
const proto::VarType::Type type_;
int64_t total_numel_;
};
} // namespace details
......
......@@ -106,7 +106,7 @@ void ReduceOpHandle::RunImpl() {
if (!FLAGS_cpu_deterministic) {
ReduceLoDTensor func(lod_tensors,
out_var->GetMutable<framework::LoDTensor>());
VisitDataType(ToDataType(lod_tensors[0]->type()), func);
VisitDataType(lod_tensors[0]->type(), func);
} else {
// We sum lod_tensors to reduce_sum_trg which is in local_scopes_0
// here, but it doesn't mean reduce_sum_trg must be in local_scopes_0.
......@@ -116,7 +116,7 @@ void ReduceOpHandle::RunImpl() {
->FindVar(out_var_handle->name_)
->GetMutable<framework::LoDTensor>();
ReduceLoDTensor func(lod_tensors, &reduce_sum_trg);
VisitDataType(ToDataType(lod_tensors[0]->type()), func);
VisitDataType(lod_tensors[0]->type(), func);
auto trg = out_var->GetMutable<framework::LoDTensor>();
if (reduce_sum_trg.data<void>() != trg->data<void>()) {
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/dlpack_tensor.h"
#include "paddle/fluid/framework/data_type.h"
namespace paddle {
namespace framework {
......@@ -36,26 +36,23 @@ static ::DLDataType GetDLDataTypeCode() {
return dtype;
}
static DLDataType GetDLDataTypeFromTypeIndex(const std::type_index &type) {
#define REG_DL_DATA_TYPE(type) \
{ std::type_index(typeid(type)), GetDLDataTypeCode<type>() }
static const std::unordered_map<std::type_index, ::DLDataType>
type_to_dtype_map({
REG_DL_DATA_TYPE(platform::float16), // NOLINT
REG_DL_DATA_TYPE(float), // NOLINT
REG_DL_DATA_TYPE(double), // NOLINT
REG_DL_DATA_TYPE(int), // NOLINT
REG_DL_DATA_TYPE(int64_t), // NOLINT
REG_DL_DATA_TYPE(bool), // NOLINT
REG_DL_DATA_TYPE(size_t), // NOLINT
REG_DL_DATA_TYPE(int16_t), // NOLINT
REG_DL_DATA_TYPE(uint8_t), // NOLINT
REG_DL_DATA_TYPE(int8_t) // NOLINT
});
static std::unordered_map<int, ::DLDataType> CreateDLDataTypeMap() {
static std::unordered_map<int, ::DLDataType> result;
#define REG_DL_DATA_TYPE(cpp_type, proto_type) \
result[static_cast<int>(proto_type)] = GetDLDataTypeCode<cpp_type>()
_ForEachDataType_(REG_DL_DATA_TYPE);
#undef REG_DL_DATA_TYPE
return result;
}
static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) {
static auto type_to_dtype_map = CreateDLDataTypeMap();
static auto type_to_dtype_map_end_it = type_to_dtype_map.end();
auto it = type_to_dtype_map.find(type);
PADDLE_ENFORCE(it != type_to_dtype_map_end_it, "Unsupported data type %s",
type.name());
auto it = type_to_dtype_map.find(static_cast<int>(type));
PADDLE_ENFORCE(it != type_to_dtype_map_end_it, "Unsupported data type %d",
type);
return it->second;
#undef REG_DL_DATA_TYPE
}
......
......@@ -91,23 +91,11 @@ void TestMainLoop() {
}
}
}
TEST(dlpack, test_all) {
#define TestCallback(cpp_type, proto_type) TestMainLoop<cpp_type>()
#define PADDLE_DLPACK_TEST(type) \
TEST(dlpack, test_##type) { TestMainLoop<type>(); }
using float16 = platform::float16;
PADDLE_DLPACK_TEST(float16);
PADDLE_DLPACK_TEST(float);
PADDLE_DLPACK_TEST(double);
PADDLE_DLPACK_TEST(int);
PADDLE_DLPACK_TEST(int64_t);
PADDLE_DLPACK_TEST(bool);
PADDLE_DLPACK_TEST(size_t);
PADDLE_DLPACK_TEST(int16_t);
PADDLE_DLPACK_TEST(uint8_t);
PADDLE_DLPACK_TEST(int8_t);
#undef PADDLE_DLPACK_TEST
_ForEachDataType_(TestCallback);
}
} // namespace framework
} // namespace paddle
......@@ -138,39 +138,19 @@ void print_lod_tensor(std::string var_name, const LoDTensor& lod_tensor) {
std::cout << sstream.str() << std::endl;
}
void print_fetch_var(Scope* scope, std::string var_name) {
const LoDTensor& tensor = scope->FindVar(var_name)->Get<LoDTensor>();
if (std::type_index(tensor.type()) ==
std::type_index(typeid(platform::float16))) {
print_lod_tensor<platform::float16>(var_name, tensor);
} else if (std::type_index(tensor.type()) == std::type_index(typeid(float))) {
print_lod_tensor<float>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(double))) {
print_lod_tensor<double>(var_name, tensor);
} else if (std::type_index(tensor.type()) == std::type_index(typeid(int))) {
print_lod_tensor<int>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int64_t))) {
print_lod_tensor<int64_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) == std::type_index(typeid(bool))) {
print_lod_tensor<bool>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(uint8_t))) {
print_lod_tensor<uint8_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int16_t))) {
print_lod_tensor<int16_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int8_t))) {
print_lod_tensor<int8_t>(var_name, tensor);
} else {
VLOG(1) << "print_fetch_var: unrecognized data type:"
<< tensor.type().name();
}
return;
static void print_fetch_var(Scope* scope, const std::string& var_name) {
auto& tensor = scope->FindVar(var_name)->Get<LoDTensor>();
#define PrintLoDTensorCallback(cpp_type, proto_type) \
do { \
if (tensor.type() == proto_type) { \
print_lod_tensor<cpp_type>(var_name, tensor); \
return; \
} \
} while (0)
_ForEachDataType_(PrintLoDTensorCallback);
VLOG(1) << "print_fetch_var: unrecognized data type:" << tensor.type();
}
void ExecutorThreadWorker::TrainFiles() {
......
......@@ -70,9 +70,9 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
// only print first ten elements
int64_t size = t.numel() < 10 ? t.numel() : 10;
for (int64_t i = 0; i < size; ++i) {
if (IsType<float>(t.type())) {
if (t.type() == proto::VarType::FP32) {
os << t.data<float>()[i] << " ";
} else if (IsType<int64_t>(t.type())) {
} else if (t.type() == proto::VarType::INT64) {
os << t.data<int64_t>()[i] << " ";
} else {
PADDLE_THROW("LoDTensor data type not in [float, int64_t]");
......@@ -387,7 +387,7 @@ void LoDTensor::MergeLoDTensor(
PADDLE_ENFORCE(!lod_tensors.empty());
framework::DDim new_dim = lod_tensors[0]->dims();
std::type_index new_type = lod_tensors[0]->type();
auto new_type = lod_tensors[0]->type();
framework::DataLayout new_layout = lod_tensors[0]->layout();
LoD new_lod = lod_tensors[0]->lod();
for (size_t i = 1; i < lod_tensors.size(); ++i) {
......
......@@ -34,7 +34,8 @@ TEST(OpKernelType, ToString) {
OpKernelType op_kernel_type2(DataType::FP16, CUDAPlace(0), DataLayout::kNCHW,
LibraryType::kCUDNN);
ASSERT_EQ(paddle::framework::KernelTypeToString(op_kernel_type2),
"data_type[float16]:data_layout[NCHW]:place[CUDAPlace(0)]:library_"
"data_type[::paddle::platform::float16]:data_layout[NCHW]:place["
"CUDAPlace(0)]:library_"
"type[CUDNN]");
}
......
......@@ -43,10 +43,9 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
proto::VarType::Type GetDataTypeOfVar(const Variable* var) {
if (var->IsType<framework::LoDTensor>()) {
return framework::ToDataType(var->Get<framework::LoDTensor>().type());
return var->Get<framework::LoDTensor>().type();
} else if (var->IsType<framework::SelectedRows>()) {
return framework::ToDataType(
var->Get<framework::SelectedRows>().value().type());
return var->Get<framework::SelectedRows>().value().type();
} else {
PADDLE_THROW("Var should be LoDTensor or SelectedRows");
}
......@@ -93,13 +92,13 @@ static std::string GetDtype(const Scope& scope, const std::string& name) {
if (UNLIKELY(!tensor.IsInitialized())) {
return "";
}
return DataTypeToString(ToDataType(tensor.type()));
return DataTypeToString(tensor.type());
} else if (var->IsType<SelectedRows>()) {
auto tensor = var->Get<SelectedRows>().value();
if (UNLIKELY(!tensor.IsInitialized())) {
return "uninited";
} else {
return DataTypeToString(ToDataType(tensor.type()));
return DataTypeToString(tensor.type());
}
} else {
return "";
......@@ -686,7 +685,8 @@ static void CheckTensorNANOrInf(const std::string& name,
if (tensor.memory_size() == 0) {
return;
}
if (!IsType<float>(tensor.type()) && !IsType<double>(tensor.type())) {
if (tensor.type() != proto::VarType::FP32 &&
tensor.type() != proto::VarType::FP64) {
return;
}
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
......@@ -873,7 +873,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
t = &(var->Get<SelectedRows>().value());
}
if (t != nullptr) {
int tmp = static_cast<int>(ToDataType(t->type()));
int tmp = static_cast<int>(t->type());
PADDLE_ENFORCE(
tmp == data_type || data_type == -1,
"DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)",
......
......@@ -218,11 +218,11 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
if (index < 0) {
VLOG(5) << "id " << id << " not in the table, return 0";
framework::VisitDataType(
framework::ToDataType(value_->type()),
value_->type(),
TensorFillVisitor(value, i * value_width, value_width, 0.0));
} else {
framework::VisitDataType(
framework::ToDataType(value_->type()),
value_->type(),
TensorCopyVisitor(value, i * value_width, *value_.get(),
index * value_width, value_width));
}
......
......@@ -16,7 +16,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
extern size_t SizeOfType(std::type_index type);
extern size_t SizeOfType(proto::VarType::Type type);
void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
......@@ -31,7 +31,7 @@ size_t Tensor::memory_size() const {
return holder_ == nullptr ? 0UL : holder_->size() - offset_;
}
void* Tensor::mutable_data(platform::Place place, std::type_index type,
void* Tensor::mutable_data(platform::Place place, proto::VarType::Type type,
memory::Allocator::Attr attr,
size_t requested_size) {
type_ = type;
......
......@@ -19,9 +19,9 @@ limitations under the License. */
#include <memory>
#include <typeindex>
#include <vector>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -67,7 +67,7 @@ class Tensor {
friend struct EigenVector;
public:
Tensor() : type_(typeid(float)), offset_(0) {}
Tensor() : type_(proto::VarType::FP32), offset_(0) {}
/*! Return a pointer to mutable memory block. */
template <typename T>
......@@ -88,7 +88,7 @@ class Tensor {
memory::Allocator::Attr attr = memory::Allocator::kDefault,
size_t requested_size = 0);
void* mutable_data(platform::Place place, std::type_index type,
void* mutable_data(platform::Place place, proto::VarType::Type type,
memory::Allocator::Attr attr = memory::Allocator::kDefault,
size_t requested_size = 0);
......@@ -138,7 +138,7 @@ class Tensor {
return holder_->place();
}
std::type_index type() const {
proto::VarType::Type type() const {
PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor not initialized yet when Tensor::type() is called.");
return type_;
......@@ -161,7 +161,7 @@ class Tensor {
private:
/*! holds the memory block if allocated. */
std::shared_ptr<memory::Allocation> holder_;
std::type_index type_;
proto::VarType::Type type_;
/**
* @brief points to elements dimensions.
*
......
......@@ -24,9 +24,8 @@ template <typename T>
inline const T* Tensor::data() const {
check_memory_size();
bool valid =
std::is_same<T, void>::value || type_ == std::type_index(typeid(T));
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s",
type_.name());
std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType;
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %d", type_);
return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
......@@ -38,9 +37,8 @@ template <typename T>
inline T* Tensor::data() {
check_memory_size();
bool valid =
std::is_same<T, void>::value || type_ == std::type_index(typeid(T));
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s",
type_.name());
std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType;
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", type_);
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
......@@ -60,7 +58,7 @@ inline T* Tensor::mutable_data(platform::Place place,
size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD");
return reinterpret_cast<T*>(
mutable_data(place, typeid(T), attr, requested_size));
mutable_data(place, DataTypeTrait<T>::DataType, attr, requested_size));
}
inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
......
......@@ -186,7 +186,7 @@ struct AnyDTypeVisitor {
template <typename Predicate, typename DevCtx>
inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor,
const DevCtx& ctx, framework::Tensor* out) {
VisitDataType(ToDataType(tensor.type()), AnyDTypeVisitor<Predicate, DevCtx>(
VisitDataType(tensor.type(), AnyDTypeVisitor<Predicate, DevCtx>(
predicate, tensor, ctx, out));
}
......@@ -379,7 +379,7 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
// int32_t size
// void* protobuf message
proto::VarType::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type()));
desc.set_data_type(tensor.type());
auto dims = framework::vectorize(tensor.dims());
auto* pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0);
......@@ -461,9 +461,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
tensor->Resize(framework::make_ddim(dims));
void* buf;
auto ctx = platform::CPUDeviceContext();
size_t size =
tensor->numel() *
framework::SizeOfType(framework::ToTypeIndex(desc.data_type()));
size_t size = tensor->numel() * framework::SizeOfType(desc.data_type());
if (platform::is_gpu_place(dev_ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
Tensor cpu_tensor;
......
......@@ -289,10 +289,10 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
auto type = fetch.type();
auto output = &(outputs->at(i));
output->name = fetchs_[idx]->Input("X")[0];
if (type == typeid(float)) {
if (type == framework::proto::VarType::FP32) {
GetFetchOne<float>(fetch, output);
output->dtype = PaddleDType::FLOAT32;
} else if (type == typeid(int64_t)) {
} else if (type == framework::proto::VarType::INT64) {
GetFetchOne<int64_t>(fetch, output);
output->dtype = PaddleDType::INT64;
} else {
......
......@@ -266,10 +266,10 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs,
auto type = fetch.type();
auto output = &(outputs->at(i));
output->name = fetchs_[idx]->Input("X")[0];
if (type == typeid(float)) {
if (type == framework::DataTypeTrait<float>::DataType) {
GetFetchOne<float>(fetch, output);
output->dtype = PaddleDType::FLOAT32;
} else if (type == typeid(int64_t)) {
} else if (type == framework::DataTypeTrait<int64_t>::DataType) {
GetFetchOne<int64_t>(fetch, output);
output->dtype = PaddleDType::INT64;
} else {
......
......@@ -36,10 +36,10 @@ namespace paddle {
PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) {
PaddleTensor pt;
if (t->type() == typeid(int64_t)) {
if (t->type() == framework::proto::VarType::INT64) {
pt.data.Reset(t->data<void>(), t->numel() * sizeof(int64_t));
pt.dtype = PaddleDType::INT64;
} else if (t->type() == typeid(float)) {
} else if (t->type() == framework::proto::VarType::FP32) {
pt.data.Reset(t->data<void>(), t->numel() * sizeof(float));
pt.dtype = PaddleDType::FLOAT32;
} else {
......
......@@ -361,7 +361,7 @@ static bool CompareTensorData(const framework::LoDTensor &a,
}
for (size_t i = 0; i < a_size; i++) {
if (a.type() == typeid(float)) {
if (a.type() == framework::proto::VarType::FP32) {
const auto *a_data = a.data<float>();
const auto *b_data = b.data<float>();
if (std::abs(a_data[i] - b_data[i]) > 1e-3) {
......@@ -370,7 +370,7 @@ static bool CompareTensorData(const framework::LoDTensor &a,
b_data[i]);
return false;
}
} else if (a.type() == typeid(int64_t)) {
} else if (a.type() == framework::proto::VarType::INT64) {
const auto *a_data = a.data<int64_t>();
const auto *b_data = b.data<int64_t>();
if (std::abs(a_data[i] - b_data[i]) > 1e-3) {
......
......@@ -78,7 +78,7 @@ class AffineGridOp : public framework::OperatorWithKernel {
library = framework::LibraryType::kCUDNN;
}
#endif
auto data_type = framework::ToDataType(ctx.Input<Tensor>("Theta")->type());
auto data_type = ctx.Input<Tensor>("Theta")->type();
return framework::OpKernelType(data_type, ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library);
}
......@@ -188,9 +188,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Theta")->type()),
ctx.GetPlace(), framework::DataLayout::kAnyLayout, library_);
return framework::OpKernelType(ctx.Input<Tensor>("Theta")->type(),
ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_);
}
};
......
......@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL(
int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, size_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
uint8_t>);
......@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL(
int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
size_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
uint8_t>);
......@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL(
int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, size_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
uint8_t>);
......@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL(
int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
size_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
uint8_t>);
......@@ -58,7 +58,7 @@ struct ArrayToLoDFunctor : public boost::static_visitor<void> {
ArrayToLoDFunctorImpl<DeviceContext> functor;
functor.dev_ctx_ = dev_ctx;
functor.prev_functor_ = this;
framework::VisitDataType(framework::ToDataType(out->type()), functor);
framework::VisitDataType(out->type(), functor);
}
};
......@@ -91,7 +91,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
PADDLE_ENFORCE(!x.empty(), "There's no element in the input array.");
int rank = x[0].dims().size();
platform::Place place = x[0].place();
std::type_index data_type = x[0].type();
auto data_type = x[0].type();
int64_t batch_size = x[0].dims()[0];
framework::DDim ins_dims = rank > 1
? framework::slice_ddim(x[0].dims(), 1, rank)
......
......@@ -121,8 +121,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
......
......@@ -103,8 +103,7 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("param")->type()),
return framework::OpKernelType(ctx.Input<Tensor>("param")->type(),
ctx.GetPlace());
}
};
......
......@@ -72,8 +72,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("X")->type());
auto input_data_type = ctx.Input<Tensor>("X")->type();
// By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
......@@ -81,17 +80,13 @@ class BatchNormOp : public framework::OperatorWithKernel {
if (input_data_type == framework::proto::VarType::FP64) {
bn_param_type = framework::proto::VarType::FP64;
}
PADDLE_ENFORCE_EQ(bn_param_type,
framework::ToDataType(ctx.Input<Tensor>("Scale")->type()),
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Scale")->type(),
"Scale input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type,
framework::ToDataType(ctx.Input<Tensor>("Bias")->type()),
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Bias")->type(),
"Bias input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type,
framework::ToDataType(ctx.Input<Tensor>("Mean")->type()),
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Mean")->type(),
"Mean input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType(
ctx.Input<Tensor>("Variance")->type()),
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Variance")->type(),
"Variance input should be of float type");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
......@@ -387,9 +382,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace(), layout, library);
}
};
......
......@@ -145,7 +145,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
LoDTensor* sentenceScores = ctx.Output<LoDTensor>("SentenceScores");
framework::VisitDataType(
framework::ToDataType(scores->at(0).type()),
scores->at(0).type(),
BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores,
beam_size, end_id));
}
......
......@@ -282,8 +282,7 @@ class BeamSearchOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>("pre_ids")->type()),
ctx.Input<framework::LoDTensor>("pre_ids")->type(),
platform::CPUPlace());
return kt;
}
......
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/bilinear_tensor_product_op.h"
namespace ops = paddle::operators;
......
......@@ -48,13 +48,12 @@ class ConditionalOp : public framework::OperatorBase {
if (!(ips.size() == 1UL && ips[0]->IsInitialized())) {
PADDLE_THROW("should have one initialized input as condition");
}
if (!(framework::IsType<bool>(ips[0]->type()) && // NOLINT
ips[0]->numel() == 1)) {
PADDLE_THROW(
PADDLE_ENFORCE(ips[0]->type() == framework::proto::VarType::BOOL &&
ips[0]->numel() == 1,
"condition input's data type should be bool, "
"numel should be 1, actual numel is %d",
ips[0]->numel());
}
bool res = false;
if (platform::is_gpu_place(ips[0]->place())) {
#ifdef PADDLE_WITH_CUDA
......
......@@ -237,7 +237,7 @@ class WhileGradOp : public framework::OperatorBase {
if (var->IsType<LoDTensor>()) {
auto &inside_tensor = var->Get<framework::LoDTensor>();
framework::AttributeMap attrs;
attrs["dtype"] = framework::ToDataType(inside_tensor.type());
attrs["dtype"] = inside_tensor.type();
attrs["shape"] = framework::vectorize2int(inside_tensor.dims());
attrs["value"] = 0.0f;
......
......@@ -92,10 +92,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
}
#endif
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Input")->type());
auto filter_data_type =
framework::ToDataType(ctx.Input<Tensor>("Filter")->type());
auto input_data_type = ctx.Input<Tensor>("Input")->type();
auto filter_data_type = ctx.Input<Tensor>("Filter")->type();
PADDLE_ENFORCE_EQ(input_data_type, filter_data_type,
"input and filter data type should be consistent");
......@@ -360,9 +358,8 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout_, library_);
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout_, library_);
}
} // namespace operators
......
......@@ -95,9 +95,8 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout_, library_);
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout_, library_);
}
void Conv2DTransposeOpMaker::Make() {
......@@ -314,9 +313,8 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout_, library_);
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout_, library_);
}
} // namespace operators
......
......@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/cos_sim_op.h"
namespace ops = paddle::operators;
......
......@@ -118,8 +118,7 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
return framework::OpKernelType(ctx.Input<LoDTensor>("Emission")->type(),
platform::CPUPlace());
}
};
......
......@@ -51,8 +51,7 @@ class CropOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
};
......@@ -174,9 +173,7 @@ class CropOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))
->type()),
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
ctx.device_context());
}
};
......
......@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/crop_op.h"
namespace ops = paddle::operators;
......
......@@ -57,8 +57,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
}
};
......@@ -111,8 +110,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
}
};
......
......@@ -36,8 +36,7 @@ class CTCAlignOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.device_context());
}
};
......
......@@ -53,8 +53,7 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),
ctx.device_context());
ctx.Input<framework::Tensor>("Input")->type(), ctx.device_context());
}
};
......
......@@ -45,8 +45,7 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("DistMat")->type()),
return framework::OpKernelType(ctx.Input<LoDTensor>("DistMat")->type(),
platform::CPUPlace());
}
};
......
......@@ -66,8 +66,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),
ctx.GetPlace());
ctx.Input<framework::Tensor>("Input")->type(), ctx.GetPlace());
}
};
......
......@@ -66,8 +66,7 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Anchors")->type()),
return framework::OpKernelType(ctx.Input<Tensor>("Anchors")->type(),
ctx.device_context());
}
};
......
......@@ -249,8 +249,7 @@ class MineHardExamplesOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("ClsLoss")->type()),
platform::CPUPlace());
ctx.Input<framework::Tensor>("ClsLoss")->type(), platform::CPUPlace());
}
};
......
......@@ -65,8 +65,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>("Scores")->type()),
ctx.Input<framework::LoDTensor>("Scores")->type(),
platform::CPUPlace());
}
};
......
......@@ -72,8 +72,7 @@ class PriorBoxOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),
ctx.device_context());
ctx.Input<framework::Tensor>("Input")->type(), ctx.device_context());
}
};
......
......@@ -498,8 +498,7 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.device_context());
}
};
......@@ -519,8 +518,7 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.device_context());
}
};
......
......@@ -78,8 +78,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>("Anchor")->type()),
ctx.Input<framework::LoDTensor>("Anchor")->type(),
platform::CPUPlace());
}
};
......
......@@ -57,8 +57,7 @@ class TargetAssignOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
};
......
......@@ -71,8 +71,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::Tensor>("DetectRes")->type()),
ctx.Input<framework::Tensor>("DetectRes")->type(),
platform::CPUPlace());
}
};
......
......@@ -122,8 +122,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>();
ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
size_t rows_memory_size =
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
size_t rows_memory_size = slr->rows().size() * sizeof(int64_t);
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
slices[2] = ::grpc::Slice(e2.size());
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
......
......@@ -61,8 +61,7 @@ TensorPayload GetTensorPayload(framework::Variable* var,
auto tensor = var->Get<framework::LoDTensor>();
// FIXME(wuyi): data types in send_recv.proto is copied from
// framework.proto
request->set_data_type(
static_cast<VarMsg::Type>(framework::ToDataType(tensor.type())));
request->set_data_type(static_cast<VarMsg::Type>(tensor.type()));
for (auto& dim : framework::vectorize(tensor.dims())) {
request->add_dims(dim);
}
......@@ -83,8 +82,7 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx,
VarMsg* request) {
auto* slr = var->GetMutable<framework::SelectedRows>();
request->set_data_type(
static_cast<VarMsg::Type>(framework::ToDataType(slr->value().type())));
request->set_data_type(static_cast<VarMsg::Type>(slr->value().type()));
request->set_lod_level(0);
request->set_slr_height(slr->height());
......
......@@ -58,18 +58,19 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx,
VarMsg* request);
inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) {
inline framework::proto::VarType::Type ToVarType(
sendrecv::VariableMessage::Type type) {
switch (type) {
case sendrecv::VariableMessage::FP32:
return typeid(float); // NOLINT
return framework::proto::VarType::FP32; // NOLINT
case sendrecv::VariableMessage::FP64:
return typeid(double); // NOLINT
return framework::proto::VarType::FP64; // NOLINT
case sendrecv::VariableMessage::INT32:
return typeid(int); // NOLINT
return framework::proto::VarType::INT32; // NOLINT
case sendrecv::VariableMessage::INT64:
return typeid(int64_t); // NOLINT
return framework::proto::VarType::INT64; // NOLINT
case sendrecv::VariableMessage::BOOL:
return typeid(bool); // NOLINT
return framework::proto::VarType::BOOL; // NOLINT
default:
PADDLE_THROW("Not support type %d", type);
}
......
......@@ -114,7 +114,7 @@ bool VariableResponse::CopyLodTensorData(
tensor->set_lod(lod);
void* tensor_data =
tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type()));
tensor->mutable_data(ctx.GetPlace(), ToVarType(meta_.data_type()));
VLOG(6) << "Tensor.memory_size = " << tensor->memory_size()
<< ", Buffer Size = " << length;
......@@ -139,13 +139,13 @@ bool VariableResponse::CopySelectRowsTensorData(
slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value();
tensor->Resize(dims);
PADDLE_ENFORCE_EQ(static_cast<size_t>(tensor->numel()),
length / framework::SizeOfType(
paddle::operators::distributed::ToTypeIndex(
PADDLE_ENFORCE_EQ(
static_cast<size_t>(tensor->numel()),
length / framework::SizeOfType(paddle::operators::distributed::ToVarType(
meta_.data_type())));
void* tensor_data = tensor->mutable_data(
ctx.GetPlace(),
paddle::operators::distributed::ToTypeIndex(meta_.data_type()));
paddle::operators::distributed::ToVarType(meta_.data_type()));
if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
return false;
......@@ -159,8 +159,7 @@ bool VariableResponse::CopySelectRowsData(
const platform::DeviceContext& ctx, int length) {
auto* slr = GetVar()->GetMutable<framework::SelectedRows>();
slr->mutable_rows()->clear();
slr->mutable_rows()->resize(length /
framework::SizeOfType(typeid(int64_t))); // int64
slr->mutable_rows()->resize(length / sizeof(int64_t)); // int64
int64_t* rows_data = slr->mutable_rows()->data();
// copy rows CPU data, GPU data will be copied lazily.
......
......@@ -108,9 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.MultiInput<framework::Tensor>("X").front()->type()),
ctx.GetPlace());
ctx.MultiInput<framework::Tensor>("X").front()->type(), ctx.GetPlace());
}
};
......
......@@ -42,9 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.MultiInput<framework::Tensor>("X")[0]->type()),
ctx.GetPlace());
ctx.MultiInput<framework::Tensor>("X")[0]->type(), ctx.GetPlace());
}
};
......
......@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
......
......@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/platform/float16.h"
......
......@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
namespace ops = paddle::operators;
......
......@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
namespace ops = paddle::operators;
......
......@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise/elementwise_min_op.h"
namespace ops = paddle::operators;
......
......@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
namespace ops = paddle::operators;
......
......@@ -197,8 +197,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type = framework::ToDataType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type());
auto input_data_type =
ctx.Input<Tensor>(framework::GradVarName("Out"))->type();
#ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) {
......
......@@ -8,8 +8,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise/elementwise_pow_op.h"
namespace ops = paddle::operators;
......
......@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
namespace ops = paddle::operators;
......
......@@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/expand_op.h"
namespace ops = paddle::operators;
......
......@@ -115,8 +115,7 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
};
......@@ -175,8 +174,7 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
};
......
......@@ -79,9 +79,8 @@ framework::OpKernelType FCOp::GetExpectedKernelType(
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout, library);
}
void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const {
......@@ -111,9 +110,8 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout, library);
}
void FCOpMaker::Make() {
......
......@@ -59,9 +59,9 @@ class FillConstantOp : public framework::OperatorBase {
if (force_cpu) {
auto cpu = platform::CPUPlace();
tensor->mutable_data(cpu, framework::ToTypeIndex(data_type));
tensor->mutable_data(cpu, data_type);
} else {
tensor->mutable_data(dev_place, framework::ToTypeIndex(data_type));
tensor->mutable_data(dev_place, data_type);
}
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
......
......@@ -55,7 +55,7 @@ class FillOp : public framework::OperatorBase {
static_cast<framework::proto::VarType::Type>(Attr<int>("dtype"));
platform::CPUPlace cpu;
auto force_cpu = Attr<bool>("force_cpu");
out.mutable_data(force_cpu ? cpu : place, framework::ToTypeIndex(dtype));
out.mutable_data(force_cpu ? cpu : place, dtype);
framework::LoDTensor tensor;
......@@ -64,7 +64,7 @@ class FillOp : public framework::OperatorBase {
} else {
// Always make tensor in CPU memory.
tensor.Resize(out.dims());
tensor.mutable_data(cpu, framework::ToTypeIndex(dtype));
tensor.mutable_data(cpu, dtype);
}
framework::VisitDataType(
......
......@@ -135,9 +135,8 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx.Input<framework::Tensor>("X")->type(),
ctx.Input<framework::Tensor>("Y")->type(),
"The element's type of input should be the same.");
auto input_data_type =
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.GetPlace());
}
};
......@@ -324,9 +323,8 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type_index = ctx.Input<framework::Tensor>("Y")->type();
auto input_data_type = framework::ToDataType(input_data_type_index);
return framework::OpKernelType(input_data_type, ctx.GetPlace());
return framework::OpKernelType(ctx.Input<framework::Tensor>("Y")->type(),
ctx.GetPlace());
}
};
} // namespace operators
......
......@@ -115,8 +115,7 @@ void FusedEmbeddingFCLSTMOp::InferShape(
framework::OpKernelType FusedEmbeddingFCLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>("Embeddings")->type()),
ctx.Input<framework::LoDTensor>("Embeddings")->type(),
ctx.device_context());
}
......
......@@ -93,8 +93,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
......
......@@ -117,8 +117,7 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
......
......@@ -61,8 +61,7 @@ void FusionSeqConvEltAddReluOp::InferShape(
framework::OpKernelType FusionSeqConvEltAddReluOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
......
......@@ -67,8 +67,7 @@ void FusionSeqExpandConcatFCOp::InferShape(
framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.MultiInput<LoDTensor>("X")[0]->type()),
return framework::OpKernelType(ctx.MultiInput<LoDTensor>("X")[0]->type(),
ctx.device_context());
}
......
......@@ -42,8 +42,7 @@ class GatherOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
}
};
......@@ -60,8 +59,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
}
};
......
......@@ -63,8 +63,8 @@ class GridSampleOp : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_);
}
};
......@@ -159,8 +159,8 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_);
}
};
......
......@@ -141,8 +141,7 @@ class GroupNormGradOp : public framework::OperatorWithKernel {
if (t == nullptr) {
PADDLE_THROW("can't find Y@GRAD");
}
return framework::OpKernelType(framework::ToDataType(t->type()),
ctx.GetPlace());
return framework::OpKernelType(t->type(), ctx.GetPlace());
}
};
......
......@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/gru_unit_op.h"
namespace ops = paddle::operators;
......
......@@ -76,8 +76,7 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.GetPlace());
}
};
......@@ -162,8 +161,7 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.GetPlace());
}
};
......
......@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/hinge_loss_op.h"
namespace ops = paddle::operators;
......
......@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/huber_loss_op.h"
namespace ops = paddle::operators;
......
......@@ -11,8 +11,6 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/im2sequence_op.h"
namespace ops = paddle::operators;
......
......@@ -55,8 +55,8 @@ class InterpolateOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace());
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
}
};
......@@ -124,8 +124,8 @@ class InterpolateOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace());
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
}
};
......
......@@ -35,8 +35,7 @@ class IsEmptyOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
platform::CPUPlace());
ctx.Input<framework::LoDTensor>("X")->type(), platform::CPUPlace());
return kt;
}
};
......
......@@ -40,10 +40,9 @@ class OverflowOp : public framework::OperatorWithKernel {
int dtype = -1;
auto *x_var = ctx.InputVar("X");
if (x_var->IsType<framework::LoDTensor>()) {
dtype = framework::ToDataType(x_var->Get<framework::LoDTensor>().type());
dtype = x_var->Get<framework::LoDTensor>().type();
} else if (x_var->IsType<framework::SelectedRows>()) {
dtype = framework::ToDataType(
x_var->Get<framework::SelectedRows>().value().type());
dtype = x_var->Get<framework::SelectedRows>().value().type();
} else {
PADDLE_THROW("Cannot find the input data type by all input data");
}
......
......@@ -11,8 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/isfinite_op.h"
#include "paddle/fluid/platform/float16.h"
......
......@@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/l1_norm_op.h"
namespace ops = paddle::operators;
......
......@@ -153,8 +153,7 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
if (t == nullptr) {
PADDLE_THROW("can't find Y@GRAD");
}
return framework::OpKernelType(framework::ToDataType(t->type()),
ctx.GetPlace());
return framework::OpKernelType(t->type(), ctx.GetPlace());
}
};
......
......@@ -184,8 +184,7 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
// is determined by its input "Emission".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
return framework::OpKernelType(ctx.Input<LoDTensor>("Emission")->type(),
platform::CPUPlace());
}
};
......@@ -244,9 +243,7 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<LoDTensor>(framework::GradVarName("LogLikelihood"))
->type()),
ctx.Input<LoDTensor>(framework::GradVarName("LogLikelihood"))->type(),
platform::CPUPlace());
}
};
......
......@@ -60,7 +60,7 @@ class LoadCombineOp : public framework::OperatorBase {
// Get data from fin to tensor
DeserializeFromStream(fin, tensor, dev_ctx);
auto in_dtype = framework::ToDataType(tensor->type());
auto in_dtype = tensor->type();
auto out_dtype =
load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
......
......@@ -65,7 +65,7 @@ class LoadOp : public framework::OperatorBase {
DeserializeFromStream(fin, tensor, dev_ctx);
auto load_as_fp16 = Attr<bool>("load_as_fp16");
auto in_dtype = framework::ToDataType(tensor->type());
auto in_dtype = tensor->type();
auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册