Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
715c0735
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
715c0735
编写于
4月 22, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 22, 2020
浏览文件
操作
浏览文件
下载
差异文件
!487 add dtype trans template
Merge pull request !487 from liubuyu/master
上级
67057d13
ac2d5df2
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
56 addition
and
29 deletion
+56
-29
mindspore/ccsrc/common/trans.cc
mindspore/ccsrc/common/trans.cc
+33
-11
mindspore/ccsrc/common/trans.h
mindspore/ccsrc/common/trans.h
+6
-3
mindspore/ccsrc/device/ascend/ascend_device_address.cc
mindspore/ccsrc/device/ascend/ascend_device_address.cc
+17
-15
未找到文件。
mindspore/ccsrc/common/trans.cc
浏览文件 @
715c0735
...
...
@@ -103,17 +103,39 @@ const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{
template
<
typename
SrcT
,
typename
DstT
>
void
TransDataSrc2Dst
(
const
TypeIdArgs
&
args
,
void
*
dst
,
const
size_t
data_size
)
{
auto
src_id
=
TypeIdSize
(
args
.
src_type
);
auto
dst_id
=
TypeIdSize
(
args
.
dst_type
);
if
(
args
.
src_size
/
src_id
!=
args
.
src_shape_size
||
args
.
dst_size
/
dst_id
!=
args
.
dst_shape_size
)
{
MS_LOG
(
EXCEPTION
)
<<
"Invalid src or dst data size."
;
}
for
(
size_t
idx
=
0
;
idx
!=
data_size
;
idx
++
)
{
SrcT
src_data
=
static_cast
<
const
SrcT
*>
(
args
.
data
)[
idx
];
static_cast
<
DstT
*>
(
dst
)[
idx
]
=
static_cast
<
DstT
>
(
src_data
);
}
}
template
<
typename
SrcT
>
void
TransDataSrc2Fp16
(
const
TypeIdArgs
&
args
,
void
*
dst
,
const
size_t
data_size
)
{
auto
src_id
=
TypeIdSize
(
args
.
src_type
);
auto
dst_id
=
TypeIdSize
(
args
.
dst_type
);
if
(
args
.
src_size
/
src_id
!=
args
.
src_shape_size
||
args
.
dst_size
/
dst_id
!=
args
.
dst_shape_size
)
{
MS_LOG
(
EXCEPTION
)
<<
"Invalid src or dst data size."
;
}
auto
src_data
=
static_cast
<
const
SrcT
*>
(
args
.
data
);
auto
half_data
=
static_cast
<
Eigen
::
half
*>
(
dst
);
for
(
size_t
i
=
0
;
i
<
data_size
;
i
++
)
{
half_data
[
i
]
=
Eigen
::
half
(
src_data
[
i
]);
}
}
bool
CastKernel
(
const
TypeIdArgs
&
args
,
void
*
dst
,
const
size_t
data_size
,
const
DataTypeTransMode
mode
)
{
switch
(
mode
)
{
case
FROM_FLOAT_TO_FLOAT16
:
device
::
FloatToHalf
(
dst
,
args
.
data
,
data_size
);
break
;
case
FROM_INT32_TO_FLOAT16
:
TransDataSrc2Fp16
<
int32_t
>
(
args
,
dst
,
data_size
);
break
;
case
FROM_FLOAT16_TO_FLOAT
:
device
::
HalfToFloat
(
dst
,
args
.
data
,
data_size
);
break
;
...
...
@@ -372,27 +394,27 @@ bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
}
bool
TransDataType
(
const
TypeIdArgs
&
args
,
void
*
result
)
{
MS_LOG
(
DEBUG
)
<<
"Begin trans datatype from "
<<
TypeIdLabel
(
args
.
host_data_type
)
<<
" to "
<<
TypeIdLabel
(
args
.
device_data_type
);
MS_LOG
(
DEBUG
)
<<
"Begin trans datatype from "
<<
TypeIdLabel
(
args
.
src_type
)
<<
" to "
<<
TypeIdLabel
(
args
.
dst_type
);
MS_EXCEPTION_IF_NULL
(
result
);
std
::
pair
<
TypeId
,
TypeId
>
type_info
(
args
.
host_data_type
,
args
.
device_data
_type
);
std
::
pair
<
TypeId
,
TypeId
>
type_info
(
args
.
src_type
,
args
.
dst
_type
);
auto
iter
=
mode_map
.
find
(
type_info
);
if
(
iter
==
mode_map
.
end
())
{
MS_LOG
(
ERROR
)
<<
"Unsupported datatype trans. src_type :"
<<
TypeIdLabel
(
args
.
host_data
_type
)
<<
", dst_type:"
<<
TypeIdLabel
(
args
.
d
evice_data
_type
);
MS_LOG
(
ERROR
)
<<
"Unsupported datatype trans. src_type :"
<<
TypeIdLabel
(
args
.
src
_type
)
<<
", dst_type:"
<<
TypeIdLabel
(
args
.
d
st
_type
);
return
false
;
}
auto
trans_mode
=
iter
->
second
;
auto
type_size
=
TypeIdSize
(
args
.
device_data_type
);
if
(
type_size
<
1
)
{
MS_LOG
(
ERROR
)
<<
"Invalid host data type."
;
auto
src_id
=
TypeIdSize
(
args
.
src_type
);
auto
dst_id
=
TypeIdSize
(
args
.
dst_type
);
if
(
src_id
<
1
||
dst_id
<
1
)
{
MS_LOG
(
ERROR
)
<<
"Invalid src or dst data type."
;
return
false
;
}
if
(
args
.
host_shape_size
<
1
)
{
MS_LOG
(
ERROR
)
<<
"Invalid
ho
st data size."
;
if
(
args
.
src_size
/
src_id
!=
args
.
src_shape_size
||
args
.
dst_size
/
dst_id
!=
args
.
dst_shape_size
)
{
MS_LOG
(
ERROR
)
<<
"Invalid
src or d
st data size."
;
return
false
;
}
if
(
!
CastKernel
(
args
,
result
,
args
.
ho
st_shape_size
,
trans_mode
))
{
if
(
!
CastKernel
(
args
,
result
,
args
.
d
st_shape_size
,
trans_mode
))
{
MS_LOG
(
ERROR
)
<<
"Failed to trans datatype.."
;
return
false
;
}
...
...
mindspore/ccsrc/common/trans.h
浏览文件 @
715c0735
...
...
@@ -31,9 +31,12 @@ namespace mindspore {
namespace
trans
{
struct
TypeIdArgs
{
const
void
*
data
;
size_t
host_shape_size
;
// Multiply each dimension elements. [a, b, c, d] => a*b*c*d
TypeId
host_data_type
;
TypeId
device_data_type
;
size_t
src_size
;
size_t
dst_size
;
TypeId
src_type
;
TypeId
dst_type
;
size_t
src_shape_size
;
size_t
dst_shape_size
;
};
struct
FormatArgs
{
...
...
mindspore/ccsrc/device/ascend/ascend_device_address.cc
浏览文件 @
715c0735
...
...
@@ -104,10 +104,10 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t
}
else
if
(
type_id_
==
kNumberTypeFloat32
&&
type
==
kNumberTypeFloat64
)
{
sync_ok
=
SyncDeviceToHostAndFloatToFloat64
(
host_ptr
,
size
,
ptr_
,
size_
);
}
else
{
auto
shape
_size
=
trans
::
ShapeSize
(
host_shape
);
auto
host
_size
=
trans
::
ShapeSize
(
host_shape
);
auto
host
=
std
::
vector
<
uint8_t
>
(
size_
);
SyncMemory
(
host
.
data
(),
ptr_
,
size_
,
RT_MEMCPY_DEVICE_TO_HOST
);
const
trans
::
TypeIdArgs
type_args
{
host
.
data
(),
s
hape_size
,
type_id_
,
typ
e
};
const
trans
::
TypeIdArgs
type_args
{
host
.
data
(),
s
ize_
,
size
,
type_id_
,
type
,
host_size
,
host_siz
e
};
sync_ok
=
trans
::
TransDataType
(
type_args
,
host_ptr
);
if
(
!
sync_ok
)
{
MS_LOG
(
ERROR
)
<<
"trans data type failed."
;
...
...
@@ -153,14 +153,15 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int
auto
host
=
std
::
vector
<
uint8_t
>
(
size_
);
sync_ok
=
trans
::
TransFormatFromDeviceToHost
(
format_args
,
host
.
data
());
if
(
!
sync_ok
)
{
MS_LOG
(
ERROR
)
<<
"
t
rans format failed."
;
MS_LOG
(
ERROR
)
<<
"
T
rans format failed."
;
return
false
;
}
auto
shape_size
=
trans
::
ShapeSize
(
host_shape
);
const
trans
::
TypeIdArgs
type_args
{
host
.
data
(),
shape_size
,
type_id_
,
type
};
auto
host_size
=
trans
::
ShapeSize
(
host_shape
);
auto
device_size
=
trans
::
ShapeSize
(
device_shape
);
const
trans
::
TypeIdArgs
type_args
{
host
.
data
(),
size_
,
size
,
type_id_
,
type
,
device_size
,
host_size
};
sync_ok
=
trans
::
TransDataType
(
type_args
,
host_ptr
);
if
(
!
sync_ok
)
{
MS_LOG
(
ERROR
)
<<
"
t
rans format failed."
;
MS_LOG
(
ERROR
)
<<
"
T
rans format failed."
;
return
false
;
}
}
else
{
...
...
@@ -168,7 +169,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int
host_shape
,
device_shape
,
type_id_
};
sync_ok
=
trans
::
TransFormatFromDeviceToHost
(
format_args
,
host_ptr
);
if
(
!
sync_ok
)
{
MS_LOG
(
ERROR
)
<<
"
t
rans format failed."
;
MS_LOG
(
ERROR
)
<<
"
T
rans format failed."
;
return
false
;
}
}
...
...
@@ -192,12 +193,12 @@ bool AscendDeviceAddress::SyncHostToDevice(const std::vector<int> &shape, size_t
}
else
if
(
type_id_
==
kNumberTypeFloat32
&&
type
==
kNumberTypeFloat64
)
{
sync_ok
=
Float64ToFloatAndSyncHostToDevice
(
ptr_
,
size_
,
host_ptr
,
size
);
}
else
{
auto
shape
_size
=
trans
::
ShapeSize
(
host_shape
);
const
trans
::
TypeIdArgs
type_args
{
host_ptr
,
s
hape_size
,
type
,
type_id_
};
auto
host
_size
=
trans
::
ShapeSize
(
host_shape
);
const
trans
::
TypeIdArgs
type_args
{
host_ptr
,
s
ize
,
size_
,
type
,
type_id_
,
host_size
,
host_size
};
auto
host_tmp
=
std
::
vector
<
uint8_t
>
(
size_
);
sync_ok
=
trans
::
TransDataType
(
type_args
,
host_tmp
.
data
());
if
(
!
sync_ok
)
{
MS_LOG
(
ERROR
)
<<
"
t
rans data type failed."
;
MS_LOG
(
ERROR
)
<<
"
T
rans data type failed."
;
return
false
;
}
SyncMemory
(
ptr_
,
host_tmp
.
data
(),
size_
,
RT_MEMCPY_HOST_TO_DEVICE
);
...
...
@@ -234,12 +235,13 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int
device_shape
=
trans
::
TransShapeToDevice
(
host_shape
,
format_
);
}
if
(
type_id_
!=
type
)
{
auto
shape_size
=
trans
::
ShapeSize
(
host_shape
);
const
trans
::
TypeIdArgs
type_args
{
host_ptr
,
shape_size
,
type
,
type_id_
};
auto
host_size
=
trans
::
ShapeSize
(
host_shape
);
auto
device_size
=
trans
::
ShapeSize
(
device_shape
);
const
trans
::
TypeIdArgs
type_args
{
host_ptr
,
size
,
size_
,
type
,
type_id_
,
host_size
,
device_size
};
auto
host_tmp
=
std
::
vector
<
uint8_t
>
(
size_
);
sync_ok
=
trans
::
TransDataType
(
type_args
,
host_tmp
.
data
());
if
(
!
sync_ok
)
{
MS_LOG
(
ERROR
)
<<
"
t
rans datatype failed."
;
MS_LOG
(
ERROR
)
<<
"
T
rans datatype failed."
;
return
false
;
}
const
trans
::
FormatArgs
format_args
{
host_tmp
.
data
(),
size_
,
kOpFormat_NCHW
,
format_
,
...
...
@@ -247,7 +249,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int
auto
dst_tmp
=
std
::
vector
<
uint8_t
>
(
size_
);
sync_ok
=
trans
::
TransFormat
(
format_args
,
dst_tmp
.
data
());
if
(
!
sync_ok
)
{
MS_LOG
(
ERROR
)
<<
"
t
rans format failed."
;
MS_LOG
(
ERROR
)
<<
"
T
rans format failed."
;
return
false
;
}
SyncMemory
(
ptr_
,
dst_tmp
.
data
(),
size_
,
RT_MEMCPY_HOST_TO_DEVICE
);
...
...
@@ -256,7 +258,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int
auto
host_tmp
=
std
::
vector
<
uint8_t
>
(
size_
);
sync_ok
=
trans
::
TransFormat
(
format_args
,
host_tmp
.
data
());
if
(
!
sync_ok
)
{
MS_LOG
(
ERROR
)
<<
"
t
rans format failed."
;
MS_LOG
(
ERROR
)
<<
"
T
rans format failed."
;
return
false
;
}
SyncMemory
(
ptr_
,
host_tmp
.
data
(),
size_
,
RT_MEMCPY_HOST_TO_DEVICE
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录