未验证 提交 24e24987 编写于 作者: M mapingshuo 提交者: GitHub

fixes the place info in the Print op (#24934)

fixes the CUDAPlace info in the Print op
上级 6be0ee15
......@@ -73,18 +73,6 @@ class PrintOp : public framework::OperatorBase {
int first_n = Attr<int>("first_n");
if (first_n > 0 && ++times_ > first_n) return;
framework::LoDTensor printed_tensor;
printed_tensor.set_lod(in_tensor.lod());
printed_tensor.Resize(in_tensor.dims());
if (is_cpu_place(in_tensor.place())) {
printed_tensor.ShareDataWith(in_tensor);
} else {
// copy data to cpu to print
platform::CPUPlace place;
TensorCopy(in_tensor, place, &printed_tensor);
}
TensorFormatter formatter;
const std::string &name =
Attr<bool>("print_tensor_name") ? printed_var_name : "";
......@@ -93,7 +81,7 @@ class PrintOp : public framework::OperatorBase {
formatter.SetPrintTensorLod(Attr<bool>("print_tensor_lod"));
formatter.SetPrintTensorLayout(Attr<bool>("print_tensor_layout"));
formatter.SetSummarize(static_cast<int64_t>(Attr<int>("summarize")));
formatter.Print(printed_tensor, name, Attr<std::string>("message"));
formatter.Print(in_tensor, name, Attr<std::string>("message"));
}
private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册