Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
45a1df06
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
45a1df06
编写于
6月 01, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 01, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1421 Modify code to support dynamic graph.
Merge pull request !1421 from rick_sanchez/pynative_new_4
上级
05177ff9
e2a322b6
变更
34
展开全部
隐藏空白更改
内联
并排
Showing
34 changed file
with
673 addition
and
72 deletion
+673
-72
mindspore/_extends/parse/__init__.py
mindspore/_extends/parse/__init__.py
+7
-6
mindspore/_extends/parse/parser.py
mindspore/_extends/parse/parser.py
+11
-0
mindspore/ccsrc/ir/manager.cc
mindspore/ccsrc/ir/manager.cc
+1
-5
mindspore/ccsrc/ir/param_value_py.h
mindspore/ccsrc/ir/param_value_py.h
+1
-1
mindspore/ccsrc/ir/primitive.cc
mindspore/ccsrc/ir/primitive.cc
+1
-1
mindspore/ccsrc/ir/tensor.cc
mindspore/ccsrc/ir/tensor.cc
+4
-0
mindspore/ccsrc/ir/tensor.h
mindspore/ccsrc/ir/tensor.h
+2
-0
mindspore/ccsrc/operator/composite/composite.cc
mindspore/ccsrc/operator/composite/composite.cc
+8
-2
mindspore/ccsrc/operator/composite/composite.h
mindspore/ccsrc/operator/composite/composite.h
+1
-1
mindspore/ccsrc/operator/composite/do_signature.cc
mindspore/ccsrc/operator/composite/do_signature.cc
+5
-0
mindspore/ccsrc/optimizer/ad/dfunctor.cc
mindspore/ccsrc/optimizer/ad/dfunctor.cc
+4
-0
mindspore/ccsrc/optimizer/ad/dfunctor.h
mindspore/ccsrc/optimizer/ad/dfunctor.h
+2
-1
mindspore/ccsrc/optimizer/ad/grad.cc
mindspore/ccsrc/optimizer/ad/grad.cc
+10
-4
mindspore/ccsrc/optimizer/ad/grad.h
mindspore/ccsrc/optimizer/ad/grad.h
+2
-1
mindspore/ccsrc/optimizer/irpass/inline.h
mindspore/ccsrc/optimizer/irpass/inline.h
+2
-1
mindspore/ccsrc/pipeline/action.cc
mindspore/ccsrc/pipeline/action.cc
+2
-0
mindspore/ccsrc/pipeline/action.h
mindspore/ccsrc/pipeline/action.h
+1
-0
mindspore/ccsrc/pipeline/parse/data_converter.cc
mindspore/ccsrc/pipeline/parse/data_converter.cc
+4
-0
mindspore/ccsrc/pipeline/parse/parse_base.h
mindspore/ccsrc/pipeline/parse/parse_base.h
+2
-0
mindspore/ccsrc/pipeline/pass.cc
mindspore/ccsrc/pipeline/pass.cc
+2
-0
mindspore/ccsrc/pipeline/pass.h
mindspore/ccsrc/pipeline/pass.h
+1
-0
mindspore/ccsrc/pipeline/pipeline.cc
mindspore/ccsrc/pipeline/pipeline.cc
+7
-2
mindspore/ccsrc/pipeline/pipeline.h
mindspore/ccsrc/pipeline/pipeline.h
+2
-0
mindspore/ccsrc/pynative/pynative_execute.cc
mindspore/ccsrc/pynative/pynative_execute.cc
+430
-15
mindspore/ccsrc/pynative/pynative_execute.h
mindspore/ccsrc/pynative/pynative_execute.h
+70
-0
mindspore/common/api.py
mindspore/common/api.py
+30
-1
mindspore/common/tensor.py
mindspore/common/tensor.py
+0
-7
mindspore/nn/cell.py
mindspore/nn/cell.py
+29
-4
mindspore/nn/wrap/cell_wrapper.py
mindspore/nn/wrap/cell_wrapper.py
+1
-0
mindspore/ops/composite/base.py
mindspore/ops/composite/base.py
+26
-4
mindspore/ops/operations/debug_ops.py
mindspore/ops/operations/debug_ops.py
+0
-6
mindspore/ops/primitive.py
mindspore/ops/primitive.py
+2
-8
tests/ut/python/pynative_mode/nn/test_dropout.py
tests/ut/python/pynative_mode/nn/test_dropout.py
+1
-1
tests/vm_impl/array_ops_vm_impl.py
tests/vm_impl/array_ops_vm_impl.py
+2
-1
未找到文件。
mindspore/_extends/parse/__init__.py
浏览文件 @
45a1df06
...
...
@@ -19,14 +19,15 @@ Interfaces for parser module in c++.
from
.parser
import
(
Parser
,
create_obj_instance
,
generate_scope
,
get_bprop_method_of_class
,
get_class_instance_type
,
get_class_member_namespace_symbol
,
create_slice_obj
,
get_dataclass_attributes
,
get_dataclass_methods
,
get_dataclass_attributes
,
get_dataclass_methods
,
get_obj_id
,
get_module_namespace
,
get_obj_type
,
get_object_key
,
get_parse_method_of_class
,
get_scope_name
,
get_
default_input
,
get_
parse_method_of_class
,
get_scope_name
,
is_class_member
,
parse_cb
,
resolve_symbol
,
create_ellipsis_obj
)
from
.serialize
import
*
__all__
=
[
'parse_cb'
,
'get_parse_method_of_class'
,
'get_bprop_method_of_class'
,
'resolve_symbol'
,
'get_object_key'
,
'get_class_instance_type'
,
'is_class_member'
,
'get_obj_type'
,
'create_obj_instance'
,
'get_module_namespace'
,
'get_class_member_namespace_symbol'
,
'Parser'
,
'get_dataclass_attributes'
,
'get_dataclass_methods'
,
'dump_obj'
,
'load_obj'
,
'get_dataclass_methods'
,
'get_scope_name'
,
'create_slice_obj'
,
'create_ellipsis_obj'
]
'get_object_key'
,
'get_default_input'
,
'get_class_instance_type'
,
'is_class_member'
,
'get_obj_type'
,
'get_obj_id'
,
'create_obj_instance'
,
'get_module_namespace'
,
'get_class_member_namespace_symbol'
,
'get_obj_id'
,
'Parser'
,
'get_dataclass_attributes'
,
'get_dataclass_methods'
,
'dump_obj'
,
'load_obj'
,
'get_dataclass_methods'
,
'get_scope_name'
,
'create_slice_obj'
,
'create_ellipsis_obj'
]
mindspore/_extends/parse/parser.py
浏览文件 @
45a1df06
...
...
@@ -209,6 +209,14 @@ def get_object_key(obj):
obj_id
=
instance_id
+
obj_id
return
obj_id
,
obj_key
def
get_default_input
(
obj
):
if
hasattr
(
obj
,
'__parameter__'
):
return
obj
.
default_input
if
isinstance
(
obj
,
tuple
):
convert
=
lambda
x
:
x
.
default_input
if
hasattr
(
x
,
'__parameter__'
)
else
x
args
=
tuple
(
convert
(
x
)
for
x
in
obj
)
return
args
return
obj
def
is_class_member
(
node
):
"""Check the attr is class member variable."""
...
...
@@ -221,6 +229,9 @@ def is_class_member(node):
return
True
return
False
def
get_obj_id
(
obj
):
"""Get the obj id."""
return
str
(
id
(
obj
))
def
get_obj_type
(
obj
):
"""Get the obj type."""
...
...
mindspore/ccsrc/ir/manager.cc
浏览文件 @
45a1df06
...
...
@@ -328,9 +328,6 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E
DropEdge
(
node
,
index
,
inp
);
}
else
{
MS_LOG
(
DEBUG
)
<<
"Add node "
<<
node
->
ToString
()
<<
" input["
<<
index
<<
"] "
<<
inp
->
ToString
();
if
(
inp
->
func_graph
()
!=
nullptr
)
{
AddFuncGraph
(
inp
->
func_graph
());
}
if
(
IsValueNode
<
FuncGraph
>
(
inp
))
{
MS_LOG
(
DEBUG
)
<<
"Input["
<<
index
<<
"] is const graph "
<<
inp
->
ToString
();
AddFuncGraph
(
GetValueNode
<
FuncGraphPtr
>
(
inp
));
...
...
@@ -372,9 +369,8 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr> &nodes) {
for
(
auto
&
node
:
acq
)
{
MS_EXCEPTION_IF_NULL
(
node
);
FuncGraphPtr
fg
=
node
->
func_graph
();
auto
fg
=
node
->
func_graph
();
if
(
fg
!=
nullptr
)
{
AddFuncGraph
(
fg
);
fg
->
AddNode
(
node
);
}
ProcessInputs
(
node
,
kIncEdge
);
...
...
mindspore/ccsrc/ir/param_value_py.h
浏览文件 @
45a1df06
...
...
@@ -28,7 +28,7 @@ namespace py = pybind11;
class
ParamValuePy
:
public
ParamValue
{
public:
ParamValuePy
()
:
value_
(
py
::
none
())
{}
explicit
ParamValuePy
(
py
::
object
value
)
:
value_
(
value
)
{}
explicit
ParamValuePy
(
const
py
::
object
&
value
)
:
value_
(
value
)
{}
~
ParamValuePy
()
override
=
default
;
py
::
object
value
()
{
return
value_
;
}
...
...
mindspore/ccsrc/ir/primitive.cc
浏览文件 @
45a1df06
...
...
@@ -75,7 +75,7 @@ py::function PrimitivePy::GetComputeFunction() {
py
::
function
vm_fn
=
get_fn
(
python_obj_
);
if
(
py
::
isinstance
<
py
::
none
>
(
vm_fn
))
{
MS_LOG
(
DEBU
G
)
<<
"Cannot find "
<<
python_obj_
.
attr
(
"__class__"
).
attr
(
"__name__"
).
cast
<
std
::
string
>
();
MS_LOG
(
WARNIN
G
)
<<
"Cannot find "
<<
python_obj_
.
attr
(
"__class__"
).
attr
(
"__name__"
).
cast
<
std
::
string
>
();
vm_fn
=
mindspore
::
GetComputeFunction
(
Primitive
::
name
());
}
return
vm_fn
;
...
...
mindspore/ccsrc/ir/tensor.cc
浏览文件 @
45a1df06
...
...
@@ -81,6 +81,7 @@ Tensor::Tensor(const Tensor &tensor, const TypePtr &data_type)
:
MetaTensor
(
tensor
),
device_address_
(
tensor
.
device_address_
)
{
init
(
tensor
.
data_
,
data_type
);
dirty_
=
tensor
.
is_dirty
();
id_
=
tensor
.
id
();
}
Tensor
&
Tensor
::
operator
=
(
const
Tensor
&
tensor
)
{
...
...
@@ -89,6 +90,7 @@ Tensor &Tensor::operator=(const Tensor &tensor) {
dirty_
=
tensor
.
is_dirty
();
device_address_
=
tensor
.
device_address
();
data_
=
tensor
.
data_
;
id_
=
tensor
.
id
();
}
return
*
this
;
}
...
...
@@ -208,6 +210,7 @@ void Tensor::init(const py::array &input, const TypeId &data_type) {
data_
=
input
;
}
dirty_
=
true
;
id_
=
std
::
to_string
((
uintptr_t
)(
this
));
}
void
Tensor
::
init
(
TypeId
data_type
,
const
std
::
vector
<
int
>
&
shape
,
py
::
array
*
const
data
)
{
...
...
@@ -254,6 +257,7 @@ void Tensor::init(TypeId data_type, const std::vector<int> &shape, py::array *co
MS_LOG
(
EXCEPTION
)
<<
"Cannot construct Tensor because of unsupported data type: "
<<
data_type
<<
"."
;
break
;
}
id_
=
std
::
to_string
((
uintptr_t
)(
this
));
}
TypePtr
Tensor
::
SetDtype
(
const
TypePtr
type_ptr
)
{
...
...
mindspore/ccsrc/ir/tensor.h
浏览文件 @
45a1df06
...
...
@@ -263,9 +263,11 @@ class Tensor : public MetaTensor {
DeviceAddressPtr
device_address
()
const
{
return
device_address_
;
}
void
set_device_address
(
const
DeviceAddressPtr
&
device_address
)
{
device_address_
=
device_address
;
}
py
::
array
data_sync
();
std
::
string
id
()
const
{
return
id_
;
}
private:
bool
dirty_
{
true
};
std
::
string
id_
{
""
};
DeviceAddressPtr
device_address_
{
nullptr
};
};
...
...
mindspore/ccsrc/operator/composite/composite.cc
浏览文件 @
45a1df06
...
...
@@ -501,10 +501,16 @@ GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_
}
FuncGraphPtr
GradOperation
::
GetGrad
(
AnfNodePtr
node
,
const
AnfNodePtr
&
weights
,
const
std
::
vector
<
AnfNodePtr
>
&
params_list
,
bool
applyJ
)
{
const
std
::
vector
<
AnfNodePtr
>
&
params_list
,
const
std
::
vector
<
AnfNodePtr
>
&
args
,
bool
applyJ
)
{
FuncGraphPtr
ret
=
std
::
make_shared
<
FuncGraph
>
();
ret
->
set_flags
(
FUNC_GRAPH_FLAG_CORE
,
true
);
auto
weights_node
=
weights
;
if
(
weights
==
nullptr
&&
!
args
.
empty
())
{
weights_node
=
ret
->
NewCNode
(
args
);
}
ValueNodePtr
opsJ
=
NewValueNode
(
prim
::
kPrimJ
);
ValueNodePtr
opsTupleItem
=
NewValueNode
(
prim
::
kPrimTupleGetItem
);
...
...
@@ -537,7 +543,7 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
inputs
.
push_back
(
NewValueNode
(
1
));
AnfNodePtr
ptrBprop
=
ret
->
NewCNode
(
inputs
);
doGetGrad
(
ret
,
out
,
ptrBprop
,
weights
,
opsTupleItem
);
doGetGrad
(
ret
,
out
,
ptrBprop
,
weights
_node
,
opsTupleItem
);
return
ret
;
}
...
...
mindspore/ccsrc/operator/composite/composite.h
浏览文件 @
45a1df06
...
...
@@ -129,7 +129,7 @@ class GradOperation : public MetaFuncGraph {
MS_DECLARE_PARENT
(
GradOperation
,
MetaFuncGraph
)
FuncGraphPtr
GetGrad
(
AnfNodePtr
ptrNode
,
const
AnfNodePtr
&
weights
,
const
std
::
vector
<
AnfNodePtr
>
&
ptrParams
,
bool
applyJ
=
false
);
const
std
::
vector
<
AnfNodePtr
>
&
args
=
{},
bool
applyJ
=
false
);
FuncGraphPtr
GenerateFuncGraph
(
const
AbstractBasePtrList
&
args_spec_list
)
override
;
bool
sens_param
()
const
{
return
sens_param_
;
}
bool
get_all_
;
...
...
mindspore/ccsrc/operator/composite/do_signature.cc
浏览文件 @
45a1df06
...
...
@@ -285,6 +285,10 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
// and add cast op on other inputs to keep the same type with assigned parameter.
for
(
size_t
i
=
0
;
i
<
args_spec_list
.
size
();
++
i
)
{
AnfNodePtr
param
=
params_list
[
i
];
if
(
args_spec_list
[
i
]
==
nullptr
)
{
op_inputs
.
push_back
(
param
);
continue
;
}
SignatureEnumRW
sig
=
SignatureEnumRW
::
kRWDefault
;
// If sig_size is 0 use defalut.
if
(
sig_size
>
0
&&
i
<
sig_size
)
{
...
...
@@ -292,6 +296,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
}
else
if
(
has_var
&&
i
>=
sig_size
)
{
sig
=
signature
[
sig_size
-
1
].
rw
;
}
TypePtr
type
=
args_spec_list
[
i
]
->
GetTypeTrack
();
if
(
type
&&
type
->
type_id
()
==
kObjectTypeRef
)
{
if
(
sig
==
SignatureEnumRW
::
kRWRead
)
{
...
...
mindspore/ccsrc/optimizer/ad/dfunctor.cc
浏览文件 @
45a1df06
...
...
@@ -551,6 +551,10 @@ AdjointPtr DFunctor::FindAdjoint(const AnfNodePtr &primal) {
}
void
DFunctor
::
CallDoutHoleOnTape
()
{
if
(
!
is_top_
)
{
return
;
}
// Call dout hole of all adjoint.
for
(
auto
&
f
:
func_graph_to_functor_
)
{
for
(
auto
&
adjoint
:
f
.
second
->
anfnode_to_adjoin_
)
{
...
...
mindspore/ccsrc/optimizer/ad/dfunctor.h
浏览文件 @
45a1df06
...
...
@@ -55,6 +55,8 @@ class DFunctor {
FuncGraphPtr
KUserDefined
(
const
FuncGraphPtr
&
primal
);
// Register functor objects to form a global view.
void
Init
(
const
DFunctorPtr
&
functor
,
bool
is_top
=
false
);
bool
IsInScope
(
const
AnfNodePtr
&
node
);
// Clear resources.
static
void
Clear
();
...
...
@@ -62,7 +64,6 @@ class DFunctor {
// Map one morphism.
AdjointPtr
MapMorphism
(
const
AnfNodePtr
&
morph
);
bool
IsFreeMorphism
(
const
AnfNodePtr
&
node
);
bool
IsInScope
(
const
AnfNodePtr
&
node
);
// Map morphism that's not attached to output.
void
MapFreeMorphism
();
void
BackPropagateFv
(
const
AnfNodePtr
&
fv
,
const
AnfNodePtr
&
din
);
...
...
mindspore/ccsrc/optimizer/ad/grad.cc
浏览文件 @
45a1df06
...
...
@@ -23,7 +23,7 @@
namespace
mindspore
{
namespace
ad
{
FuncGraphPtr
Grad
(
const
FuncGraphPtr
&
func_graph
,
const
pipeline
::
ResourceBasePtr
&
resources
)
{
FuncGraphPtr
Grad
(
const
FuncGraphPtr
&
func_graph
,
const
pipeline
::
ResourceBasePtr
&
resources
,
bool
is_top
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
auto
gradkv
=
func_graph
->
transforms
().
find
(
"grad"
);
if
(
gradkv
!=
func_graph
->
transforms
().
end
())
{
...
...
@@ -46,14 +46,18 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
auto
user_defined
=
f
->
KUserDefined
(
func_graph
);
if
(
user_defined
!=
nullptr
)
{
multi_graph_sink
(
user_defined
);
DFunctor
::
Clear
();
if
(
is_top
)
{
DFunctor
::
Clear
();
}
return
user_defined
;
}
f
->
Init
(
f
,
true
);
f
->
Init
(
f
,
is_top
);
f
->
MapObject
();
f
->
MapMorphism
();
auto
ret
=
f
->
k_graph
();
DFunctor
::
Clear
();
if
(
is_top
)
{
DFunctor
::
Clear
();
}
multi_graph_sink
(
ret
);
return
ret
;
...
...
@@ -71,5 +75,7 @@ MetaFuncGraphPtr Kmeta(const PrimitivePtr &prim, const pipeline::ResourceBasePtr
MetaFuncGraphPtr
fg
=
g_k_prims
.
KMetaFuncGraph
(
prim
);
return
fg
;
}
void
CleanRes
()
{
DFunctor
::
Clear
();
}
}
// namespace ad
}
// namespace mindspore
mindspore/ccsrc/optimizer/ad/grad.h
浏览文件 @
45a1df06
...
...
@@ -28,9 +28,10 @@ namespace mindspore {
namespace
ad
{
using
ResourcePtr
=
std
::
shared_ptr
<
pipeline
::
Resource
>
;
FuncGraphPtr
Grad
(
const
FuncGraphPtr
&
func_graph
,
const
pipeline
::
ResourceBasePtr
&
resources
);
FuncGraphPtr
Grad
(
const
FuncGraphPtr
&
func_graph
,
const
pipeline
::
ResourceBasePtr
&
resources
,
bool
is_top
=
true
);
FuncGraphPtr
Kprim
(
const
ValueNodePtr
&
value_node
,
const
pipeline
::
ResourceBasePtr
&
resources
);
MetaFuncGraphPtr
Kmeta
(
const
PrimitivePtr
&
prim
,
const
pipeline
::
ResourceBasePtr
&
);
void
CleanRes
();
}
// namespace ad
}
// namespace mindspore
...
...
mindspore/ccsrc/optimizer/irpass/inline.h
浏览文件 @
45a1df06
...
...
@@ -167,7 +167,8 @@ class InlinerBase : public AnfVisitor {
auto
params
=
fg
->
parameters
();
auto
old_size
=
params
.
size
();
if
(
old_size
!=
new_params
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Parameter size not match."
;
MS_LOG
(
EXCEPTION
)
<<
"Parameter size not match."
<<
old_size
<<
" new "
<<
new_params
.
size
()
<<
fg
->
output
()
->
DebugString
(
10
);
}
for
(
size_t
i
=
0
;
i
<
old_size
;
i
++
)
{
(
void
)
mng
->
Replace
(
params
[
i
],
new_params
[
i
]);
...
...
mindspore/ccsrc/pipeline/action.cc
浏览文件 @
45a1df06
...
...
@@ -276,6 +276,8 @@ bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePa
bool
VmOptimizeAction
(
const
ResourcePtr
&
res
)
{
return
OptimizeAction
(
res
,
kVmPasses
);
}
bool
PynativeOptimizeAction
(
const
ResourcePtr
&
res
)
{
return
OptimizeAction
(
res
,
kPynativePasses
);
}
static
bool
IsCtrlSink
()
{
auto
ms_ctx
=
MsContext
::
GetInstance
();
std
::
string
device_target
=
ms_ctx
->
device_target
();
...
...
mindspore/ccsrc/pipeline/action.h
浏览文件 @
45a1df06
...
...
@@ -35,6 +35,7 @@ bool SymbolResolveAction(const ResourcePtr &res);
bool
AbstractSpecializeAction
(
const
ResourcePtr
&
res
);
bool
GeOptimizeAction
(
const
ResourcePtr
&
res
);
bool
VmOptimizeAction
(
const
ResourcePtr
&
res
);
bool
PynativeOptimizeAction
(
const
ResourcePtr
&
res
);
bool
TaskEmitAction
(
const
ResourcePtr
&
res
);
bool
ExecuteAction
(
const
ResourcePtr
&
res
);
...
...
mindspore/ccsrc/pipeline/parse/data_converter.cc
浏览文件 @
45a1df06
...
...
@@ -32,6 +32,7 @@
#include "utils/symbolic.h"
#include "utils/context/ms_context.h"
#include "debug/trace.h"
#include "optimizer/ad/grad.h"
namespace
mindspore
{
namespace
parse
{
...
...
@@ -338,6 +339,9 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
}
else
if
(
py
::
hasattr
(
obj
,
PYTHON_ENVINSTANCE_FLAG
))
{
std
::
shared_ptr
<
EnvInstance
>
env
=
obj
.
cast
<
std
::
shared_ptr
<
EnvInstance
>>
();
converted
=
env
;
}
else
if
(
py
::
hasattr
(
obj
,
"__parameter__"
))
{
auto
to_convert
=
py
::
cast
<
py
::
object
>
(
python_adapter
::
GetPyObjAttr
(
obj
,
"default_input"
));
ret
=
ConvertData
(
to_convert
,
&
converted
);
}
else
{
ret
=
ConvertOtherObj
(
obj
,
&
converted
);
}
...
...
mindspore/ccsrc/pipeline/parse/parse_base.h
浏览文件 @
45a1df06
...
...
@@ -60,6 +60,7 @@ const char PYTHON_MOD_RESOLVE_FUNCTION[] = "resolve_symbol";
const
char
PYTHON_MOD_RESOLVE_GET_OBJ_KEY
[]
=
"get_object_key"
;
const
char
PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER
[]
=
"is_class_member"
;
const
char
PYTHON_MOD_RESOLVE_GET_OBJ_TYPE
[]
=
"get_obj_type"
;
const
char
PYTHON_MOD_GET_OBJ_ID
[]
=
"get_obj_id"
;
const
char
PYTHON_MOD_GET_CLASS_INSTANCE_TYPE
[]
=
"get_class_instance_type"
;
const
char
PYTHON_MOD_CREATE_OBJ_INSTANCE
[]
=
"create_obj_instance"
;
const
char
PYTHON_MOD_GET_DATACLASS_ATTRS
[]
=
"get_dataclass_attributes"
;
...
...
@@ -83,6 +84,7 @@ const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name";
const
char
PYTHON_PARSE_CLASS_SLICE
[]
=
"create_slice_obj"
;
const
char
PYTHON_PARSE_CLASS_ELLIPSIS
[]
=
"create_ellipsis_obj"
;
const
char
PYTHON_MOD_GET_DEFAULT_INPUT
[]
=
"get_default_input"
;
// define the common name
const
char
NAMED_PRIMITIVE_ITER
[]
=
"iter"
;
...
...
mindspore/ccsrc/pipeline/pass.cc
浏览文件 @
45a1df06
...
...
@@ -278,5 +278,7 @@ std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStru
{
"opt_control"
,
ControlGroup
},
{
"opt_prepare"
,
PrepareGroup
},
{
"cconv"
,
CconvPass
}};
std
::
vector
<
PassItem
>
kPynativePasses
=
{{
"opt_a"
,
OptPassAGroup
},
{
"opt_b"
,
OptPassBGroup
},
{
"cconv"
,
CconvPass
}};
}
// namespace pipeline
}
// namespace mindspore
mindspore/ccsrc/pipeline/pass.h
浏览文件 @
45a1df06
...
...
@@ -29,6 +29,7 @@ using PassItem = std::pair<std::string, std::function<bool(ResourcePtr)>>;
extern
std
::
vector
<
PassItem
>
kGePasses
;
extern
std
::
vector
<
PassItem
>
kVmPasses
;
extern
std
::
vector
<
PassItem
>
kPynativePasses
;
bool
CconvPass
(
const
ResourcePtr
&
res
);
bool
ValidatePass
(
const
ResourcePtr
&
res
);
...
...
mindspore/ccsrc/pipeline/pipeline.cc
浏览文件 @
45a1df06
...
...
@@ -608,7 +608,7 @@ void Pipeline::Run() {
MS_LOG
(
INFO
)
<<
"End"
;
}
void
ExecutorPy
::
ProcessVmArg
(
const
py
::
tuple
&
args
,
const
std
::
string
&
phase
,
VectorRef
*
arg_list
)
{
void
ProcessVmArgInner
(
const
py
::
tuple
&
args
,
const
ResourcePtr
&
res
,
VectorRef
*
arg_list
)
{
std
::
size_t
size
=
args
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
size
;
i
++
)
{
...
...
@@ -625,7 +625,6 @@ void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, V
arg_list
->
push_back
(
converted
);
}
ResourcePtr
res
=
GetResource
(
phase
);
MS_EXCEPTION_IF_NULL
(
res
);
auto
graph
=
res
->
func_graph
();
MS_EXCEPTION_IF_NULL
(
graph
);
...
...
@@ -647,6 +646,10 @@ void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, V
}
}
void
ExecutorPy
::
ProcessVmArg
(
const
py
::
tuple
&
args
,
const
std
::
string
&
phase
,
VectorRef
*
arg_list
)
{
ProcessVmArgInner
(
args
,
GetResource
(
phase
),
arg_list
);
}
py
::
object
ExecutorPy
::
Run
(
const
py
::
tuple
&
args
,
const
py
::
object
&
phase
)
{
std
::
size_t
size
=
args
.
size
();
if
(
!
py
::
isinstance
<
py
::
str
>
(
phase
))
{
...
...
@@ -874,6 +877,8 @@ void ClearResAtexit() {
compile
::
ClearConvertCache
();
pipeline
::
GetMethodMap
().
clear
();
pipeline
::
ExecutorPy
::
ClearRes
();
pipeline
::
ReclaimOptimizer
();
pynative
::
PynativeExecutor
::
GetInstance
()
->
Clean
();
#ifdef ENABLE_GE
transform
::
DfGraphManager
::
GetInstance
().
ClearGraph
();
transform
::
DfGraphConvertor
::
get_adpt_map
().
clear
();
...
...
mindspore/ccsrc/pipeline/pipeline.h
浏览文件 @
45a1df06
...
...
@@ -139,6 +139,8 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
const
std
::
vector
<
TypePtr
>
&
types
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
shapes
,
const
std
::
vector
<
int64_t
>
&
input_indexes
,
bool
need_run
);
void
ProcessVmArgInner
(
const
py
::
tuple
&
args
,
const
ResourcePtr
&
res
,
VectorRef
*
arg_list
);
}
// namespace pipeline
}
// namespace mindspore
...
...
mindspore/ccsrc/pynative/pynative_execute.cc
浏览文件 @
45a1df06
此差异已折叠。
点击以展开。
mindspore/ccsrc/pynative/pynative_execute.h
浏览文件 @
45a1df06
...
...
@@ -22,23 +22,93 @@
#include <string>
#include <memory>
#include <unordered_map>
#include <mutex>
#include <stack>
#include "pybind11/pybind11.h"
#include "pynative/base.h"
#include "utils/context/ms_context.h"
#include "ir/anf.h"
#include "pipeline/resource.h"
#include "operator/composite/composite.h"
namespace
mindspore
{
namespace
pynative
{
namespace
py
=
pybind11
;
using
ResourcePtr
=
std
::
shared_ptr
<
pipeline
::
Resource
>
;
using
GradOperationPtr
=
std
::
shared_ptr
<
prim
::
GradOperation
>
;
py
::
object
RunOpInVM
(
const
OpExecInfoPtr
&
op_exec_info
,
PynativeStatusCode
*
status
);
py
::
tuple
RunOp
(
const
py
::
args
&
args
);
py
::
list
ConvertInputs
(
const
PrimitivePyPtr
&
prim
,
const
py
::
list
&
py_args
);
void
ClearPyNativeSession
();
struct
GraphInfo
{
std
::
unordered_map
<
std
::
string
,
AnfNodePtr
>
param_map
;
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
AnfNodePtr
,
int
>>
obj_node_map
;
AnfNodePtr
output
;
std
::
vector
<
std
::
string
>
objects
;
};
class
PynativeExecutor
:
public
std
::
enable_shared_from_this
<
PynativeExecutor
>
{
public:
static
std
::
shared_ptr
<
PynativeExecutor
>
GetInstance
()
{
std
::
lock_guard
<
std
::
mutex
>
i_lock
(
instance_lock_
);
if
(
executor_
==
nullptr
)
{
executor_
=
std
::
shared_ptr
<
PynativeExecutor
>
(
new
(
std
::
nothrow
)
PynativeExecutor
());
resource_
=
std
::
make_shared
<
pipeline
::
Resource
>
();
}
return
executor_
;
}
void
NewGraph
(
const
py
::
object
&
cell
,
const
py
::
args
&
args
);
void
EndGraph
(
const
py
::
object
&
cell
,
const
py
::
object
&
out
,
const
py
::
args
&
args
);
void
GradNet
(
const
GradOperationPtr
&
grad
,
const
py
::
object
&
cell
,
const
py
::
object
&
weights
,
const
py
::
args
&
args
);
void
Clear
();
void
Clean
();
bool
grad_flag
()
{
return
grad_flag_
;
}
void
set_grad_flag
(
bool
flag
)
{
grad_flag_
=
flag
;
}
AnfNodePtr
GetInput
(
const
py
::
object
&
obj
,
const
py
::
object
&
op_mask
);
AnfNodePtr
GetObjNode
(
const
py
::
object
&
obj
);
FuncGraphPtr
curr_g
()
{
return
curr_g_
;
}
void
set_pyobj
(
FuncGraphPtr
g
,
const
std
::
string
obj
)
{
graph_info_map_
[
g
].
objects
.
push_back
(
obj
);
}
void
set_obj_node_map
(
FuncGraphPtr
g
,
const
std
::
string
obj
,
AnfNodePtr
node
)
{
graph_info_map_
[
g
].
obj_node_map
[
obj
]
=
std
::
make_pair
(
node
,
-
1
);
}
void
set_obj_node_map
(
FuncGraphPtr
g
,
const
std
::
string
obj
,
AnfNodePtr
node
,
int
index
)
{
graph_info_map_
[
g
].
obj_node_map
[
obj
]
=
std
::
make_pair
(
node
,
index
);
}
AnfNodePtr
MakeCNode
(
const
py
::
args
&
args
,
const
py
::
tuple
&
out
);
py
::
object
Run
(
const
py
::
tuple
&
args
,
const
py
::
object
&
phase
);
void
Pushp
();
void
Popp
();
FuncGraphPtr
GradGraph
(
FuncGraphPtr
g
,
const
GradOperationPtr
&
grad_op
,
const
std
::
vector
<
AnfNodePtr
>
&
weights
,
size_t
arg_size
);
~
PynativeExecutor
();
private:
PynativeExecutor
();
static
std
::
shared_ptr
<
PynativeExecutor
>
executor_
;
static
std
::
mutex
instance_lock_
;
static
ResourcePtr
resource_
;
bool
grad_flag_
;
std
::
unordered_map
<
std
::
string
,
FuncGraphPtr
>
graph_map_
;
std
::
unordered_map
<
std
::
string
,
FuncGraphPtr
>
cell_graph_map_
;
std
::
unordered_map
<
FuncGraphPtr
,
GraphInfo
>
graph_info_map_
;
std
::
stack
<
FuncGraphPtr
>
graph_p_
;
FuncGraphPtr
top_g_
;
FuncGraphPtr
df_builder_
;
FuncGraphPtr
curr_g_
;
};
using
PynativeExecutorPtr
=
std
::
shared_ptr
<
PynativeExecutor
>
;
}
// namespace pynative
}
// namespace mindspore
...
...
mindspore/common/api.py
浏览文件 @
45a1df06
...
...
@@ -20,7 +20,7 @@ from collections import OrderedDict
from
functools
import
wraps
from
mindspore
import
context
from
mindspore
import
log
as
logger
from
.._c_expression
import
generate_key
,
Executor_
,
Tensor
,
MetaTensor
from
.._c_expression
import
generate_key
,
Executor_
,
Tensor
,
MetaTensor
,
PynativeExecutor_
from
.._c_expression
import
verify_inputs_signature
,
init_exec_dataset
,
_set_dataset_mode_config
,
init_backend
from
.tensor
import
Tensor
as
MsTensor
...
...
@@ -273,6 +273,34 @@ def _generate_pip_args(obj, *args, method="construct"):
obj
.
__parse_method__
=
parse_method
return
args_names
,
args_list
class
_PynativeExecutor
:
"""
An pynative executor used to compile/manage/run graph.
Returns:
Graph, return the result of pipeline running.
"""
def
__init__
(
self
):
self
.
_executor
=
PynativeExecutor_
.
get_instance
()
def
new_graph
(
self
,
obj
,
*
args
):
self
.
_executor
.
new_graph
(
obj
,
*
args
)
def
end_graph
(
self
,
obj
,
output
,
*
args
):
self
.
_executor
.
end_graph
(
obj
,
output
,
*
args
)
def
grad
(
self
,
grad
,
obj
,
weights
,
*
args
):
self
.
_executor
.
grad_net
(
grad
,
obj
,
weights
,
*
args
)
def
clear
(
self
):
self
.
_executor
.
clear
()
def
set_grad_flag
(
self
,
flag
):
self
.
_executor
.
set_grad_flag
(
flag
)
def
__call__
(
self
,
*
args
):
return
self
.
_executor
(
args
,
""
)
class
_Executor
:
"""
...
...
@@ -500,5 +528,6 @@ class _Executor:
_executor
=
_Executor
()
_pynative_exec
=
_PynativeExecutor
()
__all__
=
[
'ms_function'
]
mindspore/common/tensor.py
浏览文件 @
45a1df06
...
...
@@ -89,7 +89,6 @@ class Tensor(Tensor_):
return
hash
(
id
(
self
))
def
__mul__
(
self
,
other
):
check_type
(
'tensor input_data'
,
other
,
(
Tensor
,
float
,
int
))
out
=
tensor_operator_registry
.
get
(
'__mul__'
)(
self
,
other
)
return
out
...
...
@@ -101,7 +100,6 @@ class Tensor(Tensor_):
return
out
def
__radd__
(
self
,
other
):
check_type
(
'tensor operation input'
,
other
,
(
Tensor
,
float
,
int
))
out
=
tensor_operator_registry
.
get
(
'__add__'
)(
other
,
self
)
return
out
...
...
@@ -110,22 +108,18 @@ class Tensor(Tensor_):
return
out
def
__rmul__
(
self
,
other
):
check_type
(
'tensor operation input'
,
other
,
(
Tensor
,
float
,
int
))
out
=
tensor_operator_registry
.
get
(
'__mul__'
)(
other
,
self
)
return
out
def
__truediv__
(
self
,
other
):
check_type
(
'tensor operation input'
,
other
,
(
Tensor
,
float
,
int
))
out
=
tensor_operator_registry
.
get
(
'__div__'
)(
self
,
other
)
return
out
def
__rtruediv__
(
self
,
other
):
check_type
(
'tensor operation input'
,
other
,
(
Tensor
,
float
,
int
))
out
=
tensor_operator_registry
.
get
(
'__div__'
)(
other
,
self
)
return
out
def
__sub__
(
self
,
other
):
check_type
(
'tensor operation input'
,
other
,
(
Tensor
,
float
,
int
))
out
=
self
.
__add__
(
-
other
)
return
out
...
...
@@ -134,7 +128,6 @@ class Tensor(Tensor_):
return
out
def
__rsub__
(
self
,
other
):
check_type
(
'tensor operation input'
,
other
,
(
Tensor
,
float
,
int
))
out
=
tensor_operator_registry
.
get
(
'__add__'
)(
other
,
Tensor
(
-
self
.
asnumpy
()))
return
out
...
...
mindspore/nn/cell.py
浏览文件 @
45a1df06
...
...
@@ -19,7 +19,7 @@ from collections import OrderedDict
from
mindspore
import
log
as
logger
from
..
import
context
from
..common
import
dtype
as
mstype
from
..common.api
import
_executor
from
..common.api
import
_executor
,
_pynative_exec
from
.._checkparam
import
_check_str_by_regular
from
..common.parameter
import
Parameter
,
ParameterTuple
from
.._c_expression
import
init_backend
...
...
@@ -60,6 +60,7 @@ class Cell:
self
.
_params
=
OrderedDict
()
self
.
_cells
=
OrderedDict
()
self
.
training
=
False
self
.
requires_grad
=
False
self
.
pynative
=
False
self
.
_param_prefix
=
''
self
.
_auto_prefix
=
auto_prefix
...
...
@@ -79,6 +80,15 @@ class Cell:
self
.
_backward_hook
=
None
self
.
enable_hook
=
False
self
.
_bprop_debug
=
False
self
.
_is_run
=
False
@
property
def
is_run
(
self
):
return
self
.
_is_run
@
is_run
.
setter
def
is_run
(
self
,
value
):
self
.
_is_run
=
value
@
property
def
create_time
(
self
):
...
...
@@ -192,9 +202,20 @@ class Cell:
out
=
self
.
compile_and_run
(
*
inputs
)
return
out
self
.
init_parameters_data
()
output
=
self
.
construct
(
*
inputs
)
if
self
.
requires_grad
is
True
:
_pynative_exec
.
set_grad_flag
(
True
)
_pynative_exec
.
new_graph
(
self
,
*
inputs
)
else
:
_pynative_exec
.
set_grad_flag
(
False
)
if
self
.
enable_hook
:
output
=
self
.
_hook_construct
(
*
inputs
)
else
:
output
=
self
.
construct
(
*
inputs
)
if
isinstance
(
output
,
Parameter
):
output
=
output
.
data
if
self
.
requires_grad
is
True
:
_pynative_exec
.
end_graph
(
self
,
output
,
*
inputs
)
self
.
_is_run
=
True
return
output
def
__setattr__
(
self
,
name
,
value
):
...
...
@@ -722,6 +743,10 @@ class Cell:
self
.
add_flags_recursive
(
**
flags
)
return
self
def
set_grad
(
self
,
mode
=
True
):
self
.
add_flags_recursive
(
requires_grad
=
mode
)
return
self
def
set_train
(
self
,
mode
=
True
):
"""
Sets the cell to training mode.
...
...
@@ -762,9 +787,9 @@ class Cell:
self
.
add_flags
(
auto_parallel
=
True
)
self
.
_get_construct_inputs_number_and_name
()
def
_hook_construct
(
self
,
inputs
):
def
_hook_construct
(
self
,
*
inputs
):
"""Hook construct method to replace original construct method when hook function enabled."""
inputs
=
self
.
_backward_hook
(
inputs
)
inputs
=
self
.
_backward_hook
(
*
inputs
)
inputs
=
self
.
construct
(
inputs
)
outputs
=
self
.
_backward_hook
(
inputs
)
return
outputs
...
...
mindspore/nn/wrap/cell_wrapper.py
浏览文件 @
45a1df06
...
...
@@ -166,6 +166,7 @@ class TrainOneStepCell(Cell):
def
__init__
(
self
,
network
,
optimizer
,
sens
=
1.0
):
super
(
TrainOneStepCell
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
network
=
network
self
.
network
.
set_grad
()
self
.
network
.
add_flags
(
defer_inline
=
True
)
self
.
weights
=
optimizer
.
parameters
self
.
optimizer
=
optimizer
...
...
mindspore/ops/composite/base.py
浏览文件 @
45a1df06
...
...
@@ -18,14 +18,16 @@
"""Basic composite operations."""
from
functools
import
partial
from
mindspore
import
context
from
..._c_expression
import
EnvInstance_
,
GradOperation_
,
HyperMap_
,
MultitypeFuncGraph_
,
Tail_
,
TensorSlice_
,
\
TupleAdd_
,
TupleSlice_
,
UnpackCall_
,
ZipOperation_
,
ListAppend_
,
TupleGetItemTensor_
from
...common
import
dtype
as
mstype
from
...common.api
import
ms_function
from
...common.api
import
ms_function
,
_pynative_exec
from
..
import
functional
as
F
from
..
import
operations
as
P
from
...common.parameter
import
Parameter
__all__
=
[
EnvInstance_
,
TensorSlice_
,
TupleAdd_
,
TupleSlice_
,
UnpackCall_
,
TupleGetItemTensor_
]
...
...
@@ -105,14 +107,34 @@ class GradOperation(GradOperation_):
GradOperation_
.
__init__
(
self
,
name
,
get_all
,
get_by_list
,
sens_param
)
self
.
grad_fn
=
None
self
.
fn
=
None
self
.
need_forward
=
False
def
__call__
(
self
,
fn
,
weights
=
None
):
grad_
=
GradOperation
(
'grad'
,
self
.
get_all
,
self
.
get_by_list
,
self
.
sens_param
)
if
self
.
grad_fn
is
None
or
self
.
fn
!=
fn
:
if
self
.
get_by_list
:
@
ms_function
(
obj
=
fn
)
def
after_grad
(
*
args
):
return
grad_
(
fn
,
weights
)(
*
args
)
if
context
.
get_context
(
"mode"
)
==
context
.
GRAPH_MODE
or
fn
.
bprop_debug
:
@
ms_function
(
obj
=
fn
)
def
after_grad
(
*
args
):
return
grad_
(
fn
,
weights
)(
*
args
)
else
:
def
after_grad
(
*
args
):
if
fn
.
is_run
and
not
fn
.
requires_grad
:
raise
ValueError
(
"obj must set_grad."
)
if
not
fn
.
is_run
:
self
.
need_forward
=
True
print
(
"already has forward run before grad by user"
)
if
self
.
need_forward
:
fn
.
set_grad
()
if
self
.
sens_param
:
f_args
=
args
[:
-
1
]
fn
(
*
f_args
)
else
:
fn
(
*
args
)
_pynative_exec
.
grad
(
grad_
,
fn
,
weights
,
*
args
)
out
=
_pynative_exec
(
*
args
)
_pynative_exec
.
clear
()
return
out
else
:
@
ms_function
(
obj
=
fn
)
def
after_grad
(
*
args
):
...
...
mindspore/ops/operations/debug_ops.py
浏览文件 @
45a1df06
...
...
@@ -286,12 +286,6 @@ class HookBackward(PrimitiveWithInfer):
self
.
register_hook
(
hook_fn
)
self
.
cell_id
=
cell_id
def
__call__
(
self
,
*
inputs
):
"""run in PyNative mode."""
if
len
(
inputs
)
==
1
:
return
inputs
[
0
]
return
inputs
def
infer_shape
(
self
,
*
inputs_shape
):
if
len
(
inputs_shape
)
==
1
:
return
inputs_shape
[
0
]
...
...
mindspore/ops/primitive.py
浏览文件 @
45a1df06
...
...
@@ -328,15 +328,9 @@ def _run_op(obj, op_name, args):
op_inputs
=
[]
for
i
,
arg
in
enumerate
(
args
):
if
hasattr
(
arg
,
'__parameter__'
):
op_inputs
.
append
(
arg
.
default_input
)
op_mask
[
i
]
=
1
elif
isinstance
(
arg
,
tuple
):
convert
=
lambda
x
:
x
.
default_input
if
hasattr
(
x
,
'__parameter__'
)
else
x
args_
=
tuple
(
convert
(
x
)
for
x
in
arg
)
op_inputs
.
append
(
args_
)
else
:
op_inputs
.
append
(
arg
)
output
=
real_run_op
(
obj
,
op_name
,
tuple
(
op_inputs
),
tuple
(
op_mask
))
op_inputs
.
append
(
arg
)
output
=
real_run_op
(
obj
,
op_name
,
args
,
tuple
(
op_mask
))
if
not
output
:
raise
RuntimeError
(
"Pynative run op %s failed!"
%
op_name
)
if
len
(
output
)
==
1
:
...
...
tests/ut/python/pynative_mode/nn/test_dropout.py
浏览文件 @
45a1df06
...
...
@@ -54,4 +54,4 @@ class Net_Dropout(nn.Cell):
def
test_compile_dropout
():
net
=
Net_Dropout
()
input_data
=
Tensor
(
np
.
ones
([
20
,
16
,
50
],
dtype
=
np
.
float32
))
_executor
.
compile
(
net
,
input_data
)
net
(
input_data
)
tests/vm_impl/array_ops_vm_impl.py
浏览文件 @
45a1df06
...
...
@@ -18,6 +18,7 @@ import numpy as np
import
mindspore.common.dtype
as
mstype
from
mindspore.common.tensor
import
Tensor
from
mindspore.ops
import
operations
as
P
from
mindspore.ops.operations
import
_grad_ops
as
G
from
mindspore.ops.vm_impl_registry
import
vm_impl_registry
as
vm_impl_getters
from
.vm_interface
import
vm
...
...
@@ -225,7 +226,7 @@ def vm_impl_slice(self):
return
vm_impl
@
vm_impl_getters
.
register
(
P
.
_grad_ops
.
ConcatOffset
)
@
vm_impl_getters
.
register
(
G
.
ConcatOffset
)
def
vm_impl_concatOffset
(
self
):
"""Generate vm_impl function for ConcatOffset"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录