Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
8bf35b2b
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8bf35b2b
编写于
4月 28, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 28, 2020
浏览文件
操作
浏览文件
下载
差异文件
!700 validate bprop rules
Merge pull request !700 from penn/validate_bprop_rules
上级
92bb8a7c
9e633b6c
变更
25
隐藏空白更改
内联
并排
Showing
25 changed file
with
275 addition
and
70 deletion
+275
-70
mindspore/ccsrc/ir/dtype.cc
mindspore/ccsrc/ir/dtype.cc
+1
-0
mindspore/ccsrc/operator/ops.cc
mindspore/ccsrc/operator/ops.cc
+1
-0
mindspore/ccsrc/operator/ops.h
mindspore/ccsrc/operator/ops.h
+1
-0
mindspore/ccsrc/optimizer/ad/dfunctor.cc
mindspore/ccsrc/optimizer/ad/dfunctor.cc
+0
-8
mindspore/ccsrc/optimizer/ad/dfunctor.h
mindspore/ccsrc/optimizer/ad/dfunctor.h
+2
-5
mindspore/ccsrc/optimizer/ad/kprim.cc
mindspore/ccsrc/optimizer/ad/kprim.cc
+21
-25
mindspore/ccsrc/optimizer/irpass.cc
mindspore/ccsrc/optimizer/irpass.cc
+1
-0
mindspore/ccsrc/optimizer/irpass.h
mindspore/ccsrc/optimizer/irpass.h
+1
-0
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h
+19
-0
mindspore/ccsrc/pipeline/pass.cc
mindspore/ccsrc/pipeline/pass.cc
+1
-0
mindspore/ccsrc/pipeline/static_analysis/prim.cc
mindspore/ccsrc/pipeline/static_analysis/prim.cc
+7
-0
mindspore/common/dtype.py
mindspore/common/dtype.py
+3
-0
mindspore/ops/_grad/grad_array_ops.py
mindspore/ops/_grad/grad_array_ops.py
+3
-3
mindspore/ops/_grad/grad_math_ops.py
mindspore/ops/_grad/grad_math_ops.py
+2
-2
mindspore/ops/_grad/grad_nn_ops.py
mindspore/ops/_grad/grad_nn_ops.py
+1
-1
mindspore/ops/composite/multitype_ops/zeros_like_impl.py
mindspore/ops/composite/multitype_ops/zeros_like_impl.py
+4
-0
mindspore/ops/functional.py
mindspore/ops/functional.py
+1
-0
mindspore/ops/operations/__init__.py
mindspore/ops/operations/__init__.py
+2
-1
mindspore/ops/operations/other_ops.py
mindspore/ops/operations/other_ops.py
+63
-0
tests/ut/python/model/test_bert_cell.py
tests/ut/python/model/test_bert_cell.py
+4
-4
tests/ut/python/model/test_mix_precision.py
tests/ut/python/model/test_mix_precision.py
+1
-1
tests/ut/python/ops/test_ops.py
tests/ut/python/ops/test_ops.py
+11
-11
tests/ut/python/pynative_mode/test_cell_bprop.py
tests/ut/python/pynative_mode/test_cell_bprop.py
+31
-8
tests/ut/python/pynative_mode/test_framstruct.py
tests/ut/python/pynative_mode/test_framstruct.py
+93
-0
tests/ut/python/pynative_mode/test_insert_grad_of.py
tests/ut/python/pynative_mode/test_insert_grad_of.py
+1
-1
未找到文件。
mindspore/ccsrc/ir/dtype.cc
浏览文件 @
8bf35b2b
...
...
@@ -695,6 +695,7 @@ REGISTER_PYBIND_DEFINE(
(
void
)
py
::
class_
<
String
,
Type
,
std
::
shared_ptr
<
String
>>
(
m_sub
,
"String"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
RefKeyType
,
Type
,
std
::
shared_ptr
<
RefKeyType
>>
(
m_sub
,
"RefKeyType"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
RefType
,
Type
,
std
::
shared_ptr
<
RefType
>>
(
m_sub
,
"RefType"
).
def
(
py
::
init
());
(
void
)
py
::
class_
<
TypeAnything
,
Type
,
std
::
shared_ptr
<
TypeAnything
>>
(
m_sub
,
"TypeAnything"
).
def
(
py
::
init
());
}));
const
TypePtr
kTypeExternal
=
std
::
make_shared
<
External
>
();
...
...
mindspore/ccsrc/operator/ops.cc
浏览文件 @
8bf35b2b
...
...
@@ -213,6 +213,7 @@ const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_orig
const
PrimitivePtr
kPrimInsertGradientOf
=
std
::
make_shared
<
Primitive
>
(
"InsertGradientOf"
);
const
PrimitivePtr
kPrimPrintShapeType
=
std
::
make_shared
<
Primitive
>
(
"PrintShapeType"
);
const
PrimitivePtr
kPrimSameTypeShape
=
std
::
make_shared
<
Primitive
>
(
"SameTypeShape"
);
const
PrimitivePtr
kPrimCheckBprop
=
std
::
make_shared
<
Primitive
>
(
"CheckBprop"
);
const
PrimitivePtr
kPrimPrint
=
std
::
make_shared
<
Primitive
>
(
"Print"
);
const
PrimitivePtr
kPrimMakeRef
=
std
::
make_shared
<
Primitive
>
(
"make_ref"
);
...
...
mindspore/ccsrc/operator/ops.h
浏览文件 @
8bf35b2b
...
...
@@ -220,6 +220,7 @@ extern const PrimitivePtr kPrimInsertGradientOf;
extern
const
PrimitivePtr
kPrimPrintShapeType
;
extern
const
PrimitivePtr
kPrimPrint
;
extern
const
PrimitivePtr
kPrimSameTypeShape
;
extern
const
PrimitivePtr
kPrimCheckBprop
;
extern
const
PrimitivePtr
kPrimDepend
;
extern
const
PrimitivePtr
kPrimStateSetItem
;
extern
const
PrimitivePtr
kPrimScalarSummary
;
...
...
mindspore/ccsrc/optimizer/ad/dfunctor.cc
浏览文件 @
8bf35b2b
...
...
@@ -309,14 +309,6 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
auto
bprop
=
primal
->
transforms
().
find
(
"bprop"
);
if
(
bprop
!=
primal
->
transforms
().
end
())
{
FuncGraphPtr
bprop_graph
=
bprop
->
second
.
func_graph
();
const
size_t
param_diff
=
1
;
if
(
bprop_graph
->
output
()
->
isa
<
CNode
>
()
&&
bprop_graph
->
output
()
->
cast
<
CNodePtr
>
()
->
size
()
+
param_diff
!=
bprop_graph
->
parameters
().
size
())
{
// It does not matter with the final tangents, just a tip for debugging
MS_LOG
(
DEBUG
)
<<
"User defined Cell bprop "
<<
primal
->
ToString
()
<<
" in scope "
<<
primal
->
output
()
->
scope
()
->
name
()
<<
" output must be a tuple and output number should be the same with inputs."
;
}
resources_
->
manager
()
->
AddFuncGraph
(
bprop_graph
);
if
(
bprop_graph
->
free_variables_nodes
().
size
()
!=
0
||
primal
->
free_variables_nodes
().
size
()
!=
0
)
{
...
...
mindspore/ccsrc/optimizer/ad/dfunctor.h
浏览文件 @
8bf35b2b
...
...
@@ -127,7 +127,7 @@ class KPrim {
AnfNodePtr
BuildOutput
(
const
FuncGraphPtr
&
bprop_fg
);
void
TransformArgs
(
const
FuncGraphManagerPtr
&
mng
,
const
FuncGraphPtr
&
bprop_fg
,
const
FuncGraphPtr
&
outer
,
std
::
vector
<
AnfNodePtr
>
*
const
transf_args
);
void
AddCheckTypeShapeOp
(
const
FuncGraphPtr
&
bprop_fg
);
void
CheckBprop
(
const
FuncGraphPtr
&
bprop_fg
,
const
string
&
prim_to_check
);
Registry
bprop_registry_
;
std
::
unordered_map
<
PrimitivePtr
,
MetaFuncGraphPtr
>
bprop_registry_meta_
;
...
...
@@ -137,10 +137,7 @@ template <typename T>
FuncGraphPtr
KPrim
::
BpropToK
(
const
T
&
primal
,
const
FuncGraphPtr
&
bprop_fg
)
{
MS_EXCEPTION_IF_NULL
(
primal
);
MS_EXCEPTION_IF_NULL
(
bprop_fg
);
if
(
IsPrimitiveCNode
(
bprop_fg
->
output
(),
prim
::
kPrimMakeTuple
))
{
AddCheckTypeShapeOp
(
bprop_fg
);
}
CheckBprop
(
bprop_fg
,
primal
->
ToString
());
auto
debug_info
=
std
::
make_shared
<
GraphDebugInfo
>
();
debug_info
->
set_name
(
primal
->
ToString
());
...
...
mindspore/ccsrc/optimizer/ad/kprim.cc
浏览文件 @
8bf35b2b
...
...
@@ -50,9 +50,13 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
grad_op_child_scope_prefix
+
prim
->
name
());
ScopeGuard
scope_guard
(
scope
);
py
::
function
fn
=
prim
->
GetBpropFunction
();
if
(
fn
==
nullptr
||
py
::
isinstance
<
py
::
none
>
(
fn
))
{
MS_LOG
(
DEBUG
)
<<
"Fail to find bprop function for "
<<
prim
->
name
()
<<
"."
;
return
nullptr
;
}
FuncGraphPtr
func_graph
=
parse
::
ParsePythonCode
(
fn
);
if
(
func_graph
==
nullptr
)
{
MS_LOG
(
WARNING
)
<<
"Fail to find
bprop function for "
<<
prim
->
name
()
<<
"."
;
MS_LOG
(
ERROR
)
<<
"Fail to parse
bprop function for "
<<
prim
->
name
()
<<
"."
;
return
nullptr
;
}
return
func_graph
;
...
...
@@ -153,31 +157,23 @@ void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bp
}
}
void
KPrim
::
AddCheckTypeShapeOp
(
const
FuncGraphPtr
&
bprop_fg
)
{
void
KPrim
::
CheckBprop
(
const
FuncGraphPtr
&
bprop_fg
,
const
string
&
prim_to_check
)
{
// bprop_fg has been checked in caller
auto
same_type_shape
=
prim
::
GetPythonOps
(
"same_type_shape"
,
"mindspore.ops.functional"
)
->
cast
<
PrimitivePtr
>
();
MS_EXCEPTION_IF_NULL
(
same_type_shape
);
std
::
vector
<
AnfNodePtr
>
bout_input
;
bout_input
.
push_back
(
NewValueNode
(
prim
::
kPrimMakeTuple
));
auto
fg_out
=
bprop_fg
->
output
();
MS_EXCEPTION_IF_NULL
(
fg_out
);
auto
cnode
=
fg_out
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
&
inputs
=
cnode
->
inputs
();
auto
params
=
bprop_fg
->
parameters
();
std
::
vector
<
AnfNodePtr
>
sub_input
;
for
(
size_t
i
=
1
;
i
<
inputs
.
size
();
++
i
)
{
sub_input
.
clear
();
sub_input
.
push_back
(
NewValueNode
(
same_type_shape
));
sub_input
.
push_back
(
inputs
[
i
]);
sub_input
.
push_back
(
params
[
i
-
1
]);
bout_input
.
push_back
(
bprop_fg
->
NewCNode
(
sub_input
));
}
AnfNodePtr
cbout
=
bprop_fg
->
NewCNode
(
bout_input
);
bprop_fg
->
set_output
(
cbout
);
auto
check_bprop
=
prim
::
GetPythonOps
(
"check_bprop"
,
"mindspore.ops.functional"
)
->
cast
<
PrimitivePtr
>
();
MS_EXCEPTION_IF_NULL
(
check_bprop
);
check_bprop
->
set_attr
(
"prim_to_check"
,
std
::
make_shared
<
StringImm
>
(
prim_to_check
));
std
::
vector
<
AnfNodePtr
>
inputs
;
inputs
.
emplace_back
(
NewValueNode
(
prim
::
kPrimMakeTuple
));
inputs
.
insert
(
inputs
.
begin
()
+
1
,
bprop_fg
->
parameters
().
begin
(),
bprop_fg
->
parameters
().
end
()
-
2
);
AnfNodePtr
params
=
bprop_fg
->
NewCNode
(
inputs
);
inputs
.
clear
();
inputs
.
push_back
(
NewValueNode
(
check_bprop
));
inputs
.
push_back
(
bprop_fg
->
output
());
inputs
.
push_back
(
params
);
AnfNodePtr
bprop_out
=
bprop_fg
->
NewCNode
(
inputs
);
bprop_fg
->
set_output
(
bprop_out
);
}
FuncGraphPtr
KPrim
::
KUserDefinedCellBprop
(
const
FuncGraphPtr
bprop_fg
)
{
...
...
mindspore/ccsrc/optimizer/irpass.cc
浏览文件 @
8bf35b2b
...
...
@@ -67,6 +67,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
{
prim
::
kPrimReduceMean
,
prim
::
kPrimReduceAll
,
prim
::
kPrimReduceSum
,
prim
::
kPrimReduceMax
,
prim
::
kPrimReduceMin
});
partial_eliminate_
=
MakeSubstitution
(
PartialEliminater
(),
"partial_eliminate"
,
IsCNodeDup
);
same_eliminate_
=
MakeSubstitution
(
SameEliminater
(),
"same_eliminate"
,
prim
::
kPrimSameTypeShape
);
check_bprop_eliminate_
=
MakeSubstitution
(
CheckBpropEliminater
(),
"check_bprop_eliminate"
,
prim
::
kPrimCheckBprop
);
reset_defer_inline_
=
MakeSubstitution
(
ResetDeferInline
(),
"reset_defer_inline"
,
IsValueNode
<
FuncGraph
>
);
// Env Item Eliminate
...
...
mindspore/ccsrc/optimizer/irpass.h
浏览文件 @
8bf35b2b
...
...
@@ -45,6 +45,7 @@ class OptimizeIRPassLib {
SubstitutionPtr
reduce_eliminate_
;
SubstitutionPtr
partial_eliminate_
;
SubstitutionPtr
same_eliminate_
;
SubstitutionPtr
check_bprop_eliminate_
;
SubstitutionPtr
reset_defer_inline_
;
// Env Item Eliminate
...
...
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h
浏览文件 @
8bf35b2b
...
...
@@ -109,6 +109,25 @@ class SameEliminater : public AnfVisitor {
AnfNodePtr
x_
{
nullptr
};
};
// {prim::kPrimCheckBprop, X, Y} -> X
class
CheckBpropEliminater
:
public
AnfVisitor
{
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
x_
=
nullptr
;
AnfVisitor
::
Match
(
prim
::
kPrimCheckBprop
,
{
IsNode
,
IsNode
})(
node
);
return
x_
;
}
void
Visit
(
const
AnfNodePtr
&
node
)
override
{
if
(
x_
==
nullptr
)
{
x_
=
node
;
}
}
private:
AnfNodePtr
x_
{
nullptr
};
};
// Reset defer_inline flag
class
ResetDeferInline
:
public
AnfVisitor
{
public:
...
...
mindspore/ccsrc/pipeline/pass.cc
浏览文件 @
8bf35b2b
...
...
@@ -108,6 +108,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
});
opt
::
OptPassConfig
a_3
=
opt
::
OptPassConfig
({
irpass
.
same_eliminate_
,
irpass
.
check_bprop_eliminate_
,
irpass
.
replace_applicator_
,
});
opt
::
OptPassConfig
virtual_dataset
=
opt
::
OptPassConfig
({
irpass
.
virtual_dataset_eliminate_
});
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.cc
浏览文件 @
8bf35b2b
...
...
@@ -295,6 +295,9 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
dic
[
"shape"
]
=
shape
;
dic
[
"dtype"
]
=
arg_slice
->
BuildType
();
dic
[
"value"
]
=
BuildValue
(
arg_slice
->
BuildValue
());
}
else
if
(
abs_base
->
isa
<
AbstractRef
>
())
{
auto
value
=
abs_base
->
cast
<
AbstractRefPtr
>
()
->
ref
();
dic
=
ConvertAbstractToPython
(
value
);
}
else
if
(
abs_base
->
isa
<
AbstractTuple
>
())
{
auto
arg_tuple
=
dyn_cast
<
AbstractTuple
>
(
abs_base
);
size_t
len
=
arg_tuple
->
size
();
...
...
@@ -327,6 +330,10 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
dic
[
"shape"
]
=
py
::
none
();
dic
[
"dtype"
]
=
py
::
none
();
dic
[
"value"
]
=
py
::
none
();
}
else
if
(
abs_base
->
isa
<
AbstractFunction
>
())
{
dic
[
"shape"
]
=
py
::
none
();
dic
[
"dtype"
]
=
abs_base
->
BuildType
();
dic
[
"value"
]
=
py
::
none
();
}
else
{
auto
value
=
abs_base
->
BuildValue
();
if
((
*
value
==
*
kAnyValue
))
{
...
...
mindspore/common/dtype.py
浏览文件 @
8bf35b2b
...
...
@@ -85,13 +85,16 @@ list_ = typing.List()
tuple_
=
typing
.
Tuple
()
tensor
=
typing
.
TensorType
()
function
=
typing
.
Function
()
function_type
=
typing
.
Function
symbolic_key
=
typing
.
SymbolicKeyType
()
env_type
=
typing
.
EnvType
()
env_type_type
=
typing
.
EnvType
type_type
=
typing
.
TypeType
()
type_none
=
typing
.
TypeNone
()
string
=
typing
.
String
()
type_refkey
=
typing
.
RefKeyType
()
tensor_type
=
typing
.
TensorType
anything_type
=
typing
.
TypeAnything
number_type
=
(
int8
,
int16
,
...
...
mindspore/ops/_grad/grad_array_ops.py
浏览文件 @
8bf35b2b
...
...
@@ -211,11 +211,11 @@ def get_bprop_slice(self):
def
bprop
(
x
,
begin
,
size
,
out
,
dout
):
dx
=
P
.
Pad
(
_slice_grad_pad
(
begin
,
size
,
shape_op
(
x
)))(
dout
)
return
(
dx
,)
return
(
dx
,
zeros_like
(
begin
),
zeros_like
(
size
)
)
def
bprop_gpu
(
x
,
begin
,
size
,
out
,
dout
):
dx
=
dx
=
G
.
SliceGrad
()(
dout
,
x
,
begin
,
size
)
return
(
dx
,)
return
(
dx
,
zeros_like
(
begin
),
zeros_like
(
size
)
)
if
context
.
get_context
(
'device_target'
)
==
"GPU"
:
return
bprop_gpu
...
...
@@ -262,7 +262,7 @@ def get_bprop_gather_v2(self):
# Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
perm_2
=
_generate_inverse_index
(
x_shp
,
axis
)
params_grad
=
transpose
(
params_grad
,
perm_2
)
return
params_grad
,
zeros_like
(
indices
)
return
params_grad
,
zeros_like
(
indices
)
,
zeros_like
(
axis
)
return
bprop
...
...
mindspore/ops/_grad/grad_math_ops.py
浏览文件 @
8bf35b2b
...
...
@@ -505,7 +505,7 @@ def get_bprop_reducemax(self):
def
bprop
(
x
,
axis
,
out
,
dout
):
dx
=
_min_or_max_grad
(
x
,
axis
,
out
,
dout
)
return
(
dx
,)
return
(
dx
,
zeros_like
(
axis
)
)
return
bprop
...
...
@@ -528,7 +528,7 @@ def get_bprop_reducemin(self):
def
bprop
(
x
,
axis
,
out
,
dout
):
dx
=
_min_or_max_grad
(
x
,
axis
,
out
,
dout
)
return
(
dx
,)
return
(
dx
,
zeros_like
(
axis
)
)
return
bprop
...
...
mindspore/ops/_grad/grad_nn_ops.py
浏览文件 @
8bf35b2b
...
...
@@ -436,7 +436,7 @@ def get_bprop_onehot(self):
"""Grad definition for `OneHot` operation."""
def
bprop
(
indices
,
depth
,
on_value
,
off_value
,
out
,
dout
):
return
zeros_like
(
indices
),
zeros_like
(
depth
)
return
zeros_like
(
indices
),
zeros_like
(
depth
)
,
zeros_like
(
on_value
),
zeros_like
(
off_value
)
return
bprop
...
...
mindspore/ops/composite/multitype_ops/zeros_like_impl.py
浏览文件 @
8bf35b2b
...
...
@@ -31,6 +31,10 @@ def _zeros_like_scala(x):
"""Returns 0 which has the same dtype as x where x is a scalar."""
return
0
@
zeros_like_leaf
.
register
(
"Bool"
)
def
_zeros_like_bool
(
x
):
"""Returns False if x is a bool."""
return
False
newenv
=
base
.
EnvInstance_
()
...
...
mindspore/ops/functional.py
浏览文件 @
8bf35b2b
...
...
@@ -56,6 +56,7 @@ tensor_pow = P.Pow()
tensor_mod
=
P
.
FloorMod
()
strided_slice
=
P
.
StridedSlice
()
same_type_shape
=
P
.
SameTypeShape
()
check_bprop
=
P
.
CheckBprop
()
equal
=
P
.
Equal
()
not_equal
=
P
.
NotEqual
()
assign_sub
=
P
.
AssignSub
()
...
...
mindspore/ops/operations/__init__.py
浏览文件 @
8bf35b2b
...
...
@@ -67,7 +67,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
SparseSoftmaxCrossEntropyWithLogits
,
Tanh
,
TopK
,
BinaryCrossEntropy
,
SparseApplyAdagrad
,
LARSUpdate
,
ApplyFtrl
,
ApplyRMSProp
,
ApplyCenteredRMSProp
)
from
.other_ops
import
Assign
,
IOU
,
BoundingBoxDecode
,
BoundingBoxEncode
,
CheckValid
,
MakeRefKey
from
.other_ops
import
Assign
,
IOU
,
BoundingBoxDecode
,
BoundingBoxEncode
,
CheckValid
,
MakeRefKey
,
CheckBprop
from
.
import
_quant_ops
from
._quant_ops
import
*
...
...
@@ -179,6 +179,7 @@ __all__ = [
'GeSwitch'
,
'Merge'
,
'SameTypeShape'
,
'CheckBprop'
,
'CheckValid'
,
'BoundingBoxEncode'
,
'BoundingBoxDecode'
,
...
...
mindspore/ops/operations/other_ops.py
浏览文件 @
8bf35b2b
...
...
@@ -269,3 +269,66 @@ class MakeRefKey(Primitive):
def
__call__
(
self
):
pass
class
CheckBprop
(
PrimitiveWithInfer
):
"""
Checks whether data type and shape of corresponding element from tuple x and y are the same.
Raises:
TypeError: If not the same.
Inputs:
- **input_x** (tuple[Tensor]) - The input_x contains the outputs of bprop to be checked.
- **input_y** (tuple[Tensor]) - The input_y contains the inputs of bprop to check against.
Outputs:
(tuple[Tensor]), the input_x,
if data type and shape of corresponding elements from `input_x` and `input_y` are the same.
Examples:
>>> input_x = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),)
>>> input_y = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),)
>>> out = P.CheckBprop()(input_x, input_y)
"""
@
prim_attr_register
def
__init__
(
self
):
"""init CheckBprop"""
def
infer_shape
(
self
,
xshapes
,
yshapes
):
tips
=
f
'Bprop of
{
self
.
prim_to_check
}
'
if
len
(
xshapes
)
<
len
(
yshapes
):
raise
TypeError
(
f
"
{
tips
}
, the size of output should be
{
len
(
yshapes
)
}
,"
f
" but got
{
len
(
xshapes
)
}
."
)
checking_range
=
len
(
yshapes
)
for
i
in
range
(
checking_range
):
xshape
=
xshapes
[
i
]
yshape
=
yshapes
[
i
]
if
not
xshape
or
not
yshape
:
continue
if
xshape
!=
yshape
:
raise
TypeError
(
f
"
{
tips
}
, the shape of
{
i
}
th output should be
{
yshape
}
,"
f
" but got
{
xshape
}
."
)
return
xshapes
def
infer_dtype
(
self
,
xdtypes
,
ydtypes
):
tips
=
f
'Bprop of
{
self
.
prim_to_check
}
'
if
len
(
xdtypes
)
<
len
(
ydtypes
):
raise
TypeError
(
f
"
{
tips
}
, the size of output should be
{
len
(
ydtypes
)
}
,"
f
" but got
{
len
(
xdtypes
)
}
."
)
checking_range
=
len
(
ydtypes
)
for
i
in
range
(
checking_range
):
xdtype
=
xdtypes
[
i
]
ydtype
=
ydtypes
[
i
]
if
isinstance
(
xdtype
,
mstype
.
anything_type
)
or
isinstance
(
ydtype
,
mstype
.
anything_type
):
continue
if
isinstance
(
ydtype
,
mstype
.
function_type
):
if
not
isinstance
(
xdtype
,
mstype
.
env_type_type
):
raise
TypeError
(
f
"
{
tips
}
, the dtype of
{
i
}
th output should be
{
mstype
.
env_type_type
}
,"
f
" but got
{
xdtype
}
."
)
continue
if
xdtype
!=
ydtype
:
raise
TypeError
(
f
"
{
tips
}
, the dtype of
{
i
}
th output should be
{
ydtype
}
,"
f
" but got
{
xdtype
}
."
)
return
xdtypes
tests/ut/python/model/test_bert_cell.py
浏览文件 @
8bf35b2b
...
...
@@ -317,7 +317,7 @@ test_case_cell_ops = [
initializer_range
=
0.02
,
dropout_prob
=
0.1
),
'desc_inputs'
:
[[
1
,
768
],
[
1
,
768
]],
'desc_bprop'
:
[[
1
,
128
,
768
]]}),
# maybe not right
'desc_bprop'
:
[[
1
,
768
]]}),
(
'BertTransformer_2'
,
{
'block'
:
bert_trans
(),
'desc_inputs'
:
[[
1
,
128
,
768
],
[
1
,
128
,
128
]]}),
...
...
@@ -331,7 +331,7 @@ test_case_cell_ops = [
'desc_inputs'
:
[
Tensor
(
np
.
random
.
rand
(
128
).
astype
(
np
.
int32
)),
Tensor
(
np
.
random
.
rand
(
128
).
astype
(
np
.
int32
)),
[
128
]],
'desc_bprop'
:
[[
1
,
128
,
768
],
[
1
,
128
,
768
],
[
1
,
128
,
768
]],
'num_output'
:
3
}),
# maybe not right
'num_output'
:
3
}),
(
'BertModel_1'
,
{
'block'
:
BertModel
(
config
=
BertConfig
(
batch_size
=
1
,
...
...
@@ -342,7 +342,7 @@ test_case_cell_ops = [
'desc_inputs'
:
[
Tensor
(
np
.
random
.
rand
(
128
).
astype
(
np
.
int32
)),
Tensor
(
np
.
random
.
rand
(
128
).
astype
(
np
.
int32
)),
[
128
]],
'desc_bprop'
:
[[
1
,
128
,
768
],
[
1
,
128
,
768
],
[
1
,
128
,
768
]],
'num_output'
:
3
}),
# maybe not right
'num_output'
:
3
}),
(
'BertModel_2'
,
{
'block'
:
BertModel
(
config
=
BertConfig
(
batch_size
=
1
,
...
...
@@ -354,7 +354,7 @@ test_case_cell_ops = [
'desc_inputs'
:
[
Tensor
(
np
.
random
.
rand
(
128
).
astype
(
np
.
int32
)),
Tensor
(
np
.
random
.
rand
(
128
).
astype
(
np
.
int32
)),
[
128
]],
'desc_bprop'
:
[[
1
,
128
,
768
],
[
1
,
128
,
768
],
[
1
,
128
,
768
]],
'num_output'
:
3
}),
# maybe not right
'num_output'
:
3
}),
(
'BertPretrainingLoss'
,
{
'block'
:
BertPretrainingLoss
(
config
=
BertConfig
(
batch_size
=
1
)),
...
...
tests/ut/python/model/test_mix_precision.py
浏览文件 @
8bf35b2b
...
...
@@ -175,7 +175,7 @@ class GetParamGrad(nn.Cell):
def
test_grad_conv_prelu
():
shapes
=
[[
64
,
64
,
112
,
112
]]
outshape
=
[[
64
,
64
,
56
,
56
]]
outshape
=
[[
64
,
64
,
112
,
112
]]
net
=
IRBlockZ
(
inplanes
=
64
,
planes
=
64
).
add_flags_recursive
(
fp16
=
True
)
inputs
=
[
convert
(
shp
,
dtype
=
np
.
float16
)
for
shp
in
shapes
]
sens_shape
=
outshape
[
0
]
...
...
tests/ut/python/ops/test_ops.py
浏览文件 @
8bf35b2b
...
...
@@ -585,7 +585,7 @@ test_case_nn_ops = [
(
'ReLUV2'
,
{
'block'
:
P
.
ReLUV2
(),
'desc_inputs'
:
[[
1
,
3
,
4
,
4
]],
'desc_bprop'
:
[[
1
,
3
,
4
,
4
],
[
1
,
3
,
4
,
4
]
]}),
'desc_bprop'
:
[[
1
,
3
,
4
,
4
],
([
1
,
1
,
4
,
4
,
2
],
{
'dtype'
:
np
.
uint8
})
]}),
(
'ReLUGrad'
,
{
'block'
:
G
.
ReluGrad
(),
'desc_inputs'
:
[[
1
,
3
,
4
,
4
],
[
1
,
3
,
4
,
4
]],
...
...
@@ -626,7 +626,7 @@ test_case_nn_ops = [
(
'MaxPoolWithArgmax'
,
{
'block'
:
P
.
MaxPoolWithArgmax
(
ksize
=
2
,
strides
=
2
),
'desc_inputs'
:
[[
128
,
32
,
32
,
64
]],
'desc_bprop'
:
[[
128
,
32
,
8
,
16
],
[
128
,
32
,
8
,
16
]
]}),
'desc_bprop'
:
[[
128
,
32
,
16
,
32
],
([
128
,
32
,
4
,
33
],
{
'dtype'
:
np
.
uint16
})
]}),
(
'SoftmaxCrossEntropyWithLogits'
,
{
'block'
:
P
.
SoftmaxCrossEntropyWithLogits
(),
'desc_inputs'
:
[[
1
,
10
],
[
1
,
10
]],
...
...
@@ -639,7 +639,7 @@ test_case_nn_ops = [
(
'LogSoftmax'
,
{
'block'
:
P
.
LogSoftmax
(),
'desc_inputs'
:
[[
64
,
2
]],
'desc_bprop'
:
[[
160
,
3052
2
]]}),
'desc_bprop'
:
[[
64
,
2
]]}),
(
'LogSoftmaxGrad'
,
{
'block'
:
G
.
LogSoftmaxGrad
(),
'desc_inputs'
:
[[
16
,
1234
],
[
16
,
1234
]],
...
...
@@ -648,7 +648,7 @@ test_case_nn_ops = [
(
'LayerNorm'
,
{
'block'
:
P
.
LayerNorm
(),
'desc_inputs'
:
[[
2
,
16
],
[
16
],
[
16
]],
'desc_bprop'
:
[[
2
,
16
],
[
2
,
1
6
],
[
2
,
16
]]}),
'desc_bprop'
:
[[
2
,
16
],
[
2
,
1
],
[
2
,
1
]]}),
(
'LayerNormGrad'
,
{
'block'
:
G
.
LayerNormGrad
(),
'desc_inputs'
:
[[
2
,
16
],
[
2
,
16
],
[
2
,
16
],
[
2
,
16
],
[
16
]],
...
...
@@ -845,7 +845,7 @@ test_case_nn_ops = [
'block'
:
P
.
OneHot
(),
'desc_const'
:
[
3
,
Tensor
(
1.0
,
mstype
.
float32
),
Tensor
(
0.0
,
mstype
.
float32
)],
'desc_inputs'
:
[
Tensor
(
np
.
array
([
64
]).
astype
(
np
.
int32
))],
'desc_bprop'
:
[[
64
,
2
]]}),
'desc_bprop'
:
[[
1
,
3
]]}),
(
'ReduceProd_0'
,
{
'block'
:
P
.
ReduceProd
(),
'desc_const'
:
[
0
],
...
...
@@ -950,7 +950,7 @@ test_case_array_ops = [
'block'
:
P
.
Cast
(),
'desc_const'
:
[
mstype
.
int32
],
'desc_inputs'
:
[[
2
,
3
,
4
,
5
]],
'desc_bprop'
:
[
Tensor
(
np
.
ones
((
2
,
3
,
3
,
5
)).
astype
(
np
.
int32
))]}),
'desc_bprop'
:
[
Tensor
(
np
.
ones
((
2
,
3
,
4
,
5
)).
astype
(
np
.
int32
))]}),
(
'ExpandDims'
,
{
'block'
:
P
.
ExpandDims
(),
'desc_const'
:
[
0
],
...
...
@@ -1002,12 +1002,12 @@ test_case_array_ops = [
'desc_inputs'
:
[
(
Tensor
(
np
.
array
([[
0
,
1
],
[
2
,
1
]]).
astype
(
np
.
int32
)),
Tensor
(
np
.
array
([[
0
,
1
],
[
2
,
1
]]).
astype
(
np
.
int32
)))],
'desc_bprop'
:
[
[
4
,
2
]
]}),
'desc_bprop'
:
[
([
4
,
2
],
{
'dtype'
:
np
.
int32
})
]}),
(
'ConcatV2_1'
,
{
'block'
:
P
.
Concat
(
axis
=
2
),
'desc_inputs'
:
[(
Tensor
(
np
.
array
([[[
0
,
1
,
2
]],
[[
2
,
1
,
2
]]]).
astype
(
np
.
int32
)),
Tensor
(
np
.
array
([[[
0
,
1
]],
[[
2
,
1
]]]).
astype
(
np
.
int32
)))],
'desc_bprop'
:
[
[
2
,
1
,
5
]
]}),
'desc_bprop'
:
[
([
2
,
1
,
5
],
{
'dtype'
:
np
.
int32
})
]}),
(
'ConcatV2_2'
,
{
'block'
:
NetForConcat
(),
'desc_inputs'
:
[[
2
,
2
]],
...
...
@@ -1042,7 +1042,7 @@ test_case_array_ops = [
(
'Pack_2'
,
{
'block'
:
NetForPackInput
(
P
.
Pack
()),
'desc_inputs'
:[[
2
,
2
]],
'desc_bprop'
:[[
2
,
2
,
2
]],
'desc_bprop'
:[[
1
,
2
,
2
]],
}),
(
'Pack_3'
,
{
'block'
:
NetForPackInput
(
P
.
Pack
()),
...
...
@@ -1077,7 +1077,7 @@ test_case_array_ops = [
(
'SpaceToBatch_2'
,
{
'block'
:
P
.
SpaceToBatch
(
2
,
[[
1
,
1
],
[
0
,
4
]]),
'desc_inputs'
:
[[
1
,
3
,
2
,
2
]],
'desc_bprop'
:
[[
4
,
3
,
2
,
4
]],
'desc_bprop'
:
[[
4
,
3
,
2
,
3
]],
}),
(
'BatchToSpace_1'
,
{
'block'
:
P
.
BatchToSpace
(
2
,
[[
0
,
0
],
[
0
,
0
]]),
...
...
@@ -1124,7 +1124,7 @@ test_case_other_ops = [
'desc_const'
:
[(
3
,
3
)],
'desc_inputs'
:
(
Tensor
(
np
.
ones
((
2
,
2
),
np
.
int32
)),
Tensor
(
np
.
ones
((
2
,),
np
.
int32
))),
'desc_bprop'
:
[
[
3
,
3
]
]}),
'desc_bprop'
:
[
([
3
,
3
],
{
'dtype'
:
np
.
int32
})
]}),
(
'SmoothL1Loss'
,
{
'block'
:
P
.
SmoothL1Loss
(),
'desc_inputs'
:
[[
256
,
4
],
[
256
,
4
]],
...
...
tests/ut/python/pynative_mode/test_cell_bprop.py
浏览文件 @
8bf35b2b
...
...
@@ -229,12 +229,6 @@ class TwoInputBprop(nn.Cell):
def
bprop
(
self
,
x
,
y
,
out
,
dout
):
return
5
*
x
,
8
*
y
class
TwoInput
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
op
=
P
.
Mul
()
def
construct
(
self
,
x
,
y
):
return
self
.
op
(
x
,
y
)
class
TwoInputWithParameter
(
nn
.
Cell
):
def
__init__
(
self
):
...
...
@@ -301,8 +295,37 @@ class MulAddWithWrongOutputNum(nn.Cell):
def
construct
(
self
,
x
,
y
):
return
2
*
x
+
y
def
bprop
(
self
,
x
,
y
,
out
,
dout
):
return
2
*
dout
,
2
*
y
,
out
return
2
*
dout
,
def
test_grad_mul_add_with_wrong_output_num
():
mul_add
=
MulAddWithWrongOutputNum
()
C
.
grad_all
(
mul_add
)(
1
,
2
)
with
pytest
.
raises
(
TypeError
):
C
.
grad_all
(
mul_add
)(
1
,
2
)
class
MulAddWithWrongOutputType
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
MulAddWithWrongOutputType
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
):
return
2
*
x
+
y
def
bprop
(
self
,
x
,
y
,
out
,
dout
):
return
2
*
dout
,
2
def
test_grad_mul_add_with_wrong_output_type
():
mul_add
=
MulAddWithWrongOutputType
()
with
pytest
.
raises
(
TypeError
):
C
.
grad_all
(
mul_add
)(
1
,
Tensor
(
np
.
ones
([
2
,
2
])))
class
MulAddWithWrongOutputShape
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
MulAddWithWrongOutputShape
,
self
).
__init__
()
self
.
ones
=
Tensor
(
np
.
ones
([
2
,]))
def
construct
(
self
,
x
,
y
):
return
2
*
x
+
y
def
bprop
(
self
,
x
,
y
,
out
,
dout
):
return
2
,
self
.
ones
def
test_grad_mul_add_with_wrong_output_shape
():
mul_add
=
MulAddWithWrongOutputShape
()
with
pytest
.
raises
(
TypeError
):
C
.
grad_all
(
mul_add
)(
1
,
Tensor
(
np
.
ones
([
2
,
2
])))
tests/ut/python/pynative_mode/test_framstruct.py
浏览文件 @
8bf35b2b
...
...
@@ -32,6 +32,8 @@ from ....mindspore_test_framework.utils.check_gradient import (
OperationGradChecker
,
check_gradient
,
ScalarGradChecker
)
from
....mindspore_test_framework.utils.bprop_util
import
bprop
import
mindspore.context
as
context
from
mindspore.ops._grad.grad_base
import
bprop_getters
from
mindspore.ops.primitive
import
prim_attr_register
,
PrimitiveWithInfer
def
setup_module
(
module
):
...
...
@@ -721,3 +723,94 @@ def test_grad_if_defer_inline():
inp
=
Tensor
(
np
.
ones
([
128
,
96
]).
astype
(
np
.
float32
))
grads
=
C
.
grad_all
(
network
)(
inp
)
assert
grads
==
(
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)),)
def
test_bprop_with_wrong_output_num
():
class
BpropWithWrongOutputNum
(
PrimitiveWithInfer
):
@
prim_attr_register
def
__init__
(
self
):
super
(
BpropWithWrongOutputNum
,
self
).
__init__
(
'BpropWithWrongOutputNum'
)
def
__call__
(
self
,
x
,
y
):
return
x
def
infer_shape
(
self
,
x_shape
,
yshape
):
return
x_shape
def
infer_dtype
(
self
,
x_type
,
y_type
):
return
x_type
@
bprop_getters
.
register
(
BpropWithWrongOutputNum
)
def
get_bprop_with_wrong_output_num
(
self
):
"""Generate bprop for BpropWithWrongOutputNum"""
def
bprop
(
x
,
y
,
out
,
dout
):
return
(
dout
,)
return
bprop
class
BpropWithWrongOutputNumCell
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
BpropWithWrongOutputNumCell
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
):
return
BpropWithWrongOutputNum
()(
x
,
y
)
with
pytest
.
raises
(
TypeError
):
C
.
grad_all
(
BpropWithWrongOutputNumCell
())(
1
,
2
)
def
test_bprop_with_wrong_output_type
():
class
BpropWithWrongOutputType
(
PrimitiveWithInfer
):
@
prim_attr_register
def
__init__
(
self
):
super
(
BpropWithWrongOutputType
,
self
).
__init__
(
'BpropWithWrongOutputType'
)
def
__call__
(
self
,
x
):
return
x
def
infer_shape
(
self
,
x_shape
):
return
x_shape
def
infer_dtype
(
self
,
x_type
):
return
x_type
@
bprop_getters
.
register
(
BpropWithWrongOutputType
)
def
get_bprop_with_wrong_output_type
(
self
):
"""Generate bprop for BpropWithWrongOutputType"""
def
bprop
(
x
,
out
,
dout
):
return
(
1
,)
return
bprop
class
BpropWithWrongOutputTypeCell
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
BpropWithWrongOutputTypeCell
,
self
).
__init__
()
def
construct
(
self
,
x
):
return
BpropWithWrongOutputType
()(
x
)
with
pytest
.
raises
(
TypeError
):
C
.
grad_all
(
BpropWithWrongOutputTypeCell
())(
Tensor
(
np
.
ones
([
64
,
10
]).
astype
(
np
.
int32
)))
def
test_bprop_with_wrong_output_shape
():
class
BpropWithWrongOutputShape
(
PrimitiveWithInfer
):
@
prim_attr_register
def
__init__
(
self
):
super
(
BpropWithWrongOutputShape
,
self
).
__init__
(
'BpropWithWrongOutputShape'
)
def
__call__
(
self
,
x
):
return
x
def
infer_shape
(
self
,
x_shape
):
return
x_shape
def
infer_dtype
(
self
,
x_type
):
return
x_type
@
bprop_getters
.
register
(
BpropWithWrongOutputShape
)
def
get_bprop_with_wrong_output_shape
(
self
):
"""Generate bprop for BpropWithWrongOutputShape"""
ones
=
Tensor
(
np
.
ones
([
2
,]).
astype
(
np
.
int32
))
def
bprop
(
x
,
out
,
dout
):
return
(
ones
,)
return
bprop
class
BpropWithWrongOutputShapeCell
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
BpropWithWrongOutputShapeCell
,
self
).
__init__
()
def
construct
(
self
,
x
):
return
BpropWithWrongOutputShape
()(
x
)
with
pytest
.
raises
(
TypeError
):
C
.
grad_all
(
BpropWithWrongOutputShapeCell
())(
Tensor
(
np
.
ones
([
64
,
10
]).
astype
(
np
.
int32
)))
tests/ut/python/pynative_mode/test_insert_grad_of.py
浏览文件 @
8bf35b2b
...
...
@@ -79,7 +79,7 @@ def test_InsertGradientOf_2():
summary
=
P
.
ScalarSummary
()
def
debug_gradient
(
dx
):
""" debug_gradient """
dx
=
summary
(
"dx: "
,
dx
)
summary
(
"dx: "
,
dx
)
return
dx
debug
=
P
.
InsertGradientOf
(
debug_gradient
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录