From 343017324ec731ab4ff2b185aff845c02b43bd12 Mon Sep 17 00:00:00 2001 From: chengduo Date: Mon, 27 May 2019 11:17:24 +0800 Subject: [PATCH] Polish Print Op (#17651) * enhance print --- paddle/fluid/operators/print_op.cc | 52 +++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/operators/print_op.cc b/paddle/fluid/operators/print_op.cc index 6a5bf170600..200b01797e4 100644 --- a/paddle/fluid/operators/print_op.cc +++ b/paddle/fluid/operators/print_op.cc @@ -26,6 +26,19 @@ const char kForward[] = "FORWARD"; const char kBackward[] = "BACKWARD"; const char kBoth[] = "BOTH"; +class LogGuard { + public: + inline LogGuard() { LogMutex().lock(); } + + inline ~LogGuard() { LogMutex().unlock(); } + + private: + static std::mutex &LogMutex() { + static std::mutex mtx; + return mtx; + } +}; + struct Formater { std::string message; std::string name; @@ -34,48 +47,54 @@ struct Formater { framework::LoD lod; int summarize; void *data{nullptr}; + platform::Place place; + std::stringstream logs; void operator()(size_t size) { PrintMessage(); + PrintPlaceInfo(); PrintName(); PrintDims(); PrintDtype(); PrintLod(); PrintData(size); + LogGuard guard; + CLOG << logs.str(); } private: - void PrintMessage() { CLOG << std::time(nullptr) << "\t" << message << "\t"; } + void PrintPlaceInfo() { logs << "The place is:" << place << std::endl; } + void PrintMessage() { logs << std::time(nullptr) << "\t" << message << "\t"; } void PrintName() { if (!name.empty()) { - CLOG << "Tensor[" << name << "]" << std::endl; + logs << "Tensor[" << name << "]" << std::endl; } } void PrintDims() { if (!dims.empty()) { - CLOG << "\tshape: ["; + logs << "\tshape: ["; for (auto i : dims) { - CLOG << i << ","; + logs << i << ","; } - CLOG << "]" << std::endl; + logs << "]" << std::endl; } } void PrintDtype() { if (!framework::IsType(dtype)) { - CLOG << "\tdtype: " << dtype.name() << std::endl; + logs << "\tdtype: " << dtype.name() << std::endl; } } void PrintLod() { if (!lod.empty()) { - CLOG << "\tLoD: ["; + logs << "\tLoD: ["; for (auto level : lod) { - CLOG << "[ "; + logs << "[ "; for (auto i : level) { - CLOG << i << ","; + logs << i << ","; } - CLOG << " ]"; + logs << " ]"; } - CLOG << "]" << std::endl; + logs << "]" << std::endl; } } @@ -93,25 +112,25 @@ struct Formater { } else if (framework::IsType(dtype)) { Display(size); } else { - CLOG << "\tdata: unprintable type: " << dtype.name() << std::endl; + logs << "\tdata: unprintable type: " << dtype.name() << std::endl; } } template void Display(size_t size) { auto *d = reinterpret_cast(data); - CLOG << "\tdata: "; + logs << "\tdata: "; if (summarize != -1) { summarize = std::min(size, (size_t)summarize); for (int i = 0; i < summarize; i++) { - CLOG << d[i] << ","; + logs << d[i] << ","; } } else { for (size_t i = 0; i < size; i++) { - CLOG << d[i] << ","; + logs << d[i] << ","; } } - CLOG << std::endl; + logs << std::endl; } }; @@ -167,6 +186,7 @@ class TensorPrintOp : public framework::OperatorBase { } Formater formater; + formater.place = place; formater.message = Attr("message"); if (Attr("print_tensor_name")) { formater.name = printed_var_name; -- GitLab