Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
cd7090ac
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
cd7090ac
编写于
1月 27, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(opencl): enable image on mali(cl2.1)
GitOrigin-RevId: 0c670fba807e9bf25e7825e7de5ce8c04d30dae8
上级
fc8d13cd
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
263 addition
and
152 deletion
+263
-152
dnn/include/megdnn/handle.h
dnn/include/megdnn/handle.h
+14
-0
dnn/include/megdnn/tensor_format.h
dnn/include/megdnn/tensor_format.h
+56
-32
dnn/src/common/handle.cpp
dnn/src/common/handle.cpp
+4
-0
dnn/src/common/relayout_format.cpp
dnn/src/common/relayout_format.cpp
+7
-6
dnn/src/common/tensor_format.cpp
dnn/src/common/tensor_format.cpp
+145
-111
dnn/src/cuda/handle.cpp
dnn/src/cuda/handle.cpp
+4
-0
dnn/src/cuda/handle.h
dnn/src/cuda/handle.h
+1
-0
dnn/src/naive/handle.cpp
dnn/src/naive/handle.cpp
+4
-0
dnn/src/naive/handle.h
dnn/src/naive/handle.h
+1
-0
dnn/test/common/test_basic_types.cpp
dnn/test/common/test_basic_types.cpp
+27
-3
未找到文件。
dnn/include/megdnn/handle.h
浏览文件 @
cd7090ac
...
@@ -38,6 +38,17 @@ class Handle {
...
@@ -38,6 +38,17 @@ class Handle {
CAMBRICON
=
12
,
CAMBRICON
=
12
,
};
};
//! Device vendor
enum
class
HandleVendorType
:
uint32_t
{
NOT_SPEC
=
0
,
MALI
=
1
,
ADRENO
=
2
,
CUDA
=
3
,
INTEL
=
4
,
POWERVR
=
5
,
AMD
=
6
,
};
protected:
protected:
Handle
(
megcoreComputingHandle_t
computing_handle
,
HandleType
type
);
Handle
(
megcoreComputingHandle_t
computing_handle
,
HandleType
type
);
...
@@ -130,6 +141,9 @@ class Handle {
...
@@ -130,6 +141,9 @@ class Handle {
//! get alignment in bytes for rows of image 2D tensor format
//! get alignment in bytes for rows of image 2D tensor format
virtual
size_t
image2d_pitch_alignment
()
const
;
virtual
size_t
image2d_pitch_alignment
()
const
;
//! get vendor type
virtual
HandleVendorType
vendor_type
()
const
;
HandleType
type
()
const
{
HandleType
type
()
const
{
return
m_handle_type
;
return
m_handle_type
;
}
}
...
...
dnn/include/megdnn/tensor_format.h
浏览文件 @
cd7090ac
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#pragma once
#pragma once
#include "megdnn/basic_types.h"
#include "megdnn/basic_types.h"
#include "megdnn/handle.h"
#include "megdnn/internal/visibility_prologue.h"
#include "megdnn/internal/visibility_prologue.h"
namespace
megdnn
{
namespace
megdnn
{
...
@@ -43,7 +44,7 @@ public:
...
@@ -43,7 +44,7 @@ public:
protected:
protected:
ImplBase
(
Type
type
)
:
m_type
{
type
}
{}
ImplBase
(
Type
type
)
:
m_type
{
type
}
{}
~
ImplBase
()
=
default
;
virtual
~
ImplBase
()
=
default
;
static
TensorFormat
impl_to_tensor_format
(
ImplBase
*
impl
)
{
return
{
impl
};
}
static
TensorFormat
impl_to_tensor_format
(
ImplBase
*
impl
)
{
return
{
impl
};
}
...
@@ -93,8 +94,8 @@ namespace detail {
...
@@ -93,8 +94,8 @@ namespace detail {
*
*
* \p align_axis is the axis to be aligned, also the first axis of image width.
* \p align_axis is the axis to be aligned, also the first axis of image width.
* More precisely speaking, `stride[align_axis-1] * dtype.size()` must divide \p
* More precisely speaking, `stride[align_axis-1] * dtype.size()` must divide \p
* align_size_in_
byte. Axes from 0 to align_axis-1 would be considered as the
* align_size_in_
elements. Axes from 0 to align_axis-1 would be considered as
* height of the image, and other axes are the width.
*
the
height of the image, and other axes are the width.
*
*
* Empty tensors and negative strides are not allowed. Only contiguous or
* Empty tensors and negative strides are not allowed. Only contiguous or
* broadcasted cases are allowed.
* broadcasted cases are allowed.
...
@@ -103,41 +104,32 @@ namespace detail {
...
@@ -103,41 +104,32 @@ namespace detail {
* considered as contiguous.
* considered as contiguous.
*/
*/
class
Image2DTensorFormatBase
:
public
TensorFormat
::
ImplBase
{
class
Image2DTensorFormatBase
:
public
TensorFormat
::
ImplBase
{
size_t
m_align_axis
,
m_align_size_in_
byte
_log2
;
size_t
m_align_axis
,
m_align_size_in_
elements
_log2
;
protected:
protected:
Image2DTensorFormatBase
(
Type
type
,
size_t
align_axis
,
Image2DTensorFormatBase
(
Type
type
,
size_t
align_axis
,
size_t
align_size_in_
byte
);
size_t
align_size_in_
elements
);
~
Image2DTensorFormatBase
()
=
default
;
virtual
~
Image2DTensorFormatBase
()
=
default
;
public:
public:
/*!
/*!
* \brief get alignment requirement in
byte
s
* \brief get alignment requirement in
element
s
* \param div_log2 the result would be divided by `(1 << div_log2)`
* \param div_log2 the result would be divided by `(1 << div_log2)`
*/
*/
size_t
align_size_in_
byte
(
size_t
div_log2
=
0
)
const
{
size_t
align_size_in_
elements
(
size_t
div_log2
=
0
)
const
{
return
1
<<
(
m_align_size_in_
byte
_log2
>
div_log2
return
1
<<
(
m_align_size_in_
elements
_log2
>
div_log2
?
m_align_size_in_
byte
_log2
-
div_log2
?
m_align_size_in_
elements
_log2
-
div_log2
:
0
);
:
0
);
}
}
size_t
align_axis
()
const
{
return
m_align_axis
;
}
size_t
align_axis
()
const
{
return
m_align_axis
;
}
size_t
init_contiguous_stride
(
TensorLayout
&
layout
)
const
override
;
size_t
align_size_in_elements_log2
()
const
{
return
m_align_size_in_elements_log2
;
bool
is_contiguous_spec
(
const
TensorLayout
&
layout
)
const
override
;
}
TensorLayout
collapse_contiguous_spec
(
const
TensorLayout
&
layout
)
const
override
;
//! span for image must include the padding at the last row
TensorLayout
::
Span
span_spec
(
const
TensorLayout
&
layout
)
const
override
;
std
::
string
to_string
()
const
override
;
std
::
string
to_string
()
const
override
;
//! raise exception if preconditions violated
virtual
void
assert_valid
(
const
TensorLayout
&
layout
)
const
;
//! modify the align axis and return a new TensorFormat
//! modify the align axis and return a new TensorFormat
virtual
TensorFormat
change_axis
(
size_t
axis
)
const
=
0
;
virtual
TensorFormat
change_axis
(
size_t
axis
)
const
=
0
;
...
@@ -147,9 +139,6 @@ public:
...
@@ -147,9 +139,6 @@ public:
//! number of rows
//! number of rows
size_t
image_height
(
const
TensorLayout
&
layout
)
const
;
size_t
image_height
(
const
TensorLayout
&
layout
)
const
;
//! delta of addresses of consecutive rows (in bytes)
size_t
image_row_pitch
(
const
TensorLayout
&
layout
)
const
;
void
serialize_append
(
std
::
string
&
result
)
const
override
;
void
serialize_append
(
std
::
string
&
result
)
const
override
;
protected:
protected:
struct
SerializePack
{
struct
SerializePack
{
...
@@ -159,9 +148,27 @@ protected:
...
@@ -159,9 +148,27 @@ protected:
template
<
size_t
PIXEL_SIZE
>
template
<
size_t
PIXEL_SIZE
>
class
Image2DPackedTensorFormatBase
:
public
Image2DTensorFormatBase
{
class
Image2DPackedTensorFormatBase
:
public
Image2DTensorFormatBase
{
Handle
::
HandleVendorType
m_vendor_type
=
Handle
::
HandleVendorType
::
NOT_SPEC
;
/*!
* \brief get fix alignment requirement in bytes, consider m_vendor_type,
* for example on MALI, CL_DEVICE_IMAGE_PITCH_ALIGNMENT means image_width
* align COUNT, but mdl needs align size in byte, which equal to
* (image_width algin count) * sizeof(data_type) * pixel_size
*/
size_t
image_pitch_alignment_in_bytes
(
size_t
align_size_in_elements
,
const
TensorLayout
&
layout
)
const
;
protected:
protected:
using
Image2DTensorFormatBase
::
Image2DTensorFormatBase
;
Image2DPackedTensorFormatBase
(
Type
type
,
size_t
align_axis
,
~
Image2DPackedTensorFormatBase
()
=
default
;
size_t
align_size_in_elements
,
Handle
::
HandleVendorType
vendor_type
)
:
detail
::
Image2DTensorFormatBase
(
type
,
align_axis
,
align_size_in_elements
),
m_vendor_type
(
vendor_type
)
{}
virtual
~
Image2DPackedTensorFormatBase
()
=
default
;
Handle
::
HandleVendorType
vendor
()
const
{
return
m_vendor_type
;
}
public:
public:
/*!
/*!
...
@@ -173,7 +180,20 @@ public:
...
@@ -173,7 +180,20 @@ public:
*/
*/
size_t
image_width
(
const
TensorLayout
&
layout
)
const
;
size_t
image_width
(
const
TensorLayout
&
layout
)
const
;
void
assert_valid
(
const
TensorLayout
&
layout
)
const
override
;
//! raise exception if preconditions violated
void
assert_valid
(
const
TensorLayout
&
layout
)
const
;
size_t
image_row_pitch
(
const
TensorLayout
&
layout
)
const
;
//! span for image must include the padding at the last row
TensorLayout
::
Span
span_spec
(
const
TensorLayout
&
layout
)
const
override
;
size_t
init_contiguous_stride
(
TensorLayout
&
layout
)
const
override
;
bool
is_contiguous_spec
(
const
TensorLayout
&
layout
)
const
override
;
TensorLayout
collapse_contiguous_spec
(
const
TensorLayout
&
layout
)
const
override
;
};
};
using
Image2DPack4TensorFormatBase
=
Image2DPackedTensorFormatBase
<
4
>
;
using
Image2DPack4TensorFormatBase
=
Image2DPackedTensorFormatBase
<
4
>
;
}
// namespace detail
}
// namespace detail
...
@@ -190,7 +210,10 @@ public:
...
@@ -190,7 +210,10 @@ public:
static
constexpr
Type
TYPE
=
Type
::
IMAGE2D_PACK4
;
static
constexpr
Type
TYPE
=
Type
::
IMAGE2D_PACK4
;
//! for internal usage or test purposes
//! for internal usage or test purposes
static
TensorFormat
make_raw
(
size_t
align_axis
,
size_t
align_size_in_byte
);
static
TensorFormat
make_raw
(
size_t
align_axis
,
size_t
align_size_in_elements
,
Handle
::
HandleVendorType
vendor_type
=
Handle
::
HandleVendorType
::
NOT_SPEC
);
static
TensorFormat
make
(
size_t
align_axis
,
const
Handle
*
handle
);
static
TensorFormat
make
(
size_t
align_axis
,
const
Handle
*
handle
);
...
@@ -215,9 +238,10 @@ public:
...
@@ -215,9 +238,10 @@ public:
TensorFormat
change_axis
(
size_t
axis
)
const
override
;
TensorFormat
change_axis
(
size_t
axis
)
const
override
;
private:
private:
Image2DPack4TensorFormat
(
size_t
align_axis
,
size_t
align_size_in_byte
)
Image2DPack4TensorFormat
(
size_t
align_axis
,
size_t
align_size_in_elements
,
:
detail
::
Image2DPack4TensorFormatBase
(
TYPE
,
align_axis
,
Handle
::
HandleVendorType
vendor_type
)
align_size_in_byte
)
{}
:
detail
::
Image2DPack4TensorFormatBase
(
TYPE
,
align_axis
,
align_size_in_elements
,
vendor_type
)
{}
};
};
}
// namespace megdnn
}
// namespace megdnn
...
...
dnn/src/common/handle.cpp
浏览文件 @
cd7090ac
...
@@ -147,6 +147,10 @@ std::unique_ptr<Handle> Handle::make(megcoreComputingHandle_t computing_handle,
...
@@ -147,6 +147,10 @@ std::unique_ptr<Handle> Handle::make(megcoreComputingHandle_t computing_handle,
megdnn_throw
(
"image2d tensor format not supported on this handle"
);
megdnn_throw
(
"image2d tensor format not supported on this handle"
);
}
}
megdnn
::
HandleImplHelper
::
HandleVendorType
Handle
::
vendor_type
()
const
{
return
HandleVendorType
::
NOT_SPEC
;
}
bool
Handle
::
check_cross_dev_copy_constraint
(
const
TensorLayout
&
src
)
{
bool
Handle
::
check_cross_dev_copy_constraint
(
const
TensorLayout
&
src
)
{
return
src
.
is_contiguous
();
return
src
.
is_contiguous
();
}
}
...
...
dnn/src/common/relayout_format.cpp
浏览文件 @
cd7090ac
...
@@ -236,6 +236,7 @@ void RelayoutFormat::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
...
@@ -236,6 +236,7 @@ void RelayoutFormat::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
void
RelayoutFormat
::
deduce_format
(
TensorFormat
src
,
TensorFormat
&
dst
)
{
void
RelayoutFormat
::
deduce_format
(
TensorFormat
src
,
TensorFormat
&
dst
)
{
size_t
align
=
handle
()
->
image2d_pitch_alignment
();
size_t
align
=
handle
()
->
image2d_pitch_alignment
();
auto
vendor_type
=
handle
()
->
vendor_type
();
using
Param
=
param
::
RelayoutFormat
;
using
Param
=
param
::
RelayoutFormat
;
#define CHECK_SRC(_expect) \
#define CHECK_SRC(_expect) \
megdnn_assert(src == _expect, "invalid src format: expect=%s got=%s", \
megdnn_assert(src == _expect, "invalid src format: expect=%s got=%s", \
...
@@ -251,7 +252,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
...
@@ -251,7 +252,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
break
;
break
;
case
Param
::
Mode
::
NHWC_NHWCD4I
:
case
Param
::
Mode
::
NHWC_NHWCD4I
:
CHECK_SRC
(
DefaultTensorFormat
::
make
());
CHECK_SRC
(
DefaultTensorFormat
::
make
());
dst
=
Image2DPack4TensorFormat
::
make_raw
(
2
,
align
);
dst
=
Image2DPack4TensorFormat
::
make_raw
(
2
,
align
,
vendor_type
);
break
;
break
;
case
Param
::
Mode
::
NCHW_NHWCD4
:
case
Param
::
Mode
::
NCHW_NHWCD4
:
CHECK_SRC
(
DefaultTensorFormat
::
make
());
CHECK_SRC
(
DefaultTensorFormat
::
make
());
...
@@ -263,10 +264,10 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
...
@@ -263,10 +264,10 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
break
;
break
;
case
Param
::
Mode
::
NCHW_NHWCD4I
:
case
Param
::
Mode
::
NCHW_NHWCD4I
:
CHECK_SRC
(
DefaultTensorFormat
::
make
());
CHECK_SRC
(
DefaultTensorFormat
::
make
());
dst
=
Image2DPack4TensorFormat
::
make_raw
(
2
,
align
);
dst
=
Image2DPack4TensorFormat
::
make_raw
(
2
,
align
,
vendor_type
);
break
;
break
;
case
Param
::
Mode
::
NHWCD4I_NCHW
:
case
Param
::
Mode
::
NHWCD4I_NCHW
:
CHECK_SRC
(
Image2DPack4TensorFormat
::
make_raw
(
2
,
align
));
CHECK_SRC
(
Image2DPack4TensorFormat
::
make_raw
(
2
,
align
,
vendor_type
));
dst
=
DefaultTensorFormat
::
make
();
dst
=
DefaultTensorFormat
::
make
();
break
;
break
;
case
Param
::
Mode
::
NHWCD4_NCHW
:
case
Param
::
Mode
::
NHWCD4_NCHW
:
...
@@ -280,7 +281,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
...
@@ -280,7 +281,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
case
Param
::
Mode
::
INTER_WEIGHT_DENSEI
:
case
Param
::
Mode
::
INTER_WEIGHT_DENSEI
:
case
Param
::
Mode
::
INTER_WEIGHT_DENSEI_DOT
:
case
Param
::
Mode
::
INTER_WEIGHT_DENSEI_DOT
:
CHECK_SRC
(
DefaultTensorFormat
::
make
());
CHECK_SRC
(
DefaultTensorFormat
::
make
());
dst
=
Image2DPack4TensorFormat
::
make_raw
(
3
,
align
);
dst
=
Image2DPack4TensorFormat
::
make_raw
(
3
,
align
,
vendor_type
);
break
;
break
;
case
Param
::
Mode
::
INTER_WEIGHT_GROUP
:
case
Param
::
Mode
::
INTER_WEIGHT_GROUP
:
CHECK_SRC
(
DefaultTensorFormat
::
make
());
CHECK_SRC
(
DefaultTensorFormat
::
make
());
...
@@ -289,7 +290,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
...
@@ -289,7 +290,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
case
Param
::
Mode
::
INTER_WEIGHT_GROUPI
:
case
Param
::
Mode
::
INTER_WEIGHT_GROUPI
:
case
Param
::
Mode
::
INTER_WEIGHT_GROUPI_DOT
:
case
Param
::
Mode
::
INTER_WEIGHT_GROUPI_DOT
:
CHECK_SRC
(
DefaultTensorFormat
::
make
());
CHECK_SRC
(
DefaultTensorFormat
::
make
());
dst
=
Image2DPack4TensorFormat
::
make_raw
(
4
,
align
);
dst
=
Image2DPack4TensorFormat
::
make_raw
(
4
,
align
,
vendor_type
);
break
;
break
;
case
Param
::
Mode
::
INTER_WEIGHT_CHAN
:
case
Param
::
Mode
::
INTER_WEIGHT_CHAN
:
CHECK_SRC
(
DefaultTensorFormat
::
make
());
CHECK_SRC
(
DefaultTensorFormat
::
make
());
...
@@ -297,7 +298,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
...
@@ -297,7 +298,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
break
;
break
;
case
Param
::
Mode
::
INTER_WEIGHT_CHANI
:
case
Param
::
Mode
::
INTER_WEIGHT_CHANI
:
CHECK_SRC
(
DefaultTensorFormat
::
make
());
CHECK_SRC
(
DefaultTensorFormat
::
make
());
dst
=
Image2DPack4TensorFormat
::
make_raw
(
1
,
align
);
dst
=
Image2DPack4TensorFormat
::
make_raw
(
1
,
align
,
vendor_type
);
break
;
break
;
case
Param
::
Mode
::
NCHW4_CHWN4
:
case
Param
::
Mode
::
NCHW4_CHWN4
:
CHECK_SRC
(
DefaultTensorFormat
::
make
());
CHECK_SRC
(
DefaultTensorFormat
::
make
());
...
...
dnn/src/common/tensor_format.cpp
浏览文件 @
cd7090ac
...
@@ -185,23 +185,134 @@ TensorFormat DefaultTensorFormat::make() {
...
@@ -185,23 +185,134 @@ TensorFormat DefaultTensorFormat::make() {
/* ===================== Image2DTensorFormatBase ===================== */
/* ===================== Image2DTensorFormatBase ===================== */
Image2DTensorFormatBase
::
Image2DTensorFormatBase
(
Type
type
,
size_t
align_axis
,
Image2DTensorFormatBase
::
Image2DTensorFormatBase
(
Type
type
,
size_t
align_axis
,
size_t
align_size_in_
byte
)
size_t
align_size_in_
elements
)
:
ImplBase
(
type
)
{
:
ImplBase
(
type
)
,
m_align_axis
(
align_axis
)
{
megdnn_assert
(
align_size_in_
byte
&&
align_axis
);
megdnn_assert
(
align_size_in_
elements
&&
align_axis
);
m_align_
axis
=
align_axis
;
m_align_
size_in_elements_log2
=
__builtin_ctz
(
align_size_in_elements
)
;
m
_align_size_in_byte_log2
=
__builtin_ctz
(
align_size_in_byte
);
m
egdnn_assert
(
megdnn_assert
((
1u
<<
m_align_size_in_byte_log2
)
==
align_size_in_byte
,
(
1u
<<
m_align_size_in_elements_log2
)
==
align_size_in_elements
,
"align size not power of 2: %zu"
,
align_size_in_byte
);
"align size not power of 2: %zu"
,
align_size_in_elements
);
}
}
size_t
Image2DTensorFormatBase
::
init_contiguous_stride
(
void
Image2DTensorFormatBase
::
serialize_append
(
std
::
string
&
result
)
const
{
SerializePack
pack
;
pack
.
align_axis
=
m_align_axis
;
megdnn_assert
(
pack
.
align_axis
==
m_align_axis
);
// detect overflow
result
.
append
(
reinterpret_cast
<
char
*>
(
&
pack
),
sizeof
(
pack
));
}
size_t
Image2DTensorFormatBase
::
image_height
(
const
TensorLayout
&
layout
)
const
{
size_t
accum
=
1
;
for
(
int
i
=
m_align_axis
-
1
;
i
>=
0
;
--
i
)
{
if
(
layout
.
stride
[
i
]
==
0
)
{
// this dimension is broadcasted
}
else
{
accum
*=
layout
.
shape
[
i
];
}
}
return
accum
;
}
size_t
Image2DTensorFormatBase
::
image_width_elems
(
const
TensorLayout
&
layout
)
const
{
size_t
high_elem
=
0
;
for
(
size_t
i
=
m_align_axis
;
i
<
layout
.
ndim
;
++
i
)
{
high_elem
+=
(
layout
.
shape
[
i
]
-
1
)
*
layout
.
stride
[
i
];
}
return
high_elem
+
1
;
}
std
::
string
Image2DTensorFormatBase
::
to_string
()
const
{
return
ssprintf
(
"I2D{%zu,%d}"
,
m_align_axis
,
1
<<
m_align_size_in_elements_log2
);
}
/* ===================== Image2DPackedTensorFormatBase ===================== */
template
<
size_t
PIXEL_SIZE
>
size_t
Image2DPackedTensorFormatBase
<
PIXEL_SIZE
>::
image_width
(
const
TensorLayout
&
layout
)
const
{
auto
ret
=
image_width_elems
(
layout
);
megdnn_assert
(
ret
%
PIXEL_SIZE
==
0
);
return
ret
/
PIXEL_SIZE
;
}
template
<
size_t
PIXEL_SIZE
>
void
Image2DPackedTensorFormatBase
<
PIXEL_SIZE
>::
assert_valid
(
const
TensorLayout
&
layout
)
const
{
auto
m_align_axis
=
align_axis
();
megdnn_assert
(
!
(
layout
.
shape
[
layout
.
ndim
-
1
]
%
PIXEL_SIZE
),
"bad shape: %zu"
,
layout
.
shape
[
layout
.
ndim
-
1
]);
megdnn_assert
(
layout
.
dtype
.
valid
()
&&
layout
.
ndim
>
m_align_axis
);
ptrdiff_t
first_non_zero_stride
=
0
;
for
(
int
i
=
layout
.
ndim
-
1
;
i
>=
0
;
--
i
)
{
megdnn_assert
(
layout
.
shape
[
i
]
&&
layout
.
stride
[
i
]
>=
0
);
if
(
i
<
static_cast
<
int
>
(
m_align_axis
)
&&
!
first_non_zero_stride
)
{
first_non_zero_stride
=
layout
.
stride
[
i
];
}
}
size_t
mask
=
image_pitch_alignment_in_bytes
(
align_size_in_elements
(
layout
.
dtype
.
size_log
()),
layout
)
-
1
;
megdnn_assert
(
!
(
first_non_zero_stride
&
mask
),
"first stride is %d, but alignment is %zu"
,
static_cast
<
int
>
(
first_non_zero_stride
),
mask
+
1
);
}
template
<
size_t
PIXEL_SIZE
>
size_t
Image2DPackedTensorFormatBase
<
PIXEL_SIZE
>::
image_row_pitch
(
const
TensorLayout
&
layout
)
const
{
for
(
int
i
=
align_axis
()
-
1
;
i
>=
0
;
--
i
)
{
// find a non-broadcast axis
if
(
auto
s
=
layout
.
stride
[
i
])
{
return
layout
.
dtype
.
size
(
s
);
}
}
// use width for all broadcasted case
size_t
alignment_in_bytes_log2
=
align_size_in_elements_log2
();
if
(
m_vendor_type
==
Handle
::
HandleVendorType
::
MALI
)
{
alignment_in_bytes_log2
+=
__builtin_ctz
(
layout
.
dtype
.
size
()
*
PIXEL_SIZE
);
}
return
get_aligned_power2
<
size_t
>
(
layout
.
dtype
.
size
(
image_width_elems
(
layout
)),
1
<<
alignment_in_bytes_log2
);
}
template
<
size_t
PIXEL_SIZE
>
size_t
Image2DPackedTensorFormatBase
<
PIXEL_SIZE
>::
image_pitch_alignment_in_bytes
(
size_t
align_size_in_elements
,
const
TensorLayout
&
layout
)
const
{
return
m_vendor_type
==
Handle
::
HandleVendorType
::
MALI
?
(
align_size_in_elements
*
layout
.
dtype
.
size
()
*
PIXEL_SIZE
)
:
align_size_in_elements
;
}
template
<
size_t
PIXEL_SIZE
>
TensorLayout
::
Span
Image2DPackedTensorFormatBase
<
PIXEL_SIZE
>::
span_spec
(
const
TensorLayout
&
layout
)
const
{
assert_valid
(
layout
);
size_t
size
=
image_height
(
layout
)
*
image_row_pitch
(
layout
);
auto
mask
=
(
1
<<
layout
.
dtype
.
size_log
())
-
1
;
megdnn_assert
(
!
(
size
&
mask
),
"unaligned size: %zu"
,
size
);
return
{
0
,
0
,
size
>>
layout
.
dtype
.
size_log
(),
size
};
}
template
<
size_t
PIXEL_SIZE
>
size_t
Image2DPackedTensorFormatBase
<
PIXEL_SIZE
>::
init_contiguous_stride
(
TensorLayout
&
layout
)
const
{
TensorLayout
&
layout
)
const
{
auto
m_align_axis
=
align_axis
();
if
(
!
layout
.
ndim
)
if
(
!
layout
.
ndim
)
return
0
;
return
0
;
megdnn_assert
(
layout
.
dtype
.
valid
()
&&
layout
.
ndim
>
m_align_axis
,
megdnn_assert
(
layout
.
dtype
.
valid
()
&&
layout
.
ndim
>
m_align_axis
,
"dtype=%s ndim=%zu align=%zu"
,
layout
.
dtype
.
name
(),
"dtype=%s ndim=%zu align=%zu"
,
layout
.
dtype
.
name
(),
layout
.
ndim
,
m_align_axis
);
layout
.
ndim
,
m_align_axis
);
size_t
align_size
=
align_size_in_byte
(
layout
.
dtype
.
size_log
());
size_t
align_size
=
image_pitch_alignment_in_bytes
(
align_size_in_elements
(
layout
.
dtype
.
size_log
()),
layout
);
size_t
accum
=
1
;
size_t
accum
=
1
;
SafeMultiplies
<
size_t
>
mul
;
SafeMultiplies
<
size_t
>
mul
;
for
(
size_t
i
=
layout
.
ndim
;
i
;
--
i
)
{
for
(
size_t
i
=
layout
.
ndim
;
i
;
--
i
)
{
...
@@ -216,12 +327,15 @@ size_t Image2DTensorFormatBase::init_contiguous_stride(
...
@@ -216,12 +327,15 @@ size_t Image2DTensorFormatBase::init_contiguous_stride(
return
accum
;
return
accum
;
};
};
bool
Image2DTensorFormatBase
::
is_contiguous_spec
(
template
<
size_t
PIXEL_SIZE
>
bool
Image2DPackedTensorFormatBase
<
PIXEL_SIZE
>::
is_contiguous_spec
(
const
TensorLayout
&
layout
)
const
{
const
TensorLayout
&
layout
)
const
{
megdnn_assert
(
layout
.
dtype
.
valid
());
megdnn_assert
(
layout
.
dtype
.
valid
());
size_t
align_size
=
align_size_in_byte
(
layout
.
dtype
.
size_log
());
size_t
align_size
=
image_pitch_alignment_in_bytes
(
align_size_in_elements
(
layout
.
dtype
.
size_log
()),
layout
);
ptrdiff_t
expected
=
1
;
ptrdiff_t
expected
=
1
;
int
height_axis
=
static_cast
<
int
>
(
m_align_axis
-
1
);
int
height_axis
=
static_cast
<
int
>
(
align_axis
()
-
1
);
for
(
int
i
=
layout
.
ndim
-
1
;
i
>=
0
;
--
i
)
{
for
(
int
i
=
layout
.
ndim
-
1
;
i
>=
0
;
--
i
)
{
if
(
i
==
height_axis
)
{
if
(
i
==
height_axis
)
{
expected
=
megdnn
::
get_aligned_power2
<
size_t
>
(
expected
,
align_size
);
expected
=
megdnn
::
get_aligned_power2
<
size_t
>
(
expected
,
align_size
);
...
@@ -235,7 +349,12 @@ bool Image2DTensorFormatBase::is_contiguous_spec(
...
@@ -235,7 +349,12 @@ bool Image2DTensorFormatBase::is_contiguous_spec(
return
false
;
return
false
;
}
}
size_t
mask
=
align_size_in_byte
(
layout
.
dtype
.
size_log
())
-
1
;
size_t
mask
=
image_pitch_alignment_in_bytes
(
align_size_in_elements
(
layout
.
dtype
.
size_log
()),
layout
)
-
1
;
megdnn_assert
(
s
>
expected
&&
!
(
s
&
mask
),
megdnn_assert
(
s
>
expected
&&
!
(
s
&
mask
),
"invalid row pitch: %d; layout: %s"
,
"invalid row pitch: %d; layout: %s"
,
static_cast
<
int
>
(
s
),
layout
.
to_string
().
c_str
());
static_cast
<
int
>
(
s
),
layout
.
to_string
().
c_str
());
...
@@ -250,11 +369,12 @@ bool Image2DTensorFormatBase::is_contiguous_spec(
...
@@ -250,11 +369,12 @@ bool Image2DTensorFormatBase::is_contiguous_spec(
return
expected
!=
0
;
return
expected
!=
0
;
}
}
TensorLayout
Image2DTensorFormatBase
::
collapse_contiguous_spec
(
template
<
size_t
PIXEL_SIZE
>
TensorLayout
Image2DPackedTensorFormatBase
<
PIXEL_SIZE
>::
collapse_contiguous_spec
(
const
TensorLayout
&
layout
)
const
{
const
TensorLayout
&
layout
)
const
{
assert_valid
(
layout
);
assert_valid
(
layout
);
TensorLayout
res
{
layout
};
TensorLayout
res
{
layout
};
int
new_axis
=
m_align_axis
;
int
new_axis
=
align_axis
()
;
// remove all dims with shape 1
// remove all dims with shape 1
for
(
int
i
=
static_cast
<
int
>
(
res
.
ndim
)
-
1
;
i
>=
0
&&
res
.
ndim
>=
3
;
--
i
)
{
for
(
int
i
=
static_cast
<
int
>
(
res
.
ndim
)
-
1
;
i
>=
0
&&
res
.
ndim
>=
3
;
--
i
)
{
if
(
i
==
new_axis
&&
static_cast
<
int
>
(
res
.
ndim
)
==
new_axis
+
1
)
{
if
(
i
==
new_axis
&&
static_cast
<
int
>
(
res
.
ndim
)
==
new_axis
+
1
)
{
...
@@ -302,95 +422,6 @@ TensorLayout Image2DTensorFormatBase::collapse_contiguous_spec(
...
@@ -302,95 +422,6 @@ TensorLayout Image2DTensorFormatBase::collapse_contiguous_spec(
return
res
;
return
res
;
}
}
TensorLayout
::
Span
Image2DTensorFormatBase
::
span_spec
(
const
TensorLayout
&
layout
)
const
{
assert_valid
(
layout
);
size_t
size
=
image_height
(
layout
)
*
image_row_pitch
(
layout
);
auto
mask
=
(
1
<<
layout
.
dtype
.
size_log
())
-
1
;
megdnn_assert
(
!
(
size
&
mask
),
"unaligned size: %zu"
,
size
);
return
{
0
,
0
,
size
>>
layout
.
dtype
.
size_log
(),
size
};
}
void
Image2DTensorFormatBase
::
serialize_append
(
std
::
string
&
result
)
const
{
SerializePack
pack
;
pack
.
align_axis
=
m_align_axis
;
megdnn_assert
(
pack
.
align_axis
==
m_align_axis
);
// detect overflow
result
.
append
(
reinterpret_cast
<
char
*>
(
&
pack
),
sizeof
(
pack
));
}
size_t
Image2DTensorFormatBase
::
image_height
(
const
TensorLayout
&
layout
)
const
{
size_t
accum
=
1
;
for
(
int
i
=
m_align_axis
-
1
;
i
>=
0
;
--
i
)
{
if
(
layout
.
stride
[
i
]
==
0
)
{
// this dimension is broadcasted
}
else
{
accum
*=
layout
.
shape
[
i
];
}
}
return
accum
;
}
size_t
Image2DTensorFormatBase
::
image_row_pitch
(
const
TensorLayout
&
layout
)
const
{
for
(
int
i
=
m_align_axis
-
1
;
i
>=
0
;
--
i
)
{
// find a non-broadcast axis
if
(
auto
s
=
layout
.
stride
[
i
])
{
return
layout
.
dtype
.
size
(
s
);
}
}
// use width for all broadcasted case
return
get_aligned_power2
<
size_t
>
(
layout
.
dtype
.
size
(
image_width_elems
(
layout
)),
1
<<
m_align_size_in_byte_log2
);
}
void
Image2DTensorFormatBase
::
assert_valid
(
const
TensorLayout
&
layout
)
const
{
megdnn_assert
(
layout
.
dtype
.
valid
()
&&
layout
.
ndim
>
m_align_axis
);
ptrdiff_t
first_non_zero_stride
=
0
;
for
(
int
i
=
layout
.
ndim
-
1
;
i
>=
0
;
--
i
)
{
megdnn_assert
(
layout
.
shape
[
i
]
&&
layout
.
stride
[
i
]
>=
0
);
if
(
i
<
static_cast
<
int
>
(
m_align_axis
)
&&
!
first_non_zero_stride
)
{
first_non_zero_stride
=
layout
.
stride
[
i
];
}
}
size_t
mask
=
align_size_in_byte
(
layout
.
dtype
.
size_log
())
-
1
;
megdnn_assert
(
!
(
first_non_zero_stride
&
mask
),
"first stride is %d, but alignment is %zu"
,
static_cast
<
int
>
(
first_non_zero_stride
),
mask
+
1
);
}
size_t
Image2DTensorFormatBase
::
image_width_elems
(
const
TensorLayout
&
layout
)
const
{
size_t
high_elem
=
0
;
for
(
size_t
i
=
m_align_axis
;
i
<
layout
.
ndim
;
++
i
)
{
high_elem
+=
(
layout
.
shape
[
i
]
-
1
)
*
layout
.
stride
[
i
];
}
return
high_elem
+
1
;
}
std
::
string
Image2DTensorFormatBase
::
to_string
()
const
{
return
ssprintf
(
"I2D{%zu,%d}"
,
m_align_axis
,
1
<<
m_align_size_in_byte_log2
);
}
/* ===================== Image2DPackedTensorFormatBase ===================== */
template
<
size_t
PIXEL_SIZE
>
size_t
Image2DPackedTensorFormatBase
<
PIXEL_SIZE
>::
image_width
(
const
TensorLayout
&
layout
)
const
{
auto
ret
=
image_width_elems
(
layout
);
megdnn_assert
(
ret
%
PIXEL_SIZE
==
0
);
return
ret
/
PIXEL_SIZE
;
}
template
<
size_t
PIXEL_SIZE
>
void
Image2DPackedTensorFormatBase
<
PIXEL_SIZE
>::
assert_valid
(
const
TensorLayout
&
layout
)
const
{
Image2DTensorFormatBase
::
assert_valid
(
layout
);
megdnn_assert
(
!
(
layout
.
shape
[
layout
.
ndim
-
1
]
%
PIXEL_SIZE
),
"bad shape: %zu"
,
layout
.
shape
[
layout
.
ndim
-
1
]);
}
namespace
megdnn
{
namespace
megdnn
{
namespace
detail
{
namespace
detail
{
template
class
Image2DPackedTensorFormatBase
<
4
>;
template
class
Image2DPackedTensorFormatBase
<
4
>;
...
@@ -398,26 +429,29 @@ template class Image2DPackedTensorFormatBase<4>;
...
@@ -398,26 +429,29 @@ template class Image2DPackedTensorFormatBase<4>;
}
// namespace megdnn
}
// namespace megdnn
/* ===================== Image2DPack4TensorFormat ===================== */
/* ===================== Image2DPack4TensorFormat ===================== */
TensorFormat
Image2DPack4TensorFormat
::
make_raw
(
size_t
align_axis
,
TensorFormat
Image2DPack4TensorFormat
::
make_raw
(
size_t
align_size_in_byte
)
{
size_t
align_axis
,
size_t
align_size_in_elements
,
Handle
::
HandleVendorType
vendor_type
)
{
static
std
::
mutex
mtx
;
static
std
::
mutex
mtx
;
static
std
::
unordered_map
<
uint64_t
,
static
std
::
unordered_map
<
uint64_t
,
std
::
unique_ptr
<
Image2DPack4TensorFormat
>>
std
::
unique_ptr
<
Image2DPack4TensorFormat
>>
cache
;
cache
;
megdnn_assert
(
std
::
max
(
align_axis
,
align_size_in_
byte
)
<=
megdnn_assert
(
std
::
max
(
align_axis
,
align_size_in_
elements
)
<=
std
::
numeric_limits
<
uint32_t
>::
max
());
std
::
numeric_limits
<
uint32_t
>::
max
());
MEGDNN_LOCK_GUARD
(
mtx
);
MEGDNN_LOCK_GUARD
(
mtx
);
auto
&&
ptr
=
cache
[(
static_cast
<
uint64_t
>
(
align_axis
)
<<
32
)
|
auto
&&
ptr
=
cache
[(
static_cast
<
uint64_t
>
(
align_axis
)
<<
32
)
|
align_size_in_
byte
];
align_size_in_
elements
];
if
(
!
ptr
)
{
if
(
!
ptr
)
{
ptr
.
reset
(
new
Image2DPack4TensorFormat
{
align_axis
,
align_size_in_byte
});
ptr
.
reset
(
new
Image2DPack4TensorFormat
{
align_axis
,
align_size_in_elements
,
vendor_type
});
}
}
return
impl_to_tensor_format
(
ptr
.
get
());
return
impl_to_tensor_format
(
ptr
.
get
());
}
}
TensorFormat
Image2DPack4TensorFormat
::
make
(
size_t
align_axis
,
TensorFormat
Image2DPack4TensorFormat
::
make
(
size_t
align_axis
,
const
Handle
*
handle
)
{
const
Handle
*
handle
)
{
return
make_raw
(
align_axis
,
handle
->
image2d_pitch_alignment
());
return
make_raw
(
align_axis
,
handle
->
image2d_pitch_alignment
(),
handle
->
vendor_type
());
}
}
TensorFormat
Image2DPack4TensorFormat
::
deserialize
(
const
Handle
*
handle
,
TensorFormat
Image2DPack4TensorFormat
::
deserialize
(
const
Handle
*
handle
,
...
@@ -429,7 +463,7 @@ TensorFormat Image2DPack4TensorFormat::deserialize(const Handle* handle,
...
@@ -429,7 +463,7 @@ TensorFormat Image2DPack4TensorFormat::deserialize(const Handle* handle,
}
}
TensorFormat
Image2DPack4TensorFormat
::
change_axis
(
size_t
axis
)
const
{
TensorFormat
Image2DPack4TensorFormat
::
change_axis
(
size_t
axis
)
const
{
return
make_raw
(
axis
,
align_size_in_
byte
());
return
make_raw
(
axis
,
align_size_in_
elements
(),
vendor
());
}
}
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
dnn/src/cuda/handle.cpp
浏览文件 @
cd7090ac
...
@@ -123,6 +123,10 @@ size_t HandleImpl::image2d_pitch_alignment() const {
...
@@ -123,6 +123,10 @@ size_t HandleImpl::image2d_pitch_alignment() const {
return
align
;
return
align
;
}
}
HandleImpl
::
HandleVendorType
HandleImpl
::
vendor_type
()
const
{
return
HandleVendorType
::
CUDA
;
}
}
// namespace cuda
}
// namespace cuda
}
// namespace megdnn
}
// namespace megdnn
...
...
dnn/src/cuda/handle.h
浏览文件 @
cd7090ac
...
@@ -123,6 +123,7 @@ class HandleImpl: public HandleImplHelper {
...
@@ -123,6 +123,7 @@ class HandleImpl: public HandleImplHelper {
TypeCvt
*
typecvt_opr
()
{
return
get_helper_opr
<
TypeCvt
,
0
>
(
this
);
}
TypeCvt
*
typecvt_opr
()
{
return
get_helper_opr
<
TypeCvt
,
0
>
(
this
);
}
size_t
image2d_pitch_alignment
()
const
override
;
size_t
image2d_pitch_alignment
()
const
override
;
HandleVendorType
vendor_type
()
const
override
;
private:
private:
bool
m_is_tegra_k1
;
bool
m_is_tegra_k1
;
int
m_device_id
;
int
m_device_id
;
...
...
dnn/src/naive/handle.cpp
浏览文件 @
cd7090ac
...
@@ -118,6 +118,10 @@ size_t HandleImpl::image2d_pitch_alignment() const {
...
@@ -118,6 +118,10 @@ size_t HandleImpl::image2d_pitch_alignment() const {
return
g_image2d_pitch_alignment
;
return
g_image2d_pitch_alignment
;
}
}
HandleImpl
::
HandleVendorType
HandleImpl
::
vendor_type
()
const
{
return
HandleVendorType
::
NOT_SPEC
;
}
size_t
HandleImpl
::
exchange_image2d_pitch_alignment
(
size_t
alignment
)
{
size_t
HandleImpl
::
exchange_image2d_pitch_alignment
(
size_t
alignment
)
{
auto
ret
=
g_image2d_pitch_alignment
;
auto
ret
=
g_image2d_pitch_alignment
;
g_image2d_pitch_alignment
=
alignment
;
g_image2d_pitch_alignment
=
alignment
;
...
...
dnn/src/naive/handle.h
浏览文件 @
cd7090ac
...
@@ -169,6 +169,7 @@ public:
...
@@ -169,6 +169,7 @@ public:
* \param alignment the new alignment value to set
* \param alignment the new alignment value to set
*/
*/
static
size_t
exchange_image2d_pitch_alignment
(
size_t
alignment
);
static
size_t
exchange_image2d_pitch_alignment
(
size_t
alignment
);
HandleVendorType
vendor_type
()
const
override
;
};
};
}
// namespace naive
}
// namespace naive
...
...
dnn/test/common/test_basic_types.cpp
浏览文件 @
cd7090ac
...
@@ -175,6 +175,30 @@ namespace {
...
@@ -175,6 +175,30 @@ namespace {
}
}
}
}
TEST
(
BASIC_TYPES
,
TENSOR_LAYOUT_FMT_WITH_VENDOR_MALI
)
{
TensorFormat
fmt
=
Image2DPack4TensorFormat
::
make_raw
(
1
,
512
,
Handle
::
HandleVendorType
::
MALI
);
TensorLayout
layout
{{
5
,
3
,
8
},
dtype
::
Float32
{},
fmt
};
ASSERT_EQ
(
layout
.
stride
[
2
],
1
);
ASSERT_EQ
(
layout
.
stride
[
1
],
8
);
ASSERT_EQ
(
layout
.
stride
[
0
],
2048
);
ASSERT_EQ
(
8192u
,
image_row_pitch
(
layout
));
ASSERT_EQ
(
6u
,
image_width
(
layout
));
ASSERT_EQ
(
5u
,
image_height
(
layout
));
fmt
=
Image2DPack4TensorFormat
::
make_raw
(
1
,
512
,
Handle
::
HandleVendorType
::
MALI
);
TensorLayout
layout_s
{{
5
,
3
,
8
},
dtype
::
Float16
{},
fmt
};
ASSERT_EQ
(
layout_s
.
stride
[
2
],
1
);
ASSERT_EQ
(
layout_s
.
stride
[
1
],
8
);
ASSERT_EQ
(
layout_s
.
stride
[
0
],
2048
);
ASSERT_EQ
(
4096u
,
image_row_pitch
(
layout_s
));
ASSERT_EQ
(
6u
,
image_width
(
layout_s
));
ASSERT_EQ
(
5u
,
image_height
(
layout_s
));
}
TEST
(
BASIC_TYPES
,
TENSOR_LAYOUT_FMT
)
{
TEST
(
BASIC_TYPES
,
TENSOR_LAYOUT_FMT
)
{
TensorFormat
fmt
=
Image2DPack4TensorFormat
::
make_raw
(
1
,
1024
);
TensorFormat
fmt
=
Image2DPack4TensorFormat
::
make_raw
(
1
,
1024
);
ASSERT_FALSE
(
fmt
.
is_default
());
ASSERT_FALSE
(
fmt
.
is_default
());
...
@@ -233,7 +257,7 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT) {
...
@@ -233,7 +257,7 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT) {
auto
&&
impl
=
contig
.
format
.
as_impl
<
Image2DPack4TensorFormat
>
();
auto
&&
impl
=
contig
.
format
.
as_impl
<
Image2DPack4TensorFormat
>
();
ASSERT_EQ
(
make_layout
({
1
,
8
},
{
32
,
1
},
layout
.
dtype
),
contig
);
ASSERT_EQ
(
make_layout
({
1
,
8
},
{
32
,
1
},
layout
.
dtype
),
contig
);
ASSERT_EQ
(
1u
,
impl
.
align_axis
());
ASSERT_EQ
(
1u
,
impl
.
align_axis
());
ASSERT_EQ
(
64u
,
impl
.
align_size_in_
byte
());
ASSERT_EQ
(
64u
,
impl
.
align_size_in_
elements
());
}
}
}
}
...
@@ -258,7 +282,7 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_H) {
...
@@ -258,7 +282,7 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_H) {
auto
&&
impl
=
contig
.
format
.
as_impl
<
Image2DPack4TensorFormat
>
();
auto
&&
impl
=
contig
.
format
.
as_impl
<
Image2DPack4TensorFormat
>
();
ASSERT_EQ
(
make_layout
({
v0
,
8
},
{
32
,
1
},
layout
.
dtype
),
contig
);
ASSERT_EQ
(
make_layout
({
v0
,
8
},
{
32
,
1
},
layout
.
dtype
),
contig
);
ASSERT_EQ
(
1u
,
impl
.
align_axis
());
ASSERT_EQ
(
1u
,
impl
.
align_axis
());
ASSERT_EQ
(
64u
,
impl
.
align_size_in_
byte
());
ASSERT_EQ
(
64u
,
impl
.
align_size_in_
elements
());
}
}
}
}
...
@@ -274,7 +298,7 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_W) {
...
@@ -274,7 +298,7 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_W) {
layout
.
dtype
),
layout
.
dtype
),
contig
);
contig
);
ASSERT_EQ
(
1u
,
impl
.
align_axis
());
ASSERT_EQ
(
1u
,
impl
.
align_axis
());
ASSERT_EQ
(
64u
,
impl
.
align_size_in_
byte
());
ASSERT_EQ
(
64u
,
impl
.
align_size_in_
elements
());
}
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录