提交 dd70fb43 编写于 作者: S sneaxiy

fix type comparation bugs

上级 19e877ff
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
......@@ -68,9 +69,9 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
// only print first ten elements
int64_t size = t.numel() < 10 ? t.numel() : 10;
for (int64_t i = 0; i < size; ++i) {
if (t.type().hash_code() == typeid(float).hash_code()) {
if (IsType<float>(t.type())) {
os << t.data<float>()[i] << " ";
} else if (t.type().hash_code() == typeid(int64_t).hash_code()) {
} else if (IsType<int64_t>(t.type())) {
os << t.data<int64_t>()[i] << " ";
} else {
PADDLE_THROW("LoDTensor data type not in [float, int64_t]");
......@@ -384,7 +385,7 @@ void LoDTensor::MergeLoDTensor(
LoD new_lod = lod_tensors[0]->lod();
for (size_t i = 1; i < lod_tensors.size(); ++i) {
auto *t = lod_tensors[i];
PADDLE_ENFORCE_EQ(new_type.hash_code(), t->type().hash_code());
PADDLE_ENFORCE_EQ(new_type, t->type());
PADDLE_ENFORCE_EQ(new_layout, t->layout());
PADDLE_ENFORCE_EQ(framework::product(new_dim) / new_dim[0],
......
......@@ -592,8 +592,7 @@ static void CheckTensorNANOrInf(const std::string& name,
if (tensor.memory_size() == 0) {
return;
}
if (tensor.type().hash_code() != typeid(float).hash_code() && // NOLINT
tensor.type().hash_code() != typeid(double).hash_code()) { // NOLINT
if (!IsType<float>(tensor.type()) && !IsType<double>(tensor.type())) {
return;
}
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
......
......@@ -24,18 +24,24 @@ limitations under the License. */
namespace paddle {
namespace framework {
template <typename T>
bool IsType(const std::type_index& type_index) {
return type_index == std::type_index(typeid(T));
}
inline proto::VarType::Type ToVarType(std::type_index type) {
if (type.hash_code() == typeid(LoDTensor).hash_code()) {
if (IsType<LoDTensor>(type)) {
return proto::VarType_Type_LOD_TENSOR;
} else if (type.hash_code() == typeid(LoDRankTable).hash_code()) {
} else if (IsType<LoDRankTable>(type)) {
return proto::VarType_Type_LOD_RANK_TABLE;
} else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) {
} else if (IsType<LoDTensorArray>(type)) {
return proto::VarType_Type_LOD_TENSOR_ARRAY;
} else if (type.hash_code() == typeid(SelectedRows).hash_code()) {
} else if (IsType<SelectedRows>(type)) {
return proto::VarType_Type_SELECTED_ROWS;
} else if (type.hash_code() == typeid(ReaderHolder).hash_code()) {
} else if (IsType<ReaderHolder>(type)) {
return proto::VarType_Type_READER;
} else if (type.hash_code() == typeid(ChannelHolder).hash_code()) {
} else if (IsType<ChannelHolder>(type)) {
return proto::VarType_Type_CHANNEL;
} else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <cstdio>
#include <string>
#include <typeindex>
#include <unordered_map>
#include <vector>
......@@ -41,7 +42,7 @@ int AccuDims(Vec &&vec, int size) {
return res;
}
#define SET_TYPE(type__) dic_[typeid(type__).hash_code()] = #type__;
#define SET_TYPE(type__) dic_[std::type_index(typeid(type__))] = #type__;
/*
* Map typeid to representation.
*/
......@@ -53,14 +54,14 @@ struct DataTypeNamer {
template <typename T>
const std::string &repr() const {
auto x = typeid(T).hash_code();
auto x = std::type_index(typeid(T));
PADDLE_ENFORCE(dic_.count(x), "unknown type for representation");
return dic_.at(x);
}
const std::string &repr(size_t &hash) const { // NOLINT
PADDLE_ENFORCE(dic_.count(hash), "unknown type for representation");
return dic_.at(hash);
const std::string &repr(const std::type_index &type) const { // NOLINT
PADDLE_ENFORCE(dic_.count(type), "unknown type for representation");
return dic_.at(type);
}
private:
......@@ -71,9 +72,7 @@ struct DataTypeNamer {
SET_TYPE(void *);
}
std::unordered_map<decltype(typeid(int).hash_code()), // NOLINT
std::string>
dic_;
std::unordered_map<std::type_index, std::string> dic_;
};
#undef SET_TYPE
......
......@@ -23,9 +23,9 @@ namespace analysis {
template <>
std::string &NodeAttr::As<std::string>() {
if (data_.empty()) {
type_hash_ = typeid(std::string).hash_code();
type_index_ = std::type_index(typeid(std::string));
}
PADDLE_ENFORCE_EQ(type_hash_, typeid(std::string).hash_code());
PADDLE_ENFORCE_EQ(type_index_, std::type_index(typeid(std::string)));
return data_;
}
......
......@@ -25,6 +25,7 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/inference/analysis/device.h"
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/inference/analysis/helper.h"
......@@ -57,12 +58,12 @@ struct NodeAttr {
// init storage in the first usage.
if (data_.empty()) {
VLOG(4) << "resize data to " << sizeof(T);
type_hash_ = typeid(T).hash_code();
type_index_ = std::type_index(typeid(T));
data_.resize(sizeof(T));
}
PADDLE_ENFORCE(type_hash_ == typeid(T).hash_code(),
PADDLE_ENFORCE(framework::IsType<T>(type_index_),
"type not matched, origin is %s, want %s",
DataTypeNamer::Global().repr(type_hash_),
DataTypeNamer::Global().repr(type_index_),
DataTypeNamer::Global().repr<T>());
PADDLE_ENFORCE_EQ(data_.size(), sizeof(T), "Node attr type recast error");
return *reinterpret_cast<T *>(&data_[0]);
......@@ -70,7 +71,7 @@ struct NodeAttr {
private:
std::string data_;
size_t type_hash_{std::numeric_limits<size_t>::max()};
std::type_index type_index_{typeid(NodeAttr)};
};
/*
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h"
namespace paddle {
namespace operators {
......@@ -47,7 +48,7 @@ class ConditionalOp : public framework::OperatorBase {
if (!(ips.size() == 1UL && ips[0]->IsInitialized())) {
PADDLE_THROW("should have one initialized input as condition");
}
if (!(ips[0]->type().hash_code() == typeid(bool).hash_code() && // NOLINT
if (!(framework::IsType<bool>(ips[0]->type()) && // NOLINT
ips[0]->numel() == 1)) {
PADDLE_THROW(
"condition input's data type should be bool, "
......
......@@ -16,6 +16,7 @@
#include <ctime>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
......@@ -62,7 +63,7 @@ struct Formater {
}
}
void PrintDtype() {
if (dtype.hash_code() != typeid(const char).hash_code()) {
if (!framework::IsType<const char>(dtype)) {
CLOG << "\tdtype: " << dtype.name() << std::endl;
}
}
......@@ -83,15 +84,15 @@ struct Formater {
void PrintData(size_t size) {
PADDLE_ENFORCE_NOT_NULL(data);
// print float
if (dtype.hash_code() == typeid(const float).hash_code()) {
if (framework::IsType<const float>(dtype)) {
Display<float>(size);
} else if (dtype.hash_code() == typeid(const double).hash_code()) {
} else if (framework::IsType<const double>(dtype)) {
Display<double>(size);
} else if (dtype.hash_code() == typeid(const int).hash_code()) {
} else if (framework::IsType<const int>(dtype)) {
Display<int>(size);
} else if (dtype.hash_code() == typeid(const int64_t).hash_code()) {
} else if (framework::IsType<const int64_t>(dtype)) {
Display<int64_t>(size);
} else if (dtype.hash_code() == typeid(const bool).hash_code()) {
} else if (framework::IsType<const bool>(dtype)) {
Display<bool>(size);
} else {
CLOG << "\tdata: unprintable type: " << dtype.name() << std::endl;
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle {
......@@ -135,15 +136,14 @@ class WhileGradOp : public framework::OperatorBase {
auto &og_inside =
detail::Ref(cur_scope.Var(inside_og_name),
"Cannot find inside gradient %s", inside_og_name);
if (og_outside.Type().hash_code() ==
typeid(framework::LoDTensor).hash_code()) {
if (framework::IsType<framework::LoDTensor>(og_outside.Type())) {
auto &outside_tensor = og_outside.Get<framework::LoDTensor>();
auto &inside_tensor =
detail::Ref(og_inside.GetMutable<framework::LoDTensor>());
inside_tensor.set_lod(outside_tensor.lod());
inside_tensor.ShareDataWith(outside_tensor);
} else if (og_outside.Type().hash_code() ==
typeid(framework::LoDTensorArray).hash_code()) {
} else if (framework::IsType<framework::LoDTensorArray>(
og_outside.Type())) {
auto &outside_array = og_outside.Get<framework::LoDTensorArray>();
auto &inside_array =
detail::Ref(og_inside.GetMutable<framework::LoDTensorArray>());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册