提交 a659b37a 编写于 作者: D dongdaxiang

make lodtensor_printer usable in gpu setting

test=develop
上级 ceac9df8
...@@ -52,16 +52,26 @@ void PrintVar(framework::Scope* scope, const std::string& var_name, ...@@ -52,16 +52,26 @@ void PrintVar(framework::Scope* scope, const std::string& var_name,
return; return;
} }
framework::LoDTensor printed_tensor;
printed_tensor.set_lod(tensor->lod());
printed_tensor.Resize(tensor->dims());
if (platform::is_cpu_place(tensor->place())) {
printed_tensor.ShareDataWith(*tensor);
} else {
platform::CPUPlace place;
framework::TensorCopy(*tensor, place, &printed_tensor);
}
#define PrintLoDTensorCallback(cpp_type, proto_type) \ #define PrintLoDTensorCallback(cpp_type, proto_type) \
do { \ do { \
if (tensor->type() == proto_type) { \ if (tensor->type() == proto_type) { \
print_lod_tensor<cpp_type>(var_name, *tensor, print_info); \ print_lod_tensor<cpp_type>(var_name, printed_tensor, print_info); \
return; \ return; \
} \ } \
} while (0) } while (0)
_ForEachDataType_(PrintLoDTensorCallback); _ForEachDataType_(PrintLoDTensorCallback);
VLOG(1) << "PrintVar: unrecognized data type:" << tensor->type(); VLOG(1) << "PrintVar: unrecognized data type:" << printed_tensor.type();
} }
} // end namespace platform } // end namespace platform
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册