提交 b927e72f 编写于 作者: S superjomn

add type compatible check

上级 12db9f3c
...@@ -36,7 +36,9 @@ void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) { ...@@ -36,7 +36,9 @@ void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
inst.place, inst.op_type, tmp); inst.place, inst.op_type, tmp);
CHECK(type) << "no param type found for " << inst.op_type << ":" << name CHECK(type) << "no param type found for " << inst.op_type << ":" << name
<< " " << inst.place; << " " << inst.place;
if (type->tensor_place != in->AsArgument().place) { CHECK(type->type);
CHECK(in->AsArgument().type);
if (!TypeCompatible(*type->type, *in->AsArgument().type)) {
LOG(INFO) << "found IO unmatched tensor: " << in->AsArgument().name; LOG(INFO) << "found IO unmatched tensor: " << in->AsArgument().name;
} }
} }
......
...@@ -58,6 +58,7 @@ class Node { ...@@ -58,6 +58,7 @@ class Node {
struct Argument { struct Argument {
std::string name; std::string name;
Place place; Place place;
const Type* type;
// Weight is a special kind of argument, it is marked as weight explicitly // Weight is a special kind of argument, it is marked as weight explicitly
// so that some weight related optimization can take place. // so that some weight related optimization can take place.
bool is_weight{false}; bool is_weight{false};
......
...@@ -68,6 +68,7 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -68,6 +68,7 @@ class VariablePlaceInferencePass : public DebugPass {
auto& arg_node = node->AsArgument(); auto& arg_node = node->AsArgument();
if (arg_node.place.is_valid()) continue; if (arg_node.place.is_valid()) continue;
UpdatePlace(&arg_node.place, type->tensor_place); UpdatePlace(&arg_node.place, type->tensor_place);
arg_node.type = type->type;
} }
} }
...@@ -86,6 +87,7 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -86,6 +87,7 @@ class VariablePlaceInferencePass : public DebugPass {
CHECK(node) << "argument " << arg_name << " not exists in the graph"; CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArgument(); auto& arg_node = node->AsArgument();
if (arg_node.place.is_valid()) continue; if (arg_node.place.is_valid()) continue;
node->AsArgument().type = type->type;
UpdatePlace(&arg_node.place, type->tensor_place); UpdatePlace(&arg_node.place, type->tensor_place);
} }
} }
......
...@@ -94,6 +94,7 @@ class Type : public DataTypeBase { ...@@ -94,6 +94,7 @@ class Type : public DataTypeBase {
TargetType target() const { return place_.target; } TargetType target() const { return place_.target; }
PrecisionType precision() const { return place_.precision; } PrecisionType precision() const { return place_.precision; }
DataLayoutType layout() const { return place_.layout; } DataLayoutType layout() const { return place_.layout; }
short device() const { return place().device; }
const Place& place() const { return place_; } const Place& place() const { return place_; }
const std::string& name() const { return name_; } const std::string& name() const { return name_; }
...@@ -121,9 +122,9 @@ class Type : public DataTypeBase { ...@@ -121,9 +122,9 @@ class Type : public DataTypeBase {
Type(ID id, const std::string& name, bool is_tensor, Type(ID id, const std::string& name, bool is_tensor,
TargetType target = TargetType::kHost, TargetType target = TargetType::kHost,
PrecisionType precision = PrecisionType::kFloat, PrecisionType precision = PrecisionType::kFloat,
DataLayoutType layout = DataLayoutType::kNCHW) DataLayoutType layout = DataLayoutType::kNCHW, short device = 0)
: DataTypeBase(id, is_tensor), : DataTypeBase(id, is_tensor),
place_{target, precision, layout}, place_{target, precision, layout, device},
name_(name) {} name_(name) {}
protected: protected:
...@@ -131,6 +132,32 @@ class Type : public DataTypeBase { ...@@ -131,6 +132,32 @@ class Type : public DataTypeBase {
const std::string name_; const std::string name_;
}; };
// -------------------------------- compatible check ---------------------------
static bool TargetCompatible(const Type& a, const Type& b) {
return (a.IsVoid() || b.IsVoid()) || //
a.target() == b.target();
}
static bool DataLayoutCompatible(const Type& a, const Type& b) {
return (a.IsVoid() || b.IsVoid()) || //
(a.IsTensor() && b.IsTensor() && a.layout() == b.layout());
}
static bool PrecisionCompatible(const Type& a, const Type& b) {
return (a.IsVoid() || b.IsVoid()) || //
(a.precision() == b.precision());
}
static bool DeviceCompatible(const Type& a, const Type& b) {
return (a.IsVoid() || b.IsVoid()) || //
(a.device() == b.device());
}
static bool TypeCompatible(const Type& a, const Type& b) {
return TargetCompatible(a, b) && DataLayoutCompatible(a, b) &&
PrecisionCompatible(a, b) && DeviceCompatible(a, b);
}
// -------------------------------- predefined types --------------------------- // -------------------------------- predefined types ---------------------------
// TODO(Superjomn) make all the Types' constructs protected to make sure there // TODO(Superjomn) make all the Types' constructs protected to make sure there
// is only one instance across the system. // is only one instance across the system.
...@@ -232,14 +259,14 @@ struct ParamType { ...@@ -232,14 +259,14 @@ struct ParamType {
// For unsupported types. // For unsupported types.
size_t element_type_hash{}; size_t element_type_hash{};
Place tensor_place{}; Place tensor_place{};
const Type* type_; const Type* type;
explicit ParamType() = default; explicit ParamType() = default;
explicit ParamType(size_t element_type_hash) explicit ParamType(size_t element_type_hash)
: element_type_hash(element_type_hash) {} : element_type_hash(element_type_hash) {}
ParamType(size_t element_type_hash, const Place& place) ParamType(size_t element_type_hash, const Place& place)
: element_type_hash(element_type_hash), tensor_place(place) {} : element_type_hash(element_type_hash), tensor_place(place) {}
ParamType(const Type* type) : type_(type) { tensor_place = type_->place(); } ParamType(const Type* type) : type(type) { tensor_place = type->place(); }
std::string DebugString() const { return tensor_place.DebugString(); } std::string DebugString() const { return tensor_place.DebugString(); }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册