Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
38436f92
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
38436f92
编写于
6月 09, 2020
作者:
K
kingfo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move hook function to primtivePy class
上级
444d9484
变更
27
隐藏空白更改
内联
并排
Showing
27 changed file
with
115 addition
and
58 deletion
+115
-58
mindspore/_extends/builtin_operations.py
mindspore/_extends/builtin_operations.py
+2
-12
mindspore/ccsrc/ir/primitive.h
mindspore/ccsrc/ir/primitive.h
+3
-0
mindspore/ccsrc/ir/primitive_base.h
mindspore/ccsrc/ir/primitive_base.h
+0
-4
mindspore/ccsrc/operator/ops.cc
mindspore/ccsrc/operator/ops.cc
+2
-2
mindspore/ccsrc/optimizer/ad/kprim.cc
mindspore/ccsrc/optimizer/ad/kprim.cc
+6
-2
mindspore/ccsrc/parallel/ops_info/ops_utils.h
mindspore/ccsrc/parallel/ops_info/ops_utils.h
+1
-1
mindspore/ccsrc/pipeline/parse/data_converter.cc
mindspore/ccsrc/pipeline/parse/data_converter.cc
+1
-1
mindspore/ccsrc/pynative/base.h
mindspore/ccsrc/pynative/base.h
+1
-1
mindspore/ccsrc/pynative/pynative_execute.cc
mindspore/ccsrc/pynative/pynative_execute.cc
+1
-1
mindspore/ccsrc/transform/convert.cc
mindspore/ccsrc/transform/convert.cc
+6
-6
mindspore/ccsrc/vm/vm.cc
mindspore/ccsrc/vm/vm.cc
+5
-4
mindspore/common/tensor.py
mindspore/common/tensor.py
+2
-0
mindspore/ops/_grad/grad_implementations.py
mindspore/ops/_grad/grad_implementations.py
+1
-2
mindspore/ops/composite/base.py
mindspore/ops/composite/base.py
+2
-1
mindspore/ops/functional.py
mindspore/ops/functional.py
+3
-3
mindspore/ops/operations/__init__.py
mindspore/ops/operations/__init__.py
+3
-1
mindspore/ops/operations/other_ops.py
mindspore/ops/operations/other_ops.py
+41
-0
tests/ut/cpp/operator/ops_test.cc
tests/ut/cpp/operator/ops_test.cc
+1
-1
tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py
tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py
+2
-2
tests/ut/cpp/python_input/gtest_input/pre_activate/eliminate_redundant_op_test.py
...t/gtest_input/pre_activate/eliminate_redundant_op_test.py
+1
-1
tests/ut/cpp/python_input/gtest_input/pre_activate/fused_batch_norm_fusion_test.py
.../gtest_input/pre_activate/fused_batch_norm_fusion_test.py
+7
-7
tests/ut/cpp/python_input/gtest_input/pre_activate/hw_opt_test.py
.../cpp/python_input/gtest_input/pre_activate/hw_opt_test.py
+1
-1
tests/ut/cpp/python_input/gtest_input/pre_activate/mixed_precision_test.py
...on_input/gtest_input/pre_activate/mixed_precision_test.py
+1
-1
tests/ut/cpp/python_input/gtest_input/pre_activate/optimize_dependence_test.py
...nput/gtest_input/pre_activate/optimize_dependence_test.py
+1
-1
tests/ut/python/ops/test_nn_ops.py
tests/ut/python/ops/test_nn_ops.py
+2
-2
tests/ut/python/pynative_mode/ops/test_hypermap.py
tests/ut/python/pynative_mode/ops/test_hypermap.py
+1
-1
tests/vm_impl/array_ops_vm_impl.py
tests/vm_impl/array_ops_vm_impl.py
+18
-0
未找到文件。
mindspore/_extends/builtin_operations.py
浏览文件 @
38436f92
...
...
@@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
"""builtin_operations"""
import
functools
import
numpy
as
np
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.dtype
import
dtype_to_nptype
,
get_py_obj_dtype
...
...
@@ -124,17 +123,8 @@ def list_len(x):
"""Implement `list_len`."""
return
len
(
x
)
# only used in PyNative mode
def
partial
(
*
args
):
"""Implement `partial`."""
func
=
args
[
0
].
__call__
partial_func
=
functools
.
partial
(
func
,
*
args
[
1
:])
return
partial_func
# only used in PyNative mode
def
depend
(
value
,
expr
):
def
Depend
(
value
,
expr
):
"""Implement `Depend`."""
return
value
# only used in PyNative mode
...
...
mindspore/ccsrc/ir/primitive.h
浏览文件 @
38436f92
...
...
@@ -49,6 +49,8 @@ class PrimitivePy : public Primitive {
void
AddPyAttr
(
const
py
::
str
&
name
,
const
py
::
object
&
obj
);
py
::
dict
GetAttrDict
();
void
set_hook
(
const
py
::
function
&
hook
)
{
hook_
=
hook
;
}
py
::
function
hook
()
const
{
return
hook_
;
}
const
bool
parse_info_
=
true
;
const
py
::
object
&
GetPyObj
()
const
{
return
python_obj_
;
}
...
...
@@ -56,6 +58,7 @@ class PrimitivePy : public Primitive {
private:
py
::
object
python_obj_
;
py
::
function
hook_
;
std
::
vector
<
Signature
>
signatures_
;
};
...
...
mindspore/ccsrc/ir/primitive_base.h
浏览文件 @
38436f92
...
...
@@ -89,9 +89,6 @@ class Primitive : public Named {
return
iter
==
attrs_
.
cend
()
?
nullptr
:
iter
->
second
;
}
void
set_hook
(
const
py
::
function
&
hook
)
{
hook_
=
hook
;
}
py
::
function
hook
()
const
{
return
hook_
;
}
const
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
attrs
()
const
{
return
attrs_
;
}
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
&
evaluate_added_attrs
()
{
return
evaluate_added_attrs_
;
}
...
...
@@ -124,7 +121,6 @@ class Primitive : public Named {
private:
std
::
string
instance_name_
;
py
::
function
hook_
;
bool
is_base_
;
bool
has_signature_
;
PrimType
prim_type_
;
...
...
mindspore/ccsrc/operator/ops.cc
浏览文件 @
38436f92
...
...
@@ -220,7 +220,7 @@ const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut");
// Other miscellaneous
const
PrimitivePtr
kPrimIdentity
=
std
::
make_shared
<
Primitive
>
(
"identity"
);
const
PrimitivePtr
kPrimPartial
=
std
::
make_shared
<
Primitive
>
(
"
p
artial"
);
const
PrimitivePtr
kPrimPartial
=
std
::
make_shared
<
Primitive
>
(
"
P
artial"
);
const
PrimitivePtr
kPrimJ
=
std
::
make_shared
<
Primitive
>
(
"J"
);
const
PrimitivePtr
kPrimEnvSetItem
=
std
::
make_shared
<
Primitive
>
(
"env_setitem"
);
const
PrimitivePtr
kPrimEnvGetItem
=
std
::
make_shared
<
Primitive
>
(
"env_getitem"
);
...
...
@@ -237,7 +237,7 @@ const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop");
const
PrimitivePtr
kPrimPrint
=
std
::
make_shared
<
Primitive
>
(
"Print"
);
const
PrimitivePtr
kPrimMakeRef
=
std
::
make_shared
<
Primitive
>
(
"make_ref"
);
const
PrimitivePtr
kPrimDepend
=
std
::
make_shared
<
Primitive
>
(
"
d
epend"
);
const
PrimitivePtr
kPrimDepend
=
std
::
make_shared
<
Primitive
>
(
"
D
epend"
);
const
PrimitivePtr
kPrimStateSetItem
=
std
::
make_shared
<
Primitive
>
(
"state_setitem"
);
const
PrimitivePtr
kPrimBroadcastGradientArgs
=
std
::
make_shared
<
Primitive
>
(
"BroadcastGradientArgs"
);
...
...
mindspore/ccsrc/optimizer/ad/kprim.cc
浏览文件 @
38436f92
...
...
@@ -238,8 +238,12 @@ FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::Res
auto
func_graph
=
std
::
make_shared
<
FuncGraph
>
();
std
::
vector
<
AnfNodePtr
>
outputs
;
auto
bprop_cut
=
std
::
make_shared
<
Primitive
>
(
"bprop_cut"
);
bprop_cut
->
set_hook
(
prim
->
hook
());
auto
bprop_cut
=
std
::
make_shared
<
PrimitivePy
>
(
"bprop_cut"
,
py
::
object
());
if
(
!
prim
->
is_base
())
{
PrimitivePyPtr
prim_py
=
dyn_cast
<
PrimitivePy
>
(
prim
);
bprop_cut
->
set_hook
(
prim_py
->
hook
());
}
auto
cell_id
=
GetValue
<
std
::
string
>
(
prim
->
GetAttr
(
"cell_id"
));
if
(
cell_id
!=
""
)
{
(
void
)
bprop_cut
->
AddAttr
(
"cell_hook"
,
MakeValue
(
true
));
...
...
mindspore/ccsrc/parallel/ops_info/ops_utils.h
浏览文件 @
38436f92
...
...
@@ -72,7 +72,7 @@ constexpr char OP[] = "op";
constexpr
char
IDENTITY_INFO
[]
=
"identity_info"
;
constexpr
char
DIVISOR
[]
=
"divisor"
;
constexpr
char
NONE
[]
=
"None"
;
constexpr
char
DEPEND
[]
=
"
d
epend"
;
constexpr
char
DEPEND
[]
=
"
D
epend"
;
constexpr
char
BATCH_PARALLEL
[]
=
"BatchParallel"
;
constexpr
char
ACTIVATION_TYPE
[]
=
"activation_type"
;
...
...
mindspore/ccsrc/pipeline/parse/data_converter.cc
浏览文件 @
38436f92
...
...
@@ -217,7 +217,7 @@ FuncGraphPtr ConvertToBpropCut(py::object obj) {
FuncGraphPtr
bprop_graph
=
std
::
make_shared
<
FuncGraph
>
();
std
::
vector
<
AnfNodePtr
>
outputs
;
auto
fake_bprop
=
std
::
make_shared
<
Primitive
>
(
"bprop_cut"
);
auto
fake_bprop
=
std
::
make_shared
<
Primitive
Py
>
(
"bprop_cut"
,
py
::
object
()
);
fake_bprop
->
set_hook
(
bprop_func
);
(
void
)
fake_bprop
->
AddAttr
(
"bprop"
,
MakeValue
(
true
));
outputs
.
push_back
(
NewValueNode
(
fake_bprop
));
...
...
mindspore/ccsrc/pynative/base.h
浏览文件 @
38436f92
...
...
@@ -59,7 +59,7 @@ struct OpExecInfo {
using
OpExecInfoPtr
=
std
::
shared_ptr
<
OpExecInfo
>
;
OpExecInfoPtr
GenerateOpExecInfo
(
const
py
::
args
&
args
);
const
std
::
set
<
std
::
string
>
ignore_infer_prim
=
{
"
partial"
,
"
make_ref"
};
const
std
::
set
<
std
::
string
>
ignore_infer_prim
=
{
"make_ref"
};
}
// namespace pynative
}
// namespace mindspore
...
...
mindspore/ccsrc/pynative/pynative_execute.cc
浏览文件 @
38436f92
...
...
@@ -53,7 +53,7 @@
const
char
SINGLE_OP_GRAPH
[]
=
"single_op_graph"
;
// primitive unable to infer value for constant input in PyNative mode
const
std
::
set
<
std
::
string
>
vm_operators
=
{
"
partial"
,
"depend"
,
"
make_ref"
,
"HookBackward"
};
const
std
::
set
<
std
::
string
>
vm_operators
=
{
"make_ref"
,
"HookBackward"
};
namespace
mindspore
{
namespace
pynative
{
...
...
mindspore/ccsrc/transform/convert.cc
浏览文件 @
38436f92
...
...
@@ -959,8 +959,8 @@ void DfGraphConvertor::TraceOutput(const AnfNodePtr node) {
for
(
unsigned
int
i
=
1
;
i
<
c
->
inputs
().
size
();
i
++
)
{
TraceOutput
(
c
->
input
(
i
));
}
}
else
if
(
name
==
"
d
epend"
)
{
if
(
c
->
inputs
().
size
()
<
3
)
{
// "
d
epend" primitive have 3 inputs
}
else
if
(
name
==
"
D
epend"
)
{
if
(
c
->
inputs
().
size
()
<
3
)
{
// "
D
epend" primitive have 3 inputs
MS_LOG
(
EXCEPTION
)
<<
"length of inputs is "
<<
c
->
inputs
().
size
()
<<
", which is less than 3"
;
}
TraceOutput
(
c
->
input
(
1
));
...
...
@@ -1183,7 +1183,7 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node
auto
&
inputs
=
node
->
inputs
();
for
(
size_t
i
=
1
;
i
<
inputs
.
size
();
i
++
)
{
auto
pred
=
inputs
[
i
];
while
(
pred
->
isa
<
CNode
>
()
&&
GetCNodeFuncName
(
pred
->
cast
<
CNodePtr
>
())
==
"
d
epend"
)
{
while
(
pred
->
isa
<
CNode
>
()
&&
GetCNodeFuncName
(
pred
->
cast
<
CNodePtr
>
())
==
"
D
epend"
)
{
pred
=
pred
->
cast
<
CNodePtr
>
()
->
input
(
1
);
}
// skip the None input
...
...
@@ -1362,7 +1362,7 @@ AnfNodePtr DfGraphConvertor::TraceTupleGetItem(const CNodePtr &node, unsigned in
AnfNodePtr
DfGraphConvertor
::
TraceDepend
(
const
CNodePtr
&
node
)
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
cnode
->
inputs
().
size
()
<
3
)
{
// "
d
epend" primitive have 3 inputs
if
(
cnode
->
inputs
().
size
()
<
3
)
{
// "
D
epend" primitive have 3 inputs
MS_LOG
(
EXCEPTION
)
<<
"length of inputs of depend is less than 3"
;
}
return
cnode
->
inputs
()[
1
];
...
...
@@ -1483,7 +1483,7 @@ AnfNodePtr DfGraphConvertor::GetRealOpNode(AnfNodePtr node) {
// depend apply inputs: depend,output,depended_node
if
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimDepend
))
{
auto
depend_inputs
=
node
->
cast
<
CNodePtr
>
()
->
inputs
();
if
(
depend_inputs
.
size
()
!=
3
)
{
// "
d
epend" primitive have 3 inputs
if
(
depend_inputs
.
size
()
!=
3
)
{
// "
D
epend" primitive have 3 inputs
MS_LOG
(
ERROR
)
<<
"depend input items not correct"
;
error_
=
FAILED
;
return
node
;
...
...
@@ -1700,7 +1700,7 @@ void DfGraphConvertor::ConvertControlDependNode(const CNodePtr node) {
bool
DfGraphConvertor
::
CheckCNode
(
const
std
::
string
&
name
,
const
CNodePtr
node
)
{
// ignore apply node of return
if
(
name
==
"return"
||
name
==
"
d
epend"
)
{
if
(
name
==
"return"
||
name
==
"
D
epend"
)
{
return
false
;
}
...
...
mindspore/ccsrc/vm/vm.cc
浏览文件 @
38436f92
...
...
@@ -585,8 +585,8 @@ void FinalVM::InstPushPrim(const VectorRef &args) {
return
;
}
VectorRef
tuple
;
auto
prim
=
utils
::
cast
<
PrimitivePtr
>
(
args
[
0
]);
VectorRef
tuple
;
for
(
size_t
i
=
1
;
i
<
args
.
size
();
++
i
)
{
auto
index
=
utils
::
cast
<
int
>
(
args
[
i
]);
tuple
.
push_back
(
Ref
(
index
));
...
...
@@ -618,6 +618,7 @@ void FinalVM::SyncData(const py::object &arg) {
BaseRef
FinalVM
::
RunHook
(
const
PrimitivePtr
&
prim
,
const
VectorRef
&
args
)
{
MS_LOG
(
DEBUG
)
<<
"input for operation:"
;
auto
prim_py
=
dyn_cast
<
PrimitivePy
>
(
prim
);
std
::
size_t
args_size
=
args
.
size
();
auto
py_args
=
py
::
tuple
(
args_size
);
size_t
i
=
0
;
...
...
@@ -631,7 +632,7 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
bool
is_bprop
=
prim
->
HasAttr
(
"bprop"
);
if
(
is_bprop
)
{
SyncData
(
py_args
);
py
::
function
fn_bprop
=
prim
->
hook
();
py
::
function
fn_bprop
=
prim
_py
->
hook
();
obj
=
fn_bprop
(
*
py_args
);
return
obj
;
}
...
...
@@ -647,7 +648,7 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
hook_args
[
0
]
=
cell_id
;
hook_args
[
1
]
=
py
::
make_tuple
(
_hook_grad
[
cell_id
]);
hook_args
[
2
]
=
py
::
make_tuple
(
py_args
[
2
]);
py
::
function
fn_hook
=
prim
->
hook
();
py
::
function
fn_hook
=
prim
_py
->
hook
();
obj
=
fn_hook
(
*
hook_args
);
if
(
py
::
isinstance
<
py
::
none
>
(
obj
))
{
obj
=
py_args
[
2
];
...
...
@@ -659,7 +660,7 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
}
}
else
{
// Hook operator for execute variable hook function
py
::
function
fn_hook
=
prim
->
hook
();
py
::
function
fn_hook
=
prim
_py
->
hook
();
obj
=
fn_hook
(
py
::
make_tuple
(
py_args
[
2
]));
if
(
py
::
isinstance
<
py
::
none
>
(
obj
))
{
obj
=
py_args
[
2
];
...
...
mindspore/common/tensor.py
浏览文件 @
38436f92
...
...
@@ -78,6 +78,8 @@ class Tensor(Tensor_):
def
__eq__
(
self
,
other
):
if
not
isinstance
(
other
,
Tensor
):
return
False
# The GE backend don't support single `Equal` operator execution.
# bool type is not supported for `Equal` operator in backend.
if
context
.
get_context
(
"enable_ge"
)
or
self
.
dtype
()
==
mstype
.
bool_
or
other
.
dtype
()
==
mstype
.
bool_
:
return
Tensor
(
np
.
array
(
self
.
asnumpy
()
==
other
.
asnumpy
()))
return
tensor_operator_registry
.
get
(
'__eq__'
)(
self
,
other
)
...
...
mindspore/ops/_grad/grad_implementations.py
浏览文件 @
38436f92
...
...
@@ -195,7 +195,7 @@ def bprop_array_reduce(fn, x, shp, out, dout):
return
F
.
distribute
(
dout
,
F
.
shape
(
x
)),
C
.
zeros_like
(
shp
)
@
bprops
.
register
(
"
d
epend"
)
@
bprops
.
register
(
"
D
epend"
)
def
bprop_depend
(
x
,
y
,
out
,
dout
):
"""Backpropagator for primitive `depend`."""
return
dout
,
C
.
zeros_like
(
y
)
...
...
@@ -236,7 +236,6 @@ def bprop_control_depend(x, y, out, dout):
"""Backpropagator for primitive `Control_depend`."""
return
C
.
zeros_like
(
x
),
C
.
zeros_like
(
y
)
@
bprops
.
register
(
"switch"
)
def
bprop_switch
(
cond
,
tb
,
fb
,
out
,
dout
):
"""Backpropagator for primitive `switch`."""
...
...
mindspore/ops/composite/base.py
浏览文件 @
38436f92
...
...
@@ -22,7 +22,7 @@ from mindspore import context
from
..._c_expression
import
EnvInstance_
,
GradOperation_
,
HyperMap_
,
MultitypeFuncGraph_
,
Tail_
,
TensorSlice_
,
\
TupleAdd_
,
TupleSlice_
,
UnpackCall_
,
ZipOperation_
,
ListAppend_
,
TupleGetItemTensor_
from
...common
import
dtype
as
mstype
from
...common.api
import
ms_function
,
_pynative_exec
from
...common.api
import
ms_function
,
_pynative_exec
,
_wrap_func
from
..
import
functional
as
F
from
...common.parameter
import
Parameter
...
...
@@ -117,6 +117,7 @@ class GradOperation(GradOperation_):
def
after_grad
(
*
args
):
return
grad_
(
fn
,
weights
)(
*
args
)
else
:
@
_wrap_func
def
after_grad
(
*
args
):
if
fn
.
is_run
and
not
fn
.
requires_grad
:
raise
ValueError
(
"obj must set_grad."
)
...
...
mindspore/ops/functional.py
浏览文件 @
38436f92
...
...
@@ -77,6 +77,9 @@ gather_nd = P.GatherNd()
scatter_update
=
P
.
ScatterUpdate
()
scatter_nd_update
=
P
.
ScatterNdUpdate
()
pack
=
P
.
Pack
()
partial
=
P
.
Partial
()
# depend: mount a node to another node
depend
=
P
.
Depend
()
tuple_setitem
=
Primitive
(
'tuple_setitem'
)
...
...
@@ -131,12 +134,9 @@ mixed_precision_cast = Primitive("mixed_precision_cast")
broadcast_gradient_args
=
Primitive
(
'BroadcastGradientArgs'
)
dot
=
Primitive
(
'dot'
)
array_reduce
=
Primitive
(
'array_reduce'
)
partial
=
Primitive
(
'partial'
)
zeros_like
=
P
.
ZerosLike
()
identity
=
Primitive
(
'identity'
)
distribute
=
Primitive
(
'distribute'
)
# depend: mount a node to another node
depend
=
Primitive
(
'depend'
)
embed
=
Primitive
(
'embed'
)
ref_to_embed
=
_grad_ops
.
RefToEmbed
()
env_setitem
=
Primitive
(
'env_setitem'
)
...
...
mindspore/ops/operations/__init__.py
浏览文件 @
38436f92
...
...
@@ -74,7 +74,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
ApplyProximalAdagrad
,
SparseApplyProximalAdagrad
,
ApplyRMSProp
,
ApplyCenteredRMSProp
,
BasicLSTMCell
)
from
.other_ops
import
(
Assign
,
IOU
,
BoundingBoxDecode
,
BoundingBoxEncode
,
CheckValid
,
MakeRefKey
,
CheckBprop
,
ConfusionMatrix
)
CheckValid
,
MakeRefKey
,
Partial
,
Depend
,
CheckBprop
,
ConfusionMatrix
)
from
.
import
_quant_ops
from
._quant_ops
import
*
from
.thor_ops
import
*
...
...
@@ -213,6 +213,8 @@ __all__ = [
'NMSWithMask'
,
'IOU'
,
'MakeRefKey'
,
'Partial'
,
'Depend'
,
'AvgPool'
,
# Back Primitive
'Equal'
,
...
...
mindspore/ops/operations/other_ops.py
浏览文件 @
38436f92
...
...
@@ -14,6 +14,7 @@
# ============================================================================
"""Other operators."""
import
functools
from
..._c_expression
import
signature_rw
as
sig_rw
from
..._c_expression
import
signature_kind
as
sig_kind
from
..._c_expression
import
signature_dtype
as
sig_dtype
...
...
@@ -304,6 +305,46 @@ class MakeRefKey(Primitive):
pass
class
Partial
(
Primitive
):
"""
Make a partial function instance, used for pynative mode.
Inputs:
- **args** (Union[FunctionType, Tensor]) - The function and bind arguments.
Outputs:
FunctionType, partial function binded with arguments.
"""
@
prim_attr_register
def
__init__
(
self
):
pass
def
__call__
(
self
,
*
args
):
func
=
args
[
0
].
__call__
partial_func
=
functools
.
partial
(
func
,
*
args
[
1
:])
return
partial_func
class
Depend
(
Primitive
):
"""
Depend is used for process side-effect operations.
Inputs:
- **value** (Tensor) - the real value to return for depend operator.
- **expr** (Expression) - the expression to execute with no outputs.
Outputs:
Tensor, the value passed by last operator.
"""
@
prim_attr_register
def
__init__
(
self
):
pass
def
__call__
(
self
,
value
,
expr
):
return
value
class
CheckBprop
(
PrimitiveWithInfer
):
"""
Checks whether data type and shape of corresponding element from tuple x and y are the same.
...
...
tests/ut/cpp/operator/ops_test.cc
浏览文件 @
38436f92
...
...
@@ -341,7 +341,7 @@ TEST_F(TestOps, ResolveTest) {
}
TEST_F
(
TestOps
,
PartialTest
)
{
auto
prim
=
std
::
make_shared
<
Primitive
>
(
"
p
artial"
);
auto
prim
=
std
::
make_shared
<
Primitive
>
(
"
P
artial"
);
ASSERT_EQ
(
prim
->
name
(),
kPrimPartial
->
name
());
}
...
...
tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py
浏览文件 @
38436f92
...
...
@@ -636,7 +636,7 @@ def test_tuple_get_set_item(tag):
def
test_partial
(
tag
):
""" test_partial """
fns
=
FnDict
()
partail
=
P
rimitive
(
'partial'
)
partail
=
P
.
Partial
(
)
def
f
(
x
,
y
):
return
scalar_add
(
x
,
y
)
...
...
@@ -655,7 +655,7 @@ def test_partial(tag):
def
test_replace_applicator
(
tag
):
""" test_replace_applicator """
fns
=
FnDict
()
partail
=
P
rimitive
(
'partial'
)
partail
=
P
.
Partial
(
)
def
app1
(
x
,
y
):
return
scalar_add
(
x
,
y
)
...
...
tests/ut/cpp/python_input/gtest_input/pre_activate/eliminate_redundant_op_test.py
浏览文件 @
38436f92
...
...
@@ -22,7 +22,7 @@ four2five = Primitive('Four2Five')
five2four
=
Primitive
(
'Five2Four'
)
transdata
=
Primitive
(
"TransData"
)
cast
=
Primitive
(
'Cast'
)
depend
=
P
rimitive
(
'depend'
)
depend
=
P
.
Depend
(
)
class
FnDict
:
...
...
tests/ut/cpp/python_input/gtest_input/pre_activate/fused_batch_norm_fusion_test.py
浏览文件 @
38436f92
...
...
@@ -16,13 +16,13 @@ import mindspore.common.dtype as mstype
from
mindspore.common.tensor
import
Tensor
from
mindspore.ops
import
Primitive
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
functional
as
F
AssignSub
=
P
.
AssignSub
()
Mul
=
P
.
Mul
()
Sub
=
P
.
Sub
()
make_tuple
=
Primitive
(
'make_tuple'
)
tuple_getitem
=
Primitive
(
'tuple_getitem'
)
depend
=
Primitive
(
'depend'
)
BatchNorm
=
P
.
BatchNorm
()
Cast
=
P
.
Cast
()
BNTrainingReduce
=
Primitive
(
'BNTrainingReduce'
)
...
...
@@ -54,8 +54,8 @@ def test_fused_batch_norm_fusion(tag):
mul1
=
Mul
(
sub1
,
constant1
)
assign_sub0
=
AssignSub
(
var0
,
mul0
)
assign_sub1
=
AssignSub
(
var1
,
mul1
)
depend0
=
depend
(
tuple_getitem
(
batch_norm
,
0
),
assign_sub0
)
depend1
=
depend
(
depend0
,
assign_sub1
)
depend0
=
F
.
depend
(
tuple_getitem
(
batch_norm
,
0
),
assign_sub0
)
depend1
=
F
.
depend
(
depend0
,
assign_sub1
)
outputs
=
make_tuple
(
depend1
,
tuple_getitem
(
batch_norm
,
3
),
tuple_getitem
(
batch_norm
,
4
))
output
=
tuple_getitem
(
outputs
,
0
)
return
output
...
...
@@ -69,8 +69,8 @@ def test_fused_batch_norm_fusion(tag):
mul1
=
Mul
(
sub1
,
constant1
)
assign_sub0
=
AssignSub
(
var0
,
Cast
(
mul0
,
mstype
.
float32
))
assign_sub1
=
AssignSub
(
var1
,
Cast
(
mul1
,
mstype
.
float32
))
depend0
=
depend
(
tuple_getitem
(
batch_norm
,
0
),
assign_sub0
)
depend1
=
depend
(
depend0
,
assign_sub1
)
depend0
=
F
.
depend
(
tuple_getitem
(
batch_norm
,
0
),
assign_sub0
)
depend1
=
F
.
depend
(
depend0
,
assign_sub1
)
outputs
=
make_tuple
(
depend1
,
tuple_getitem
(
batch_norm
,
3
),
tuple_getitem
(
batch_norm
,
4
))
output
=
tuple_getitem
(
outputs
,
0
)
return
output
...
...
@@ -84,8 +84,8 @@ def test_fused_batch_norm_fusion(tag):
mul1
=
Mul
(
Cast
(
sub1
,
mstype
.
float32
),
constant1
)
assign_sub0
=
AssignSub
(
var0
,
mul0
)
assign_sub1
=
AssignSub
(
var1
,
mul1
)
depend0
=
depend
(
tuple_getitem
(
batch_norm
,
0
),
assign_sub0
)
depend1
=
depend
(
depend0
,
assign_sub1
)
depend0
=
F
.
depend
(
tuple_getitem
(
batch_norm
,
0
),
assign_sub0
)
depend1
=
F
.
depend
(
depend0
,
assign_sub1
)
outputs
=
make_tuple
(
depend1
,
tuple_getitem
(
batch_norm
,
3
),
tuple_getitem
(
batch_norm
,
4
))
output
=
tuple_getitem
(
outputs
,
0
)
return
output
...
...
tests/ut/cpp/python_input/gtest_input/pre_activate/hw_opt_test.py
浏览文件 @
38436f92
...
...
@@ -16,7 +16,7 @@ from mindspore.ops import Primitive
from
mindspore.ops
import
operations
as
P
tuple_getitem
=
Primitive
(
'tuple_getitem'
)
depend
=
P
rimitive
(
'depend'
)
depend
=
P
.
Depend
(
)
addn
=
P
.
AddN
()
add
=
P
.
TensorAdd
()
sub
=
P
.
Sub
()
...
...
tests/ut/cpp/python_input/gtest_input/pre_activate/mixed_precision_test.py
浏览文件 @
38436f92
...
...
@@ -16,7 +16,7 @@ from mindspore.ops import Primitive
from
mindspore.ops
import
operations
as
P
tuple_getitem
=
Primitive
(
'tuple_getitem'
)
depend
=
P
rimitive
(
'depend'
)
depend
=
P
.
Depend
(
)
addn
=
P
.
AddN
()
add
=
P
.
TensorAdd
()
sub
=
P
.
Sub
()
...
...
tests/ut/cpp/python_input/gtest_input/pre_activate/optimize_dependence_test.py
浏览文件 @
38436f92
...
...
@@ -15,7 +15,7 @@
from
mindspore.ops
import
Primitive
from
mindspore.ops
import
operations
as
P
depend
=
P
rimitive
(
'depend'
)
depend
=
P
.
Depend
(
)
TransData
=
Primitive
(
'TransData'
)
add
=
P
.
TensorAdd
()
make_tuple
=
Primitive
(
'make_tuple'
)
...
...
tests/ut/python/ops/test_nn_ops.py
浏览文件 @
38436f92
...
...
@@ -20,9 +20,9 @@ import mindspore.context as context
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
,
Parameter
from
mindspore.common.initializer
import
initializer
from
mindspore.ops
import
Primitive
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
functional
as
F
from
mindspore.ops.operations
import
_grad_ops
as
G
from
mindspore.ops
import
prim_attr_register
,
PrimitiveWithInfer
from
..ut_filter
import
non_graph_engine
...
...
@@ -358,7 +358,7 @@ class StateNet(nn.Cell):
self
.
assign
=
P
.
Assign
()
def
construct
(
self
,
x
):
x
=
Primitive
(
'depend'
)
(
x
,
self
.
assign
(
self
.
s1
,
x
+
self
.
s1
))
x
=
F
.
depend
(
x
,
self
.
assign
(
self
.
s1
,
x
+
self
.
s1
))
self
.
s1
=
self
.
sub
(
self
.
s1
,
x
)
self
.
s2
=
self
.
sub
(
self
.
s2
,
x
)
return
x
...
...
tests/ut/python/pynative_mode/ops/test_hypermap.py
浏览文件 @
38436f92
...
...
@@ -132,7 +132,7 @@ def test_hypermap_add3_easy():
add3
=
C
.
MultitypeFuncGraph
(
'add'
)
partial
=
P
rimitive
(
'partial'
)
partial
=
P
.
Partial
(
)
@
add3
.
register
(
"Number"
,
"Number"
,
"Number"
)
...
...
tests/vm_impl/array_ops_vm_impl.py
浏览文件 @
38436f92
...
...
@@ -284,3 +284,21 @@ def vm_impl_zeros_like(self):
"""Generate vm_impl function for ZerosLike"""
def
vm_impl
(
x
):
return
Tensor
(
np
.
zeros_like
(
x
.
asnumpy
()))
@
vm_impl_getters
.
register
(
P
.
Partial
)
def
vm_impl_partial
(
self
):
"""Generate vm_impl function for Partial"""
def
vm_impl
(
*
args
):
func
=
args
[
0
].
__call__
partial_func
=
functools
.
partial
(
func
,
*
args
[
1
:])
return
partial_func
return
vm_impl
@
vm_impl_getters
.
register
(
P
.
Depend
)
def
vm_impl_depend
(
self
):
"""Generate vm_impl function for Depend"""
def
vm_impl
(
value
,
expr
):
return
value
return
vm_impl
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录