Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
813628e2
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
813628e2
编写于
6月 29, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(opr): add interpolate trilinear
GitOrigin-RevId: 19a96ba58bdf645ceb15655dc2f9d29085c677a7
上级
2937ea0e
变更
30
隐藏空白更改
内联
并排
Showing
30 changed file
with
1057 addition
and
25 deletion
+1057
-25
dnn/include/megdnn/oprs/cv.h
dnn/include/megdnn/oprs/cv.h
+29
-0
dnn/scripts/opr_param_defs.py
dnn/scripts/opr_param_defs.py
+5
-0
dnn/src/common/handle_impl.h
dnn/src/common/handle_impl.h
+1
-0
dnn/src/common/opr_trait.h
dnn/src/common/opr_trait.h
+1
-0
dnn/src/common/resize.cpp
dnn/src/common/resize.cpp
+33
-0
dnn/src/cuda/handle_create.cpp
dnn/src/cuda/handle_create.cpp
+1
-0
dnn/src/cuda/resize/common.h
dnn/src/cuda/resize/common.h
+9
-0
dnn/src/cuda/resize/forward.cpp
dnn/src/cuda/resize/forward.cpp
+37
-0
dnn/src/cuda/resize/forward.cu
dnn/src/cuda/resize/forward.cu
+150
-0
dnn/src/cuda/resize/opr_impl.h
dnn/src/cuda/resize/opr_impl.h
+10
-0
dnn/src/naive/resize/opr_impl.cpp
dnn/src/naive/resize/opr_impl.cpp
+145
-0
dnn/src/naive/resize/opr_impl.h
dnn/src/naive/resize/opr_impl.h
+18
-0
dnn/test/cuda/resize.cpp
dnn/test/cuda/resize.cpp
+20
-0
dnn/test/naive/resize.cpp
dnn/test/naive/resize.cpp
+92
-0
imperative/python/megengine/functional/vision.py
imperative/python/megengine/functional/vision.py
+43
-16
imperative/python/test/unit/functional/test_functional.py
imperative/python/test/unit/functional/test_functional.py
+1
-1
imperative/python/test/unit/utils/test_network_node.py
imperative/python/test/unit/utils/test_network_node.py
+11
-0
imperative/src/impl/ops/resize.cpp
imperative/src/impl/ops/resize.cpp
+16
-2
imperative/tablegen/generated/hash.txt
imperative/tablegen/generated/hash.txt
+6
-6
imperative/tablegen/generated/opdef.cpp.inl
imperative/tablegen/generated/opdef.cpp.inl
+72
-0
imperative/tablegen/generated/opdef.cpy.inl
imperative/tablegen/generated/opdef.cpy.inl
+164
-0
imperative/tablegen/generated/opdef.h.inl
imperative/tablegen/generated/opdef.h.inl
+17
-0
imperative/tablegen/generated/opdef.py.inl
imperative/tablegen/generated/opdef.py.inl
+12
-0
src/core/include/megbrain/ir/ops.td
src/core/include/megbrain/ir/ops.td
+2
-0
src/opr/impl/imgproc.cpp
src/opr/impl/imgproc.cpp
+50
-0
src/opr/impl/imgproc.sereg.h
src/opr/impl/imgproc.sereg.h
+1
-0
src/opr/include/megbrain/opr/imgproc.h
src/opr/include/megbrain/opr/imgproc.h
+37
-0
src/opr/test/imgproc.cpp
src/opr/test/imgproc.cpp
+72
-0
src/serialization/impl/schema.fbs
src/serialization/impl/schema.fbs
+1
-0
src/serialization/impl/schema_v2.fbs
src/serialization/impl/schema_v2.fbs
+1
-0
未找到文件。
dnn/include/megdnn/oprs/cv.h
浏览文件 @
813628e2
...
...
@@ -245,6 +245,35 @@ protected:
size_t
workspace_in_bytes
);
};
class
Resize3DBase
:
public
OperatorBase
{
DEF_OPR_PARAM
(
Resize3D
);
DEF_OPR_IMPL
(
Resize3DBase
,
OperatorBase
,
1
,
1
);
public:
using
InterpolationMode
=
Param
::
InterpolationMode
;
protected:
void
check_layout_fwd
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
);
};
class
Resize3DForward
:
public
Resize3DBase
{
DEF_OPR_IMPL
(
Resize3DForward
,
Resize3DBase
,
1
,
1
);
public:
virtual
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
=
0
;
virtual
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
=
0
;
protected:
void
check_exec
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
size_t
workspace_in_bytes
);
};
using
Resize3D
=
Resize3DForward
;
/**
* \brief Remap opr.
*/
...
...
dnn/scripts/opr_param_defs.py
浏览文件 @
813628e2
...
...
@@ -965,6 +965,11 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
'ksize_d'
,
0
,
'ksize_h'
,
3
,
'ksize_w'
,
3
,
'anchor_d'
,
0
,
'anchor_h'
,
1
,
'anchor_w'
,
1
))
(
pdef
(
'Resize3D'
)
.
add_enum_alias
(
'InterpolationMode'
,
'WarpPerspectiveV1'
,
name_field
=
'imode'
)
.
add_enum_alias
(
'Format'
,
'Convolution3D'
,
default
=
1
)
.
add_fields
(
'bool'
,
'align_corners'
,
'false'
))
(
pdef
(
'TopK'
).
add_enum
(
'Mode'
,
...
...
dnn/src/common/handle_impl.h
浏览文件 @
813628e2
...
...
@@ -160,6 +160,7 @@ private:
cb(GaussianBlur) \
cb(Resize) \
cb(ResizeBackward) \
cb(Resize3D) \
cb(ParamPackConcat) \
cb(MaxTensorDiff) \
cb(MaskConvForward) \
...
...
dnn/src/common/opr_trait.h
浏览文件 @
813628e2
...
...
@@ -150,6 +150,7 @@ DEF(GroupNormBackward, 8, true, true);
DEF
(
MaskedFill
,
3
,
false
,
true
);
DEF
(
MultiHeadAttnForward
,
11
,
true
,
true
);
DEF
(
MultiHeadAttnBackward
,
15
,
true
,
true
);
DEF
(
Resize3D
,
2
,
true
,
false
);
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/common/resize.cpp
浏览文件 @
813628e2
...
...
@@ -111,6 +111,39 @@ std::tuple<float, int, float, int> ResizeBase::get_nearest_linear_coord(
int
ResizeBase
::
get_nearest_src
(
float
scale
,
int
size
,
int
idx
)
{
return
std
::
min
(
static_cast
<
int
>
(
idx
/
scale
),
size
-
1
);
}
void
Resize3DBase
::
check_layout_fwd
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
{
auto
errmsg
=
[
&
]()
{
return
megdnn_layout_msg
(
src
)
+
", "
+
", "
+
megdnn_layout_msg
(
dst
);
};
MEGDNN_MARK_USED_VAR
(
errmsg
);
megdnn_assert
(
param
().
format
==
Param
::
Format
::
NCDHW
,
"Resize3D only support NCDHW"
);
megdnn_assert
(
src
.
ndim
==
5
&&
dst
.
ndim
==
5
,
"shape dim mismatch: %s"
,
errmsg
().
c_str
());
megdnn_assert
(
src
.
dtype
==
dst
.
dtype
,
"dtype mismatch: %s"
,
errmsg
().
c_str
());
megdnn_assert
(
src
.
shape
[
0
]
==
dst
.
shape
[
0
],
"batch size mismatch: %s"
,
errmsg
().
c_str
());
megdnn_assert
(
src
.
shape
[
1
]
==
dst
.
shape
[
1
],
"channel size mismatch: %s"
,
errmsg
().
c_str
());
megdnn_assert_contiguous
(
src
);
megdnn_assert_contiguous
(
dst
);
auto
imode
=
param
().
imode
;
using
IMode
=
param
::
Resize3D
::
InterpolationMode
;
megdnn_assert
(
imode
==
IMode
::
INTER_LINEAR
,
"Resize3D only support TriLinear mode"
);
}
void
Resize3D
::
check_exec
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
size_t
workspace_in_bytes
)
{
check_layout_fwd
(
src
,
dst
);
auto
required_workspace_in_bytes
=
get_workspace_in_bytes
(
src
,
dst
);
megdnn_assert
(
workspace_in_bytes
>=
required_workspace_in_bytes
);
}
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/cuda/handle_create.cpp
浏览文件 @
813628e2
...
...
@@ -177,6 +177,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(WarpAffine);
MEGDNN_SPECIALIZE_CREATE_OPERATOR
(
GaussianBlur
);
MEGDNN_SPECIALIZE_CREATE_OPERATOR
(
Resize
);
MEGDNN_SPECIALIZE_CREATE_OPERATOR
(
ResizeBackward
);
MEGDNN_SPECIALIZE_CREATE_OPERATOR
(
Resize3D
);
MEGDNN_SPECIALIZE_CREATE_OPERATOR
(
ParamPackConcat
);
MEGDNN_SPECIALIZE_CREATE_OPERATOR
(
MaxTensorDiff
);
MEGDNN_SPECIALIZE_CREATE_OPERATOR
(
MaskConvForward
);
...
...
dnn/src/cuda/resize/common.h
浏览文件 @
813628e2
...
...
@@ -26,6 +26,15 @@ void backward_data_proxy(
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
cudaStream_t
stream
);
}
// namespace resize
namespace
resize3d
{
template
<
typename
ctype
>
void
resize3d_forward
(
const
bool
align_corners
,
const
ctype
*
iptr
,
ctype
*
optr
,
const
int
N
,
const
int
C
,
const
int
ID
,
const
int
IH
,
const
int
IW
,
const
int
OD
,
const
int
OH
,
const
int
OW
,
cudaStream_t
stream
);
}
// namespace resize3d
}
// namespace cuda
}
// namespace megdnn
...
...
dnn/src/cuda/resize/forward.cpp
浏览文件 @
813628e2
...
...
@@ -168,4 +168,41 @@ void ResizeImpl::exec(
}
}
size_t
Resize3DImpl
::
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
{
return
0
;
}
void
Resize3DImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
dst
,
_megdnn_workspace
workspace
)
{
check_exec
(
src
.
layout
,
dst
.
layout
,
workspace
.
size
);
size_t
out_depth
=
dst
.
layout
.
shape
[
2
];
size_t
out_height
=
dst
.
layout
.
shape
[
3
];
size_t
out_width
=
dst
.
layout
.
shape
[
4
];
size_t
in_depth
=
src
.
layout
.
shape
[
2
];
size_t
in_height
=
src
.
layout
.
shape
[
3
];
size_t
in_width
=
src
.
layout
.
shape
[
4
];
bool
align_corners
=
param
().
align_corners
;
auto
stream
=
cuda_stream
(
this
->
handle
());
if
(
src
.
layout
.
dtype
==
dtype
::
Float32
{})
{
resize3d
::
resize3d_forward
(
align_corners
,
src
.
ptr
<
dt_float32
>
(),
dst
.
ptr
<
dt_float32
>
(),
src
.
layout
[
0
],
src
.
layout
[
1
],
in_depth
,
in_height
,
in_width
,
out_depth
,
out_height
,
out_width
,
stream
);
#if !MEGDNN_DISABLE_FLOAT16
}
else
if
(
src
.
layout
.
dtype
==
dtype
::
Float16
{})
{
resize3d
::
resize3d_forward
(
align_corners
,
src
.
ptr
<
dt_float16
>
(),
dst
.
ptr
<
dt_float16
>
(),
src
.
layout
[
0
],
src
.
layout
[
1
],
in_depth
,
in_height
,
in_width
,
out_depth
,
out_height
,
out_width
,
stream
);
#endif
}
else
{
megdnn_throw
(
ssprintf
(
"unsupported dtype: %s for Resize3D"
,
src
.
layout
.
dtype
.
name
()));
}
}
// vim: syntax=cpp.doxygen
dnn/src/cuda/resize/forward.cu
浏览文件 @
813628e2
...
...
@@ -308,6 +308,156 @@ DNN_INC_FLOAT16(INST(dt_float16))
INST
(
int8_t
);
#undef INST
}
// namespace resize
namespace
resize3d
{
__device__
__forceinline__
static
float
pixel_get_src_index
(
float
scale
,
int64_t
dst_index
,
bool
align_corners
)
{
if
(
align_corners
)
{
return
scale
*
dst_index
;
}
else
{
float
src_idx
=
scale
*
(
dst_index
+
0.5
f
)
-
0.5
f
;
return
src_idx
<
0.
f
?
0.
f
:
src_idx
;
}
}
__device__
__forceinline__
static
size_t
index_getter
(
int
n
,
int
c
,
int
d
,
int
h
,
int
w
,
const
int
N
,
const
int
C
,
const
int
D
,
const
int
H
,
const
int
W
)
{
return
n
*
C
*
D
*
H
*
W
+
c
*
D
*
H
*
W
+
d
*
H
*
W
+
h
*
W
+
w
;
}
template
<
typename
ctype
>
__global__
void
trilinear_forward
(
const
int
num_kernels
,
const
float
rdepth
,
const
float
rheight
,
const
float
rwidth
,
const
bool
align_corners
,
const
ctype
*
iptr
,
ctype
*
optr
,
const
int
N
,
const
int
C
,
const
int
ID
,
const
int
IH
,
const
int
IW
,
const
int
OD
,
const
int
OH
,
const
int
OW
)
{
int
index
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
index
<
num_kernels
)
{
const
int
w2
=
(
index
%
(
OH
*
OW
))
%
OW
;
const
int
h2
=
(
index
%
(
OH
*
OW
))
/
OW
;
const
int
t2
=
index
/
(
OH
*
OW
);
if
(
ID
==
OD
&&
IH
==
OH
&&
IW
==
OW
)
{
const
int
t1
=
t2
;
const
int
h1
=
h2
;
const
int
w1
=
w2
;
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
const
ctype
val
=
iptr
[
index_getter
(
n
,
c
,
t1
,
h1
,
w1
,
N
,
C
,
ID
,
IH
,
IW
)];
optr
[
index_getter
(
n
,
c
,
t2
,
h2
,
w2
,
N
,
C
,
OD
,
OH
,
OW
)]
=
val
;
}
}
return
;
}
const
float
t1r
=
pixel_get_src_index
(
rdepth
,
t2
,
align_corners
);
const
int
t1
=
t1r
;
const
int
t1p
=
(
t1
<
ID
-
1
)
?
1
:
0
;
const
float
t1lambda
=
t1r
-
t1
;
const
float
t0lambda
=
static_cast
<
float
>
(
1
)
-
t1lambda
;
const
float
h1r
=
pixel_get_src_index
(
rheight
,
h2
,
align_corners
);
const
int
h1
=
h1r
;
const
int
h1p
=
(
h1
<
IH
-
1
)
?
1
:
0
;
const
float
h1lambda
=
h1r
-
h1
;
const
float
h0lambda
=
static_cast
<
float
>
(
1
)
-
h1lambda
;
const
float
w1r
=
pixel_get_src_index
(
rwidth
,
w2
,
align_corners
);
const
int
w1
=
w1r
;
const
int
w1p
=
(
w1
<
IW
-
1
)
?
1
:
0
;
const
float
w1lambda
=
w1r
-
w1
;
const
float
w0lambda
=
static_cast
<
float
>
(
1
)
-
w1lambda
;
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
const
float
val
=
t0lambda
*
(
h0lambda
*
(
w0lambda
*
iptr
[
index_getter
(
n
,
c
,
t1
,
h1
,
w1
,
N
,
C
,
ID
,
IH
,
IW
)]
+
w1lambda
*
iptr
[
index_getter
(
n
,
c
,
t1
,
h1
,
w1
+
w1p
,
N
,
C
,
ID
,
IH
,
IW
)])
+
h1lambda
*
(
w0lambda
*
iptr
[
index_getter
(
n
,
c
,
t1
,
h1
+
h1p
,
w1
,
N
,
C
,
ID
,
IH
,
IW
)]
+
w1lambda
*
iptr
[
index_getter
(
n
,
c
,
t1
,
h1
+
h1p
,
w1
+
w1p
,
N
,
C
,
ID
,
IH
,
IW
)]))
+
t1lambda
*
(
h0lambda
*
(
w0lambda
*
iptr
[
index_getter
(
n
,
c
,
t1
+
t1p
,
h1
,
w1
,
N
,
C
,
ID
,
IH
,
IW
)]
+
w1lambda
*
iptr
[
index_getter
(
n
,
c
,
t1
+
t1p
,
h1
,
w1
+
w1p
,
N
,
C
,
ID
,
IH
,
IW
)])
+
h1lambda
*
(
w0lambda
*
iptr
[
index_getter
(
n
,
c
,
t1
+
t1p
,
h1
+
h1p
,
w1
,
N
,
C
,
ID
,
IH
,
IW
)]
+
w1lambda
*
iptr
[
index_getter
(
n
,
c
,
t1
+
t1p
,
h1
+
h1p
,
w1
+
w1p
,
N
,
C
,
ID
,
IH
,
IW
)]));
optr
[
index_getter
(
n
,
c
,
t2
,
h2
,
w2
,
N
,
C
,
OD
,
OH
,
OW
)]
=
static_cast
<
ctype
>
(
val
);
}
}
}
}
__host__
__forceinline__
static
float
get_scale
(
int
input_size
,
int
output_size
,
bool
align_corners
)
{
if
(
align_corners
)
{
if
(
output_size
>
1
)
{
return
static_cast
<
float
>
(
input_size
-
1
)
/
(
output_size
-
1
);
}
else
{
return
0.
f
;
}
}
else
{
return
static_cast
<
float
>
(
input_size
)
/
output_size
;
}
}
template
<
typename
ctype
>
void
resize3d_forward
(
const
bool
align_corners
,
const
ctype
*
iptr
,
ctype
*
optr
,
const
int
N
,
const
int
C
,
const
int
ID
,
const
int
IH
,
const
int
IW
,
const
int
OD
,
const
int
OH
,
const
int
OW
,
cudaStream_t
stream
)
{
const
size_t
num_kernels
=
OD
*
OH
*
OW
;
const
size_t
num_threads
=
512
;
float
rdepth
=
get_scale
(
ID
,
OD
,
align_corners
);
float
rheight
=
get_scale
(
IH
,
OH
,
align_corners
);
float
rwidth
=
get_scale
(
IW
,
OW
,
align_corners
);
trilinear_forward
<
ctype
>
<<<
(
num_kernels
+
num_threads
-
1
)
/
num_threads
,
num_threads
,
0
,
stream
>>>
(
num_kernels
,
rdepth
,
rheight
,
rwidth
,
align_corners
,
iptr
,
optr
,
N
,
C
,
ID
,
IH
,
IW
,
OD
,
OH
,
OW
);
}
#define INST(ctype) \
template void resize3d_forward( \
const bool, const ctype*, ctype*, const int, const int, const int, \
const int, const int, const int, const int, const int, cudaStream_t);
INST
(
float
)
DNN_INC_FLOAT16
(
INST
(
dt_float16
))
#undef INST
}
// namespace resize3d
}
// namespace cuda
}
// namespace megdnn
...
...
dnn/src/cuda/resize/opr_impl.h
浏览文件 @
813628e2
...
...
@@ -24,6 +24,16 @@ public:
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
};
class
Resize3DImpl
final
:
public
Resize3D
{
public:
using
Resize3D
::
Resize3D
;
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
override
;
};
}
// namespace cuda
}
// namespace megdnn
...
...
dnn/src/naive/resize/opr_impl.cpp
浏览文件 @
813628e2
...
...
@@ -545,4 +545,149 @@ void ResizeBackwardImpl::exec(
}
}
template
<
typename
ctype
>
void
Resize3DImpl
::
kern_naive
(
const
float
rdepth
,
const
float
rheight
,
const
float
rwidth
,
const
bool
align_corners
,
const
ctype
*
iptr
,
ctype
*
optr
,
const
int
N
,
const
int
C
,
const
int
ID
,
const
int
IH
,
const
int
IW
,
const
int
OD
,
const
int
OH
,
const
int
OW
)
{
auto
pixel_get_src_index
=
[](
float
scale
,
int64_t
dst_index
,
bool
align_corners
)
{
if
(
align_corners
)
{
return
scale
*
dst_index
;
}
else
{
float
src_idx
=
scale
*
(
dst_index
+
0.5
f
)
-
0.5
f
;
return
src_idx
<
0.
f
?
0.
f
:
src_idx
;
}
};
auto
i_index
=
[
&
](
int
in
,
int
ic
,
int
id
,
int
ih
,
int
iw
)
->
int
{
return
in
*
C
*
ID
*
IH
*
IW
+
ic
*
ID
*
IH
*
IW
+
id
*
IH
*
IW
+
ih
*
IW
+
iw
;
};
auto
o_index
=
[
&
](
int
in
,
int
ic
,
int
id
,
int
ih
,
int
iw
)
->
int
{
return
in
*
C
*
OD
*
OH
*
OW
+
ic
*
OD
*
OH
*
OW
+
id
*
OH
*
OW
+
ih
*
OW
+
iw
;
};
for
(
int
t2
=
0
;
t2
<
OD
;
++
t2
)
{
for
(
int
h2
=
0
;
h2
<
OH
;
++
h2
)
{
for
(
int
w2
=
0
;
w2
<
OW
;
++
w2
)
{
const
float
t1r
=
pixel_get_src_index
(
rdepth
,
t2
,
align_corners
);
const
int
t1
=
t1r
;
const
int
t1p
=
(
t1
<
ID
-
1
)
?
1
:
0
;
const
float
t1lambda
=
t1r
-
t1
;
const
float
t0lambda
=
static_cast
<
float
>
(
1
)
-
t1lambda
;
const
float
h1r
=
pixel_get_src_index
(
rheight
,
h2
,
align_corners
);
const
int
h1
=
h1r
;
const
int
h1p
=
(
h1
<
IH
-
1
)
?
1
:
0
;
const
float
h1lambda
=
h1r
-
h1
;
const
float
h0lambda
=
static_cast
<
float
>
(
1
)
-
h1lambda
;
const
float
w1r
=
pixel_get_src_index
(
rwidth
,
w2
,
align_corners
);
const
int
w1
=
w1r
;
const
int
w1p
=
(
w1
<
IW
-
1
)
?
1
:
0
;
const
float
w1lambda
=
w1r
-
w1
;
const
float
w0lambda
=
static_cast
<
float
>
(
1
)
-
w1lambda
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
const
float
val
=
t0lambda
*
(
h0lambda
*
(
w0lambda
*
iptr
[
i_index
(
n
,
c
,
t1
,
h1
,
w1
)]
+
w1lambda
*
iptr
[
i_index
(
n
,
c
,
t1
,
h1
,
w1
+
w1p
)])
+
h1lambda
*
(
w0lambda
*
iptr
[
i_index
(
n
,
c
,
t1
,
h1
+
h1p
,
w1
)]
+
w1lambda
*
iptr
[
i_index
(
n
,
c
,
t1
,
h1
+
h1p
,
w1
+
w1p
)]))
+
t1lambda
*
(
h0lambda
*
(
w0lambda
*
iptr
[
i_index
(
n
,
c
,
t1
+
t1p
,
h1
,
w1
)]
+
w1lambda
*
iptr
[
i_index
(
n
,
c
,
t1
+
t1p
,
h1
,
w1
+
w1p
)])
+
h1lambda
*
(
w0lambda
*
iptr
[
i_index
(
n
,
c
,
t1
+
t1p
,
h1
+
h1p
,
w1
)]
+
w1lambda
*
iptr
[
i_index
(
n
,
c
,
t1
+
t1p
,
h1
+
h1p
,
w1
+
w1p
)]));
optr
[
o_index
(
n
,
c
,
t2
,
h2
,
w2
)]
=
static_cast
<
ctype
>
(
val
);
}
}
}
}
}
}
#define INST(ctype) \
template void Resize3DImpl::kern_naive( \
const float, const float, const float, const bool, const ctype*, ctype*, \
const int, const int, const int, const int, const int, const int, \
const int, const int)
INST
(
dt_float32
);
DNN_INC_FLOAT16
(
INST
(
dt_float16
));
#undef INST
void
Resize3DImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
check_exec
(
src
.
layout
,
dst
.
layout
,
workspace
.
size
);
bool
align_corners
=
param
().
align_corners
;
size_t
N
=
src
.
layout
.
shape
[
0
];
size_t
C
=
src
.
layout
.
shape
[
1
];
size_t
OD
=
dst
.
layout
.
shape
[
2
];
size_t
OH
=
dst
.
layout
.
shape
[
3
];
size_t
OW
=
dst
.
layout
.
shape
[
4
];
size_t
ID
=
src
.
layout
.
shape
[
2
];
size_t
IH
=
src
.
layout
.
shape
[
3
];
size_t
IW
=
src
.
layout
.
shape
[
4
];
auto
get_scale
=
[](
int
input_size
,
int
output_size
,
bool
align_corners
)
->
float
{
if
(
align_corners
)
{
if
(
output_size
>
1
)
{
return
static_cast
<
float
>
(
input_size
-
1
)
/
(
output_size
-
1
);
}
else
{
return
0.
f
;
}
}
else
{
return
static_cast
<
float
>
(
input_size
)
/
output_size
;
}
};
float
rdepth
=
get_scale
(
ID
,
OD
,
align_corners
);
float
rheight
=
get_scale
(
IH
,
OH
,
align_corners
);
float
rwidth
=
get_scale
(
IW
,
OW
,
align_corners
);
if
(
src
.
layout
.
dtype
==
dtype
::
Float32
{})
{
Resize3DImpl
::
kern_naive
(
rdepth
,
rheight
,
rwidth
,
align_corners
,
src
.
ptr
<
dt_float32
>
(),
dst
.
ptr
<
dt_float32
>
(),
N
,
C
,
ID
,
IH
,
IW
,
OD
,
OH
,
OW
);
#if !MEGDNN_DISABLE_FLOAT16
}
else
if
(
src
.
layout
.
dtype
==
dtype
::
Float16
{})
{
Resize3DImpl
::
kern_naive
(
rdepth
,
rheight
,
rwidth
,
align_corners
,
src
.
ptr
<
dt_float16
>
(),
dst
.
ptr
<
dt_float16
>
(),
N
,
C
,
ID
,
IH
,
IW
,
OD
,
OH
,
OW
);
#endif
}
else
{
megdnn_throw
(
ssprintf
(
"unsupported dtype: %s for Resize3D"
,
src
.
layout
.
dtype
.
name
()));
}
}
size_t
Resize3DImpl
::
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
0
;
}
// vim: syntax=cpp.doxygen
dnn/src/naive/resize/opr_impl.h
浏览文件 @
813628e2
...
...
@@ -83,6 +83,24 @@ private:
int
N
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
);
};
class
Resize3DImpl
final
:
public
Resize3D
{
public:
using
Resize3D
::
Resize3D
;
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
override
;
private:
template
<
typename
ctype
>
void
kern_naive
(
const
float
rdepth
,
const
float
rheight
,
const
float
rwidth
,
const
bool
align_corners
,
const
ctype
*
iptr
,
ctype
*
optr
,
const
int
N
,
const
int
C
,
const
int
ID
,
const
int
IH
,
const
int
IW
,
const
int
OD
,
const
int
OH
,
const
int
OW
);
};
}
// namespace naive
}
// namespace megdnn
...
...
dnn/test/cuda/resize.cpp
浏览文件 @
813628e2
...
...
@@ -193,6 +193,26 @@ TEST_F(CUDA, RESIZE_BACKWARD) {
}
}
TEST_F
(
CUDA
,
RESIZE3D_NCDHW
)
{
using
IMode
=
param
::
Resize3D
::
InterpolationMode
;
using
Format
=
param
::
Resize3D
::
Format
;
auto
ac_param
=
param
::
Resize3D
{
IMode
::
LINEAR
,
Format
::
NCDHW
,
true
};
auto
nac_param
=
param
::
Resize3D
{
IMode
::
LINEAR
,
Format
::
NCDHW
,
false
};
auto
run
=
[
&
](
DType
d
,
TensorShape
ishape
,
TensorShape
oshape
)
{
Checker
<
Resize3D
>
checker
(
handle_cuda
());
checker
.
set_param
(
ac_param
).
set_dtype
(
0
,
d
).
set_dtype
(
1
,
d
).
execs
(
{
ishape
,
oshape
});
checker
.
set_param
(
nac_param
).
execs
({
ishape
,
oshape
});
};
for
(
auto
&&
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
dtype
::
Float16
()})
{
run
(
dtype
,
{
1
,
1
,
2
,
2
,
2
},
{
1
,
1
,
4
,
4
,
4
});
run
(
dtype
,
{
2
,
2
,
2
,
3
,
4
},
{
2
,
2
,
2
,
3
,
6
});
run
(
dtype
,
{
2
,
2
,
2
,
3
,
4
},
{
2
,
2
,
3
,
4
,
5
});
}
}
#if MEGDNN_WITH_BENCHMARK
TEST_F
(
CUDA
,
BENCHMARK_RESIZE_CV
)
{
...
...
dnn/test/naive/resize.cpp
浏览文件 @
813628e2
...
...
@@ -61,3 +61,95 @@ TEST_F(NAIVE, RESIZE_NCHW4) {
.
execs
({
arg
.
src
,
arg
.
dst
});
}
}
TEST_F
(
NAIVE
,
RESIZE3D_NCDHW
)
{
using
IMode
=
param
::
Resize3D
::
InterpolationMode
;
using
Format
=
param
::
Resize3D
::
Format
;
auto
ac_param
=
param
::
Resize3D
{
IMode
::
LINEAR
,
Format
::
NCDHW
,
true
};
auto
nac_param
=
param
::
Resize3D
{
IMode
::
LINEAR
,
Format
::
NCDHW
,
false
};
Checker
<
Resize3D
>
checker
(
handle
());
checker
.
set_param
(
nac_param
).
exect
(
Testcase
{
TensorValue
(
{
1
,
1
,
2
,
2
,
2
},
dtype
::
Float32
(),
{
0.
,
1.
,
2.
,
3.
,
4.
,
5.
,
6.
,
7.
}),
{}},
Testcase
{
{},
TensorValue
(
{
1
,
1
,
4
,
4
,
4
},
dtype
::
Float32
(),
{
0.
,
0.25
,
0.75
,
1.
,
0.5
,
0.75
,
1.25
,
1.5
,
1.5
,
1.75
,
2.25
,
2.5
,
2.
,
2.25
,
2.75
,
3.
,
1.
,
1.25
,
1.75
,
2.
,
1.5
,
1.75
,
2.25
,
2.5
,
2.5
,
2.75
,
3.25
,
3.5
,
3.
,
3.25
,
3.75
,
4.
,
3.
,
3.25
,
3.75
,
4.
,
3.5
,
3.75
,
4.25
,
4.5
,
4.5
,
4.75
,
5.25
,
5.5
,
5.
,
5.25
,
5.75
,
6.
,
4.
,
4.25
,
4.75
,
5.
,
4.5
,
4.75
,
5.25
,
5.5
,
5.5
,
5.75
,
6.25
,
6.5
,
6.
,
6.25
,
6.75
,
7.
})});
checker
.
set_param
(
ac_param
).
exect
(
Testcase
{
TensorValue
(
{
1
,
1
,
2
,
2
,
2
},
dtype
::
Float32
(),
{
0.
,
1.
,
2.
,
3.
,
4.
,
5.
,
6.
,
7.
}),
{}},
Testcase
{
{},
TensorValue
(
{
1
,
1
,
4
,
4
,
4
},
dtype
::
Float32
(),
{
0.
,
0.3333333
,
0.6666667
,
1.
,
0.6666667
,
1.
,
1.3333333
,
1.6666666
,
1.3333334
,
1.6666667
,
1.9999999
,
2.3333333
,
2.
,
2.3333333
,
2.6666665
,
3.
,
1.3333334
,
1.6666666
,
2.0000002
,
2.3333335
,
2.
,
2.333333
,
2.6666667
,
2.9999998
,
2.6666665
,
3.
,
3.3333333
,
3.6666665
,
3.3333333
,
3.6666665
,
4.
,
4.3333335
,
2.6666667
,
3.
,
3.3333337
,
3.6666667
,
3.3333335
,
3.6666663
,
4.
,
4.333333
,
3.9999998
,
4.333333
,
4.6666665
,
5.
,
4.6666665
,
5.
,
5.3333335
,
5.666667
,
4.
,
4.333333
,
4.666667
,
5.
,
4.6666665
,
4.9999995
,
5.3333335
,
5.6666665
,
5.333333
,
5.6666665
,
6.
,
6.3333335
,
6.
,
6.333333
,
6.666667
,
7.
})});
checker
.
set_param
(
nac_param
).
exect
(
Testcase
{
TensorValue
(
{
1
,
1
,
2
,
2
,
2
},
dtype
::
Float16
(),
{
0.
,
1.
,
2.
,
3.
,
4.
,
5.
,
6.
,
7.
}),
{}},
Testcase
{
{},
TensorValue
(
{
1
,
1
,
4
,
4
,
4
},
dtype
::
Float16
(),
{
0.
,
0.25
,
0.75
,
1.
,
0.5
,
0.75
,
1.25
,
1.5
,
1.5
,
1.75
,
2.25
,
2.5
,
2.
,
2.25
,
2.75
,
3.
,
1.
,
1.25
,
1.75
,
2.
,
1.5
,
1.75
,
2.25
,
2.5
,
2.5
,
2.75
,
3.25
,
3.5
,
3.
,
3.25
,
3.75
,
4.
,
3.
,
3.25
,
3.75
,
4.
,
3.5
,
3.75
,
4.25
,
4.5
,
4.5
,
4.75
,
5.25
,
5.5
,
5.
,
5.25
,
5.75
,
6.
,
4.
,
4.25
,
4.75
,
5.
,
4.5
,
4.75
,
5.25
,
5.5
,
5.5
,
5.75
,
6.25
,
6.5
,
6.
,
6.25
,
6.75
,
7.
})});
checker
.
set_param
(
ac_param
).
exect
(
Testcase
{
TensorValue
(
{
1
,
1
,
2
,
2
,
2
},
dtype
::
Float16
(),
{
0.
,
1.
,
2.
,
3.
,
4.
,
5.
,
6.
,
7.
}),
{}},
Testcase
{
{},
TensorValue
(
{
1
,
1
,
4
,
4
,
4
},
dtype
::
Float16
(),
{
0.
,
0.3333333
,
0.6666667
,
1.
,
0.6666667
,
1.
,
1.3333333
,
1.6666666
,
1.3333334
,
1.6666667
,
1.9999999
,
2.3333333
,
2.
,
2.3333333
,
2.6666665
,
3.
,
1.3333334
,
1.6666666
,
2.0000002
,
2.3333335
,
2.
,
2.333333
,
2.6666667
,
2.9999998
,
2.6666665
,
3.
,
3.3333333
,
3.6666665
,
3.3333333
,
3.6666665
,
4.
,
4.3333335
,
2.6666667
,
3.
,
3.3333337
,
3.6666667
,
3.3333335
,
3.6666663
,
4.
,
4.333333
,
3.9999998
,
4.333333
,
4.6666665
,
5.
,
4.6666665
,
5.
,
5.3333335
,
5.666667
,
4.
,
4.333333
,
4.666667
,
5.
,
4.6666665
,
4.9999995
,
5.3333335
,
5.6666665
,
5.333333
,
5.6666665
,
6.
,
6.3333335
,
6.
,
6.333333
,
6.666667
,
7.
})});
}
imperative/python/megengine/functional/vision.py
浏览文件 @
813628e2
...
...
@@ -474,7 +474,8 @@ def interpolate(
size: the size of the output tensor. Default: None
scale_factor: scaling factor of the output tensor. Default: None
mode: interpolation methods, acceptable values are:
"bilinear", "linear", "bicubic" and "nearest". Default: "bilinear"
"bilinear", "linear", "trilinear", "bicubic" and "nearest". Default: "bilinear"
"trilinear" is valid only when inp is a 5D-tensor
align_corners: This only has an effect when ``mode``
is "bilinear" or "linear". Geometrically, we consider the pixels of the input
and output as squares rather than points. If set to ``True``, the input
...
...
@@ -500,9 +501,9 @@ def interpolate(
>>> np.testing.assert_allclose(out.numpy(), out2.numpy())
"""
mode
=
mode
.
lower
()
if
mode
not
in
[
"bilinear"
,
"linear"
,
"bicubic"
,
"nearest"
]:
if
mode
not
in
[
"bilinear"
,
"linear"
,
"
trilinear"
,
"
bicubic"
,
"nearest"
]:
raise
ValueError
(
"unsupported interpolate mode: {}"
.
format
(
mode
))
if
mode
not
in
[
"bilinear"
,
"linear"
]:
if
mode
not
in
[
"bilinear"
,
"linear"
,
"trilinear"
]:
if
align_corners
is
not
None
:
raise
ValueError
(
"align_corners option can only be set in the bilinear/linear interpolating mode"
...
...
@@ -514,14 +515,22 @@ def interpolate(
if
mode
==
"linear"
:
inp
=
expand_dims
(
inp
,
3
)
if
inp
.
ndim
!=
4
:
raise
ValueError
(
"shape of input tensor must correspond to the operartion mode"
)
if
mode
==
"trilinear"
:
assert
(
inp
.
ndim
==
5
),
"under trilinear mode, input tensor must have 5 dimensions"
else
:
assert
(
inp
.
ndim
==
4
),
"shape of input tensor must correspond to the operartion mode"
def
get_dsize
(
scale_factor
):
if
isinstance
(
scale_factor
,
(
float
,
int
)):
scale_factor
=
float
(
scale_factor
)
if
mode
==
"linear"
:
scale_factor
=
(
scale_factor
,
float
(
1
))
elif
mode
==
"trilinear"
:
scale_factor
=
(
scale_factor
,
scale_factor
,
scale_factor
)
else
:
scale_factor
=
(
scale_factor
,
scale_factor
)
else
:
...
...
@@ -530,21 +539,28 @@ def interpolate(
"under linear mode, scale_factor can only be single value"
)
assert
len
(
scale_factor
)
==
2
,
"shape of scale_factor must be equal to (2, )"
assert
isinstance
(
scale_factor
[
0
],
float
)
and
isinstance
(
scale_factor
[
1
],
float
),
"scale_factor must be float type"
dsize
=
tuple
(
if
mode
==
"trilinear"
:
assert
(
len
(
scale_factor
)
==
3
),
f
"shape of scale_factor of interpolate-
{
mode
}
must be equal to (3, )"
else
:
assert
(
len
(
scale_factor
)
==
2
),
f
"shape of scale_factor of interpolate-
{
mode
}
must be equal to (2, )"
assert
all
(
isinstance
(
x
,
(
float
,
int
))
for
x
in
scale_factor
),
f
"scale_factor of interpolate must be float/int type"
dsize
=
[
floor
(
Tensor
(
inp
.
shape
[
i
+
2
]
*
scale_factor
[
i
]
,
inp
.
shape
[
i
+
2
]
*
float
(
scale_factor
[
i
])
,
dtype
=
"float32"
,
device
=
inp
.
device
,
)
)
for
i
in
range
(
2
)
)
dsize
=
concat
(
[
dsize
[
0
],
dsize
[
1
]]
,
axis
=
0
)
for
i
in
range
(
len
(
scale_factor
)
)
]
dsize
=
concat
(
dsize
,
axis
=
0
)
return
dsize
if
size
is
None
:
...
...
@@ -557,13 +573,24 @@ def interpolate(
raise
ValueError
(
"scale_factor must be None when size is provided"
)
if
isinstance
(
size
,
int
):
size
=
(
size
,
1
)
if
mode
==
"trilinear"
:
size
=
(
size
,
1
,
1
)
else
:
size
=
(
size
,
1
)
else
:
if
mode
==
"linear"
:
raise
ValueError
(
"under linear mode, size can only be single value"
)
dsize
=
size
if
not
align_corners
:
if
mode
==
"trilinear"
:
if
inp
.
dtype
==
np
.
float16
:
inp
=
inp
.
astype
(
"float32"
)
op
=
builtin
.
Resize3D
(
imode
=
"linear"
,
format
=
"NCDHW"
,
align_corners
=
align_corners
)
shape
=
astensor1d
(
dsize
,
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
(
ret
,)
=
apply
(
op
,
inp
,
shape
)
elif
not
align_corners
:
# fastpath for interpolate
mode_map
=
{
"linear"
:
"linear"
,
...
...
imperative/python/test/unit/functional/test_functional.py
浏览文件 @
813628e2
...
...
@@ -232,7 +232,7 @@ def test_interpolate():
def
error_shape_linear_interpolate
():
inp
=
tensor
(
np
.
arange
(
1
,
5
,
dtype
=
np
.
float32
).
reshape
(
1
,
1
,
2
,
2
))
with
pytest
.
raises
(
Value
Error
):
with
pytest
.
raises
(
Assertion
Error
):
F
.
vision
.
interpolate
(
inp
,
scale_factor
=
2.0
,
mode
=
"linear"
)
def
inappropriate_scale_linear_interpolate
():
...
...
imperative/python/test/unit/utils/test_network_node.py
浏览文件 @
813628e2
...
...
@@ -465,6 +465,17 @@ def test_resize():
check_pygraph_dump
(
fwd
,
[
x
],
[
out
])
def
test_resize3d
():
x
=
Tensor
(
np
.
random
.
randn
(
10
,
3
,
32
,
32
,
32
))
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
x
):
return
F
.
vision
.
interpolate
(
x
,
size
=
(
16
,
16
,
16
),
mode
=
"trilinear"
)
out
=
fwd
(
x
)
check_pygraph_dump
(
fwd
,
[
x
],
[
out
])
def
test_index_onehot
():
src
=
Tensor
([[
1.0
,
2.0
]])
index
=
Tensor
([
0
])
...
...
imperative/src/impl/ops/resize.cpp
浏览文件 @
813628e2
...
...
@@ -6,7 +6,7 @@
namespace
mgb
{
namespace
imperative
{
namespace
{
namespace
resize
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
Resize
&>
(
def
);
...
...
@@ -16,7 +16,21 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
}
OP_TRAIT_REG
(
Resize
,
Resize
).
apply_on_var_node
(
apply_on_var_node
).
fallback
();
}
// anonymous namespace
}
// namespace resize
namespace
resize3d
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
Resize3D
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
2
);
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
Resize3D
::
make
(
inputs
[
0
],
inputs
[
1
],
op
.
param
(),
config
);
}
OP_TRAIT_REG
(
Resize3D
,
Resize3D
).
apply_on_var_node
(
apply_on_var_node
).
fallback
();
}
// namespace resize3d
}
// namespace imperative
}
// namespace mgb
...
...
imperative/tablegen/generated/hash.txt
浏览文件 @
813628e2
2
0aa8ae7e128c1e24564ce68389307cc
../../dnn/scripts/opr_param_defs.py
9e9636d66694dd7d5a7853247a5406f9
../../src/core/include/megbrain/ir/ops.td
e4489c2e1ea2b680d61c352842e56929
generated/opdef.h.inl
fd27534146a1cfcc791e40b2bb53207
6 generated/opdef.cpp.inl
6
754eaa59ef19178eba41e99e418790c
generated/opdef.py.inl
df66a3089aa6c12e5b1d943cd3d20e80
generated/opdef.cpy.inl
2
9b2127eb4034bf24e473945d70ead4a
../../dnn/scripts/opr_param_defs.py
639ff50d64fcb78374de266c88942c2c
../../src/core/include/megbrain/ir/ops.td
16654743e01160eeee879107cc4cac41
generated/opdef.h.inl
97c541ed45b0be98f1ac2700f5b4d8a
6 generated/opdef.cpp.inl
6
f9c6a7a1d71cca195c1e30743a1f542
generated/opdef.py.inl
806c5ceb34f571fc5c9d98d2ca8cad63
generated/opdef.cpy.inl
911001ef0dd771024919f7a1a3a009db generated/enum_macro.h
imperative/tablegen/generated/opdef.cpp.inl
浏览文件 @
813628e2
...
...
@@ -7044,6 +7044,78 @@ OP_TRAIT_REG(Resize, Resize)
.props(Resize_props_impl)
.make_name(Resize_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Resize3D);
namespace {
size_t Resize3D_hash_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<Resize3D>();
static_cast<void>(op_);
size_t val = mgb::hash(op_.dyn_typeinfo());
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.imode));
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.format));
val = mgb::hash_pair_combine(val, mgb::hash(op_.align_corners));
return val;
}
bool Resize3D_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
auto &&a_ = lhs_.cast_final_safe<Resize3D>(),
&&b_ = rhs_.cast_final_safe<Resize3D>();
static_cast<void>(a_);
static_cast<void>(b_);
if (a_.imode != b_.imode) return false;
if (a_.format != b_.format) return false;
if (a_.align_corners != b_.align_corners) return false;
return true;
}
std::vector<std::pair<const char*, std::string>> Resize3D_props_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<Resize3D>();
static_cast<void>(op_);
std::vector<std::pair<const char*, std::string>> props_;
switch (op_.imode){
case Resize3D::InterpolationMode::NEAREST:
props_.emplace_back("imode", "NEAREST");
break;
case Resize3D::InterpolationMode::LINEAR:
props_.emplace_back("imode", "LINEAR");
break;
case Resize3D::InterpolationMode::AREA:
props_.emplace_back("imode", "AREA");
break;
case Resize3D::InterpolationMode::CUBIC:
props_.emplace_back("imode", "CUBIC");
break;
case Resize3D::InterpolationMode::LANCZOS4:
props_.emplace_back("imode", "LANCZOS4");
break;
default:
props_.emplace_back("imode", "INVALID");
break;
}
switch (op_.format){
case Resize3D::Format::NCDHW:
props_.emplace_back("format", "NCDHW");
break;
case Resize3D::Format::NDHWC:
props_.emplace_back("format", "NDHWC");
break;
default:
props_.emplace_back("format", "INVALID");
break;
}
props_.emplace_back("align_corners", std::to_string(op_.align_corners));
return props_;
}
std::string Resize3D_make_name_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<Resize3D>();
static_cast<void>(op_);
return "Resize3D";
}
} // anonymous namespace
OP_TRAIT_REG(Resize3D, Resize3D)
.hash(Resize3D_hash_impl)
.is_same_st(Resize3D_is_same_st_impl)
.props(Resize3D_props_impl)
.make_name(Resize3D_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(SVD);
namespace {
...
...
imperative/tablegen/generated/opdef.cpy.inl
浏览文件 @
813628e2
...
...
@@ -20536,6 +20536,169 @@ void _init_py_Resize(py::module m) {
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Resize::typeinfo(), &py_type).second);
}
void _init_py_Resize3D_InterpolationMode(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<Resize3D::InterpolationMode>::type;
Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "InterpolationMode", reinterpret_cast<PyObject*>(e_type)) >= 0);
}
void _init_py_Resize3D_Format(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<Resize3D::Format>::type;
Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "Format", reinterpret_cast<PyObject*>(e_type)) >= 0);
}
PyOpDefBegin(Resize3D) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
static PyObject* getstate(PyObject* self, PyObject*) {
auto& opdef = reinterpret_cast<PyOp(Resize3D)*>(self)->inst();
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {
{"imode", serialization<decltype(opdef.imode)>::dump(opdef.imode)},
{"format", serialization<decltype(opdef.format)>::dump(opdef.format)},
{"align_corners", serialization<decltype(opdef.align_corners)>::dump(opdef.align_corners)}
};
return py::cast(state).release().ptr();
}
static PyObject* setstate(PyObject* self, PyObject* args) {
PyObject* dict = PyTuple_GetItem(args, 0);
if (!dict) return NULL;
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
auto& opdef = reinterpret_cast<PyOp(Resize3D)*>(self)->inst();
static_cast<void>(opdef);
{
auto&& iter = state.find("imode");
if (iter != state.end()) {
opdef.imode = serialization<decltype(opdef.imode)>::load(iter->second);
}
}
{
auto&& iter = state.find("format");
if (iter != state.end()) {
opdef.format = serialization<decltype(opdef.format)>::load(iter->second);
}
}
{
auto&& iter = state.find("align_corners");
if (iter != state.end()) {
opdef.align_corners = serialization<decltype(opdef.align_corners)>::load(iter->second);
}
}
Py_RETURN_NONE;
}
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds);
static PyMethodDef py_init_methoddef;
// };
PyOpDefEnd(Resize3D)
int PyOp(Resize3D)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
static const char* kwlist[] = {"imode", "format", "align_corners", "scope", NULL};
PyObject *imode = NULL, *format = NULL, *align_corners = NULL, *scope = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOO", const_cast<char**>(kwlist), &imode, &format, &align_corners, &scope))
return -1;
if (imode) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(Resize3D)*>(self)->inst().imode =
py::cast<decltype(Resize3D::imode)>(py::handle(imode));
} CATCH_ALL(-1)
}
if (format) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(Resize3D)*>(self)->inst().format =
py::cast<decltype(Resize3D::format)>(py::handle(format));
} CATCH_ALL(-1)
}
if (align_corners) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(Resize3D)*>(self)->inst().align_corners =
py::cast<decltype(Resize3D::align_corners)>(py::handle(align_corners));
} CATCH_ALL(-1)
}
if (scope) {
try {
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(py::cast<std::string>(py::handle(scope)));
} CATCH_ALL(-1)
}
return 0;
}
PyGetSetDef PyOp(Resize3D)::py_getsetters[] = {
{const_cast<char*>("imode"), py_get_generic(Resize3D, imode), py_set_generic(Resize3D, imode), const_cast<char*>("imode"), NULL},
{const_cast<char*>("format"), py_get_generic(Resize3D, format), py_set_generic(Resize3D, format), const_cast<char*>("format"), NULL},
{const_cast<char*>("align_corners"), py_get_generic(Resize3D, align_corners), py_set_generic(Resize3D, align_corners), const_cast<char*>("align_corners"), NULL},
{NULL} /* Sentinel */
};
PyMethodDef PyOp(Resize3D)::tp_methods[] = {
{const_cast<char*>("__getstate__"), PyOp(Resize3D)::getstate, METH_NOARGS, "Resize3D getstate"},
{const_cast<char*>("__setstate__"), PyOp(Resize3D)::setstate, METH_VARARGS, "Resize3D setstate"},
{NULL} /* Sentinel */
};
PyObject *PyOp(Resize3D)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) {
if (PyOp(Resize3D)::py_init(self, args, kwds) < 0) {
return NULL;
}
Py_RETURN_NONE;
}
PyMethodDef PyOp(Resize3D)::py_init_methoddef = {
"__init__",
(PyCFunction)PyOp(Resize3D)::py_init_proxy,
METH_VARARGS | METH_KEYWORDS,
"__init__(self, imode: Union[str, InterpolationMode] = ..., format: Union[str, Format] = ..., align_corners: bool = ...) -> None\n"
};
void _init_py_Resize3D(py::module m) {
using py_op = PyOp(Resize3D);
auto& py_type = PyOpType(Resize3D);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.Resize3D";
py_type.tp_basicsize = sizeof(PyOp(Resize3D));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "Resize3D";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_methods = py_op::tp_methods;
py_type.tp_getset = py_op::py_getsetters;
py_type.tp_dict = PyDict_New();
PyObject* descr = PyDescr_NewMethod(&PyOpType(Resize3D), &PyOp(Resize3D)::py_init_methoddef);
PyDict_SetItemString(py_type.tp_dict, "__init__", descr);
mgb_assert(PyType_Ready(&py_type) >= 0);
_init_py_Resize3D_InterpolationMode(py_type);
_init_py_Resize3D_Format(py_type);
PyType_Modified(&py_type);
m.add_object("Resize3D", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Resize3D::typeinfo(), &py_type).second);
}
PyOpDefBegin(SVD) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
...
...
@@ -23327,6 +23490,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) {
_init_py_RemoveAxis(m); \
_init_py_Reshape(m); \
_init_py_Resize(m); \
_init_py_Resize3D(m); \
_init_py_SVD(m); \
_init_py_SetMeshIndexing(m); \
_init_py_SetSubtensor(m); \
...
...
imperative/tablegen/generated/opdef.h.inl
浏览文件 @
813628e2
...
...
@@ -1808,6 +1808,23 @@ public:
}
};
class Resize3D : public OpDefImplBase<Resize3D> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using InterpolationMode = ::megdnn::param::Resize3D::InterpolationMode;
using Format = ::megdnn::param::Resize3D::Format;
InterpolationMode imode = ::megdnn::param::Resize3D::InterpolationMode::LINEAR;
Format format = ::megdnn::param::Resize3D::Format::NDHWC;
bool align_corners = false;
Resize3D() = default;
Resize3D(InterpolationMode imode_, Format format_, bool align_corners_, std::string scope_ = {}): imode(imode_), format(format_), align_corners(align_corners_) { set_scope(scope_); }
Resize3D(::megdnn::param::Resize3D packed_param_0): imode(packed_param_0.imode), format(packed_param_0.format), align_corners(packed_param_0.align_corners) {}
::megdnn::param::Resize3D param() const {
return {imode, format, align_corners};
}
};
class SVD : public OpDefImplBase<SVD> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
...
...
imperative/tablegen/generated/opdef.py.inl
浏览文件 @
813628e2
...
...
@@ -1924,6 +1924,18 @@ ResizeInst
.def_readwrite("imode", &Resize::imode)
.def_readwrite("format", &Resize::format);
py::class_<Resize3D, std::shared_ptr<Resize3D>, OpDef> Resize3DInst(m, "Resize3D");
Resize3DInst.attr("InterpolationMode") = RemapInst.attr("InterpolationMode");
Resize3DInst.attr("Format") = Convolution3DInst.attr("Format");
Resize3DInst
.def(py::init<::megdnn::param::Resize3D::InterpolationMode, ::megdnn::param::Resize3D::Format, bool, std::string>(), py::arg("imode") = ::megdnn::param::Resize3D::InterpolationMode::LINEAR, py::arg("format") = ::megdnn::param::Resize3D::Format::NDHWC, py::arg("align_corners") = false, py::arg("scope") = {})
.def_readwrite("imode", &Resize3D::imode)
.def_readwrite("format", &Resize3D::format)
.def_readwrite("align_corners", &Resize3D::align_corners);
py::class_<SVD, std::shared_ptr<SVD>, OpDef> SVDInst(m, "SVD");
SVDInst
...
...
src/core/include/megbrain/ir/ops.td
浏览文件 @
813628e2
...
...
@@ -112,6 +112,8 @@ def Remap: MgbHashableOp<"Remap", [RemapParam]>;
def Resize: MgbHashableOp<"Resize", [ResizeParam]>;
def Resize3D: MgbHashableOp<"Resize3D", [Resize3DParam]>;
def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]> {
let extraArguments = (ins
MgbI32Attr:$ndim
...
...
src/opr/impl/imgproc.cpp
浏览文件 @
813628e2
...
...
@@ -502,6 +502,56 @@ MGB_IMPL_OPR_GRAD(ResizeForward) {
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
ResizeBackward
);
MEGDNN_OPR_INIT2
(
ResizeBackward
,
"resize_bwd"
,
1
,
false
);
/* ======================= Resize3DForward ======================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
Resize3DForward
);
MEGDNN_OPR_INIT2
(
Resize3DForward
,
"resize3d"
)
void
Resize3DForward
::
init_output_dtype
()
{
output
(
0
)
->
dtype
(
input
(
0
)
->
dtype
());
outshape_by_symvar_enable
(
1
,
1
);
}
void
Resize3DForward
::
add_input_layout_constraint
()
{
input
(
0
)
->
add_layout_constraint_contiguous
();
input
(
1
)
->
add_layout_constraint_contiguous
();
}
void
Resize3DForward
::
outshape_by_symvar_do_get_output_shape
(
TensorShape
&
dest
,
const
ShapeInferInfo
&
shpinfo
)
{
TensorShape
oshp3d
;
cg
::
copy_tensor_value_to_shape
(
oshp3d
,
*
shpinfo
.
shpval_inp_val
.
at
(
0
));
auto
imgshp
=
shpinfo
.
shape_inp_shp
.
at
(
0
);
mgb_assert
(
imgshp
.
ndim
==
5
&&
oshp3d
.
ndim
==
3
,
"shape mismatch for Resize3DForward: img=%s out3d=%s"
,
imgshp
.
to_string
().
c_str
(),
oshp3d
.
to_string
().
c_str
());
dest
=
imgshp
;
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
dest
.
shape
[
2
+
i
]
=
oshp3d
.
shape
[
i
];
}
}
void
Resize3DForward
::
init_output_static_infer_desc
()
{
Super
::
init_output_static_infer_desc
();
init_output_static_infer_desc_workspace
(
false
);
}
void
Resize3DForward
::
scn_do_execute
()
{
intl
::
MegDNNOprMethInvoker
<
megdnn
::
Resize3D
>::
exec
(
megdnn_opr
(),
this
);
}
size_t
Resize3DForward
::
get_workspace_size_bytes
(
const
TensorShapeArray
&
input_shapes
,
const
TensorShapeArray
&
output_shapes
)
const
{
return
intl
::
MegDNNOprMethInvoker
<
megdnn
::
Resize3D
>::
get_workspace_in_bytes
(
megdnn_opr
(),
this
,
input_shapes
,
output_shapes
);
}
void
Resize3DForward
::
record_execute_deps
(
ExecDependencyArray
&
deps
)
{
record_megdnn_opr
(
deps
);
}
/* ======================= WarpAffineForward ======================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
WarpAffineForward
);
...
...
src/opr/impl/imgproc.sereg.h
浏览文件 @
813628e2
...
...
@@ -195,6 +195,7 @@ MGB_SEREG_OPR(ResizeV2, 2);
using
DctChannelSelectV1
=
opr
::
DctChannelSelect
;
MGB_SEREG_OPR
(
DctChannelSelectV1
,
0
);
MGB_SEREG_OPR
(
Resize3D
,
2
);
}
// namespace opr
}
// namespace mgb
...
...
src/opr/include/megbrain/opr/imgproc.h
浏览文件 @
813628e2
...
...
@@ -206,6 +206,43 @@ public:
const
OperatorNodeConfig
&
config
=
{});
}
;
/* ============================= user set shape =========================== */
MGB_DEFINE_OPR_CLASS
(
Resize3DForward
,
intl
::
WorkspaceSizeInfer
<
intl
::
OutshapeBySymvarSCNOpr
<
mixin
::
MegDNNOprHolderImpl
<
megdnn
::
Resize3DForward
>>>
)
// {
public
:
Resize3DForward
(
VarNode
*
in_tensor
,
VarNode
*
out_shape
,
const
Param
&
param
,
const
OperatorNodeConfig
&
config
);
MGE_WIN_DECLSPEC_FUC
static
SymbolVar
make
(
SymbolVar
in_tensor
,
SymbolVar
out_shape
,
const
Param
&
param
=
{},
const
OperatorNodeConfig
&
config
=
{});
static
SymbolVar
make
(
SymbolVar
in_tensor
,
const
TensorShape
&
out_shape
,
const
Param
&
param
=
{},
const
OperatorNodeConfig
&
config
=
{})
{
return
make
(
in_tensor
,
cg
::
var_from_tensor_shape
(
in_tensor
,
out_shape
),
param
,
config
);
}
private
:
void
init_output_dtype
()
override
;
void
add_input_layout_constraint
()
override
;
void
init_output_static_infer_desc
()
override
;
void
outshape_by_symvar_do_get_output_shape
(
TensorShape
&
dest
,
const
ShapeInferInfo
&
shpinfo
)
override
;
void
scn_do_execute
()
override
;
size_t
get_workspace_size_bytes
(
const
TensorShapeArray
&
input_shapes
,
const
TensorShapeArray
&
output_shapes
)
const
override
;
void
record_execute_deps
(
ExecDependencyArray
&
deps
)
override
;
}
;
using
Resize3D
=
Resize3DForward
;
MGB_DEFINE_OPR_CLASS
(
RemapForward
,
intl
::
MegDNNOprWrapperFwd
<
megdnn
::
RemapForward
>
)
// {
public
:
...
...
src/opr/test/imgproc.cpp
浏览文件 @
813628e2
...
...
@@ -768,6 +768,78 @@ TEST(TestOprImgproc, ResizeBackward) {
{{
10
,
8
,
8
,
4
},
{
10
,
8
,
4
,
8
}},
param
,
1e-1
,
1e-2
);
}
TEST
(
TestOprImgproc
,
Resize3DForward
)
{
using
Param
=
opr
::
Resize3D
::
Param
;
using
IMode
=
Param
::
InterpolationMode
;
using
Format
=
Param
::
Format
;
auto
ac_param
=
Param
{
IMode
::
LINEAR
,
Format
::
NCDHW
,
true
};
auto
nac_param
=
Param
{
IMode
::
LINEAR
,
Format
::
NCDHW
,
false
};
auto
run
=
[
&
](
TensorShape
ishape
,
TensorShape
oshape
,
std
::
vector
<
float
>
idata
,
std
::
vector
<
float
>
oup_ref
,
Param
param
,
DType
test_dtype
)
{
std
::
shared_ptr
<
HostTensorND
>
inp_host
(
new
HostTensorND
{
CompNode
::
load
(
"xpux"
),
ishape
,
test_dtype
});
for
(
size_t
i
=
0
;
i
<
ishape
.
total_nr_elems
();
++
i
)
{
if
(
test_dtype
==
dtype
::
Float32
())
{
inp_host
->
ptr
<
dt_float32
>
()[
i
]
=
idata
[
i
];
}
else
if
(
test_dtype
==
dtype
::
Float16
())
{
inp_host
->
ptr
<
dt_float16
>
()[
i
]
=
idata
[
i
];
}
else
{
mgb_assert
(
false
,
"invalid"
);
}
}
std
::
shared_ptr
<
HostTensorND
>
oup_shape_host
(
new
HostTensorND
{
CompNode
::
load
(
"xpux"
),
TensorShape
({
oshape
.
ndim
}),
dtype
::
Int32
()});
for
(
size_t
i
=
0
;
i
<
oshape
.
ndim
;
++
i
)
{
oup_shape_host
->
ptr
<
dt_int32
>
()[
i
]
=
oshape
[
i
];
}
auto
graph
=
ComputingGraph
::
make
();
auto
inp_sym
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
inp_host
);
auto
oup_shape_sym
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
oup_shape_host
);
auto
oup
=
opr
::
Resize3D
::
make
(
inp_sym
,
oup_shape_sym
,
param
);
HostTensorND
oup_host
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
oup
,
oup_host
)});
func
->
execute
();
for
(
size_t
i
=
0
;
i
<
oshape
.
total_nr_elems
();
++
i
)
{
if
(
test_dtype
==
dtype
::
Float32
())
{
MGB_ASSERT_FLOAT_EQ
(
oup_ref
[
i
],
oup_host
.
ptr
<
dt_float32
>
()[
i
]);
}
else
if
(
test_dtype
==
dtype
::
Float16
())
{
MGB_ASSERT_FLOAT_NEAR
(
oup_ref
[
i
],
oup_host
.
ptr
<
dt_float16
>
()[
i
],
1e-3
);
}
else
{
mgb_assert
(
false
,
"invalid"
);
}
}
};
for
(
auto
&&
test_dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
dtype
::
Float16
()})
{
run
({
1
,
1
,
2
,
2
,
2
},
{
4
,
4
,
4
},
{
0.
,
1.
,
2.
,
3.
,
4.
,
5.
,
6.
,
7.
},
{
0.
,
0.25
,
0.75
,
1.
,
0.5
,
0.75
,
1.25
,
1.5
,
1.5
,
1.75
,
2.25
,
2.5
,
2.
,
2.25
,
2.75
,
3.
,
1.
,
1.25
,
1.75
,
2.
,
1.5
,
1.75
,
2.25
,
2.5
,
2.5
,
2.75
,
3.25
,
3.5
,
3.
,
3.25
,
3.75
,
4.
,
3.
,
3.25
,
3.75
,
4.
,
3.5
,
3.75
,
4.25
,
4.5
,
4.5
,
4.75
,
5.25
,
5.5
,
5.
,
5.25
,
5.75
,
6.
,
4.
,
4.25
,
4.75
,
5.
,
4.5
,
4.75
,
5.25
,
5.5
,
5.5
,
5.75
,
6.25
,
6.5
,
6.
,
6.25
,
6.75
,
7.
},
nac_param
,
test_dtype
);
run
({
1
,
1
,
2
,
2
,
2
},
{
4
,
4
,
4
},
{
0.
,
1.
,
2.
,
3.
,
4.
,
5.
,
6.
,
7.
},
{
0.
,
0.3333333
,
0.6666667
,
1.
,
0.6666667
,
1.
,
1.3333333
,
1.6666666
,
1.3333334
,
1.6666667
,
1.9999999
,
2.3333333
,
2.
,
2.3333333
,
2.6666665
,
3.
,
1.3333334
,
1.6666666
,
2.0000002
,
2.3333335
,
2.
,
2.333333
,
2.6666667
,
2.9999998
,
2.6666665
,
3.
,
3.3333333
,
3.6666665
,
3.3333333
,
3.6666665
,
4.
,
4.3333335
,
2.6666667
,
3.
,
3.3333337
,
3.6666667
,
3.3333335
,
3.6666663
,
4.
,
4.333333
,
3.9999998
,
4.333333
,
4.6666665
,
5.
,
4.6666665
,
5.
,
5.3333335
,
5.666667
,
4.
,
4.333333
,
4.666667
,
5.
,
4.6666665
,
4.9999995
,
5.3333335
,
5.6666665
,
5.333333
,
5.6666665
,
6.
,
6.3333335
,
6.
,
6.333333
,
6.666667
,
7.
},
ac_param
,
test_dtype
);
}
}
TEST
(
TestOprImgproc
,
WarpAffineForward
)
{
constexpr
size_t
INP_H
=
6
,
INP_W
=
4
,
N
=
2
,
C
=
3
;
...
...
src/serialization/impl/schema.fbs
浏览文件 @
813628e2
...
...
@@ -127,6 +127,7 @@ union OperatorParam {
param.Fill = 93,
param.GeneralNorm=94,
param.MultiHeadAttn=95,
param.Resize3D = 96,
}
table Operator {
...
...
src/serialization/impl/schema_v2.fbs
浏览文件 @
813628e2
...
...
@@ -144,6 +144,7 @@ union OperatorParam {
param.Fill = 93,
param.GeneralNorm=94,
param.MultiHeadAttn=95,
param.Resize3D = 96,
}
table Operator {
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录