From 8814bec0c5275036e50f9bf12ceb67648419d04f Mon Sep 17 00:00:00 2001 From: emailweixu Date: Sun, 7 Jan 2018 19:54:54 -0800 Subject: [PATCH] Show argument dimensions with operator::DebugStringEx (#7268) This can make it easier to locate error. --- paddle/framework/executor.cc | 2 +- paddle/framework/operator.cc | 19 ++++++++++++++++++- paddle/framework/operator.h | 5 ++++- paddle/operators/net_op.cc | 6 +++--- paddle/operators/net_op.h | 3 ++- 5 files changed, 28 insertions(+), 7 deletions(-) diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index bf1f0471c..844d98916 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -111,7 +111,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, for (auto& op_desc : block.AllOps()) { auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); - VLOG(3) << op->DebugString(); + VLOG(3) << op->DebugStringEx(local_scope); op->Run(*local_scope, place_); if (FLAGS_check_nan_inf) { for (auto& vname : op->OutputVars(true)) { diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index b9dcf16da..4ef0c2523 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -73,6 +73,17 @@ void UseALL() { UseCUDNN(); } +static DDim GetDims(const Scope& scope, const std::string& name) { + Variable* var = scope.FindVar(name); + if (var->IsType()) { + return var->Get().dims(); + } else if (var->IsType()) { + return var->Get().GetCompleteDims(); + } else { + return DDim({-1}); + } +} + std::string OperatorBase::Input(const std::string& name) const { auto& ins = Inputs(name); PADDLE_ENFORCE_LE(ins.size(), 1UL, @@ -105,7 +116,7 @@ const std::vector& OperatorBase::Outputs( return it->second; } -std::string OperatorBase::DebugString() const { +std::string OperatorBase::DebugStringEx(const Scope* scope) const { std::stringstream ss; ss << "Op(" << type_ << "), inputs:{"; for (auto it = inputs_.begin(); it != inputs_.end();) { @@ -113,6 +124,9 @@ std::string OperatorBase::DebugString() const { ss << input.first << "["; for (size_t i = 0; i < input.second.size(); ++i) { ss << input.second[i]; + if (scope) { + ss << "(" << GetDims(*scope, input.second[i]) << ")"; + } if (i != input.second.size() - 1) { ss << ", "; } @@ -129,6 +143,9 @@ std::string OperatorBase::DebugString() const { ss << output.first << "["; for (size_t i = 0; i < output.second.size(); ++i) { ss << output.second[i]; + if (scope) { + ss << "(" << GetDims(*scope, output.second[i]) << ")"; + } if (i != output.second.size() - 1) { ss << ", "; } diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 1f5a4af58..800397c07 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -108,7 +108,10 @@ class OperatorBase { return boost::get(attrs_.at(name)); } - virtual std::string DebugString() const; + /// if scope is not null, also show dimensions of arguments + virtual std::string DebugStringEx(const Scope* scope) const; + + std::string DebugString() const { return DebugStringEx(nullptr); } /// Net will call this function to Run an op. virtual void Run(const Scope& scope, const platform::Place& place) const = 0; diff --git a/paddle/operators/net_op.cc b/paddle/operators/net_op.cc index 78b5e2767..03302f5cb 100644 --- a/paddle/operators/net_op.cc +++ b/paddle/operators/net_op.cc @@ -56,11 +56,11 @@ void NetOp::CompleteAddOp(bool calc) { std::copy(output_set.begin(), output_set.end(), std::back_inserter(outputs)); } -std::string NetOp::DebugString() const { +std::string NetOp::DebugStringEx(const framework::Scope* scope) const { std::ostringstream os; - os << OperatorBase::DebugString() << std::endl; + os << OperatorBase::DebugStringEx(scope) << std::endl; for (auto& op : ops_) { - std::istringstream is(op->DebugString()); + std::istringstream is(op->DebugStringEx(scope)); for (std::string line; std::getline(is, line);) { os << " " << line << std::endl; } diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index 85d0153b3..b24042f5e 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -106,7 +106,8 @@ class NetOp : public framework::OperatorBase { void CompleteAddOp(bool calculate = true); - std::string DebugString() const override; + std::string DebugStringEx( + const framework::Scope* scope = nullptr) const override; bool IsNetOp() const override; std::vector OutputVars(bool has_intermediate) const override; -- GitLab