Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
7b48a122
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7b48a122
编写于
7月 17, 2020
作者:
L
lvchangquan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
insert trans_data to reduce time in print process
上级
4e0cfafc
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
346 addition
and
13 deletion
+346
-13
mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc
...pore/ccsrc/runtime/device/ascend/ascend_device_address.cc
+321
-13
mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h
...spore/ccsrc/runtime/device/ascend/ascend_device_address.h
+11
-0
mindspore/ccsrc/runtime/device/kernel_runtime.cc
mindspore/ccsrc/runtime/device/kernel_runtime.cc
+12
-0
mindspore/ccsrc/runtime/device/kernel_runtime.h
mindspore/ccsrc/runtime/device/kernel_runtime.h
+2
-0
未找到文件。
mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc
浏览文件 @
7b48a122
...
...
@@ -16,14 +16,19 @@
#include "runtime/device/ascend/ascend_device_address.h"
#include <memory>
#include <vector>
#include <unordered_map>
#include <utility>
#include <set>
#include <algorithm>
#include "runtime/mem.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "runtime/device/kernel_runtime.h"
#include "runtime/device/convert_tensor_utils.h"
#include "ir/dtype/type.h"
#include "ir/tensor.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_build.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h"
#include "utils/utils.h"
#include "common/utils.h"
#include "common/trans.h"
...
...
@@ -34,6 +39,58 @@
#include "debug/tensor_load.h"
#endif
namespace
{
const
std
::
unordered_map
<
mindspore
::
TypeId
,
std
::
string
>
type_id_name_map
=
{
{
mindspore
::
kNumberTypeBool
,
"bool"
},
{
mindspore
::
kNumberTypeInt8
,
"int8"
},
{
mindspore
::
kNumberTypeInt16
,
"int16"
},
{
mindspore
::
kNumberTypeInt32
,
"int32"
},
{
mindspore
::
kNumberTypeInt64
,
"int64"
},
{
mindspore
::
kNumberTypeFloat16
,
"float16"
},
{
mindspore
::
kNumberTypeFloat32
,
"float32"
},
{
mindspore
::
kNumberTypeUInt8
,
"uint8"
},
{
mindspore
::
kNumberTypeUInt16
,
"uint16"
},
{
mindspore
::
kNumberTypeUInt32
,
"uint32"
},
{
mindspore
::
kNumberTypeUInt64
,
"uint64"
}};
const
std
::
set
<
std
::
pair
<
std
::
string
,
std
::
string
>>
use_trans_data
=
{
std
::
make_pair
(
"float16"
,
mindspore
::
kOpFormat_NC1HWC0
),
std
::
make_pair
(
"float32"
,
mindspore
::
kOpFormat_NC1HWC0
),
std
::
make_pair
(
"bool"
,
mindspore
::
kOpFormat_NC1HWC0
),
std
::
make_pair
(
"float32"
,
mindspore
::
kOpFormat_FRAC_Z
),
std
::
make_pair
(
"float16"
,
mindspore
::
kOpFormat_FRAC_Z
),
std
::
make_pair
(
"float16"
,
mindspore
::
kOpFormat_FRAC_NZ
),
std
::
make_pair
(
"float32"
,
mindspore
::
kOpFormat_FRAC_NZ
),
std
::
make_pair
(
"int32"
,
mindspore
::
kOpFormat_FRAC_NZ
),
std
::
make_pair
(
"float16"
,
mindspore
::
kOpFormat_NHWC
),
std
::
make_pair
(
"float32"
,
mindspore
::
kOpFormat_NHWC
),
std
::
make_pair
(
"int8"
,
mindspore
::
kOpFormat_NHWC
),
std
::
make_pair
(
"int16"
,
mindspore
::
kOpFormat_NHWC
),
std
::
make_pair
(
"int32"
,
mindspore
::
kOpFormat_NHWC
),
std
::
make_pair
(
"int64"
,
mindspore
::
kOpFormat_NHWC
),
std
::
make_pair
(
"uint8"
,
mindspore
::
kOpFormat_NHWC
),
std
::
make_pair
(
"uint16"
,
mindspore
::
kOpFormat_NHWC
),
std
::
make_pair
(
"uint32"
,
mindspore
::
kOpFormat_NHWC
),
std
::
make_pair
(
"uint64"
,
mindspore
::
kOpFormat_NHWC
),
std
::
make_pair
(
"float16"
,
mindspore
::
kOpFormat_HWCN
),
std
::
make_pair
(
"float32"
,
mindspore
::
kOpFormat_HWCN
),
std
::
make_pair
(
"int8"
,
mindspore
::
kOpFormat_HWCN
),
std
::
make_pair
(
"int16"
,
mindspore
::
kOpFormat_HWCN
),
std
::
make_pair
(
"int32"
,
mindspore
::
kOpFormat_HWCN
),
std
::
make_pair
(
"int64"
,
mindspore
::
kOpFormat_HWCN
),
std
::
make_pair
(
"uint8"
,
mindspore
::
kOpFormat_HWCN
),
std
::
make_pair
(
"uint16"
,
mindspore
::
kOpFormat_HWCN
),
std
::
make_pair
(
"uint32"
,
mindspore
::
kOpFormat_HWCN
),
std
::
make_pair
(
"uint64"
,
mindspore
::
kOpFormat_HWCN
)};
constexpr
auto
src_format
=
"src_format"
;
constexpr
auto
dst_format
=
"dst_format"
;
constexpr
auto
src
=
"src_0"
;
constexpr
auto
dst
=
"dst"
;
constexpr
auto
param_type_required
=
"required"
;
constexpr
auto
gen_model_single
=
"single"
;
constexpr
auto
trans_data
=
"trans_data"
;
constexpr
auto
platform_tbe
=
"TBE"
;
constexpr
auto
name
=
"name"
;
constexpr
auto
valid
=
"valid"
;
constexpr
auto
value
=
"value"
;
constexpr
auto
dtype
=
"dtype"
;
constexpr
auto
format_str
=
"format"
;
constexpr
auto
ori_format
=
"ori_format"
;
constexpr
auto
ori_shape
=
"ori_shape"
;
constexpr
auto
param_type
=
"param_type"
;
constexpr
auto
shape_str
=
"shape"
;
constexpr
auto
process_aicore
=
"aicore"
;
constexpr
auto
gen_model_str
=
"gen_model"
;
constexpr
auto
impl_path_str
=
"impl_path"
;
constexpr
auto
attrs_str
=
"attrs"
;
constexpr
auto
inputs_str
=
"inputs"
;
constexpr
auto
outputs_str
=
"outputs"
;
constexpr
auto
kernel_name_str
=
"kernel_name"
;
constexpr
auto
op_info_str
=
"op_info"
;
constexpr
auto
platform_str
=
"platform"
;
constexpr
auto
fractal_z
=
"FRACTAL_Z"
;
}
// namespace
namespace
mindspore
{
namespace
device
{
namespace
ascend
{
...
...
@@ -96,6 +153,102 @@ bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *s
return
true
;
}
size_t
GetCommonAlignSize
(
size_t
input_size
)
{
return
(
input_size
+
kMemAlignSize
+
31
)
/
kMemAlignSize
*
kMemAlignSize
;
}
nlohmann
::
json
ConstructAttrs
(
const
std
::
string
&
format
)
{
nlohmann
::
json
real_attr
;
nlohmann
::
json
src_attr
;
nlohmann
::
json
des_attr
;
src_attr
[
name
]
=
src_format
;
src_attr
[
valid
]
=
true
;
if
(
format
==
kOpFormat_FRAC_Z
)
{
src_attr
[
value
]
=
fractal_z
;
}
else
{
src_attr
[
value
]
=
format
;
}
des_attr
[
name
]
=
dst_format
;
des_attr
[
valid
]
=
true
;
des_attr
[
value
]
=
kOpFormat_NCHW
;
real_attr
.
push_back
(
src_attr
);
real_attr
.
push_back
(
des_attr
);
return
real_attr
;
}
nlohmann
::
json
ConstructInputs
(
const
std
::
vector
<
size_t
>
&
input_shape
,
const
std
::
vector
<
size_t
>
&
output_shape
,
const
std
::
string
&
format
,
mindspore
::
TypeId
type
)
{
nlohmann
::
json
input
;
nlohmann
::
json
input_json
;
nlohmann
::
json
real_input
;
real_input
[
dtype
]
=
type_id_name_map
.
at
(
type
);
if
(
format
==
kOpFormat_FRAC_Z
)
{
real_input
[
format_str
]
=
fractal_z
;
}
else
{
real_input
[
format_str
]
=
format
;
}
real_input
[
name
]
=
src
;
real_input
[
ori_format
]
=
kOpFormat_NCHW
;
for
(
auto
shape
:
output_shape
)
{
real_input
[
ori_shape
].
push_back
(
shape
);
}
real_input
[
param_type
]
=
param_type_required
;
// obtain inputs shape
for
(
auto
shape
:
input_shape
)
{
real_input
[
shape_str
].
push_back
(
shape
);
}
real_input
[
valid
]
=
true
;
input_json
.
push_back
(
real_input
);
input
.
push_back
(
input_json
);
return
input
;
}
nlohmann
::
json
ConstructOutputs
(
const
std
::
vector
<
size_t
>
&
output_shape
,
mindspore
::
TypeId
type
)
{
nlohmann
::
json
output
;
nlohmann
::
json
output_json
;
nlohmann
::
json
real_output
;
real_output
[
dtype
]
=
type_id_name_map
.
at
(
type
);
real_output
[
format_str
]
=
kOpFormat_NCHW
;
real_output
[
name
]
=
dst
;
real_output
[
ori_format
]
=
kOpFormat_NCHW
;
for
(
auto
shape
:
output_shape
)
{
real_output
[
ori_shape
].
push_back
(
shape
);
}
real_output
[
param_type
]
=
param_type_required
;
// obtain outputs shape
for
(
auto
shape
:
output_shape
)
{
real_output
[
shape_str
].
push_back
(
shape
);
}
real_output
[
valid
]
=
true
;
output_json
.
push_back
(
real_output
);
output
.
push_back
(
output_json
);
return
output
;
}
nlohmann
::
json
ConstructTransDataKernelJson
(
const
std
::
vector
<
size_t
>
&
host_shape
,
const
std
::
vector
<
size_t
>
&
device_shape
,
const
std
::
string
&
format
,
mindspore
::
TypeId
type
)
{
// generate kernel json
nlohmann
::
json
kernel_json
;
kernel_json
[
gen_model_str
]
=
gen_model_single
;
kernel_json
[
impl_path_str
]
=
""
;
// construct op_info
nlohmann
::
json
op_info
;
op_info
[
attrs_str
]
=
ConstructAttrs
(
format
);
op_info
[
inputs_str
]
=
ConstructInputs
(
device_shape
,
host_shape
,
format
,
type
);
op_info
[
kernel_name_str
]
=
""
;
op_info
[
name
]
=
trans_data
;
op_info
[
outputs_str
]
=
ConstructOutputs
(
host_shape
,
type
);
kernel_json
[
op_info_str
]
=
op_info
;
kernel_json
[
platform_str
]
=
platform_tbe
;
std
::
string
json_str
=
kernel_json
[
op_info_str
].
dump
();
size_t
hash_id
=
std
::
hash
<
std
::
string
>
()(
json_str
);
const
std
::
string
op_name
=
op_info
[
name
];
const
std
::
string
json_name
=
op_name
+
"_"
+
std
::
to_string
(
hash_id
);
kernel_json
[
op_info_str
][
kernel_name_str
]
=
json_name
;
return
kernel_json
;
}
void
AscendDeviceAddress
::
SyncStream
()
const
{
MS_LOG
(
INFO
)
<<
"Start!"
;
auto
ms_context
=
MsContext
::
GetInstance
();
...
...
@@ -158,31 +311,186 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t
return
sync_ok
;
}
void
AscendDeviceAddress
::
LaunchTransData
(
kernel
::
KernelModPtr
kernel_mod_ptr
,
void
*
output_address_ptr
,
size_t
output_size
,
const
std
::
vector
<
size_t
>
&
workspace_size_list
)
const
{
MS_EXCEPTION_IF_NULL
(
kernel_mod_ptr
);
auto
input_address
=
std
::
make_shared
<
kernel
::
Address
>
();
MS_EXCEPTION_IF_NULL
(
input_address
);
input_address
->
addr
=
ptr_
;
input_address
->
size
=
size_
;
auto
output_address
=
std
::
make_shared
<
kernel
::
Address
>
();
MS_EXCEPTION_IF_NULL
(
output_address
);
output_address
->
addr
=
output_address_ptr
;
output_address
->
size
=
output_size
;
AddressPtrList
kernel_inputs
=
{
input_address
};
AddressPtrList
kernel_outputs
=
{
output_address
};
AddressPtrList
kernel_workspaces
;
std
::
vector
<
void
*>
workspaces_address_ptr
(
workspace_size_list
.
size
(),
nullptr
);
if
(
!
workspace_size_list
.
empty
())
{
for
(
size_t
i
=
0
;
i
<
workspace_size_list
.
size
();
++
i
)
{
auto
workspace_size
=
GetCommonAlignSize
(
workspace_size_list
[
i
]);
auto
ret_malloc
=
rtMalloc
(
&
workspaces_address_ptr
[
i
],
workspace_size
,
RT_MEMORY_HBM
);
if
(
ret_malloc
!=
RT_ERROR_NONE
)
{
MS_LOG
(
ERROR
)
<<
"Failed to rtMalloc memory"
;
}
auto
workspace_address
=
std
::
make_shared
<
kernel
::
Address
>
();
MS_EXCEPTION_IF_NULL
(
workspace_address
);
workspace_address
->
addr
=
workspaces_address_ptr
[
i
];
workspace_address
->
size
=
workspace_size
;
kernel_workspaces
.
push_back
(
workspace_address
);
}
}
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
auto
device_id
=
ms_context
->
device_id
();
auto
runtime_instance
=
device
::
KernelRuntimeManager
::
Instance
().
GetKernelRuntime
(
kAscendDevice
,
device_id
);
MS_EXCEPTION_IF_NULL
(
runtime_instance
);
auto
ret
=
runtime_instance
->
LaunchTaskBasedOnSingleKernel
(
kernel_mod_ptr
,
kernel_inputs
,
kernel_outputs
,
kernel_workspaces
);
if
(
!
ret
)
{
MS_LOG
(
ERROR
)
<<
"Launch kernel failed."
;
}
SyncStream
();
if
(
!
workspace_size_list
.
empty
())
{
for
(
size_t
i
=
0
;
i
<
workspace_size_list
.
size
();
++
i
)
{
auto
ret_free
=
rtFree
(
workspaces_address_ptr
[
i
]);
if
(
ret_free
!=
RT_ERROR_NONE
)
{
MS_LOG
(
ERROR
)
<<
"Failed to rtFree memory"
;
}
}
}
}
kernel
::
KernelModPtr
AscendDeviceAddress
::
CompileTransDataAndObtainKernelMod
(
const
nlohmann
::
json
&
kernel_json
)
const
{
static
std
::
set
<
std
::
string
>
constructed_kernel
;
auto
build_manager
=
std
::
make_shared
<
kernel
::
ParallelBuildManager
>
();
MS_EXCEPTION_IF_NULL
(
build_manager
);
std
::
string
processor
=
process_aicore
;
// get size
std
::
vector
<
size_t
>
input_size_list
;
std
::
vector
<
size_t
>
output_size_list
;
(
void
)
kernel
::
TbeKernelBuild
::
GetIOSize
(
kernel_json
,
&
input_size_list
,
&
output_size_list
);
std
::
string
json_name
=
kernel_json
[
op_info_str
][
kernel_name_str
];
// op build
if
(
constructed_kernel
.
find
(
json_name
)
==
constructed_kernel
.
end
())
{
auto
task_id
=
build_manager
->
StartCompileOp
(
kernel_json
);
build_manager
->
SaveTaskInfo
(
task_id
,
nullptr
,
json_name
,
input_size_list
,
output_size_list
);
}
while
(
!
build_manager
->
IsAllTaskFinish
())
{
int
task_id
=
-
1
;
char
*
task_result
=
nullptr
;
char
*
pre_build_result
=
nullptr
;
auto
ret
=
build_manager
->
WaitOne
(
&
task_id
,
&
task_result
,
&
pre_build_result
);
if
(
!
ret
)
{
MS_EXCEPTION
(
ArgumentError
)
<<
"Build Failed. wait one ret:"
<<
ret
<<
", task id:"
<<
task_id
;
}
if
((
task_result
!=
nullptr
)
&&
(
strcmp
(
task_result
,
"Success"
)
!=
0
))
{
MS_EXCEPTION
(
ArgumentError
)
<<
"task compile Failed, task id:"
<<
task_id
<<
", cause:"
<<
task_result
;
}
(
void
)
build_manager
->
TaskFinishProcess
(
task_id
,
false
);
}
constructed_kernel
.
insert
(
json_name
);
// search cache
auto
cached_kernel_pack
=
TbeUtils
::
SearchCache
(
json_name
,
processor
);
MS_EXCEPTION_IF_NULL
(
cached_kernel_pack
);
auto
kernel_mod_ptr
=
build_manager
->
GenKernelMod
(
json_name
,
processor
,
input_size_list
,
output_size_list
,
cached_kernel_pack
);
return
kernel_mod_ptr
;
}
bool
AscendDeviceAddress
::
SyncDeviceToHostAndConvertFormatBasedOnTransData
(
const
std
::
vector
<
size_t
>
&
host_shape
,
const
std
::
vector
<
size_t
>
&
device_shape
,
size_t
size
,
mindspore
::
TypeId
type
,
void
*
host_ptr
)
const
{
bool
sync_ok
=
true
;
// construct trans data kernel json
nlohmann
::
json
kernel_json
=
ConstructTransDataKernelJson
(
host_shape
,
device_shape
,
format_
,
type_id_
);
MS_LOG
(
INFO
)
<<
"Construct trans_data kernel json: "
<<
kernel_json
.
dump
();
auto
kernel_mod_ptr
=
CompileTransDataAndObtainKernelMod
(
kernel_json
);
MS_EXCEPTION_IF_NULL
(
kernel_mod_ptr
);
auto
host_size
=
size
;
if
(
type_id_
!=
type
)
{
auto
device_dtype_size
=
trans
::
TypeIdSize
(
type_id_
);
if
(
device_dtype_size
<
1
)
{
MS_LOG
(
ERROR
)
<<
"Illegal dtype."
;
}
auto
shape_size
=
trans
::
ShapeSize
(
host_shape
);
auto
size_tmp
=
device_dtype_size
*
shape_size
;
size
=
GetCommonAlignSize
(
size_tmp
);
}
void
*
output_address_ptr
=
nullptr
;
auto
ret_malloc
=
rtMalloc
(
&
output_address_ptr
,
size
,
RT_MEMORY_HBM
);
if
(
ret_malloc
!=
RT_ERROR_NONE
)
{
MS_LOG
(
ERROR
)
<<
"Failed to rtMalloc memory"
;
}
auto
workspace_size_list
=
GetWorkspaceSizeList
(
kernel_json
);
// launch
LaunchTransData
(
kernel_mod_ptr
,
output_address_ptr
,
size
,
workspace_size_list
);
if
(
type_id_
==
type
)
{
SyncMemory
(
host_ptr
,
output_address_ptr
,
size
,
RT_MEMCPY_DEVICE_TO_HOST
);
}
else
{
auto
host
=
std
::
vector
<
uint8_t
>
(
size
);
SyncMemory
(
host
.
data
(),
output_address_ptr
,
size
,
RT_MEMCPY_DEVICE_TO_HOST
);
auto
shape_size
=
trans
::
ShapeSize
(
host_shape
);
const
trans
::
TypeIdArgs
type_args
{
host
.
data
(),
shape_size
,
type_id_
,
type
,
host_size
};
sync_ok
=
trans
::
TransDataType
(
type_args
,
host_ptr
);
if
(
!
sync_ok
)
{
MS_LOG
(
ERROR
)
<<
"Trans format failed."
;
return
false
;
}
}
auto
ret_free
=
rtFree
(
output_address_ptr
);
if
(
ret_free
!=
RT_ERROR_NONE
)
{
MS_LOG
(
ERROR
)
<<
"Failed to rtFree memory"
;
}
return
sync_ok
;
}
std
::
vector
<
size_t
>
AscendDeviceAddress
::
GetWorkspaceSizeList
(
const
nlohmann
::
json
&
kernel_json
)
const
{
std
::
string
json_name
=
kernel_json
[
op_info_str
][
kernel_name_str
];
std
::
string
processor
=
process_aicore
;
auto
cached_kernel_pack
=
TbeUtils
::
SearchCache
(
json_name
,
processor
);
MS_EXCEPTION_IF_NULL
(
cached_kernel_pack
);
auto
kernel_json_info
=
cached_kernel_pack
->
kernel_json_info
();
return
kernel_json_info
.
workspaces
;
}
std
::
vector
<
size_t
>
AscendDeviceAddress
::
GetDeviceShape
(
std
::
vector
<
size_t
>
*
host_shape
)
const
{
std
::
vector
<
size_t
>
device_shape
;
if
(
format_
==
kOpFormat_FRAC_NZ
||
format_
==
kOpFormat_NDHWC
)
{
device_shape
=
trans
::
TransShapeToDevice
(
*
host_shape
,
format_
);
}
else
{
if
(
host_shape_
.
empty
())
{
*
host_shape
=
trans
::
PaddingShapeTo4d
(
*
host_shape
);
}
else
{
host_shape
->
clear
();
(
void
)
std
::
transform
(
host_shape_
.
begin
(),
host_shape_
.
end
(),
std
::
back_inserter
(
*
host_shape
),
IntToSize
);
}
device_shape
=
trans
::
TransShapeToDevice
(
*
host_shape
,
format_
);
}
return
device_shape
;
}
bool
AscendDeviceAddress
::
SyncDeviceToHostAndConvertFormat
(
const
std
::
vector
<
int
>
&
shape
,
size_t
size
,
mindspore
::
TypeId
type
,
void
*
host_ptr
)
const
{
MS_LOG
(
INFO
)
<<
"SyncDeviceToHostAndConvertFormat, Device(format:"
<<
format_
<<
", type_id:"
<<
TypeIdLabel
(
type_id_
)
<<
", size:"
<<
size_
<<
"), Host(type_id:"
<<
TypeIdLabel
(
type
)
<<
", size:"
<<
size
<<
")"
;
bool
sync_ok
=
false
;
auto
host_tmp
=
std
::
vector
<
uint8_t
>
(
size_
);
SyncMemory
(
host_tmp
.
data
(),
ptr_
,
size_
,
RT_MEMCPY_DEVICE_TO_HOST
);
std
::
vector
<
size_t
>
host_shape
;
(
void
)
std
::
transform
(
shape
.
begin
(),
shape
.
end
(),
std
::
back_inserter
(
host_shape
),
IntToSize
);
std
::
vector
<
size_t
>
device_shape
;
if
(
host_shape
.
empty
())
{
host_shape
.
emplace_back
(
1
);
}
if
(
format_
==
kOpFormat_FRAC_NZ
||
format_
==
kOpFormat_NDHWC
)
{
device_shape
=
trans
::
TransShapeToDevice
(
host_shape
,
format_
);
}
else
{
if
(
host_shape_
.
empty
())
{
host_shape
=
trans
::
PaddingShapeTo4d
(
host_shape
);
}
else
{
host_shape
.
clear
();
(
void
)
std
::
transform
(
host_shape_
.
begin
(),
host_shape_
.
end
(),
std
::
back_inserter
(
host_shape
),
IntToSize
);
std
::
vector
<
size_t
>
device_shape
=
GetDeviceShape
(
&
host_shape
);
if
(
type_id_name_map
.
find
(
type_id_
)
!=
type_id_name_map
.
end
())
{
std
::
pair
<
std
::
string
,
std
::
string
>
type_format
=
std
::
make_pair
(
type_id_name_map
.
at
(
type_id_
),
format_
);
if
(
use_trans_data
.
find
(
type_format
)
!=
use_trans_data
.
end
())
{
sync_ok
=
SyncDeviceToHostAndConvertFormatBasedOnTransData
(
host_shape
,
device_shape
,
size
,
type
,
host_ptr
);
return
sync_ok
;
}
device_shape
=
trans
::
TransShapeToDevice
(
host_shape
,
format_
);
}
auto
host_tmp
=
std
::
vector
<
uint8_t
>
(
size_
);
SyncMemory
(
host_tmp
.
data
(),
ptr_
,
size_
,
RT_MEMCPY_DEVICE_TO_HOST
);
if
(
type_id_
!=
type
)
{
const
trans
::
FormatArgs
format_args
{
host_tmp
.
data
(),
size_
,
kOpFormat_NCHW
,
format_
,
host_shape
,
device_shape
,
type_id_
};
...
...
mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h
浏览文件 @
7b48a122
...
...
@@ -20,9 +20,11 @@
#include <string>
#include <vector>
#include <memory>
#include <nlohmann/json.hpp>
#include "runtime/device/device_address.h"
#include "runtime/device/ascend/ascend_memory_pool.h"
#include "ir/dtype.h"
#include "backend/kernel_compiler/kernel.h"
namespace
mindspore
{
#ifdef ENABLE_DEBUGGER
...
...
@@ -53,7 +55,16 @@ class AscendDeviceAddress : public DeviceAddress {
bool
SyncDeviceToHostAndConvertFormat
(
const
std
::
vector
<
int
>
&
shape
,
size_t
size
,
TypeId
type
,
void
*
host_ptr
)
const
;
bool
ConvertFormatAndSyncHostToDevice
(
const
std
::
vector
<
int
>
&
shape
,
size_t
size
,
TypeId
type
,
const
void
*
host_ptr
)
const
;
bool
SyncDeviceToHostAndConvertFormatBasedOnTransData
(
const
std
::
vector
<
size_t
>
&
host_shape
,
const
std
::
vector
<
size_t
>
&
device_shape
,
size_t
size
,
mindspore
::
TypeId
type
,
void
*
host_ptr
)
const
;
void
SyncStream
()
const
;
void
LaunchTransData
(
kernel
::
KernelModPtr
kernel_mod_ptr
,
void
*
output_address_ptr
,
size_t
output_size
,
const
std
::
vector
<
size_t
>
&
workspace_size_list
)
const
;
std
::
vector
<
size_t
>
GetDeviceShape
(
std
::
vector
<
size_t
>
*
host_shape
)
const
;
std
::
vector
<
size_t
>
GetWorkspaceSizeList
(
const
nlohmann
::
json
&
kernel_json
)
const
;
kernel
::
KernelModPtr
CompileTransDataAndObtainKernelMod
(
const
nlohmann
::
json
&
kernel_json
)
const
;
};
using
AscendDeviceAddressPtr
=
std
::
shared_ptr
<
AscendDeviceAddress
>
;
}
// namespace ascend
...
...
mindspore/ccsrc/runtime/device/kernel_runtime.cc
浏览文件 @
7b48a122
...
...
@@ -757,6 +757,18 @@ void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) {
MS_LOG
(
INFO
)
<<
"Clear graph:"
<<
graph_id
<<
" runtime resource"
;
}
bool
KernelRuntime
::
LaunchTaskBasedOnSingleKernel
(
kernel
::
KernelModPtr
kernel_mod_ptr
,
AddressPtrList
kernel_inputs
,
AddressPtrList
kernel_outputs
,
AddressPtrList
kernel_workspaces
)
const
{
MS_EXCEPTION_IF_NULL
(
kernel_mod_ptr
);
auto
ret
=
kernel_mod_ptr
->
Launch
(
kernel_inputs
,
kernel_workspaces
,
kernel_outputs
,
stream_
);
if
(
!
ret
)
{
MS_LOG
(
ERROR
)
<<
"Launch kernel failed."
;
return
false
;
}
return
true
;
}
#ifdef ENABLE_DUMP_E2E
bool
KernelRuntime
::
SetDumpConf
()
{
dump_conf_ptr_
=
std
::
make_shared
<
Dump
>
();
...
...
mindspore/ccsrc/runtime/device/kernel_runtime.h
浏览文件 @
7b48a122
...
...
@@ -61,6 +61,8 @@ class KernelRuntime {
virtual
bool
RunTask
(
const
session
::
KernelGraph
*
graph
);
virtual
bool
GenTask
(
const
session
::
KernelGraph
*
graph
);
bool
LaunchKernel
(
const
session
::
KernelGraph
*
graph
);
bool
LaunchTaskBasedOnSingleKernel
(
kernel
::
KernelModPtr
kernel_mod_ptr
,
AddressPtrList
kernel_inputs
,
AddressPtrList
kernel_outputs
,
AddressPtrList
kernel_workspaces
)
const
;
virtual
void
AssignStaticMemoryInput
(
const
session
::
KernelGraph
*
graph
);
virtual
void
AssignStaticMemoryValueNode
(
session
::
KernelGraph
*
graph
);
virtual
void
ClearGraphRuntimeResource
(
uint32_t
graph_id
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录