提交 12214124 编写于 作者: S Siddharth Goyal 提交者: Abhinav Arora

Fix cpplint for print_op (#10070)

* Fix print op cpplint errors

* Remove commented code
上级 8113de94
...@@ -23,15 +23,15 @@ namespace operators { ...@@ -23,15 +23,15 @@ namespace operators {
#define CLOG std::cout #define CLOG std::cout
const std::string kForward = "FORWARD"; const char kForward[] = "FORWARD";
const std::string kBackward = "BACKWARD"; const char kBackward[] = "BACKWARD";
const std::string kBoth = "BOTH"; const char kBoth[] = "BOTH";
struct Formater { struct Formater {
std::string message; std::string message;
std::string name; std::string name;
std::vector<int> dims; std::vector<int> dims;
std::type_index dtype{typeid(char)}; std::type_index dtype{typeid(const char)};
framework::LoD lod; framework::LoD lod;
int summarize; int summarize;
void* data{nullptr}; void* data{nullptr};
...@@ -62,7 +62,7 @@ struct Formater { ...@@ -62,7 +62,7 @@ struct Formater {
} }
} }
void PrintDtype() { void PrintDtype() {
if (dtype.hash_code() != typeid(char).hash_code()) { if (dtype.hash_code() != typeid(const char).hash_code()) {
CLOG << "\tdtype: " << dtype.name() << std::endl; CLOG << "\tdtype: " << dtype.name() << std::endl;
} }
} }
...@@ -83,15 +83,15 @@ struct Formater { ...@@ -83,15 +83,15 @@ struct Formater {
void PrintData(size_t size) { void PrintData(size_t size) {
PADDLE_ENFORCE_NOT_NULL(data); PADDLE_ENFORCE_NOT_NULL(data);
// print float // print float
if (dtype.hash_code() == typeid(float).hash_code()) { if (dtype.hash_code() == typeid(const float).hash_code()) {
Display<float>(size); Display<float>(size);
} else if (dtype.hash_code() == typeid(double).hash_code()) { } else if (dtype.hash_code() == typeid(const double).hash_code()) {
Display<double>(size); Display<double>(size);
} else if (dtype.hash_code() == typeid(int).hash_code()) { } else if (dtype.hash_code() == typeid(const int).hash_code()) {
Display<int>(size); Display<int>(size);
} else if (dtype.hash_code() == typeid(int64_t).hash_code()) { } else if (dtype.hash_code() == typeid(const int64_t).hash_code()) {
Display<int64_t>(size); Display<int64_t>(size);
} else if (dtype.hash_code() == typeid(bool).hash_code()) { } else if (dtype.hash_code() == typeid(const bool).hash_code()) {
Display<bool>(size); Display<bool>(size);
} else { } else {
CLOG << "\tdata: unprintable type: " << dtype.name() << std::endl; CLOG << "\tdata: unprintable type: " << dtype.name() << std::endl;
...@@ -100,7 +100,7 @@ struct Formater { ...@@ -100,7 +100,7 @@ struct Formater {
template <typename T> template <typename T>
void Display(size_t size) { void Display(size_t size) {
auto* d = (T*)data; auto* d = reinterpret_cast<T*>(data);
CLOG << "\tdata: "; CLOG << "\tdata: ";
if (summarize != -1) { if (summarize != -1) {
summarize = std::min(size, (size_t)summarize); summarize = std::min(size, (size_t)summarize);
...@@ -135,7 +135,7 @@ class TensorPrintOp : public framework::OperatorBase { ...@@ -135,7 +135,7 @@ class TensorPrintOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override { const platform::Place& place) const override {
const framework::Variable* in_var_ptr = nullptr; const framework::Variable* in_var_ptr = nullptr;
std::string phase = kForward; std::string phase(kForward);
std::string printed_var_name = ""; std::string printed_var_name = "";
auto& inputs = Inputs(); auto& inputs = Inputs();
...@@ -146,7 +146,7 @@ class TensorPrintOp : public framework::OperatorBase { ...@@ -146,7 +146,7 @@ class TensorPrintOp : public framework::OperatorBase {
!Inputs("In@GRAD").empty()) { !Inputs("In@GRAD").empty()) {
in_var_ptr = scope.FindVar(Input("In@GRAD")); in_var_ptr = scope.FindVar(Input("In@GRAD"));
printed_var_name = Inputs("In@GRAD").front(); printed_var_name = Inputs("In@GRAD").front();
phase = kBackward; phase = std::string(kBackward);
} else { } else {
PADDLE_THROW("Unknown phase, should be forward or backward."); PADDLE_THROW("Unknown phase, should be forward or backward.");
} }
...@@ -163,7 +163,7 @@ class TensorPrintOp : public framework::OperatorBase { ...@@ -163,7 +163,7 @@ class TensorPrintOp : public framework::OperatorBase {
out_tensor.set_lod(in_tensor.lod()); out_tensor.set_lod(in_tensor.lod());
std::string print_phase = Attr<std::string>("print_phase"); std::string print_phase = Attr<std::string>("print_phase");
if (print_phase != phase && print_phase != kBoth) { if (print_phase != phase && print_phase != std::string(kBoth)) {
return; return;
} }
...@@ -199,7 +199,7 @@ class TensorPrintOp : public framework::OperatorBase { ...@@ -199,7 +199,7 @@ class TensorPrintOp : public framework::OperatorBase {
formater.lod = printed_tensor.lod(); formater.lod = printed_tensor.lod();
} }
formater.summarize = Attr<int>("summarize"); formater.summarize = Attr<int>("summarize");
formater.data = (void*)printed_tensor.data<void>(); formater.data = reinterpret_cast<void*>(printed_tensor.data<void>());
formater(printed_tensor.numel()); formater(printed_tensor.numel());
} }
...@@ -223,8 +223,9 @@ class PrintOpProtoAndCheckMaker : public framework::OpProtoAndCheckerMaker { ...@@ -223,8 +223,9 @@ class PrintOpProtoAndCheckMaker : public framework::OpProtoAndCheckerMaker {
"print_phase", "print_phase",
"(string, default 'BOTH') Which phase to display including 'FORWARD' " "(string, default 'BOTH') Which phase to display including 'FORWARD' "
"'BACKWARD' and 'BOTH'.") "'BACKWARD' and 'BOTH'.")
.SetDefault(kBoth) .SetDefault(std::string(kBoth))
.InEnum({kForward, kBackward, kBoth}); .InEnum({std::string(kForward), std::string(kBackward),
std::string(kBoth)});
AddOutput("Out", "Output tensor with same data as input tensor."); AddOutput("Out", "Output tensor with same data as input tensor.");
AddComment(R"DOC( AddComment(R"DOC(
Creates a print op that will print when a tensor is accessed. Creates a print op that will print when a tensor is accessed.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册