From 3011949033fe727912c1f82f10272cabfcba64ab Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Mon, 27 Jul 2020 13:25:07 +0800 Subject: [PATCH] [Cherry-pick] Append error op hint for GradOpMaker and Dygraph (#25704) * Append error op hint for GradOpMaker (#24750) * append error op hint for grad op maker, test=develop * add unittests for coverage, test=develop * append try-catch to opbase run, test=develop (#24870) --- paddle/fluid/framework/CMakeLists.txt | 1 + paddle/fluid/framework/grad_op_desc_maker.h | 20 ++++++- paddle/fluid/framework/op_call_stack.cc | 12 +++- paddle/fluid/framework/op_call_stack.h | 7 +++ paddle/fluid/framework/op_call_stack_test.cc | 61 ++++++++++++++++++++ paddle/fluid/imperative/dygraph_grad_maker.h | 2 + paddle/fluid/imperative/tracer.cc | 18 +++++- 7 files changed, 115 insertions(+), 6 deletions(-) create mode 100644 paddle/fluid/framework/op_call_stack_test.cc diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 361f25d09e7..5e77ac4c6a3 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -148,6 +148,7 @@ cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) cc_library(op_call_stack SRCS op_call_stack.cc DEPS op_proto_maker enforce) +cc_test(op_call_stack_test SRCS op_call_stack_test.cc DEPS op_call_stack) nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) diff --git a/paddle/fluid/framework/grad_op_desc_maker.h b/paddle/fluid/framework/grad_op_desc_maker.h index 8d55c79a0dd..b1ca10c6155 100644 --- a/paddle/fluid/framework/grad_op_desc_maker.h +++ b/paddle/fluid/framework/grad_op_desc_maker.h @@ -18,7 +18,9 @@ limitations under the License. */ #include #include #include +#include #include +#include "paddle/fluid/framework/op_call_stack.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/imperative/dygraph_grad_maker.h" @@ -195,7 +197,14 @@ class SingleGradOpMaker : public GradOpDescMakerBase { std::vector> operator()() const { std::vector> retv; retv.emplace_back(new OpDesc()); - this->Apply(retv.front().get()); + try { + this->Apply(retv.front().get()); + } catch (platform::EnforceNotMet& exception) { + framework::AppendErrorOpHint(retv.front().get()->Type(), &exception); + throw std::move(exception); + } catch (...) { + std::rethrow_exception(std::current_exception()); + } return retv; } @@ -213,7 +222,14 @@ class SingleGradOpMaker auto node = this->NewGradNode(); { imperative::TracedGradOp traced_grad_op(node); - this->Apply(&traced_grad_op); + try { + this->Apply(&traced_grad_op); + } catch (platform::EnforceNotMet& exception) { + framework::AppendErrorOpHint(traced_grad_op.Type(), &exception); + throw std::move(exception); + } catch (...) { + std::rethrow_exception(std::current_exception()); + } } return node->empty() ? nullptr : node; } diff --git a/paddle/fluid/framework/op_call_stack.cc b/paddle/fluid/framework/op_call_stack.cc index c3c56210b62..dee98969902 100644 --- a/paddle/fluid/framework/op_call_stack.cc +++ b/paddle/fluid/framework/op_call_stack.cc @@ -56,9 +56,15 @@ void InsertCallStackInfo(const std::string &type, const AttributeMap &attrs, } // Step 3. Construct final call stack & append error op name sout << exception->err_str_; - if (callstack) { - sout << " [operator < " << type << " > error]"; - } + sout << " [operator < " << type << " > error]"; + exception->err_str_ = sout.str(); +} + +void AppendErrorOpHint(const std::string &type, + platform::EnforceNotMet *exception) { + std::ostringstream sout; + sout << exception->err_str_; + sout << " [operator < " << type << " > error]"; exception->err_str_ = sout.str(); } diff --git a/paddle/fluid/framework/op_call_stack.h b/paddle/fluid/framework/op_call_stack.h index 4408601abf0..d48cf27285a 100644 --- a/paddle/fluid/framework/op_call_stack.h +++ b/paddle/fluid/framework/op_call_stack.h @@ -20,7 +20,14 @@ limitations under the License. */ namespace paddle { namespace framework { + +// insert python call stack & append error op for exception message void InsertCallStackInfo(const std::string &type, const AttributeMap &attrs, platform::EnforceNotMet *exception); + +// only append error op for exception message +void AppendErrorOpHint(const std::string &type, + platform::EnforceNotMet *exception); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/op_call_stack_test.cc b/paddle/fluid/framework/op_call_stack_test.cc new file mode 100644 index 00000000000..93db97a93f4 --- /dev/null +++ b/paddle/fluid/framework/op_call_stack_test.cc @@ -0,0 +1,61 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_call_stack.h" + +#include +#include + +#include "gtest/gtest.h" + +namespace paddle { +namespace framework { +namespace details { + +static void ThrowEnforceNotMet() { + PADDLE_THROW(platform::errors::InvalidArgument( + "\n----------------------\nError Message " + "Summary:\n----------------------\n" + "Created error.")); +} + +} // namespace details +} // namespace framework +} // namespace paddle + +TEST(OpCallStack, InsertCallStackInfo) { + try { + paddle::framework::details::ThrowEnforceNotMet(); + } catch (paddle::platform::EnforceNotMet &exception) { + paddle::framework::AttributeMap attr_map; + std::string stack_test_str = "test for op callstack"; + std::vector stack_test_vec; + stack_test_vec.emplace_back(stack_test_str); + attr_map["op_callstack"] = stack_test_vec; + paddle::framework::InsertCallStackInfo("test", attr_map, &exception); + std::string ex_msg = exception.what(); + EXPECT_TRUE(ex_msg.find(stack_test_str) != std::string::npos); + EXPECT_TRUE(ex_msg.find("[operator < test > error]") != std::string::npos); + } +} + +TEST(OpCallStack, AppendErrorOpHint) { + try { + paddle::framework::details::ThrowEnforceNotMet(); + } catch (paddle::platform::EnforceNotMet &exception) { + paddle::framework::AppendErrorOpHint("test", &exception); + std::string ex_msg = exception.what(); + EXPECT_TRUE(ex_msg.find("[operator < test > error]") != std::string::npos); + } +} diff --git a/paddle/fluid/imperative/dygraph_grad_maker.h b/paddle/fluid/imperative/dygraph_grad_maker.h index 757f4193690..07a18a1a0dc 100644 --- a/paddle/fluid/imperative/dygraph_grad_maker.h +++ b/paddle/fluid/imperative/dygraph_grad_maker.h @@ -258,6 +258,8 @@ class TracedGradOp { } } + std::string Type() const { return op_->Type(); } + void SetType(const std::string& type) { op_->SetType(type); } void SetAttrMap(const framework::AttributeMap& attrs) { diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 873963db1a1..ee4c5617397 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -53,7 +53,23 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, attr_checker->Check(&attrs, true); } - OpBase::Run(*op, ins, outs, attrs, place); + try { + OpBase::Run(*op, ins, outs, attrs, place); + } catch (platform::EnforceNotMet& exception) { + framework::AppendErrorOpHint(type, &exception); + throw std::move(exception); + } catch (std::exception& ex) { + PADDLE_THROW(platform::errors::Fatal( + "Operator %s raises an %s exception.\n" + "The exception content is\n:%s.", + type, platform::demangle(typeid(ex).name()), ex.what())); + } catch (...) { + // NOTE: this branch represents a very serious bug with + // low probability of occurrence, and we can't get its + // exception content here. + PADDLE_THROW(platform::errors::Fatal( + "Operator %s raises an unknown exception.", type)); + } if (enable_program_desc_tracing_) { VLOG(5) << "Trace op " << type << " into ProgramDesc"; -- GitLab