diff --git a/paddle/fluid/framework/data_layout_transform.h b/paddle/fluid/framework/data_layout_transform.h index 2479de4fd46802148af09d34b627a8804276cacf..2c0a34b881176adf5f2a24a227ca114cc3b4721c 100644 --- a/paddle/fluid/framework/data_layout_transform.h +++ b/paddle/fluid/framework/data_layout_transform.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/framework/tensor.h" @@ -52,11 +53,11 @@ inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) { inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) { static std::unordered_map dict{ - {DataTypeTrait::DataType, MKLDNNDataType::f32}, - {DataTypeTrait::DataType, MKLDNNDataType::s8}, - {DataTypeTrait::DataType, MKLDNNDataType::u8}, - {DataTypeTrait::DataType, MKLDNNDataType::s16}, - {DataTypeTrait::DataType, MKLDNNDataType::s32}}; + {DataTypeTrait::DataType(), MKLDNNDataType::f32}, + {DataTypeTrait::DataType(), MKLDNNDataType::s8}, + {DataTypeTrait::DataType(), MKLDNNDataType::u8}, + {DataTypeTrait::DataType(), MKLDNNDataType::s16}, + {DataTypeTrait::DataType(), MKLDNNDataType::s32}}; auto iter = dict.find(static_cast(type)); if (iter != dict.end()) return iter->second; return MKLDNNDataType::data_undef; diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index 76df78ea5e17c7eaf1e8ce7a7dc2282a5a4ed579..60644820df7cd4133c5fd8f24fe693245d68a5f3 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -28,7 +28,9 @@ struct DataTypeTrait {}; // Stub handle for void template <> struct DataTypeTrait { - constexpr static auto DataType = proto::VarType::RAW; + constexpr static proto::VarType::Type DataType() { + return proto::VarType::RAW; + } }; #define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \ @@ -45,10 +47,10 @@ struct DataTypeTrait { _ForEachDataTypeHelper_(callback, int16_t, INT16); \ _ForEachDataTypeHelper_(callback, int8_t, INT8) -#define DefineDataTypeTrait(cpp_type, proto_type) \ - template <> \ - struct DataTypeTrait { \ - constexpr static auto DataType = proto_type; \ +#define DefineDataTypeTrait(cpp_type, proto_type) \ + template <> \ + struct DataTypeTrait { \ + constexpr static proto::VarType::Type DataType() { return proto_type; } \ } _ForEachDataType_(DefineDataTypeTrait); diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h index a4b1457ad567cf5f1f2788a5c24889c3066c84b0..a5c39b7e923e24e82996402489ea537df08a7d5d 100644 --- a/paddle/fluid/framework/tensor_impl.h +++ b/paddle/fluid/framework/tensor_impl.h @@ -24,10 +24,10 @@ template inline const T* Tensor::data() const { check_memory_size(); bool valid = - std::is_same::value || type_ == DataTypeTrait::DataType; + std::is_same::value || type_ == DataTypeTrait::DataType(); PADDLE_ENFORCE( valid, "Tensor holds the wrong type, it holds %s, but desires to be %s", - DataTypeToString(type_), DataTypeToString(DataTypeTrait::DataType)); + DataTypeToString(type_), DataTypeToString(DataTypeTrait::DataType())); return reinterpret_cast( reinterpret_cast(holder_->ptr()) + offset_); @@ -39,10 +39,10 @@ template inline T* Tensor::data() { check_memory_size(); bool valid = - std::is_same::value || type_ == DataTypeTrait::DataType; + std::is_same::value || type_ == DataTypeTrait::DataType(); PADDLE_ENFORCE( valid, "Tensor holds the wrong type, it holds %s, but desires to be %s", - DataTypeToString(type_), DataTypeToString(DataTypeTrait::DataType)); + DataTypeToString(type_), DataTypeToString(DataTypeTrait::DataType())); return reinterpret_cast(reinterpret_cast(holder_->ptr()) + offset_); } @@ -59,7 +59,7 @@ template inline T* Tensor::mutable_data(platform::Place place, size_t requested_size) { static_assert(std::is_pod::value, "T must be POD"); return reinterpret_cast( - mutable_data(place, DataTypeTrait::DataType, requested_size)); + mutable_data(place, DataTypeTrait::DataType(), requested_size)); } inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 56996c5cff88f5b4a9094291a09996f8b8d70a23..628817c6f4614026566f74510426efb65f740ea5 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -280,13 +280,13 @@ bool NativePaddlePredictor::GetFetch(std::vector *outputs, auto type = fetch.type(); auto output = &(outputs->at(i)); output->name = fetchs_[idx]->Input("X")[0]; - if (type == framework::DataTypeTrait::DataType) { + if (type == framework::DataTypeTrait::DataType()) { GetFetchOne(fetch, output); output->dtype = PaddleDType::FLOAT32; - } else if (type == framework::DataTypeTrait::DataType) { + } else if (type == framework::DataTypeTrait::DataType()) { GetFetchOne(fetch, output); output->dtype = PaddleDType::INT64; - } else if (type == framework::DataTypeTrait::DataType) { + } else if (type == framework::DataTypeTrait::DataType()) { GetFetchOne(fetch, output); output->dtype = PaddleDType::INT32; } else { diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 5c48a8ee8f56e8010a9263d8a2a460a1d22d420c..d2036c611edc69a5cd671165b20377a95c009ac3 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -103,8 +103,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; customized_type_value = - (input_data_type == framework::DataTypeTrait::DataType || - input_data_type == framework::DataTypeTrait::DataType) + (input_data_type == framework::DataTypeTrait::DataType() || + input_data_type == framework::DataTypeTrait::DataType()) ? kConvMKLDNNINT8 : kConvMKLDNNFP32; } diff --git a/paddle/fluid/operators/detection/prior_box_op.cc b/paddle/fluid/operators/detection/prior_box_op.cc index 34ddacb6f5415e18930cbdb711ce3d0182f31308..9709cbc058900cfc64839b450484957b18604583 100644 --- a/paddle/fluid/operators/detection/prior_box_op.cc +++ b/paddle/fluid/operators/detection/prior_box_op.cc @@ -87,10 +87,10 @@ class PriorBoxOp : public framework::OperatorWithKernel { auto input_image_type = ctx.Input("Image")->type(); int customized_type_value = framework::OpKernelType::kDefaultCustomizedTypeValue; - if (input_image_type == framework::DataTypeTrait::DataType) { + if (input_image_type == framework::DataTypeTrait::DataType()) { customized_type_value = kPriorBoxFLOAT; } else if (input_image_type == - framework::DataTypeTrait::DataType) { + framework::DataTypeTrait::DataType()) { customized_type_value = kPriorBoxDOUBLE; } return framework::OpKernelType(input_input_type, ctx.GetPlace(), layout_, diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 824a605f6b1e283240e3a2b71eec071ca51b9fce..876a0b8b60cfc440040370bea680b9a47b20832a 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -358,13 +358,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto dst_dt = unsigned_output ? paddle::framework::ToMKLDNNDataType( - framework::DataTypeTrait::DataType) + framework::DataTypeTrait::DataType()) : paddle::framework::ToMKLDNNDataType( - framework::DataTypeTrait::DataType); + framework::DataTypeTrait::DataType()); if (force_fp32_output) { dst_dt = paddle::framework::ToMKLDNNDataType( - framework::DataTypeTrait::DataType); + framework::DataTypeTrait::DataType()); } if (fuse_residual_conn) { diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index ff4533d9c525c2f5887519edd153d90147115c6e..4697f8c916177bb6bc1bb9ccea32cd73269aeb5d 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -917,7 +917,7 @@ static void SetDstMemoryQuantized( auto dst_md = platform::MKLDNNMemDesc( {dst_tz}, paddle::framework::ToMKLDNNDataType( - framework::DataTypeTrait::DataType), + framework::DataTypeTrait::DataType()), dst_fmt); dst_pd.reset(new mkldnn::memory::primitive_desc(dst_md, engine)); dst_memory.reset(new mkldnn::memory(*dst_pd, to_void_cast(output_data)));