提交 7b48a122 编写于 作者: L lvchangquan

insert trans_data to reduce time in print process

上级 4e0cfafc
......@@ -16,14 +16,19 @@
#include "runtime/device/ascend/ascend_device_address.h"
#include <memory>
#include <vector>
#include <unordered_map>
#include <utility>
#include <set>
#include <algorithm>
#include "runtime/mem.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "runtime/device/kernel_runtime.h"
#include "runtime/device/convert_tensor_utils.h"
#include "ir/dtype/type.h"
#include "ir/tensor.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_build.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h"
#include "utils/utils.h"
#include "common/utils.h"
#include "common/trans.h"
......@@ -34,6 +39,58 @@
#include "debug/tensor_load.h"
#endif
namespace {
const std::unordered_map<mindspore::TypeId, std::string> type_id_name_map = {
{mindspore::kNumberTypeBool, "bool"}, {mindspore::kNumberTypeInt8, "int8"},
{mindspore::kNumberTypeInt16, "int16"}, {mindspore::kNumberTypeInt32, "int32"},
{mindspore::kNumberTypeInt64, "int64"}, {mindspore::kNumberTypeFloat16, "float16"},
{mindspore::kNumberTypeFloat32, "float32"}, {mindspore::kNumberTypeUInt8, "uint8"},
{mindspore::kNumberTypeUInt16, "uint16"}, {mindspore::kNumberTypeUInt32, "uint32"},
{mindspore::kNumberTypeUInt64, "uint64"}};
const std::set<std::pair<std::string, std::string>> use_trans_data = {
std::make_pair("float16", mindspore::kOpFormat_NC1HWC0), std::make_pair("float32", mindspore::kOpFormat_NC1HWC0),
std::make_pair("bool", mindspore::kOpFormat_NC1HWC0), std::make_pair("float32", mindspore::kOpFormat_FRAC_Z),
std::make_pair("float16", mindspore::kOpFormat_FRAC_Z), std::make_pair("float16", mindspore::kOpFormat_FRAC_NZ),
std::make_pair("float32", mindspore::kOpFormat_FRAC_NZ), std::make_pair("int32", mindspore::kOpFormat_FRAC_NZ),
std::make_pair("float16", mindspore::kOpFormat_NHWC), std::make_pair("float32", mindspore::kOpFormat_NHWC),
std::make_pair("int8", mindspore::kOpFormat_NHWC), std::make_pair("int16", mindspore::kOpFormat_NHWC),
std::make_pair("int32", mindspore::kOpFormat_NHWC), std::make_pair("int64", mindspore::kOpFormat_NHWC),
std::make_pair("uint8", mindspore::kOpFormat_NHWC), std::make_pair("uint16", mindspore::kOpFormat_NHWC),
std::make_pair("uint32", mindspore::kOpFormat_NHWC), std::make_pair("uint64", mindspore::kOpFormat_NHWC),
std::make_pair("float16", mindspore::kOpFormat_HWCN), std::make_pair("float32", mindspore::kOpFormat_HWCN),
std::make_pair("int8", mindspore::kOpFormat_HWCN), std::make_pair("int16", mindspore::kOpFormat_HWCN),
std::make_pair("int32", mindspore::kOpFormat_HWCN), std::make_pair("int64", mindspore::kOpFormat_HWCN),
std::make_pair("uint8", mindspore::kOpFormat_HWCN), std::make_pair("uint16", mindspore::kOpFormat_HWCN),
std::make_pair("uint32", mindspore::kOpFormat_HWCN), std::make_pair("uint64", mindspore::kOpFormat_HWCN)};
constexpr auto src_format = "src_format";
constexpr auto dst_format = "dst_format";
constexpr auto src = "src_0";
constexpr auto dst = "dst";
constexpr auto param_type_required = "required";
constexpr auto gen_model_single = "single";
constexpr auto trans_data = "trans_data";
constexpr auto platform_tbe = "TBE";
constexpr auto name = "name";
constexpr auto valid = "valid";
constexpr auto value = "value";
constexpr auto dtype = "dtype";
constexpr auto format_str = "format";
constexpr auto ori_format = "ori_format";
constexpr auto ori_shape = "ori_shape";
constexpr auto param_type = "param_type";
constexpr auto shape_str = "shape";
constexpr auto process_aicore = "aicore";
constexpr auto gen_model_str = "gen_model";
constexpr auto impl_path_str = "impl_path";
constexpr auto attrs_str = "attrs";
constexpr auto inputs_str = "inputs";
constexpr auto outputs_str = "outputs";
constexpr auto kernel_name_str = "kernel_name";
constexpr auto op_info_str = "op_info";
constexpr auto platform_str = "platform";
constexpr auto fractal_z = "FRACTAL_Z";
} // namespace
namespace mindspore {
namespace device {
namespace ascend {
......@@ -96,6 +153,102 @@ bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *s
return true;
}
size_t GetCommonAlignSize(size_t input_size) {
return (input_size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize;
}
nlohmann::json ConstructAttrs(const std::string &format) {
nlohmann::json real_attr;
nlohmann::json src_attr;
nlohmann::json des_attr;
src_attr[name] = src_format;
src_attr[valid] = true;
if (format == kOpFormat_FRAC_Z) {
src_attr[value] = fractal_z;
} else {
src_attr[value] = format;
}
des_attr[name] = dst_format;
des_attr[valid] = true;
des_attr[value] = kOpFormat_NCHW;
real_attr.push_back(src_attr);
real_attr.push_back(des_attr);
return real_attr;
}
nlohmann::json ConstructInputs(const std::vector<size_t> &input_shape, const std::vector<size_t> &output_shape,
const std::string &format, mindspore::TypeId type) {
nlohmann::json input;
nlohmann::json input_json;
nlohmann::json real_input;
real_input[dtype] = type_id_name_map.at(type);
if (format == kOpFormat_FRAC_Z) {
real_input[format_str] = fractal_z;
} else {
real_input[format_str] = format;
}
real_input[name] = src;
real_input[ori_format] = kOpFormat_NCHW;
for (auto shape : output_shape) {
real_input[ori_shape].push_back(shape);
}
real_input[param_type] = param_type_required;
// obtain inputs shape
for (auto shape : input_shape) {
real_input[shape_str].push_back(shape);
}
real_input[valid] = true;
input_json.push_back(real_input);
input.push_back(input_json);
return input;
}
nlohmann::json ConstructOutputs(const std::vector<size_t> &output_shape, mindspore::TypeId type) {
nlohmann::json output;
nlohmann::json output_json;
nlohmann::json real_output;
real_output[dtype] = type_id_name_map.at(type);
real_output[format_str] = kOpFormat_NCHW;
real_output[name] = dst;
real_output[ori_format] = kOpFormat_NCHW;
for (auto shape : output_shape) {
real_output[ori_shape].push_back(shape);
}
real_output[param_type] = param_type_required;
// obtain outputs shape
for (auto shape : output_shape) {
real_output[shape_str].push_back(shape);
}
real_output[valid] = true;
output_json.push_back(real_output);
output.push_back(output_json);
return output;
}
nlohmann::json ConstructTransDataKernelJson(const std::vector<size_t> &host_shape,
const std::vector<size_t> &device_shape, const std::string &format,
mindspore::TypeId type) {
// generate kernel json
nlohmann::json kernel_json;
kernel_json[gen_model_str] = gen_model_single;
kernel_json[impl_path_str] = "";
// construct op_info
nlohmann::json op_info;
op_info[attrs_str] = ConstructAttrs(format);
op_info[inputs_str] = ConstructInputs(device_shape, host_shape, format, type);
op_info[kernel_name_str] = "";
op_info[name] = trans_data;
op_info[outputs_str] = ConstructOutputs(host_shape, type);
kernel_json[op_info_str] = op_info;
kernel_json[platform_str] = platform_tbe;
std::string json_str = kernel_json[op_info_str].dump();
size_t hash_id = std::hash<std::string>()(json_str);
const std::string op_name = op_info[name];
const std::string json_name = op_name + "_" + std::to_string(hash_id);
kernel_json[op_info_str][kernel_name_str] = json_name;
return kernel_json;
}
void AscendDeviceAddress::SyncStream() const {
MS_LOG(INFO) << "Start!";
auto ms_context = MsContext::GetInstance();
......@@ -158,31 +311,186 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t
return sync_ok;
}
void AscendDeviceAddress::LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, void *output_address_ptr,
size_t output_size, const std::vector<size_t> &workspace_size_list) const {
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
auto input_address = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(input_address);
input_address->addr = ptr_;
input_address->size = size_;
auto output_address = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(output_address);
output_address->addr = output_address_ptr;
output_address->size = output_size;
AddressPtrList kernel_inputs = {input_address};
AddressPtrList kernel_outputs = {output_address};
AddressPtrList kernel_workspaces;
std::vector<void *> workspaces_address_ptr(workspace_size_list.size(), nullptr);
if (!workspace_size_list.empty()) {
for (size_t i = 0; i < workspace_size_list.size(); ++i) {
auto workspace_size = GetCommonAlignSize(workspace_size_list[i]);
auto ret_malloc = rtMalloc(&workspaces_address_ptr[i], workspace_size, RT_MEMORY_HBM);
if (ret_malloc != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Failed to rtMalloc memory";
}
auto workspace_address = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(workspace_address);
workspace_address->addr = workspaces_address_ptr[i];
workspace_address->size = workspace_size;
kernel_workspaces.push_back(workspace_address);
}
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto device_id = ms_context->device_id();
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
MS_EXCEPTION_IF_NULL(runtime_instance);
auto ret =
runtime_instance->LaunchTaskBasedOnSingleKernel(kernel_mod_ptr, kernel_inputs, kernel_outputs, kernel_workspaces);
if (!ret) {
MS_LOG(ERROR) << "Launch kernel failed.";
}
SyncStream();
if (!workspace_size_list.empty()) {
for (size_t i = 0; i < workspace_size_list.size(); ++i) {
auto ret_free = rtFree(workspaces_address_ptr[i]);
if (ret_free != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Failed to rtFree memory";
}
}
}
}
kernel::KernelModPtr AscendDeviceAddress::CompileTransDataAndObtainKernelMod(const nlohmann::json &kernel_json) const {
static std::set<std::string> constructed_kernel;
auto build_manager = std::make_shared<kernel::ParallelBuildManager>();
MS_EXCEPTION_IF_NULL(build_manager);
std::string processor = process_aicore;
// get size
std::vector<size_t> input_size_list;
std::vector<size_t> output_size_list;
(void)kernel::TbeKernelBuild::GetIOSize(kernel_json, &input_size_list, &output_size_list);
std::string json_name = kernel_json[op_info_str][kernel_name_str];
// op build
if (constructed_kernel.find(json_name) == constructed_kernel.end()) {
auto task_id = build_manager->StartCompileOp(kernel_json);
build_manager->SaveTaskInfo(task_id, nullptr, json_name, input_size_list, output_size_list);
}
while (!build_manager->IsAllTaskFinish()) {
int task_id = -1;
char *task_result = nullptr;
char *pre_build_result = nullptr;
auto ret = build_manager->WaitOne(&task_id, &task_result, &pre_build_result);
if (!ret) {
MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id;
}
if ((task_result != nullptr) && (strcmp(task_result, "Success") != 0)) {
MS_EXCEPTION(ArgumentError) << "task compile Failed, task id:" << task_id << ", cause:" << task_result;
}
(void)build_manager->TaskFinishProcess(task_id, false);
}
constructed_kernel.insert(json_name);
// search cache
auto cached_kernel_pack = TbeUtils::SearchCache(json_name, processor);
MS_EXCEPTION_IF_NULL(cached_kernel_pack);
auto kernel_mod_ptr =
build_manager->GenKernelMod(json_name, processor, input_size_list, output_size_list, cached_kernel_pack);
return kernel_mod_ptr;
}
bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const std::vector<size_t> &host_shape,
const std::vector<size_t> &device_shape,
size_t size, mindspore::TypeId type,
void *host_ptr) const {
bool sync_ok = true;
// construct trans data kernel json
nlohmann::json kernel_json = ConstructTransDataKernelJson(host_shape, device_shape, format_, type_id_);
MS_LOG(INFO) << "Construct trans_data kernel json: " << kernel_json.dump();
auto kernel_mod_ptr = CompileTransDataAndObtainKernelMod(kernel_json);
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
auto host_size = size;
if (type_id_ != type) {
auto device_dtype_size = trans::TypeIdSize(type_id_);
if (device_dtype_size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
}
auto shape_size = trans::ShapeSize(host_shape);
auto size_tmp = device_dtype_size * shape_size;
size = GetCommonAlignSize(size_tmp);
}
void *output_address_ptr = nullptr;
auto ret_malloc = rtMalloc(&output_address_ptr, size, RT_MEMORY_HBM);
if (ret_malloc != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Failed to rtMalloc memory";
}
auto workspace_size_list = GetWorkspaceSizeList(kernel_json);
// launch
LaunchTransData(kernel_mod_ptr, output_address_ptr, size, workspace_size_list);
if (type_id_ == type) {
SyncMemory(host_ptr, output_address_ptr, size, RT_MEMCPY_DEVICE_TO_HOST);
} else {
auto host = std::vector<uint8_t>(size);
SyncMemory(host.data(), output_address_ptr, size, RT_MEMCPY_DEVICE_TO_HOST);
auto shape_size = trans::ShapeSize(host_shape);
const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, host_size};
sync_ok = trans::TransDataType(type_args, host_ptr);
if (!sync_ok) {
MS_LOG(ERROR) << "Trans format failed.";
return false;
}
}
auto ret_free = rtFree(output_address_ptr);
if (ret_free != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Failed to rtFree memory";
}
return sync_ok;
}
std::vector<size_t> AscendDeviceAddress::GetWorkspaceSizeList(const nlohmann::json &kernel_json) const {
std::string json_name = kernel_json[op_info_str][kernel_name_str];
std::string processor = process_aicore;
auto cached_kernel_pack = TbeUtils::SearchCache(json_name, processor);
MS_EXCEPTION_IF_NULL(cached_kernel_pack);
auto kernel_json_info = cached_kernel_pack->kernel_json_info();
return kernel_json_info.workspaces;
}
std::vector<size_t> AscendDeviceAddress::GetDeviceShape(std::vector<size_t> *host_shape) const {
std::vector<size_t> device_shape;
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) {
device_shape = trans::TransShapeToDevice(*host_shape, format_);
} else {
if (host_shape_.empty()) {
*host_shape = trans::PaddingShapeTo4d(*host_shape);
} else {
host_shape->clear();
(void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(*host_shape), IntToSize);
}
device_shape = trans::TransShapeToDevice(*host_shape, format_);
}
return device_shape;
}
bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int> &shape, size_t size,
mindspore::TypeId type, void *host_ptr) const {
MS_LOG(INFO) << "SyncDeviceToHostAndConvertFormat, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_)
<< ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")";
bool sync_ok = false;
auto host_tmp = std::vector<uint8_t>(size_);
SyncMemory(host_tmp.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST);
std::vector<size_t> host_shape;
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize);
std::vector<size_t> device_shape;
if (host_shape.empty()) {
host_shape.emplace_back(1);
}
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) {
device_shape = trans::TransShapeToDevice(host_shape, format_);
} else {
if (host_shape_.empty()) {
host_shape = trans::PaddingShapeTo4d(host_shape);
} else {
host_shape.clear();
(void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(host_shape), IntToSize);
std::vector<size_t> device_shape = GetDeviceShape(&host_shape);
if (type_id_name_map.find(type_id_) != type_id_name_map.end()) {
std::pair<std::string, std::string> type_format = std::make_pair(type_id_name_map.at(type_id_), format_);
if (use_trans_data.find(type_format) != use_trans_data.end()) {
sync_ok = SyncDeviceToHostAndConvertFormatBasedOnTransData(host_shape, device_shape, size, type, host_ptr);
return sync_ok;
}
device_shape = trans::TransShapeToDevice(host_shape, format_);
}
auto host_tmp = std::vector<uint8_t>(size_);
SyncMemory(host_tmp.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST);
if (type_id_ != type) {
const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_,
host_shape, device_shape, type_id_};
......
......@@ -20,9 +20,11 @@
#include <string>
#include <vector>
#include <memory>
#include <nlohmann/json.hpp>
#include "runtime/device/device_address.h"
#include "runtime/device/ascend/ascend_memory_pool.h"
#include "ir/dtype.h"
#include "backend/kernel_compiler/kernel.h"
namespace mindspore {
#ifdef ENABLE_DEBUGGER
......@@ -53,7 +55,16 @@ class AscendDeviceAddress : public DeviceAddress {
bool SyncDeviceToHostAndConvertFormat(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const;
bool ConvertFormatAndSyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type,
const void *host_ptr) const;
bool SyncDeviceToHostAndConvertFormatBasedOnTransData(const std::vector<size_t> &host_shape,
const std::vector<size_t> &device_shape, size_t size,
mindspore::TypeId type, void *host_ptr) const;
void SyncStream() const;
void LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, void *output_address_ptr, size_t output_size,
const std::vector<size_t> &workspace_size_list) const;
std::vector<size_t> GetDeviceShape(std::vector<size_t> *host_shape) const;
std::vector<size_t> GetWorkspaceSizeList(const nlohmann::json &kernel_json) const;
kernel::KernelModPtr CompileTransDataAndObtainKernelMod(const nlohmann::json &kernel_json) const;
};
using AscendDeviceAddressPtr = std::shared_ptr<AscendDeviceAddress>;
} // namespace ascend
......
......@@ -757,6 +757,18 @@ void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) {
MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource";
}
bool KernelRuntime::LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, AddressPtrList kernel_inputs,
AddressPtrList kernel_outputs,
AddressPtrList kernel_workspaces) const {
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
auto ret = kernel_mod_ptr->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
if (!ret) {
MS_LOG(ERROR) << "Launch kernel failed.";
return false;
}
return true;
}
#ifdef ENABLE_DUMP_E2E
bool KernelRuntime::SetDumpConf() {
dump_conf_ptr_ = std::make_shared<Dump>();
......
......@@ -61,6 +61,8 @@ class KernelRuntime {
virtual bool RunTask(const session::KernelGraph *graph);
virtual bool GenTask(const session::KernelGraph *graph);
bool LaunchKernel(const session::KernelGraph *graph);
bool LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, AddressPtrList kernel_inputs,
AddressPtrList kernel_outputs, AddressPtrList kernel_workspaces) const;
virtual void AssignStaticMemoryInput(const session::KernelGraph *graph);
virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph);
virtual void ClearGraphRuntimeResource(uint32_t graph_id);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册