提交 a8b2c5a0 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1335 check memcpy_async task src size and dst size

Merge pull request !1335 from caifubi/get-device-datatype-in-memcpy
......@@ -137,8 +137,9 @@ bool TaskGenerator::LaunchAllKernel(const std::vector<CNodePtr> &anf_node_list,
for (const auto &anf_node_ptr : anf_node_list) {
size_t old_size = task_info_list->size();
uint32_t stream_id = AnfAlgo::GetStreamId(anf_node_ptr);
MS_EXCEPTION_IF_NULL(anf_node_ptr);
MS_LOG(INFO) << "Task gen launch begin, current_op_idx:" << current_op_index
<< " type:" << (AnfAlgo::GetCNodeName(anf_node_ptr)) << ", stream id:" << stream_id;
<< " name:" << anf_node_ptr->fullname_with_scope() << ", stream id:" << stream_id;
if (!LaunchKernel(anf_node_ptr, stream_id, task_info_list)) {
MS_LOG(ERROR) << "LaunchKernel failed.";
return false;
......
......@@ -658,31 +658,10 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
#if defined(_WIN32) || defined(_WIN64)
auto start_time = std::chrono::steady_clock::now();
#else
struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr);
#endif
auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
if (!ret) {
MS_LOG(ERROR) << "Launch kernel failed.";
return false;
} else {
if (AnfAlgo::GetKernelType(kernel) == TBE_KERNEL && !SyncStream()) {
MS_LOG(EXCEPTION) << "SyncStream failed.";
}
#if defined(_WIN32) || defined(_WIN64)
auto end_time = std::chrono::steady_clock::now();
std::chrono::duration<double, std::ratio<1, 1000000>> cost = end_time - start_time;
MS_LOG(DEBUG) << "d " << kernel->fullname_with_scope() << " in " << cost.count() << " us";
#else
(void)gettimeofday(&end_time, nullptr);
const uint64_t kUSecondInSecond = 1000000;
uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
MS_LOG(DEBUG) << "d " << kernel->fullname_with_scope() << " in " << cost << " us";
#endif
}
}
return true;
......
......@@ -48,6 +48,13 @@ bool MemCpyAsyncKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
MS_LOG(INFO) << "input addr is same with output addr , no need exe memcpy async";
return true;
}
if (outputs[0]->size < inputs[0]->size) {
MS_LOG(EXCEPTION) << "rtMemcpyAsync destMax < src size";
}
// input x -> memcpy_async -> AllReduce
if (outputs[0]->size > inputs[0]->size) {
MS_LOG(WARNING) << "rtMemcpyAsync destMax > src size";
}
rtError_t status = rtMemcpyAsync(outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size,
RT_MEMCPY_DEVICE_TO_DEVICE, stream_ptr);
if (status != RT_ERROR_NONE) {
......@@ -70,7 +77,7 @@ void MemCpyAsyncKernel::GetInputOutputDataType(const AnfNodePtr &anf_node) {
if (input_size != 1) {
MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1";
}
input_type_id_ = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, 0);
input_type_id_ = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, 0);
}
void MemCpyAsyncKernel::GetInputOutputTotalCount(const AnfNodePtr &anf_node) {
......@@ -102,6 +109,14 @@ std::vector<TaskInfoPtr> MemCpyAsyncKernel::GenTask(const std::vector<AddressPtr
MS_LOG(EXCEPTION) << "MemCpyAsync op output is not one";
}
if (outputs[0]->size < inputs[0]->size) {
MS_LOG(EXCEPTION) << "rtMemcpyAsync destMax < src size";
}
// input x -> memcpy_async -> AllReduce
if (outputs[0]->size > inputs[0]->size) {
MS_LOG(WARNING) << "rtMemcpyAsync destMax > src size";
}
stream_id_ = stream_id;
std::shared_ptr<MemcpyAsyncTaskInfo> task_info_ptr = std::make_shared<MemcpyAsyncTaskInfo>(
stream_id, outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size, RT_MEMCPY_DEVICE_TO_DEVICE);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册