提交 b075674c 编写于 作者: B buxue

support tensor attr shape and dtype in graph mode

上级 fa96dfd1
......@@ -28,7 +28,8 @@ from ...ops.composite.base import _append
__all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like']
trans = P.Transpose()
shape_ = P.Shape()
dtype_ = P.DType()
def transpose(x):
"""Implementation of `transpose`."""
......
......@@ -93,7 +93,6 @@ inline const PrimitivePtr kPrimArrayToScalar = std::make_shared<Primitive>("arra
inline const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("broadcast_shape");
inline const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map");
inline const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_reduce");
inline const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape");
inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast");
inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat");
inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze");
......
......@@ -15,7 +15,6 @@
*/
#include "pipeline/jit/static_analysis/prim.h"
#include "frontend/operator/ops.h"
#include "abstract/utils.h"
#include "frontend/operator/cc_implementations.h"
#include "abstract/param_validator.h"
......@@ -80,23 +79,6 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti
return std::make_shared<AbstractTuple>(elems);
}
AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a tensor.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
AbstractTensorPtr arg = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_LOG(DEBUG) << "InferImplShape:" << arg->ToString();
AbstractBasePtrList values;
auto shp = arg->shape();
for (int entry : shp->shape()) {
auto entry_v = MakeValue(entry);
values.push_back(std::make_shared<AbstractScalar>(entry_v, entry_v->type()));
}
return std::make_shared<AbstractTuple>(values);
}
AbstractBasePtr InferImplTile(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a tensor and a tuple.
......
......@@ -963,6 +963,7 @@ void ClearResAtexit() {
abstract::ClearPrimEvaluatorMap();
compile::ClearConvertCache();
pipeline::GetMethodMap().clear();
pipeline::GetAttrMap().clear();
pipeline::ExecutorPy::ClearRes();
pipeline::ReclaimOptimizer();
pynative::PynativeExecutor::GetInstance()->ClearRes();
......
......@@ -17,23 +17,20 @@
*/
#include "pipeline/jit/resource.h"
#include "pipeline/jit/pipeline.h"
#include "pipeline/jit/static_analysis/static_analysis.h"
#include "debug/draw.h"
#include "debug/trace.h"
#include "ir/dtype.h"
#include "pipeline/jit/parse/data_converter.h"
#include "frontend/operator/ops.h"
#include "ir/graph_utils.h"
#include "frontend/optimizer/ad/dfunctor.h"
#include "vm/segment_runner.h"
namespace mindspore {
// namespace to support opmap definition
namespace pipeline {
MethodMap &GetMethodMap() {
static MethodMap method_map = {
BuiltInTypeMap &GetMethodMap() {
static BuiltInTypeMap method_map = {
{kObjectTypeString,
{
{"__bool__", std::string("str_bool")} // C.str_bool
......@@ -191,6 +188,15 @@ MethodMap &GetMethodMap() {
return method_map;
}
BuiltInTypeMap &GetAttrMap() {
static BuiltInTypeMap attr_map = {{kObjectTypeTensorType,
{
{"shape", std::string("shape_")}, // C.shape_
{"dtype", std::string("dtype_")}, // C.dtype_
}}};
return attr_map;
}
Resource::Resource(const py::object &obj)
: engine_(std::make_shared<abstract::AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), manager_)),
input_(obj),
......@@ -218,31 +224,42 @@ Resource::~Resource() {
}
}
bool Resource::IsTypeInMethodMap(const TypeId &type) {
TypeId type_id = NormalizeTypeId(type);
const MethodMap &method_map = GetMethodMap();
auto iter = method_map.find(static_cast<int>(type_id));
if (iter != method_map.end()) {
return true;
Any GetMethodOrAttr(const string &name, const TypeId &type_id, const BuiltInTypeMap &method_map) {
auto type_method_map = method_map.find(static_cast<int>(type_id));
if (type_method_map == method_map.end()) {
return Any();
}
return false;
auto method = type_method_map->second.find(name);
if (method == type_method_map->second.end()) {
return Any();
}
return method->second;
}
Any Resource::GetMethodPtr(const TypeId &type, const std::string &name) {
bool Resource::IsTypeInBuiltInMap(const TypeId &type) {
TypeId type_id = NormalizeTypeId(type);
const MethodMap &method_map = GetMethodMap();
const BuiltInTypeMap &method_map = GetMethodMap();
auto iter = method_map.find(static_cast<int>(type_id));
if (iter == method_map.end()) {
MS_LOG(WARNING) << "Object type: " << type_id << " not in the method_map";
return Any();
const BuiltInTypeMap &attr_map = GetAttrMap();
iter = attr_map.find(static_cast<int>(type_id));
if (iter == attr_map.end()) {
return false;
}
}
return true;
}
auto iter_map = iter->second.find(name);
if (iter_map == iter->second.end()) {
MS_LOG(WARNING) << "Object type: " << type_id << " have no method: " << name;
return Any();
}
return iter_map->second;
Any Resource::GetMethodPtr(const TypeId &type, const std::string &name) {
TypeId type_id = NormalizeTypeId(type);
const BuiltInTypeMap &method_map = GetMethodMap();
return GetMethodOrAttr(name, type_id, method_map);
}
Any Resource::GetAttrPtr(const TypeId &type, const std::string &name) {
TypeId type_id = NormalizeTypeId(type);
const BuiltInTypeMap &attr_map = GetAttrMap();
return GetMethodOrAttr(name, type_id, attr_map);
}
void Resource::Clean() {
......
......@@ -44,9 +44,11 @@ const char kOutput[] = "output";
class InferenceResource;
using MethodMap = std::unordered_map<int, std::unordered_map<std::string, Any>>;
using BuiltInTypeMap = std::unordered_map<int, std::unordered_map<std::string, Any>>;
MethodMap &GetMethodMap();
BuiltInTypeMap &GetMethodMap();
BuiltInTypeMap &GetAttrMap();
class ResourceBase {
public:
......@@ -87,10 +89,12 @@ class Resource : public ResourceBase {
abstract::AnalysisEnginePtr engine() { return engine_; }
static bool IsTypeInMethodMap(const TypeId &type);
static bool IsTypeInBuiltInMap(const TypeId &type);
static Any GetMethodPtr(const TypeId &type, const std::string &name);
static Any GetAttrPtr(const TypeId &type, const std::string &name);
const py::object &input() const { return input_; }
FuncGraphPtr func_graph() const { return func_graph_; }
......
......@@ -21,7 +21,6 @@
#include <algorithm>
#include <limits>
#include <mutex>
#include <set>
#include <string>
#include <utility>
......@@ -31,10 +30,8 @@
#include "frontend/operator/prim_to_function.h"
#include "abstract/utils.h"
#include "utils/symbolic.h"
#include "./common.h"
#include "pipeline/jit/resource.h"
#include "pipeline/jit/parse/resolve.h"
#include "ir/tensor.h"
#include "utils/convert_utils.h"
#include "utils/context/ms_context.h"
#include "pipeline/jit/parse/data_converter.h"
......@@ -64,7 +61,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimScalarToArray, {InferImplScalarToArray, true}},
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
{prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}},
{prim::kPrimShape, {InferImplShape, true}},
{prim::kPrimPack, {InferImplPack, true}},
// Structure
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
......@@ -634,7 +630,7 @@ EvaluatorPtr InitUniformPrimEvaluator(const PrimitivePtr &primitive, PrimitiveIm
}
const int kResolveCaseUserDefineClass = 1;
const int kResolveCaseBuildinTypeMethod = 2;
const int kResolveCaseBuiltInType = 2;
const int kResolveCaseFunction = 3;
int GetResolveCase(const TypePtr &data_type) {
MS_EXCEPTION_IF_NULL(data_type);
......@@ -643,8 +639,8 @@ int GetResolveCase(const TypePtr &data_type) {
}
// try method map, if not in method map, the data_type should be External type.
if (pipeline::Resource::IsTypeInMethodMap(data_type->type_id())) {
return kResolveCaseBuildinTypeMethod;
if (pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id())) {
return kResolveCaseBuiltInType;
}
return kResolveCaseFunction;
......@@ -674,8 +670,10 @@ inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr fun
manager->AddFuncGraph(func_graph);
}
EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf,
const AnfNodeConfigPtr &old_conf) {
enum REQUIRE_TYPE { ATTR, METHOD };
EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, const AnfNodeConfigPtr &old_conf,
REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD) {
MS_EXCEPTION_IF_NULL(old_conf);
AbstractBasePtr abs_ptr = ToAbstract(value, AnalysisContext::DummyContext(), old_conf);
......@@ -701,6 +699,9 @@ EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_
MS_EXCEPTION_IF_NULL(old_conf);
FuncGraphPtr func_graph = old_conf->node()->func_graph();
CNodePtr new_cnode = func_graph->NewCNode(input);
if (require_type == REQUIRE_TYPE::ATTR) {
new_cnode = func_graph->NewCNode({new_cnode});
}
AnalysisEnginePtr eng = old_conf->engine();
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, old_conf->context());
return eng->ForwardConfig(old_conf, fn_conf);
......@@ -781,9 +782,9 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng
return StaticGetterInferred(converted_v, data_conf, out_conf);
}
EvalResultPtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v,
const TypePtr &data_type, const ConfigPtr &data_conf,
const AnfNodeConfigPtr &out_conf) {
EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v,
const TypePtr &data_type, const ConfigPtr &data_conf,
const AnfNodeConfigPtr &out_conf) {
MS_EXCEPTION_IF_NULL(item_v);
MS_EXCEPTION_IF_NULL(data_type);
// The method maybe a Primitive or Composite
......@@ -792,22 +793,29 @@ EvalResultPtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &eng
}
std::string item_name = item_v->cast<StringImmPtr>()->value();
Any method = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name);
if (method.empty()) {
MS_LOG(EXCEPTION) << "Object type: " << data_type->ToString() << " has no method: " << item_name;
REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD;
Any require = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name);
if (require.empty()) {
require = pipeline::Resource::GetAttrPtr(data_type->type_id(), item_name);
if (require.empty()) {
MS_LOG(EXCEPTION) << "The object of type: " << data_type->ToString() << " has no method or attr: " << item_name;
}
require_type = REQUIRE_TYPE::ATTR;
}
ValuePtr converted_v = nullptr;
if (method.is<std::string>()) {
if (require.is<std::string>()) {
// composite registered in standard_method_map go to this branch
converted_v = prim::GetPythonOps(method.cast<std::string>());
AddToManager(engine, converted_v->cast<FuncGraphPtr>());
} else if (method.is<PrimitivePtr>()) {
converted_v = method.cast<PrimitivePtr>();
converted_v = prim::GetPythonOps(require.cast<std::string>());
if (!converted_v->isa<Primitive>()) {
AddToManager(engine, converted_v->cast<FuncGraphPtr>());
}
} else if (require.is<PrimitivePtr>()) {
converted_v = require.cast<PrimitivePtr>();
} else {
MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from method map, but got " << method.ToString();
MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from attr or method map, but got " << require.ToString();
}
return StaticGetterInferred(converted_v, data_conf, out_conf);
return StaticGetterInferred(converted_v, data_conf, out_conf, require_type);
}
EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
......@@ -831,8 +839,8 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
int case_v = GetResolveCase(data_type);
if (case_v == kResolveCaseUserDefineClass) {
return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf);
} else if (case_v == kResolveCaseBuildinTypeMethod) {
return GetEvaluatedValueForBuiltinTypeMethod(engine, item_value, data_type, data_conf, out_conf);
} else if (case_v == kResolveCaseBuiltInType) {
return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, item_value, data_type, data_conf, out_conf);
} else {
return GetEvaluatedValueForNameSpaceString(engine, args_spec_list, out_conf);
}
......
......@@ -218,10 +218,6 @@ AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const P
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
......@@ -246,8 +242,6 @@ AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const Primitiv
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
......
......@@ -22,20 +22,21 @@ import copy
import functools
import itertools
import numbers
import numpy as np
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
from ...common.tensor import Tensor
from ...common.parameter import Parameter
from ..operations.math_ops import _infer_shape_reduce
from .._utils import get_concat_offset
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_op
from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind
from ..operations.math_ops import _infer_shape_reduce
from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op
from ..._c_expression import signature_dtype as sig_dtype
from ..._c_expression import signature_kind as sig_kind
from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import typing
from ..._checkparam import Rel
from ..._checkparam import Validator as validator
from ...common import dtype as mstype
from ...common.parameter import Parameter
from ...common.tensor import Tensor
class _ScatterOp(PrimitiveWithInfer):
......@@ -415,7 +416,7 @@ class Reshape(PrimitiveWithInfer):
return out
class Shape(Primitive):
class Shape(PrimitiveWithInfer):
"""
Returns the shape of input tensor.
......@@ -436,6 +437,13 @@ class Shape(Primitive):
def __init__(self):
"""init Shape"""
def __infer__(self, x):
validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name)
out = {'shape': (),
'dtype': mstype.tuple_,
'value': tuple(x['shape'])}
return out
class Squeeze(PrimitiveWithInfer):
"""
......
......@@ -267,11 +267,6 @@ TEST_F(TestOps, BroadCastShapeTest) {
ASSERT_EQ(prim->name(), kPrimBroadcastShape->name());
}
TEST_F(TestOps, ShapeTest) {
auto prim = std::make_shared<Primitive>("Shape");
ASSERT_EQ(prim->name(), kPrimShape->name());
}
TEST_F(TestOps, ArrayMapTest) {
auto prim = std::make_shared<Primitive>("array_map");
ASSERT_EQ(prim->name(), kPrimArrayMap->name());
......
......@@ -36,23 +36,23 @@ class TestResource : public UT::Common {
void TearDown() {}
};
TEST_F(TestResource, test_standard_method_map) {
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt));
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt8));
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt16));
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt32));
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt64));
TEST_F(TestResource, test_built_in_type_map) {
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt));
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt8));
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt16));
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt32));
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt64));
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeFloat));
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeFloat16));
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeFloat32));
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeFloat64));
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeFloat));
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeFloat16));
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeFloat32));
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeFloat64));
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeBool));
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeUInt));
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kObjectTypeTuple));
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kObjectTypeList));
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kObjectTypeTensorType));
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeBool));
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeUInt));
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kObjectTypeTuple));
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kObjectTypeList));
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kObjectTypeTensorType));
MethodMap& map = GetMethodMap();
for (auto& iter : map) {
......
......@@ -467,24 +467,6 @@ TEST_F(TestPrim, test_env_add) {
ASSERT_TRUE(*res == *exp);
}
TEST_F(TestPrim, test_shape) {
PrimitivePtr shap = std::make_shared<Primitive>("Shape");
FuncGraphPtr func_graph = MakeFuncGraph(shap, 1);
auto a = UTPrimUtils::ArrayFloat64Of({2, 3});
AbstractBasePtrList args_spec_list = {a};
AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred->abstract());
auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value();
std::vector<ValuePtr> element_list = {MakeValue(2), MakeValue(3)};
ASSERT_TRUE(ret.size() == element_list.size());
for (int i = 0; i < element_list.size(); i++) {
ASSERT_TRUE(*ret[i] == *element_list[i]);
}
}
TEST_F(TestPrim, test_relu) {
PrimitivePtr relu = prim::kPrimRelu;
relu->AddAttr("T", MakeValue(static_cast<int>(kNumberTypeFloat64)));
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test dtype and shape as attr"""
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore import dtype as mstype
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE)
def test_dtype_and_shape_as_attr():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
shape = x.shape
dtype = x.dtype
return shape, dtype
net = Net()
x = Tensor(np.ones([1, 2, 3], np.int32))
ret = net(x)
assert ret == ((1, 2, 3), mstype.int32)
def test_dtype_and_shape_as_attr_to_new_tensor():
class Net(nn.Cell):
def __init__(self, value):
super(Net, self).__init__()
self.fill = P.Fill()
self.value = value
def construct(self, x):
dtype = x.dtype
shape = x.shape
y = self.fill(dtype, shape, self.value)
return y
net = Net(2.2)
x = Tensor(np.ones([1, 2, 3], np.float32))
ret = net(x)
assert (ret.asnumpy() == (np.zeros([1, 2, 3], np.float32) + 2.2)).all()
def test_type_not_have_the_attr():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
shape = x.shapes
return shape
net = Net()
x = Tensor(np.ones([1, 2, 3], np.int32))
with pytest.raises(RuntimeError) as ex:
net(x)
assert "The object of type: Tensor[Int32] has no method or attr: shapes" in str(ex.value)
def test_type_not_have_the_method():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
shape = x.dtypes()
return shape
net = Net()
x = Tensor(np.ones([1, 2, 3], np.int32))
with pytest.raises(RuntimeError) as ex:
net(x)
assert "The object of type: Tensor[Int32] has no method or attr: dtypes" in str(ex.value)
......@@ -20,7 +20,7 @@ import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
context.set_context(mode=context.GRAPH_MODE)
class FatherNet(nn.Cell):
......@@ -92,7 +92,6 @@ class Net(nn.Cell):
def test_single_super():
single_net = SingleSubNet(2, 3)
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.ones([1, 2, 3], np.int32))
y = Tensor(np.ones([1, 2, 3], np.int32))
single_net(x, y)
......@@ -100,7 +99,6 @@ def test_single_super():
def test_mul_super():
mul_net = MulSubNet(2, 3, 4)
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.ones([1, 2, 3], np.int32))
y = Tensor(np.ones([1, 2, 3], np.int32))
mul_net(x, y)
......@@ -108,7 +106,6 @@ def test_mul_super():
def test_super_cell():
net = Net(2)
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.ones([1, 2, 3], np.int32))
y = Tensor(np.ones([1, 2, 3], np.int32))
with pytest.raises(RuntimeError) as er:
......@@ -142,7 +139,6 @@ def test_single_super_in():
return ret_father_construct, ret_father_test, ret_father_x, ret_sub_z
single_net_in = SingleSubNetIN(2, 3)
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
x = Tensor(np.ones([1, 2, 3], np.int32))
y = Tensor(np.ones([1, 2, 3], np.int32))
single_net_in(x, y)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册