Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
afd16fbf
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
afd16fbf
编写于
8月 25, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 25, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4963 fix bug of switch layer join
Merge pull request !4963 from fary86/fix_switch_layer_join_bug
上级
5b722a10
947e19b8
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
229 addition
and
35 deletion
+229
-35
mindspore/ccsrc/pipeline/jit/parse/data_converter.cc
mindspore/ccsrc/pipeline/jit/parse/data_converter.cc
+93
-3
mindspore/ccsrc/pipeline/jit/parse/parse_base.h
mindspore/ccsrc/pipeline/jit/parse/parse_base.h
+1
-1
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
+3
-2
mindspore/ccsrc/utils/convert_utils.cc
mindspore/ccsrc/utils/convert_utils.cc
+22
-2
mindspore/core/abstract/abstract_value.cc
mindspore/core/abstract/abstract_value.cc
+5
-1
mindspore/core/abstract/utils.cc
mindspore/core/abstract/utils.cc
+10
-2
mindspore/core/ir/scalar.h
mindspore/core/ir/scalar.h
+6
-6
tests/mindspore_test_framework/utils/check_gradient.py
tests/mindspore_test_framework/utils/check_gradient.py
+1
-1
tests/ut/python/ops/test_control_ops.py
tests/ut/python/ops/test_control_ops.py
+70
-0
tests/ut/python/ops/test_ops.py
tests/ut/python/ops/test_ops.py
+4
-3
tests/ut/python/ops/test_ops_reid.py
tests/ut/python/ops/test_ops_reid.py
+6
-6
tests/ut/python/parameter_feature/test_var_grad.py
tests/ut/python/parameter_feature/test_var_grad.py
+8
-8
未找到文件。
mindspore/ccsrc/pipeline/jit/parse/data_converter.cc
浏览文件 @
afd16fbf
...
...
@@ -283,9 +283,99 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
MS_LOG
(
ERROR
)
<<
"Resolve type is invalid "
<<
((
std
::
string
)
py
::
str
(
obj
));
return
false
;
}
bool
ConvertIntegerWithType
(
const
int
&
obj
,
ValuePtr
*
const
data
,
TypePtr
dtype
=
nullptr
)
{
if
(
dtype
==
nullptr
)
{
*
data
=
std
::
make_shared
<
Int32Imm
>
(
obj
);
return
true
;
}
auto
int_dypte
=
dyn_cast
<
Int
>
(
dtype
);
if
(
int_dypte
!=
nullptr
)
{
switch
(
int_dypte
->
nbits
())
{
case
8
:
*
data
=
std
::
make_shared
<
Int8Imm
>
(
static_cast
<
int8_t
>
(
obj
));
break
;
case
16
:
*
data
=
std
::
make_shared
<
Int16Imm
>
(
obj
);
break
;
case
32
:
*
data
=
std
::
make_shared
<
Int32Imm
>
(
obj
);
break
;
case
64
:
*
data
=
std
::
make_shared
<
Int64Imm
>
(
obj
);
break
;
default:
*
data
=
std
::
make_shared
<
Int32Imm
>
(
obj
);
}
return
true
;
}
auto
uint_dypte
=
dyn_cast
<
UInt
>
(
dtype
);
if
(
int_dypte
!=
nullptr
)
{
switch
(
uint_dypte
->
nbits
())
{
case
8
:
*
data
=
std
::
make_shared
<
UInt8Imm
>
(
obj
);
break
;
case
16
:
*
data
=
std
::
make_shared
<
UInt16Imm
>
(
obj
);
break
;
case
32
:
*
data
=
std
::
make_shared
<
UInt32Imm
>
(
obj
);
break
;
case
64
:
*
data
=
std
::
make_shared
<
UInt64Imm
>
(
obj
);
break
;
default:
*
data
=
std
::
make_shared
<
UInt32Imm
>
(
obj
);
}
return
true
;
}
auto
float_dypte
=
dyn_cast
<
Float
>
(
dtype
);
if
(
float_dypte
!=
nullptr
)
{
switch
(
float_dypte
->
nbits
())
{
case
32
:
*
data
=
std
::
make_shared
<
FP32Imm
>
(
obj
);
break
;
case
64
:
*
data
=
std
::
make_shared
<
FP64Imm
>
(
obj
);
break
;
default:
*
data
=
std
::
make_shared
<
FP32Imm
>
(
obj
);
}
return
true
;
}
return
false
;
}
bool
ConvertFloatWithType
(
const
float
&
obj
,
ValuePtr
*
const
data
,
TypePtr
dtype
=
nullptr
)
{
if
(
dtype
==
nullptr
)
{
*
data
=
std
::
make_shared
<
FP32Imm
>
(
obj
);
return
true
;
}
auto
float_dypte
=
dyn_cast
<
Float
>
(
dtype
);
if
(
float_dypte
==
nullptr
)
{
return
false
;
}
switch
(
float_dypte
->
nbits
())
{
case
32
:
*
data
=
std
::
make_shared
<
FP32Imm
>
(
obj
);
break
;
case
64
:
*
data
=
std
::
make_shared
<
FP64Imm
>
(
obj
);
break
;
default:
*
data
=
std
::
make_shared
<
FP32Imm
>
(
obj
);
}
return
true
;
}
}
// namespace
bool
ConvertData
(
const
py
::
object
&
obj
,
ValuePtr
*
const
data
,
bool
use_signature
)
{
bool
ConvertData
(
const
py
::
object
&
obj
,
ValuePtr
*
const
data
,
bool
use_signature
,
TypePtr
dtype
)
{
// check parameter valid
if
(
data
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Data is null pointer"
;
...
...
@@ -299,9 +389,9 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
}
else
if
(
py
::
isinstance
<
py
::
bool_
>
(
obj
))
{
converted
=
std
::
make_shared
<
BoolImm
>
(
py
::
cast
<
bool
>
(
obj
));
}
else
if
(
py
::
isinstance
<
py
::
int_
>
(
obj
))
{
converted
=
std
::
make_shared
<
Int32Imm
>
(
py
::
cast
<
int
>
(
obj
)
);
ret
=
ConvertIntegerWithType
(
py
::
cast
<
int
>
(
obj
),
&
converted
,
dtype
);
}
else
if
(
py
::
isinstance
<
py
::
float_
>
(
obj
))
{
converted
=
std
::
make_shared
<
FP32Imm
>
(
py
::
cast
<
float
>
(
obj
)
);
ret
=
ConvertFloatWithType
(
py
::
cast
<
float
>
(
obj
),
&
converted
,
dtype
);
}
else
if
(
py
::
isinstance
<
py
::
str
>
(
obj
))
{
converted
=
std
::
make_shared
<
StringImm
>
(
py
::
cast
<
std
::
string
>
(
obj
));
}
else
if
(
py
::
isinstance
<
py
::
dict
>
(
obj
))
{
...
...
mindspore/ccsrc/pipeline/jit/parse/parse_base.h
浏览文件 @
afd16fbf
...
...
@@ -139,7 +139,7 @@ enum ClassInstanceTypeDef {
};
// Convert python object to ValuePtr
bool
ConvertData
(
const
py
::
object
&
obj
,
ValuePtr
*
data
,
bool
use_signature
=
false
);
bool
ConvertData
(
const
py
::
object
&
obj
,
ValuePtr
*
data
,
bool
use_signature
=
false
,
TypePtr
dtype
=
nullptr
);
// Convert python obj to graph
FuncGraphPtr
ConvertToFuncGraph
(
const
py
::
object
&
obj
,
...
...
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
浏览文件 @
afd16fbf
...
...
@@ -407,9 +407,9 @@ py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrLi
AbstractBasePtr
PyInferRes2Abstract
(
const
PrimitivePyPtr
&
prim_py
,
const
py
::
dict
&
output
)
{
// Convert to AbstractValue based on type and shape
auto
out_dtype
=
output
[
"dtype"
];
if
(
output
[
"value"
].
is_none
())
{
auto
out_shape
=
output
[
"shape"
];
auto
out_dtype
=
output
[
"dtype"
];
py
::
object
min_shape
=
output
.
contains
(
"min_shape"
)
?
(
py
::
object
)
output
[
"min_shape"
]
:
(
py
::
object
)
py
::
none
();
py
::
object
max_shape
=
output
.
contains
(
"max_shape"
)
?
(
py
::
object
)
output
[
"max_shape"
]
:
(
py
::
object
)
py
::
none
();
...
...
@@ -417,7 +417,8 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
}
// Convert pyobject to Value, then to AbstractValue
ValuePtr
converted_ret
=
nullptr
;
bool
converted
=
parse
::
ConvertData
(
output
[
"value"
],
&
converted_ret
);
TypePtr
dtype
=
py
::
isinstance
<
Type
>
(
out_dtype
)
?
out_dtype
.
cast
<
TypePtr
>
()
:
nullptr
;
bool
converted
=
parse
::
ConvertData
(
output
[
"value"
],
&
converted_ret
,
false
,
dtype
);
if
(
!
converted
)
{
MS_LOG
(
EXCEPTION
)
<<
"Convert data failed"
;
}
...
...
mindspore/ccsrc/utils/convert_utils.cc
浏览文件 @
afd16fbf
...
...
@@ -45,14 +45,34 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
MS_LOG
(
EXCEPTION
)
<<
"value is null"
;
}
py
::
object
ret
;
if
(
value
->
isa
<
Int32Imm
>
())
{
MS_LOG
(
DEBUG
)
<<
"int"
;
if
(
value
->
isa
<
Int8Imm
>
())
{
MS_LOG
(
DEBUG
)
<<
"int8"
;
py
::
int_
v
=
value
->
cast
<
Int8ImmPtr
>
()
->
value
();
ret
=
v
;
}
else
if
(
value
->
isa
<
Int16Imm
>
())
{
MS_LOG
(
DEBUG
)
<<
"int16"
;
py
::
int_
v
=
value
->
cast
<
Int16ImmPtr
>
()
->
value
();
ret
=
v
;
}
else
if
(
value
->
isa
<
Int32Imm
>
())
{
MS_LOG
(
DEBUG
)
<<
"int32"
;
py
::
int_
v
=
value
->
cast
<
Int32ImmPtr
>
()
->
value
();
ret
=
v
;
}
else
if
(
value
->
isa
<
Int64Imm
>
())
{
MS_LOG
(
DEBUG
)
<<
"int64"
;
py
::
int_
v
=
value
->
cast
<
Int64ImmPtr
>
()
->
value
();
ret
=
v
;
}
else
if
(
value
->
isa
<
UInt8Imm
>
())
{
MS_LOG
(
DEBUG
)
<<
"uint8"
;
py
::
int_
v
=
value
->
cast
<
UInt8ImmPtr
>
()
->
value
();
ret
=
v
;
}
else
if
(
value
->
isa
<
UInt16Imm
>
())
{
MS_LOG
(
DEBUG
)
<<
"uint16"
;
py
::
int_
v
=
value
->
cast
<
UInt16ImmPtr
>
()
->
value
();
ret
=
v
;
}
else
if
(
value
->
isa
<
UInt32Imm
>
())
{
MS_LOG
(
DEBUG
)
<<
"uint32"
;
py
::
int_
v
=
value
->
cast
<
UInt32ImmPtr
>
()
->
value
();
ret
=
v
;
}
else
if
(
value
->
isa
<
UInt64Imm
>
())
{
MS_LOG
(
DEBUG
)
<<
"uint64"
;
py
::
int_
v
=
value
->
cast
<
UInt64ImmPtr
>
()
->
value
();
...
...
mindspore/core/abstract/abstract_value.cc
浏览文件 @
afd16fbf
...
...
@@ -97,8 +97,12 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
}
auto
value_self
=
GetValueTrack
();
MS_EXCEPTION_IF_NULL
(
value_self
);
ValuePtr
res_value
=
ValueJoin
(
value_self
,
other
->
GetValueTrack
());
TypePtr
res_type
=
TypeJoin
(
GetTypeTrack
(),
other
->
GetTypeTrack
());
if
(
res_type
==
kAnyType
)
{
MS_EXCEPTION
(
TypeError
)
<<
"Type join failed, type1 = "
<<
GetTypeTrack
()
->
ToString
()
<<
", type2 = "
<<
other
->
GetTypeTrack
()
->
ToString
();
}
ValuePtr
res_value
=
ValueJoin
(
value_self
,
other
->
GetValueTrack
());
if
(
res_value
==
value_self
)
{
return
shared_from_base
<
AbstractBase
>
();
}
...
...
mindspore/core/abstract/utils.cc
浏览文件 @
afd16fbf
...
...
@@ -50,9 +50,17 @@ ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) {
if
(
*
shape1
==
*
shape2
)
{
return
shape1
;
}
// lengths of two shapes are not same, join failed
if
(
shape1
->
shape
().
size
()
!=
shape2
->
shape
().
size
())
{
MS_LOG
(
WARNING
)
<<
"Unsupported shape join. shape1 = "
<<
shape1
->
ToString
()
<<
", shape2 = "
<<
shape2
->
ToString
();
return
shape1
;
// special case: shape(1), shape() -> shape(1)
if
(
shape1
->
shape
().
size
()
==
1
&&
shape1
->
shape
()[
0
]
==
1
&&
shape2
->
shape
().
size
()
==
0
)
{
return
shape1
;
}
if
(
shape2
->
shape
().
size
()
==
1
&&
shape2
->
shape
()[
0
]
==
1
&&
shape1
->
shape
().
size
()
==
0
)
{
return
shape2
;
}
MS_EXCEPTION
(
ValueError
)
<<
"Unsupported shape join. shape1 = "
<<
shape1
->
ToString
()
<<
", shape2 = "
<<
shape2
->
ToString
();
}
std
::
vector
<
int
>
dims
;
bool
has_dynamic_shape
=
false
;
...
...
mindspore/core/ir/scalar.h
浏览文件 @
afd16fbf
...
...
@@ -105,7 +105,7 @@ class Int8Imm : public IntergerImm {
std
::
string
DumpText
()
const
override
{
std
::
ostringstream
oss
;
oss
<<
"I8("
<<
v_
<<
")"
;
oss
<<
"I8("
<<
int
(
v_
)
<<
")"
;
return
oss
.
str
();
}
...
...
@@ -131,7 +131,7 @@ class Int16Imm : public IntergerImm {
std
::
string
DumpText
()
const
override
{
std
::
ostringstream
oss
;
oss
<<
"I16("
<<
v_
<<
")"
;
oss
<<
"I16("
<<
int
(
v_
)
<<
")"
;
return
oss
.
str
();
}
...
...
@@ -157,7 +157,7 @@ class Int32Imm : public IntergerImm {
std
::
string
DumpText
()
const
override
{
std
::
ostringstream
oss
;
oss
<<
"I32("
<<
v_
<<
")"
;
oss
<<
"I32("
<<
int
(
v_
)
<<
")"
;
return
oss
.
str
();
}
...
...
@@ -211,7 +211,7 @@ class UInt8Imm : public IntergerImm {
std
::
string
DumpText
()
const
override
{
std
::
ostringstream
oss
;
oss
<<
"U8("
<<
v_
<<
")"
;
oss
<<
"U8("
<<
unsigned
(
v_
)
<<
")"
;
return
oss
.
str
();
}
...
...
@@ -239,7 +239,7 @@ class UInt16Imm : public IntergerImm {
std
::
string
DumpText
()
const
override
{
std
::
ostringstream
oss
;
oss
<<
"U16("
<<
v_
<<
")"
;
oss
<<
"U16("
<<
unsigned
(
v_
)
<<
")"
;
return
oss
.
str
();
}
...
...
@@ -267,7 +267,7 @@ class UInt32Imm : public IntergerImm {
std
::
string
DumpText
()
const
override
{
std
::
ostringstream
oss
;
oss
<<
"U32("
<<
v_
<<
")"
;
oss
<<
"U32("
<<
unsigned
(
v_
)
<<
")"
;
return
oss
.
str
();
}
...
...
tests/mindspore_test_framework/utils/check_gradient.py
浏览文件 @
afd16fbf
...
...
@@ -324,7 +324,7 @@ class ScalarGradChecker(_GradChecker):
self
.
input_selector
=
[
i
for
i
in
range
(
self
.
nin
)]
def
get_sens
(
self
,
i
):
return
1
return
1
.0
def
check_against_numeric
(
self
,
out_index
):
args
=
list
(
self
.
args
)
...
...
tests/ut/python/ops/test_control_ops.py
浏览文件 @
afd16fbf
...
...
@@ -916,3 +916,73 @@ def test_recursive_call():
with
pytest
.
raises
(
RuntimeError
):
net
(
input_data
)
context
.
set_context
(
max_call_depth
=
old_max_call_depth
)
def
test_switch_layer_shape_join_failed
():
class
AddFuncNet
(
nn
.
Cell
):
def
__init__
(
self
,
funcs
,
new_func
):
super
(
AddFuncNet
,
self
).
__init__
()
self
.
funcs
=
funcs
self
.
new_func
=
new_func
def
construct
(
self
,
i
,
inputs
):
final_funcs
=
self
.
funcs
+
(
self
.
new_func
,)
x
=
final_funcs
[
i
](
inputs
)
return
x
class
ReLUTuple
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
ReLUTuple
,
self
).
__init__
()
self
.
op
=
nn
.
ReLU
()
def
construct
(
self
,
x
):
return
self
.
op
(
x
[
0
])
func1
=
nn
.
Softmax
()
func2
=
nn
.
ReLU
()
func3
=
ReLUTuple
()
funcs
=
(
func1
,
func2
)
net
=
AddFuncNet
(
funcs
,
func3
)
inp
=
Tensor
(
np
.
random
.
randn
(
2
,
3
,
4
,
5
).
astype
(
np
.
float32
))
i
=
Tensor
(
1
,
mstype
.
int32
)
with
pytest
.
raises
(
ValueError
)
as
err
:
net
(
i
,
inp
)
def
test_switch_layer_dtype_join_failed
():
class
Cast
(
nn
.
Cell
):
def
__init__
(
self
,
dtype
):
super
(
Cast
,
self
).
__init__
()
self
.
op
=
P
.
Cast
()
self
.
dtype
=
dtype
def
construct
(
self
,
x
):
y
=
self
.
op
(
x
,
self
.
dtype
)
return
y
+
y
class
SwitchNegNet
(
nn
.
Cell
):
def
__init__
(
self
,
funcs
):
super
(
SwitchNegNet
,
self
).
__init__
()
self
.
funcs
=
funcs
self
.
op
=
P
.
Neg
()
def
construct
(
self
,
i
,
inputs
):
x
=
self
.
funcs
[
i
](
inputs
)
x
=
self
.
op
(
x
)
return
x
func1
=
nn
.
ReLU
()
func2
=
Cast
(
mstype
.
int32
)
funcs
=
(
func1
,
func2
)
net
=
SwitchNegNet
(
funcs
)
inp
=
Tensor
(
np
.
random
.
randn
(
2
,
3
,
4
,
5
).
astype
(
np
.
float32
))
i
=
Tensor
(
0
,
mstype
.
int32
)
with
pytest
.
raises
(
TypeError
)
as
err
:
net
(
i
,
inp
)
tests/ut/python/ops/test_ops.py
浏览文件 @
afd16fbf
...
...
@@ -33,6 +33,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
)
from
....mindspore_test_framework.pipeline.gradient.compile_gradient
\
import
pipeline_for_compile_grad_ge_graph_for_case_by_case_config
from
....ops_common
import
convert
grad_all_with_sens
=
C
.
GradOperation
(
'grad_all_with_sens'
,
get_all
=
True
,
sens_param
=
True
)
...
...
@@ -1703,7 +1704,7 @@ test_case_nn_ops = [
(
'ResizeBilinear'
,
{
'block'
:
P
.
ResizeBilinear
((
5
,
5
)),
'desc_inputs'
:
[
Tensor
([[[[
1
,
2
,
3
,
4
,
5
],
[
1
,
2
,
3
,
4
,
5
]]]],
mstype
.
float16
)],
'desc_bprop'
:
[
Tensor
([[[[
1
,
2
,
3
,
4
,
5
],
[
1
,
2
,
3
,
4
,
5
]]]],
mstype
.
float
16
)]}),
'desc_bprop'
:
[
Tensor
([[[[
1
,
2
,
3
,
4
,
5
],
[
1
,
2
,
3
,
4
,
5
]]]],
mstype
.
float
32
)]}),
(
'ResizeBilinearGrad'
,
{
'block'
:
G
.
ResizeBilinearGrad
(),
'desc_inputs'
:
[
Tensor
([[[[
1
,
2
,
3
,
4
,
5
]]]],
mstype
.
float32
),
Tensor
([[[[
1
,
2
,
3
,
4
,
5
]]]],
mstype
.
float32
)],
...
...
@@ -1712,7 +1713,7 @@ test_case_nn_ops = [
(
'ROIAlign'
,
{
'block'
:
P
.
ROIAlign
(
7
,
7
,
0.03125
,
2
),
'desc_inputs'
:
[[
2
,
256
,
192
,
320
],
[
1024
,
5
]],
'desc_bprop'
:
[[
7
,
7
]]}),
'desc_bprop'
:
[[
1024
,
256
,
7
,
7
]]}),
(
'ROIAlignGrad'
,
{
'block'
:
G
.
ROIAlignGrad
((
1
,
1
,
1
,
1
),
2
,
2
,
0.5
,
2
),
'desc_inputs'
:
[[
1
,
1
,
2
,
2
],
[
1
,
5
]],
...
...
@@ -2315,7 +2316,7 @@ test_case_other_ops = [
(
'IOU'
,
{
'block'
:
P
.
IOU
(),
'desc_inputs'
:
[
Tensor
(
np
.
ones
((
256
,
4
),
np
.
float16
)),
Tensor
(
np
.
ones
((
128
,
4
),
np
.
float16
))],
'desc_bprop'
:
[
[
128
,
256
]
]}),
'desc_bprop'
:
[
convert
([
128
,
256
],
np
.
float16
)
]}),
(
'Summary'
,
{
'block'
:
SummaryNet
(),
'desc_inputs'
:
[
Tensor
(
np
.
array
([
1.1
]).
astype
(
np
.
float32
)),
...
...
tests/ut/python/ops/test_ops_reid.py
浏览文件 @
afd16fbf
...
...
@@ -118,29 +118,29 @@ test_case_reid_ops = [
'desc_inputs'
:
[[
256
,
8
]],
'desc_bprop'
:
[[
256
,
8
]]}),
(
'Pow'
,
{
'block'
:
P
.
Pow
(),
# 输入有标量插件产生了段错误。
'block'
:
P
.
Pow
(),
'desc_const'
:
[
2.0
],
'desc_inputs'
:
[[
1
,
512
]],
'desc_bprop'
:
[[
1
,
512
]]}),
(
'LogicalNot'
,
{
'block'
:
P
.
LogicalNot
(),
'desc_inputs'
:
[
convert
([
256
],
np
.
bool_
)],
'desc_bprop'
:
[
[
256
]]}),
# 自定义算子 input bool没转换,gongchen提单。
'desc_bprop'
:
[
convert
([
256
],
np
.
bool_
)]}),
(
'Equal'
,
{
'block'
:
P
.
Equal
(),
'desc_inputs'
:
[
convert
([
256
],
np
.
float16
),
convert
([
256
],
np
.
float16
)],
'desc_bprop'
:
[
[
256
]
]}),
'desc_bprop'
:
[
convert
([
256
],
np
.
bool_
)
]}),
(
'Greater'
,
{
'block'
:
P
.
Greater
(),
'desc_inputs'
:
[
convert
([
256
],
np
.
float16
),
convert
([
256
],
np
.
float16
)],
'desc_bprop'
:
[
[
256
]
]}),
'desc_bprop'
:
[
convert
([
256
],
np
.
bool_
)
]}),
(
'Dropout'
,
{
'block'
:
nn
.
Dropout
(),
'desc_inputs'
:
[[
1
,
512
,
7
,
7
]],
'desc_bprop'
:
[[
1
,
512
,
7
,
7
]]}),
# 输入有标量插件产生了段错误。
'desc_bprop'
:
[[
1
,
512
,
7
,
7
]]}),
(
'MatMul'
,
{
'block'
:
P
.
MatMul
(),
'desc_inputs'
:
[[
64
,
512
],
[
512
,
64
]],
# fp16不行。很有问题。
'desc_inputs'
:
[[
64
,
512
],
[
512
,
64
]],
'desc_bprop'
:
[[
64
,
64
]]}),
(
'Maximum'
,
{
'block'
:
P
.
Maximum
(),
...
...
tests/ut/python/parameter_feature/test_var_grad.py
浏览文件 @
afd16fbf
...
...
@@ -84,8 +84,8 @@ class Bprop(Cell):
self
.
grad
=
grad_op
self
.
with_sens
=
False
self
.
sens
=
sens
if
sens
:
self
.
sens
=
Tensor
(
sens
,
dtype
=
mstype
.
float32
)
if
not
sens
is
None
:
self
.
sens
=
sens
if
isinstance
(
sens
,
Tensor
)
else
Tensor
(
sens
,
dtype
=
mstype
.
float32
)
self
.
with_sens
=
True
def
construct
(
self
,
*
inputs
):
...
...
@@ -115,7 +115,7 @@ def test_all_var_args_grad_with_sens():
x
=
Tensor
(
np
.
ones
([
3
,
4
,
5
]),
dtype
=
mstype
.
float32
)
y
=
Tensor
(
np
.
ones
([
3
,
4
,
5
]),
dtype
=
mstype
.
float32
)
sens
=
Tensor
(
1.0
,
dtype
=
mstype
.
float32
)
sens
=
Tensor
(
np
.
ones
([
3
,
4
,
5
])
,
dtype
=
mstype
.
float32
)
net
=
VarNet
(
SecondNet
())
grad_net
=
GradNet
(
net
)
_
=
grad_net
(
x
,
y
,
sens
)
...
...
@@ -167,7 +167,7 @@ def test_grad_all_var_args_with_sens():
x
=
Tensor
(
np
.
ones
([
3
,
4
,
5
]),
dtype
=
mstype
.
float32
)
y
=
Tensor
(
np
.
ones
([
3
,
4
,
5
]),
dtype
=
mstype
.
float32
)
sens
=
Tensor
(
1.0
,
dtype
=
mstype
.
float32
)
sens
=
Tensor
(
np
.
ones
([
3
,
4
,
5
])
,
dtype
=
mstype
.
float32
)
net
=
VarNet
(
SecondNet
())
grad_net
=
GradNet
(
net
)
_
=
grad_net
(
x
,
y
,
sens
)
...
...
@@ -185,7 +185,7 @@ def test_grad_var_args_with_sens():
x
=
Tensor
(
np
.
ones
([
3
,
4
,
5
]),
dtype
=
mstype
.
float32
)
y
=
Tensor
(
np
.
ones
([
3
,
4
,
5
]),
dtype
=
mstype
.
float32
)
sens
=
Tensor
(
1.0
,
dtype
=
mstype
.
float32
)
sens
=
Tensor
(
np
.
ones
([
3
,
4
,
5
])
,
dtype
=
mstype
.
float32
)
net
=
VarNet
(
SecondNet
())
grad_net
=
GradNet
(
net
)
_
=
grad_net
(
x
,
y
,
sens
)
...
...
@@ -244,7 +244,7 @@ def test_var_args_grad():
x
=
Tensor
(
np
.
ones
([
3
,
4
,
5
]),
dtype
=
mstype
.
float32
)
y
=
Tensor
(
np
.
ones
([
3
,
4
,
5
]),
dtype
=
mstype
.
float32
)
sens
=
Tensor
(
1.0
,
dtype
=
mstype
.
float32
)
sens
=
Tensor
(
np
.
ones
([
3
,
4
,
5
])
,
dtype
=
mstype
.
float32
)
net
=
VarNet
(
SecondNet
())
grad_net
=
GradNet
(
net
)
_
=
grad_net
(
x
,
y
,
sens
)
...
...
@@ -292,14 +292,14 @@ def test_grad_within_if_else():
self
.
net
=
net
grad_op
=
C
.
GradOperation
(
name
=
'grad'
,
get_all
=
False
,
get_by_list
=
True
,
sens_param
=
True
)
self
.
grad
=
Bprop
(
self
.
net
,
True
,
self
.
weights
,
grad_op
,
1.0
)
sens
=
Tensor
(
np
.
ones
([
3
,
4
,
5
]),
dtype
=
mstype
.
float32
)
self
.
grad
=
Bprop
(
self
.
net
,
True
,
self
.
weights
,
grad_op
,
sens
)
def
construct
(
self
,
*
inputs
):
return
self
.
grad
(
*
inputs
)
x
=
Tensor
(
np
.
ones
([
3
,
4
,
5
]),
dtype
=
mstype
.
float32
)
y
=
Tensor
(
np
.
ones
([
3
,
4
,
5
]),
dtype
=
mstype
.
float32
)
_
=
Tensor
(
1.0
,
dtype
=
mstype
.
float32
)
net
=
VarNet
(
SecondNet
())
grad_net
=
GradNet
(
net
)
out
=
grad_net
(
x
,
y
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录