提交 a659b37a 编写于 作者: D dongdaxiang

make lodtensor_printer usable in gpu setting

test=develop
上级 ceac9df8
......@@ -52,16 +52,26 @@ void PrintVar(framework::Scope* scope, const std::string& var_name,
return;
}
framework::LoDTensor printed_tensor;
printed_tensor.set_lod(tensor->lod());
printed_tensor.Resize(tensor->dims());
if (platform::is_cpu_place(tensor->place())) {
printed_tensor.ShareDataWith(*tensor);
} else {
platform::CPUPlace place;
framework::TensorCopy(*tensor, place, &printed_tensor);
}
#define PrintLoDTensorCallback(cpp_type, proto_type) \
do { \
if (tensor->type() == proto_type) { \
print_lod_tensor<cpp_type>(var_name, *tensor, print_info); \
print_lod_tensor<cpp_type>(var_name, printed_tensor, print_info); \
return; \
} \
} while (0)
_ForEachDataType_(PrintLoDTensorCallback);
VLOG(1) << "PrintVar: unrecognized data type:" << tensor->type();
VLOG(1) << "PrintVar: unrecognized data type:" << printed_tensor.type();
}
} // end namespace platform
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册