diff --git a/paddle/fluid/lite/core/mir/io_complement_pass.cc b/paddle/fluid/lite/core/mir/io_complement_pass.cc index 17bbfb948f20789aab83348d740a002539ca6e25..f122bfab6897a61a1e6987bd67ea4af66ed6ba23 100644 --- a/paddle/fluid/lite/core/mir/io_complement_pass.cc +++ b/paddle/fluid/lite/core/mir/io_complement_pass.cc @@ -36,7 +36,9 @@ void IoComplementPass::Apply(std::unique_ptr& graph) { inst.place, inst.op_type, tmp); CHECK(type) << "no param type found for " << inst.op_type << ":" << name << " " << inst.place; - if (type->tensor_place != in->AsArgument().place) { + CHECK(type->type); + CHECK(in->AsArgument().type); + if (!TypeCompatible(*type->type, *in->AsArgument().type)) { LOG(INFO) << "found IO unmatched tensor: " << in->AsArgument().name; } } diff --git a/paddle/fluid/lite/core/mir/node.h b/paddle/fluid/lite/core/mir/node.h index 2077bddc839e54e0f063c9c401378f8a25312fbc..a37c16b62d9f64af194e4e72d14211aa2111bbed 100644 --- a/paddle/fluid/lite/core/mir/node.h +++ b/paddle/fluid/lite/core/mir/node.h @@ -58,6 +58,7 @@ class Node { struct Argument { std::string name; Place place; + const Type* type; // Weight is a special kind of argument, it is marked as weight explicitly // so that some weight related optimization can take place. bool is_weight{false}; diff --git a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h index 608894504d7ef8aaa3a1fbdf58bbe517051381d4..25285025354e79a7a7b027754ccd563b0652e048 100644 --- a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h +++ b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h @@ -68,6 +68,7 @@ class VariablePlaceInferencePass : public DebugPass { auto& arg_node = node->AsArgument(); if (arg_node.place.is_valid()) continue; UpdatePlace(&arg_node.place, type->tensor_place); + arg_node.type = type->type; } } @@ -86,6 +87,7 @@ class VariablePlaceInferencePass : public DebugPass { CHECK(node) << "argument " << arg_name << " not exists in the graph"; auto& arg_node = node->AsArgument(); if (arg_node.place.is_valid()) continue; + node->AsArgument().type = type->type; UpdatePlace(&arg_node.place, type->tensor_place); } } diff --git a/paddle/fluid/lite/core/type_system.h b/paddle/fluid/lite/core/type_system.h index a22c4e59c6fa2bd039f67b1bcac979af83415071..e9aabee51f2708b663a3ee5571327d01944febe8 100644 --- a/paddle/fluid/lite/core/type_system.h +++ b/paddle/fluid/lite/core/type_system.h @@ -94,6 +94,7 @@ class Type : public DataTypeBase { TargetType target() const { return place_.target; } PrecisionType precision() const { return place_.precision; } DataLayoutType layout() const { return place_.layout; } + short device() const { return place().device; } const Place& place() const { return place_; } const std::string& name() const { return name_; } @@ -121,9 +122,9 @@ class Type : public DataTypeBase { Type(ID id, const std::string& name, bool is_tensor, TargetType target = TargetType::kHost, PrecisionType precision = PrecisionType::kFloat, - DataLayoutType layout = DataLayoutType::kNCHW) + DataLayoutType layout = DataLayoutType::kNCHW, short device = 0) : DataTypeBase(id, is_tensor), - place_{target, precision, layout}, + place_{target, precision, layout, device}, name_(name) {} protected: @@ -131,6 +132,32 @@ class Type : public DataTypeBase { const std::string name_; }; +// -------------------------------- compatible check --------------------------- +static bool TargetCompatible(const Type& a, const Type& b) { + return (a.IsVoid() || b.IsVoid()) || // + a.target() == b.target(); +} + +static bool DataLayoutCompatible(const Type& a, const Type& b) { + return (a.IsVoid() || b.IsVoid()) || // + (a.IsTensor() && b.IsTensor() && a.layout() == b.layout()); +} + +static bool PrecisionCompatible(const Type& a, const Type& b) { + return (a.IsVoid() || b.IsVoid()) || // + (a.precision() == b.precision()); +} + +static bool DeviceCompatible(const Type& a, const Type& b) { + return (a.IsVoid() || b.IsVoid()) || // + (a.device() == b.device()); +} + +static bool TypeCompatible(const Type& a, const Type& b) { + return TargetCompatible(a, b) && DataLayoutCompatible(a, b) && + PrecisionCompatible(a, b) && DeviceCompatible(a, b); +} + // -------------------------------- predefined types --------------------------- // TODO(Superjomn) make all the Types' constructs protected to make sure there // is only one instance across the system. @@ -232,14 +259,14 @@ struct ParamType { // For unsupported types. size_t element_type_hash{}; Place tensor_place{}; - const Type* type_; + const Type* type; explicit ParamType() = default; explicit ParamType(size_t element_type_hash) : element_type_hash(element_type_hash) {} ParamType(size_t element_type_hash, const Place& place) : element_type_hash(element_type_hash), tensor_place(place) {} - ParamType(const Type* type) : type_(type) { tensor_place = type_->place(); } + ParamType(const Type* type) : type(type) { tensor_place = type->place(); } std::string DebugString() const { return tensor_place.DebugString(); } };