提交 8814bec0 编写于 作者: E emailweixu 提交者: dzhwinter

Show argument dimensions with operator::DebugStringEx (#7268)

This can make it easier to locate error.
上级 a9f7cd34
......@@ -111,7 +111,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
VLOG(3) << op->DebugString();
VLOG(3) << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_);
if (FLAGS_check_nan_inf) {
for (auto& vname : op->OutputVars(true)) {
......
......@@ -73,6 +73,17 @@ void UseALL() {
UseCUDNN();
}
static DDim GetDims(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name);
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims();
} else {
return DDim({-1});
}
}
std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL,
......@@ -105,7 +116,7 @@ const std::vector<std::string>& OperatorBase::Outputs(
return it->second;
}
std::string OperatorBase::DebugString() const {
std::string OperatorBase::DebugStringEx(const Scope* scope) const {
std::stringstream ss;
ss << "Op(" << type_ << "), inputs:{";
for (auto it = inputs_.begin(); it != inputs_.end();) {
......@@ -113,6 +124,9 @@ std::string OperatorBase::DebugString() const {
ss << input.first << "[";
for (size_t i = 0; i < input.second.size(); ++i) {
ss << input.second[i];
if (scope) {
ss << "(" << GetDims(*scope, input.second[i]) << ")";
}
if (i != input.second.size() - 1) {
ss << ", ";
}
......@@ -129,6 +143,9 @@ std::string OperatorBase::DebugString() const {
ss << output.first << "[";
for (size_t i = 0; i < output.second.size(); ++i) {
ss << output.second[i];
if (scope) {
ss << "(" << GetDims(*scope, output.second[i]) << ")";
}
if (i != output.second.size() - 1) {
ss << ", ";
}
......
......@@ -108,7 +108,10 @@ class OperatorBase {
return boost::get<T>(attrs_.at(name));
}
virtual std::string DebugString() const;
/// if scope is not null, also show dimensions of arguments
virtual std::string DebugStringEx(const Scope* scope) const;
std::string DebugString() const { return DebugStringEx(nullptr); }
/// Net will call this function to Run an op.
virtual void Run(const Scope& scope, const platform::Place& place) const = 0;
......
......@@ -56,11 +56,11 @@ void NetOp::CompleteAddOp(bool calc) {
std::copy(output_set.begin(), output_set.end(), std::back_inserter(outputs));
}
std::string NetOp::DebugString() const {
std::string NetOp::DebugStringEx(const framework::Scope* scope) const {
std::ostringstream os;
os << OperatorBase::DebugString() << std::endl;
os << OperatorBase::DebugStringEx(scope) << std::endl;
for (auto& op : ops_) {
std::istringstream is(op->DebugString());
std::istringstream is(op->DebugStringEx(scope));
for (std::string line; std::getline(is, line);) {
os << " " << line << std::endl;
}
......
......@@ -106,7 +106,8 @@ class NetOp : public framework::OperatorBase {
void CompleteAddOp(bool calculate = true);
std::string DebugString() const override;
std::string DebugStringEx(
const framework::Scope* scope = nullptr) const override;
bool IsNetOp() const override;
std::vector<std::string> OutputVars(bool has_intermediate) const override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册