未验证 提交 34301732 编写于 作者: C chengduo 提交者: GitHub

Polish Print Op (#17651)

* enhance print
上级 4aa931dd
......@@ -26,6 +26,19 @@ const char kForward[] = "FORWARD";
const char kBackward[] = "BACKWARD";
const char kBoth[] = "BOTH";
class LogGuard {
public:
inline LogGuard() { LogMutex().lock(); }
inline ~LogGuard() { LogMutex().unlock(); }
private:
static std::mutex &LogMutex() {
static std::mutex mtx;
return mtx;
}
};
struct Formater {
std::string message;
std::string name;
......@@ -34,48 +47,54 @@ struct Formater {
framework::LoD lod;
int summarize;
void *data{nullptr};
platform::Place place;
std::stringstream logs;
void operator()(size_t size) {
PrintMessage();
PrintPlaceInfo();
PrintName();
PrintDims();
PrintDtype();
PrintLod();
PrintData(size);
LogGuard guard;
CLOG << logs.str();
}
private:
void PrintMessage() { CLOG << std::time(nullptr) << "\t" << message << "\t"; }
void PrintPlaceInfo() { logs << "The place is:" << place << std::endl; }
void PrintMessage() { logs << std::time(nullptr) << "\t" << message << "\t"; }
void PrintName() {
if (!name.empty()) {
CLOG << "Tensor[" << name << "]" << std::endl;
logs << "Tensor[" << name << "]" << std::endl;
}
}
void PrintDims() {
if (!dims.empty()) {
CLOG << "\tshape: [";
logs << "\tshape: [";
for (auto i : dims) {
CLOG << i << ",";
logs << i << ",";
}
CLOG << "]" << std::endl;
logs << "]" << std::endl;
}
}
void PrintDtype() {
if (!framework::IsType<const char>(dtype)) {
CLOG << "\tdtype: " << dtype.name() << std::endl;
logs << "\tdtype: " << dtype.name() << std::endl;
}
}
void PrintLod() {
if (!lod.empty()) {
CLOG << "\tLoD: [";
logs << "\tLoD: [";
for (auto level : lod) {
CLOG << "[ ";
logs << "[ ";
for (auto i : level) {
CLOG << i << ",";
logs << i << ",";
}
CLOG << " ]";
logs << " ]";
}
CLOG << "]" << std::endl;
logs << "]" << std::endl;
}
}
......@@ -93,25 +112,25 @@ struct Formater {
} else if (framework::IsType<const bool>(dtype)) {
Display<bool>(size);
} else {
CLOG << "\tdata: unprintable type: " << dtype.name() << std::endl;
logs << "\tdata: unprintable type: " << dtype.name() << std::endl;
}
}
template <typename T>
void Display(size_t size) {
auto *d = reinterpret_cast<T *>(data);
CLOG << "\tdata: ";
logs << "\tdata: ";
if (summarize != -1) {
summarize = std::min(size, (size_t)summarize);
for (int i = 0; i < summarize; i++) {
CLOG << d[i] << ",";
logs << d[i] << ",";
}
} else {
for (size_t i = 0; i < size; i++) {
CLOG << d[i] << ",";
logs << d[i] << ",";
}
}
CLOG << std::endl;
logs << std::endl;
}
};
......@@ -167,6 +186,7 @@ class TensorPrintOp : public framework::OperatorBase {
}
Formater formater;
formater.place = place;
formater.message = Attr<std::string>("message");
if (Attr<bool>("print_tensor_name")) {
formater.name = printed_var_name;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册