提交 593c4fc7 编写于 作者: D dengwentao

fix shape used for dump

上级 6721541c
......@@ -151,14 +151,18 @@ void DumpOutput(mindspore::session::KernelGraph *graph, const string &dump_path,
auto output_size = AnfAlgo::GetOutputTensorNum(node);
for (size_t j = 0; j < output_size; ++j) {
auto addr = AnfAlgo::GetOutputAddr(node, j);
auto shape = trans::GetRuntimePaddingShape(node, j);
std::vector<int> int_shapes;
if (trans_flag) {
int_shapes = trans::GetRuntimePaddingShape(node, j);
} else {
auto shape = AnfAlgo::GetOutputDeviceShape(node, j);
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes),
[](size_t inner_item) { return SizeToInt(inner_item); });
}
auto type = AnfAlgo::GetOutputInferDataType(node, j);
auto format = kOpFormat_DEFAULT;
string filepath = dump_path + '/' + kernel_name + '_' + "output_" + std::to_string(j);
auto ascend_addr = dynamic_cast<const mindspore::device::ascend::AscendDeviceAddress *>(addr);
std::vector<int> int_shapes;
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes),
[](size_t inner_item) { return SizeToInt(inner_item); });
auto ret = ascend_addr->DumpMemToFile(trans_flag, filepath, format, int_shapes, type);
if (!ret) {
MS_LOG(ERROR) << "DumpMemToFile Failed: flag:" << trans_flag << ", path:" << filepath
......@@ -182,14 +186,18 @@ void DumpParameters(mindspore::session::KernelGraph *graph, const string &dump_p
continue;
}
auto addr = AnfAlgo::GetOutputAddr(item, PRAMATER_OUTPUT_INDEX);
auto shape = trans::GetRuntimePaddingShape(item, PRAMATER_OUTPUT_INDEX);
std::vector<int> int_shapes;
if (trans_flag) {
int_shapes = trans::GetRuntimePaddingShape(item, PRAMATER_OUTPUT_INDEX);
} else {
auto shape = AnfAlgo::GetOutputDeviceShape(item, PRAMATER_OUTPUT_INDEX);
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes),
[](size_t inner_item) { return SizeToInt(inner_item); });
}
auto type = AnfAlgo::GetOutputInferDataType(item, PRAMATER_OUTPUT_INDEX);
auto format = kOpFormat_DEFAULT;
string filepath = dump_path + '/' + parameter_name + '_' + "output_0";
auto ascend_addr = dynamic_cast<const mindspore::device::ascend::AscendDeviceAddress *>(addr);
std::vector<int> int_shapes;
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes),
[](size_t inner_item) { return SizeToInt(inner_item); });
auto ret = ascend_addr->DumpMemToFile(trans_flag, filepath, format, int_shapes, type);
if (!ret) {
MS_LOG(ERROR) << "DumpMemToFile Failed: flag:" << trans_flag << ", path:" << filepath
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册