diff --git a/mindspore/_extends/parse/standard_method.py b/mindspore/_extends/parse/standard_method.py index 5e0463a074bea656daf9fe0dd71b936d3a3a3df3..6a7aa179e052411d4491c20172a6eb9941ffc430 100644 --- a/mindspore/_extends/parse/standard_method.py +++ b/mindspore/_extends/parse/standard_method.py @@ -28,7 +28,8 @@ from ...ops.composite.base import _append __all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like'] trans = P.Transpose() - +shape_ = P.Shape() +dtype_ = P.DType() def transpose(x): """Implementation of `transpose`.""" diff --git a/mindspore/ccsrc/frontend/operator/ops.h b/mindspore/ccsrc/frontend/operator/ops.h index 46cccfcc47c3b8457ca08766e9a8eaeacf6a0ac2..0404140b73be0cc6e266c4dfe8dd63b48ac2a21b 100755 --- a/mindspore/ccsrc/frontend/operator/ops.h +++ b/mindspore/ccsrc/frontend/operator/ops.h @@ -93,7 +93,6 @@ inline const PrimitivePtr kPrimArrayToScalar = std::make_shared("arra inline const PrimitivePtr kPrimBroadcastShape = std::make_shared("broadcast_shape"); inline const PrimitivePtr kPrimArrayMap = std::make_shared("array_map"); inline const PrimitivePtr kPrimArrayReduce = std::make_shared("array_reduce"); -inline const PrimitivePtr kPrimShape = std::make_shared("Shape"); inline const PrimitivePtr kPrimCast = std::make_shared("Cast"); inline const PrimitivePtr kPrimConcat = std::make_shared("Concat"); inline const PrimitivePtr kPrimSqueeze = std::make_shared("Squeeze"); diff --git a/mindspore/ccsrc/frontend/operator/prim_arrays.cc b/mindspore/ccsrc/frontend/operator/prim_arrays.cc index caaf1d1b2a703c2b6e8704e098e8b8e33e1a3215..1ed97353075e8320d77ce83be05d0df5f831721c 100644 --- a/mindspore/ccsrc/frontend/operator/prim_arrays.cc +++ b/mindspore/ccsrc/frontend/operator/prim_arrays.cc @@ -15,7 +15,6 @@ */ #include "pipeline/jit/static_analysis/prim.h" -#include "frontend/operator/ops.h" #include "abstract/utils.h" #include "frontend/operator/cc_implementations.h" #include "abstract/param_validator.h" @@ -80,23 +79,6 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti return std::make_shared(elems); } -AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - AbstractTensorPtr arg = CheckArg(op_name, args_spec_list, 0); - MS_LOG(DEBUG) << "InferImplShape:" << arg->ToString(); - - AbstractBasePtrList values; - auto shp = arg->shape(); - for (int entry : shp->shape()) { - auto entry_v = MakeValue(entry); - values.push_back(std::make_shared(entry_v, entry_v->type())); - } - return std::make_shared(values); -} - AbstractBasePtr InferImplTile(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: a tensor and a tuple. diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 42539f27edb2b8aa46b7c37462a8d8ca7e2ac2a9..51fcfa7602e66189b2ac571a0a778dedab92735c 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -963,6 +963,7 @@ void ClearResAtexit() { abstract::ClearPrimEvaluatorMap(); compile::ClearConvertCache(); pipeline::GetMethodMap().clear(); + pipeline::GetAttrMap().clear(); pipeline::ExecutorPy::ClearRes(); pipeline::ReclaimOptimizer(); pynative::PynativeExecutor::GetInstance()->ClearRes(); diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc index f31ab2159dbfa43c7bb71d1c21a5d856de76115f..f0466ad8e23cda8771e5e2d5bd1e36baee373012 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.cc +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -17,23 +17,20 @@ */ #include "pipeline/jit/resource.h" -#include "pipeline/jit/pipeline.h" #include "pipeline/jit/static_analysis/static_analysis.h" -#include "debug/draw.h" #include "debug/trace.h" #include "ir/dtype.h" #include "pipeline/jit/parse/data_converter.h" #include "frontend/operator/ops.h" #include "ir/graph_utils.h" #include "frontend/optimizer/ad/dfunctor.h" -#include "vm/segment_runner.h" namespace mindspore { // namespace to support opmap definition namespace pipeline { -MethodMap &GetMethodMap() { - static MethodMap method_map = { +BuiltInTypeMap &GetMethodMap() { + static BuiltInTypeMap method_map = { {kObjectTypeString, { {"__bool__", std::string("str_bool")} // C.str_bool @@ -191,6 +188,15 @@ MethodMap &GetMethodMap() { return method_map; } +BuiltInTypeMap &GetAttrMap() { + static BuiltInTypeMap attr_map = {{kObjectTypeTensorType, + { + {"shape", std::string("shape_")}, // C.shape_ + {"dtype", std::string("dtype_")}, // C.dtype_ + }}}; + return attr_map; +} + Resource::Resource(const py::object &obj) : engine_(std::make_shared(abstract::GetPrimEvaluatorConstructors(), manager_)), input_(obj), @@ -218,31 +224,42 @@ Resource::~Resource() { } } -bool Resource::IsTypeInMethodMap(const TypeId &type) { - TypeId type_id = NormalizeTypeId(type); - const MethodMap &method_map = GetMethodMap(); - auto iter = method_map.find(static_cast(type_id)); - if (iter != method_map.end()) { - return true; +Any GetMethodOrAttr(const string &name, const TypeId &type_id, const BuiltInTypeMap &method_map) { + auto type_method_map = method_map.find(static_cast(type_id)); + if (type_method_map == method_map.end()) { + return Any(); } - return false; + auto method = type_method_map->second.find(name); + if (method == type_method_map->second.end()) { + return Any(); + } + return method->second; } -Any Resource::GetMethodPtr(const TypeId &type, const std::string &name) { +bool Resource::IsTypeInBuiltInMap(const TypeId &type) { TypeId type_id = NormalizeTypeId(type); - const MethodMap &method_map = GetMethodMap(); + const BuiltInTypeMap &method_map = GetMethodMap(); auto iter = method_map.find(static_cast(type_id)); if (iter == method_map.end()) { - MS_LOG(WARNING) << "Object type: " << type_id << " not in the method_map"; - return Any(); + const BuiltInTypeMap &attr_map = GetAttrMap(); + iter = attr_map.find(static_cast(type_id)); + if (iter == attr_map.end()) { + return false; + } } + return true; +} - auto iter_map = iter->second.find(name); - if (iter_map == iter->second.end()) { - MS_LOG(WARNING) << "Object type: " << type_id << " have no method: " << name; - return Any(); - } - return iter_map->second; +Any Resource::GetMethodPtr(const TypeId &type, const std::string &name) { + TypeId type_id = NormalizeTypeId(type); + const BuiltInTypeMap &method_map = GetMethodMap(); + return GetMethodOrAttr(name, type_id, method_map); +} + +Any Resource::GetAttrPtr(const TypeId &type, const std::string &name) { + TypeId type_id = NormalizeTypeId(type); + const BuiltInTypeMap &attr_map = GetAttrMap(); + return GetMethodOrAttr(name, type_id, attr_map); } void Resource::Clean() { diff --git a/mindspore/ccsrc/pipeline/jit/resource.h b/mindspore/ccsrc/pipeline/jit/resource.h index 2e5fda23e48ef38867025994be0af8ee3b1e61ae..243e424d031a6b0eedec12e3bd9202f7c497aaf3 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.h +++ b/mindspore/ccsrc/pipeline/jit/resource.h @@ -44,9 +44,11 @@ const char kOutput[] = "output"; class InferenceResource; -using MethodMap = std::unordered_map>; +using BuiltInTypeMap = std::unordered_map>; -MethodMap &GetMethodMap(); +BuiltInTypeMap &GetMethodMap(); + +BuiltInTypeMap &GetAttrMap(); class ResourceBase { public: @@ -87,10 +89,12 @@ class Resource : public ResourceBase { abstract::AnalysisEnginePtr engine() { return engine_; } - static bool IsTypeInMethodMap(const TypeId &type); + static bool IsTypeInBuiltInMap(const TypeId &type); static Any GetMethodPtr(const TypeId &type, const std::string &name); + static Any GetAttrPtr(const TypeId &type, const std::string &name); + const py::object &input() const { return input_; } FuncGraphPtr func_graph() const { return func_graph_; } diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 0cc887c0a41a361549b3d063f32de549b68d7044..22ef9c25f32d741c40d4fe8093c3c137e9ec93b4 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -21,7 +21,6 @@ #include #include #include -#include #include #include @@ -31,10 +30,8 @@ #include "frontend/operator/prim_to_function.h" #include "abstract/utils.h" #include "utils/symbolic.h" -#include "./common.h" #include "pipeline/jit/resource.h" #include "pipeline/jit/parse/resolve.h" -#include "ir/tensor.h" #include "utils/convert_utils.h" #include "utils/context/ms_context.h" #include "pipeline/jit/parse/data_converter.h" @@ -64,7 +61,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, - {prim::kPrimShape, {InferImplShape, true}}, {prim::kPrimPack, {InferImplPack, true}}, // Structure {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, @@ -634,7 +630,7 @@ EvaluatorPtr InitUniformPrimEvaluator(const PrimitivePtr &primitive, PrimitiveIm } const int kResolveCaseUserDefineClass = 1; -const int kResolveCaseBuildinTypeMethod = 2; +const int kResolveCaseBuiltInType = 2; const int kResolveCaseFunction = 3; int GetResolveCase(const TypePtr &data_type) { MS_EXCEPTION_IF_NULL(data_type); @@ -643,8 +639,8 @@ int GetResolveCase(const TypePtr &data_type) { } // try method map, if not in method map, the data_type should be External type. - if (pipeline::Resource::IsTypeInMethodMap(data_type->type_id())) { - return kResolveCaseBuildinTypeMethod; + if (pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id())) { + return kResolveCaseBuiltInType; } return kResolveCaseFunction; @@ -674,8 +670,10 @@ inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr fun manager->AddFuncGraph(func_graph); } -EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, - const AnfNodeConfigPtr &old_conf) { +enum REQUIRE_TYPE { ATTR, METHOD }; + +EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, const AnfNodeConfigPtr &old_conf, + REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD) { MS_EXCEPTION_IF_NULL(old_conf); AbstractBasePtr abs_ptr = ToAbstract(value, AnalysisContext::DummyContext(), old_conf); @@ -701,6 +699,9 @@ EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_ MS_EXCEPTION_IF_NULL(old_conf); FuncGraphPtr func_graph = old_conf->node()->func_graph(); CNodePtr new_cnode = func_graph->NewCNode(input); + if (require_type == REQUIRE_TYPE::ATTR) { + new_cnode = func_graph->NewCNode({new_cnode}); + } AnalysisEnginePtr eng = old_conf->engine(); AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, old_conf->context()); return eng->ForwardConfig(old_conf, fn_conf); @@ -781,9 +782,9 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng return StaticGetterInferred(converted_v, data_conf, out_conf); } -EvalResultPtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v, - const TypePtr &data_type, const ConfigPtr &data_conf, - const AnfNodeConfigPtr &out_conf) { +EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v, + const TypePtr &data_type, const ConfigPtr &data_conf, + const AnfNodeConfigPtr &out_conf) { MS_EXCEPTION_IF_NULL(item_v); MS_EXCEPTION_IF_NULL(data_type); // The method maybe a Primitive or Composite @@ -792,22 +793,29 @@ EvalResultPtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &eng } std::string item_name = item_v->cast()->value(); - Any method = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name); - if (method.empty()) { - MS_LOG(EXCEPTION) << "Object type: " << data_type->ToString() << " has no method: " << item_name; + REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD; + Any require = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name); + if (require.empty()) { + require = pipeline::Resource::GetAttrPtr(data_type->type_id(), item_name); + if (require.empty()) { + MS_LOG(EXCEPTION) << "The object of type: " << data_type->ToString() << " has no method or attr: " << item_name; + } + require_type = REQUIRE_TYPE::ATTR; } ValuePtr converted_v = nullptr; - if (method.is()) { + if (require.is()) { // composite registered in standard_method_map go to this branch - converted_v = prim::GetPythonOps(method.cast()); - AddToManager(engine, converted_v->cast()); - } else if (method.is()) { - converted_v = method.cast(); + converted_v = prim::GetPythonOps(require.cast()); + if (!converted_v->isa()) { + AddToManager(engine, converted_v->cast()); + } + } else if (require.is()) { + converted_v = require.cast(); } else { - MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from method map, but got " << method.ToString(); + MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from attr or method map, but got " << require.ToString(); } - return StaticGetterInferred(converted_v, data_conf, out_conf); + return StaticGetterInferred(converted_v, data_conf, out_conf, require_type); } EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, @@ -831,8 +839,8 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt int case_v = GetResolveCase(data_type); if (case_v == kResolveCaseUserDefineClass) { return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf); - } else if (case_v == kResolveCaseBuildinTypeMethod) { - return GetEvaluatedValueForBuiltinTypeMethod(engine, item_value, data_type, data_conf, out_conf); + } else if (case_v == kResolveCaseBuiltInType) { + return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, item_value, data_type, data_conf, out_conf); } else { return GetEvaluatedValueForNameSpaceString(engine, args_spec_list, out_conf); } diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h index f56cac85aa341711046a2c1253876597221d406c..2f44c173d0b066f321573c5e3923e81ef7178e0d 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h @@ -218,10 +218,6 @@ AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const P const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, @@ -246,8 +242,6 @@ AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const Primitiv const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 2d074e7367fa78f035b2ee06aa8d5c28281f3912..c8697cf9bcdc1238feceec32b5bde182b8e8a250 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -22,20 +22,21 @@ import copy import functools import itertools import numbers + import numpy as np -from ..._checkparam import Validator as validator -from ..._checkparam import Rel -from ...common import dtype as mstype -from ...common.tensor import Tensor -from ...common.parameter import Parameter -from ..operations.math_ops import _infer_shape_reduce from .._utils import get_concat_offset -from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_op -from ..._c_expression import signature_rw as sig_rw -from ..._c_expression import signature_kind as sig_kind +from ..operations.math_ops import _infer_shape_reduce +from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op from ..._c_expression import signature_dtype as sig_dtype +from ..._c_expression import signature_kind as sig_kind +from ..._c_expression import signature_rw as sig_rw from ..._c_expression import typing +from ..._checkparam import Rel +from ..._checkparam import Validator as validator +from ...common import dtype as mstype +from ...common.parameter import Parameter +from ...common.tensor import Tensor class _ScatterOp(PrimitiveWithInfer): @@ -415,7 +416,7 @@ class Reshape(PrimitiveWithInfer): return out -class Shape(Primitive): +class Shape(PrimitiveWithInfer): """ Returns the shape of input tensor. @@ -436,6 +437,13 @@ class Shape(Primitive): def __init__(self): """init Shape""" + def __infer__(self, x): + validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name) + out = {'shape': (), + 'dtype': mstype.tuple_, + 'value': tuple(x['shape'])} + return out + class Squeeze(PrimitiveWithInfer): """ diff --git a/tests/ut/cpp/operator/ops_test.cc b/tests/ut/cpp/operator/ops_test.cc index 4af615e9c6999a0ee62e52f14573c5b5ddc3e195..796bad8053362ee1ed4ab73c59aa4d9e424530cb 100644 --- a/tests/ut/cpp/operator/ops_test.cc +++ b/tests/ut/cpp/operator/ops_test.cc @@ -267,11 +267,6 @@ TEST_F(TestOps, BroadCastShapeTest) { ASSERT_EQ(prim->name(), kPrimBroadcastShape->name()); } -TEST_F(TestOps, ShapeTest) { - auto prim = std::make_shared("Shape"); - ASSERT_EQ(prim->name(), kPrimShape->name()); -} - TEST_F(TestOps, ArrayMapTest) { auto prim = std::make_shared("array_map"); ASSERT_EQ(prim->name(), kPrimArrayMap->name()); diff --git a/tests/ut/cpp/pipeline/resource_test.cc b/tests/ut/cpp/pipeline/resource_test.cc index b6be393652b1fdc8aa3f1e2fd45866107faca106..f6fe8e52421424d100aa4191c97fb6bc5df4fc2f 100644 --- a/tests/ut/cpp/pipeline/resource_test.cc +++ b/tests/ut/cpp/pipeline/resource_test.cc @@ -36,23 +36,23 @@ class TestResource : public UT::Common { void TearDown() {} }; -TEST_F(TestResource, test_standard_method_map) { - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt8)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt16)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt32)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt64)); +TEST_F(TestResource, test_built_in_type_map) { + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt8)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt16)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt32)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt64)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeFloat)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeFloat16)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeFloat32)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeFloat64)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeFloat)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeFloat16)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeFloat32)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeFloat64)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeBool)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeUInt)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kObjectTypeTuple)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kObjectTypeList)); - ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kObjectTypeTensorType)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeBool)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeUInt)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kObjectTypeTuple)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kObjectTypeList)); + ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kObjectTypeTensorType)); MethodMap& map = GetMethodMap(); for (auto& iter : map) { diff --git a/tests/ut/cpp/pipeline/static_analysis/prim_test.cc b/tests/ut/cpp/pipeline/static_analysis/prim_test.cc index 8ebea4d21226edaca263196d7d8c10296a8dc598..d037a9019c05a9e06ddaeca7ff8aa38cdf317590 100644 --- a/tests/ut/cpp/pipeline/static_analysis/prim_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/prim_test.cc @@ -467,24 +467,6 @@ TEST_F(TestPrim, test_env_add) { ASSERT_TRUE(*res == *exp); } -TEST_F(TestPrim, test_shape) { - PrimitivePtr shap = std::make_shared("Shape"); - FuncGraphPtr func_graph = MakeFuncGraph(shap, 1); - - auto a = UTPrimUtils::ArrayFloat64Of({2, 3}); - - AbstractBasePtrList args_spec_list = {a}; - - AbstractTuplePtr res = dyn_cast(engine_->Run(func_graph, args_spec_list).inferred->abstract()); - auto ret = res->BuildValue()->cast()->value(); - - std::vector element_list = {MakeValue(2), MakeValue(3)}; - ASSERT_TRUE(ret.size() == element_list.size()); - for (int i = 0; i < element_list.size(); i++) { - ASSERT_TRUE(*ret[i] == *element_list[i]); - } -} - TEST_F(TestPrim, test_relu) { PrimitivePtr relu = prim::kPrimRelu; relu->AddAttr("T", MakeValue(static_cast(kNumberTypeFloat64))); diff --git a/tests/ut/python/pipeline/parse/test_dtype_and_shape_as_attr.py b/tests/ut/python/pipeline/parse/test_dtype_and_shape_as_attr.py new file mode 100644 index 0000000000000000000000000000000000000000..94236ac48938e58b364f3b8a99117335f1f84778 --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_dtype_and_shape_as_attr.py @@ -0,0 +1,96 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test dtype and shape as attr""" +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore import dtype as mstype +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE) + + +def test_dtype_and_shape_as_attr(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + shape = x.shape + dtype = x.dtype + return shape, dtype + + + net = Net() + x = Tensor(np.ones([1, 2, 3], np.int32)) + ret = net(x) + assert ret == ((1, 2, 3), mstype.int32) + + +def test_dtype_and_shape_as_attr_to_new_tensor(): + class Net(nn.Cell): + def __init__(self, value): + super(Net, self).__init__() + self.fill = P.Fill() + self.value = value + + def construct(self, x): + dtype = x.dtype + shape = x.shape + y = self.fill(dtype, shape, self.value) + return y + + + net = Net(2.2) + x = Tensor(np.ones([1, 2, 3], np.float32)) + ret = net(x) + assert (ret.asnumpy() == (np.zeros([1, 2, 3], np.float32) + 2.2)).all() + + +def test_type_not_have_the_attr(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + shape = x.shapes + return shape + + + net = Net() + x = Tensor(np.ones([1, 2, 3], np.int32)) + with pytest.raises(RuntimeError) as ex: + net(x) + assert "The object of type: Tensor[Int32] has no method or attr: shapes" in str(ex.value) + + +def test_type_not_have_the_method(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + shape = x.dtypes() + return shape + + + net = Net() + x = Tensor(np.ones([1, 2, 3], np.int32)) + with pytest.raises(RuntimeError) as ex: + net(x) + assert "The object of type: Tensor[Int32] has no method or attr: dtypes" in str(ex.value) diff --git a/tests/ut/python/pipeline/parse/test_super.py b/tests/ut/python/pipeline/parse/test_super.py index 6405b278ae13815052c1d66b81276261c33a7b82..f8734584adfc6ac41bfa74eef671705589b8842d 100644 --- a/tests/ut/python/pipeline/parse/test_super.py +++ b/tests/ut/python/pipeline/parse/test_super.py @@ -20,7 +20,7 @@ import mindspore.nn as nn from mindspore import Tensor from mindspore import context -context.set_context(mode=context.GRAPH_MODE, save_graphs=True) +context.set_context(mode=context.GRAPH_MODE) class FatherNet(nn.Cell): @@ -92,7 +92,6 @@ class Net(nn.Cell): def test_single_super(): single_net = SingleSubNet(2, 3) - context.set_context(mode=context.GRAPH_MODE) x = Tensor(np.ones([1, 2, 3], np.int32)) y = Tensor(np.ones([1, 2, 3], np.int32)) single_net(x, y) @@ -100,7 +99,6 @@ def test_single_super(): def test_mul_super(): mul_net = MulSubNet(2, 3, 4) - context.set_context(mode=context.GRAPH_MODE) x = Tensor(np.ones([1, 2, 3], np.int32)) y = Tensor(np.ones([1, 2, 3], np.int32)) mul_net(x, y) @@ -108,7 +106,6 @@ def test_mul_super(): def test_super_cell(): net = Net(2) - context.set_context(mode=context.GRAPH_MODE) x = Tensor(np.ones([1, 2, 3], np.int32)) y = Tensor(np.ones([1, 2, 3], np.int32)) with pytest.raises(RuntimeError) as er: @@ -142,7 +139,6 @@ def test_single_super_in(): return ret_father_construct, ret_father_test, ret_father_x, ret_sub_z single_net_in = SingleSubNetIN(2, 3) - context.set_context(mode=context.GRAPH_MODE, save_graphs=True) x = Tensor(np.ones([1, 2, 3], np.int32)) y = Tensor(np.ones([1, 2, 3], np.int32)) single_net_in(x, y)