diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc index a490f95fc62a85b80c9ff8099e05a7e10561c3dd..26dbef12e4327211b7d23094b23fc026d7d5cdba 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc @@ -16,14 +16,19 @@ #include "runtime/device/ascend/ascend_device_address.h" #include #include +#include +#include #include #include #include "runtime/mem.h" #include "runtime/device/kernel_runtime_manager.h" +#include "runtime/device/kernel_runtime.h" #include "runtime/device/convert_tensor_utils.h" #include "ir/dtype/type.h" #include "ir/tensor.h" #include "backend/kernel_compiler/common_utils.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_build.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h" #include "utils/utils.h" #include "common/utils.h" #include "common/trans.h" @@ -34,6 +39,58 @@ #include "debug/tensor_load.h" #endif +namespace { +const std::unordered_map type_id_name_map = { + {mindspore::kNumberTypeBool, "bool"}, {mindspore::kNumberTypeInt8, "int8"}, + {mindspore::kNumberTypeInt16, "int16"}, {mindspore::kNumberTypeInt32, "int32"}, + {mindspore::kNumberTypeInt64, "int64"}, {mindspore::kNumberTypeFloat16, "float16"}, + {mindspore::kNumberTypeFloat32, "float32"}, {mindspore::kNumberTypeUInt8, "uint8"}, + {mindspore::kNumberTypeUInt16, "uint16"}, {mindspore::kNumberTypeUInt32, "uint32"}, + {mindspore::kNumberTypeUInt64, "uint64"}}; +const std::set> use_trans_data = { + std::make_pair("float16", mindspore::kOpFormat_NC1HWC0), std::make_pair("float32", mindspore::kOpFormat_NC1HWC0), + std::make_pair("bool", mindspore::kOpFormat_NC1HWC0), std::make_pair("float32", mindspore::kOpFormat_FRAC_Z), + std::make_pair("float16", mindspore::kOpFormat_FRAC_Z), std::make_pair("float16", mindspore::kOpFormat_FRAC_NZ), + std::make_pair("float32", mindspore::kOpFormat_FRAC_NZ), std::make_pair("int32", mindspore::kOpFormat_FRAC_NZ), + std::make_pair("float16", mindspore::kOpFormat_NHWC), std::make_pair("float32", mindspore::kOpFormat_NHWC), + std::make_pair("int8", mindspore::kOpFormat_NHWC), std::make_pair("int16", mindspore::kOpFormat_NHWC), + std::make_pair("int32", mindspore::kOpFormat_NHWC), std::make_pair("int64", mindspore::kOpFormat_NHWC), + std::make_pair("uint8", mindspore::kOpFormat_NHWC), std::make_pair("uint16", mindspore::kOpFormat_NHWC), + std::make_pair("uint32", mindspore::kOpFormat_NHWC), std::make_pair("uint64", mindspore::kOpFormat_NHWC), + std::make_pair("float16", mindspore::kOpFormat_HWCN), std::make_pair("float32", mindspore::kOpFormat_HWCN), + std::make_pair("int8", mindspore::kOpFormat_HWCN), std::make_pair("int16", mindspore::kOpFormat_HWCN), + std::make_pair("int32", mindspore::kOpFormat_HWCN), std::make_pair("int64", mindspore::kOpFormat_HWCN), + std::make_pair("uint8", mindspore::kOpFormat_HWCN), std::make_pair("uint16", mindspore::kOpFormat_HWCN), + std::make_pair("uint32", mindspore::kOpFormat_HWCN), std::make_pair("uint64", mindspore::kOpFormat_HWCN)}; +constexpr auto src_format = "src_format"; +constexpr auto dst_format = "dst_format"; +constexpr auto src = "src_0"; +constexpr auto dst = "dst"; +constexpr auto param_type_required = "required"; +constexpr auto gen_model_single = "single"; +constexpr auto trans_data = "trans_data"; +constexpr auto platform_tbe = "TBE"; +constexpr auto name = "name"; +constexpr auto valid = "valid"; +constexpr auto value = "value"; +constexpr auto dtype = "dtype"; +constexpr auto format_str = "format"; +constexpr auto ori_format = "ori_format"; +constexpr auto ori_shape = "ori_shape"; +constexpr auto param_type = "param_type"; +constexpr auto shape_str = "shape"; +constexpr auto process_aicore = "aicore"; +constexpr auto gen_model_str = "gen_model"; +constexpr auto impl_path_str = "impl_path"; +constexpr auto attrs_str = "attrs"; +constexpr auto inputs_str = "inputs"; +constexpr auto outputs_str = "outputs"; +constexpr auto kernel_name_str = "kernel_name"; +constexpr auto op_info_str = "op_info"; +constexpr auto platform_str = "platform"; +constexpr auto fractal_z = "FRACTAL_Z"; +} // namespace + namespace mindspore { namespace device { namespace ascend { @@ -96,6 +153,102 @@ bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *s return true; } +size_t GetCommonAlignSize(size_t input_size) { + return (input_size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize; +} + +nlohmann::json ConstructAttrs(const std::string &format) { + nlohmann::json real_attr; + nlohmann::json src_attr; + nlohmann::json des_attr; + src_attr[name] = src_format; + src_attr[valid] = true; + if (format == kOpFormat_FRAC_Z) { + src_attr[value] = fractal_z; + } else { + src_attr[value] = format; + } + des_attr[name] = dst_format; + des_attr[valid] = true; + des_attr[value] = kOpFormat_NCHW; + real_attr.push_back(src_attr); + real_attr.push_back(des_attr); + return real_attr; +} + +nlohmann::json ConstructInputs(const std::vector &input_shape, const std::vector &output_shape, + const std::string &format, mindspore::TypeId type) { + nlohmann::json input; + nlohmann::json input_json; + nlohmann::json real_input; + real_input[dtype] = type_id_name_map.at(type); + if (format == kOpFormat_FRAC_Z) { + real_input[format_str] = fractal_z; + } else { + real_input[format_str] = format; + } + real_input[name] = src; + real_input[ori_format] = kOpFormat_NCHW; + for (auto shape : output_shape) { + real_input[ori_shape].push_back(shape); + } + real_input[param_type] = param_type_required; + // obtain inputs shape + for (auto shape : input_shape) { + real_input[shape_str].push_back(shape); + } + real_input[valid] = true; + input_json.push_back(real_input); + input.push_back(input_json); + return input; +} + +nlohmann::json ConstructOutputs(const std::vector &output_shape, mindspore::TypeId type) { + nlohmann::json output; + nlohmann::json output_json; + nlohmann::json real_output; + real_output[dtype] = type_id_name_map.at(type); + real_output[format_str] = kOpFormat_NCHW; + real_output[name] = dst; + real_output[ori_format] = kOpFormat_NCHW; + for (auto shape : output_shape) { + real_output[ori_shape].push_back(shape); + } + real_output[param_type] = param_type_required; + // obtain outputs shape + for (auto shape : output_shape) { + real_output[shape_str].push_back(shape); + } + real_output[valid] = true; + output_json.push_back(real_output); + output.push_back(output_json); + return output; +} + +nlohmann::json ConstructTransDataKernelJson(const std::vector &host_shape, + const std::vector &device_shape, const std::string &format, + mindspore::TypeId type) { + // generate kernel json + nlohmann::json kernel_json; + kernel_json[gen_model_str] = gen_model_single; + kernel_json[impl_path_str] = ""; + // construct op_info + nlohmann::json op_info; + op_info[attrs_str] = ConstructAttrs(format); + op_info[inputs_str] = ConstructInputs(device_shape, host_shape, format, type); + op_info[kernel_name_str] = ""; + op_info[name] = trans_data; + op_info[outputs_str] = ConstructOutputs(host_shape, type); + kernel_json[op_info_str] = op_info; + kernel_json[platform_str] = platform_tbe; + std::string json_str = kernel_json[op_info_str].dump(); + size_t hash_id = std::hash()(json_str); + const std::string op_name = op_info[name]; + const std::string json_name = op_name + "_" + std::to_string(hash_id); + kernel_json[op_info_str][kernel_name_str] = json_name; + return kernel_json; +} + void AscendDeviceAddress::SyncStream() const { MS_LOG(INFO) << "Start!"; auto ms_context = MsContext::GetInstance(); @@ -158,31 +311,186 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector &shape, size_t return sync_ok; } +void AscendDeviceAddress::LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, void *output_address_ptr, + size_t output_size, const std::vector &workspace_size_list) const { + MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + auto input_address = std::make_shared(); + MS_EXCEPTION_IF_NULL(input_address); + input_address->addr = ptr_; + input_address->size = size_; + auto output_address = std::make_shared(); + MS_EXCEPTION_IF_NULL(output_address); + output_address->addr = output_address_ptr; + output_address->size = output_size; + AddressPtrList kernel_inputs = {input_address}; + AddressPtrList kernel_outputs = {output_address}; + AddressPtrList kernel_workspaces; + std::vector workspaces_address_ptr(workspace_size_list.size(), nullptr); + if (!workspace_size_list.empty()) { + for (size_t i = 0; i < workspace_size_list.size(); ++i) { + auto workspace_size = GetCommonAlignSize(workspace_size_list[i]); + auto ret_malloc = rtMalloc(&workspaces_address_ptr[i], workspace_size, RT_MEMORY_HBM); + if (ret_malloc != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Failed to rtMalloc memory"; + } + auto workspace_address = std::make_shared(); + MS_EXCEPTION_IF_NULL(workspace_address); + workspace_address->addr = workspaces_address_ptr[i]; + workspace_address->size = workspace_size; + kernel_workspaces.push_back(workspace_address); + } + } + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + auto device_id = ms_context->device_id(); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); + MS_EXCEPTION_IF_NULL(runtime_instance); + auto ret = + runtime_instance->LaunchTaskBasedOnSingleKernel(kernel_mod_ptr, kernel_inputs, kernel_outputs, kernel_workspaces); + if (!ret) { + MS_LOG(ERROR) << "Launch kernel failed."; + } + SyncStream(); + if (!workspace_size_list.empty()) { + for (size_t i = 0; i < workspace_size_list.size(); ++i) { + auto ret_free = rtFree(workspaces_address_ptr[i]); + if (ret_free != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Failed to rtFree memory"; + } + } + } +} + +kernel::KernelModPtr AscendDeviceAddress::CompileTransDataAndObtainKernelMod(const nlohmann::json &kernel_json) const { + static std::set constructed_kernel; + auto build_manager = std::make_shared(); + MS_EXCEPTION_IF_NULL(build_manager); + std::string processor = process_aicore; + // get size + std::vector input_size_list; + std::vector output_size_list; + (void)kernel::TbeKernelBuild::GetIOSize(kernel_json, &input_size_list, &output_size_list); + std::string json_name = kernel_json[op_info_str][kernel_name_str]; + // op build + if (constructed_kernel.find(json_name) == constructed_kernel.end()) { + auto task_id = build_manager->StartCompileOp(kernel_json); + build_manager->SaveTaskInfo(task_id, nullptr, json_name, input_size_list, output_size_list); + } + while (!build_manager->IsAllTaskFinish()) { + int task_id = -1; + char *task_result = nullptr; + char *pre_build_result = nullptr; + auto ret = build_manager->WaitOne(&task_id, &task_result, &pre_build_result); + if (!ret) { + MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id; + } + if ((task_result != nullptr) && (strcmp(task_result, "Success") != 0)) { + MS_EXCEPTION(ArgumentError) << "task compile Failed, task id:" << task_id << ", cause:" << task_result; + } + (void)build_manager->TaskFinishProcess(task_id, false); + } + constructed_kernel.insert(json_name); + // search cache + auto cached_kernel_pack = TbeUtils::SearchCache(json_name, processor); + MS_EXCEPTION_IF_NULL(cached_kernel_pack); + auto kernel_mod_ptr = + build_manager->GenKernelMod(json_name, processor, input_size_list, output_size_list, cached_kernel_pack); + return kernel_mod_ptr; +} + +bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const std::vector &host_shape, + const std::vector &device_shape, + size_t size, mindspore::TypeId type, + void *host_ptr) const { + bool sync_ok = true; + // construct trans data kernel json + nlohmann::json kernel_json = ConstructTransDataKernelJson(host_shape, device_shape, format_, type_id_); + MS_LOG(INFO) << "Construct trans_data kernel json: " << kernel_json.dump(); + auto kernel_mod_ptr = CompileTransDataAndObtainKernelMod(kernel_json); + MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + auto host_size = size; + if (type_id_ != type) { + auto device_dtype_size = trans::TypeIdSize(type_id_); + if (device_dtype_size < 1) { + MS_LOG(ERROR) << "Illegal dtype."; + } + auto shape_size = trans::ShapeSize(host_shape); + auto size_tmp = device_dtype_size * shape_size; + size = GetCommonAlignSize(size_tmp); + } + void *output_address_ptr = nullptr; + auto ret_malloc = rtMalloc(&output_address_ptr, size, RT_MEMORY_HBM); + if (ret_malloc != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Failed to rtMalloc memory"; + } + auto workspace_size_list = GetWorkspaceSizeList(kernel_json); + // launch + LaunchTransData(kernel_mod_ptr, output_address_ptr, size, workspace_size_list); + if (type_id_ == type) { + SyncMemory(host_ptr, output_address_ptr, size, RT_MEMCPY_DEVICE_TO_HOST); + } else { + auto host = std::vector(size); + SyncMemory(host.data(), output_address_ptr, size, RT_MEMCPY_DEVICE_TO_HOST); + auto shape_size = trans::ShapeSize(host_shape); + const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, host_size}; + sync_ok = trans::TransDataType(type_args, host_ptr); + if (!sync_ok) { + MS_LOG(ERROR) << "Trans format failed."; + return false; + } + } + auto ret_free = rtFree(output_address_ptr); + if (ret_free != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Failed to rtFree memory"; + } + return sync_ok; +} + +std::vector AscendDeviceAddress::GetWorkspaceSizeList(const nlohmann::json &kernel_json) const { + std::string json_name = kernel_json[op_info_str][kernel_name_str]; + std::string processor = process_aicore; + auto cached_kernel_pack = TbeUtils::SearchCache(json_name, processor); + MS_EXCEPTION_IF_NULL(cached_kernel_pack); + auto kernel_json_info = cached_kernel_pack->kernel_json_info(); + return kernel_json_info.workspaces; +} + +std::vector AscendDeviceAddress::GetDeviceShape(std::vector *host_shape) const { + std::vector device_shape; + if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { + device_shape = trans::TransShapeToDevice(*host_shape, format_); + } else { + if (host_shape_.empty()) { + *host_shape = trans::PaddingShapeTo4d(*host_shape); + } else { + host_shape->clear(); + (void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(*host_shape), IntToSize); + } + device_shape = trans::TransShapeToDevice(*host_shape, format_); + } + return device_shape; +} + bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector &shape, size_t size, mindspore::TypeId type, void *host_ptr) const { MS_LOG(INFO) << "SyncDeviceToHostAndConvertFormat, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; bool sync_ok = false; - auto host_tmp = std::vector(size_); - SyncMemory(host_tmp.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); std::vector host_shape; (void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); - std::vector device_shape; if (host_shape.empty()) { host_shape.emplace_back(1); } - if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { - device_shape = trans::TransShapeToDevice(host_shape, format_); - } else { - if (host_shape_.empty()) { - host_shape = trans::PaddingShapeTo4d(host_shape); - } else { - host_shape.clear(); - (void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(host_shape), IntToSize); + std::vector device_shape = GetDeviceShape(&host_shape); + if (type_id_name_map.find(type_id_) != type_id_name_map.end()) { + std::pair type_format = std::make_pair(type_id_name_map.at(type_id_), format_); + if (use_trans_data.find(type_format) != use_trans_data.end()) { + sync_ok = SyncDeviceToHostAndConvertFormatBasedOnTransData(host_shape, device_shape, size, type, host_ptr); + return sync_ok; } - - device_shape = trans::TransShapeToDevice(host_shape, format_); } + auto host_tmp = std::vector(size_); + SyncMemory(host_tmp.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); if (type_id_ != type) { const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_, host_shape, device_shape, type_id_}; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h index b554e560cf773fc09efa65524c51113d0839666d..6f1bec2d1581afeb98be039c459c21eadbcc3dc1 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h @@ -20,9 +20,11 @@ #include #include #include +#include #include "runtime/device/device_address.h" #include "runtime/device/ascend/ascend_memory_pool.h" #include "ir/dtype.h" +#include "backend/kernel_compiler/kernel.h" namespace mindspore { #ifdef ENABLE_DEBUGGER @@ -53,7 +55,16 @@ class AscendDeviceAddress : public DeviceAddress { bool SyncDeviceToHostAndConvertFormat(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const; bool ConvertFormatAndSyncHostToDevice(const std::vector &shape, size_t size, TypeId type, const void *host_ptr) const; + bool SyncDeviceToHostAndConvertFormatBasedOnTransData(const std::vector &host_shape, + const std::vector &device_shape, size_t size, + mindspore::TypeId type, void *host_ptr) const; void SyncStream() const; + + void LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, void *output_address_ptr, size_t output_size, + const std::vector &workspace_size_list) const; + std::vector GetDeviceShape(std::vector *host_shape) const; + std::vector GetWorkspaceSizeList(const nlohmann::json &kernel_json) const; + kernel::KernelModPtr CompileTransDataAndObtainKernelMod(const nlohmann::json &kernel_json) const; }; using AscendDeviceAddressPtr = std::shared_ptr; } // namespace ascend diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index d6cce971c2565168866931eecda7e60bda1e7a39..77d19e6121710c14027ddd15ccad9fe2a5c12c1c 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -757,6 +757,18 @@ void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) { MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource"; } +bool KernelRuntime::LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, AddressPtrList kernel_inputs, + AddressPtrList kernel_outputs, + AddressPtrList kernel_workspaces) const { + MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + auto ret = kernel_mod_ptr->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); + if (!ret) { + MS_LOG(ERROR) << "Launch kernel failed."; + return false; + } + return true; +} + #ifdef ENABLE_DUMP_E2E bool KernelRuntime::SetDumpConf() { dump_conf_ptr_ = std::make_shared(); diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h index 41cbd6f4e4912b3d5e7105df5aaf0b8a0f1f672d..e84b55f00778a6e4d9619d1543201cb2bcd04c8c 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -61,6 +61,8 @@ class KernelRuntime { virtual bool RunTask(const session::KernelGraph *graph); virtual bool GenTask(const session::KernelGraph *graph); bool LaunchKernel(const session::KernelGraph *graph); + bool LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, AddressPtrList kernel_inputs, + AddressPtrList kernel_outputs, AddressPtrList kernel_workspaces) const; virtual void AssignStaticMemoryInput(const session::KernelGraph *graph); virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph); virtual void ClearGraphRuntimeResource(uint32_t graph_id);