Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
41c969ab
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
41c969ab
编写于
4月 17, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 17, 2020
浏览文件
操作
浏览文件
下载
差异文件
!414 add 6d format transfer
Merge pull request !414 from liubuyu/dev_lby
上级
f69a668d
fc07cd90
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
228 addition
and
37 deletion
+228
-37
mindspore/ccsrc/common/trans.cc
mindspore/ccsrc/common/trans.cc
+212
-31
mindspore/ccsrc/common/trans.h
mindspore/ccsrc/common/trans.h
+2
-0
mindspore/ccsrc/device/ascend/ascend_device_address.cc
mindspore/ccsrc/device/ascend/ascend_device_address.cc
+10
-4
mindspore/ccsrc/utils/utils.h
mindspore/ccsrc/utils/utils.h
+4
-2
未找到文件。
mindspore/ccsrc/common/trans.cc
浏览文件 @
41c969ab
...
...
@@ -231,7 +231,98 @@ std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std
return
shape_4d
;
}
namespace
{
bool
CheckDims
(
const
std
::
vector
<
size_t
>
&
shape
)
{
if
(
shape
.
size
()
!=
4
)
{
MS_LOG
(
ERROR
)
<<
"Host shape dims shoud be 4"
;
return
false
;
}
return
true
;
}
std
::
vector
<
size_t
>
NchwDeviceShape
(
const
std
::
vector
<
size_t
>
&
shape
)
{
if
(
!
CheckDims
(
shape
))
{
MS_LOG
(
EXCEPTION
)
<<
"Check dims failed."
;
}
return
shape
;
}
std
::
vector
<
size_t
>
NhwcDeviceShape
(
const
std
::
vector
<
size_t
>
&
shape
)
{
if
(
!
CheckDims
(
shape
))
{
MS_LOG
(
EXCEPTION
)
<<
"Ccheck dims failed."
;
}
std
::
vector
<
size_t
>
device_shape
;
device_shape
.
push_back
(
shape
[
0
]);
device_shape
.
push_back
(
shape
[
2
]);
device_shape
.
push_back
(
shape
[
3
]);
device_shape
.
push_back
(
shape
[
1
]);
return
device_shape
;
}
std
::
vector
<
size_t
>
HwchDeviceShape
(
const
std
::
vector
<
size_t
>
&
shape
)
{
if
(
!
CheckDims
(
shape
))
{
MS_LOG
(
EXCEPTION
)
<<
"Check dims failed."
;
}
std
::
vector
<
size_t
>
device_shape
;
device_shape
.
push_back
(
shape
[
2
]);
device_shape
.
push_back
(
shape
[
3
]);
device_shape
.
push_back
(
shape
[
1
]);
device_shape
.
push_back
(
shape
[
0
]);
return
device_shape
;
}
std
::
vector
<
size_t
>
FracZDeviceShape
(
const
std
::
vector
<
size_t
>
&
shape
)
{
if
(
!
CheckDims
(
shape
))
{
MS_LOG
(
EXCEPTION
)
<<
"Check dims failed."
;
}
std
::
vector
<
size_t
>
device_shape
;
size_t
cout16
=
((
shape
[
0
]
+
kCubeSize
-
1
)
/
kCubeSize
)
*
kCubeSize
;
size_t
cin16
=
((
shape
[
1
]
+
kCubeSize
-
1
)
/
kCubeSize
)
*
kCubeSize
;
device_shape
.
push_back
(
shape
[
2
]
*
shape
[
3
]
*
cin16
/
kCubeSize
);
device_shape
.
push_back
(
cout16
/
kCubeSize
);
device_shape
.
push_back
(
kCubeSize
);
device_shape
.
push_back
(
kCubeSize
);
return
device_shape
;
}
std
::
vector
<
size_t
>
Nc1hwc0DeviceShape
(
const
std
::
vector
<
size_t
>
&
shape
)
{
if
(
!
CheckDims
(
shape
))
{
MS_LOG
(
EXCEPTION
)
<<
"Check dims failed."
;
}
std
::
vector
<
size_t
>
device_shape
;
size_t
C1
=
(
shape
[
1
]
+
kCubeSize
-
1
)
/
kCubeSize
;
size_t
C0
=
kCubeSize
;
device_shape
.
push_back
(
shape
[
0
]);
device_shape
.
push_back
(
C1
);
device_shape
.
push_back
(
shape
[
2
]);
device_shape
.
push_back
(
shape
[
3
]);
device_shape
.
push_back
(
C0
);
return
device_shape
;
}
std
::
vector
<
size_t
>
C1hwncoc0DeviceShape
(
const
std
::
vector
<
size_t
>
&
shape
)
{
if
(
!
CheckDims
(
shape
))
{
MS_LOG
(
EXCEPTION
)
<<
"Check dims failed."
;
}
std
::
vector
<
size_t
>
device_shape
;
device_shape
.
push_back
((
shape
[
1
]
-
1
)
/
kCubeSize
+
1
);
device_shape
.
push_back
(
shape
[
2
]);
device_shape
.
push_back
(
shape
[
3
]);
device_shape
.
push_back
(
shape
[
0
]);
device_shape
.
push_back
(
kCubeSize
);
device_shape
.
push_back
(
kCubeSize
);
return
device_shape
;
}
}
// namespace
std
::
vector
<
size_t
>
TransShapeToDevice
(
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
string
&
format
)
{
using
DeviceShapeTransfer
=
std
::
function
<
std
::
vector
<
size_t
>
(
const
std
::
vector
<
size_t
>
&
)
>
;
const
std
::
map
<
std
::
string
,
DeviceShapeTransfer
>
device_shape_map
{
{
kOpFormat_NCHW
,
NchwDeviceShape
},
{
kOpFormat_NHWC
,
NhwcDeviceShape
},
{
kOpFormat_HWCN
,
HwchDeviceShape
},
{
kOpFormat_FRAC_Z
,
FracZDeviceShape
},
{
kOpFormat_NC1HWC0
,
Nc1hwc0DeviceShape
},
{
kOpFormat_C1HWNCoC0
,
C1hwncoc0DeviceShape
},
};
if
(
format
==
kOpFormat_ND
||
format
==
kOpFormat_DEFAULT
)
{
return
shape
;
}
...
...
@@ -255,37 +346,31 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
MS_LOG
(
WARNING
)
<<
"Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"
;
temp_shape
=
PaddingShapeTo4dByDefault
(
shape
);
}
if
(
format
==
kOpFormat_NC1HWC0
)
{
size_t
C1
=
(
temp_shape
[
1
]
+
kCubeSize
-
1
)
/
kCubeSize
;
size_t
C0
=
kCubeSize
;
device_shape
.
push_back
(
temp_shape
[
0
]);
device_shape
.
push_back
(
C1
);
device_shape
.
push_back
(
temp_shape
[
2
]);
device_shape
.
push_back
(
temp_shape
[
3
]);
device_shape
.
push_back
(
C0
);
return
device_shape
;
}
else
if
(
format
==
kOpFormat_FRAC_Z
)
{
size_t
cout16
=
((
temp_shape
[
0
]
+
kCubeSize
-
1
)
/
kCubeSize
)
*
kCubeSize
;
size_t
cin16
=
((
temp_shape
[
1
]
+
kCubeSize
-
1
)
/
kCubeSize
)
*
kCubeSize
;
device_shape
.
push_back
(
temp_shape
[
2
]
*
temp_shape
[
3
]
*
cin16
/
kCubeSize
);
device_shape
.
push_back
(
cout16
/
kCubeSize
);
device_shape
.
push_back
(
kCubeSize
);
device_shape
.
push_back
(
kCubeSize
);
return
device_shape
;
}
else
if
(
format
==
kOpFormat_NHWC
)
{
device_shape
.
push_back
(
temp_shape
[
0
]);
device_shape
.
push_back
(
temp_shape
[
2
]);
device_shape
.
push_back
(
temp_shape
[
3
]);
device_shape
.
push_back
(
temp_shape
[
1
]);
return
device_shape
;
}
else
if
(
format
==
kOpFormat_HWCN
)
{
return
{
temp_shape
[
2
],
temp_shape
[
3
],
temp_shape
[
1
],
temp_shape
[
0
]};
}
else
if
(
format
==
kOpFormat_NCHW
)
{
return
temp_shape
;
auto
iter
=
device_shape_map
.
find
(
format
);
if
(
iter
!=
device_shape_map
.
end
())
{
return
iter
->
second
(
temp_shape
);
}
MS_LOG
(
EXCEPTION
)
<<
"Unexpected format["
<<
format
<<
"]"
;
}
bool
CheckArgs
(
const
FormatArgs
&
args
,
size_t
*
size
,
size_t
*
total_size
)
{
if
(
args
.
host_shape
.
size
()
!=
kNchwDims
)
{
MS_LOG
(
ERROR
)
<<
"Invalid host shape, host shape dims:"
<<
args
.
host_shape
.
size
()
<<
", expect dims:"
<<
kNchwDims
;
return
false
;
}
*
size
=
TypeIdSize
(
args
.
src_data_type
);
if
(
*
size
<
1
)
{
MS_LOG
(
ERROR
)
<<
"Illegal dtype."
;
return
false
;
}
*
total_size
=
ShapeSize
(
args
.
device_shape
)
*
(
*
size
);
if
(
*
total_size
!=
args
.
device_size
)
{
MS_LOG
(
ERROR
)
<<
"Illegal total data size, total_size:"
<<
*
total_size
<<
", device_size:"
<<
args
.
device_size
;
return
false
;
}
return
true
;
}
bool
TransDataType
(
const
TypeIdArgs
&
args
,
void
*
result
)
{
MS_LOG
(
DEBUG
)
<<
"Begin trans datatype from "
<<
TypeIdLabel
(
args
.
host_data_type
)
<<
" to "
<<
TypeIdLabel
(
args
.
device_data_type
);
...
...
@@ -320,13 +405,14 @@ bool TransFormat(const FormatArgs &args, void *result) {
MS_LOG
(
ERROR
)
<<
"Invalid datatype.."
;
return
false
;
}
if
((
args
.
host_format
==
kOpFormat_NCHW
||
args
.
host_format
==
kOpFormat_ND
)
&&
args
.
device_format
==
kOpFormat_FRAC_Z
)
{
if
(
args
.
device_format
==
kOpFormat_FRAC_Z
)
{
return
NchwToFracZ
(
args
,
result
);
}
else
if
(
args
.
device_format
==
kOpFormat_FRAC_NZ
)
{
return
NchwToFracNz
(
args
,
result
);
}
else
if
(
args
.
device_format
==
kOpFormat_NC1HWC0
)
{
return
NchwToNc1hwc0
(
args
,
result
);
}
else
if
(
args
.
device_format
==
kOpFormat_C1HWNCoC0
)
{
return
NchwToC1hwncoc0
(
args
,
result
);
}
return
true
;
}
...
...
@@ -337,13 +423,14 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
MS_LOG
(
ERROR
)
<<
"Invalid datatype.."
;
return
false
;
}
if
((
args
.
host_format
==
kOpFormat_NCHW
||
args
.
host_format
==
kOpFormat_ND
)
&&
args
.
device_format
==
kOpFormat_FRAC_Z
)
{
if
(
args
.
device_format
==
kOpFormat_FRAC_Z
)
{
return
FracZToNchw
(
args
,
result
);
}
else
if
(
args
.
device_format
==
kOpFormat_FRAC_NZ
)
{
return
FracNzToNchw
(
args
,
result
);
}
else
if
(
args
.
device_format
==
kOpFormat_NC1HWC0
)
{
return
Nc1hwc0ToNchw
(
args
,
result
);
}
else
if
(
args
.
device_format
==
kOpFormat_C1HWNCoC0
)
{
return
C1hwncoc0ToNchw
(
args
,
result
);
}
return
true
;
}
...
...
@@ -801,5 +888,99 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) {
}
return
true
;
}
bool
NchwToC1hwncoc0
(
const
FormatArgs
&
args
,
void
*
result
)
{
// trans nchw to c1hwncoc0
MS_LOG
(
DEBUG
)
<<
"Trans format from nchw to c1hwncoc0."
;
MS_EXCEPTION_IF_NULL
(
result
);
size_t
size
=
0
;
size_t
total_size
=
0
;
if
(
!
CheckArgs
(
args
,
&
size
,
&
total_size
))
{
MS_LOG
(
ERROR
)
<<
"Check args failed."
;
return
false
;
}
auto
n
=
args
.
host_shape
[
0
];
auto
c
=
args
.
host_shape
[
1
];
auto
h
=
args
.
host_shape
[
2
];
auto
w
=
args
.
host_shape
[
3
];
auto
c1
=
args
.
device_shape
[
0
];
auto
co
=
args
.
device_shape
[
4
];
auto
c0
=
args
.
device_shape
[
5
];
for
(
size_t
c1_i
=
0
;
c1_i
<
c1
;
c1_i
++
)
{
for
(
size_t
h_i
=
0
;
h_i
<
h
;
h_i
++
)
{
for
(
size_t
w_i
=
0
;
w_i
<
w
;
w_i
++
)
{
for
(
size_t
n_i
=
0
;
n_i
<
n
;
n_i
++
)
{
for
(
size_t
co_i
=
0
;
co_i
<
co
;
co_i
++
)
{
for
(
size_t
c0_i
=
0
;
c0_i
<
c0
;
c0_i
++
)
{
size_t
dst_offset
=
(
c1_i
*
h
*
w
*
n
*
co
*
c0
+
h_i
*
w
*
n
*
co
*
c0
+
w_i
*
n
*
co
*
c0
+
n_i
*
co
*
c0
+
co_i
*
c0
+
c0_i
)
*
size
;
size_t
protected_size
=
total_size
-
dst_offset
<
static_cast
<
size_t
>
(
SECUREC_MEM_MAX_LEN
)
?
total_size
-
dst_offset
:
static_cast
<
size_t
>
(
SECUREC_MEM_MAX_LEN
);
size_t
c_i
=
c0_i
+
c1_i
*
c0
;
size_t
src_offset
=
(
n_i
*
c
*
h
*
w
+
c_i
*
h
*
w
+
h_i
*
w
+
w_i
)
*
size
;
error_t
ret
;
if
(
c_i
<
c
&&
c0_i
==
co_i
)
{
ret
=
memcpy_s
(
static_cast
<
uint8_t
*>
(
result
)
+
dst_offset
,
protected_size
,
static_cast
<
uint8_t
const
*>
(
args
.
data
)
+
src_offset
,
size
);
}
else
{
ret
=
memset_s
(
static_cast
<
uint8_t
*>
(
result
)
+
dst_offset
,
protected_size
,
0
,
size
);
}
if
(
ret
!=
EOK
)
{
MS_LOG
(
ERROR
)
<<
"Failed to operate the dst memory, error-code:"
<<
ret
;
return
false
;
}
}
}
}
}
}
}
return
true
;
}
bool
C1hwncoc0ToNchw
(
const
FormatArgs
&
args
,
void
*
result
)
{
// trans c1hwncoc0 to nchw
MS_LOG
(
DEBUG
)
<<
"Trans format from c1hwncoc0 to nchw"
;
MS_EXCEPTION_IF_NULL
(
result
);
size_t
size
=
0
;
size_t
total_size
=
0
;
if
(
!
CheckArgs
(
args
,
&
size
,
&
total_size
))
{
MS_LOG
(
ERROR
)
<<
"Check args failed."
;
return
false
;
}
auto
n
=
args
.
host_shape
[
0
];
auto
c
=
args
.
host_shape
[
1
];
auto
h
=
args
.
host_shape
[
2
];
auto
w
=
args
.
host_shape
[
3
];
auto
co
=
args
.
device_shape
[
4
];
auto
c0
=
args
.
device_shape
[
5
];
for
(
size_t
n_i
=
0
;
n_i
<
n
;
n_i
++
)
{
for
(
size_t
c_i
=
0
;
c_i
<
c
;
c_i
++
)
{
for
(
size_t
h_i
=
0
;
h_i
<
h
;
h_i
++
)
{
for
(
size_t
w_i
=
0
;
w_i
<
w
;
w_i
++
)
{
size_t
dst_offset
=
(
n_i
*
c
*
h
*
w
+
c_i
*
h
*
w
+
h_i
*
w
+
w_i
)
*
size
;
size_t
c1_i
=
c_i
/
kCubeSize
;
size_t
c0_i
=
c_i
%
kCubeSize
;
size_t
co_i
=
c0_i
;
size_t
src_offset
=
(
c1_i
*
h
*
w
*
n
*
co
*
c0
+
h_i
*
w
*
n
*
co
*
c0
+
w_i
*
n
*
co
*
c0
+
n_i
*
co
*
c0
+
co_i
*
c0
+
c0_i
)
*
size
;
size_t
protected_size
=
total_size
-
dst_offset
<
static_cast
<
size_t
>
(
SECUREC_MEM_MAX_LEN
)
?
total_size
-
dst_offset
:
static_cast
<
size_t
>
(
SECUREC_MEM_MAX_LEN
);
auto
ret
=
memcpy_s
(
static_cast
<
uint8_t
*>
(
result
)
+
dst_offset
,
protected_size
,
static_cast
<
uint8_t
const
*>
(
args
.
data
)
+
src_offset
,
size
);
if
(
ret
!=
EOK
)
{
MS_LOG
(
ERROR
)
<<
"Failed to operate the dst memory, error-code:"
<<
ret
;
return
false
;
}
}
}
}
}
return
true
;
}
}
// namespace trans
}
// namespace mindspore
mindspore/ccsrc/common/trans.h
浏览文件 @
41c969ab
...
...
@@ -63,10 +63,12 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result);
bool
NchwToFracZ
(
const
FormatArgs
&
args
,
void
*
result
);
bool
NchwToFracNz
(
const
FormatArgs
&
args
,
void
*
result
);
bool
NchwToNc1hwc0
(
const
FormatArgs
&
args
,
void
*
result
);
bool
NchwToC1hwncoc0
(
const
FormatArgs
&
args
,
void
*
result
);
// device to host
bool
FracZToNchw
(
const
FormatArgs
&
args
,
void
*
result
);
bool
FracNzToNchw
(
const
FormatArgs
&
args
,
void
*
result
);
bool
Nc1hwc0ToNchw
(
const
FormatArgs
&
args
,
void
*
result
);
bool
C1hwncoc0ToNchw
(
const
FormatArgs
&
args
,
void
*
result
);
}
// namespace trans
}
// namespace mindspore
...
...
mindspore/ccsrc/device/ascend/ascend_device_address.cc
浏览文件 @
41c969ab
...
...
@@ -114,8 +114,11 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t
return
false
;
}
}
}
else
if
(
format_
==
kOpFormat_NC1HWC0
||
format_
==
kOpFormat_FRAC_Z
||
format_
==
kOpFormat_FRAC_NZ
)
{
sync_ok
=
SyncDeviceToHostAndConvertFormat
(
shape
,
size
,
type
,
host_ptr
);
}
else
{
auto
iter
=
kNeedTransFormatSet
.
find
(
format_
);
if
(
iter
!=
kNeedTransFormatSet
.
end
())
{
sync_ok
=
ConvertFormatAndSyncHostToDevice
(
shape
,
size
,
type
,
host_ptr
);
}
}
if
(
!
sync_ok
)
{
MS_LOG
(
ERROR
)
<<
"Not support to trans, dev_format:"
<<
format_
<<
", dev_type:"
<<
TypeIdLabel
(
type_id_
)
...
...
@@ -199,9 +202,12 @@ bool AscendDeviceAddress::SyncHostToDevice(const std::vector<int> &shape, size_t
}
SyncMemory
(
ptr_
,
host_tmp
.
data
(),
size_
,
RT_MEMCPY_HOST_TO_DEVICE
);
}
}
else
if
(
format_
==
kOpFormat_NC1HWC0
||
format_
==
kOpFormat_FRAC_Z
||
format_
==
kOpFormat_FRAC_NZ
)
{
}
else
{
auto
iter
=
kNeedTransFormatSet
.
find
(
format_
);
if
(
iter
!=
kNeedTransFormatSet
.
end
())
{
sync_ok
=
ConvertFormatAndSyncHostToDevice
(
shape
,
size
,
type
,
host_ptr
);
}
}
if
(
!
sync_ok
)
{
MS_LOG
(
ERROR
)
<<
"Not support to trans, dev_format:"
<<
format_
<<
", dev_type:"
<<
TypeIdLabel
(
type_id_
)
<<
", host_type:"
<<
TypeIdLabel
(
type
);
...
...
mindspore/ccsrc/utils/utils.h
浏览文件 @
41c969ab
...
...
@@ -187,7 +187,9 @@ constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ";
constexpr
auto
kOpFormat_C1HWNCoC0
=
"C1HWNCoC0"
;
constexpr
auto
kOpFormat_NC1HWC0_C04
=
"NC1HWC0_C04"
;
const
std
::
set
<
std
::
string
>
k1DSupportFormat
=
{
kOpFormat_DEFAULT
,
kOpFormat_NCHW
,
kOpFormat_NHWC
,
kOpFormat_FRAC_Z
,
kOpFormat_NC1KHKWHWC0
,
kOpFormat_NC1HWC0
};
kOpFormat_FRAC_Z
,
kOpFormat_NC1KHKWHWC0
,
kOpFormat_NC1HWC0
,
kOpFormat_C1HWNCoC0
};
const
std
::
set
<
std
::
string
>
k2DSupportFormat
=
{
kOpFormat_DEFAULT
,
kOpFormat_NCHW
,
kOpFormat_NHWC
,
kOpFormat_FRAC_Z
,
kOpFormat_NC1KHKWHWC0
};
const
std
::
set
<
std
::
string
>
k3DSupportFormat
=
{
kOpFormat_DEFAULT
,
kOpFormat_NC1KHKWHWC0
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录