Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
28c6ebfe
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
28c6ebfe
编写于
11月 21, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(imperative): speed up subtensor
GitOrigin-RevId: c3d94bfde8f4d3c7e2efc46af7e17a255ee01785
上级
a1cbd9bb
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
487 addition
and
30 deletion
+487
-30
imperative/python/src/grad_override.cpp
imperative/python/src/grad_override.cpp
+31
-1
imperative/python/src/tensor_utils.cpp
imperative/python/src/tensor_utils.cpp
+139
-11
imperative/src/impl/ops/broadcast.cpp
imperative/src/impl/ops/broadcast.cpp
+3
-0
imperative/src/impl/ops/indexing.cpp
imperative/src/impl/ops/indexing.cpp
+9
-2
imperative/src/impl/ops/specializations.cpp
imperative/src/impl/ops/specializations.cpp
+0
-1
imperative/src/impl/ops/subtensor.cpp
imperative/src/impl/ops/subtensor.cpp
+218
-0
imperative/src/impl/transformations/format.cpp
imperative/src/impl/transformations/format.cpp
+42
-1
imperative/tablegen/generated/hash.txt
imperative/tablegen/generated/hash.txt
+5
-5
imperative/tablegen/generated/opdef.cpp.inl
imperative/tablegen/generated/opdef.cpp.inl
+3
-0
imperative/tablegen/generated/opdef.cpy.inl
imperative/tablegen/generated/opdef.cpy.inl
+23
-5
imperative/tablegen/generated/opdef.h.inl
imperative/tablegen/generated/opdef.h.inl
+2
-1
imperative/tablegen/generated/opdef.py.inl
imperative/tablegen/generated/opdef.py.inl
+3
-2
src/core/include/megbrain/ir/ops.td
src/core/include/megbrain/ir/ops.td
+9
-1
未找到文件。
imperative/python/src/grad_override.cpp
浏览文件 @
28c6ebfe
...
...
@@ -325,6 +325,35 @@ std::optional<ValueRefList> subtensor_grad_rule(
inputs2
.
push_back
(
inputs
[
i
]);
}
}
CompNodeValue
::
ref_t
device
=
inputs
[
0
].
device
();
auto
get_subtensor_index
=
[
&
](
int
idx
)
{
HostTensorStorage
storage
(
*
device
);
storage
.
ensure_size
(
dtype
::
Int32
().
size
());
auto
*
ptr
=
reinterpret_cast
<
dt_int32
*>
(
storage
.
ptr
());
ptr
[
0
]
=
idx
;
return
imperative
::
apply
(
CreateTensor
(
CreateTensor
::
Unique
,
*
device
,
dtype
::
Int32
(),
ValueShape
({
1
})),
HostStorage
::
make
(
storage
))[
0
];
};
auto
slice_items
=
subtensor
.
slice_items
;
auto
items
=
subtensor
.
items
;
for
(
int
i
=
0
;
i
<
slice_items
.
size
();
i
++
)
{
auto
&&
[
axis
,
b_flag
,
e_flag
,
s_flag
,
idx_flag
]
=
items
[
i
];
auto
&&
[
b_val
,
e_val
,
s_val
,
ax_val
]
=
slice_items
[
i
];
if
(
b_flag
)
{
inputs2
.
push_back
(
get_subtensor_index
(
b_val
));
};
if
(
e_flag
)
{
inputs2
.
push_back
(
get_subtensor_index
(
e_val
));
};
if
(
s_flag
)
{
inputs2
.
push_back
(
get_subtensor_index
(
s_val
));
};
if
(
idx_flag
)
{
inputs2
.
push_back
(
get_subtensor_index
(
ax_val
));
};
};
auto
maker
=
CustomGradMaker
(
backward
,
inputs
.
size
());
maker
.
output_size
(
1
).
output_captured
(
0
,
false
);
maker
.
backward
([
inputs
=
std
::
move
(
inputs2
),
...
...
@@ -647,8 +676,9 @@ std::optional<ValueRefList> warp_affine_grad_rule(
ret
[
1
]
=
imperative
::
apply
(
*
grad_op
,
args_
)[
0
];
std
::
vector
<
std
::
tuple
<
int8_t
,
bool
,
bool
,
bool
,
bool
>>
items
;
std
::
vector
<
std
::
tuple
<
int32_t
,
int32_t
,
int32_t
,
int32_t
>>
slice_items
;
items
.
push_back
(
std
::
make_tuple
(
1
,
true
,
true
,
false
,
false
));
auto
&&
subtensor
=
Subtensor
::
make
(
items
);
auto
&&
subtensor
=
Subtensor
::
make
(
items
,
slice_items
);
CompNodeValue
::
ref_t
device
=
inputs
[
0
].
device
();
DTypeValue
::
ref_t
dtype
=
inputs
[
0
].
dtype
();
...
...
imperative/python/src/tensor_utils.cpp
浏览文件 @
28c6ebfe
...
...
@@ -781,14 +781,8 @@ std::pair<size_t, bool> get_ndim_safe(py::handle tensor) {
}
}
py
::
tuple
_unpack_indexes
(
py
::
handle
inp_hdl
,
py
::
handle
idx_hd
l
)
{
py
::
tuple
_unpack_indexes
(
py
::
handle
inp_hdl
,
py
::
tuple
tuple_va
l
)
{
py
::
object
inp
=
py
::
reinterpret_borrow
<
py
::
object
>
(
inp_hdl
);
py
::
tuple
tuple_val
;
if
(
py
::
isinstance
<
py
::
tuple
>
(
idx_hdl
))
{
tuple_val
=
py
::
reinterpret_borrow
<
py
::
tuple
>
(
idx_hdl
);
}
else
{
tuple_val
=
py
::
make_tuple
(
idx_hdl
);
}
bool
use_subtensor
=
true
;
bool
need_remove_ellipsis
=
false
;
...
...
@@ -939,6 +933,20 @@ bool enable_fastpath(py::handle inp) {
return
true
;
}
bool
subtensor_fastpath
(
py
::
handle
inp_hdl
,
py
::
tuple
tuple_val
)
{
bool
use_fastpath
=
true
;
for
(
size_t
i
=
0
;
i
<
tuple_val
.
size
();
++
i
)
{
PyObject
*
obj
=
tuple_val
[
i
].
ptr
();
if
((
!
is_scalar
(
obj
)
&&
!
PySlice_Check
(
obj
)
&&
obj
!=
Py_Ellipsis
&&
obj
!=
Py_None
)
||
(
PyObject_TypeCheck
(
obj
,
py_varnode_type
)))
{
use_fastpath
=
false
;
break
;
}
}
return
use_fastpath
&&
enable_fastpath
(
inp_hdl
);
}
py
::
object
_broadcast_cpp
(
py
::
handle
input
,
py
::
handle
args
)
{
py
::
object
shape
=
_expand_args
(
args
);
py
::
list
dims
;
...
...
@@ -1128,16 +1136,129 @@ py::object _adaptive_pool2d_cpp(
return
ret
[
0
];
}
py
::
object
_fastpath_getitem_cpp
(
py
::
handle
inp_hdl
,
py
::
tuple
tuple_val
)
{
py
::
object
inp
=
py
::
reinterpret_borrow
<
py
::
object
>
(
inp_hdl
);
int
ax
=
0
;
bool
use_ellipsis
=
false
;
size_t
special_dim
=
0
;
for
(
size_t
i
=
0
;
i
<
tuple_val
.
size
();
++
i
)
{
PyObject
*
obj
=
tuple_val
[
i
].
ptr
();
if
(
obj
==
Py_Ellipsis
)
{
use_ellipsis
=
true
;
for
(
size_t
j
=
i
+
1
;
j
<
tuple_val
.
size
();
j
++
)
{
PyObject
*
obj_last
=
tuple_val
[
j
].
ptr
();
if
(
obj_last
==
Py_Ellipsis
)
{
throw
py
::
index_error
(
"only one ellipsis is allowed."
);
}
}
}
if
(
obj
!=
Py_None
&&
obj
!=
Py_Ellipsis
&&
obj
!=
Py_True
&&
obj
!=
Py_False
)
{
special_dim
++
;
}
}
size_t
ndim
=
0
;
try
{
ndim
=
getattr
(
inp_hdl
,
"ndim"
).
cast
<
size_t
>
();
}
catch
(
py
::
error_already_set
&
err
)
{
if
(
use_ellipsis
)
{
throw
py
::
index_error
(
"does not support Ellipsis when tensor's ndim is unknown."
);
};
}
std
::
vector
<
std
::
tuple
<
int8_t
,
bool
,
bool
,
bool
,
bool
>>
cpp_items
;
std
::
vector
<
std
::
tuple
<
int32_t
,
int32_t
,
int32_t
,
int32_t
>>
slice_items
;
std
::
vector
<
int32_t
>
expand_items
;
for
(
size_t
i
=
0
;
i
<
tuple_val
.
size
();
++
i
)
{
py
::
object
t
=
tuple_val
[
i
];
if
(
t
.
ptr
()
==
Py_Ellipsis
)
{
ax
+=
ndim
-
special_dim
;
}
else
if
(
PySlice_Check
(
t
.
ptr
()))
{
PySliceObject
*
s
=
(
PySliceObject
*
)
t
.
ptr
();
std
::
vector
<
int
>
items
;
std
::
vector
<
bool
>
idx_items
;
auto
push
=
[
&
](
PyObject
*
v
,
int
default_value
)
{
if
(
v
==
Py_None
)
{
items
.
push_back
(
default_value
);
idx_items
.
push_back
(
false
);
}
else
{
auto
obj
=
py
::
reinterpret_borrow
<
py
::
object
>
(
v
);
items
.
push_back
(
obj
.
cast
<
int
>
());
idx_items
.
push_back
(
true
);
}
};
push
(
s
->
start
,
INT_MIN
);
push
(
s
->
stop
,
INT_MAX
);
push
(
s
->
step
,
INT_MAX
);
if
(
idx_items
[
0
]
||
idx_items
[
1
]
||
idx_items
[
2
])
{
cpp_items
.
push_back
(
{
ax
,
idx_items
[
0
],
idx_items
[
1
],
idx_items
[
2
],
false
});
slice_items
.
push_back
({
items
[
0
],
items
[
1
],
items
[
2
],
INT_MAX
});
}
ax
+=
1
;
}
else
if
(
PyLong_Check
(
t
.
ptr
())
&&
!
PyBool_Check
(
t
.
ptr
()))
{
cpp_items
.
push_back
({
ax
,
false
,
false
,
false
,
true
});
slice_items
.
push_back
({
INT_MIN
,
INT_MAX
,
INT_MAX
,
t
.
cast
<
int
>
()});
ax
+=
1
;
}
else
if
(
PyBool_Check
(
t
.
ptr
()))
{
expand_items
.
push_back
(
ax
);
}
else
if
(
t
.
ptr
()
==
Py_None
)
{
expand_items
.
push_back
(
ax
);
ax
+=
1
;
}
else
if
(
is_scalar
(
t
.
ptr
()))
{
cpp_items
.
push_back
({
ax
,
false
,
false
,
false
,
true
});
slice_items
.
push_back
({
INT_MIN
,
INT_MAX
,
INT_MAX
,
t
.
cast
<
int
>
()});
ax
+=
1
;
}
else
{
throw
py
::
value_error
(
"fast path subtensor index not impl"
);
}
}
if
(
expand_items
.
size
())
{
std
::
shared_ptr
<
OpDef
>
op
=
AddAxis
::
make
(
expand_items
);
py
::
object
Op
=
py
::
cast
(
op
);
PyObject
*
p
[
2
]
=
{
Op
.
ptr
(),
inp
.
ptr
()};
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
,
2
));
inp
=
ret
[
0
];
}
std
::
shared_ptr
<
OpDef
>
op
;
op
=
Subtensor
::
make
(
cpp_items
,
slice_items
);
std
::
vector
<
PyObject
*>
p
;
p
.
resize
(
2
);
py
::
object
Op
=
py
::
cast
(
op
);
p
[
0
]
=
Op
.
ptr
();
p
[
1
]
=
inp
.
ptr
();
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
.
data
(),
p
.
size
()));
return
ret
[
0
];
}
py
::
object
_getitem_cpp
(
py
::
handle
inp_hdl
,
py
::
handle
idx_hdl
)
{
py
::
tuple
try_res
=
_try_cond_take
(
inp_hdl
,
idx_hdl
);
if
(
try_res
.
size
()
==
2
)
{
return
try_res
[
0
];
}
py
::
tuple
up
=
_unpack_indexes
(
inp_hdl
,
idx_hdl
);
py
::
tuple
tuple_val
;
if
(
py
::
isinstance
<
py
::
tuple
>
(
idx_hdl
))
{
tuple_val
=
py
::
reinterpret_borrow
<
py
::
tuple
>
(
idx_hdl
);
}
else
{
tuple_val
=
py
::
make_tuple
(
idx_hdl
);
}
if
(
subtensor_fastpath
(
inp_hdl
,
tuple_val
))
{
return
_fastpath_getitem_cpp
(
inp_hdl
,
tuple_val
);
}
py
::
tuple
up
=
_unpack_indexes
(
inp_hdl
,
tuple_val
);
py
::
object
tensor
=
py
::
reinterpret_borrow
<
py
::
object
>
(
up
[
0
]);
py
::
list
tensors
=
py
::
reinterpret_borrow
<
py
::
list
>
(
up
[
1
]);
py
::
list
py_items
=
py
::
reinterpret_borrow
<
py
::
list
>
(
up
[
2
]);
std
::
vector
<
std
::
tuple
<
int8_t
,
bool
,
bool
,
bool
,
bool
>>
cpp_items
;
std
::
vector
<
std
::
tuple
<
int32_t
,
int32_t
,
int32_t
,
int32_t
>>
slice_items
;
for
(
size_t
i
=
0
;
i
<
py_items
.
size
();
++
i
)
{
py
::
list
item
=
py
::
reinterpret_borrow
<
py
::
list
>
(
py_items
[
i
]);
cpp_items
.
push_back
(
...
...
@@ -1146,7 +1267,7 @@ py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) {
}
std
::
shared_ptr
<
OpDef
>
op
;
if
(
up
[
3
].
cast
<
bool
>
())
{
op
=
Subtensor
::
make
(
cpp_items
);
op
=
Subtensor
::
make
(
cpp_items
,
slice_items
);
}
else
{
op
=
IndexingMultiAxisVec
::
make
(
cpp_items
);
}
...
...
@@ -1170,11 +1291,18 @@ py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_h
val
=
_Const
(
val_hdl
,
getattr
(
inp_hdl
,
"dtype"
),
getattr
(
inp_hdl
,
"device"
));
}
py
::
tuple
up
=
_unpack_indexes
(
inp_hdl
,
idx_hdl
);
py
::
tuple
tuple_val
;
if
(
py
::
isinstance
<
py
::
tuple
>
(
idx_hdl
))
{
tuple_val
=
py
::
reinterpret_borrow
<
py
::
tuple
>
(
idx_hdl
);
}
else
{
tuple_val
=
py
::
make_tuple
(
idx_hdl
);
}
py
::
tuple
up
=
_unpack_indexes
(
inp_hdl
,
tuple_val
);
py
::
object
tensor
=
py
::
reinterpret_borrow
<
py
::
object
>
(
up
[
0
]);
py
::
list
tensors
=
py
::
reinterpret_borrow
<
py
::
list
>
(
up
[
1
]);
py
::
list
py_items
=
py
::
reinterpret_borrow
<
py
::
list
>
(
up
[
2
]);
std
::
vector
<
std
::
tuple
<
int8_t
,
bool
,
bool
,
bool
,
bool
>>
cpp_items
;
std
::
vector
<
std
::
tuple
<
int32_t
,
int32_t
,
int32_t
,
int32_t
>>
slice_items
;
for
(
size_t
i
=
0
;
i
<
py_items
.
size
();
++
i
)
{
py
::
list
item
=
py
::
reinterpret_borrow
<
py
::
list
>
(
py_items
[
i
]);
cpp_items
.
push_back
(
...
...
@@ -1183,7 +1311,7 @@ py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_h
}
std
::
shared_ptr
<
OpDef
>
op
,
set_op
;
if
(
up
[
3
].
cast
<
bool
>
())
{
op
=
Subtensor
::
make
(
cpp_items
);
op
=
Subtensor
::
make
(
cpp_items
,
slice_items
);
}
else
{
op
=
IndexingMultiAxisVec
::
make
(
cpp_items
);
}
...
...
imperative/src/impl/ops/broadcast.cpp
浏览文件 @
28c6ebfe
...
...
@@ -208,6 +208,9 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
cg
::
copy_tensor_value_to_shape
(
tshp
,
tshp_nd
->
get_value
().
proxy_to_default_cpu
());
}
if
(
tshp
.
is_empty
())
{
return
{
Tensor
::
make
(
TensorLayout
(
tshp
,
src
->
dtype
()),
src
->
comp_node
())};
}
TensorLayout
tlayout
=
slayout
.
broadcast
(
tshp
);
// memory forward
return
{
Tensor
::
make
(
src
->
blob
(),
src
->
offset
(),
tlayout
)};
...
...
imperative/src/impl/ops/indexing.cpp
浏览文件 @
28c6ebfe
...
...
@@ -103,8 +103,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
if
(
!
src
.
ndim
)
{
return
{{{{{},
src
.
dtype
},
comp_node
}},
false
};
}
mgb_assert
(
src
.
is_contiguous
(),
"src should be contiguous"
);
return
{{{
src
,
comp_node
}},
true
};
return
{{{{
src
,
src
.
dtype
},
comp_node
}},
true
};
}
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
...
...
@@ -138,10 +137,18 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
dnn_op
.
exec_with_ws
(
out
,
index
,
sub
);
return
{
out
};
}
SmallVector
<
VarNode
::
LayoutConstraintCallback
>
get_input_layout_constraint
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
SmallVector
<
VarNode
::
LayoutConstraintCallback
>
layout_checker
(
inputs
.
size
());
layout_checker
[
0
]
=
layout_checker
[
1
]
=
layout_checker
[
2
]
=
[](
const
TensorLayout
&
layout
)
{
return
layout
.
is_contiguous
();
};
return
layout_checker
;
}
OP_TRAIT_REG
(
IndexingSetOneHot
,
IndexingSetOneHot
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
apply_on_var_node
(
apply_on_var_node
)
.
get_input_layout_constraint
(
get_input_layout_constraint
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
fallback
();
}
// namespace indexing_set_one_hot
...
...
imperative/src/impl/ops/specializations.cpp
浏览文件 @
28c6ebfe
...
...
@@ -542,7 +542,6 @@ auto get_index(
OP_TRAIT_REG(NAME, NAME).apply_on_var_node(apply_on_var_node).fallback(); \
}
FANCY_INDEXING_IMPL
(
Subtensor
,
1
)
FANCY_INDEXING_IMPL
(
SetSubtensor
,
2
)
FANCY_INDEXING_IMPL
(
IncrSubtensor
,
2
)
FANCY_INDEXING_IMPL
(
IndexingMultiAxisVec
,
1
)
...
...
imperative/src/impl/ops/subtensor.cpp
0 → 100644
浏览文件 @
28c6ebfe
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/internal/indexing_helper.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/tensor.h"
#include "../algo_chooser.h"
#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
using
namespace
mgb
::
opr
::
indexing
;
namespace
mgb
::
imperative
{
namespace
{
namespace
subtensor
{
auto
get_index
(
const
VarNodeArray
&
inputs
,
const
std
::
vector
<
std
::
tuple
<
int8_t
,
bool
,
bool
,
bool
,
bool
>>&
mask
,
const
std
::
vector
<
std
::
tuple
<
int32_t
,
int32_t
,
int32_t
,
int32_t
>>&
slice
)
{
size_t
length
=
mask
.
size
();
auto
graph
=
inputs
[
0
]
->
owner_graph
();
auto
comp_node
=
inputs
[
0
]
->
comp_node
();
opr
::
Subtensor
::
IndexDesc
ret
(
length
);
auto
immutable_node
=
[
&
](
int
val
)
{
DTypeScalar
scalar
=
DTypeScalar
(
static_cast
<
megdnn
::
dt_int32
>
(
val
));
return
opr
::
ImmutableTensor
::
make
(
*
graph
,
scalar
,
{
comp_node
});
};
for
(
size_t
i
=
0
;
i
<
length
;
++
i
)
{
auto
&&
[
axis
,
b_flag
,
e_flag
,
s_flag
,
idx_flag
]
=
mask
[
i
];
auto
&&
[
b_val
,
e_val
,
s_val
,
ax_val
]
=
slice
[
i
];
ret
[
i
].
axis
=
axis
;
if
(
idx_flag
)
{
ret
[
i
].
idx
=
immutable_node
(
ax_val
);
}
else
{
if
(
b_flag
)
{
ret
[
i
].
begin
=
immutable_node
(
b_val
);
}
if
(
e_flag
)
{
ret
[
i
].
end
=
immutable_node
(
e_val
);
}
if
(
s_flag
)
{
ret
[
i
].
step
=
immutable_node
(
s_val
);
}
}
}
return
ret
;
}
auto
origin_get_index
(
const
VarNodeArray
&
inputs
,
size_t
vidx
,
const
std
::
vector
<
std
::
tuple
<
int8_t
,
bool
,
bool
,
bool
,
bool
>>&
mask
)
{
size_t
length
=
mask
.
size
();
opr
::
Subtensor
::
IndexDesc
ret
(
length
);
for
(
size_t
i
=
0
;
i
<
length
;
++
i
)
{
auto
&&
[
axis
,
begin
,
end
,
step
,
idx
]
=
mask
[
i
];
ret
[
i
].
axis
=
axis
;
if
(
idx
)
{
ret
[
i
].
idx
=
inputs
[
vidx
++
];
}
else
{
mgb_assert
(
begin
||
end
||
step
);
if
(
begin
)
ret
[
i
].
begin
=
inputs
[
vidx
++
];
if
(
end
)
ret
[
i
].
end
=
inputs
[
vidx
++
];
if
(
step
)
ret
[
i
].
step
=
inputs
[
vidx
++
];
}
}
mgb_assert
(
vidx
==
inputs
.
size
());
return
ret
;
}
TensorLayout
deduce_layout
(
TensorLayout
src
,
std
::
vector
<
std
::
tuple
<
int8_t
,
bool
,
bool
,
bool
,
bool
>>
items
,
std
::
vector
<
std
::
tuple
<
int32_t
,
int32_t
,
int32_t
,
int32_t
>>
slice_items
)
{
auto
mod_size
=
[](
int
v
,
int
size_ax
)
->
int
{
if
(
size_ax
==
0
)
return
0
;
return
v
<
0
?
v
+
size_ax
:
v
;
};
#define CHECK(cond) \
mgb_assert(cond, "index out of bound: layout=%s", src.to_string().c_str())
for
(
int
i
=
items
.
size
()
-
1
;
i
>=
0
;
i
--
)
{
auto
&&
[
axis
,
b_flag
,
e_flag
,
s_flag
,
idx_flag
]
=
items
[
i
];
auto
&&
[
b_val
,
e_val
,
s_val
,
ax_val
]
=
slice_items
[
i
];
int
shape_axis
=
src
.
shape
[
axis
];
int
slice_step
=
s_val
==
INT_MAX
?
1
:
s_val
;
int
slice_start
=
b_val
==
INT_MIN
?
0
:
b_val
;
int
slice_stop
=
e_val
==
INT_MAX
?
shape_axis
:
e_val
;
if
(
slice_step
>
0
)
{
slice_start
=
mod_size
(
slice_start
,
shape_axis
);
slice_stop
=
mod_size
(
slice_stop
,
shape_axis
);
slice_stop
=
std
::
min
(
slice_stop
,
shape_axis
);
slice_start
=
std
::
min
(
slice_start
,
slice_stop
);
CHECK
(
slice_start
>=
0
&&
slice_stop
>=
slice_start
&&
slice_stop
<=
shape_axis
);
}
else
{
slice_start
=
s_val
==
INT_MIN
?
shape_axis
-
1
:
b_val
;
slice_start
=
mod_size
(
slice_start
,
shape_axis
);
slice_stop
=
e_val
==
INT_MAX
?
-
1
:
mod_size
(
e_val
,
shape_axis
);
slice_start
=
std
::
min
(
slice_start
,
std
::
max
(
shape_axis
-
1
,
0
));
slice_stop
=
std
::
min
(
slice_stop
,
slice_start
);
CHECK
(
slice_step
<
0
&&
slice_start
>=
0
&&
slice_stop
<=
slice_start
&&
slice_start
<
shape_axis
&&
slice_stop
>=
-
1
);
}
int
abs_step
=
std
::
abs
(
slice_step
);
if
(
axis
<
0
)
{
axis
=
axis
+
src
.
ndim
;
};
if
(
idx_flag
==
true
)
{
if
(
src
.
ndim
==
1
)
{
src
.
shape
[
0
]
=
1
;
}
else
{
src
.
remove_axis_inplace
(
axis
);
}
}
else
{
src
.
shape
[
axis
]
=
(
std
::
abs
(
slice_stop
-
slice_start
)
+
abs_step
-
1
)
/
abs_step
;
src
.
stride
[
axis
]
*=
slice_step
;
}
}
return
src
;
}
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
Subtensor
&>
(
def
);
OperatorNodeConfig
config
{
op
.
make_name
()};
if
(
inputs
.
size
()
>
1
)
{
return
opr
::
Subtensor
::
make
(
inputs
[
0
],
origin_get_index
(
inputs
,
1
,
op
.
items
),
config
);
}
else
{
return
opr
::
Subtensor
::
make
(
inputs
[
0
],
get_index
(
inputs
,
op
.
items
,
op
.
slice_items
),
config
);
}
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
if
(
inputs
.
size
()
>=
2
)
{
return
proxy_graph_detail
::
infer_output_attrs_fallible
(
def
,
inputs
);
}
auto
&&
inp
=
inputs
[
0
];
auto
&
inp_cn
=
inp
.
comp_node
;
if
(
inp
.
layout
.
ndim
==
0
)
{
return
{{{
TensorLayout
{
inp
.
layout
.
dtype
},
inp_cn
,
{}}},
false
};
}
auto
&&
op
=
static_cast
<
const
Subtensor
&>
(
def
);
auto
items
=
op
.
items
;
auto
slice_itmes
=
op
.
slice_items
;
TensorLayout
out_layout
=
deduce_layout
(
inp
.
layout
,
items
,
slice_itmes
);
return
{{{
out_layout
,
inp_cn
,
{}}},
true
};
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
,
SmallVector
<
LogicalTensorDesc
>&
output_descs
,
const
bool
&
validated
)
{
CompNode
cn
=
inputs
[
0
]
->
comp_node
();
auto
&&
layout
=
inputs
[
0
]
->
layout
();
auto
&&
op
=
static_cast
<
const
Subtensor
&>
(
def
);
if
(
inputs
.
size
()
>
1
)
{
return
proxy_graph_detail
::
apply_on_physical_tensor
(
def
,
inputs
,
output_descs
,
validated
);
}
auto
&&
src
=
inputs
[
0
];
auto
slice_items
=
op
.
slice_items
;
auto
items
=
op
.
items
;
TensorLayout
res_layout
=
deduce_layout
(
layout
,
items
,
slice_items
);
if
(
res_layout
.
is_empty
())
{
return
{
Tensor
::
make
(
res_layout
,
cn
)};
}
size_t
offset
=
0
;
size_t
dtype_size
=
layout
.
dtype
.
size
();
TensorPtr
tensor
=
src
;
for
(
int
i
=
items
.
size
()
-
1
;
i
>=
0
;
i
--
)
{
auto
&&
[
axis
,
b_flag
,
e_flag
,
s_flag
,
idx_flag
]
=
items
[
i
];
auto
&&
[
b_val
,
e_val
,
s_val
,
ax_val
]
=
slice_items
[
i
];
int
start
=
b_val
;
if
(
idx_flag
)
{
ax_val
=
ax_val
<
0
?
layout
.
shape
[
axis
]
+
ax_val
:
ax_val
;
offset
+=
ax_val
*
layout
.
stride
[
axis
]
*
dtype_size
;
}
else
{
start
=
std
::
max
(
start
,
0
);
offset
+=
start
*
layout
.
stride
[
axis
]
*
dtype_size
;
}
}
// memory forward
return
{
Tensor
::
make
(
src
->
blob
(),
src
->
offset
()
+
offset
,
res_layout
)};
}
SmallVector
<
VarNode
::
LayoutConstraintCallback
>
get_input_layout_constraint
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
SmallVector
<
VarNode
::
LayoutConstraintCallback
>
layout_checker
(
inputs
.
size
());
return
layout_checker
;
}
OP_TRAIT_REG
(
Subtensor
,
Subtensor
,
opr
::
Subtensor
)
.
apply_on_var_node
(
apply_on_var_node
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
get_input_layout_constraint
(
get_input_layout_constraint
)
.
fallback
();
}
// namespace subtensor
}
// namespace
}
// namespace mgb::imperative
\ No newline at end of file
imperative/src/impl/transformations/format.cpp
浏览文件 @
28c6ebfe
...
...
@@ -306,6 +306,19 @@ inline bool is_reduce_ndim_idx_items(
return
false
;
}
inline
bool
is_subtensor_reduce_ndim
(
const
std
::
vector
<
std
::
tuple
<
int8_t
,
bool
,
bool
,
bool
,
bool
>>&
items
,
const
std
::
vector
<
std
::
tuple
<
int32_t
,
int32_t
,
int32_t
,
int32_t
>>
slice_items
)
{
for
(
auto
i
=
0
;
i
<
items
.
size
();
++
i
)
{
auto
&&
[
axis
,
begin
,
end
,
step
,
idx
]
=
items
[
i
];
if
(
idx
)
{
auto
&&
[
b_val
,
e_val
,
s_val
,
ax_val
]
=
slice_items
[
i
];
return
ax_val
!=
INT_MAX
;
}
}
return
false
;
}
inline
auto
convert_nchw2nhwc_idx_items
(
const
std
::
vector
<
std
::
tuple
<
int8_t
,
bool
,
bool
,
bool
,
bool
>>&
items
)
{
auto
nhwc_items
=
items
;
...
...
@@ -326,6 +339,34 @@ ValueRefList subtensor_rule(
const
FormatTransformation
&
t
)
{
mgb_assert
(
inputs
.
size
()
>=
1
);
auto
&
src
=
inputs
[
0
].
cast
(
t
.
value_type
());
bool
is_reduce_ndim
=
false
;
if
(
inputs
.
size
()
>
1
)
{
is_reduce_ndim
=
is_reduce_ndim_idx_items
(
op
.
items
,
{
&
inputs
[
1
],
&
inputs
[
inputs
.
size
()
-
1
]});
}
else
{
is_reduce_ndim
=
is_subtensor_reduce_ndim
(
op
.
items
,
op
.
slice_items
);
}
if
(
!
is_reduce_ndim
)
{
// only support NHWC2NCHW convert, otherwise maintain src's format
if
(
!
(
auto_convert
&&
src
.
format
()
==
FT
::
NHWC
))
{
return
{
t
.
wrap_output
(
imperative
::
apply
(
op
,
t
.
unwrap_inputs
(
inputs
))[
0
],
src
.
format
())};
}
auto
nhwc_items
=
convert_nchw2nhwc_idx_items
(
op
.
items
);
auto
outputs
=
imperative
::
apply
(
*
T
::
make
(
std
::
move
(
nhwc_items
),
op
.
slice_items
,
op
.
scope
()),
t
.
unwrap_inputs
(
inputs
));
return
t
.
wrap_outputs
(
outputs
,
FT
::
NHWC
);
}
return
t
.
wrap_outputs
(
imperative
::
apply
(
op
,
t
.
unwrap_inputs
(
inputs
)));
}
template
<
typename
T
>
ValueRefList
indexing_rule
(
const
T
&
op
,
Span
<
ValueRef
>&
inputs
,
const
bool
&
auto_convert
,
const
FormatTransformation
&
t
)
{
mgb_assert
(
inputs
.
size
()
>=
1
);
auto
&
src
=
inputs
[
0
].
cast
(
t
.
value_type
());
bool
is_reduce_ndim
=
is_reduce_ndim_idx_items
(
op
.
items
,
{
&
inputs
[
1
],
&
inputs
[
inputs
.
size
()
-
1
]});
if
(
!
is_reduce_ndim
)
{
...
...
@@ -597,7 +638,7 @@ struct FormatRuleRegistry {
register_format_rule
(
reshape_rule
);
register_format_rule
(
broadcast_rule
);
register_format_rule
(
subtensor_rule
<
Subtensor
>
);
register_format_rule
(
subtensor
_rule
<
IndexingMultiAxisVec
>
);
register_format_rule
(
indexing
_rule
<
IndexingMultiAxisVec
>
);
register_format_rule
(
setsubtensor_rule
<
SetSubtensor
>
);
register_format_rule
(
setsubtensor_rule
<
IndexingSetMultiAxisVec
>
);
register_format_rule
(
elemwise_rule
);
...
...
imperative/tablegen/generated/hash.txt
浏览文件 @
28c6ebfe
8dd504f360fd3d3bfb560c970b568153 ../../dnn/scripts/opr_param_defs.py
cf864561de125ab559c0035158656682
../../src/core/include/megbrain/ir/ops.td
f27cdbb7926e0be9f5dabb8651d2e4da
generated/opdef.h.inl
96817f709ee92c8e1eb7cb4168f28565
generated/opdef.cpp.inl
672668fa3ed11c27781f0fa380e6c8aa
generated/opdef.py.inl
47511e3e7fed8c64a1c4fea48d79b3d
1 generated/opdef.cpy.inl
6811fde221f86d1ef8de425a3c83127b
../../src/core/include/megbrain/ir/ops.td
55123da1605ef6edd79e3a2ede8aefeb
generated/opdef.h.inl
6f4beb6d12cdd9ec4c4e61b6d7d35144
generated/opdef.cpp.inl
185ba3c3a0fce480ee498cef058670b2
generated/opdef.py.inl
b7ed7a638b7586709bb23dd153fb58b
1 generated/opdef.cpy.inl
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h
imperative/tablegen/generated/opdef.cpp.inl
浏览文件 @
28c6ebfe
...
...
@@ -6918,6 +6918,7 @@ size_t Subtensor_hash_impl(const OpDef& def_) {
static_cast<void>(op_);
size_t val = mgb::hash(op_.dyn_typeinfo());
val = mgb::hash_pair_combine(val, mgb::hash(op_.items));
val = mgb::hash_pair_combine(val, mgb::hash(op_.slice_items));
return val;
}
bool Subtensor_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
...
...
@@ -6926,6 +6927,7 @@ bool Subtensor_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
static_cast<void>(a_);
static_cast<void>(b_);
if (a_.items != b_.items) return false;
if (a_.slice_items != b_.slice_items) return false;
return true;
}
std::vector<std::pair<const char*, std::string>> Subtensor_props_impl(const OpDef& def_) {
...
...
@@ -6933,6 +6935,7 @@ std::vector<std::pair<const char*, std::string>> Subtensor_props_impl(const OpDe
static_cast<void>(op_);
std::vector<std::pair<const char*, std::string>> props_;
props_.emplace_back("items", "{std::vector}");
props_.emplace_back("slice_items", "{std::vector}");
return props_;
}
std::string Subtensor_make_name_impl(const OpDef& def_) {
...
...
imperative/tablegen/generated/opdef.cpy.inl
浏览文件 @
28c6ebfe
...
...
@@ -20113,7 +20113,8 @@ PyOpDefBegin(Subtensor) // {
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {
{"items", serialization<decltype(opdef.items)>::dump(opdef.items)}
{"items", serialization<decltype(opdef.items)>::dump(opdef.items)},
{"slice_items", serialization<decltype(opdef.slice_items)>::dump(opdef.slice_items)}
};
return py::cast(state).release().ptr();
}
...
...
@@ -20130,6 +20131,13 @@ PyOpDefBegin(Subtensor) // {
opdef.items = serialization<decltype(opdef.items)>::load(iter->second);
}
}
{
auto&& iter = state.find("slice_items");
if (iter != state.end()) {
opdef.slice_items = serialization<decltype(opdef.slice_items)>::load(iter->second);
}
}
Py_RETURN_NONE;
}
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
...
...
@@ -20139,9 +20147,9 @@ PyOpDefBegin(Subtensor) // {
PyOpDefEnd(Subtensor)
int PyOp(Subtensor)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
static const char* kwlist[] = {"items", "scope", NULL};
PyObject *items = NULL, *scope = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OO
", const_cast<char**>(kwlist), &
items, &scope))
static const char* kwlist[] = {"items", "s
lice_items", "s
cope", NULL};
PyObject *items = NULL, *s
lice_items = NULL, *s
cope = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OO
O", const_cast<char**>(kwlist), &items, &slice_
items, &scope))
return -1;
if (items) {
...
...
@@ -20153,6 +20161,15 @@ int PyOp(Subtensor)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
} CATCH_ALL(-1)
}
if (slice_items) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(Subtensor)*>(self)->inst().slice_items =
py::cast<decltype(Subtensor::slice_items)>(py::handle(slice_items));
} CATCH_ALL(-1)
}
if (scope) {
try {
reinterpret_cast<PyOp(OpDef)*>(self)->op
...
...
@@ -20165,6 +20182,7 @@ int PyOp(Subtensor)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
PyGetSetDef PyOp(Subtensor)::py_getsetters[] = {
{const_cast<char*>("items"), py_get_generic(Subtensor, items), py_set_generic(Subtensor, items), const_cast<char*>("items"), NULL},
{const_cast<char*>("slice_items"), py_get_generic(Subtensor, slice_items), py_set_generic(Subtensor, slice_items), const_cast<char*>("slice_items"), NULL},
{NULL} /* Sentinel */
};
...
...
@@ -20185,7 +20203,7 @@ PyMethodDef PyOp(Subtensor)::py_init_methoddef = {
"__init__",
(PyCFunction)PyOp(Subtensor)::py_init_proxy,
METH_VARARGS | METH_KEYWORDS,
"__init__(self, items: list[tuple[int, bool, bool, bool, bool]] = ...) -> None\n"
"__init__(self, items: list[tuple[int, bool, bool, bool, bool]] = ...
, slice_items: list[tuple[int, int, int, int]] = ...
) -> None\n"
};
void _init_py_Subtensor(py::module m) {
...
...
imperative/tablegen/generated/opdef.h.inl
浏览文件 @
28c6ebfe
...
...
@@ -1795,8 +1795,9 @@ class Subtensor : public OpDefImplBase<Subtensor> {
public:
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>> slice_items;
Subtensor() = default;
Subtensor(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::
string scope_ = {}): items(
items_) { set_scope(scope_); }
Subtensor(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::
vector<std::tuple<int32_t, int32_t, int32_t, int32_t>> slice_items_, std::string scope_ = {}): items(items_), slice_items(slice_
items_) { set_scope(scope_); }
};
class TQT : public OpDefImplBase<TQT> {
...
...
imperative/tablegen/generated/opdef.py.inl
浏览文件 @
28c6ebfe
...
...
@@ -1882,9 +1882,10 @@ SplitInst
py::class_<Subtensor, std::shared_ptr<Subtensor>, OpDef> SubtensorInst(m, "Subtensor");
SubtensorInst
.def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::
string>(), py::arg("
items"), py::arg("scope") = {})
.def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::
vector<std::tuple<int32_t, int32_t, int32_t, int32_t>>, std::string>(), py::arg("items"), py::arg("slice_
items"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("items", &Subtensor::items);
.def_readwrite("items", &Subtensor::items)
.def_readwrite("slice_items", &Subtensor::slice_items);
py::class_<TQT, std::shared_ptr<TQT>, OpDef> TQTInst(m, "TQT");
...
...
src/core/include/megbrain/ir/ops.td
浏览文件 @
28c6ebfe
...
...
@@ -380,7 +380,15 @@ class FancyIndexingBase<string name>: MgbHashableOp<name> {
);
}
def Subtensor: FancyIndexingBase<"Subtensor">;
def Subtensor: MgbHashableOp<"Subtensor"> {
let extraArguments = (ins
MgbArrayAttr<MgbTupleAttr<
[MgbI8Attr, MgbBoolAttr, MgbBoolAttr, MgbBoolAttr, MgbBoolAttr]>>:$items,
MgbArrayAttr<MgbTupleAttr<[MgbI32Attr, MgbI32Attr, MgbI32Attr, MgbI32Attr]>>:$slice_items
);
}
// def Subtensor: FancyIndexingBase<"Subtensor">;
def SetSubtensor: FancyIndexingBase<"SetSubtensor">;
def IncrSubtensor: FancyIndexingBase<"IncrSubtensor">;
def IndexingMultiAxisVec: FancyIndexingBase<"IndexingMultiAxisVec">;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录