提交 b927e72f 编写于 作者: S superjomn

add type compatible check

上级 12db9f3c
......@@ -36,7 +36,9 @@ void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
inst.place, inst.op_type, tmp);
CHECK(type) << "no param type found for " << inst.op_type << ":" << name
<< " " << inst.place;
if (type->tensor_place != in->AsArgument().place) {
CHECK(type->type);
CHECK(in->AsArgument().type);
if (!TypeCompatible(*type->type, *in->AsArgument().type)) {
LOG(INFO) << "found IO unmatched tensor: " << in->AsArgument().name;
}
}
......
......@@ -58,6 +58,7 @@ class Node {
struct Argument {
std::string name;
Place place;
const Type* type;
// Weight is a special kind of argument, it is marked as weight explicitly
// so that some weight related optimization can take place.
bool is_weight{false};
......
......@@ -68,6 +68,7 @@ class VariablePlaceInferencePass : public DebugPass {
auto& arg_node = node->AsArgument();
if (arg_node.place.is_valid()) continue;
UpdatePlace(&arg_node.place, type->tensor_place);
arg_node.type = type->type;
}
}
......@@ -86,6 +87,7 @@ class VariablePlaceInferencePass : public DebugPass {
CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArgument();
if (arg_node.place.is_valid()) continue;
node->AsArgument().type = type->type;
UpdatePlace(&arg_node.place, type->tensor_place);
}
}
......
......@@ -94,6 +94,7 @@ class Type : public DataTypeBase {
TargetType target() const { return place_.target; }
PrecisionType precision() const { return place_.precision; }
DataLayoutType layout() const { return place_.layout; }
short device() const { return place().device; }
const Place& place() const { return place_; }
const std::string& name() const { return name_; }
......@@ -121,9 +122,9 @@ class Type : public DataTypeBase {
Type(ID id, const std::string& name, bool is_tensor,
TargetType target = TargetType::kHost,
PrecisionType precision = PrecisionType::kFloat,
DataLayoutType layout = DataLayoutType::kNCHW)
DataLayoutType layout = DataLayoutType::kNCHW, short device = 0)
: DataTypeBase(id, is_tensor),
place_{target, precision, layout},
place_{target, precision, layout, device},
name_(name) {}
protected:
......@@ -131,6 +132,32 @@ class Type : public DataTypeBase {
const std::string name_;
};
// -------------------------------- compatible check ---------------------------
static bool TargetCompatible(const Type& a, const Type& b) {
return (a.IsVoid() || b.IsVoid()) || //
a.target() == b.target();
}
static bool DataLayoutCompatible(const Type& a, const Type& b) {
return (a.IsVoid() || b.IsVoid()) || //
(a.IsTensor() && b.IsTensor() && a.layout() == b.layout());
}
static bool PrecisionCompatible(const Type& a, const Type& b) {
return (a.IsVoid() || b.IsVoid()) || //
(a.precision() == b.precision());
}
static bool DeviceCompatible(const Type& a, const Type& b) {
return (a.IsVoid() || b.IsVoid()) || //
(a.device() == b.device());
}
static bool TypeCompatible(const Type& a, const Type& b) {
return TargetCompatible(a, b) && DataLayoutCompatible(a, b) &&
PrecisionCompatible(a, b) && DeviceCompatible(a, b);
}
// -------------------------------- predefined types ---------------------------
// TODO(Superjomn) make all the Types' constructs protected to make sure there
// is only one instance across the system.
......@@ -232,14 +259,14 @@ struct ParamType {
// For unsupported types.
size_t element_type_hash{};
Place tensor_place{};
const Type* type_;
const Type* type;
explicit ParamType() = default;
explicit ParamType(size_t element_type_hash)
: element_type_hash(element_type_hash) {}
ParamType(size_t element_type_hash, const Place& place)
: element_type_hash(element_type_hash), tensor_place(place) {}
ParamType(const Type* type) : type_(type) { tensor_place = type_->place(); }
ParamType(const Type* type) : type(type) { tensor_place = type->place(); }
std::string DebugString() const { return tensor_place.DebugString(); }
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册