Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
9a6ba334
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
9a6ba334
编写于
11月 17, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative): speed up concat and stack
GitOrigin-RevId: 614e87171908f419f98eb4c150c0a10a13f85c60
上级
489af281
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
539 addition
and
36 deletion
+539
-36
imperative/python/megengine/functional/tensor.py
imperative/python/megengine/functional/tensor.py
+9
-10
imperative/python/scripts/format.sh
imperative/python/scripts/format.sh
+12
-4
imperative/python/src/tensor_utils.cpp
imperative/python/src/tensor_utils.cpp
+3
-0
imperative/python/test/unit/functional/test_tensor.py
imperative/python/test/unit/functional/test_tensor.py
+49
-0
imperative/src/impl/dnn_op_helper.h
imperative/src/impl/dnn_op_helper.h
+3
-3
imperative/src/impl/ops/concatenate.cpp
imperative/src/impl/ops/concatenate.cpp
+259
-0
imperative/src/impl/ops/specializations.cpp
imperative/src/impl/ops/specializations.cpp
+0
-12
imperative/src/impl/proxy_graph/proxy_graph.cpp
imperative/src/impl/proxy_graph/proxy_graph.cpp
+1
-1
imperative/src/impl/transformations/dtype_promote.cpp
imperative/src/impl/transformations/dtype_promote.cpp
+1
-0
imperative/src/include/megbrain/imperative/physical_tensor.h
imperative/src/include/megbrain/imperative/physical_tensor.h
+1
-1
imperative/tablegen/generated/hash.txt
imperative/tablegen/generated/hash.txt
+5
-5
imperative/tablegen/generated/opdef.cpp.inl
imperative/tablegen/generated/opdef.cpp.inl
+40
-0
imperative/tablegen/generated/opdef.cpy.inl
imperative/tablegen/generated/opdef.cpy.inl
+128
-0
imperative/tablegen/generated/opdef.h.inl
imperative/tablegen/generated/opdef.h.inl
+14
-0
imperative/tablegen/generated/opdef.py.inl
imperative/tablegen/generated/opdef.py.inl
+8
-0
src/core/include/megbrain/ir/ops.td
src/core/include/megbrain/ir/ops.td
+6
-0
未找到文件。
imperative/python/megengine/functional/tensor.py
浏览文件 @
9a6ba334
...
@@ -489,9 +489,9 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
...
@@ -489,9 +489,9 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
return
inps
[
0
]
return
inps
[
0
]
if
device
is
None
:
if
device
is
None
:
device
=
get_de
vice
(
inps
)
device
=
get_de
fault_device
(
)
device
=
as_device
(
device
)
(
result
,)
=
apply
(
builtin
.
Concat
(
axis
=
axis
,
comp_node
=
device
.
to_c
()
),
*
inps
)
(
result
,)
=
apply
(
builtin
.
Concat
(
axis
=
axis
,
comp_node
=
device
),
*
inps
)
return
result
return
result
...
@@ -516,13 +516,12 @@ def stack(inps, axis=0, device=None):
...
@@ -516,13 +516,12 @@ def stack(inps, axis=0, device=None):
array([[0., 1., 2.],
array([[0., 1., 2.],
[6., 7., 8.]], dtype=float32)
[6., 7., 8.]], dtype=float32)
"""
"""
if
len
(
inps
)
>
0
and
not
isinstance
(
inps
[
0
].
shape
,
inps
[
0
].
__class__
):
if
len
(
inps
)
==
1
:
shapes
=
{
arr
.
shape
for
arr
in
inps
}
return
expand_dims
(
inps
[
0
],
axis
=
axis
)
if
len
(
shapes
)
!=
1
:
if
device
is
None
:
raise
ValueError
(
"All input tensors must have the same shape"
)
device
=
get_default_device
()
(
result
,)
=
apply
(
builtin
.
Stack
(
axis
=
axis
,
comp_node
=
device
),
*
inps
)
inps
=
[
expand_dims
(
inp
,
axis
=
axis
)
for
inp
in
inps
]
return
result
return
concat
(
inps
,
axis
=
axis
,
device
=
device
)
def
split
(
inp
,
nsplits_or_sections
,
axis
=
0
):
def
split
(
inp
,
nsplits_or_sections
,
axis
=
0
):
...
...
imperative/python/scripts/format.sh
浏览文件 @
9a6ba334
...
@@ -5,21 +5,29 @@ cd $(dirname $0)/..
...
@@ -5,21 +5,29 @@ cd $(dirname $0)/..
ISORT_ARG
=
""
ISORT_ARG
=
""
BLACK_ARG
=
""
BLACK_ARG
=
""
while
getopts
'd'
OPT
;
do
while
getopts
'd
t:
'
OPT
;
do
case
$OPT
in
case
$OPT
in
d
)
d
)
ISORT_ARG
=
"--diff --check-only"
ISORT_ARG
=
"--diff --check-only"
BLACK_ARG
=
"--diff --check"
BLACK_ARG
=
"--diff --check"
;;
;;
t
)
TARGET
=
$OPTARG
;;
?
)
?
)
echo
"Usage:
`
basename
$0
`
[-d]"
echo
"Usage:
`
basename
$0
`
[-d]"
esac
esac
done
done
directories
=(
megengine
test
)
if
[[
$TARGET
]]
;
then
if
[[
-d
examples
]]
;
then
directories
=(
$TARGET
)
directories+
=(
examples
)
else
directories
=(
megengine
test
)
if
[[
-d
examples
]]
;
then
directories+
=(
examples
)
fi
fi
fi
# do not isort megengine/__init__.py file, caused we must
# do not isort megengine/__init__.py file, caused we must
# init library load path before load dependent lib in core
# init library load path before load dependent lib in core
isort
$ISORT_ARG
-j
$(
nproc
)
-rc
"
${
directories
[@]
}
"
-s
megengine/__init__.py
isort
$ISORT_ARG
-j
$(
nproc
)
-rc
"
${
directories
[@]
}
"
-s
megengine/__init__.py
...
...
imperative/python/src/tensor_utils.cpp
浏览文件 @
9a6ba334
...
@@ -176,9 +176,12 @@ PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) {
...
@@ -176,9 +176,12 @@ PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) {
return
res
;
return
res
;
}
}
// if all the inputs are not megengine tensor, return get_default_device()
// else check whether all input tensors have the same device
CompNode
_get_device
(
PyObject
*
const
*
args
,
size_t
nargs
)
{
CompNode
_get_device
(
PyObject
*
const
*
args
,
size_t
nargs
)
{
bool
is_tuple
=
false
;
bool
is_tuple
=
false
;
PyObject
*
tuple
=
nullptr
;
PyObject
*
tuple
=
nullptr
;
// convert input args to a tuple
if
(
nargs
==
1
&&
(
PyTuple_Check
(
args
[
0
])
||
PyList_Check
(
args
[
0
])))
{
if
(
nargs
==
1
&&
(
PyTuple_Check
(
args
[
0
])
||
PyList_Check
(
args
[
0
])))
{
if
(
PyList_Check
(
args
[
0
]))
{
if
(
PyList_Check
(
args
[
0
]))
{
tuple
=
PyList_AsTuple
(
args
[
0
]);
tuple
=
PyList_AsTuple
(
args
[
0
]);
...
...
imperative/python/test/unit/functional/test_tensor.py
浏览文件 @
9a6ba334
...
@@ -217,6 +217,55 @@ def test_split_basic(is_varnode):
...
@@ -217,6 +217,55 @@ def test_split_basic(is_varnode):
set_symbolic_shape
(
saved_symbolic_shape
)
set_symbolic_shape
(
saved_symbolic_shape
)
def
test_concat_and_stack
():
import
copy
def
generate_test_data
(
max_nr_inp
,
max_dim
,
max_dim_len
,
test_concat
=
True
):
nr_inp
=
np
.
random
.
randint
(
1
,
max_nr_inp
)
dims
=
np
.
random
.
randint
(
1
,
max_dim
)
cat_axis
=
(
np
.
random
.
randint
(
-
dims
,
dims
)
if
test_concat
else
np
.
random
.
randint
(
-
dims
-
1
,
dims
+
1
)
)
ishape
=
[
np
.
random
.
randint
(
0
,
max_dim_len
)
for
_
in
range
(
dims
)]
ishapes
=
[
copy
.
deepcopy
(
ishape
)
for
_
in
range
(
nr_inp
)]
if
test_concat
:
for
i
in
range
(
nr_inp
):
ishapes
[
i
][
cat_axis
]
=
np
.
random
.
randint
(
0
,
max_dim_len
)
inp_nps
=
[]
for
ishape
in
ishapes
:
inp_nps
.
append
(
np
.
random
.
randn
(
*
ishape
))
return
inp_nps
,
cat_axis
def
test_impl
(
max_nr_inp
,
max_dim
,
max_dim_len
,
test_concat
):
inp_nps
,
cat_axis
=
generate_test_data
(
max_nr_inp
,
max_dim
,
max_dim_len
,
test_concat
)
inp_mges
=
[
Tensor
(
inp_np
)
for
inp_np
in
inp_nps
]
if
test_concat
:
np_func
,
mge_func
=
np
.
concatenate
,
F
.
concat
else
:
np_func
,
mge_func
=
np
.
stack
,
F
.
stack
res_np
=
np_func
(
inp_nps
,
axis
=
cat_axis
)
res_mge
=
mge_func
(
inp_mges
,
axis
=
cat_axis
)
np
.
testing
.
assert_allclose
(
res_mge
.
numpy
(),
res_np
)
def
test_concat
(
max_nr_inp
,
max_dim
,
max_dim_len
):
test_impl
(
max_nr_inp
,
max_dim
,
max_dim_len
,
test_concat
=
True
)
def
test_stack
(
max_nr_inp
,
max_dim
,
max_dim_len
):
test_impl
(
max_nr_inp
,
max_dim
,
max_dim_len
,
test_concat
=
False
)
for
_
in
range
(
3
):
test_concat
(
10
,
7
,
16
)
for
_
in
range
(
3
):
test_stack
(
10
,
7
,
16
)
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
None
,
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
None
,
False
,
True
])
def
test_split
(
symbolic
):
def
test_split
(
symbolic
):
x
=
Tensor
(
np
.
random
.
random
((
10
,
20
)),
dtype
=
np
.
float32
)
x
=
Tensor
(
np
.
random
.
random
((
10
,
20
)),
dtype
=
np
.
float32
)
...
...
imperative/src/impl/dnn_op_helper.h
浏览文件 @
9a6ba334
...
@@ -28,7 +28,7 @@ public:
...
@@ -28,7 +28,7 @@ public:
// FIXME: maybe in-place style deduction works better
// FIXME: maybe in-place style deduction works better
template
<
typename
...
TArgs
>
template
<
typename
...
TArgs
>
TensorLayout
deduce_layout
(
TArgs
&&
...
args
)
{
TensorLayout
deduce_layout
(
TArgs
&&
...
args
)
{
static_assert
((
std
::
is_convertible_v
<
TArgs
,
TensorLayout
>
&&
...));
//
static_assert((std::is_convertible_v<TArgs, TensorLayout> && ...));
TensorLayout
output_layout
;
TensorLayout
output_layout
;
m_opr
->
deduce_layout
(
args
...,
output_layout
);
m_opr
->
deduce_layout
(
args
...,
output_layout
);
return
output_layout
;
return
output_layout
;
...
@@ -36,7 +36,7 @@ public:
...
@@ -36,7 +36,7 @@ public:
template
<
typename
...
TArgs
>
template
<
typename
...
TArgs
>
TensorLayout
deduce_layout_fallible
(
TArgs
&&
...
args
)
{
TensorLayout
deduce_layout_fallible
(
TArgs
&&
...
args
)
{
static_assert
((
std
::
is_convertible_v
<
TArgs
,
TensorLayout
>
&&
...));
//
static_assert((std::is_convertible_v<TArgs, TensorLayout> && ...));
TensorLayout
output_layout
;
TensorLayout
output_layout
;
bool
success
=
(
args
.
ndim
*
...)
>
0
;
bool
success
=
(
args
.
ndim
*
...)
>
0
;
if
(
success
)
{
if
(
success
)
{
...
@@ -49,7 +49,7 @@ public:
...
@@ -49,7 +49,7 @@ public:
template
<
size_t
nr_outputs
,
typename
...
TArgs
>
template
<
size_t
nr_outputs
,
typename
...
TArgs
>
std
::
array
<
TensorLayout
,
nr_outputs
>
deduce_layouts
(
TArgs
&&
...
args
)
{
std
::
array
<
TensorLayout
,
nr_outputs
>
deduce_layouts
(
TArgs
&&
...
args
)
{
static_assert
((
std
::
is_convertible_v
<
TArgs
,
TensorLayout
>
&&
...));
//
static_assert((std::is_convertible_v<TArgs, TensorLayout> && ...));
std
::
array
<
TensorLayout
,
nr_outputs
>
layouts
;
std
::
array
<
TensorLayout
,
nr_outputs
>
layouts
;
std
::
apply
(
std
::
apply
(
[
&
](
auto
&&
...
outputs
)
{
m_opr
->
deduce_layout
(
args
...,
outputs
...);
},
[
&
](
auto
&&
...
outputs
)
{
m_opr
->
deduce_layout
(
args
...,
outputs
...);
},
...
...
imperative/src/impl/ops/concatenate.cpp
0 → 100644
浏览文件 @
9a6ba334
#include <climits>
#include "../dnn_op_helper.h"
#include "../op_trait.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/utils/stats.h"
namespace
mgb
::
imperative
{
namespace
{
template
<
typename
Opr
>
CompNode
get_device
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
auto
&&
op_def
=
def
.
cast_final_safe
<
Opr
>
();
const
char
*
op_name
=
op_def
.
make_name
().
c_str
();
CompNode
oup_cn
=
op_def
.
comp_node
;
if
(
!
oup_cn
.
valid
())
{
size_t
nr_inp
=
inputs
.
size
();
mgb_assert
(
nr_inp
>
0
,
"number of inputs of %s should be greater than 0"
,
op_name
);
auto
&&
inp_cn
=
inputs
[
0
].
comp_node
;
for
(
size_t
i
=
1
;
i
<
nr_inp
;
++
i
)
{
mgb_assert
(
inp_cn
==
inputs
[
i
].
comp_node
,
"input tensors of %s operator should have same device, but get "
"%s vs %s"
,
op_name
,
inp_cn
.
to_string
().
c_str
(),
inputs
[
i
].
comp_node
.
to_string
().
c_str
());
}
oup_cn
=
inp_cn
;
}
return
oup_cn
;
}
bool
is_all_inputs_valid
(
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
bool
input_valid
=
true
;
size_t
nr_inp
=
inputs
.
size
();
for
(
size_t
i
=
0
;
i
<
nr_inp
;
++
i
)
{
if
(
inputs
[
i
].
layout
.
ndim
==
0
)
{
input_valid
=
false
;
break
;
}
}
return
input_valid
;
}
}
// namespace
namespace
concatenate
{
TensorLayout
concat_layout_deduce
(
const
SmallVector
<
const
TensorLayout
*>
inputs
,
int
axis
)
{
// if we use megdnn::Concat::deduce_layout directly, we need construct
// TensorLayoutArray, which will result in much memory copy
auto
shape_equal_but_specific_axis
=
[](
const
TensorShape
&
lhs
,
const
TensorShape
&
rhs
,
int
axis
)
->
bool
{
if
(
lhs
.
ndim
!=
rhs
.
ndim
)
{
return
false
;
}
for
(
size_t
i
=
0
;
i
<
lhs
.
ndim
;
++
i
)
{
if
(
i
==
axis
)
continue
;
if
(
lhs
.
shape
[
i
]
!=
rhs
.
shape
[
i
])
return
false
;
}
return
true
;
};
TensorLayout
oup_layout
=
*
inputs
[
0
];
for
(
size_t
i
=
1
;
i
<
inputs
.
size
();
++
i
)
{
mgb_assert
(
shape_equal_but_specific_axis
(
oup_layout
,
*
inputs
[
i
],
axis
),
"Concat input shape mismatch: %s vs %s"
,
inputs
[
0
]
->
to_string
().
c_str
(),
inputs
[
i
]
->
to_string
().
c_str
());
oup_layout
.
shape
[
axis
]
+=
inputs
[
i
]
->
shape
[
axis
];
}
oup_layout
.
init_contiguous_stride
();
return
oup_layout
;
}
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
Concat
&>
(
def
);
cg
::
OperatorNodeConfig
config
{
op
.
comp_node
};
config
.
name
(
op
.
make_name
());
return
opr
::
Concat
::
make
(
inputs
,
op
.
axis
,
config
);
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
auto
&&
op_def
=
def
.
cast_final_safe
<
Concat
>
();
auto
oup_cn
=
get_device
<
Concat
>
(
def
,
inputs
);
if
(
!
is_all_inputs_valid
(
inputs
))
{
// because dtypepromote_trans, so use inputs[0].dtype as oup_dtype here
return
{{{
TensorLayout
{
inputs
[
0
].
layout
.
dtype
},
oup_cn
,
{}}},
false
};
}
SmallVector
<
const
TensorLayout
*>
inputs_holder
(
inputs
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
inputs_holder
[
i
]
=
&
inputs
[
i
].
layout
;
}
int
axis
=
op_def
.
axis
>=
0
?
op_def
.
axis
:
op_def
.
axis
+
inputs
[
0
].
layout
.
ndim
;
TensorLayout
oup_layout
=
concat_layout_deduce
(
inputs_holder
,
axis
);
return
{{{
oup_layout
,
oup_cn
,
{}}},
true
};
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
,
SmallVector
<
LogicalTensorDesc
>&
output_descs
,
const
bool
&
validated
)
{
auto
&&
op_def
=
def
.
cast_final_safe
<
Concat
>
();
int
axis
=
op_def
.
axis
>=
0
?
op_def
.
axis
:
op_def
.
axis
+
inputs
[
0
]
->
layout
().
ndim
;
CompNode
&
oup_cn
=
output_descs
[
0
].
comp_node
;
if
(
op_def
.
comp_node
.
valid
())
{
mgb_assert
(
op_def
.
comp_node
==
oup_cn
,
"Concat compnode infer error"
);
}
// prepare inputs and output layout
TensorLayout
&
oup_layout
=
output_descs
[
0
].
layout
;
if
(
!
validated
)
{
SmallVector
<
const
TensorLayout
*>
inputs_holder
(
inputs
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
inputs_holder
[
i
]
=
&
inputs
[
i
]
->
layout
();
}
oup_layout
=
concat_layout_deduce
(
inputs_holder
,
axis
);
}
auto
oup
=
Tensor
::
make
(
oup_layout
,
oup_cn
);
// because the dnn concat is very slow, we copy the slice code from
// src/opr/impl/tensor_manip.cpp
auto
&&
out
=
oup
->
dev_tensor
();
size_t
end
=
0
;
for
(
auto
&&
input
:
inputs
)
{
auto
&&
in
=
input
->
dev_tensor
();
auto
begin
=
end
;
end
=
begin
+
in
.
shape
().
shape
[
axis
];
if
(
!
in
.
layout
().
is_empty
())
{
out
.
sub
(
Slice
(
begin
,
end
).
apply
(
out
.
layout
(),
axis
))
.
copy_from_fixlayout
(
in
);
}
}
return
{
oup
};
}
OP_TRAIT_REG
(
Concat
,
Concat
)
.
apply_on_var_node
(
apply_on_var_node
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
fallback
();
}
// namespace concatenate
namespace
stack
{
TensorLayout
stack_layout_deduce
(
const
SmallVector
<
const
TensorLayout
*>
inputs
,
int
axis
)
{
size_t
nr_inp
=
inputs
.
size
();
auto
&&
inp_layout0
=
*
inputs
[
0
];
for
(
size_t
i
=
1
;
i
<
nr_inp
;
++
i
)
{
mgb_assert
(
inp_layout0
.
eq_shape
(
*
inputs
[
i
]),
"Stack input shape mismatch: %s vs %s"
,
inp_layout0
.
to_string
().
c_str
(),
inputs
[
i
]
->
to_string
().
c_str
());
}
TensorLayout
oup_layout
{
TensorShape
{
inp_layout0
},
inp_layout0
.
dtype
};
oup_layout
.
add_axis_cont_inplace
(
axis
);
oup_layout
.
shape
[
axis
]
=
nr_inp
;
oup_layout
.
init_contiguous_stride
();
return
oup_layout
;
}
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
Stack
&>
(
def
);
cg
::
OperatorNodeConfig
config
{
op
.
comp_node
};
config
.
name
(
op
.
make_name
());
using
Desc
=
opr
::
AxisAddRemove
::
AxisDesc
;
std
::
vector
<
Desc
>
param
{
Desc
::
make_add
(
op
.
axis
)};
VarNodeArray
expanded_inputs
;
for
(
auto
&&
inp
:
inputs
)
{
expanded_inputs
.
emplace_back
(
opr
::
AxisAddRemove
::
make
(
inp
,
param
,
cg
::
OperatorNodeConfig
{}).
node
());
}
return
opr
::
Concat
::
make
(
expanded_inputs
,
op
.
axis
,
config
);
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
auto
&&
op_def
=
def
.
cast_final_safe
<
Stack
>
();
auto
oup_cn
=
get_device
<
Stack
>
(
def
,
inputs
);
if
(
!
is_all_inputs_valid
(
inputs
))
{
// because dtypepromote_trans, so use inputs[0].dtype as oup_dtype here
return
{{{
TensorLayout
{
inputs
[
0
].
layout
.
dtype
},
oup_cn
,
{}}},
false
};
}
SmallVector
<
const
TensorLayout
*>
inputs_holder
(
inputs
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
inputs_holder
[
i
]
=
&
inputs
[
i
].
layout
;
}
int
axis
=
op_def
.
axis
>=
0
?
op_def
.
axis
:
op_def
.
axis
+
inputs
[
0
].
layout
.
ndim
+
1
;
TensorLayout
oup_layout
=
stack_layout_deduce
(
inputs_holder
,
axis
);
return
{{{
oup_layout
,
oup_cn
,
{}}},
true
};
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
,
SmallVector
<
LogicalTensorDesc
>&
output_descs
,
const
bool
&
validated
)
{
auto
&&
op_def
=
def
.
cast_final_safe
<
Stack
>
();
size_t
nr_inp
=
inputs
.
size
();
TensorLayout
inp_layout
=
inputs
[
0
]
->
layout
();
int
axis
=
op_def
.
axis
>=
0
?
op_def
.
axis
:
op_def
.
axis
+
inputs
[
0
]
->
layout
().
ndim
+
1
;
CompNode
&
oup_cn
=
output_descs
[
0
].
comp_node
;
if
(
op_def
.
comp_node
.
valid
())
{
mgb_assert
(
op_def
.
comp_node
==
oup_cn
,
"Stack compnode infer error"
);
}
// prepare inputs and output layout
TensorLayout
&
oup_layout
=
output_descs
[
0
].
layout
;
if
(
!
validated
)
{
SmallVector
<
const
TensorLayout
*>
inputs_holder
(
inputs
.
size
());
for
(
size_t
i
=
0
;
i
<
nr_inp
;
++
i
)
{
inputs_holder
[
i
]
=
&
inputs
[
i
]
->
layout
();
}
oup_layout
=
stack_layout_deduce
(
inputs_holder
,
axis
);
}
inp_layout
.
add_axis_cont_inplace
(
axis
);
SmallVector
<
TensorPtr
>
expanded
;
for
(
size_t
i
=
0
;
i
<
nr_inp
;
++
i
)
{
expanded
.
push_back
(
Tensor
::
make
(
inputs
[
i
]
->
blob
(),
inputs
[
i
]
->
offset
(),
inp_layout
));
}
auto
oup
=
Tensor
::
make
(
oup_layout
,
oup_cn
);
// because the dnn concat is very slow, we copy the slice code from
// src/opr/impl/tensor_manip.cpp
auto
&&
out
=
oup
->
dev_tensor
();
size_t
end
=
0
;
for
(
auto
&&
input
:
expanded
)
{
auto
&&
in
=
input
->
dev_tensor
();
auto
begin
=
end
;
end
=
begin
+
in
.
shape
().
shape
[
axis
];
if
(
!
in
.
layout
().
is_empty
())
{
out
.
sub
(
Slice
(
begin
,
end
).
apply
(
out
.
layout
(),
axis
))
.
copy_from_fixlayout
(
in
);
}
}
return
{
oup
};
}
OP_TRAIT_REG
(
Stack
,
Stack
)
.
apply_on_var_node
(
apply_on_var_node
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
fallback
();
}
// namespace stack
}
// namespace mgb::imperative
imperative/src/impl/ops/specializations.cpp
浏览文件 @
9a6ba334
...
@@ -384,18 +384,6 @@ OP_TRAIT_REG(TypeCvt, TypeCvt).apply_on_var_node(apply_on_var_node).fallback();
...
@@ -384,18 +384,6 @@ OP_TRAIT_REG(TypeCvt, TypeCvt).apply_on_var_node(apply_on_var_node).fallback();
}
// namespace typecvt
}
// namespace typecvt
}
// namespace
}
// namespace
namespace
{
namespace
concat
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
Concat
&>
(
def
);
cg
::
OperatorNodeConfig
config
{
op
.
comp_node
};
config
.
name
(
op
.
make_name
());
return
opr
::
Concat
::
make
(
inputs
,
op
.
axis
,
config
);
}
OP_TRAIT_REG
(
Concat
,
Concat
).
apply_on_var_node
(
apply_on_var_node
).
fallback
();
}
// namespace concat
}
// namespace
namespace
{
namespace
{
namespace
copy
{
namespace
copy
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
...
...
imperative/src/impl/proxy_graph/proxy_graph.cpp
浏览文件 @
9a6ba334
...
@@ -53,7 +53,7 @@ SmallVector<LayoutConstraintCallback> get_input_layout_constraint(
...
@@ -53,7 +53,7 @@ SmallVector<LayoutConstraintCallback> get_input_layout_constraint(
VarNodeArray
vinputs
(
inputs
.
size
());
VarNodeArray
vinputs
(
inputs
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
OperatorNodeConfig
config
;
OperatorNodeConfig
config
;
auto
&&
layout
=
inputs
[
i
]
->
layout
();
auto
layout
=
inputs
[
i
]
->
layout
();
layout
.
init_contiguous_stride
();
layout
.
init_contiguous_stride
();
vinputs
[
i
]
=
graph
->
insert_opr
(
std
::
make_unique
<
mgb
::
opr
::
SharedDeviceTensor
>
(
vinputs
[
i
]
=
graph
->
insert_opr
(
std
::
make_unique
<
mgb
::
opr
::
SharedDeviceTensor
>
(
*
graph
,
*
graph
,
...
...
imperative/src/impl/transformations/dtype_promote.cpp
浏览文件 @
9a6ba334
...
@@ -391,6 +391,7 @@ struct DTypePromoteRuleRegistry {
...
@@ -391,6 +391,7 @@ struct DTypePromoteRuleRegistry {
register_dtype_promote_rule
<
Elemwise
>
(
elemwise_rule
);
register_dtype_promote_rule
<
Elemwise
>
(
elemwise_rule
);
register_dtype_promote_rule
<
ElemwiseMultiType
>
(
elemwise_multi_type_rule
);
register_dtype_promote_rule
<
ElemwiseMultiType
>
(
elemwise_multi_type_rule
);
register_dtype_promote_rule
<
Concat
>
(
naive_promote_rule
);
register_dtype_promote_rule
<
Concat
>
(
naive_promote_rule
);
register_dtype_promote_rule
<
Stack
>
(
naive_promote_rule
);
register_dtype_promote_rule
<
GroupLocal
>
(
naive_promote_rule
);
register_dtype_promote_rule
<
GroupLocal
>
(
naive_promote_rule
);
register_dtype_promote_rule
<
Reduce
>
(
reduce_rule
);
register_dtype_promote_rule
<
Reduce
>
(
reduce_rule
);
register_dtype_promote_rule
<
Convolution
>
(
convolution_rule
);
register_dtype_promote_rule
<
Convolution
>
(
convolution_rule
);
...
...
imperative/src/include/megbrain/imperative/physical_tensor.h
浏览文件 @
9a6ba334
...
@@ -133,7 +133,7 @@ public:
...
@@ -133,7 +133,7 @@ public:
DType
dtype
()
const
{
return
m_dtype
;
}
DType
dtype
()
const
{
return
m_dtype
;
}
TensorLayout
layout
()
const
{
return
m_layout
;
}
const
TensorLayout
&
layout
()
const
{
return
m_layout
;
}
const
TensorShape
&
shape
()
const
{
return
m_shape
;
}
const
TensorShape
&
shape
()
const
{
return
m_shape
;
}
...
...
imperative/tablegen/generated/hash.txt
浏览文件 @
9a6ba334
8dd504f360fd3d3bfb560c970b568153 ../../dnn/scripts/opr_param_defs.py
8dd504f360fd3d3bfb560c970b568153 ../../dnn/scripts/opr_param_defs.py
06e8a3af239b545470b38b3e8296093
5 ../../src/core/include/megbrain/ir/ops.td
7d6df1c8e50a22ef2c36b7ea89daa9c
5 ../../src/core/include/megbrain/ir/ops.td
7f37497cffb24554073cbc42b89e2db8
generated/opdef.h.inl
f30ae9494b4bf3363cd74d9396acaf49
generated/opdef.h.inl
1e2041f6374e48d53762ddfe7a6ebca3
generated/opdef.cpp.inl
cb27f486b28a099221f38c6fcaa06a44
generated/opdef.cpp.inl
9a813355a742330e9ba6e5c14ea67c7c
generated/opdef.py.inl
adb758acd1147f213db7f0cb1b708773
generated/opdef.py.inl
8d4ae7fef8234d8c79ac52017f4710e3
generated/opdef.cpy.inl
30ad8e75a5994edf9ec46387c6285312
generated/opdef.cpy.inl
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h
imperative/tablegen/generated/opdef.cpp.inl
浏览文件 @
9a6ba334
...
@@ -6993,6 +6993,46 @@ OP_TRAIT_REG(Split, Split)
...
@@ -6993,6 +6993,46 @@ OP_TRAIT_REG(Split, Split)
.props(Split_props_impl)
.props(Split_props_impl)
.make_name(Split_make_name_impl);
.make_name(Split_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Stack);
namespace {
size_t Stack_hash_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<Stack>();
static_cast<void>(op_);
size_t val = mgb::hash(op_.dyn_typeinfo());
val = mgb::hash_pair_combine(val, mgb::hash(op_.axis));
val = mgb::hash_pair_combine(val, mgb::hash(op_.comp_node));
return val;
}
bool Stack_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
auto &&a_ = lhs_.cast_final_safe<Stack>(),
&&b_ = rhs_.cast_final_safe<Stack>();
static_cast<void>(a_);
static_cast<void>(b_);
if (a_.axis != b_.axis) return false;
if (a_.comp_node != b_.comp_node) return false;
return true;
}
std::vector<std::pair<const char*, std::string>> Stack_props_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<Stack>();
static_cast<void>(op_);
std::vector<std::pair<const char*, std::string>> props_;
props_.emplace_back("axis", std::to_string(op_.axis));
props_.emplace_back("comp_node", op_.comp_node.to_string());
return props_;
}
std::string Stack_make_name_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<Stack>();
static_cast<void>(op_);
return "Stack";
}
} // anonymous namespace
OP_TRAIT_REG(Stack, Stack)
.hash(Stack_hash_impl)
.is_same_st(Stack_is_same_st_impl)
.props(Stack_props_impl)
.make_name(Stack_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Subtensor);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Subtensor);
namespace {
namespace {
...
...
imperative/tablegen/generated/opdef.cpy.inl
浏览文件 @
9a6ba334
...
@@ -20376,6 +20376,133 @@ void _init_py_Split(py::module m) {
...
@@ -20376,6 +20376,133 @@ void _init_py_Split(py::module m) {
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Split::typeinfo(), &py_type).second);
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Split::typeinfo(), &py_type).second);
}
}
PyOpDefBegin(Stack) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
static PyObject* getstate(PyObject* self, PyObject*) {
auto& opdef = reinterpret_cast<PyOp(Stack)*>(self)->inst();
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {
{"axis", serialization<decltype(opdef.axis)>::dump(opdef.axis)},
{"comp_node", serialization<decltype(opdef.comp_node)>::dump(opdef.comp_node)}
};
return py::cast(state).release().ptr();
}
static PyObject* setstate(PyObject* self, PyObject* args) {
PyObject* dict = PyTuple_GetItem(args, 0);
if (!dict) return NULL;
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
auto& opdef = reinterpret_cast<PyOp(Stack)*>(self)->inst();
static_cast<void>(opdef);
{
auto&& iter = state.find("axis");
if (iter != state.end()) {
opdef.axis = serialization<decltype(opdef.axis)>::load(iter->second);
}
}
{
auto&& iter = state.find("comp_node");
if (iter != state.end()) {
opdef.comp_node = serialization<decltype(opdef.comp_node)>::load(iter->second);
}
}
Py_RETURN_NONE;
}
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds);
static PyMethodDef py_init_methoddef;
// };
PyOpDefEnd(Stack)
int PyOp(Stack)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
static const char* kwlist[] = {"axis", "comp_node", "scope", NULL};
PyObject *axis = NULL, *comp_node = NULL, *scope = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOO", const_cast<char**>(kwlist), &axis, &comp_node, &scope))
return -1;
if (axis) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(Stack)*>(self)->inst().axis =
py::cast<decltype(Stack::axis)>(py::handle(axis));
} CATCH_ALL(-1)
}
if (comp_node) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(Stack)*>(self)->inst().comp_node =
py::cast<decltype(Stack::comp_node)>(py::handle(comp_node));
} CATCH_ALL(-1)
}
if (scope) {
try {
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(py::cast<std::string>(py::handle(scope)));
} CATCH_ALL(-1)
}
return 0;
}
PyGetSetDef PyOp(Stack)::py_getsetters[] = {
{const_cast<char*>("axis"), py_get_generic(Stack, axis), py_set_generic(Stack, axis), const_cast<char*>("axis"), NULL},
{const_cast<char*>("comp_node"), py_get_generic(Stack, comp_node), py_set_generic(Stack, comp_node), const_cast<char*>("comp_node"), NULL},
{NULL} /* Sentinel */
};
PyMethodDef PyOp(Stack)::tp_methods[] = {
{const_cast<char*>("__getstate__"), PyOp(Stack)::getstate, METH_NOARGS, "Stack getstate"},
{const_cast<char*>("__setstate__"), PyOp(Stack)::setstate, METH_VARARGS, "Stack setstate"},
{NULL} /* Sentinel */
};
PyObject *PyOp(Stack)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) {
if (PyOp(Stack)::py_init(self, args, kwds) < 0) {
return NULL;
}
Py_RETURN_NONE;
}
PyMethodDef PyOp(Stack)::py_init_methoddef = {
"__init__",
(PyCFunction)PyOp(Stack)::py_init_proxy,
METH_VARARGS | METH_KEYWORDS,
"__init__(self, axis: int = ..., comp_node: str = ...) -> None\n"
};
void _init_py_Stack(py::module m) {
using py_op = PyOp(Stack);
auto& py_type = PyOpType(Stack);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.Stack";
py_type.tp_basicsize = sizeof(PyOp(Stack));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "Stack";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_methods = py_op::tp_methods;
py_type.tp_getset = py_op::py_getsetters;
py_type.tp_dict = PyDict_New();
PyObject* descr = PyDescr_NewMethod(&PyOpType(Stack), &PyOp(Stack)::py_init_methoddef);
PyDict_SetItemString(py_type.tp_dict, "__init__", descr);
mgb_assert(PyType_Ready(&py_type) >= 0);
PyType_Modified(&py_type);
m.add_object("Stack", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Stack::typeinfo(), &py_type).second);
}
PyOpDefBegin(Subtensor) // {
PyOpDefBegin(Subtensor) // {
static PyGetSetDef py_getsetters[];
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
static PyMethodDef tp_methods[];
...
@@ -22064,6 +22191,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) {
...
@@ -22064,6 +22191,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) {
_init_py_SlidingWindowTranspose(m); \
_init_py_SlidingWindowTranspose(m); \
_init_py_Softmax(m); \
_init_py_Softmax(m); \
_init_py_Split(m); \
_init_py_Split(m); \
_init_py_Stack(m); \
_init_py_Subtensor(m); \
_init_py_Subtensor(m); \
_init_py_TQT(m); \
_init_py_TQT(m); \
_init_py_TensorRTRuntime(m); \
_init_py_TensorRTRuntime(m); \
...
...
imperative/tablegen/generated/opdef.h.inl
浏览文件 @
9a6ba334
...
@@ -1819,6 +1819,20 @@ public:
...
@@ -1819,6 +1819,20 @@ public:
}
}
};
};
class Stack : public OpDefImplBase<Stack> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
int32_t axis = 0;
::mgb::CompNode comp_node;
Stack() = default;
Stack(int32_t axis_, ::mgb::CompNode comp_node_, std::string scope_ = {}): axis(axis_), comp_node(comp_node_) { set_scope(scope_); }
Stack(::megdnn::param::Axis packed_param_0, ::mgb::CompNode comp_node_): axis(packed_param_0.axis), comp_node(comp_node_) {}
::megdnn::param::Axis param() const {
return {axis};
}
};
class Subtensor : public OpDefImplBase<Subtensor> {
class Subtensor : public OpDefImplBase<Subtensor> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
MGB_DYN_TYPE_OBJ_FINAL_DECL;
...
...
imperative/tablegen/generated/opdef.py.inl
浏览文件 @
9a6ba334
...
@@ -1896,6 +1896,14 @@ SplitInst
...
@@ -1896,6 +1896,14 @@ SplitInst
.def_readwrite("axis", &Split::axis)
.def_readwrite("axis", &Split::axis)
.def_readwrite("nsections", &Split::nsections);
.def_readwrite("nsections", &Split::nsections);
py::class_<Stack, std::shared_ptr<Stack>, OpDef> StackInst(m, "Stack");
StackInst
.def(py::init<int32_t, ::mgb::CompNode, std::string>(), py::arg("axis") = 0, py::arg("comp_node"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("axis", &Stack::axis)
.def_readwrite("comp_node", &Stack::comp_node);
py::class_<Subtensor, std::shared_ptr<Subtensor>, OpDef> SubtensorInst(m, "Subtensor");
py::class_<Subtensor, std::shared_ptr<Subtensor>, OpDef> SubtensorInst(m, "Subtensor");
SubtensorInst
SubtensorInst
...
...
src/core/include/megbrain/ir/ops.td
浏览文件 @
9a6ba334
...
@@ -296,6 +296,12 @@ def Concat: MgbHashableOp<"Concat", [AxisParam]> {
...
@@ -296,6 +296,12 @@ def Concat: MgbHashableOp<"Concat", [AxisParam]> {
);
);
}
}
def Stack: MgbHashableOp<"Stack", [AxisParam]> {
let extraArguments = (ins
MgbCompNodeAttr:$comp_node
);
}
def Broadcast : MgbHashableOp<"Broadcast", [EmptyParam]> {
def Broadcast : MgbHashableOp<"Broadcast", [EmptyParam]> {
let extraArguments = (ins
let extraArguments = (ins
MgbArrayAttr<MgbI32Attr>:$shape
MgbArrayAttr<MgbI32Attr>:$shape
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录