Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
8b94f493
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
8b94f493
编写于
9月 10, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(dnn/cuda): fix elemwise and relayout int4 bug when last shape is 1
GitOrigin-RevId: e7d64c49871032deeda4176289f0457d4b9d85b8
上级
694aa1bd
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
35 addition
and
5 deletion
+35
-5
dnn/src/common/basic_types.cpp
dnn/src/common/basic_types.cpp
+8
-0
dnn/src/cuda/elemwise_helper.cpp
dnn/src/cuda/elemwise_helper.cpp
+6
-0
dnn/src/cuda/elemwise_helper.cuh
dnn/src/cuda/elemwise_helper.cuh
+4
-1
dnn/src/cuda/relayout/param_visitor.cpp
dnn/src/cuda/relayout/param_visitor.cpp
+6
-0
dnn/src/cuda/relayout/param_visitor.cuh
dnn/src/cuda/relayout/param_visitor.cuh
+4
-1
dnn/test/cuda/elemwise_multi_type.cpp
dnn/test/cuda/elemwise_multi_type.cpp
+2
-1
dnn/test/cuda/relayout.cpp
dnn/test/cuda/relayout.cpp
+1
-0
dnn/test/cuda/type_cvt.cpp
dnn/test/cuda/type_cvt.cpp
+4
-2
未找到文件。
dnn/src/common/basic_types.cpp
浏览文件 @
8b94f493
...
@@ -424,12 +424,20 @@ size_t TensorLayout::access_bytes() const {
...
@@ -424,12 +424,20 @@ size_t TensorLayout::access_bytes() const {
if
(
dtype
.
is_low_bit
())
{
if
(
dtype
.
is_low_bit
())
{
ret
=
1
;
ret
=
1
;
int
align_size_in_elements
=
8
/
dtype
.
low_bit
();
int
align_size_in_elements
=
8
/
dtype
.
low_bit
();
auto
min_stride
=
contig
.
stride
[
0
];
for
(
size_t
i
=
0
;
i
<
contig
.
ndim
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
contig
.
ndim
;
++
i
)
{
if
(
contig
.
stride
[
i
]
==
1
)
{
if
(
contig
.
stride
[
i
]
==
1
)
{
ret
*=
round_up
((
int
)
contig
.
shape
[
i
],
align_size_in_elements
);
ret
*=
round_up
((
int
)
contig
.
shape
[
i
],
align_size_in_elements
);
}
else
{
}
else
{
ret
*=
contig
.
shape
[
i
];
ret
*=
contig
.
shape
[
i
];
}
}
if
(
min_stride
>
contig
.
stride
[
i
])
{
min_stride
=
contig
.
stride
[
i
];
}
}
if
(
min_stride
!=
1
)
{
megdnn_assert
(
min_stride
==
align_size_in_elements
);
ret
*=
min_stride
;
}
}
ret
/=
align_size_in_elements
;
ret
/=
align_size_in_elements
;
}
else
{
}
else
{
...
...
dnn/src/cuda/elemwise_helper.cpp
浏览文件 @
8b94f493
...
@@ -240,6 +240,7 @@ template <int ndim>
...
@@ -240,6 +240,7 @@ template <int ndim>
void
ParamElemVisitor4bitBase
<
ndim
,
BCAST_OTHER
>::
host_init
(
void
ParamElemVisitor4bitBase
<
ndim
,
BCAST_OTHER
>::
host_init
(
const
TensorND
&
rv
,
int
/*grid_size*/
,
int
/*block_size*/
)
{
const
TensorND
&
rv
,
int
/*grid_size*/
,
int
/*block_size*/
)
{
m_ptr
=
reinterpret_cast
<
Storage
*>
(
rv
.
raw_ptr
);
m_ptr
=
reinterpret_cast
<
Storage
*>
(
rv
.
raw_ptr
);
auto
min_stride
=
rv
.
layout
.
stride
[
0
];
for
(
size_t
i
=
0
;
i
<
rv
.
layout
.
ndim
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
rv
.
layout
.
ndim
;
++
i
)
{
m_stride
[
i
]
=
rv
.
layout
.
stride
[
i
];
m_stride
[
i
]
=
rv
.
layout
.
stride
[
i
];
m_shape
[
i
]
=
rv
.
layout
.
shape
[
i
];
m_shape
[
i
]
=
rv
.
layout
.
shape
[
i
];
...
@@ -251,7 +252,12 @@ void ParamElemVisitor4bitBase<ndim, BCAST_OTHER>::host_init(
...
@@ -251,7 +252,12 @@ void ParamElemVisitor4bitBase<ndim, BCAST_OTHER>::host_init(
else
else
m_align_shape_highdim
[
i
]
=
rv
.
layout
.
shape
[
i
+
1
];
m_align_shape_highdim
[
i
]
=
rv
.
layout
.
shape
[
i
+
1
];
}
}
if
(
min_stride
>
rv
.
layout
.
stride
[
i
])
{
min_stride
=
rv
.
layout
.
stride
[
i
];
}
}
}
megdnn_assert
(
min_stride
==
1
||
min_stride
==
2
);
m_is_min_stride_2
=
(
min_stride
==
2
);
for
(
size_t
i
=
rv
.
layout
.
ndim
-
1
;
i
<
ndim
-
1
;
++
i
)
{
for
(
size_t
i
=
rv
.
layout
.
ndim
-
1
;
i
<
ndim
-
1
;
++
i
)
{
m_shape_highdim
[
i
]
=
1
;
m_shape_highdim
[
i
]
=
1
;
m_align_shape_highdim
[
i
]
=
1
;
m_align_shape_highdim
[
i
]
=
1
;
...
...
dnn/src/cuda/elemwise_helper.cuh
浏览文件 @
8b94f493
...
@@ -542,6 +542,7 @@ protected:
...
@@ -542,6 +542,7 @@ protected:
int
m_stride
[
ndim
];
int
m_stride
[
ndim
];
int
m_shape
[
ndim
];
int
m_shape
[
ndim
];
bool
m_is_physical_contiguous
;
bool
m_is_physical_contiguous
;
bool
m_is_min_stride_2
;
//! m_shape_highdim[i] = original_shape[i + 1]
//! m_shape_highdim[i] = original_shape[i + 1]
#ifdef _MSC_VER
#ifdef _MSC_VER
...
@@ -592,7 +593,7 @@ public:
...
@@ -592,7 +593,7 @@ public:
int
idx
=
0
;
int
idx
=
0
;
if
(
m_is_physical_contiguous
)
{
if
(
m_is_physical_contiguous
)
{
idx
=
access_idx
;
idx
=
access_idx
;
}
else
{
}
else
if
(
!
m_is_min_stride_2
)
{
int
shape_idx
[
ndim
];
int
shape_idx
[
ndim
];
bool
valid
=
true
;
bool
valid
=
true
;
get_shape_from_access
(
access_idx
,
shape_idx
);
get_shape_from_access
(
access_idx
,
shape_idx
);
...
@@ -605,6 +606,8 @@ public:
...
@@ -605,6 +606,8 @@ public:
idx
=
(
idx
+
shape_idx
[
i
])
*
m_shape
[
i
+
1
];
idx
=
(
idx
+
shape_idx
[
i
])
*
m_shape
[
i
+
1
];
}
}
idx
=
valid
?
idx
+
shape_idx
[
ndim
-
1
]
:
-
1
;
idx
=
valid
?
idx
+
shape_idx
[
ndim
-
1
]
:
-
1
;
}
else
{
// min_stride == 2
idx
=
((
access_idx
&
0x1
)
==
0
)
?
((
int
)
access_idx
>>
1
)
:
-
1
;
}
}
return
idx
;
return
idx
;
}
}
...
...
dnn/src/cuda/relayout/param_visitor.cpp
浏览文件 @
8b94f493
...
@@ -70,6 +70,7 @@ void ParamElemVisitor<ndim, dt_quint4, CONTIG_OTHER>::host_init(
...
@@ -70,6 +70,7 @@ void ParamElemVisitor<ndim, dt_quint4, CONTIG_OTHER>::host_init(
const
TensorND
&
rv
,
int
/*grid_size*/
,
int
/*block_size*/
)
{
const
TensorND
&
rv
,
int
/*grid_size*/
,
int
/*block_size*/
)
{
megdnn_assert
(
rv
.
layout
.
ndim
&&
rv
.
layout
.
ndim
<=
ndim
);
megdnn_assert
(
rv
.
layout
.
ndim
&&
rv
.
layout
.
ndim
<=
ndim
);
m_ptr
=
reinterpret_cast
<
Storage
*>
(
rv
.
raw_ptr
);
m_ptr
=
reinterpret_cast
<
Storage
*>
(
rv
.
raw_ptr
);
auto
min_stride
=
rv
.
layout
.
stride
[
0
];
for
(
size_t
i
=
0
;
i
<
rv
.
layout
.
ndim
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
rv
.
layout
.
ndim
;
++
i
)
{
m_stride
[
i
]
=
rv
.
layout
.
stride
[
i
];
m_stride
[
i
]
=
rv
.
layout
.
stride
[
i
];
m_shape
[
i
]
=
rv
.
layout
.
shape
[
i
];
m_shape
[
i
]
=
rv
.
layout
.
shape
[
i
];
...
@@ -81,7 +82,12 @@ void ParamElemVisitor<ndim, dt_quint4, CONTIG_OTHER>::host_init(
...
@@ -81,7 +82,12 @@ void ParamElemVisitor<ndim, dt_quint4, CONTIG_OTHER>::host_init(
else
else
m_align_shape_highdim
[
i
]
=
rv
.
layout
.
shape
[
i
+
1
];
m_align_shape_highdim
[
i
]
=
rv
.
layout
.
shape
[
i
+
1
];
}
}
if
(
min_stride
>
rv
.
layout
.
stride
[
i
])
{
min_stride
=
rv
.
layout
.
stride
[
i
];
}
}
}
megdnn_assert
(
min_stride
==
1
||
min_stride
==
2
);
m_is_min_stride_2
=
(
min_stride
==
2
);
for
(
size_t
i
=
rv
.
layout
.
ndim
-
1
;
i
<
ndim
-
1
;
++
i
)
{
for
(
size_t
i
=
rv
.
layout
.
ndim
-
1
;
i
<
ndim
-
1
;
++
i
)
{
m_shape_highdim
[
i
]
=
1
;
m_shape_highdim
[
i
]
=
1
;
m_align_shape_highdim
[
i
]
=
1
;
m_align_shape_highdim
[
i
]
=
1
;
...
...
dnn/src/cuda/relayout/param_visitor.cuh
浏览文件 @
8b94f493
...
@@ -132,6 +132,7 @@ class ParamElemVisitor<ndim, dt_quint4, CONTIG_OTHER> {
...
@@ -132,6 +132,7 @@ class ParamElemVisitor<ndim, dt_quint4, CONTIG_OTHER> {
int
m_shape
[
ndim
];
int
m_shape
[
ndim
];
bool
m_is_contiguous
;
bool
m_is_contiguous
;
bool
m_is_physical_contiguous
;
bool
m_is_physical_contiguous
;
bool
m_is_min_stride_2
;
//! m_shape_highdim[i] = original_shape[i + 1]
//! m_shape_highdim[i] = original_shape[i + 1]
#ifdef _MSC_VER
#ifdef _MSC_VER
...
@@ -197,7 +198,7 @@ public:
...
@@ -197,7 +198,7 @@ public:
int
idx
=
0
;
int
idx
=
0
;
if
(
m_is_physical_contiguous
)
{
if
(
m_is_physical_contiguous
)
{
idx
=
access_idx
;
idx
=
access_idx
;
}
else
{
}
else
if
(
!
m_is_min_stride_2
)
{
int
shape_idx
[
ndim
];
int
shape_idx
[
ndim
];
bool
valid
=
true
;
bool
valid
=
true
;
get_shape_from_access
(
access_idx
,
shape_idx
);
get_shape_from_access
(
access_idx
,
shape_idx
);
...
@@ -209,6 +210,8 @@ public:
...
@@ -209,6 +210,8 @@ public:
idx
=
(
idx
+
shape_idx
[
i
])
*
m_shape
[
i
+
1
];
idx
=
(
idx
+
shape_idx
[
i
])
*
m_shape
[
i
+
1
];
}
}
idx
=
valid
?
idx
+
shape_idx
[
ndim
-
1
]
:
-
1
;
idx
=
valid
?
idx
+
shape_idx
[
ndim
-
1
]
:
-
1
;
}
else
{
// min_stride == 2
idx
=
((
access_idx
&
0x1
)
==
0
)
?
((
int
)
access_idx
>>
1
)
:
-
1
;
}
}
return
idx
;
return
idx
;
}
}
...
...
dnn/test/cuda/elemwise_multi_type.cpp
浏览文件 @
8b94f493
...
@@ -152,7 +152,8 @@ static void run_test_q4(int arity, Checker<ElemwiseMultiType>& checker,
...
@@ -152,7 +152,8 @@ static void run_test_q4(int arity, Checker<ElemwiseMultiType>& checker,
.
execs
({{
1
,
4
,
5
,
5
},
{
1
,
4
,
5
,
5
}});
.
execs
({{
1
,
4
,
5
,
5
},
{
1
,
4
,
5
,
5
}});
}
else
if
(
arity
==
2
)
{
}
else
if
(
arity
==
2
)
{
checker
.
execs
({{
3
,
4
,
5
,
6
},
{
3
,
4
,
5
,
6
},
{
3
,
4
,
5
,
6
}})
checker
.
execs
({{
3
,
4
,
5
,
6
},
{
3
,
4
,
5
,
6
},
{
3
,
4
,
5
,
6
}})
.
execs
({{
1
,
4
,
5
,
5
},
{
1
,
4
,
5
,
5
},
{
1
,
4
,
5
,
5
}});
.
execs
({{
1
,
4
,
5
,
5
},
{
1
,
4
,
5
,
5
},
{
1
,
4
,
5
,
5
}})
.
execs
({{
2
,
2
,
3
,
1
},
{
2
,
2
,
3
,
1
},
{
2
,
2
,
3
,
1
}});
}
else
{
}
else
{
megdnn_assert
(
0
);
megdnn_assert
(
0
);
}
}
...
...
dnn/test/cuda/relayout.cpp
浏览文件 @
8b94f493
...
@@ -925,6 +925,7 @@ TEST_F(CUDA, RELAYOUT_Q4) {
...
@@ -925,6 +925,7 @@ TEST_F(CUDA, RELAYOUT_Q4) {
.
set_rng
(
1
,
&
rng_int4
)
.
set_rng
(
1
,
&
rng_int4
)
.
set_dtype
(
0
,
dtype
::
QuantizedS4
(
1.
f
))
.
set_dtype
(
0
,
dtype
::
QuantizedS4
(
1.
f
))
.
set_dtype
(
1
,
dtype
::
QuantizedS4
(
1.
f
))
.
set_dtype
(
1
,
dtype
::
QuantizedS4
(
1.
f
))
.
execs
({{
2
,
2
,
1
,
1
},
{
1
,
1
,
2
,
2
}})
.
execs
({{
1
,
64
,
15
,
15
},
{
1
,
15
,
15
,
64
}})
.
execs
({{
1
,
64
,
15
,
15
},
{
1
,
15
,
15
,
64
}})
.
execs
({{
1
,
5
,
9
,
32
},
{
1
,
5
,
32
,
9
}})
.
execs
({{
1
,
5
,
9
,
32
},
{
1
,
5
,
32
,
9
}})
.
execl
(
TensorLayoutArray
{
.
execl
(
TensorLayoutArray
{
...
...
dnn/test/cuda/type_cvt.cpp
浏览文件 @
8b94f493
...
@@ -123,11 +123,13 @@ TEST_F(CUDA, QUANTIZED_TYPECVT_4BIT) {
...
@@ -123,11 +123,13 @@ TEST_F(CUDA, QUANTIZED_TYPECVT_4BIT) {
set_err
(
dst_dtype
);
set_err
(
dst_dtype
);
checker
.
set_dtype
(
0
,
src_dtype
)
checker
.
set_dtype
(
0
,
src_dtype
)
.
set_dtype
(
1
,
dst_dtype
)
.
set_dtype
(
1
,
dst_dtype
)
.
execs
({{
16
,
3
,
224
,
223
},
{
16
,
3
,
224
,
223
}});
.
execs
({{
16
,
3
,
224
,
223
},
{
16
,
3
,
224
,
223
}})
.
execs
({{
16
,
3
,
224
,
1
},
{
16
,
3
,
224
,
1
}});
set_err
(
src_dtype
);
set_err
(
src_dtype
);
checker
.
set_dtype
(
0
,
dst_dtype
)
checker
.
set_dtype
(
0
,
dst_dtype
)
.
set_dtype
(
1
,
src_dtype
)
.
set_dtype
(
1
,
src_dtype
)
.
execs
({{
16
,
3
,
224
,
223
},
{
16
,
3
,
224
,
223
}});
.
execs
({{
16
,
3
,
224
,
223
},
{
16
,
3
,
224
,
223
}})
.
execs
({{
16
,
3
,
224
,
1
},
{
16
,
3
,
224
,
1
}});
};
};
run
(
dtype
::
Quantized4Asymm
{
1.19990518
f
,
8
},
run
(
dtype
::
Quantized4Asymm
{
1.19990518
f
,
8
},
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录