提交 9bd70a1e 编写于 作者: Y Yu Yang

Change tensor uses proto::VarType::type

test=develop
上级 5e609069
......@@ -85,7 +85,7 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
out->mutable_data(expected_kernel_type.place_, in.type());
framework::VisitDataType(
framework::ToDataType(in.type()),
in.type(),
CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out));
out->set_layout(expected_kernel_type.data_layout_);
......@@ -101,7 +101,7 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
case mkldnn::memory::data_type::f32:
return platform::to_void_cast(tensor.data<float>());
case mkldnn::memory::data_type::s8:
return platform::to_void_cast(tensor.data<char>());
return platform::to_void_cast(tensor.data<int8_t>());
case mkldnn::memory::data_type::u8:
return platform::to_void_cast(tensor.data<unsigned char>());
case mkldnn::memory::data_type::s16:
......@@ -144,7 +144,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
memory::data_type in_type = ToMKLDNNDataType(in.type());
PADDLE_ENFORCE(in_type != memory::data_type::data_undef,
"Input tensor type is not supported: ", in.type().name());
"Input tensor type is not supported: %s", in.type());
memory::data_type out_type = in_type;
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
......
......@@ -50,14 +50,14 @@ inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) {
}
}
inline MKLDNNDataType ToMKLDNNDataType(const std::type_index type) {
static const std::map<std::type_index, MKLDNNDataType> dict{
{std::type_index(typeid(float)), MKLDNNDataType::f32}, // NOLINT
{std::type_index(typeid(char)), MKLDNNDataType::s8}, // NOLINT
{std::type_index(typeid(unsigned char)), MKLDNNDataType::u8},
{std::type_index(typeid(int16_t)), MKLDNNDataType::s16},
{std::type_index(typeid(int32_t)), MKLDNNDataType::s32}};
auto iter = dict.find(type);
inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
static std::unordered_map<int, MKLDNNDataType> dict{
{DataTypeTrait<float>::DataType, MKLDNNDataType::f32},
{DataTypeTrait<int8_t>::DataType, MKLDNNDataType::s8},
{DataTypeTrait<uint8_t>::DataType, MKLDNNDataType::u8},
{DataTypeTrait<int16_t>::DataType, MKLDNNDataType::s16},
{DataTypeTrait<int32_t>::DataType, MKLDNNDataType::s32}};
auto iter = dict.find(static_cast<int>(type));
if (iter != dict.end()) return iter->second;
return MKLDNNDataType::data_undef;
}
......
......@@ -26,7 +26,7 @@ struct DataTypeMap {
std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_;
std::unordered_map<int, std::type_index> proto_to_cpp_;
std::unordered_map<int, std::string> proto_to_str_;
std::unordered_map<std::type_index, size_t> cpp_to_size_;
std::unordered_map<int, size_t> proto_to_size_;
};
static DataTypeMap* InitDataTypeMap();
......@@ -45,7 +45,7 @@ static inline void RegisterType(DataTypeMap* map,
map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T));
map->cpp_to_proto_.emplace(typeid(T), proto_type);
map->proto_to_str_.emplace(static_cast<int>(proto_type), name);
map->cpp_to_size_.emplace(typeid(T), sizeof(T));
map->proto_to_size_.emplace(static_cast<int>(proto_type), sizeof(T));
}
static DataTypeMap* InitDataTypeMap() {
......@@ -54,17 +54,7 @@ static DataTypeMap* InitDataTypeMap() {
#define RegType(cc_type, proto_type) \
RegisterType<cc_type>(retv, proto_type, #cc_type)
// NOTE: Add your customize type here.
RegType(float16, proto::VarType::FP16);
RegType(float, proto::VarType::FP32);
RegType(double, proto::VarType::FP64);
RegType(int, proto::VarType::INT32);
RegType(int64_t, proto::VarType::INT64);
RegType(bool, proto::VarType::BOOL);
RegType(size_t, proto::VarType::SIZE_T);
RegType(int16_t, proto::VarType::INT16);
RegType(uint8_t, proto::VarType::UINT8);
RegType(int8_t, proto::VarType::INT8);
_ForEachDataType_(RegType);
#undef RegType
return retv;
......@@ -96,12 +86,12 @@ std::string DataTypeToString(const proto::VarType::Type type) {
static_cast<int>(type));
}
size_t SizeOfType(std::type_index type) {
auto it = gDataTypeMap().cpp_to_size_.find(type);
if (it != gDataTypeMap().cpp_to_size_.end()) {
size_t SizeOfType(proto::VarType::Type type) {
auto it = gDataTypeMap().proto_to_size_.find(static_cast<int>(type));
if (it != gDataTypeMap().proto_to_size_.end()) {
return it->second;
}
PADDLE_THROW("Not support %s as tensor type", type.name());
PADDLE_THROW("Not support %s as tensor type", DataTypeToString(type));
}
} // namespace framework
......
......@@ -22,46 +22,59 @@ limitations under the License. */
namespace paddle {
namespace framework {
template <typename T>
struct DataTypeTrait {};
// Stub handle for void
template <>
struct DataTypeTrait<void> {
constexpr static auto DataType = proto::VarType::RAW;
};
#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
callback(cpp_type, ::paddle::framework::proto::VarType::proto_type);
#define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8)
#define DefineDataTypeTrait(cpp_type, proto_type) \
template <> \
struct DataTypeTrait<cpp_type> { \
constexpr static auto DataType = proto_type; \
}
_ForEachDataType_(DefineDataTypeTrait);
#undef DefineDataTypeTrait
extern proto::VarType::Type ToDataType(std::type_index type);
extern std::type_index ToTypeIndex(proto::VarType::Type type);
template <typename Visitor>
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
switch (type) {
case proto::VarType::FP16:
visitor.template apply<platform::float16>();
break;
case proto::VarType::FP32:
visitor.template apply<float>();
break;
case proto::VarType::FP64:
visitor.template apply<double>();
break;
case proto::VarType::INT32:
visitor.template apply<int>();
break;
case proto::VarType::INT64:
visitor.template apply<int64_t>();
break;
case proto::VarType::BOOL:
visitor.template apply<bool>();
break;
case proto::VarType::UINT8:
visitor.template apply<uint8_t>();
break;
case proto::VarType::INT16:
visitor.template apply<int16_t>();
break;
case proto::VarType::INT8:
visitor.template apply<int8_t>();
break;
default:
PADDLE_THROW("Not supported %d", type);
}
#define VisitDataTypeCallback(cpp_type, proto_type) \
do { \
if (type == proto_type) { \
visitor.template apply<cpp_type>(); \
return; \
} \
} while (0)
_ForEachDataType_(VisitDataTypeCallback);
#undef VisitDataTypeCallback
PADDLE_THROW("Not supported %d", type);
}
extern std::string DataTypeToString(const proto::VarType::Type type);
extern size_t SizeOfType(std::type_index type);
extern size_t SizeOfType(proto::VarType::Type type);
inline std::ostream& operator<<(std::ostream& out,
const proto::VarType::Type& type) {
out << DataTypeToString(type);
......
......@@ -26,13 +26,13 @@ TEST(DataType, float16) {
Tensor tensor;
CPUPlace cpu;
tensor.mutable_data(cpu, f::ToTypeIndex(dtype));
tensor.mutable_data(cpu, dtype);
// test fp16 tensor
EXPECT_EQ(tensor.type(), std::type_index(typeid(float16)));
EXPECT_EQ(tensor.type(), f::ToDataType(typeid(float16)));
// test fp16 size
EXPECT_EQ(f::SizeOfType(f::ToTypeIndex(dtype)), 2u);
EXPECT_EQ(f::SizeOfType(dtype), 2u);
// test debug info
std::string type = "float16";
......
......@@ -127,7 +127,7 @@ void AllReduceOpHandle::RunImpl() {
// Reduce All Tensor to trg in CPU
ReduceLoDTensor func(lod_tensors, &trg);
VisitDataType(ToDataType(lod_tensors[0]->type()), func);
VisitDataType(lod_tensors[0]->type(), func);
for (size_t i = 1; i < local_scopes_.size(); ++i) {
auto &scope =
......
......@@ -33,7 +33,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
FuseVarsOpHandle(ir::Node *node, Scope *local_scope,
const platform::Place &place,
const std::unordered_map<std::string, int64_t> &inputs_numel,
const std::type_index &var_type)
const proto::VarType::Type var_type)
: OpHandleBase(node),
local_scope_(local_scope),
place_(place),
......@@ -57,7 +57,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
Scope *local_scope_;
const platform::Place place_;
const std::unordered_map<std::string, int64_t> inputs_numel_;
const std::type_index type_;
const proto::VarType::Type type_;
int64_t total_numel_;
};
} // namespace details
......
......@@ -246,7 +246,7 @@ void ReduceOpHandle::RunImpl() {
if (!FLAGS_cpu_deterministic) {
ReduceLoDTensor func(lod_tensors,
out_var->GetMutable<framework::LoDTensor>());
VisitDataType(ToDataType(lod_tensors[0]->type()), func);
VisitDataType(lod_tensors[0]->type(), func);
} else {
// We sum lod_tensors to reduce_sum_trg which is in local_scopes_0
// here, but it doesn't mean reduce_sum_trg must be in local_scopes_0.
......@@ -256,7 +256,7 @@ void ReduceOpHandle::RunImpl() {
->FindVar(out_var_handle->name_)
->GetMutable<framework::LoDTensor>();
ReduceLoDTensor func(lod_tensors, &reduce_sum_trg);
VisitDataType(ToDataType(lod_tensors[0]->type()), func);
VisitDataType(lod_tensors[0]->type(), func);
auto trg = out_var->GetMutable<framework::LoDTensor>();
if (reduce_sum_trg.data<void>() != trg->data<void>()) {
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/dlpack_tensor.h"
#include "paddle/fluid/framework/data_type.h"
namespace paddle {
namespace framework {
......@@ -36,26 +36,23 @@ static ::DLDataType GetDLDataTypeCode() {
return dtype;
}
static DLDataType GetDLDataTypeFromTypeIndex(const std::type_index &type) {
#define REG_DL_DATA_TYPE(type) \
{ std::type_index(typeid(type)), GetDLDataTypeCode<type>() }
static const std::unordered_map<std::type_index, ::DLDataType>
type_to_dtype_map({
REG_DL_DATA_TYPE(platform::float16), // NOLINT
REG_DL_DATA_TYPE(float), // NOLINT
REG_DL_DATA_TYPE(double), // NOLINT
REG_DL_DATA_TYPE(int), // NOLINT
REG_DL_DATA_TYPE(int64_t), // NOLINT
REG_DL_DATA_TYPE(bool), // NOLINT
REG_DL_DATA_TYPE(size_t), // NOLINT
REG_DL_DATA_TYPE(int16_t), // NOLINT
REG_DL_DATA_TYPE(uint8_t), // NOLINT
REG_DL_DATA_TYPE(int8_t) // NOLINT
});
static std::unordered_map<int, ::DLDataType> CreateDLDataTypeMap() {
static std::unordered_map<int, ::DLDataType> result;
#define REG_DL_DATA_TYPE(cpp_type, proto_type) \
result[static_cast<int>(proto_type)] = GetDLDataTypeCode<cpp_type>()
_ForEachDataType_(REG_DL_DATA_TYPE);
#undef REG_DL_DATA_TYPE
return result;
}
static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) {
static auto type_to_dtype_map = CreateDLDataTypeMap();
static auto type_to_dtype_map_end_it = type_to_dtype_map.end();
auto it = type_to_dtype_map.find(type);
PADDLE_ENFORCE(it != type_to_dtype_map_end_it, "Unsupported data type %s",
type.name());
auto it = type_to_dtype_map.find(static_cast<int>(type));
PADDLE_ENFORCE(it != type_to_dtype_map_end_it, "Unsupported data type %d",
type);
return it->second;
#undef REG_DL_DATA_TYPE
}
......
......@@ -91,23 +91,11 @@ void TestMainLoop() {
}
}
}
TEST(dlpack, test_all) {
#define TestCallback(cpp_type, proto_type) TestMainLoop<cpp_type>()
#define PADDLE_DLPACK_TEST(type) \
TEST(dlpack, test_##type) { TestMainLoop<type>(); }
using float16 = platform::float16;
PADDLE_DLPACK_TEST(float16);
PADDLE_DLPACK_TEST(float);
PADDLE_DLPACK_TEST(double);
PADDLE_DLPACK_TEST(int);
PADDLE_DLPACK_TEST(int64_t);
PADDLE_DLPACK_TEST(bool);
PADDLE_DLPACK_TEST(size_t);
PADDLE_DLPACK_TEST(int16_t);
PADDLE_DLPACK_TEST(uint8_t);
PADDLE_DLPACK_TEST(int8_t);
#undef PADDLE_DLPACK_TEST
_ForEachDataType_(TestCallback);
}
} // namespace framework
} // namespace paddle
......@@ -138,39 +138,19 @@ void print_lod_tensor(std::string var_name, const LoDTensor& lod_tensor) {
std::cout << sstream.str() << std::endl;
}
void print_fetch_var(Scope* scope, std::string var_name) {
const LoDTensor& tensor = scope->FindVar(var_name)->Get<LoDTensor>();
if (std::type_index(tensor.type()) ==
std::type_index(typeid(platform::float16))) {
print_lod_tensor<platform::float16>(var_name, tensor);
} else if (std::type_index(tensor.type()) == std::type_index(typeid(float))) {
print_lod_tensor<float>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(double))) {
print_lod_tensor<double>(var_name, tensor);
} else if (std::type_index(tensor.type()) == std::type_index(typeid(int))) {
print_lod_tensor<int>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int64_t))) {
print_lod_tensor<int64_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) == std::type_index(typeid(bool))) {
print_lod_tensor<bool>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(uint8_t))) {
print_lod_tensor<uint8_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int16_t))) {
print_lod_tensor<int16_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int8_t))) {
print_lod_tensor<int8_t>(var_name, tensor);
} else {
VLOG(1) << "print_fetch_var: unrecognized data type:"
<< tensor.type().name();
}
return;
static void print_fetch_var(Scope* scope, const std::string& var_name) {
auto& tensor = scope->FindVar(var_name)->Get<LoDTensor>();
#define PrintLoDTensorCallback(cpp_type, proto_type) \
do { \
if (tensor.type() == proto_type) { \
print_lod_tensor<cpp_type>(var_name, tensor); \
return; \
} \
} while (0)
_ForEachDataType_(PrintLoDTensorCallback);
VLOG(1) << "print_fetch_var: unrecognized data type:" << tensor.type();
}
void ExecutorThreadWorker::TrainFiles() {
......
......@@ -70,9 +70,9 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
// only print first ten elements
int64_t size = t.numel() < 10 ? t.numel() : 10;
for (int64_t i = 0; i < size; ++i) {
if (IsType<float>(t.type())) {
if (t.type() == proto::VarType::FP32) {
os << t.data<float>()[i] << " ";
} else if (IsType<int64_t>(t.type())) {
} else if (t.type() == proto::VarType::INT64) {
os << t.data<int64_t>()[i] << " ";
} else {
PADDLE_THROW("LoDTensor data type not in [float, int64_t]");
......@@ -387,7 +387,7 @@ void LoDTensor::MergeLoDTensor(
PADDLE_ENFORCE(!lod_tensors.empty());
framework::DDim new_dim = lod_tensors[0]->dims();
std::type_index new_type = lod_tensors[0]->type();
auto new_type = lod_tensors[0]->type();
framework::DataLayout new_layout = lod_tensors[0]->layout();
LoD new_lod = lod_tensors[0]->lod();
for (size_t i = 1; i < lod_tensors.size(); ++i) {
......
......@@ -43,10 +43,9 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
proto::VarType::Type GetDataTypeOfVar(const Variable* var) {
if (var->IsType<framework::LoDTensor>()) {
return framework::ToDataType(var->Get<framework::LoDTensor>().type());
return var->Get<framework::LoDTensor>().type();
} else if (var->IsType<framework::SelectedRows>()) {
return framework::ToDataType(
var->Get<framework::SelectedRows>().value().type());
return var->Get<framework::SelectedRows>().value().type();
} else {
PADDLE_THROW("Var should be LoDTensor or SelectedRows");
}
......@@ -93,13 +92,13 @@ static std::string GetDtype(const Scope& scope, const std::string& name) {
if (UNLIKELY(!tensor.IsInitialized())) {
return "";
}
return DataTypeToString(ToDataType(tensor.type()));
return DataTypeToString(tensor.type());
} else if (var->IsType<SelectedRows>()) {
auto tensor = var->Get<SelectedRows>().value();
if (UNLIKELY(!tensor.IsInitialized())) {
return "uninited";
} else {
return DataTypeToString(ToDataType(tensor.type()));
return DataTypeToString(tensor.type());
}
} else {
return "";
......@@ -686,7 +685,8 @@ static void CheckTensorNANOrInf(const std::string& name,
if (tensor.memory_size() == 0) {
return;
}
if (!IsType<float>(tensor.type()) && !IsType<double>(tensor.type())) {
if (tensor.type() != proto::VarType::FP32 &&
tensor.type() != proto::VarType::FP64) {
return;
}
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
......@@ -879,7 +879,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
t = &(var->Get<SelectedRows>().value());
}
if (t != nullptr) {
int tmp = static_cast<int>(ToDataType(t->type()));
int tmp = static_cast<int>(t->type());
PADDLE_ENFORCE(
tmp == data_type || data_type == -1,
"DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)",
......
......@@ -218,11 +218,11 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
if (index < 0) {
VLOG(5) << "id " << id << " not in the table, return 0";
framework::VisitDataType(
framework::ToDataType(value_->type()),
value_->type(),
TensorFillVisitor(value, i * value_width, value_width, 0.0));
} else {
framework::VisitDataType(
framework::ToDataType(value_->type()),
value_->type(),
TensorCopyVisitor(value, i * value_width, *value_.get(),
index * value_width, value_width));
}
......
......@@ -16,7 +16,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
extern size_t SizeOfType(std::type_index type);
extern size_t SizeOfType(proto::VarType::Type type);
void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
......@@ -31,7 +31,7 @@ size_t Tensor::memory_size() const {
return holder_ == nullptr ? 0UL : holder_->size() - offset_;
}
void* Tensor::mutable_data(platform::Place place, std::type_index type,
void* Tensor::mutable_data(platform::Place place, proto::VarType::Type type,
memory::Allocator::Attr attr,
size_t requested_size) {
type_ = type;
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <paddle/fluid/framework/framework.pb.h>
#include <cstdint>
#include <cstring>
#include <memory>
......@@ -67,7 +68,7 @@ class Tensor {
friend struct EigenVector;
public:
Tensor() : type_(typeid(float)), offset_(0) {}
Tensor() : type_(proto::VarType::FP32), offset_(0) {}
/*! Return a pointer to mutable memory block. */
template <typename T>
......@@ -88,7 +89,7 @@ class Tensor {
memory::Allocator::Attr attr = memory::Allocator::kDefault,
size_t requested_size = 0);
void* mutable_data(platform::Place place, std::type_index type,
void* mutable_data(platform::Place place, proto::VarType::Type type,
memory::Allocator::Attr attr = memory::Allocator::kDefault,
size_t requested_size = 0);
......@@ -138,7 +139,7 @@ class Tensor {
return holder_->place();
}
std::type_index type() const {
proto::VarType::Type type() const {
PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor not initialized yet when Tensor::type() is called.");
return type_;
......@@ -161,7 +162,7 @@ class Tensor {
private:
/*! holds the memory block if allocated. */
std::shared_ptr<memory::Allocation> holder_;
std::type_index type_;
proto::VarType::Type type_;
/**
* @brief points to elements dimensions.
*
......
......@@ -24,9 +24,8 @@ template <typename T>
inline const T* Tensor::data() const {
check_memory_size();
bool valid =
std::is_same<T, void>::value || type_ == std::type_index(typeid(T));
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s",
type_.name());
std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType;
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %d", type_);
return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
......@@ -38,9 +37,8 @@ template <typename T>
inline T* Tensor::data() {
check_memory_size();
bool valid =
std::is_same<T, void>::value || type_ == std::type_index(typeid(T));
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s",
type_.name());
std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType;
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", type_);
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
......@@ -60,7 +58,7 @@ inline T* Tensor::mutable_data(platform::Place place,
size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD");
return reinterpret_cast<T*>(
mutable_data(place, typeid(T), attr, requested_size));
mutable_data(place, DataTypeTrait<T>::DataType, attr, requested_size));
}
inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
......
......@@ -186,8 +186,8 @@ struct AnyDTypeVisitor {
template <typename Predicate, typename DevCtx>
inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor,
const DevCtx& ctx, framework::Tensor* out) {
VisitDataType(ToDataType(tensor.type()), AnyDTypeVisitor<Predicate, DevCtx>(
predicate, tensor, ctx, out));
VisitDataType(tensor.type(), AnyDTypeVisitor<Predicate, DevCtx>(
predicate, tensor, ctx, out));
}
template <typename Predicate>
......@@ -379,7 +379,7 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
// int32_t size
// void* protobuf message
proto::VarType::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type()));
desc.set_data_type(tensor.type());
auto dims = framework::vectorize(tensor.dims());
auto* pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0);
......@@ -461,9 +461,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
tensor->Resize(framework::make_ddim(dims));
void* buf;
auto ctx = platform::CPUDeviceContext();
size_t size =
tensor->numel() *
framework::SizeOfType(framework::ToTypeIndex(desc.data_type()));
size_t size = tensor->numel() * framework::SizeOfType(desc.data_type());
if (platform::is_gpu_place(dev_ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
Tensor cpu_tensor;
......
......@@ -289,10 +289,10 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
auto type = fetch.type();
auto output = &(outputs->at(i));
output->name = fetchs_[idx]->Input("X")[0];
if (type == typeid(float)) {
if (type == framework::proto::VarType::FP32) {
GetFetchOne<float>(fetch, output);
output->dtype = PaddleDType::FLOAT32;
} else if (type == typeid(int64_t)) {
} else if (type == framework::proto::VarType::INT64) {
GetFetchOne<int64_t>(fetch, output);
output->dtype = PaddleDType::INT64;
} else {
......
......@@ -266,10 +266,10 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs,
auto type = fetch.type();
auto output = &(outputs->at(i));
output->name = fetchs_[idx]->Input("X")[0];
if (type == typeid(float)) {
if (type == framework::DataTypeTrait<float>::DataType) {
GetFetchOne<float>(fetch, output);
output->dtype = PaddleDType::FLOAT32;
} else if (type == typeid(int64_t)) {
} else if (type == framework::DataTypeTrait<int64_t>::DataType) {
GetFetchOne<int64_t>(fetch, output);
output->dtype = PaddleDType::INT64;
} else {
......
......@@ -36,10 +36,10 @@ namespace paddle {
PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) {
PaddleTensor pt;
if (t->type() == typeid(int64_t)) {
if (t->type() == framework::proto::VarType::INT64) {
pt.data.Reset(t->data<void>(), t->numel() * sizeof(int64_t));
pt.dtype = PaddleDType::INT64;
} else if (t->type() == typeid(float)) {
} else if (t->type() == framework::proto::VarType::INT32) {
pt.data.Reset(t->data<void>(), t->numel() * sizeof(float));
pt.dtype = PaddleDType::FLOAT32;
} else {
......
......@@ -78,7 +78,7 @@ class AffineGridOp : public framework::OperatorWithKernel {
library = framework::LibraryType::kCUDNN;
}
#endif
auto data_type = framework::ToDataType(ctx.Input<Tensor>("Theta")->type());
auto data_type = ctx.Input<Tensor>("Theta")->type();
return framework::OpKernelType(data_type, ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library);
}
......@@ -188,9 +188,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Theta")->type()),
ctx.GetPlace(), framework::DataLayout::kAnyLayout, library_);
return framework::OpKernelType(ctx.Input<Tensor>("Theta")->type(),
ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_);
}
};
......
......@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL(
int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, size_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
uint8_t>);
......@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL(
int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
size_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
uint8_t>);
......@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL(
int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, size_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
uint8_t>);
......@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL(
int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
size_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
uint8_t>);
......@@ -58,7 +58,7 @@ struct ArrayToLoDFunctor : public boost::static_visitor<void> {
ArrayToLoDFunctorImpl<DeviceContext> functor;
functor.dev_ctx_ = dev_ctx;
functor.prev_functor_ = this;
framework::VisitDataType(framework::ToDataType(out->type()), functor);
framework::VisitDataType(out->type(), functor);
}
};
......@@ -91,7 +91,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
PADDLE_ENFORCE(!x.empty(), "There's no element in the input array.");
int rank = x[0].dims().size();
platform::Place place = x[0].place();
std::type_index data_type = x[0].type();
auto data_type = x[0].type();
int64_t batch_size = x[0].dims()[0];
framework::DDim ins_dims = rank > 1
? framework::slice_ddim(x[0].dims(), 1, rank)
......
......@@ -121,9 +121,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
void AttentionLSTMOpMaker::Make() {
......
......@@ -103,9 +103,8 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("param")->type()),
ctx.GetPlace());
return framework::OpKernelType(ctx.Input<Tensor>("param")->type(),
ctx.GetPlace());
}
};
......
......@@ -72,8 +72,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("X")->type());
auto input_data_type = ctx.Input<Tensor>("X")->type();
// By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
......@@ -81,17 +80,13 @@ class BatchNormOp : public framework::OperatorWithKernel {
if (input_data_type == framework::proto::VarType::FP64) {
bn_param_type = framework::proto::VarType::FP64;
}
PADDLE_ENFORCE_EQ(bn_param_type,
framework::ToDataType(ctx.Input<Tensor>("Scale")->type()),
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Scale")->type(),
"Scale input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type,
framework::ToDataType(ctx.Input<Tensor>("Bias")->type()),
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Bias")->type(),
"Bias input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type,
framework::ToDataType(ctx.Input<Tensor>("Mean")->type()),
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Mean")->type(),
"Mean input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType(
ctx.Input<Tensor>("Variance")->type()),
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Variance")->type(),
"Variance input should be of float type");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
......@@ -413,9 +408,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace(), layout, library);
}
};
......
......@@ -145,7 +145,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
LoDTensor* sentenceScores = ctx.Output<LoDTensor>("SentenceScores");
framework::VisitDataType(
framework::ToDataType(scores->at(0).type()),
scores->at(0).type(),
BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores,
beam_size, end_id));
}
......
......@@ -282,8 +282,7 @@ class BeamSearchOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>("pre_ids")->type()),
ctx.Input<framework::LoDTensor>("pre_ids")->type(),
platform::CPUPlace());
return kt;
}
......
......@@ -47,9 +47,8 @@ class BprLossOp : public framework::OperatorWithKernel {
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
platform::CPUPlace());
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
platform::CPUPlace());
}
};
......@@ -94,9 +93,8 @@ class BprLossGradientOp : public framework::OperatorWithKernel {
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
platform::CPUPlace());
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
platform::CPUPlace());
}
};
......
......@@ -48,13 +48,12 @@ class ConditionalOp : public framework::OperatorBase {
if (!(ips.size() == 1UL && ips[0]->IsInitialized())) {
PADDLE_THROW("should have one initialized input as condition");
}
if (!(framework::IsType<bool>(ips[0]->type()) && // NOLINT
ips[0]->numel() == 1)) {
PADDLE_THROW(
"condition input's data type should be bool, "
"numel should be 1, actual numel is %d",
ips[0]->numel());
}
PADDLE_ENFORCE(ips[0]->type() == framework::proto::VarType::BOOL &&
ips[0]->numel() == 1,
"condition input's data type should be bool, "
"numel should be 1, actual numel is %d",
ips[0]->numel());
bool res = false;
if (platform::is_gpu_place(ips[0]->place())) {
#ifdef PADDLE_WITH_CUDA
......
......@@ -237,7 +237,7 @@ class WhileGradOp : public framework::OperatorBase {
if (var->IsType<LoDTensor>()) {
auto &inside_tensor = var->Get<framework::LoDTensor>();
framework::AttributeMap attrs;
attrs["dtype"] = framework::ToDataType(inside_tensor.type());
attrs["dtype"] = inside_tensor.type();
attrs["shape"] = framework::vectorize2int(inside_tensor.dims());
attrs["value"] = 0.0f;
......
......@@ -95,10 +95,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
}
#endif
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Input")->type());
auto filter_data_type =
framework::ToDataType(ctx.Input<Tensor>("Filter")->type());
auto input_data_type = ctx.Input<Tensor>("Input")->type();
auto filter_data_type = ctx.Input<Tensor>("Filter")->type();
PADDLE_ENFORCE_EQ(input_data_type, filter_data_type,
"input and filter data type should be consistent");
......@@ -382,9 +380,9 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout_, library_, customized_type_value);
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout_, library_,
customized_type_value);
}
} // namespace operators
......
......@@ -104,9 +104,8 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout_, library_);
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout_, library_);
}
void Conv2DTransposeOpMaker::Make() {
......@@ -335,9 +334,8 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout_, library_);
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout_, library_);
}
} // namespace operators
......
......@@ -118,9 +118,8 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
platform::CPUPlace());
return framework::OpKernelType(ctx.Input<LoDTensor>("Emission")->type(),
platform::CPUPlace());
}
};
} // namespace operators
......
......@@ -51,9 +51,8 @@ class CropOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
};
......@@ -174,9 +173,7 @@ class CropOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))
->type()),
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
ctx.device_context());
}
};
......
......@@ -57,9 +57,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.device_context());
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
}
};
......@@ -111,9 +110,8 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.device_context());
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
}
};
......
......@@ -36,9 +36,8 @@ class CTCAlignOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
ctx.device_context());
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.device_context());
}
};
......
......@@ -53,8 +53,7 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),
ctx.device_context());
ctx.Input<framework::Tensor>("Input")->type(), ctx.device_context());
}
};
......
......@@ -45,9 +45,8 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("DistMat")->type()),
platform::CPUPlace());
return framework::OpKernelType(ctx.Input<LoDTensor>("DistMat")->type(),
platform::CPUPlace());
}
};
......
......@@ -66,8 +66,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),
ctx.GetPlace());
ctx.Input<framework::Tensor>("Input")->type(), ctx.GetPlace());
}
};
......
......@@ -66,9 +66,8 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Anchors")->type()),
ctx.device_context());
return framework::OpKernelType(ctx.Input<Tensor>("Anchors")->type(),
ctx.device_context());
}
};
......
......@@ -249,8 +249,7 @@ class MineHardExamplesOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("ClsLoss")->type()),
platform::CPUPlace());
ctx.Input<framework::Tensor>("ClsLoss")->type(), platform::CPUPlace());
}
};
......
......@@ -65,8 +65,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>("Scores")->type()),
ctx.Input<framework::LoDTensor>("Scores")->type(),
platform::CPUPlace());
}
};
......
......@@ -72,8 +72,7 @@ class PriorBoxOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),
ctx.device_context());
ctx.Input<framework::Tensor>("Input")->type(), ctx.device_context());
}
};
......
......@@ -498,9 +498,8 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.device_context());
}
};
......@@ -519,9 +518,8 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.device_context());
}
};
......
......@@ -78,8 +78,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>("Anchor")->type()),
ctx.Input<framework::LoDTensor>("Anchor")->type(),
platform::CPUPlace());
}
};
......
......@@ -57,9 +57,8 @@ class TargetAssignOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
};
......
......@@ -71,8 +71,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::Tensor>("DetectRes")->type()),
ctx.Input<framework::Tensor>("DetectRes")->type(),
platform::CPUPlace());
}
};
......
......@@ -197,8 +197,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type = framework::ToDataType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type());
auto input_data_type =
ctx.Input<Tensor>(framework::GradVarName("Out"))->type();
#ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) {
......
......@@ -115,9 +115,8 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
};
......@@ -175,9 +174,8 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
};
......
......@@ -79,9 +79,8 @@ framework::OpKernelType FCOp::GetExpectedKernelType(
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout, library);
}
void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const {
......@@ -111,9 +110,8 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout, library);
}
void FCOpMaker::Make() {
......
......@@ -59,9 +59,9 @@ class FillConstantOp : public framework::OperatorBase {
if (force_cpu) {
auto cpu = platform::CPUPlace();
tensor->mutable_data(cpu, framework::ToTypeIndex(data_type));
tensor->mutable_data(cpu, data_type);
} else {
tensor->mutable_data(dev_place, framework::ToTypeIndex(data_type));
tensor->mutable_data(dev_place, data_type);
}
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
......
......@@ -55,7 +55,7 @@ class FillOp : public framework::OperatorBase {
static_cast<framework::proto::VarType::Type>(Attr<int>("dtype"));
platform::CPUPlace cpu;
auto force_cpu = Attr<bool>("force_cpu");
out.mutable_data(force_cpu ? cpu : place, framework::ToTypeIndex(dtype));
out.mutable_data(force_cpu ? cpu : place, dtype);
framework::LoDTensor tensor;
......@@ -64,7 +64,7 @@ class FillOp : public framework::OperatorBase {
} else {
// Always make tensor in CPU memory.
tensor.Resize(out.dims());
tensor.mutable_data(cpu, framework::ToTypeIndex(dtype));
tensor.mutable_data(cpu, dtype);
}
framework::VisitDataType(
......
......@@ -135,9 +135,8 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx.Input<framework::Tensor>("X")->type(),
ctx.Input<framework::Tensor>("Y")->type(),
"The element's type of input should be the same.");
auto input_data_type =
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.GetPlace());
}
};
......@@ -324,9 +323,8 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type_index = ctx.Input<framework::Tensor>("Y")->type();
auto input_data_type = framework::ToDataType(input_data_type_index);
return framework::OpKernelType(input_data_type, ctx.GetPlace());
return framework::OpKernelType(ctx.Input<framework::Tensor>("Y")->type(),
ctx.GetPlace());
}
};
} // namespace operators
......
......@@ -115,8 +115,7 @@ void FusedEmbeddingFCLSTMOp::InferShape(
framework::OpKernelType FusedEmbeddingFCLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>("Embeddings")->type()),
ctx.Input<framework::LoDTensor>("Embeddings")->type(),
ctx.device_context());
}
......
......@@ -93,9 +93,8 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
void FusionGRUOpMaker::Make() {
......
......@@ -117,9 +117,8 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
void FusionLSTMOpMaker::Make() {
......
......@@ -61,9 +61,8 @@ void FusionSeqConvEltAddReluOp::InferShape(
framework::OpKernelType FusionSeqConvEltAddReluOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
void FusionSeqConvEltAddReluOpMaker::Make() {
......
......@@ -67,9 +67,8 @@ void FusionSeqExpandConcatFCOp::InferShape(
framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.MultiInput<LoDTensor>("X")[0]->type()),
ctx.device_context());
return framework::OpKernelType(ctx.MultiInput<LoDTensor>("X")[0]->type(),
ctx.device_context());
}
void FusionSeqExpandConcatFCOpMaker::Make() {
......
......@@ -42,9 +42,8 @@ class GatherOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.device_context());
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
}
};
......@@ -60,9 +59,8 @@ class GatherGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.device_context());
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
}
};
......
......@@ -63,9 +63,9 @@ class GridSampleOp : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_);
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_);
}
};
......@@ -159,9 +159,9 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_);
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_);
}
};
......
......@@ -141,8 +141,7 @@ class GroupNormGradOp : public framework::OperatorWithKernel {
if (t == nullptr) {
PADDLE_THROW("can't find Y@GRAD");
}
return framework::OpKernelType(framework::ToDataType(t->type()),
ctx.GetPlace());
return framework::OpKernelType(t->type(), ctx.GetPlace());
}
};
......
......@@ -76,9 +76,8 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.GetPlace());
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.GetPlace());
}
};
......@@ -163,9 +162,8 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.GetPlace());
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.GetPlace());
}
};
......
......@@ -55,8 +55,8 @@ class InterpolateOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace());
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
}
};
......@@ -124,8 +124,8 @@ class InterpolateOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace());
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
}
};
......
......@@ -35,8 +35,7 @@ class IsEmptyOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
platform::CPUPlace());
ctx.Input<framework::LoDTensor>("X")->type(), platform::CPUPlace());
return kt;
}
};
......
......@@ -40,10 +40,9 @@ class OverflowOp : public framework::OperatorWithKernel {
int dtype = -1;
auto *x_var = ctx.InputVar("X");
if (x_var->IsType<framework::LoDTensor>()) {
dtype = framework::ToDataType(x_var->Get<framework::LoDTensor>().type());
dtype = x_var->Get<framework::LoDTensor>().type();
} else if (x_var->IsType<framework::SelectedRows>()) {
dtype = framework::ToDataType(
x_var->Get<framework::SelectedRows>().value().type());
dtype = x_var->Get<framework::SelectedRows>().value().type();
} else {
PADDLE_THROW("Cannot find the input data type by all input data");
}
......
......@@ -153,8 +153,7 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
if (t == nullptr) {
PADDLE_THROW("can't find Y@GRAD");
}
return framework::OpKernelType(framework::ToDataType(t->type()),
ctx.GetPlace());
return framework::OpKernelType(t->type(), ctx.GetPlace());
}
};
......
......@@ -184,9 +184,8 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
// is determined by its input "Emission".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
platform::CPUPlace());
return framework::OpKernelType(ctx.Input<LoDTensor>("Emission")->type(),
platform::CPUPlace());
}
};
......@@ -244,9 +243,7 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<LoDTensor>(framework::GradVarName("LogLikelihood"))
->type()),
ctx.Input<LoDTensor>(framework::GradVarName("LogLikelihood"))->type(),
platform::CPUPlace());
}
};
......
......@@ -69,7 +69,7 @@ class LoadCombineOp : public framework::OperatorBase {
// Get data from fin to tensor
DeserializeFromStream(*buffer, tensor, dev_ctx);
auto in_dtype = framework::ToDataType(tensor->type());
auto in_dtype = tensor->type();
auto out_dtype =
load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
......
......@@ -65,7 +65,7 @@ class LoadOp : public framework::OperatorBase {
DeserializeFromStream(fin, tensor, dev_ctx);
auto load_as_fp16 = Attr<bool>("load_as_fp16");
auto in_dtype = framework::ToDataType(tensor->type());
auto in_dtype = tensor->type();
auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
......
......@@ -39,9 +39,8 @@ class LoDResetOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
};
......@@ -144,9 +143,8 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
}
};
......
......@@ -72,7 +72,7 @@ struct LoDTensorToArrayFunctor : public boost::static_visitor<void> {
LoDTensorToArrayFunctorImpl<DeviceContext> func;
func.prev_functor_ = this;
func.dev_ctx_ = dev_ctx;
framework::VisitDataType(framework::ToDataType(input_.type()), func);
framework::VisitDataType(input_.type(), func);
}
};
......
......@@ -63,8 +63,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
out_shape[0] = ids_t.numel();
out_t->Resize(out_shape);
out_t->mutable_data(cpu, w_t->value().type());
PADDLE_ENFORCE_EQ(framework::ToDataType(w_t->value().type()),
framework::proto::VarType::FP32,
PADDLE_ENFORCE_EQ(w_t->value().type(), framework::proto::VarType::FP32,
"The sparse table only support FP32");
w_t->Get(ids_t, out_t, true, is_test);
out_t->set_lod(ids_t.lod());
......
......@@ -145,9 +145,8 @@ framework::OpKernelType GetExpectedLRNKernel(
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout_, library_);
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), ctx.GetPlace(),
layout_, library_);
}
} // namespace
......
......@@ -96,8 +96,7 @@ class LSTMOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
ctx.device_context());
ctx.Input<framework::LoDTensor>("Input")->type(), ctx.device_context());
}
};
......@@ -261,8 +260,7 @@ class LSTMGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
ctx.device_context());
ctx.Input<framework::LoDTensor>("Input")->type(), ctx.device_context());
}
};
......
......@@ -113,8 +113,7 @@ class LSTMPOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
ctx.device_context());
ctx.Input<framework::LoDTensor>("Input")->type(), ctx.device_context());
}
};
......@@ -312,8 +311,7 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
ctx.device_context());
ctx.Input<framework::LoDTensor>("Input")->type(), ctx.device_context());
}
};
......
......@@ -77,16 +77,14 @@ template <>
void set_constant_with_place<platform::CPUPlace>(
const platform::DeviceContext& context, framework::Tensor* tensor,
float value) {
framework::VisitDataType(framework::ToDataType(tensor->type()),
TensorSetConstantCPU(tensor, value));
framework::VisitDataType(tensor->type(), TensorSetConstantCPU(tensor, value));
}
template <>
void set_constant_with_place<platform::CUDAPinnedPlace>(
const platform::DeviceContext& context, framework::Tensor* tensor,
float value) {
framework::VisitDataType(framework::ToDataType(tensor->type()),
TensorSetConstantCPU(tensor, value));
framework::VisitDataType(tensor->type(), TensorSetConstantCPU(tensor, value));
}
struct TensorSetConstantWithPlace : public boost::static_visitor<void> {
......
......@@ -65,7 +65,7 @@ template <>
void set_constant_with_place<platform::CUDAPlace>(
const platform::DeviceContext& context, framework::Tensor* tensor,
float value) {
framework::VisitDataType(framework::ToDataType(tensor->type()),
framework::VisitDataType(tensor->type(),
TensorSetConstantGPU(context, tensor, value));
}
......
......@@ -44,9 +44,8 @@ class MeanIoUOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Predictions")->type()),
ctx.GetPlace());
return framework::OpKernelType(ctx.Input<Tensor>("Predictions")->type(),
ctx.GetPlace());
}
};
......
......@@ -61,9 +61,7 @@ class MeanGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("X")->type());
auto input_data_type = ctx.Input<Tensor>("X")->type();
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -63,9 +63,7 @@ class MergeLoDTensorOp : public framework::OperatorBase {
platform::Place place = dev_place;
int64_t batch_size = in_true.dims()[0] + in_false.dims()[0];
std::type_index data_type =
in_true.IsInitialized() ? in_true.type() : in_false.type();
auto data_type = in_true.IsInitialized() ? in_true.type() : in_false.type();
int rank;
framework::DDim in_dims;
if (in_true.IsInitialized()) {
......
......@@ -55,9 +55,8 @@ class AccuracyOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
ctx.GetPlace());
return framework::OpKernelType(ctx.Input<Tensor>("Out")->type(),
ctx.GetPlace());
}
};
......
......@@ -51,9 +51,8 @@ class AucOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Predict")->type()),
platform::CPUPlace());
return framework::OpKernelType(ctx.Input<Tensor>("Predict")->type(),
platform::CPUPlace());
}
};
......
......@@ -82,9 +82,8 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("MaxProbs")->type()),
ctx.device_context());
return framework::OpKernelType(ctx.Input<Tensor>("MaxProbs")->type(),
ctx.device_context());
}
};
......
......@@ -53,9 +53,8 @@ class MultiplexOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()),
ctx.device_context());
return framework::OpKernelType(ctx.MultiInput<Tensor>("X")[0]->type(),
ctx.device_context());
}
};
......@@ -123,9 +122,8 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()),
ctx.device_context());
return framework::OpKernelType(ctx.MultiInput<Tensor>("X")[0]->type(),
ctx.device_context());
}
};
......
......@@ -69,9 +69,8 @@ class NCEOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
platform::CPUPlace());
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
platform::CPUPlace());
}
};
......@@ -214,9 +213,8 @@ class NCEOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
platform::CPUPlace());
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
platform::CPUPlace());
}
};
......
......@@ -70,9 +70,8 @@ class AdadeltaOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(),
ctx.GetPlace());
}
};
......
......@@ -59,9 +59,8 @@ class AdagradOp : public framework::OperatorWithKernel {
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(),
ctx.GetPlace());
}
};
......
......@@ -75,8 +75,7 @@ class AdamOp : public framework::OperatorWithKernel {
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
auto input_data_type = ctx.Input<Tensor>("Param")->type();
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -76,9 +76,8 @@ class AdamaxOp : public framework::OperatorWithKernel {
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(),
ctx.GetPlace());
}
};
......
......@@ -64,9 +64,8 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(),
ctx.GetPlace());
}
};
......
......@@ -66,8 +66,7 @@ class FTRLOp : public framework::OperatorWithKernel {
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
auto input_data_type = ctx.Input<Tensor>("Param")->type();
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -58,9 +58,8 @@ class ProximalAdagradOp : public framework::OperatorWithKernel {
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(),
ctx.GetPlace());
}
};
......
......@@ -46,9 +46,8 @@ class ProximalGDOp : public framework::OperatorWithKernel {
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(),
ctx.GetPlace());
}
};
......
......@@ -511,8 +511,8 @@ class Pad2dOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace());
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
}
};
......@@ -612,8 +612,8 @@ class Pad2dOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace());
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
}
};
......
......@@ -47,9 +47,8 @@ class PadConstantLikeOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Y")->type()),
ctx.device_context());
return framework::OpKernelType(ctx.Input<Tensor>("Y")->type(),
ctx.device_context());
}
};
......@@ -171,9 +170,8 @@ class PadConstantLikeOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Y")->type()),
ctx.device_context());
return framework::OpKernelType(ctx.Input<Tensor>("Y")->type(),
ctx.device_context());
}
};
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册