提交 afd16fbf 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4963 fix bug of switch layer join

Merge pull request !4963 from fary86/fix_switch_layer_join_bug
......@@ -283,9 +283,99 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj));
return false;
}
bool ConvertIntegerWithType(const int &obj, ValuePtr *const data, TypePtr dtype = nullptr) {
if (dtype == nullptr) {
*data = std::make_shared<Int32Imm>(obj);
return true;
}
auto int_dypte = dyn_cast<Int>(dtype);
if (int_dypte != nullptr) {
switch (int_dypte->nbits()) {
case 8:
*data = std::make_shared<Int8Imm>(static_cast<int8_t>(obj));
break;
case 16:
*data = std::make_shared<Int16Imm>(obj);
break;
case 32:
*data = std::make_shared<Int32Imm>(obj);
break;
case 64:
*data = std::make_shared<Int64Imm>(obj);
break;
default:
*data = std::make_shared<Int32Imm>(obj);
}
return true;
}
auto uint_dypte = dyn_cast<UInt>(dtype);
if (int_dypte != nullptr) {
switch (uint_dypte->nbits()) {
case 8:
*data = std::make_shared<UInt8Imm>(obj);
break;
case 16:
*data = std::make_shared<UInt16Imm>(obj);
break;
case 32:
*data = std::make_shared<UInt32Imm>(obj);
break;
case 64:
*data = std::make_shared<UInt64Imm>(obj);
break;
default:
*data = std::make_shared<UInt32Imm>(obj);
}
return true;
}
auto float_dypte = dyn_cast<Float>(dtype);
if (float_dypte != nullptr) {
switch (float_dypte->nbits()) {
case 32:
*data = std::make_shared<FP32Imm>(obj);
break;
case 64:
*data = std::make_shared<FP64Imm>(obj);
break;
default:
*data = std::make_shared<FP32Imm>(obj);
}
return true;
}
return false;
}
bool ConvertFloatWithType(const float &obj, ValuePtr *const data, TypePtr dtype = nullptr) {
if (dtype == nullptr) {
*data = std::make_shared<FP32Imm>(obj);
return true;
}
auto float_dypte = dyn_cast<Float>(dtype);
if (float_dypte == nullptr) {
return false;
}
switch (float_dypte->nbits()) {
case 32:
*data = std::make_shared<FP32Imm>(obj);
break;
case 64:
*data = std::make_shared<FP64Imm>(obj);
break;
default:
*data = std::make_shared<FP32Imm>(obj);
}
return true;
}
} // namespace
bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature) {
bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, TypePtr dtype) {
// check parameter valid
if (data == nullptr) {
MS_LOG(ERROR) << "Data is null pointer";
......@@ -299,9 +389,9 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
} else if (py::isinstance<py::bool_>(obj)) {
converted = std::make_shared<BoolImm>(py::cast<bool>(obj));
} else if (py::isinstance<py::int_>(obj)) {
converted = std::make_shared<Int32Imm>(py::cast<int>(obj));
ret = ConvertIntegerWithType(py::cast<int>(obj), &converted, dtype);
} else if (py::isinstance<py::float_>(obj)) {
converted = std::make_shared<FP32Imm>(py::cast<float>(obj));
ret = ConvertFloatWithType(py::cast<float>(obj), &converted, dtype);
} else if (py::isinstance<py::str>(obj)) {
converted = std::make_shared<StringImm>(py::cast<std::string>(obj));
} else if (py::isinstance<py::dict>(obj)) {
......
......@@ -139,7 +139,7 @@ enum ClassInstanceTypeDef {
};
// Convert python object to ValuePtr
bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false);
bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false, TypePtr dtype = nullptr);
// Convert python obj to graph
FuncGraphPtr ConvertToFuncGraph(const py::object &obj,
......
......@@ -407,9 +407,9 @@ py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrLi
AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) {
// Convert to AbstractValue based on type and shape
auto out_dtype = output["dtype"];
if (output["value"].is_none()) {
auto out_shape = output["shape"];
auto out_dtype = output["dtype"];
py::object min_shape = output.contains("min_shape") ? (py::object)output["min_shape"] : (py::object)py::none();
py::object max_shape = output.contains("max_shape") ? (py::object)output["max_shape"] : (py::object)py::none();
......@@ -417,7 +417,8 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
}
// Convert pyobject to Value, then to AbstractValue
ValuePtr converted_ret = nullptr;
bool converted = parse::ConvertData(output["value"], &converted_ret);
TypePtr dtype = py::isinstance<Type>(out_dtype) ? out_dtype.cast<TypePtr>() : nullptr;
bool converted = parse::ConvertData(output["value"], &converted_ret, false, dtype);
if (!converted) {
MS_LOG(EXCEPTION) << "Convert data failed";
}
......
......@@ -45,14 +45,34 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
MS_LOG(EXCEPTION) << "value is null";
}
py::object ret;
if (value->isa<Int32Imm>()) {
MS_LOG(DEBUG) << "int";
if (value->isa<Int8Imm>()) {
MS_LOG(DEBUG) << "int8";
py::int_ v = value->cast<Int8ImmPtr>()->value();
ret = v;
} else if (value->isa<Int16Imm>()) {
MS_LOG(DEBUG) << "int16";
py::int_ v = value->cast<Int16ImmPtr>()->value();
ret = v;
} else if (value->isa<Int32Imm>()) {
MS_LOG(DEBUG) << "int32";
py::int_ v = value->cast<Int32ImmPtr>()->value();
ret = v;
} else if (value->isa<Int64Imm>()) {
MS_LOG(DEBUG) << "int64";
py::int_ v = value->cast<Int64ImmPtr>()->value();
ret = v;
} else if (value->isa<UInt8Imm>()) {
MS_LOG(DEBUG) << "uint8";
py::int_ v = value->cast<UInt8ImmPtr>()->value();
ret = v;
} else if (value->isa<UInt16Imm>()) {
MS_LOG(DEBUG) << "uint16";
py::int_ v = value->cast<UInt16ImmPtr>()->value();
ret = v;
} else if (value->isa<UInt32Imm>()) {
MS_LOG(DEBUG) << "uint32";
py::int_ v = value->cast<UInt32ImmPtr>()->value();
ret = v;
} else if (value->isa<UInt64Imm>()) {
MS_LOG(DEBUG) << "uint64";
py::int_ v = value->cast<UInt64ImmPtr>()->value();
......
......@@ -97,8 +97,12 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
}
auto value_self = GetValueTrack();
MS_EXCEPTION_IF_NULL(value_self);
ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack());
TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack());
if (res_type == kAnyType) {
MS_EXCEPTION(TypeError) << "Type join failed, type1 = " << GetTypeTrack()->ToString()
<< ", type2 = " << other->GetTypeTrack()->ToString();
}
ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack());
if (res_value == value_self) {
return shared_from_base<AbstractBase>();
}
......
......@@ -50,9 +50,17 @@ ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) {
if (*shape1 == *shape2) {
return shape1;
}
// lengths of two shapes are not same, join failed
if (shape1->shape().size() != shape2->shape().size()) {
MS_LOG(WARNING) << "Unsupported shape join. shape1 = " << shape1->ToString() << ", shape2 = " << shape2->ToString();
return shape1;
// special case: shape(1), shape() -> shape(1)
if (shape1->shape().size() == 1 && shape1->shape()[0] == 1 && shape2->shape().size() == 0) {
return shape1;
}
if (shape2->shape().size() == 1 && shape2->shape()[0] == 1 && shape1->shape().size() == 0) {
return shape2;
}
MS_EXCEPTION(ValueError) << "Unsupported shape join. shape1 = " << shape1->ToString()
<< ", shape2 = " << shape2->ToString();
}
std::vector<int> dims;
bool has_dynamic_shape = false;
......
......@@ -105,7 +105,7 @@ class Int8Imm : public IntergerImm {
std::string DumpText() const override {
std::ostringstream oss;
oss << "I8(" << v_ << ")";
oss << "I8(" << int(v_) << ")";
return oss.str();
}
......@@ -131,7 +131,7 @@ class Int16Imm : public IntergerImm {
std::string DumpText() const override {
std::ostringstream oss;
oss << "I16(" << v_ << ")";
oss << "I16(" << int(v_) << ")";
return oss.str();
}
......@@ -157,7 +157,7 @@ class Int32Imm : public IntergerImm {
std::string DumpText() const override {
std::ostringstream oss;
oss << "I32(" << v_ << ")";
oss << "I32(" << int(v_) << ")";
return oss.str();
}
......@@ -211,7 +211,7 @@ class UInt8Imm : public IntergerImm {
std::string DumpText() const override {
std::ostringstream oss;
oss << "U8(" << v_ << ")";
oss << "U8(" << unsigned(v_) << ")";
return oss.str();
}
......@@ -239,7 +239,7 @@ class UInt16Imm : public IntergerImm {
std::string DumpText() const override {
std::ostringstream oss;
oss << "U16(" << v_ << ")";
oss << "U16(" << unsigned(v_) << ")";
return oss.str();
}
......@@ -267,7 +267,7 @@ class UInt32Imm : public IntergerImm {
std::string DumpText() const override {
std::ostringstream oss;
oss << "U32(" << v_ << ")";
oss << "U32(" << unsigned(v_) << ")";
return oss.str();
}
......
......@@ -324,7 +324,7 @@ class ScalarGradChecker(_GradChecker):
self.input_selector = [i for i in range(self.nin)]
def get_sens(self, i):
return 1
return 1.0
def check_against_numeric(self, out_index):
args = list(self.args)
......
......@@ -916,3 +916,73 @@ def test_recursive_call():
with pytest.raises(RuntimeError):
net(input_data)
context.set_context(max_call_depth=old_max_call_depth)
def test_switch_layer_shape_join_failed():
class AddFuncNet(nn.Cell):
def __init__(self, funcs, new_func):
super(AddFuncNet, self).__init__()
self.funcs = funcs
self.new_func = new_func
def construct(self, i, inputs):
final_funcs = self.funcs + (self.new_func,)
x = final_funcs[i](inputs)
return x
class ReLUTuple(nn.Cell):
def __init__(self):
super(ReLUTuple, self).__init__()
self.op = nn.ReLU()
def construct(self, x):
return self.op(x[0])
func1 = nn.Softmax()
func2 = nn.ReLU()
func3 = ReLUTuple()
funcs = (func1, func2)
net = AddFuncNet(funcs, func3)
inp = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
i = Tensor(1, mstype.int32)
with pytest.raises(ValueError) as err:
net(i, inp)
def test_switch_layer_dtype_join_failed():
class Cast(nn.Cell):
def __init__(self, dtype):
super(Cast, self).__init__()
self.op = P.Cast()
self.dtype = dtype
def construct(self, x):
y = self.op(x, self.dtype)
return y + y
class SwitchNegNet(nn.Cell):
def __init__(self, funcs):
super(SwitchNegNet, self).__init__()
self.funcs = funcs
self.op = P.Neg()
def construct(self, i, inputs):
x = self.funcs[i](inputs)
x = self.op(x)
return x
func1 = nn.ReLU()
func2 = Cast(mstype.int32)
funcs = (func1, func2)
net = SwitchNegNet(funcs)
inp = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
i = Tensor(0, mstype.int32)
with pytest.raises(TypeError) as err:
net(i, inp)
......@@ -33,6 +33,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
from ....mindspore_test_framework.pipeline.gradient.compile_gradient \
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
from ....ops_common import convert
grad_all_with_sens = C.GradOperation('grad_all_with_sens', get_all=True, sens_param=True)
......@@ -1703,7 +1704,7 @@ test_case_nn_ops = [
('ResizeBilinear', {
'block': P.ResizeBilinear((5, 5)),
'desc_inputs': [Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mstype.float16)],
'desc_bprop': [Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mstype.float16)]}),
'desc_bprop': [Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mstype.float32)]}),
('ResizeBilinearGrad', {
'block': G.ResizeBilinearGrad(),
'desc_inputs': [Tensor([[[[1, 2, 3, 4, 5]]]], mstype.float32), Tensor([[[[1, 2, 3, 4, 5]]]], mstype.float32)],
......@@ -1712,7 +1713,7 @@ test_case_nn_ops = [
('ROIAlign', {
'block': P.ROIAlign(7, 7, 0.03125, 2),
'desc_inputs': [[2, 256, 192, 320], [1024, 5]],
'desc_bprop': [[7, 7]]}),
'desc_bprop': [[1024, 256, 7, 7]]}),
('ROIAlignGrad', {
'block': G.ROIAlignGrad((1, 1, 1, 1), 2, 2, 0.5, 2),
'desc_inputs': [[1, 1, 2, 2], [1, 5]],
......@@ -2315,7 +2316,7 @@ test_case_other_ops = [
('IOU', {
'block': P.IOU(),
'desc_inputs': [Tensor(np.ones((256, 4), np.float16)), Tensor(np.ones((128, 4), np.float16))],
'desc_bprop': [[128, 256]]}),
'desc_bprop': [convert([128, 256], np.float16)]}),
('Summary', {
'block': SummaryNet(),
'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)),
......
......@@ -118,29 +118,29 @@ test_case_reid_ops = [
'desc_inputs': [[256, 8]],
'desc_bprop': [[256, 8]]}),
('Pow', {
'block': P.Pow(), # 输入有标量插件产生了段错误。
'block': P.Pow(),
'desc_const': [2.0],
'desc_inputs': [[1, 512]],
'desc_bprop': [[1, 512]]}),
('LogicalNot', {
'block': P.LogicalNot(),
'desc_inputs': [convert([256], np.bool_)],
'desc_bprop': [[256]]}), # 自定义算子 input bool没转换,gongchen提单。
'desc_bprop': [convert([256], np.bool_)]}),
('Equal', {
'block': P.Equal(),
'desc_inputs': [convert([256], np.float16), convert([256], np.float16)],
'desc_bprop': [[256]]}),
'desc_bprop': [convert([256], np.bool_)]}),
('Greater', {
'block': P.Greater(),
'desc_inputs': [convert([256], np.float16), convert([256], np.float16)],
'desc_bprop': [[256]]}),
'desc_bprop': [convert([256], np.bool_)]}),
('Dropout', {
'block': nn.Dropout(),
'desc_inputs': [[1, 512, 7, 7]],
'desc_bprop': [[1, 512, 7, 7]]}), # 输入有标量插件产生了段错误。
'desc_bprop': [[1, 512, 7, 7]]}),
('MatMul', {
'block': P.MatMul(),
'desc_inputs': [[64, 512], [512, 64]], # fp16不行。很有问题。
'desc_inputs': [[64, 512], [512, 64]],
'desc_bprop': [[64, 64]]}),
('Maximum', {
'block': P.Maximum(),
......
......@@ -84,8 +84,8 @@ class Bprop(Cell):
self.grad = grad_op
self.with_sens = False
self.sens = sens
if sens:
self.sens = Tensor(sens, dtype=mstype.float32)
if not sens is None:
self.sens = sens if isinstance(sens, Tensor) else Tensor(sens, dtype=mstype.float32)
self.with_sens = True
def construct(self, *inputs):
......@@ -115,7 +115,7 @@ def test_all_var_args_grad_with_sens():
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
sens = Tensor(1.0, dtype=mstype.float32)
sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
net = VarNet(SecondNet())
grad_net = GradNet(net)
_ = grad_net(x, y, sens)
......@@ -167,7 +167,7 @@ def test_grad_all_var_args_with_sens():
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
sens = Tensor(1.0, dtype=mstype.float32)
sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
net = VarNet(SecondNet())
grad_net = GradNet(net)
_ = grad_net(x, y, sens)
......@@ -185,7 +185,7 @@ def test_grad_var_args_with_sens():
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
sens = Tensor(1.0, dtype=mstype.float32)
sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
net = VarNet(SecondNet())
grad_net = GradNet(net)
_ = grad_net(x, y, sens)
......@@ -244,7 +244,7 @@ def test_var_args_grad():
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
sens = Tensor(1.0, dtype=mstype.float32)
sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
net = VarNet(SecondNet())
grad_net = GradNet(net)
_ = grad_net(x, y, sens)
......@@ -292,14 +292,14 @@ def test_grad_within_if_else():
self.net = net
grad_op = C.GradOperation(
name='grad', get_all=False, get_by_list=True, sens_param=True)
self.grad = Bprop(self.net, True, self.weights, grad_op, 1.0)
sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
self.grad = Bprop(self.net, True, self.weights, grad_op, sens)
def construct(self, *inputs):
return self.grad(*inputs)
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
_ = Tensor(1.0, dtype=mstype.float32)
net = VarNet(SecondNet())
grad_net = GradNet(net)
out = grad_net(x, y)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册