Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
c28a875f
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
c28a875f
编写于
3月 29, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(imperative/amp): adapt new transformation
GitOrigin-RevId: 6edd577a70a8ea0ae00fbde0a6e4034273a30867
上级
fd41302c
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
157 addition
and
67 deletion
+157
-67
imperative/python/megengine/amp/autocast.py
imperative/python/megengine/amp/autocast.py
+1
-3
imperative/python/megengine/amp/convert_format.py
imperative/python/megengine/amp/convert_format.py
+7
-5
imperative/python/megengine/core/autodiff/grad.py
imperative/python/megengine/core/autodiff/grad.py
+0
-2
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+0
-2
imperative/python/src/tensor_utils.cpp
imperative/python/src/tensor_utils.cpp
+1
-1
imperative/python/test/unit/core/test_formatted_tensor.py
imperative/python/test/unit/core/test_formatted_tensor.py
+35
-12
imperative/src/impl/interpreter/interpreter_impl.cpp
imperative/src/impl/interpreter/interpreter_impl.cpp
+21
-0
imperative/src/impl/transformations/format.cpp
imperative/src/impl/transformations/format.cpp
+92
-40
imperative/src/include/megbrain/imperative/transformations/grad.h
...ve/src/include/megbrain/imperative/transformations/grad.h
+0
-2
未找到文件。
imperative/python/megengine/amp/autocast.py
浏览文件 @
c28a875f
...
@@ -50,8 +50,6 @@ class autocast:
...
@@ -50,8 +50,6 @@ class autocast:
self
.
_origin_enabled
=
None
self
.
_origin_enabled
=
None
self
.
_origin_high
=
None
self
.
_origin_high
=
None
self
.
_origin_low
=
None
self
.
_origin_low
=
None
self
.
_origin_compute_mode
=
None
self
.
_origin_configs
=
None
self
.
_origin_configs
=
None
def
__enter__
(
self
):
def
__enter__
(
self
):
...
@@ -75,7 +73,7 @@ class autocast:
...
@@ -75,7 +73,7 @@ class autocast:
amp
.
_set_amp_high_prec_dtype
(
self
.
_origin_high
)
amp
.
_set_amp_high_prec_dtype
(
self
.
_origin_high
)
amp
.
_set_amp_low_prec_dtype
(
self
.
_origin_low
)
amp
.
_set_amp_low_prec_dtype
(
self
.
_origin_low
)
_config
.
_reset_execution_config
(
*
self
.
_origin_co
mpute_mode
)
_config
.
_reset_execution_config
(
*
self
.
_origin_co
nfigs
)
def
__call__
(
self
,
func
):
def
__call__
(
self
,
func
):
@
functools
.
wraps
(
func
)
@
functools
.
wraps
(
func
)
...
...
imperative/python/megengine/amp/convert_format.py
浏览文件 @
c28a875f
...
@@ -15,11 +15,14 @@ from ..core import _config
...
@@ -15,11 +15,14 @@ from ..core import _config
def
_is_nchw_format
(
param
:
Tensor
):
def
_is_nchw_format
(
param
:
Tensor
):
# TODO: use better condition
# TODO: use better condition
return
(
len
(
param
.
shape
)
==
4
or
len
(
param
.
shape
)
==
5
)
and
param
.
format
!=
"nhwc"
return
(
param
.
ndim
==
4
or
param
.
ndim
==
5
)
and
param
.
format
!=
"nhwc"
def
convert_tensor_format
(
x
:
Tensor
,
inplace
:
bool
=
True
):
def
convert_tensor_format
(
x
:
Tensor
,
inplace
:
bool
=
True
):
"""Convert NCHW Tensor to NHWC Tensor."""
"""Convert NCHW Tensor to NHWC Tensor."""
if
not
_is_nchw_format
(
x
):
return
x
if
x
.
ndim
==
4
:
if
x
.
ndim
==
4
:
pattern
=
(
0
,
2
,
3
,
1
)
pattern
=
(
0
,
2
,
3
,
1
)
elif
x
.
ndim
==
5
:
elif
x
.
ndim
==
5
:
...
@@ -29,8 +32,9 @@ def convert_tensor_format(x: Tensor, inplace: bool = True):
...
@@ -29,8 +32,9 @@ def convert_tensor_format(x: Tensor, inplace: bool = True):
# TODO: use initialization from tensor after fixing format setting
# TODO: use initialization from tensor after fixing format setting
if
x
.
format
!=
"nhwc"
:
if
x
.
format
!=
"nhwc"
:
if
inplace
:
if
inplace
:
#
reset will destroy backward grad
#
hostvalue should still be valid, so no d2h cost.
data
=
x
.
numpy
().
transpose
(
*
pattern
)
data
=
x
.
numpy
().
transpose
(
*
pattern
)
# reset will destroy existed backward grad
x
[...]
=
Tensor
(
data
,
format
=
"nhwc"
)
x
[...]
=
Tensor
(
data
,
format
=
"nhwc"
)
else
:
else
:
# use mge interface to maintain grad
# use mge interface to maintain grad
...
@@ -45,7 +49,5 @@ def convert_module_format(module: Module, inplace: bool = True):
...
@@ -45,7 +49,5 @@ def convert_module_format(module: Module, inplace: bool = True):
module
=
deepcopy
(
module
)
module
=
deepcopy
(
module
)
for
name
,
param
in
module
.
named_tensors
():
for
name
,
param
in
module
.
named_tensors
():
if
_is_nchw_format
(
param
):
convert_tensor_format
(
param
,
inplace
=
True
)
# hostvalue should still be valid, so no d2h cost.
convert_tensor_format
(
param
,
inplace
=
True
)
return
module
return
module
imperative/python/megengine/core/autodiff/grad.py
浏览文件 @
c28a875f
...
@@ -64,9 +64,7 @@ class Grad:
...
@@ -64,9 +64,7 @@ class Grad:
continue
continue
grad
.
suppress
()
grad
.
suppress
()
print
(
"before backward"
)
self
.
_impl
.
backward
(
ys
,
dys
)
self
.
_impl
.
backward
(
ys
,
dys
)
print
(
"after backward"
)
for
grad
in
group
:
for
grad
in
group
:
if
grad
is
self
:
if
grad
is
self
:
...
...
imperative/python/megengine/functional/nn.py
浏览文件 @
c28a875f
...
@@ -245,8 +245,6 @@ def conv2d(
...
@@ -245,8 +245,6 @@ def conv2d(
sparse_type
=
"dense"
if
groups
==
1
else
"group"
sparse_type
=
"dense"
if
groups
==
1
else
"group"
compute_mode
=
_config
.
_get_actual_op_param
(
compute_mode
,
_config
.
__compute_mode
)
compute_mode
=
_config
.
_get_actual_op_param
(
compute_mode
,
_config
.
__compute_mode
)
with
_config
.
_override
(
auto_format_convert
=
False
):
print
(
compute_mode
,
inp
.
shape
,
inp
.
format
,
weight
.
shape
,
weight
.
format
)
op
=
builtin
.
Convolution
(
op
=
builtin
.
Convolution
(
stride_h
=
stride_h
,
stride_h
=
stride_h
,
stride_w
=
stride_w
,
stride_w
=
stride_w
,
...
...
imperative/python/src/tensor_utils.cpp
浏览文件 @
c28a875f
...
@@ -320,7 +320,7 @@ py::object _Const(py::handle value, py::handle dtype, py::handle device) {
...
@@ -320,7 +320,7 @@ py::object _Const(py::handle value, py::handle dtype, py::handle device) {
}
}
}
}
py
::
object
device_obj
=
device2obj
(
device
,
true
);
py
::
object
device_obj
=
device2obj
(
device
,
true
);
py
::
tuple
tup
=
py
::
make_tuple
(
val
,
dtype
,
device_obj
,
true
,
false
,
py
::
none
());
py
::
tuple
tup
=
py
::
make_tuple
(
val
,
dtype
,
device_obj
,
true
,
false
,
py
::
none
()
,
py
::
none
()
);
return
TensorWrapper
::
make
(
py_tensor_type
,
tup
.
ptr
(),
nullptr
);
return
TensorWrapper
::
make
(
py_tensor_type
,
tup
.
ptr
(),
nullptr
);
}
}
...
...
imperative/python/test/unit/core/test_formatted_tensor.py
浏览文件 @
c28a875f
...
@@ -35,6 +35,7 @@ def test_basic():
...
@@ -35,6 +35,7 @@ def test_basic():
b
.
format
=
"nhwc"
b
.
format
=
"nhwc"
assert
b
.
format
==
"nhwc"
assert
b
.
format
==
"nhwc"
def
_compare_nchw_nhwc
(
data
,
func
,
is_symbolic
=
None
):
def
_compare_nchw_nhwc
(
data
,
func
,
is_symbolic
=
None
):
x1
=
tensor
(
data
)
x1
=
tensor
(
data
)
x2
=
tensor
(
data
.
transpose
(
0
,
2
,
3
,
1
),
format
=
"nhwc"
)
x2
=
tensor
(
data
.
transpose
(
0
,
2
,
3
,
1
),
format
=
"nhwc"
)
...
@@ -335,21 +336,42 @@ def _compare_backward(inps, model, is_symbolic=None):
...
@@ -335,21 +336,42 @@ def _compare_backward(inps, model, is_symbolic=None):
gm
=
GradManager
().
attach
(
model
.
parameters
())
gm
=
GradManager
().
attach
(
model
.
parameters
())
with
gm
:
with
gm
:
rst
=
func
(
*
inps
)
with
mge
.
amp
.
autocast
():
gm
.
backward
(
rst
)
rst
=
func
(
*
inps
)
expected_grads
=
[
param
.
grad
for
param
in
model
.
parameters
()]
gm
.
backward
(
rst
)
expected_grads
=
[
param
.
grad
.
numpy
()
for
param
in
gm
.
attached_tensors
()]
for
param
in
gm
.
attached_tensors
():
param
.
grad
=
None
inps
=
[
mge
.
amp
.
convert_tensor_format
(
inp
)
for
inp
in
inps
]
inps
=
[
mge
.
amp
.
convert_tensor_format
(
inp
)
for
inp
in
inps
]
model
=
mge
.
amp
.
convert_module_format
(
model
)
model
=
mge
.
amp
.
convert_module_format
(
model
)
gm
=
GradManager
().
attach
(
model
.
parameters
())
gm
=
GradManager
().
attach
(
model
.
parameters
())
with
gm
:
with
gm
:
rst
=
func
(
*
inps
)
with
mge
.
amp
.
autocast
():
gm
.
backward
(
rst
)
rst
=
func
(
*
inps
)
actual_grads
=
[
param
.
grad
for
param
in
model
.
parameters
()]
gm
.
backward
(
rst
)
actual_grads
=
[
param
.
grad
.
numpy
()
for
param
in
gm
.
attached_tensors
()]
for
expected
,
actual
in
zip
(
expected_grads
,
actual_grads
):
for
expected
,
actual
in
zip
(
expected_grads
,
actual_grads
):
# print(param.grad)
assert
expected
is
not
None
np
.
testing
.
assert_equal
(
expected
.
numpy
(),
actual
.
numpy
())
assert
actual
is
not
None
np
.
testing
.
assert_almost_equal
(
expected
,
actual
,
decimal
=
5
)
@
pytest
.
mark
.
parametrize
(
"is_symbolic"
,
[
None
])
def
test_backward_basic
(
is_symbolic
):
class
Net
(
M
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
w
=
mge
.
Parameter
([[
2.0
],
[
4.0
],
[
6.0
]])
self
.
b
=
mge
.
Parameter
(
-
1.0
)
def
forward
(
self
,
inp
):
return
F
.
matmul
(
inp
,
self
.
w
)
+
self
.
b
inp
=
mge
.
tensor
([
1.0
,
3.0
,
5.0
]).
reshape
(
1
,
3
)
_compare_backward
([
inp
],
Net
(),
is_symbolic
)
@
pytest
.
mark
.
parametrize
(
"is_symbolic"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"is_symbolic"
,
[
None
])
...
@@ -379,14 +401,15 @@ def test_backward_groupconv2d_bn(is_symbolic):
...
@@ -379,14 +401,15 @@ def test_backward_groupconv2d_bn(is_symbolic):
class
Net
(
M
.
Module
):
class
Net
(
M
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
self
.
conv
=
M
.
Conv2d
(
2
,
2
,
1
,
groups
=
2
)
self
.
conv0
=
M
.
Conv2d
(
32
,
256
,
3
,
groups
=
32
,
stride
=
2
)
self
.
bn
=
M
.
BatchNorm2d
(
2
)
self
.
conv1
=
M
.
Conv2d
(
256
,
2048
,
3
,
groups
=
32
,
stride
=
2
)
# self.bn = M.BatchNorm2d(2048)
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
# test manually convert to NHWC, usually used in detection head
# test manually convert to NHWC, usually used in detection head
return
self
.
bn
(
self
.
conv
(
inp
))
return
self
.
conv1
(
self
.
conv0
(
inp
))
inp
=
mge
.
tensor
(
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
)
))
inp
=
mge
.
tensor
(
np
.
ones
(
shape
=
(
32
,
32
,
56
,
56
)).
astype
(
"float32"
))
_compare_backward
([
inp
],
Net
(),
is_symbolic
)
_compare_backward
([
inp
],
Net
(),
is_symbolic
)
# def func(x, w, b, bn_w, bn_b):
# def func(x, w, b, bn_w, bn_b):
# x = F.conv2d(x, w, b, groups=2)
# x = F.conv2d(x, w, b, groups=2)
...
...
imperative/src/impl/interpreter/interpreter_impl.cpp
浏览文件 @
c28a875f
...
@@ -260,6 +260,7 @@ void ChannelImpl::dispatch_default_cpu(
...
@@ -260,6 +260,7 @@ void ChannelImpl::dispatch_default_cpu(
CompNode
output_cn
;
CompNode
output_cn
;
{
{
MGB_LOCK_GUARD
(
m_mutex
);
MGB_LOCK_GUARD
(
m_mutex
);
//mgb_log_warn(">>> MGB_LOCK_GUARD dispatch_default_cpu");
for
(
auto
&&
info
:
input_infos
)
{
for
(
auto
&&
info
:
input_infos
)
{
auto
input_cn
=
info
->
desc
.
comp_node
;
auto
input_cn
=
info
->
desc
.
comp_node
;
if
(
!
output_cn
.
valid
())
{
if
(
!
output_cn
.
valid
())
{
...
@@ -277,6 +278,7 @@ void ChannelImpl::dispatch_default_cpu(
...
@@ -277,6 +278,7 @@ void ChannelImpl::dispatch_default_cpu(
input_tensornds
.
emplace_back
(
info
->
h_value
.
proxy_to_default_cpu
());
input_tensornds
.
emplace_back
(
info
->
h_value
.
proxy_to_default_cpu
());
}
}
}
}
//mgb_log_warn("<<< MGB_LOCK_GUARD dispatch_default_cpu");
}
}
SmallVector
<
DeviceTensorND
>
output_tensornds
;
SmallVector
<
DeviceTensorND
>
output_tensornds
;
...
@@ -530,7 +532,9 @@ void ChannelImpl::sync() {
...
@@ -530,7 +532,9 @@ void ChannelImpl::sync() {
void
ChannelImpl
::
sync_impl
()
{
void
ChannelImpl
::
sync_impl
()
{
m_worker
.
wait_all_task_finish
();
m_worker
.
wait_all_task_finish
();
MGB_LOCK_GUARD
(
m_mutex
);
MGB_LOCK_GUARD
(
m_mutex
);
//mgb_log_warn(">>> MGB_LOCK_GUARD sync_impl");
check_worker_exc_unsafe
();
check_worker_exc_unsafe
();
//mgb_log_warn("<<< MGB_LOCK_GUARD sync_impl");
}
}
void
ChannelImpl
::
close
()
{
void
ChannelImpl
::
close
()
{
...
@@ -689,6 +693,7 @@ ChannelImpl::~ChannelImpl() {
...
@@ -689,6 +693,7 @@ ChannelImpl::~ChannelImpl() {
void
ChannelImpl
::
produce_tensor
(
TensorInfo
*
dest
,
TensorPtr
ptr
)
{
void
ChannelImpl
::
produce_tensor
(
TensorInfo
*
dest
,
TensorPtr
ptr
)
{
auto
&
state
=
get_worker_state
();
auto
&
state
=
get_worker_state
();
MGB_LOCK_GUARD
(
m_mutex
);
MGB_LOCK_GUARD
(
m_mutex
);
//mgb_log_warn(">>> MGB_LOCK_GUARD produce_tensor");
m_dtr
.
update_used_time
(
dest
);
m_dtr
.
update_used_time
(
dest
);
MGB_RECORD_EVENT
(
MGB_RECORD_EVENT
(
TensorProduceEvent
,
dest
->
id
,
ptr
->
layout
(),
ptr
->
comp_node
(),
TensorProduceEvent
,
dest
->
id
,
ptr
->
layout
(),
ptr
->
comp_node
(),
...
@@ -715,16 +720,19 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
...
@@ -715,16 +720,19 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
m_dtr
.
insert_candidate
(
dest
);
m_dtr
.
insert_candidate
(
dest
);
}
}
notify_tensor_unsafe
(
dest
);
notify_tensor_unsafe
(
dest
);
//mgb_log_warn("<<< MGB_LOCK_GUARD produce_tensor");
}
}
void
ChannelImpl
::
release_tensor
(
TensorInfo
*
dest
)
{
void
ChannelImpl
::
release_tensor
(
TensorInfo
*
dest
)
{
MGB_RECORD_EVENT
(
TensorReleaseEvent
,
dest
->
id
);
MGB_RECORD_EVENT
(
TensorReleaseEvent
,
dest
->
id
);
MGB_LOCK_GUARD
(
m_mutex
);
MGB_LOCK_GUARD
(
m_mutex
);
//mgb_log_warn(">>> MGB_LOCK_GUARD release_tensor");
dest
->
ptr
.
reset
();
dest
->
ptr
.
reset
();
auto
&
state
=
get_worker_state
();
auto
&
state
=
get_worker_state
();
if
(
dest
->
size_exceeds_thd
(
state
.
options
.
dtr_evictee_minimum_size
))
{
if
(
dest
->
size_exceeds_thd
(
state
.
options
.
dtr_evictee_minimum_size
))
{
m_dtr
.
erase_candidate
(
dest
);
m_dtr
.
erase_candidate
(
dest
);
}
}
//mgb_log_warn("<<< MGB_LOCK_GUARD release_tensor");
}
}
void
ChannelImpl
::
regenerate
(
TensorInfo
*
dest
)
{
void
ChannelImpl
::
regenerate
(
TensorInfo
*
dest
)
{
...
@@ -1000,6 +1008,7 @@ bool ChannelImpl::check_available() {
...
@@ -1000,6 +1008,7 @@ bool ChannelImpl::check_available() {
TensorPtr
ChannelImpl
::
wait_tensor
(
TensorInfo
*
info
,
TensorProp
prop
)
{
TensorPtr
ChannelImpl
::
wait_tensor
(
TensorInfo
*
info
,
TensorProp
prop
)
{
std
::
unique_lock
<
decltype
(
m_mutex
)
>
lock
(
m_mutex
);
std
::
unique_lock
<
decltype
(
m_mutex
)
>
lock
(
m_mutex
);
//mgb_log_warn(">>> MGB_LOCK_GUARD wait_tensor");
mgb_assert
(
!
m_waitee
,
"duplicate waitee"
);
mgb_assert
(
!
m_waitee
,
"duplicate waitee"
);
m_waitee
=
info
;
m_waitee
=
info
;
m_waitee_id
=
Profiler
::
next_id
();
m_waitee_id
=
Profiler
::
next_id
();
...
@@ -1010,6 +1019,7 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
...
@@ -1010,6 +1019,7 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
if
(
require_host
&&
!
host_available
())
{
if
(
require_host
&&
!
host_available
())
{
// avoid dead lock
// avoid dead lock
lock
.
unlock
();
lock
.
unlock
();
//mgb_log_warn("<<< MGB_LOCK_GUARD wait_tensor unlock");
if
(
Profiler
::
is_profiling
())
{
if
(
Profiler
::
is_profiling
())
{
m_worker
.
add_task
(
m_worker
.
add_task
(
{
Profiler
::
next_id
(),
GetValue
{
info
},
{
Profiler
::
next_id
(),
GetValue
{
info
},
...
@@ -1021,18 +1031,21 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
...
@@ -1021,18 +1031,21 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
});
});
}
}
lock
.
lock
();
lock
.
lock
();
//mgb_log_warn(">>> MGB_LOCK_GUARD wait_tensor lock");
wait_host
=
true
;
wait_host
=
true
;
}
}
m_cv
.
wait
(
lock
,
[
&
]()
{
m_cv
.
wait
(
lock
,
[
&
]()
{
check_worker_exc_unsafe
();
check_worker_exc_unsafe
();
return
require_host
?
host_available
()
:
static_cast
<
bool
>
(
info
->
ptr
);
return
require_host
?
host_available
()
:
static_cast
<
bool
>
(
info
->
ptr
);
});
});
//mgb_log_warn("after cv wait");
MGB_RECORD_EVENT
(
TensorWaitPropFinishEvent
,
info
->
id
,
m_waitee_id
,
prop
);
MGB_RECORD_EVENT
(
TensorWaitPropFinishEvent
,
info
->
id
,
m_waitee_id
,
prop
);
m_waitee
=
nullptr
;
m_waitee
=
nullptr
;
if
(
wait_host
)
{
if
(
wait_host
)
{
auto
err
=
info
->
ptr
->
comp_node
().
check_async_error
();
auto
err
=
info
->
ptr
->
comp_node
().
check_async_error
();
mgb_assert
(
!
err
,
"%s"
,
err
->
what
());
mgb_assert
(
!
err
,
"%s"
,
err
->
what
());
}
}
//mgb_log_warn("<<< MGB_LOCK_GUARD wait_tensor");
return
info
->
ptr
;
return
info
->
ptr
;
}
}
...
@@ -1040,6 +1053,7 @@ void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
...
@@ -1040,6 +1053,7 @@ void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
if
(
info
==
m_waitee
)
{
if
(
info
==
m_waitee
)
{
MGB_RECORD_EVENT
(
TensorNotifyPropEvent
,
info
->
id
);
MGB_RECORD_EVENT
(
TensorNotifyPropEvent
,
info
->
id
);
m_cv
.
notify_all
();
m_cv
.
notify_all
();
//mgb_log_warn("cv notify_all");
}
}
}
}
...
@@ -1102,6 +1116,7 @@ void ChannelImpl::process_one_task(Command& icmd) {
...
@@ -1102,6 +1116,7 @@ void ChannelImpl::process_one_task(Command& icmd) {
using
namespace
ranges
::
views
;
using
namespace
ranges
::
views
;
auto
&
state
=
get_worker_state
();
auto
&
state
=
get_worker_state
();
auto
&
options
=
state
.
options
;
auto
&
options
=
state
.
options
;
//mgb_log_warn("process_one_task %s", to_string<Command>(icmd).c_str());
// TODO: remove std::visit for support osx 10.12
// TODO: remove std::visit for support osx 10.12
auto
cmd_visitor
=
[
&
](
const
auto
&
cmd
)
{
auto
cmd_visitor
=
[
&
](
const
auto
&
cmd
)
{
using
T
=
std
::
decay_t
<
decltype
(
cmd
)
>
;
using
T
=
std
::
decay_t
<
decltype
(
cmd
)
>
;
...
@@ -1123,9 +1138,11 @@ void ChannelImpl::process_one_task(Command& icmd) {
...
@@ -1123,9 +1138,11 @@ void ChannelImpl::process_one_task(Command& icmd) {
for
(
auto
&
i
:
cmd
.
inputs
)
{
for
(
auto
&
i
:
cmd
.
inputs
)
{
if
(
mgb_unlikely
(
i
->
invalid
))
{
if
(
mgb_unlikely
(
i
->
invalid
))
{
MGB_LOCK_GUARD
(
m_mutex
);
MGB_LOCK_GUARD
(
m_mutex
);
//mgb_log_warn(">>> MGB_LOCK_GUARD ApplyOp");
for
(
auto
&
i
:
cmd
.
outputs
)
{
for
(
auto
&
i
:
cmd
.
outputs
)
{
i
->
invalid
=
true
;
i
->
invalid
=
true
;
}
}
//mgb_log_warn("<<< MGB_LOCK_GUARD ApplyOp");
return
;
return
;
}
}
}
}
...
@@ -1210,8 +1227,10 @@ void ChannelImpl::process_one_task(Command& icmd) {
...
@@ -1210,8 +1227,10 @@ void ChannelImpl::process_one_task(Command& icmd) {
}
}
cmd
.
dest
->
ptr
->
fetch_value
();
cmd
.
dest
->
ptr
->
fetch_value
();
MGB_LOCK_GUARD
(
m_mutex
);
MGB_LOCK_GUARD
(
m_mutex
);
//mgb_log_warn(">>> MGB_LOCK_GUARD GetValue");
notify_tensor_unsafe
(
cmd
.
dest
);
notify_tensor_unsafe
(
cmd
.
dest
);
imperative_log_profile_end
(
"GetValue"
);
imperative_log_profile_end
(
"GetValue"
);
//mgb_log_warn("<<< MGB_LOCK_GUARD GetValue");
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
Drop
>
)
{
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
Drop
>
)
{
if
(
cmd
.
dest
->
invalid
)
if
(
cmd
.
dest
->
invalid
)
return
;
return
;
...
@@ -1271,6 +1290,7 @@ void ChannelImpl::process_one_task(Command& icmd) {
...
@@ -1271,6 +1290,7 @@ void ChannelImpl::process_one_task(Command& icmd) {
cmd_visitor
(
cmd
);
cmd_visitor
(
cmd
);
}
catch
(...)
{
}
catch
(...)
{
MGB_LOCK_GUARD
(
m_mutex
);
MGB_LOCK_GUARD
(
m_mutex
);
//mgb_log_warn(">>> MGB_LOCK_GUARD catch exception");
if
constexpr
(
std
::
is_same_v
<
T
,
ApplyOp
>
)
{
if
constexpr
(
std
::
is_same_v
<
T
,
ApplyOp
>
)
{
for
(
auto
oup
:
cmd
.
outputs
)
{
for
(
auto
oup
:
cmd
.
outputs
)
{
oup
->
invalid
=
true
;
oup
->
invalid
=
true
;
...
@@ -1283,6 +1303,7 @@ void ChannelImpl::process_one_task(Command& icmd) {
...
@@ -1283,6 +1303,7 @@ void ChannelImpl::process_one_task(Command& icmd) {
if
(
m_waitee
)
{
if
(
m_waitee
)
{
notify_tensor_unsafe
(
m_waitee
);
notify_tensor_unsafe
(
m_waitee
);
}
}
//mgb_log_warn("<<< MGB_LOCK_GUARD catch exception");
}
}
},
},
icmd
.
data
);
icmd
.
data
);
...
...
imperative/src/impl/transformations/format.cpp
浏览文件 @
c28a875f
...
@@ -33,9 +33,8 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to(
...
@@ -33,9 +33,8 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to(
tensor
.
format
().
to_string
().
c_str
(),
tensor
.
format
().
to_string
().
c_str
(),
Format
(
target
).
to_string
().
c_str
());
Format
(
target
).
to_string
().
c_str
());
}
}
auto
output
=
imperative
::
apply
(
auto
output
=
*
Dimshuffle
::
make
(
pattern
,
scope
),
imperative
::
apply
(
*
Dimshuffle
::
make
(
pattern
,
scope
),
{
tensor
.
value
()})[
0
];
SmallVector
<
ValueRef
>
{
tensor
.
value
()})[
0
];
return
m_value_type
.
make
(
output
,
target
);
return
m_value_type
.
make
(
output
,
target
);
}
}
...
@@ -90,6 +89,27 @@ ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) {
...
@@ -90,6 +89,27 @@ ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) {
}
}
}
}
std
::
vector
<
int32_t
>
convert_nchw2nhwc_vector
(
const
std
::
vector
<
int32_t
>&
shape
)
{
auto
out
=
std
::
vector
<
int32_t
>
(
shape
);
if
(
shape
.
size
()
==
4
)
{
out
[
1
]
=
shape
[
2
];
out
[
2
]
=
shape
[
3
];
out
[
3
]
=
shape
[
1
];
return
out
;
}
else
if
(
shape
.
size
()
==
5
)
{
// GIOHW -> GIHWO
out
[
2
]
=
shape
[
3
];
out
[
3
]
=
shape
[
4
];
out
[
4
]
=
shape
[
2
];
return
out
;
}
else
{
mgb_throw
(
MegBrainError
,
"Unsupported shape ndim %u in convert NCHW shape to NHWC."
,
shape
.
size
());
}
}
using
FormatRule
=
std
::
function
<
ValueRefList
(
using
FormatRule
=
std
::
function
<
ValueRefList
(
const
OpDef
&
,
Span
<
ValueRef
>&
,
const
bool
&
,
const
FormatTransformation
&
)
>
;
const
OpDef
&
,
Span
<
ValueRef
>&
,
const
bool
&
,
const
FormatTransformation
&
)
>
;
static
std
::
unordered_map
<
Typeinfo
*
,
FormatRule
>
format_rules
;
static
std
::
unordered_map
<
Typeinfo
*
,
FormatRule
>
format_rules
;
...
@@ -156,22 +176,38 @@ ValueRef convert_nchw2nhwc_tensornd(const HostTensorND& shape) {
...
@@ -156,22 +176,38 @@ ValueRef convert_nchw2nhwc_tensornd(const HostTensorND& shape) {
ValueRefList
reshape_rule
(
ValueRefList
reshape_rule
(
const
Reshape
&
op
,
Span
<
ValueRef
>&
inputs
,
const
bool
&
auto_convert
,
const
Reshape
&
op
,
Span
<
ValueRef
>&
inputs
,
const
bool
&
auto_convert
,
const
FormatTransformation
&
t
)
{
const
FormatTransformation
&
t
)
{
mgb_assert
(
inputs
.
size
()
==
2
);
mgb_assert
(
inputs
.
size
()
>=
1
);
auto
&
src
=
inputs
[
0
].
cast
(
t
.
value_type
());
auto
&
src
=
inputs
[
0
].
cast
(
t
.
value_type
());
if
(
auto_convert
&&
src
.
format
()
==
FT
::
NHWC
)
{
if
(
auto_convert
&&
src
.
format
()
==
FT
::
NHWC
)
{
auto
shape
=
t
.
unwrap_input
(
inputs
[
1
]).
numpy
()
->
as_nd
();
if
(
inputs
.
size
()
==
1
)
{
if
(
shape
.
layout
().
total_nr_elems
()
==
4
)
{
if
(
op
.
shape
.
size
()
==
4
)
{
// output is still NHWC format
// output is still NHWC format
auto
nhwc_shape
=
convert_nchw2nhwc_tensornd
(
shape
);
auto
nhwc_shape
=
convert_nchw2nhwc_vector
(
op
.
shape
);
auto
outputs
=
imperative
::
apply
(
auto
outputs
=
imperative
::
apply
(
op
,
SmallVector
<
ValueRef
>
{
t
.
unwrap_input
(
inputs
[
0
]),
nhwc_shape
});
*
Reshape
::
make
(
op
.
axis
,
nhwc_shape
),
{
t
.
unwrap_input
(
inputs
[
0
])});
return
t
.
wrap_outputs
(
outputs
,
FT
::
NHWC
);
return
t
.
wrap_outputs
(
outputs
,
FT
::
NHWC
);
}
else
{
}
else
{
// will not maintain src's format
// will not maintain src's format
auto
nchw_src
=
t
.
to
(
src
,
FT
::
NCHW
,
op
.
scope
())
->
value
();
auto
nchw_src
=
t
.
to
(
src
,
FT
::
NCHW
,
op
.
scope
())
->
value
();
auto
outputs
=
imperative
::
apply
(
auto
outputs
=
imperative
::
apply
(
op
,
{
nchw_src
});
op
,
SmallVector
<
ValueRef
>
{
nchw_src
,
t
.
unwrap_input
(
inputs
[
1
])});
return
t
.
wrap_outputs
(
outputs
);
return
t
.
wrap_outputs
(
outputs
);
}
}
else
if
(
inputs
.
size
()
==
2
)
{
auto
shape
=
t
.
unwrap_input
(
inputs
[
1
]).
numpy
()
->
as_nd
();
if
(
shape
.
layout
().
total_nr_elems
()
==
4
)
{
// output is still NHWC format
auto
nhwc_shape
=
convert_nchw2nhwc_tensornd
(
shape
);
auto
outputs
=
imperative
::
apply
(
op
,
SmallVector
<
ValueRef
>
{
t
.
unwrap_input
(
inputs
[
0
]),
nhwc_shape
});
return
t
.
wrap_outputs
(
outputs
,
FT
::
NHWC
);
}
else
{
// will not maintain src's format
auto
nchw_src
=
t
.
to
(
src
,
FT
::
NCHW
,
op
.
scope
())
->
value
();
auto
outputs
=
imperative
::
apply
(
op
,
SmallVector
<
ValueRef
>
{
nchw_src
,
t
.
unwrap_input
(
inputs
[
1
])});
return
t
.
wrap_outputs
(
outputs
);
}
}
}
}
}
return
t
.
wrap_outputs
(
imperative
::
apply
(
op
,
t
.
unwrap_inputs
(
inputs
)));
return
t
.
wrap_outputs
(
imperative
::
apply
(
op
,
t
.
unwrap_inputs
(
inputs
)));
...
@@ -180,22 +216,38 @@ ValueRefList reshape_rule(
...
@@ -180,22 +216,38 @@ ValueRefList reshape_rule(
ValueRefList
broadcast_rule
(
ValueRefList
broadcast_rule
(
const
Broadcast
&
op
,
Span
<
ValueRef
>&
inputs
,
const
bool
&
auto_convert
,
const
Broadcast
&
op
,
Span
<
ValueRef
>&
inputs
,
const
bool
&
auto_convert
,
const
FormatTransformation
&
t
)
{
const
FormatTransformation
&
t
)
{
mgb_assert
(
inputs
.
size
()
==
2
);
mgb_assert
(
inputs
.
size
()
>=
1
);
auto
&
src
=
inputs
[
0
].
cast
(
t
.
value_type
());
auto
&
src
=
inputs
[
0
].
cast
(
t
.
value_type
());
if
(
auto_convert
&&
src
.
format
()
==
FT
::
NHWC
)
{
if
(
auto_convert
&&
src
.
format
()
==
FT
::
NHWC
)
{
auto
shape
=
t
.
unwrap_input
(
inputs
[
1
]).
numpy
()
->
as_nd
();
if
(
inputs
.
size
()
==
1
)
{
if
(
shape
.
layout
().
total_nr_elems
()
==
4
)
{
if
(
op
.
shape
.
size
()
==
4
)
{
// output is still NHWC format
// output is still NHWC format
auto
nhwc_shape
=
convert_nchw2nhwc_tensornd
(
shape
);
auto
nhwc_shape
=
convert_nchw2nhwc_vector
(
op
.
shape
);
auto
outputs
=
imperative
::
apply
(
auto
outputs
=
imperative
::
apply
(
op
,
SmallVector
<
ValueRef
>
{
t
.
unwrap_input
(
inputs
[
0
]),
nhwc_shape
});
*
Broadcast
::
make
(
nhwc_shape
),
{
t
.
unwrap_input
(
inputs
[
0
])});
return
t
.
wrap_outputs
(
outputs
,
FT
::
NHWC
);
return
t
.
wrap_outputs
(
outputs
,
FT
::
NHWC
);
}
else
{
}
else
{
// will not maintain src's format
// will not maintain src's format
auto
nchw_src
=
t
.
to
(
src
,
FT
::
NCHW
,
op
.
scope
())
->
value
();
auto
nchw_src
=
t
.
to
(
src
,
FT
::
NCHW
,
op
.
scope
())
->
value
();
auto
outputs
=
imperative
::
apply
(
auto
outputs
=
imperative
::
apply
(
op
,
{
nchw_src
});
op
,
SmallVector
<
ValueRef
>
{
nchw_src
,
t
.
unwrap_input
(
inputs
[
1
])});
return
t
.
wrap_outputs
(
outputs
);
return
t
.
wrap_outputs
(
outputs
);
}
}
else
if
(
inputs
.
size
()
==
2
)
{
auto
shape
=
t
.
unwrap_input
(
inputs
[
1
]).
numpy
()
->
as_nd
();
if
(
shape
.
layout
().
total_nr_elems
()
==
4
)
{
// output is still NHWC format
auto
nhwc_shape
=
convert_nchw2nhwc_tensornd
(
shape
);
auto
outputs
=
imperative
::
apply
(
op
,
SmallVector
<
ValueRef
>
{
t
.
unwrap_input
(
inputs
[
0
]),
nhwc_shape
});
return
t
.
wrap_outputs
(
outputs
,
FT
::
NHWC
);
}
else
{
// will not maintain src's format
auto
nchw_src
=
t
.
to
(
src
,
FT
::
NCHW
,
op
.
scope
())
->
value
();
auto
outputs
=
imperative
::
apply
(
op
,
SmallVector
<
ValueRef
>
{
nchw_src
,
t
.
unwrap_input
(
inputs
[
1
])});
return
t
.
wrap_outputs
(
outputs
);
}
}
}
}
}
return
t
.
wrap_outputs
(
imperative
::
apply
(
op
,
t
.
unwrap_inputs
(
inputs
)));
return
t
.
wrap_outputs
(
imperative
::
apply
(
op
,
t
.
unwrap_inputs
(
inputs
)));
...
@@ -240,8 +292,7 @@ ValueRefList subtensor_rule(
...
@@ -240,8 +292,7 @@ ValueRefList subtensor_rule(
// only support NHWC2NCHW convert, otherwise maintain src's format
// only support NHWC2NCHW convert, otherwise maintain src's format
if
(
!
(
auto_convert
&&
src
.
format
()
==
FT
::
NHWC
))
{
if
(
!
(
auto_convert
&&
src
.
format
()
==
FT
::
NHWC
))
{
return
{
t
.
wrap_output
(
return
{
t
.
wrap_output
(
imperative
::
apply
(
op
,
t
.
unwrap_inputs
(
inputs
))[
0
],
imperative
::
apply
(
op
,
t
.
unwrap_inputs
(
inputs
))[
0
],
src
.
format
())};
src
.
format
())};
}
}
auto
nhwc_items
=
convert_nchw2nhwc_idx_items
(
op
.
items
);
auto
nhwc_items
=
convert_nchw2nhwc_idx_items
(
op
.
items
);
auto
outputs
=
imperative
::
apply
(
auto
outputs
=
imperative
::
apply
(
...
@@ -263,8 +314,7 @@ ValueRefList setsubtensor_rule(
...
@@ -263,8 +314,7 @@ ValueRefList setsubtensor_rule(
// only support NHWC2NCHW convert, otherwise maintain src's format
// only support NHWC2NCHW convert, otherwise maintain src's format
if
(
!
(
auto_convert
&&
src
.
format
()
==
FT
::
NHWC
))
{
if
(
!
(
auto_convert
&&
src
.
format
()
==
FT
::
NHWC
))
{
return
{
t
.
wrap_output
(
return
{
t
.
wrap_output
(
imperative
::
apply
(
op
,
t
.
unwrap_inputs
(
inputs
))[
0
],
imperative
::
apply
(
op
,
t
.
unwrap_inputs
(
inputs
))[
0
],
src
.
format
())};
src
.
format
())};
}
}
// value has been broadcasted to src's fake NCHW shape.
// value has been broadcasted to src's fake NCHW shape.
auto
&
value
=
inputs
[
1
].
cast
(
t
.
value_type
());
auto
&
value
=
inputs
[
1
].
cast
(
t
.
value_type
());
...
@@ -329,8 +379,7 @@ ValueRefList identity_rule_helper(
...
@@ -329,8 +379,7 @@ ValueRefList identity_rule_helper(
const
OpDef
&
op
,
const
Span
<
ValueRef
>&
inputs
,
const
FormatTransformation
&
t
)
{
const
OpDef
&
op
,
const
Span
<
ValueRef
>&
inputs
,
const
FormatTransformation
&
t
)
{
// mgb_assert(inputs.size() == 1);
// mgb_assert(inputs.size() == 1);
auto
&
src
=
inputs
[
0
].
cast
(
t
.
value_type
());
auto
&
src
=
inputs
[
0
].
cast
(
t
.
value_type
());
return
t
.
wrap_outputs
(
return
t
.
wrap_outputs
(
imperative
::
apply
(
op
,
t
.
unwrap_inputs
(
inputs
)),
src
.
format
());
imperative
::
apply
(
op
,
t
.
unwrap_inputs
(
inputs
)),
src
.
format
());
}
}
ValueRefList
batchnorm_rule
(
ValueRefList
batchnorm_rule
(
...
@@ -457,6 +506,7 @@ struct FormatRuleRegistry {
...
@@ -457,6 +506,7 @@ struct FormatRuleRegistry {
ValueRefList
FormatTransformation
::
apply_transformation
(
ValueRefList
FormatTransformation
::
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
//mgb_log_warn("Format::apply_transformation %s", op.to_string().c_str());
if
(
auto
*
apply_op
=
op
.
as
<
ApplyOp
>
())
{
if
(
auto
*
apply_op
=
op
.
as
<
ApplyOp
>
())
{
// all inputs should be FormattedTensorValue
// all inputs should be FormattedTensorValue
auto
iter
=
format_rules
.
find
(
apply_op
->
op
().
dyn_typeinfo
());
auto
iter
=
format_rules
.
find
(
apply_op
->
op
().
dyn_typeinfo
());
...
@@ -485,7 +535,7 @@ ValueRefList FormatTransformation::apply_transformation(
...
@@ -485,7 +535,7 @@ ValueRefList FormatTransformation::apply_transformation(
}
}
case
GetAttr
::
Value
:
{
case
GetAttr
::
Value
:
{
auto
nchw_src
=
unwrap_input
(
to
(
src
,
FT
::
NCHW
,
""
));
auto
nchw_src
=
unwrap_input
(
to
(
src
,
FT
::
NCHW
,
""
));
return
imperative
::
apply
(
op
,
SmallVector
<
ValueRef
>
{
nchw_src
});
return
imperative
::
apply
(
op
,
{
nchw_src
});
}
}
default:
default:
return
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
));
return
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
));
...
@@ -508,8 +558,7 @@ ValueRefList FormatTransformation::apply_transformation(
...
@@ -508,8 +558,7 @@ ValueRefList FormatTransformation::apply_transformation(
auto
&&
inp_ref
=
inputs
[
0
].
as_ref
(
m_value_type
);
auto
&&
inp_ref
=
inputs
[
0
].
as_ref
(
m_value_type
);
if
(
inp_ref
)
{
if
(
inp_ref
)
{
auto
&&
format
=
inp_ref
->
format
();
auto
&&
format
=
inp_ref
->
format
();
return
wrap_outputs
(
return
wrap_outputs
(
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)),
format
);
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)),
format
);
}
else
{
}
else
{
mgb_log_warn
(
mgb_log_warn
(
"Not FormattedTensorValue input for IdentityLike op: %s, %s"
,
"Not FormattedTensorValue input for IdentityLike op: %s, %s"
,
...
@@ -522,6 +571,7 @@ ValueRefList FormatTransformation::apply_transformation(
...
@@ -522,6 +571,7 @@ ValueRefList FormatTransformation::apply_transformation(
auto
format
=
inp_ref
->
format
();
auto
format
=
inp_ref
->
format
();
GenericFunction
callback
=
GenericFunction
callback
=
(
GenericFunction
&
)
inputs
[
1
].
cast
<
FunctionValue
>
();
(
GenericFunction
&
)
inputs
[
1
].
cast
<
FunctionValue
>
();
// make param grads as FormattedTensor
GenericFunction
new_callback
=
GenericFunction
new_callback
=
[
this
,
callback
,
format
](
Span
<
ValueRef
>
inputs_
)
->
ValueRefList
{
[
this
,
callback
,
format
](
Span
<
ValueRef
>
inputs_
)
->
ValueRefList
{
auto
wrapped_inputs
=
SmallVector
<
ValueRef
>
{
auto
wrapped_inputs
=
SmallVector
<
ValueRef
>
{
...
@@ -531,6 +581,7 @@ ValueRefList FormatTransformation::apply_transformation(
...
@@ -531,6 +581,7 @@ ValueRefList FormatTransformation::apply_transformation(
};
};
auto
&&
outputs
=
imperative
::
apply
(
auto
&&
outputs
=
imperative
::
apply
(
op
,
inp_ref
->
value
(),
FunctionValue
::
make
(
new_callback
));
op
,
inp_ref
->
value
(),
FunctionValue
::
make
(
new_callback
));
// make params(GradValue) as FormattedTensor
return
wrap_outputs
(
outputs
,
format
);
return
wrap_outputs
(
outputs
,
format
);
}
else
{
}
else
{
mgb_log_warn
(
mgb_log_warn
(
...
@@ -539,6 +590,7 @@ ValueRefList FormatTransformation::apply_transformation(
...
@@ -539,6 +590,7 @@ ValueRefList FormatTransformation::apply_transformation(
return
imperative
::
apply
(
op
,
inputs
);
return
imperative
::
apply
(
op
,
inputs
);
}
}
}
else
if
(
auto
*
set_grad
=
op
.
as
<
SetGrad
>
())
{
}
else
if
(
auto
*
set_grad
=
op
.
as
<
SetGrad
>
())
{
// make grads in Function backward as FormattedTensor
size_t
nr_inputs
=
set_grad
->
nr_inputs
();
size_t
nr_inputs
=
set_grad
->
nr_inputs
();
size_t
nr_outputs
=
inputs
.
size
()
-
nr_inputs
;
size_t
nr_outputs
=
inputs
.
size
()
-
nr_inputs
;
Span
<
ValueRef
>
inputs_
=
{
inputs
.
data
(),
nr_inputs
};
Span
<
ValueRef
>
inputs_
=
{
inputs
.
data
(),
nr_inputs
};
...
...
imperative/src/include/megbrain/imperative/transformations/grad.h
浏览文件 @
c28a875f
...
@@ -377,8 +377,6 @@ public:
...
@@ -377,8 +377,6 @@ public:
SetGrad
(
GenericFunction
grad_fn
,
size_t
nr_inputs
)
SetGrad
(
GenericFunction
grad_fn
,
size_t
nr_inputs
)
:
m_grad_fn
(
grad_fn
),
m_nr_inputs
(
nr_inputs
)
{}
:
m_grad_fn
(
grad_fn
),
m_nr_inputs
(
nr_inputs
)
{}
std
::
shared_ptr
<
GradKey
>
key
()
const
{
return
m_key
;
}
GenericFunction
grad_fn
()
const
{
return
m_grad_fn
;
}
GenericFunction
grad_fn
()
const
{
return
m_grad_fn
;
}
size_t
nr_inputs
()
const
{
return
m_nr_inputs
;
}
size_t
nr_inputs
()
const
{
return
m_nr_inputs
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录