提交 8814bec0 编写于 作者: E emailweixu 提交者: dzhwinter

Show argument dimensions with operator::DebugStringEx (#7268)

This can make it easier to locate error.
上级 a9f7cd34
...@@ -111,7 +111,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, ...@@ -111,7 +111,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
for (auto& op_desc : block.AllOps()) { for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
VLOG(3) << op->DebugString(); VLOG(3) << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_); op->Run(*local_scope, place_);
if (FLAGS_check_nan_inf) { if (FLAGS_check_nan_inf) {
for (auto& vname : op->OutputVars(true)) { for (auto& vname : op->OutputVars(true)) {
......
...@@ -73,6 +73,17 @@ void UseALL() { ...@@ -73,6 +73,17 @@ void UseALL() {
UseCUDNN(); UseCUDNN();
} }
static DDim GetDims(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name);
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims();
} else {
return DDim({-1});
}
}
std::string OperatorBase::Input(const std::string& name) const { std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name); auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL, PADDLE_ENFORCE_LE(ins.size(), 1UL,
...@@ -105,7 +116,7 @@ const std::vector<std::string>& OperatorBase::Outputs( ...@@ -105,7 +116,7 @@ const std::vector<std::string>& OperatorBase::Outputs(
return it->second; return it->second;
} }
std::string OperatorBase::DebugString() const { std::string OperatorBase::DebugStringEx(const Scope* scope) const {
std::stringstream ss; std::stringstream ss;
ss << "Op(" << type_ << "), inputs:{"; ss << "Op(" << type_ << "), inputs:{";
for (auto it = inputs_.begin(); it != inputs_.end();) { for (auto it = inputs_.begin(); it != inputs_.end();) {
...@@ -113,6 +124,9 @@ std::string OperatorBase::DebugString() const { ...@@ -113,6 +124,9 @@ std::string OperatorBase::DebugString() const {
ss << input.first << "["; ss << input.first << "[";
for (size_t i = 0; i < input.second.size(); ++i) { for (size_t i = 0; i < input.second.size(); ++i) {
ss << input.second[i]; ss << input.second[i];
if (scope) {
ss << "(" << GetDims(*scope, input.second[i]) << ")";
}
if (i != input.second.size() - 1) { if (i != input.second.size() - 1) {
ss << ", "; ss << ", ";
} }
...@@ -129,6 +143,9 @@ std::string OperatorBase::DebugString() const { ...@@ -129,6 +143,9 @@ std::string OperatorBase::DebugString() const {
ss << output.first << "["; ss << output.first << "[";
for (size_t i = 0; i < output.second.size(); ++i) { for (size_t i = 0; i < output.second.size(); ++i) {
ss << output.second[i]; ss << output.second[i];
if (scope) {
ss << "(" << GetDims(*scope, output.second[i]) << ")";
}
if (i != output.second.size() - 1) { if (i != output.second.size() - 1) {
ss << ", "; ss << ", ";
} }
......
...@@ -108,7 +108,10 @@ class OperatorBase { ...@@ -108,7 +108,10 @@ class OperatorBase {
return boost::get<T>(attrs_.at(name)); return boost::get<T>(attrs_.at(name));
} }
virtual std::string DebugString() const; /// if scope is not null, also show dimensions of arguments
virtual std::string DebugStringEx(const Scope* scope) const;
std::string DebugString() const { return DebugStringEx(nullptr); }
/// Net will call this function to Run an op. /// Net will call this function to Run an op.
virtual void Run(const Scope& scope, const platform::Place& place) const = 0; virtual void Run(const Scope& scope, const platform::Place& place) const = 0;
......
...@@ -56,11 +56,11 @@ void NetOp::CompleteAddOp(bool calc) { ...@@ -56,11 +56,11 @@ void NetOp::CompleteAddOp(bool calc) {
std::copy(output_set.begin(), output_set.end(), std::back_inserter(outputs)); std::copy(output_set.begin(), output_set.end(), std::back_inserter(outputs));
} }
std::string NetOp::DebugString() const { std::string NetOp::DebugStringEx(const framework::Scope* scope) const {
std::ostringstream os; std::ostringstream os;
os << OperatorBase::DebugString() << std::endl; os << OperatorBase::DebugStringEx(scope) << std::endl;
for (auto& op : ops_) { for (auto& op : ops_) {
std::istringstream is(op->DebugString()); std::istringstream is(op->DebugStringEx(scope));
for (std::string line; std::getline(is, line);) { for (std::string line; std::getline(is, line);) {
os << " " << line << std::endl; os << " " << line << std::endl;
} }
......
...@@ -106,7 +106,8 @@ class NetOp : public framework::OperatorBase { ...@@ -106,7 +106,8 @@ class NetOp : public framework::OperatorBase {
void CompleteAddOp(bool calculate = true); void CompleteAddOp(bool calculate = true);
std::string DebugString() const override; std::string DebugStringEx(
const framework::Scope* scope = nullptr) const override;
bool IsNetOp() const override; bool IsNetOp() const override;
std::vector<std::string> OutputVars(bool has_intermediate) const override; std::vector<std::string> OutputVars(bool has_intermediate) const override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册