提交 dd70fb43 编写于 作者: S sneaxiy

fix type comparation bugs

上级 19e877ff
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
...@@ -68,9 +69,9 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) { ...@@ -68,9 +69,9 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
// only print first ten elements // only print first ten elements
int64_t size = t.numel() < 10 ? t.numel() : 10; int64_t size = t.numel() < 10 ? t.numel() : 10;
for (int64_t i = 0; i < size; ++i) { for (int64_t i = 0; i < size; ++i) {
if (t.type().hash_code() == typeid(float).hash_code()) { if (IsType<float>(t.type())) {
os << t.data<float>()[i] << " "; os << t.data<float>()[i] << " ";
} else if (t.type().hash_code() == typeid(int64_t).hash_code()) { } else if (IsType<int64_t>(t.type())) {
os << t.data<int64_t>()[i] << " "; os << t.data<int64_t>()[i] << " ";
} else { } else {
PADDLE_THROW("LoDTensor data type not in [float, int64_t]"); PADDLE_THROW("LoDTensor data type not in [float, int64_t]");
...@@ -384,7 +385,7 @@ void LoDTensor::MergeLoDTensor( ...@@ -384,7 +385,7 @@ void LoDTensor::MergeLoDTensor(
LoD new_lod = lod_tensors[0]->lod(); LoD new_lod = lod_tensors[0]->lod();
for (size_t i = 1; i < lod_tensors.size(); ++i) { for (size_t i = 1; i < lod_tensors.size(); ++i) {
auto *t = lod_tensors[i]; auto *t = lod_tensors[i];
PADDLE_ENFORCE_EQ(new_type.hash_code(), t->type().hash_code()); PADDLE_ENFORCE_EQ(new_type, t->type());
PADDLE_ENFORCE_EQ(new_layout, t->layout()); PADDLE_ENFORCE_EQ(new_layout, t->layout());
PADDLE_ENFORCE_EQ(framework::product(new_dim) / new_dim[0], PADDLE_ENFORCE_EQ(framework::product(new_dim) / new_dim[0],
......
...@@ -592,8 +592,7 @@ static void CheckTensorNANOrInf(const std::string& name, ...@@ -592,8 +592,7 @@ static void CheckTensorNANOrInf(const std::string& name,
if (tensor.memory_size() == 0) { if (tensor.memory_size() == 0) {
return; return;
} }
if (tensor.type().hash_code() != typeid(float).hash_code() && // NOLINT if (!IsType<float>(tensor.type()) && !IsType<double>(tensor.type())) {
tensor.type().hash_code() != typeid(double).hash_code()) { // NOLINT
return; return;
} }
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor), PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
......
...@@ -24,18 +24,24 @@ limitations under the License. */ ...@@ -24,18 +24,24 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T>
bool IsType(const std::type_index& type_index) {
return type_index == std::type_index(typeid(T));
}
inline proto::VarType::Type ToVarType(std::type_index type) { inline proto::VarType::Type ToVarType(std::type_index type) {
if (type.hash_code() == typeid(LoDTensor).hash_code()) { if (IsType<LoDTensor>(type)) {
return proto::VarType_Type_LOD_TENSOR; return proto::VarType_Type_LOD_TENSOR;
} else if (type.hash_code() == typeid(LoDRankTable).hash_code()) { } else if (IsType<LoDRankTable>(type)) {
return proto::VarType_Type_LOD_RANK_TABLE; return proto::VarType_Type_LOD_RANK_TABLE;
} else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) { } else if (IsType<LoDTensorArray>(type)) {
return proto::VarType_Type_LOD_TENSOR_ARRAY; return proto::VarType_Type_LOD_TENSOR_ARRAY;
} else if (type.hash_code() == typeid(SelectedRows).hash_code()) { } else if (IsType<SelectedRows>(type)) {
return proto::VarType_Type_SELECTED_ROWS; return proto::VarType_Type_SELECTED_ROWS;
} else if (type.hash_code() == typeid(ReaderHolder).hash_code()) { } else if (IsType<ReaderHolder>(type)) {
return proto::VarType_Type_READER; return proto::VarType_Type_READER;
} else if (type.hash_code() == typeid(ChannelHolder).hash_code()) { } else if (IsType<ChannelHolder>(type)) {
return proto::VarType_Type_CHANNEL; return proto::VarType_Type_CHANNEL;
} else { } else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name()); PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <cstdio> #include <cstdio>
#include <string> #include <string>
#include <typeindex>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -41,7 +42,7 @@ int AccuDims(Vec &&vec, int size) { ...@@ -41,7 +42,7 @@ int AccuDims(Vec &&vec, int size) {
return res; return res;
} }
#define SET_TYPE(type__) dic_[typeid(type__).hash_code()] = #type__; #define SET_TYPE(type__) dic_[std::type_index(typeid(type__))] = #type__;
/* /*
* Map typeid to representation. * Map typeid to representation.
*/ */
...@@ -53,14 +54,14 @@ struct DataTypeNamer { ...@@ -53,14 +54,14 @@ struct DataTypeNamer {
template <typename T> template <typename T>
const std::string &repr() const { const std::string &repr() const {
auto x = typeid(T).hash_code(); auto x = std::type_index(typeid(T));
PADDLE_ENFORCE(dic_.count(x), "unknown type for representation"); PADDLE_ENFORCE(dic_.count(x), "unknown type for representation");
return dic_.at(x); return dic_.at(x);
} }
const std::string &repr(size_t &hash) const { // NOLINT const std::string &repr(const std::type_index &type) const { // NOLINT
PADDLE_ENFORCE(dic_.count(hash), "unknown type for representation"); PADDLE_ENFORCE(dic_.count(type), "unknown type for representation");
return dic_.at(hash); return dic_.at(type);
} }
private: private:
...@@ -71,9 +72,7 @@ struct DataTypeNamer { ...@@ -71,9 +72,7 @@ struct DataTypeNamer {
SET_TYPE(void *); SET_TYPE(void *);
} }
std::unordered_map<decltype(typeid(int).hash_code()), // NOLINT std::unordered_map<std::type_index, std::string> dic_;
std::string>
dic_;
}; };
#undef SET_TYPE #undef SET_TYPE
......
...@@ -23,9 +23,9 @@ namespace analysis { ...@@ -23,9 +23,9 @@ namespace analysis {
template <> template <>
std::string &NodeAttr::As<std::string>() { std::string &NodeAttr::As<std::string>() {
if (data_.empty()) { if (data_.empty()) {
type_hash_ = typeid(std::string).hash_code(); type_index_ = std::type_index(typeid(std::string));
} }
PADDLE_ENFORCE_EQ(type_hash_, typeid(std::string).hash_code()); PADDLE_ENFORCE_EQ(type_index_, std::type_index(typeid(std::string)));
return data_; return data_;
} }
......
...@@ -25,6 +25,7 @@ limitations under the License. */ ...@@ -25,6 +25,7 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/inference/analysis/device.h" #include "paddle/fluid/inference/analysis/device.h"
#include "paddle/fluid/inference/analysis/dot.h" #include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
...@@ -57,12 +58,12 @@ struct NodeAttr { ...@@ -57,12 +58,12 @@ struct NodeAttr {
// init storage in the first usage. // init storage in the first usage.
if (data_.empty()) { if (data_.empty()) {
VLOG(4) << "resize data to " << sizeof(T); VLOG(4) << "resize data to " << sizeof(T);
type_hash_ = typeid(T).hash_code(); type_index_ = std::type_index(typeid(T));
data_.resize(sizeof(T)); data_.resize(sizeof(T));
} }
PADDLE_ENFORCE(type_hash_ == typeid(T).hash_code(), PADDLE_ENFORCE(framework::IsType<T>(type_index_),
"type not matched, origin is %s, want %s", "type not matched, origin is %s, want %s",
DataTypeNamer::Global().repr(type_hash_), DataTypeNamer::Global().repr(type_index_),
DataTypeNamer::Global().repr<T>()); DataTypeNamer::Global().repr<T>());
PADDLE_ENFORCE_EQ(data_.size(), sizeof(T), "Node attr type recast error"); PADDLE_ENFORCE_EQ(data_.size(), sizeof(T), "Node attr type recast error");
return *reinterpret_cast<T *>(&data_[0]); return *reinterpret_cast<T *>(&data_[0]);
...@@ -70,7 +71,7 @@ struct NodeAttr { ...@@ -70,7 +71,7 @@ struct NodeAttr {
private: private:
std::string data_; std::string data_;
size_t type_hash_{std::numeric_limits<size_t>::max()}; std::type_index type_index_{typeid(NodeAttr)};
}; };
/* /*
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -47,7 +48,7 @@ class ConditionalOp : public framework::OperatorBase { ...@@ -47,7 +48,7 @@ class ConditionalOp : public framework::OperatorBase {
if (!(ips.size() == 1UL && ips[0]->IsInitialized())) { if (!(ips.size() == 1UL && ips[0]->IsInitialized())) {
PADDLE_THROW("should have one initialized input as condition"); PADDLE_THROW("should have one initialized input as condition");
} }
if (!(ips[0]->type().hash_code() == typeid(bool).hash_code() && // NOLINT if (!(framework::IsType<bool>(ips[0]->type()) && // NOLINT
ips[0]->numel() == 1)) { ips[0]->numel() == 1)) {
PADDLE_THROW( PADDLE_THROW(
"condition input's data type should be bool, " "condition input's data type should be bool, "
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <ctime> #include <ctime>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
namespace paddle { namespace paddle {
...@@ -62,7 +63,7 @@ struct Formater { ...@@ -62,7 +63,7 @@ struct Formater {
} }
} }
void PrintDtype() { void PrintDtype() {
if (dtype.hash_code() != typeid(const char).hash_code()) { if (!framework::IsType<const char>(dtype)) {
CLOG << "\tdtype: " << dtype.name() << std::endl; CLOG << "\tdtype: " << dtype.name() << std::endl;
} }
} }
...@@ -83,15 +84,15 @@ struct Formater { ...@@ -83,15 +84,15 @@ struct Formater {
void PrintData(size_t size) { void PrintData(size_t size) {
PADDLE_ENFORCE_NOT_NULL(data); PADDLE_ENFORCE_NOT_NULL(data);
// print float // print float
if (dtype.hash_code() == typeid(const float).hash_code()) { if (framework::IsType<const float>(dtype)) {
Display<float>(size); Display<float>(size);
} else if (dtype.hash_code() == typeid(const double).hash_code()) { } else if (framework::IsType<const double>(dtype)) {
Display<double>(size); Display<double>(size);
} else if (dtype.hash_code() == typeid(const int).hash_code()) { } else if (framework::IsType<const int>(dtype)) {
Display<int>(size); Display<int>(size);
} else if (dtype.hash_code() == typeid(const int64_t).hash_code()) { } else if (framework::IsType<const int64_t>(dtype)) {
Display<int64_t>(size); Display<int64_t>(size);
} else if (dtype.hash_code() == typeid(const bool).hash_code()) { } else if (framework::IsType<const bool>(dtype)) {
Display<bool>(size); Display<bool>(size);
} else { } else {
CLOG << "\tdata: unprintable type: " << dtype.name() << std::endl; CLOG << "\tdata: unprintable type: " << dtype.name() << std::endl;
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle { namespace paddle {
...@@ -135,15 +136,14 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -135,15 +136,14 @@ class WhileGradOp : public framework::OperatorBase {
auto &og_inside = auto &og_inside =
detail::Ref(cur_scope.Var(inside_og_name), detail::Ref(cur_scope.Var(inside_og_name),
"Cannot find inside gradient %s", inside_og_name); "Cannot find inside gradient %s", inside_og_name);
if (og_outside.Type().hash_code() == if (framework::IsType<framework::LoDTensor>(og_outside.Type())) {
typeid(framework::LoDTensor).hash_code()) {
auto &outside_tensor = og_outside.Get<framework::LoDTensor>(); auto &outside_tensor = og_outside.Get<framework::LoDTensor>();
auto &inside_tensor = auto &inside_tensor =
detail::Ref(og_inside.GetMutable<framework::LoDTensor>()); detail::Ref(og_inside.GetMutable<framework::LoDTensor>());
inside_tensor.set_lod(outside_tensor.lod()); inside_tensor.set_lod(outside_tensor.lod());
inside_tensor.ShareDataWith(outside_tensor); inside_tensor.ShareDataWith(outside_tensor);
} else if (og_outside.Type().hash_code() == } else if (framework::IsType<framework::LoDTensorArray>(
typeid(framework::LoDTensorArray).hash_code()) { og_outside.Type())) {
auto &outside_array = og_outside.Get<framework::LoDTensorArray>(); auto &outside_array = og_outside.Get<framework::LoDTensorArray>();
auto &inside_array = auto &inside_array =
detail::Ref(og_inside.GetMutable<framework::LoDTensorArray>()); detail::Ref(og_inside.GetMutable<framework::LoDTensorArray>());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册