Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
828d0b12
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
828d0b12
编写于
5月 18, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 18, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1209 add format trans function
Merge pull request !1209 from liubuyu/master
上级
699d0c10
f70429d6
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
262 addition
and
66 deletion
+262
-66
mindspore/ccsrc/common/trans.cc
mindspore/ccsrc/common/trans.cc
+130
-58
mindspore/ccsrc/common/trans.h
mindspore/ccsrc/common/trans.h
+2
-0
mindspore/ccsrc/device/ascend/ascend_device_address.cc
mindspore/ccsrc/device/ascend/ascend_device_address.cc
+17
-8
tests/ut/cpp/common/trans_test.cc
tests/ut/cpp/common/trans_test.cc
+113
-0
未找到文件。
mindspore/ccsrc/common/trans.cc
浏览文件 @
828d0b12
...
...
@@ -63,26 +63,24 @@ const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberType
{
kNumberTypeUInt32
,
4
},
{
kNumberTypeUInt64
,
8
},
{
kNumberTypeFloat
,
4
},
{
kNumberTypeFloat16
,
2
},
{
kNumberTypeFloat32
,
4
},
{
kNumberTypeFloat64
,
8
}};
#define SetDataBysize(size, pad_zero) \
do { \
switch (size) { \
case 1: \
static_cast<uint8_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint8_t *>(args.data)[src_idx]; \
break; \
case 2: \
static_cast<uint16_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint16_t *>(args.data)[src_idx]; \
break; \
case 4: \
static_cast<uint32_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint32_t *>(args.data)[src_idx]; \
break; \
case 8: \
static_cast<uint64_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint64_t *>(args.data)[src_idx]; \
break; \
default: \
MS_LOG(ERROR) << "Trans data not support size " << size; \
return false; \
} \
} while (0)
inline
void
SetData
(
size_t
size
,
bool
pad_zero
,
size_t
src_idx
,
size_t
dst_idx
,
const
FormatArgs
&
args
,
void
*
result
)
{
switch
(
size
)
{
case
1
:
static_cast
<
uint8_t
*>
(
result
)[
dst_idx
]
=
pad_zero
?
0
:
static_cast
<
const
uint8_t
*>
(
args
.
data
)[
src_idx
];
break
;
case
2
:
static_cast
<
uint16_t
*>
(
result
)[
dst_idx
]
=
pad_zero
?
0
:
static_cast
<
const
uint16_t
*>
(
args
.
data
)[
src_idx
];
break
;
case
4
:
static_cast
<
uint32_t
*>
(
result
)[
dst_idx
]
=
pad_zero
?
0
:
static_cast
<
const
uint32_t
*>
(
args
.
data
)[
src_idx
];
break
;
case
8
:
static_cast
<
uint64_t
*>
(
result
)[
dst_idx
]
=
pad_zero
?
0
:
static_cast
<
const
uint64_t
*>
(
args
.
data
)[
src_idx
];
break
;
default:
MS_LOG
(
EXCEPTION
)
<<
"Trans data not support size "
<<
size
;
}
}
template
<
typename
T
>
T
DivCeil
(
T
n1
,
T
n2
)
{
...
...
@@ -401,6 +399,13 @@ std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
device_shape
.
push_back
(
C0
);
return
device_shape
;
}
std
::
vector
<
size_t
>
NdhwcDeviceShape
(
const
std
::
vector
<
size_t
>
&
shape
)
{
if
(
shape
.
size
()
<
5
)
{
MS_LOG
(
EXCEPTION
)
<<
"Shape dims must be 5 when format is ndhwc."
;
}
return
shape
;
}
}
// namespace
std
::
vector
<
size_t
>
TransShapeToDevice
(
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
string
&
format
)
{
...
...
@@ -412,7 +417,8 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
{
kOpFormat_NC1HWC0
,
Nc1hwc0DeviceShape
},
{
kOpFormat_C1HWNCoC0
,
C1hwncoc0DeviceShape
},
{
kOpFormat_FRACTAL_Z_C04
,
FracZc04DeviceShape
},
{
kOpFormat_NC1HWC0_C04
,
Nc1hwc04DeviceShape
}};
{
kOpFormat_NC1HWC0_C04
,
Nc1hwc04DeviceShape
},
{
kOpFormat_NDHWC
,
NdhwcDeviceShape
}};
if
(
format
==
kOpFormat_ND
||
format
==
kOpFormat_DEFAULT
)
{
return
shape
;
...
...
@@ -482,43 +488,109 @@ bool TransDataType(const TypeIdArgs &args, void *result) {
}
bool
TransFormat
(
const
FormatArgs
&
args
,
void
*
result
)
{
using
FormatTransfer
=
std
::
function
<
bool
(
const
FormatArgs
&
,
void
*
)
>
;
const
std
::
map
<
std
::
string
,
FormatTransfer
>
format_trans_map
{
{
kOpFormat_FRAC_Z
,
NchwToFracZ
},
{
kOpFormat_FRAC_NZ
,
NchwToFracNz
},
{
kOpFormat_NC1HWC0
,
NchwToNc1hwc0
},
{
kOpFormat_C1HWNCoC0
,
NchwToC1hwncoc0
},
{
kOpFormat_FRACTAL_Z_C04
,
NchwToFracZc04
},
{
kOpFormat_NC1HWC0_C04
,
NchwToNc1hwc04
}};
MS_LOG
(
DEBUG
)
<<
"Start trans format."
;
if
(
TypeIdSize
(
args
.
src_data_type
)
<
1
)
{
MS_LOG
(
ERROR
)
<<
"Invalid datatype.."
;
return
false
;
}
if
(
args
.
device_format
==
kOpFormat_FRAC_Z
)
{
return
NchwToFracZ
(
args
,
result
);
}
else
if
(
args
.
device_format
==
kOpFormat_FRAC_NZ
)
{
return
NchwToFracNz
(
args
,
result
);
}
else
if
(
args
.
device_format
==
kOpFormat_NC1HWC0
)
{
return
NchwToNc1hwc0
(
args
,
result
);
}
else
if
(
args
.
device_format
==
kOpFormat_C1HWNCoC0
)
{
return
NchwToC1hwncoc0
(
args
,
result
);
}
else
if
(
args
.
device_format
==
kOpFormat_FRACTAL_Z_C04
)
{
return
NchwToFracZc04
(
args
,
result
);
}
else
if
(
args
.
device_format
==
kOpFormat_NC1HWC0_C04
)
{
return
NchwToNc1hwc04
(
args
,
result
);
if
(
args
.
device_format
==
kOpFormat_HWCN
||
args
.
device_format
==
kOpFormat_NHWC
)
{
return
NchwTo4D
(
args
,
result
);
}
return
true
;
auto
iter
=
format_trans_map
.
find
(
args
.
device_format
);
if
(
iter
==
format_trans_map
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"Unexpected format["
<<
args
.
device_format
<<
"]"
;
}
return
iter
->
second
(
args
,
result
);
}
bool
TransFormatFromDeviceToHost
(
const
FormatArgs
&
args
,
void
*
result
)
{
using
FormatTransfer
=
std
::
function
<
bool
(
const
FormatArgs
&
,
void
*
)
>
;
const
std
::
map
<
std
::
string
,
FormatTransfer
>
format_trans_map
{{
kOpFormat_FRAC_Z
,
FracZToNchw
},
{
kOpFormat_FRAC_NZ
,
FracNzToNchw
},
{
kOpFormat_NC1HWC0
,
Nc1hwc0ToNchw
},
{
kOpFormat_C1HWNCoC0
,
C1hwncoc0ToNchw
},
{
kOpFormat_NC1HWC0_C04
,
Nc1hwc04ToNchw
}};
MS_LOG
(
DEBUG
)
<<
"Start trans format."
;
if
(
TypeIdSize
(
args
.
src_data_type
)
<
1
)
{
MS_LOG
(
ERROR
)
<<
"Invalid datatype.."
;
return
false
;
}
if
(
args
.
device_format
==
kOpFormat_FRAC_Z
)
{
return
FracZToNchw
(
args
,
result
);
}
else
if
(
args
.
device_format
==
kOpFormat_FRAC_NZ
)
{
return
FracNzToNchw
(
args
,
result
);
}
else
if
(
args
.
device_format
==
kOpFormat_NC1HWC0
)
{
return
Nc1hwc0ToNchw
(
args
,
result
);
}
else
if
(
args
.
device_format
==
kOpFormat_C1HWNCoC0
)
{
return
C1hwncoc0ToNchw
(
args
,
result
);
}
else
if
(
args
.
device_format
==
kOpFormat_NC1HWC0_C04
)
{
return
Nc1hwc04ToNchw
(
args
,
result
);
if
(
args
.
device_format
==
kOpFormat_HWCN
||
args
.
device_format
==
kOpFormat_NHWC
)
{
return
ToNchw
(
args
,
result
);
}
auto
iter
=
format_trans_map
.
find
(
args
.
device_format
);
if
(
iter
==
format_trans_map
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"Unexpected format["
<<
args
.
device_format
<<
"]"
;
}
return
iter
->
second
(
args
,
result
);
}
bool
NchwTo4D
(
const
FormatArgs
&
args
,
void
*
result
)
{
// trans nchw to 4d
MS_LOG
(
DEBUG
)
<<
"Trans format from nchw to 4d."
;
MS_EXCEPTION_IF_NULL
(
result
);
size_t
size
=
0
;
size_t
total_size
=
0
;
if
(
!
CheckArgs
(
args
,
&
size
,
&
total_size
))
{
MS_LOG
(
ERROR
)
<<
"Check args failed."
;
return
false
;
}
size_t
n
=
args
.
host_shape
[
0
];
size_t
c
=
args
.
host_shape
[
1
];
size_t
h
=
args
.
host_shape
[
2
];
size_t
w
=
args
.
host_shape
[
3
];
for
(
size_t
ni
=
0
;
ni
<
n
;
ni
++
)
{
for
(
size_t
ci
=
0
;
ci
<
c
;
ci
++
)
{
for
(
size_t
hi
=
0
;
hi
<
h
;
hi
++
)
{
for
(
size_t
wi
=
0
;
wi
<
w
;
wi
++
)
{
auto
src_idx
=
ni
*
c
*
h
*
w
+
ci
*
h
*
w
+
hi
*
w
+
wi
;
auto
dst_idx
=
0
;
if
(
args
.
device_format
==
kOpFormat_NHWC
)
{
dst_idx
=
ni
*
h
*
w
*
c
+
hi
*
w
*
c
+
wi
*
c
+
ci
;
}
else
if
(
args
.
device_format
==
kOpFormat_HWCN
)
{
dst_idx
=
hi
*
w
*
c
*
n
+
wi
*
c
*
n
+
ci
*
n
+
ni
;
}
SetData
(
size
,
false
,
src_idx
,
dst_idx
,
args
,
result
);
}
}
}
}
return
true
;
}
bool
ToNchw
(
const
FormatArgs
&
args
,
void
*
result
)
{
MS_LOG
(
DEBUG
)
<<
"Trans format to nchw from 4d."
;
MS_EXCEPTION_IF_NULL
(
result
);
size_t
size
=
0
;
size_t
total_size
=
0
;
if
(
!
CheckArgs
(
args
,
&
size
,
&
total_size
))
{
MS_LOG
(
ERROR
)
<<
"Check args failed."
;
return
false
;
}
size_t
n
=
args
.
host_shape
[
0
];
size_t
c
=
args
.
host_shape
[
1
];
size_t
h
=
args
.
host_shape
[
2
];
size_t
w
=
args
.
host_shape
[
3
];
for
(
size_t
ni
=
0
;
ni
<
n
;
ni
++
)
{
for
(
size_t
ci
=
0
;
ci
<
c
;
ci
++
)
{
for
(
size_t
hi
=
0
;
hi
<
h
;
hi
++
)
{
for
(
size_t
wi
=
0
;
wi
<
w
;
wi
++
)
{
auto
dst_idx
=
ni
*
c
*
h
*
w
+
ci
*
h
*
w
+
hi
*
w
+
wi
;
auto
src_idx
=
0
;
if
(
args
.
device_format
==
kOpFormat_NHWC
)
{
src_idx
=
ni
*
h
*
w
*
c
+
hi
*
w
*
c
+
wi
*
c
+
ci
;
}
else
if
(
args
.
device_format
==
kOpFormat_HWCN
)
{
src_idx
=
hi
*
w
*
c
*
n
+
wi
*
c
*
n
+
ci
*
n
+
ni
;
}
SetData
(
size
,
false
,
src_idx
,
dst_idx
,
args
,
result
);
}
}
}
}
return
true
;
}
...
...
@@ -575,8 +647,8 @@ bool NchwToFracZ(const FormatArgs &args, void *result) {
auto
src_ni
=
hfi
*
kCubeSize
+
col
;
auto
src_idx
=
src_row_offset
+
chw
*
col
;
auto
dst_idx
=
gfi
*
fractal_ele_cnt
+
col
*
c0
+
row
;
auto
pad_zero
=
(
src_ni
>=
n
||
src_idx
>=
nchw
||
src_ci
>=
c
)
?
1
:
0
;
SetData
Bysize
(
size
,
pad_zero
);
auto
pad_zero
=
(
src_ni
>=
n
||
src_idx
>=
nchw
||
src_ci
>=
c
)
?
true
:
false
;
SetData
(
size
,
pad_zero
,
src_idx
,
dst_idx
,
args
,
result
);
}
}
}
...
...
@@ -630,7 +702,7 @@ bool FracZToNchw(const FormatArgs &args, void *result) {
size_t
c0_idx
=
c_idx
%
c0
;
size_t
nc_idx
=
n_idx
;
size_t
src_idx
=
c1_idx
*
hwncc0
+
h_idx
*
wncc0
+
w_idx
*
ncc0
+
nc_idx
*
c0
+
c0_idx
;
SetData
Bysize
(
size
,
0
);
SetData
(
size
,
false
,
src_idx
,
dst_idx
,
args
,
result
);
}
}
}
...
...
@@ -679,7 +751,7 @@ bool NchwToFracZc04(const FormatArgs &args, void *result) {
auto
c_idx
=
desc_c1
*
c0
+
desc_c0
;
auto
src_idx
=
desc_g
*
nhwc
+
desc_n
*
hwc
+
c_idx
*
h
*
w
+
desc_h
*
w
+
desc_w
;
auto
pad_zero
=
desc_g
>=
1
||
desc_n
>=
n
||
c_idx
>=
c
;
SetData
Bysize
(
size
,
pad_zero
);
SetData
(
size
,
pad_zero
,
src_idx
,
dst_idx
,
args
,
result
);
dst_idx
++
;
}
}
...
...
@@ -773,7 +845,7 @@ bool NchwToFracNz(const FormatArgs &args, void *result) {
for
(
size_t
i
=
0
;
i
<
w0
;
++
i
)
{
size_t
src_idx
=
src_h_head
+
w1_idx
*
w0
+
i
;
size_t
dst_idx
=
h1h0_head
+
w1_idx
*
h1h0w0
+
i
;
SetData
Bysize
(
size
,
0
);
SetData
(
size
,
false
,
src_idx
,
dst_idx
,
args
,
result
);
}
}
auto
w1_head
=
num_w1
*
w0
;
...
...
@@ -781,7 +853,7 @@ bool NchwToFracNz(const FormatArgs &args, void *result) {
auto
src_w_idx
=
w1_head
+
w0_idx
;
size_t
dst_idx
=
h1h0_head
+
num_w1
*
h1h0w0
+
w0_idx
;
size_t
src_idx
=
src_h_head
+
src_w_idx
;
SetData
Bysize
(
size
,
0
);
SetData
(
size
,
false
,
src_idx
,
dst_idx
,
args
,
result
);
}
}
}
...
...
@@ -835,7 +907,7 @@ bool FracNzToNchw(const FormatArgs &args, void *result) {
for
(
size_t
i
=
0
;
i
<
w0
;
++
i
)
{
size_t
src_idx
=
h1h0_head
+
w1_idx
*
h1h0w0
+
i
;
size_t
dst_idx
=
src_h_head
+
w1_idx
*
w0
+
i
;
SetData
Bysize
(
size
,
0
);
SetData
(
size
,
false
,
src_idx
,
dst_idx
,
args
,
result
);
}
}
auto
w1_head
=
num_w1
*
w0
;
...
...
@@ -843,7 +915,7 @@ bool FracNzToNchw(const FormatArgs &args, void *result) {
auto
src_w_idx
=
w1_head
+
w0_idx
;
size_t
src_idx
=
h1h0_head
+
num_w1
*
h1h0w0
+
w0_idx
;
size_t
dst_idx
=
src_h_head
+
src_w_idx
;
SetData
Bysize
(
size
,
0
);
SetData
(
size
,
false
,
src_idx
,
dst_idx
,
args
,
result
);
}
}
}
...
...
@@ -895,8 +967,8 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
size_t
dst_idx
=
c0_idx
+
w_head_addr
;
size_t
c_idx
=
c0_idx
+
c1_idx
*
c0
;
size_t
src_idx
=
n_idx
*
chw
+
c_idx
*
hw
+
h_idx
*
w
+
w_idx
;
auto
pad_zero
=
(
c_idx
<
c
)
?
0
:
1
;
SetData
Bysize
(
size
,
pad_zero
);
auto
pad_zero
=
(
c_idx
<
c
)
?
false
:
true
;
SetData
(
size
,
pad_zero
,
src_idx
,
dst_idx
,
args
,
result
);
}
}
}
...
...
@@ -947,7 +1019,7 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) {
size_t
c1_idx
=
c_idx
/
c0
;
size_t
c0_idx
=
c_idx
%
c0
;
size_t
src_idx
=
n_idx
*
c1hwc0
+
c1_idx
*
hwc0
+
h_idx
*
wc0
+
w_idx
*
c0
+
c0_idx
;
SetData
Bysize
(
size
,
0
);
SetData
(
size
,
false
,
src_idx
,
dst_idx
,
args
,
result
);
}
}
}
...
...
@@ -983,8 +1055,8 @@ bool NchwToC1hwncoc0(const FormatArgs &args, void *result) {
co_i
*
c0
+
c0_i
;
size_t
c_i
=
c0_i
+
c1_i
*
c0
;
size_t
src_idx
=
n_i
*
c
*
h
*
w
+
c_i
*
h
*
w
+
h_i
*
w
+
w_i
;
auto
pad_zero
=
(
c_i
<
c
&&
c0_i
==
co_i
)
?
0
:
1
;
SetData
Bysize
(
size
,
pad_zero
);
auto
pad_zero
=
(
c_i
<
c
&&
c0_i
==
co_i
)
?
false
:
true
;
SetData
(
size
,
pad_zero
,
src_idx
,
dst_idx
,
args
,
result
);
}
}
}
...
...
@@ -1020,7 +1092,7 @@ bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) {
size_t
co_i
=
c0_i
;
size_t
src_idx
=
c1_i
*
h
*
w
*
n
*
co
*
c0
+
h_i
*
w
*
n
*
co
*
c0
+
w_i
*
n
*
co
*
c0
+
n_i
*
co
*
c0
+
co_i
*
c0
+
c0_i
;
SetData
Bysize
(
size
,
0
);
SetData
(
size
,
false
,
src_idx
,
dst_idx
,
args
,
result
);
}
}
}
...
...
mindspore/ccsrc/common/trans.h
浏览文件 @
828d0b12
...
...
@@ -61,6 +61,7 @@ bool TransFormat(const FormatArgs &args, void *result);
bool
TransFormatFromDeviceToHost
(
const
FormatArgs
&
args
,
void
*
result
);
// host to device
bool
NchwTo4D
(
const
FormatArgs
&
args
,
void
*
result
);
bool
NchwToFracZ
(
const
FormatArgs
&
args
,
void
*
result
);
bool
NchwToFracNz
(
const
FormatArgs
&
args
,
void
*
result
);
bool
NchwToNc1hwc0
(
const
FormatArgs
&
args
,
void
*
result
);
...
...
@@ -68,6 +69,7 @@ bool NchwToFracZc04(const FormatArgs &args, void *result);
bool
NchwToNc1hwc04
(
const
FormatArgs
&
args
,
void
*
result
);
bool
NchwToC1hwncoc0
(
const
FormatArgs
&
args
,
void
*
result
);
// device to host
bool
ToNchw
(
const
FormatArgs
&
args
,
void
*
result
);
bool
FracZToNchw
(
const
FormatArgs
&
args
,
void
*
result
);
bool
FracNzToNchw
(
const
FormatArgs
&
args
,
void
*
result
);
bool
Nc1hwc0ToNchw
(
const
FormatArgs
&
args
,
void
*
result
);
...
...
mindspore/ccsrc/device/ascend/ascend_device_address.cc
浏览文件 @
828d0b12
...
...
@@ -16,6 +16,7 @@
#include "device/ascend/ascend_device_address.h"
#include <memory>
#include <vector>
#include <set>
#include <algorithm>
#include "runtime/mem.h"
#include "device/kernel_runtime_manager.h"
...
...
@@ -34,6 +35,10 @@ namespace device {
namespace
ascend
{
const
int
FLOAT_LEN
=
sizeof
(
float
);
const
int
FLOAT16_LEN
=
2
;
// sizeof(float16);
const
std
::
set
<
std
::
string
>
kOpNeedTransFormat
=
{
kOpFormat_NHWC
,
kOpFormat_HWCN
,
kOpFormat_NC1HWC0
,
kOpFormat_FRAC_Z
,
kOpFormat_C1HWNCoC0
,
kOpFormat_FRAC_NZ
,
kOpFormat_NC1HWC0_C04
,
kOpFormat_FRACTAL_Z_C04
};
void
SyncMemory
(
void
*
dst
,
const
void
*
src
,
uint64_t
size
,
rtMemcpyKind_t
kind
)
{
auto
ret_rt_memcpy
=
rtMemcpy
(
dst
,
size
,
src
,
size
,
kind
);
if
(
ret_rt_memcpy
!=
RT_ERROR_NONE
)
{
...
...
@@ -97,7 +102,7 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t
if
(
host_shape
.
empty
())
{
host_shape
.
emplace_back
(
1
);
}
if
(
format_
==
kOpFormat_NCHW
||
format_
==
kOpFormat_DEFAULT
)
{
if
(
format_
==
kOpFormat_NCHW
||
format_
==
kOpFormat_DEFAULT
||
format_
==
kOpFormat_NDHWC
)
{
if
(
type_id_
==
type
)
{
SyncMemory
(
host_ptr
,
ptr_
,
size
,
RT_MEMCPY_DEVICE_TO_HOST
);
sync_ok
=
true
;
...
...
@@ -115,9 +120,11 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t
}
}
}
else
{
auto
iter
=
k
NeedTransFormatSe
t
.
find
(
format_
);
if
(
iter
!=
k
NeedTransFormatSe
t
.
end
())
{
auto
iter
=
k
OpNeedTransForma
t
.
find
(
format_
);
if
(
iter
!=
k
OpNeedTransForma
t
.
end
())
{
sync_ok
=
SyncDeviceToHostAndConvertFormat
(
shape
,
size
,
type
,
host_ptr
);
}
else
{
MS_LOG
(
INFO
)
<<
"Can not find format transfer for :"
<<
format_
;
}
}
if
(
!
sync_ok
)
{
...
...
@@ -141,7 +148,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int
if
(
host_shape
.
empty
())
{
host_shape
.
emplace_back
(
1
);
}
if
(
format_
==
kOpFormat_FRAC_NZ
)
{
if
(
format_
==
kOpFormat_FRAC_NZ
||
format_
==
kOpFormat_NDHWC
)
{
device_shape
=
trans
::
TransShapeToDevice
(
host_shape
,
format_
);
}
else
{
host_shape
=
trans
::
PaddingShapeTo4d
(
host_shape
);
...
...
@@ -185,7 +192,7 @@ bool AscendDeviceAddress::SyncHostToDevice(const std::vector<int> &shape, size_t
if
(
host_shape
.
empty
())
{
host_shape
.
emplace_back
(
1
);
}
if
(
format_
==
kOpFormat_NCHW
||
format_
==
kOpFormat_DEFAULT
)
{
if
(
format_
==
kOpFormat_NCHW
||
format_
==
kOpFormat_DEFAULT
||
format_
==
kOpFormat_NDHWC
)
{
if
(
type_id_
==
type
)
{
SyncMemory
(
ptr_
,
host_ptr
,
size_
,
RT_MEMCPY_HOST_TO_DEVICE
);
sync_ok
=
true
;
...
...
@@ -203,9 +210,11 @@ bool AscendDeviceAddress::SyncHostToDevice(const std::vector<int> &shape, size_t
SyncMemory
(
ptr_
,
host_tmp
.
data
(),
size_
,
RT_MEMCPY_HOST_TO_DEVICE
);
}
}
else
{
auto
iter
=
k
NeedTransFormatSe
t
.
find
(
format_
);
if
(
iter
!=
k
NeedTransFormatSe
t
.
end
())
{
auto
iter
=
k
OpNeedTransForma
t
.
find
(
format_
);
if
(
iter
!=
k
OpNeedTransForma
t
.
end
())
{
sync_ok
=
ConvertFormatAndSyncHostToDevice
(
shape
,
size
,
type
,
host_ptr
);
}
else
{
MS_LOG
(
INFO
)
<<
"Can not find format transfer for :"
<<
format_
;
}
}
if
(
!
sync_ok
)
{
...
...
@@ -227,7 +236,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int
host_shape
.
emplace_back
(
1
);
}
std
::
vector
<
size_t
>
device_shape
;
if
(
format_
==
kOpFormat_FRAC_NZ
)
{
if
(
format_
==
kOpFormat_FRAC_NZ
||
format_
==
kOpFormat_NDHWC
)
{
device_shape
=
trans
::
TransShapeToDevice
(
host_shape
,
format_
);
}
else
{
host_shape
=
trans
::
PaddingShapeTo4d
(
host_shape
);
...
...
tests/ut/cpp/common/trans_test.cc
0 → 100644
浏览文件 @
828d0b12
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "common/common_test.h"
#include "common/trans.h"
#include "utils/utils.h"
using
namespace
std
;
namespace
mindspore
{
namespace
trans
{
class
FormatTransTest
:
public
UT
::
Common
{
public:
FormatTransTest
()
=
default
;
void
SetUp
()
override
{}
void
TearDown
()
override
{}
};
TEST_F
(
FormatTransTest
,
nchw_to_hwcn
)
{
uint16_t
data
[
2
*
2
*
2
*
2
]
=
{
12581
,
14220
,
14937
,
14302
,
15004
,
14951
,
14694
,
14564
,
14069
,
14554
,
10507
,
14787
,
13016
,
15263
,
14872
,
10838
};
uint16_t
res
[
2
*
2
*
2
*
2
]
=
{
12581
,
14069
,
15004
,
13016
,
14220
,
14554
,
14951
,
15263
,
14937
,
10507
,
14694
,
14872
,
14302
,
14787
,
14564
,
10838
};
size_t
device_size
=
32
;
auto
trans_tmp
=
std
::
vector
<
uint8_t
>
(
device_size
);
FormatArgs
format_args
{
data
,
device_size
,
kOpFormat_NCHW
,
kOpFormat_HWCN
,
{
2
,
2
,
2
,
2
},
{
2
,
2
,
2
,
2
},
kNumberTypeFloat16
};
EXPECT_EQ
(
trans
::
TransFormat
(
format_args
,
trans_tmp
.
data
()),
true
);
for
(
size_t
i
=
0
;
i
<
sizeof
(
res
)
/
sizeof
(
res
[
0
]);
i
++
)
{
EXPECT_EQ
((
reinterpret_cast
<
uint16_t
*>
(
trans_tmp
.
data
()))[
i
],
res
[
i
]);
}
}
TEST_F
(
FormatTransTest
,
hwcn_to_nchw
)
{
uint16_t
data
[
2
*
2
*
2
*
2
]
=
{
12581
,
14069
,
15004
,
13016
,
14220
,
14554
,
14951
,
15263
,
14937
,
10507
,
14694
,
14872
,
14302
,
14787
,
14564
,
10838
};
uint16_t
res
[
2
*
2
*
2
*
2
]
=
{
12581
,
14220
,
14937
,
14302
,
15004
,
14951
,
14694
,
14564
,
14069
,
14554
,
10507
,
14787
,
13016
,
15263
,
14872
,
10838
};
size_t
device_size
=
32
;
auto
trans_tmp
=
std
::
vector
<
uint8_t
>
(
device_size
);
FormatArgs
format_args
{
data
,
device_size
,
kOpFormat_NCHW
,
kOpFormat_HWCN
,
{
2
,
2
,
2
,
2
},
{
2
,
2
,
2
,
2
},
kNumberTypeFloat16
};
EXPECT_EQ
(
trans
::
TransFormatFromDeviceToHost
(
format_args
,
trans_tmp
.
data
()),
true
);
for
(
size_t
i
=
0
;
i
<
sizeof
(
res
)
/
sizeof
(
res
[
0
]);
i
++
)
{
EXPECT_EQ
((
reinterpret_cast
<
uint16_t
*>
(
trans_tmp
.
data
()))[
i
],
res
[
i
]);
}
}
TEST_F
(
FormatTransTest
,
nchw_to_nhwc
)
{
uint16_t
data
[
2
*
2
*
2
*
2
]
=
{
11750
,
13778
,
15007
,
15321
,
15163
,
13446
,
15063
,
14467
,
15056
,
13284
,
15219
,
14797
,
12684
,
14288
,
14855
,
14799
};
uint16_t
res
[
2
*
2
*
2
*
2
]
=
{
11750
,
15163
,
13778
,
13446
,
15007
,
15063
,
15321
,
14467
,
15056
,
12684
,
13284
,
14288
,
15219
,
14855
,
14797
,
14799
};
size_t
device_size
=
32
;
auto
trans_tmp
=
std
::
vector
<
uint8_t
>
(
device_size
);
FormatArgs
format_args
{
data
,
device_size
,
kOpFormat_NCHW
,
kOpFormat_NHWC
,
{
2
,
2
,
2
,
2
},
{
2
,
2
,
2
,
2
},
kNumberTypeFloat16
};
EXPECT_EQ
(
trans
::
TransFormat
(
format_args
,
trans_tmp
.
data
()),
true
);
for
(
size_t
i
=
0
;
i
<
sizeof
(
res
)
/
sizeof
(
res
[
0
]);
i
++
)
{
EXPECT_EQ
((
reinterpret_cast
<
uint16_t
*>
(
trans_tmp
.
data
()))[
i
],
res
[
i
]);
}
}
TEST_F
(
FormatTransTest
,
nhwc_to_nchw
)
{
uint16_t
data
[
2
*
2
*
2
*
2
]
=
{
11750
,
15163
,
13778
,
13446
,
15007
,
15063
,
15321
,
14467
,
15056
,
12684
,
13284
,
14288
,
15219
,
14855
,
14797
,
14799
};
uint16_t
res
[
2
*
2
*
2
*
2
]
=
{
11750
,
13778
,
15007
,
15321
,
15163
,
13446
,
15063
,
14467
,
15056
,
13284
,
15219
,
14797
,
12684
,
14288
,
14855
,
14799
};
size_t
device_size
=
32
;
auto
trans_tmp
=
std
::
vector
<
uint8_t
>
(
device_size
);
FormatArgs
format_args
{
data
,
device_size
,
kOpFormat_NCHW
,
kOpFormat_NHWC
,
{
2
,
2
,
2
,
2
},
{
2
,
2
,
2
,
2
},
kNumberTypeFloat16
};
EXPECT_EQ
(
trans
::
TransFormatFromDeviceToHost
(
format_args
,
trans_tmp
.
data
()),
true
);
for
(
size_t
i
=
0
;
i
<
sizeof
(
res
)
/
sizeof
(
res
[
0
]);
i
++
)
{
EXPECT_EQ
((
reinterpret_cast
<
uint16_t
*>
(
trans_tmp
.
data
()))[
i
],
res
[
i
]);
}
}
}
// namespace trans
}
// namespace mindspore
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录