提交 da06310a 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1766 [bug][auto cast]fix bug when do auto cast

Merge pull request !1766 from vlne-v1/I1J0M0-amp-do-auto-cast-failed
......@@ -65,6 +65,7 @@ test_temp_summary_event_file/
*.ckpt
*.shp
*.pkl
*.pb
.clangd
mindspore/version.py
mindspore/default_config.py
......
......@@ -253,7 +253,7 @@ std::string Dtype2String(const std::string &dtypes) {
std::string TypeId2String(TypeId type_id) {
auto iter = type_id_str_map.find(type_id);
if (iter == type_id_str_map.end()) {
MS_EXCEPTION(ArgumentError) << "Illegal input dtype." << TypeIdLabel(type_id);
return std::string(TypeIdLabel(type_id));
}
return iter->second;
}
......
......@@ -47,16 +47,6 @@ const std::vector<Signature> &GetSignature(const ValuePtr &function) {
return empty;
}
const std::string GetOpName(const ValuePtr &function) {
std::string name = "";
if (function->isa<Primitive>()) {
name = function->cast<PrimitivePyPtr>()->name();
} else if (function->isa<MetaFuncGraph>()) {
name = function->cast<MetaFuncGraphPtr>()->name();
}
return name;
}
void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list,
const std::vector<Signature> &signature, bool has_var, std::vector<AnfNodePtr> *const op_inputs) {
std::size_t sig_size = signature.size();
......@@ -93,7 +83,8 @@ void setMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number,
*max_type_number = type_number;
}
TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indexs) {
TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indexs,
const std::set<size_t> &write_indexs) {
TypeId max_type_id = kTypeUnknown;
TypeId max_type = kTypeUnknown;
size_t max_type_number = 0;
......@@ -103,7 +94,12 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
TypeId arg_type = kTypeUnknown;
AbstractBasePtr arg_value = args_spec_list[index];
if (arg_value->isa<abstract::AbstractRef>()) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
auto is_write = (write_indexs.find(index) != write_indexs.end());
if (is_write) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin();
} else {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
}
}
if (arg_value->isa<abstract::AbstractTensor>()) {
auto tensor = arg_value->cast<abstract::AbstractTensorPtr>();
......@@ -157,7 +153,8 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
// Get the largest type of index in the same SignatureEnumDType of arguments.
std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
const abstract::AbstractBasePtrList &args_spec_list) {
const abstract::AbstractBasePtrList &args_spec_list,
const std::set<size_t> &write_indexs) {
// record index for signature.dtypes of the same type
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
......@@ -192,7 +189,7 @@ std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnum
(void)dst_type.insert(std::make_pair(type, kTypeUnknown));
continue;
}
(void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indexs)));
(void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indexs, write_indexs)));
}
return dst_type;
}
......@@ -205,9 +202,9 @@ AnfNodePtr DoCast(const AnfNodePtr &param, const TypeId &type_id, const FuncGrap
return NewCNode({cast_node, param, dtype_node}, graph);
}
void DoAutoCast(const std::vector<Signature> &signature, const abstract::AbstractBasePtrList &args_spec_list,
const FuncGraphPtr &graph, std::vector<AnfNodePtr> *const op_inputs,
const std::set<size_t> &write_indexs) {
void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature,
const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph,
std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indexs) {
std::vector<SignatureEnumDType> dtypes;
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
[](const Signature &sig) { return sig.dtype; });
......@@ -216,16 +213,23 @@ void DoAutoCast(const std::vector<Signature> &signature, const abstract::Abstrac
return;
}
// Stat the index of the arguments with the largest type in the same SignatureEnumDType.
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list);
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indexs);
// Identify which arg requires auto cast
for (size_t i = 0; i < args_spec_list.size(); ++i) {
auto it = dst_type.find(dtypes[i]);
if (it == dst_type.end() || it->second == kTypeUnknown) {
continue;
}
auto rw_it = write_indexs.find(i);
auto is_write = (rw_it != write_indexs.end());
AbstractBasePtr arg_value = args_spec_list[i];
if (arg_value->isa<abstract::AbstractRef>()) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
if (is_write) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin();
} else {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
}
}
TypeId arg_type_id = kTypeUnknown;
if (arg_value->isa<abstract::AbstractTensor>()) {
......@@ -243,10 +247,9 @@ void DoAutoCast(const std::vector<Signature> &signature, const abstract::Abstrac
if (it_map == type_map.end()) {
continue;
}
auto rw_it = write_indexs.find(i);
if (rw_it != write_indexs.end()) {
if (is_write) {
if (arg_type_id != it->second) {
MS_LOG(EXCEPTION) << "In op '" << GetOpName(graph) << "', argument '" << args_spec_list[i]
MS_LOG(EXCEPTION) << "In op '" << func_name << "', argument '" << args_spec_list[i]
<< "' can not cast type from '" << TypeIdLabel(arg_type_id) << "' to '"
<< TypeIdLabel(it->second) << "' automatically.";
}
......@@ -299,8 +302,8 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
if (sig == SignatureEnumRW::kRWRead) {
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param});
} else if (sig == SignatureEnumRW::kRWWrite) {
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param});
write_indexs.insert(i);
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefKey), param});
}
// If sig is SignatureEnumRW::kRWRef, not do anything.
} else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) {
......@@ -310,7 +313,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
}
// process default
ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs);
DoAutoCast(signature, args_spec_list, func_graph, &op_inputs, write_indexs);
DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indexs);
return func_graph->NewCNode(op_inputs);
}
} // namespace
......
......@@ -160,7 +160,7 @@ AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const Primitive
const AbstractBasePtrList &args_spec_list) {
// arguments: value
if (args_spec_list.size() != 1) {
MS_LOG(EXCEPTION) << "get_ref_value requires 1 parameters, while the input size is " << args_spec_list.size()
MS_LOG(EXCEPTION) << "get_ref_origin requires 1 parameters, while the input size is " << args_spec_list.size()
<< ".";
}
TypePtr type = args_spec_list[0]->GetTypeTrack();
......
......@@ -81,8 +81,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Ref eliminate
make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef);
get_make_ref_eliminate_ =
MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", {prim::kPrimGetRefKey, prim::kPrimGetRefValue});
get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate",
{prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
replace_refkey_by_param_ =
MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM);
......
......@@ -48,6 +48,7 @@ class MakeRefEliminater : public AnfVisitor {
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z
class GetMakeRefEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
......@@ -71,6 +72,10 @@ class GetMakeRefEliminater : public AnfVisitor {
return ref->input(2);
}
if (cnode->IsApply(prim::kPrimGetRefOrigin)) {
return ref->input(3);
}
return nullptr;
}
};
......
......@@ -315,7 +315,7 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple);
ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend);
ValueNodePtr get_refkey_op = NewValueNode(prim::kPrimGetRefKey);
ValueNodePtr get_ref_origin_op = NewValueNode(prim::kPrimGetRefOrigin);
ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient);
const std::string primitive_name("assign");
const std::string module_name("mindspore.ops.functional");
......@@ -329,8 +329,8 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
vec_states.emplace_back(make_tuple_op);
for (auto &item : state_assign_) {
auto source = ReadVariable(item.second);
auto refkey = func_graph()->NewCNode({get_refkey_op, item.first});
auto assign = func_graph()->NewCNode({assign_op, refkey, source});
auto origin = func_graph()->NewCNode({get_ref_origin_op, item.first});
auto assign = func_graph()->NewCNode({assign_op, origin, source});
MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second;
vec_states.emplace_back(assign);
}
......
......@@ -801,8 +801,8 @@ bool AbstractRef::operator==(const AbstractBase &other) const {
std::string AbstractRef::ToString() const {
std::ostringstream buffer;
buffer << type_name() << "("
<< "key: " << ref_key_->ToString() << "ref_value: " << ref_->ToString()
<< "origin_value: " << ref_origin_->ToString();
<< "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString()
<< " origin_value: " << ref_origin_->ToString();
auto value = GetValueTrack();
if (value) {
buffer << ", value: " << value->ToString();
......
......@@ -783,7 +783,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
AbstractBasePtr abs = node_conf->GetEvaluatedValue()->abstract();
AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
if (ref_abs == nullptr) {
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref.";
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
return nullptr;
}
auto key_abs = ref_abs->ref_key();
......
......@@ -170,7 +170,7 @@ def get_py_obj_dtype(obj):
Type of MindSpore type.
"""
# Tensor
if hasattr(obj, 'dtype'):
if hasattr(obj, 'dtype') and callable(obj.dtype) and isinstance(obj.dtype(), typing.Type):
return tensor_type(obj.dtype())
if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'):
return function
......
......@@ -31,7 +31,9 @@ from ...common.tensor import Tensor
from ..operations.math_ops import _infer_shape_reduce
from .._utils import get_concat_offset
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind
from ..._c_expression import signature_dtype as sig_dtype
def _check_infer_attr_reduce(axis, keep_dims, prim_name):
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
......@@ -2156,13 +2158,17 @@ class ScatterUpdate(PrimitiveWithInfer):
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32))
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> op = P.ScatterNdUpdate()
>>> op = P.ScatterUpdate()
>>> output = op(input_x, indices, update)
"""
__mindspore_signature__ = (
('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
)
@prim_attr_register
def __init__(self, use_locking=True):
"""Init ScatterNdUpdate"""
"""Init ScatterUpdate"""
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
def infer_shape(self, x_shape, indices_shape, value_shape):
......@@ -2201,7 +2207,11 @@ class ScatterNdUpdate(PrimitiveWithInfer):
>>> op = P.ScatterNdUpdate()
>>> output = op(input_x, indices, update)
"""
__mindspore_signature__ = (
('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
)
@prim_attr_register
def __init__(self, use_locking=True):
"""Init ScatterNdUpdate"""
......
......@@ -179,7 +179,7 @@ class AssignAdd(PrimitiveWithInfer):
return value
def infer_dtype(self, variable, value):
args = {"value": value}
args = {"variable": variable, "value": value}
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
return value
......@@ -222,7 +222,7 @@ class AssignSub(PrimitiveWithInfer):
return value
def infer_dtype(self, variable, value):
args = {"value": value}
args = {"variable": variable, "value": value}
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
return value
......
......@@ -58,6 +58,8 @@ class Assign(PrimitiveWithInfer):
return variable
def infer_dtype(self, variable, value):
args = {"variable": variable, "value": value}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return variable
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test layer switch"""
import numpy as np
import mindspore
......
......@@ -345,19 +345,6 @@ class Conv2dNativeNet(nn.Cell):
return self.flatten(self.conv(input_x, self.weight))
class MakeRefKeyNet(nn.Cell):
""" MakeRefKeyNet definition """
def __init__(self):
super(MakeRefKeyNet, self).__init__()
self.y = Parameter(Tensor([1.0], mindspore.float32), name="y")
def construct(self, x):
key = P.MakeRefKey("y")()
P.Assign()(key, x)
return x
class StateNet(nn.Cell):
""" StateTestTensor definition """
......@@ -538,10 +525,6 @@ test_cases = [
'block': Grad(NetWithLossClass(Conv2dNativeNet())),
'desc_inputs': [Tensor(np.ones([1, 3, 16, 16], np.float32)), Tensor(np.zeros([1, 1764], np.float32))],
}),
('MakeRefKey', {
'block': MakeRefKeyNet(),
'desc_inputs': [Tensor([2.0], mindspore.float32)],
}),
('StateTest', {
'block': StateNet(),
'desc_inputs': [Tensor(np.ones([2, 1, 2, 2]).astype(np.float32))],
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test assign sub
"""
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore import Tensor
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
import mindspore as ms
class AssignW(nn.Cell):
def __init__(self):
super(AssignW, self).__init__()
self.assign = P.Assign()
def construct(self, x, w):
self.assign(x, w)
return x
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.b = Parameter(initializer('ones', [5]), name='b')
self.assign = AssignW()
def construct(self, value):
return self.assign(self.b, value)
def test_assign_through_cell():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
net = Net()
net.to_float(ms.float16)
net.add_flags_recursive(fp16=False)
input_data = Tensor(np.ones([5]).astype(np.float32))
net(input_data)
with pytest.raises(TypeError):
net(None)
class NetScatterNdUpdate(nn.Cell):
def __init__(self):
super(NetScatterNdUpdate, self).__init__()
self.b = Parameter(initializer('ones', [5, 5]), name='b')
self.scatter = P.ScatterNdUpdate()
def construct(self, idx, x):
return self.scatter(self.b, idx, x)
def test_scatter_nd_update():
context.set_context(mode=context.GRAPH_MODE)
net = NetScatterNdUpdate()
x = Tensor(np.ones([5]).astype(np.float16))
idx = Tensor(np.ones([1]).astype(np.int32))
net(idx, x)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册