提交 fe8f47dc 编写于 作者: B buxue

add typeid to type conversion scene

上级 16a75779
...@@ -400,7 +400,7 @@ std::string AnfExporter::GetValueText(const FuncGraphPtr &func_graph, const Valu ...@@ -400,7 +400,7 @@ std::string AnfExporter::GetValueText(const FuncGraphPtr &func_graph, const Valu
} else if (value->isa<tensor::Tensor>()) { } else if (value->isa<tensor::Tensor>()) {
auto tensor_ptr = dyn_cast<tensor::Tensor>(value); auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
oss << value->DumpText() << "@" << DumpObject(tensor_ptr->data(), "T"); oss << value->DumpText() << "@" << DumpObject(tensor_ptr->data(), "T");
} else if (value->isa<parse::Symbol>() || value->isa<None>() || value->isa<NullObj>()) { } else if (value->isa<parse::Symbol>() || value->isa<None>() || value->isa<Null>()) {
oss << value->DumpText(); oss << value->DumpText();
} else if (value->isa<ValueSequeue>()) { } else if (value->isa<ValueSequeue>()) {
oss << GetSequenceText(func_graph, value); oss << GetSequenceText(func_graph, value);
......
...@@ -275,6 +275,11 @@ extern const TypePtr kTypeExternal; ...@@ -275,6 +275,11 @@ extern const TypePtr kTypeExternal;
extern const TypePtr kTypeEnv; extern const TypePtr kTypeEnv;
extern const TypePtr kTypeType; extern const TypePtr kTypeType;
extern const TypePtr kString; extern const TypePtr kString;
extern const TypePtr kList;
extern const TypePtr kTuple;
extern const TypePtr kDict;
extern const TypePtr kSlice;
extern const TypePtr kKeyword;
extern const TypePtr kTensorType; extern const TypePtr kTensorType;
} // namespace mindspore } // namespace mindspore
......
...@@ -18,5 +18,7 @@ ...@@ -18,5 +18,7 @@
namespace mindspore { namespace mindspore {
const TypePtr kTypeNone = std::make_shared<TypeNone>(); const TypePtr kTypeNone = std::make_shared<TypeNone>();
const TypePtr kTypeNull = std::make_shared<TypeNull>();
const TypePtr kTypeEllipsis = std::make_shared<TypeEllipsis>();
const TypePtr kAnyType = std::make_shared<TypeAnything>(); const TypePtr kAnyType = std::make_shared<TypeAnything>();
} // namespace mindspore } // namespace mindspore
...@@ -71,20 +71,22 @@ class TypeNull : public Type { ...@@ -71,20 +71,22 @@ class TypeNull : public Type {
}; };
using TypeNullPtr = std::shared_ptr<TypeNull>; using TypeNullPtr = std::shared_ptr<TypeNull>;
class Ellipsis : public Type { class TypeEllipsis : public Type {
public: public:
Ellipsis() : Type(kMetaTypeEllipsis) {} TypeEllipsis() : Type(kMetaTypeEllipsis) {}
~Ellipsis() override {} ~TypeEllipsis() override {}
MS_DECLARE_PARENT(Ellipsis, Type) MS_DECLARE_PARENT(TypeEllipsis, Type)
TypeId generic_type_id() const override { return kMetaTypeEllipsis; } TypeId generic_type_id() const override { return kMetaTypeEllipsis; }
TypePtr DeepCopy() const override { return std::make_shared<Ellipsis>(); } TypePtr DeepCopy() const override { return std::make_shared<TypeEllipsis>(); }
std::string ToReprString() const override { return "Ellipsis"; } std::string ToReprString() const override { return "Ellipsis"; }
std::string DumpText() const override { return "Ellipsis"; } std::string DumpText() const override { return "Ellipsis"; }
}; };
using EllipsisPtr = std::shared_ptr<Ellipsis>; using TypeEllipsisPtr = std::shared_ptr<TypeEllipsis>;
extern const TypePtr kTypeNone; extern const TypePtr kTypeNone;
extern const TypePtr kTypeNull;
extern const TypePtr kTypeEllipsis;
extern const TypePtr kAnyType; extern const TypePtr kAnyType;
} // namespace mindspore } // namespace mindspore
......
...@@ -95,12 +95,30 @@ TypePtr TypeIdToType(TypeId id) { ...@@ -95,12 +95,30 @@ TypePtr TypeIdToType(TypeId id) {
return kAnyType; return kAnyType;
case kMetaTypeNone: case kMetaTypeNone:
return kTypeNone; return kTypeNone;
case kMetaTypeNull:
return kTypeNull;
case kMetaTypeEllipsis:
return kTypeEllipsis;
case kObjectTypeEnvType: case kObjectTypeEnvType:
return kTypeEnv; return kTypeEnv;
case kObjectTypeRefKey: case kObjectTypeRefKey:
return kRefKeyType; return kRefKeyType;
case kObjectTypeRef: case kObjectTypeRef:
return kRefType; return kRefType;
case kMetaTypeTypeType:
return kTypeType;
case kObjectTypeString:
return kString;
case kObjectTypeList:
return kList;
case kObjectTypeTuple:
return kTuple;
case kObjectTypeDictionary:
return kDict;
case kObjectTypeSlice:
return kSlice;
case kObjectTypeKeyword:
return kKeyword;
case kTypeUnknown: case kTypeUnknown:
return kTypeNone; return kTypeNone;
default: default:
...@@ -274,7 +292,7 @@ TypePtr StringToType(const std::string &type_name) { ...@@ -274,7 +292,7 @@ TypePtr StringToType(const std::string &type_name) {
if (type_name.compare("None") == 0) { if (type_name.compare("None") == 0) {
type = std::make_shared<TypeNone>(); type = std::make_shared<TypeNone>();
} else if (type_name.compare("Ellipsis") == 0) { } else if (type_name.compare("Ellipsis") == 0) {
type = std::make_shared<Ellipsis>(); type = std::make_shared<TypeEllipsis>();
} else if (type_name.compare("TypeType") == 0) { } else if (type_name.compare("TypeType") == 0) {
type = std::make_shared<TypeType>(); type = std::make_shared<TypeType>();
} else if (type_name.compare("SymbolicKeyType") == 0) { } else if (type_name.compare("SymbolicKeyType") == 0) {
...@@ -476,7 +494,7 @@ REGISTER_PYBIND_DEFINE( ...@@ -476,7 +494,7 @@ REGISTER_PYBIND_DEFINE(
(void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init()); (void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init());
(void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init()); (void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init());
(void)py::class_<Slice, Type, std::shared_ptr<Slice>>(m_sub, "Slice").def(py::init()); (void)py::class_<Slice, Type, std::shared_ptr<Slice>>(m_sub, "Slice").def(py::init());
(void)py::class_<Ellipsis, Type, std::shared_ptr<Ellipsis>>(m_sub, "Ellipsis").def(py::init()); (void)py::class_<TypeEllipsis, Type, std::shared_ptr<TypeEllipsis>>(m_sub, "TypeEllipsis").def(py::init());
})); }));
const TypePtr kTypeExternal = std::make_shared<External>(); const TypePtr kTypeExternal = std::make_shared<External>();
...@@ -484,4 +502,9 @@ const TypePtr kTypeEnv = std::make_shared<EnvType>(); ...@@ -484,4 +502,9 @@ const TypePtr kTypeEnv = std::make_shared<EnvType>();
const TypePtr kTypeType = std::make_shared<TypeType>(); const TypePtr kTypeType = std::make_shared<TypeType>();
const TypePtr kTensorType = std::make_shared<TensorType>(); const TypePtr kTensorType = std::make_shared<TensorType>();
const TypePtr kString = std::make_shared<String>(); const TypePtr kString = std::make_shared<String>();
const TypePtr kList = std::make_shared<List>();
const TypePtr kTuple = std::make_shared<Tuple>();
const TypePtr kDict = std::make_shared<Dictionary>();
const TypePtr kSlice = std::make_shared<Slice>();
const TypePtr kKeyword = std::make_shared<Keyword>();
} // namespace mindspore } // namespace mindspore
...@@ -432,7 +432,7 @@ AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) { ...@@ -432,7 +432,7 @@ AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) {
if (default_value == nullptr) { if (default_value == nullptr) {
MS_LOG(EXCEPTION) << "Graph parameter " << name << " not exist"; MS_LOG(EXCEPTION) << "Graph parameter " << name << " not exist";
} }
if (IsValueNode<NullObj>(default_value)) { if (IsValueNode<Null>(default_value)) {
return nullptr; return nullptr;
} }
return default_value; return default_value;
...@@ -440,8 +440,8 @@ AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) { ...@@ -440,8 +440,8 @@ AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) {
// set the default values // set the default values
void FuncGraph::SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list) { void FuncGraph::SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list) {
auto all_is_null = std::all_of(value_list.begin(), value_list.end(), auto all_is_null =
[](const AnfNodePtr &node) { return IsValueNode<NullObj>(node); }); std::all_of(value_list.begin(), value_list.end(), [](const AnfNodePtr &node) { return IsValueNode<Null>(node); });
if (value_list.empty()) { if (value_list.empty()) {
all_is_null = true; all_is_null = true;
} }
...@@ -457,7 +457,7 @@ void FuncGraph::ClearDefaultValues() { parameter_default_value_.clear(); } ...@@ -457,7 +457,7 @@ void FuncGraph::ClearDefaultValues() { parameter_default_value_.clear(); }
size_t FuncGraph::GetDefaultValueCount() { size_t FuncGraph::GetDefaultValueCount() {
int null_count = int null_count =
std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(), std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(),
[](const std::pair<std::string, AnfNodePtr> &pair) { return IsValueNode<NullObj>(pair.second); }); [](const std::pair<std::string, AnfNodePtr> &pair) { return IsValueNode<Null>(pair.second); });
return parameter_default_value_.size() - IntToSize(null_count); return parameter_default_value_.size() - IntToSize(null_count);
} }
......
...@@ -30,9 +30,9 @@ bool Named::operator==(const Value &other) const { ...@@ -30,9 +30,9 @@ bool Named::operator==(const Value &other) const {
abstract::AbstractBasePtr None::ToAbstract() { return std::make_shared<abstract::AbstractNone>(); } abstract::AbstractBasePtr None::ToAbstract() { return std::make_shared<abstract::AbstractNone>(); }
const NamedPtr kNone = std::make_shared<None>(); const NamedPtr kNone = std::make_shared<None>();
abstract::AbstractBasePtr NullObj::ToAbstract() { return std::make_shared<abstract::AbstractNull>(); } abstract::AbstractBasePtr Null::ToAbstract() { return std::make_shared<abstract::AbstractNull>(); }
const NamedPtr kNull = std::make_shared<NullObj>(); const NamedPtr kNull = std::make_shared<Null>();
abstract::AbstractBasePtr EllipsisObj::ToAbstract() { return std::make_shared<abstract::AbstractEllipsis>(); } abstract::AbstractBasePtr Ellipsis::ToAbstract() { return std::make_shared<abstract::AbstractEllipsis>(); }
const NamedPtr kEllipsis = std::make_shared<EllipsisObj>(); const NamedPtr kEllipsis = std::make_shared<Ellipsis>();
} // namespace mindspore } // namespace mindspore
...@@ -71,20 +71,20 @@ class None : public Named { ...@@ -71,20 +71,20 @@ class None : public Named {
}; };
extern const NamedPtr kNone; extern const NamedPtr kNone;
class NullObj : public Named { class Null : public Named {
public: public:
NullObj() : Named("Null") {} Null() : Named("Null") {}
~NullObj() override = default; ~Null() override = default;
MS_DECLARE_PARENT(NullObj, Named); MS_DECLARE_PARENT(Null, Named);
abstract::AbstractBasePtr ToAbstract() override; abstract::AbstractBasePtr ToAbstract() override;
}; };
extern const NamedPtr kNull; extern const NamedPtr kNull;
class EllipsisObj : public Named { class Ellipsis : public Named {
public: public:
EllipsisObj() : Named("Ellipsis") {} Ellipsis() : Named("Ellipsis") {}
~EllipsisObj() override = default; ~Ellipsis() override = default;
MS_DECLARE_PARENT(EllipsisObj, Named); MS_DECLARE_PARENT(Ellipsis, Named);
abstract::AbstractBasePtr ToAbstract() override; abstract::AbstractBasePtr ToAbstract() override;
}; };
extern const NamedPtr kEllipsis; extern const NamedPtr kEllipsis;
......
...@@ -515,11 +515,11 @@ using AbstractNullPtr = std::shared_ptr<AbstractNull>; ...@@ -515,11 +515,11 @@ using AbstractNullPtr = std::shared_ptr<AbstractNull>;
class AbstractEllipsis : public AbstractBase { class AbstractEllipsis : public AbstractBase {
public: public:
AbstractEllipsis() : AbstractBase(kEllipsis) { set_type(std::make_shared<Ellipsis>()); } AbstractEllipsis() : AbstractBase(kEllipsis) { set_type(std::make_shared<TypeEllipsis>()); }
~AbstractEllipsis() override = default; ~AbstractEllipsis() override = default;
MS_DECLARE_PARENT(AbstractEllipsis, AbstractBase) MS_DECLARE_PARENT(AbstractEllipsis, AbstractBase)
TypePtr BuildType() const override { return std::make_shared<Ellipsis>(); } TypePtr BuildType() const override { return std::make_shared<TypeEllipsis>(); }
bool operator==(const AbstractEllipsis &other) const; bool operator==(const AbstractEllipsis &other) const;
bool operator==(const AbstractBase &other) const override; bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override { return std::make_shared<AbstractEllipsis>(); } AbstractBasePtr Clone() const override { return std::make_shared<AbstractEllipsis>(); }
......
...@@ -105,7 +105,7 @@ py::object ValuePtrToPyData(const ValuePtr &value) { ...@@ -105,7 +105,7 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
i++; i++;
} }
ret = rets; ret = rets;
} else if (value->isa<EllipsisObj>()) { } else if (value->isa<Ellipsis>()) {
ret = py::ellipsis(); ret = py::ellipsis();
} else if (value->isa<ValueSlice>()) { } else if (value->isa<ValueSlice>()) {
auto slice = value->cast<ValueSlicePtr>(); auto slice = value->cast<ValueSlicePtr>();
......
...@@ -96,7 +96,7 @@ type_refkey = typing.RefKeyType() ...@@ -96,7 +96,7 @@ type_refkey = typing.RefKeyType()
tensor_type = typing.TensorType tensor_type = typing.TensorType
anything_type = typing.TypeAnything anything_type = typing.TypeAnything
slice_type = typing.Slice slice_type = typing.Slice
ellipsis_type = typing.Ellipsis ellipsis_type = typing.TypeEllipsis
number_type = (int8, number_type = (int8,
int16, int16,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册