From 54d3b5a1ebdb11686d60c024a4fc82c1e2927709 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Sat, 11 Apr 2020 14:04:21 +0800 Subject: [PATCH] enhance err_msg, test=develop (#23714) --- paddle/fluid/imperative/op_base.h | 26 +++++++++++++++++---- paddle/fluid/imperative/tests/test_layer.cc | 9 +++++++ 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/imperative/op_base.h b/paddle/fluid/imperative/op_base.h index 079177acc51..fa0e66ee1e4 100644 --- a/paddle/fluid/imperative/op_base.h +++ b/paddle/fluid/imperative/op_base.h @@ -42,13 +42,25 @@ class OpBase { ~OpBase() { VLOG(3) << "Destruct Op: " << Type(); } - const std::string& Type() const { return op_->Type(); } + const std::string& Type() const { + return op_ ? op_->Type() : UnknownOpType(); + } const framework::AttributeMap& Attrs() const { return attrs_; } - const framework::OpInfo& Info() const { return op_->Info(); } + const framework::OpInfo& Info() const { + PADDLE_ENFORCE_NOT_NULL(op_, platform::errors::PreconditionNotMet( + "OpBase::Info() should be called after " + "OpBase::SetType() is called")); + return op_->Info(); + } - const framework::OperatorBase& InnerOp() const { return *op_; } + const framework::OperatorBase& InnerOp() const { + PADDLE_ENFORCE_NOT_NULL(op_, platform::errors::PreconditionNotMet( + "OpBase::InnerOp() should be called after " + "OpBase::SetType() is called")); + return *op_; + } void ClearBackwardTrace(); @@ -63,7 +75,7 @@ class OpBase { void SetType(const std::string& type); void CheckAttrs() { - auto& info = op_->Info(); + auto& info = Info(); if (info.Checker() != nullptr) { info.Checker()->Check(&attrs_, true); } @@ -150,6 +162,12 @@ class OpBase { const framework::AttributeMap& attrs, const platform::Place& place); + private: + static const std::string& UnknownOpType() { + static std::string kUnknownOpType{"unknown"}; + return kUnknownOpType; + } + private: NameVarMap ins_; NameVarMap outs_; diff --git a/paddle/fluid/imperative/tests/test_layer.cc b/paddle/fluid/imperative/tests/test_layer.cc index f249a09f4b5..36c448402d7 100644 --- a/paddle/fluid/imperative/tests/test_layer.cc +++ b/paddle/fluid/imperative/tests/test_layer.cc @@ -271,6 +271,15 @@ TEST(test_layer, test_dygraph_infershape_context) { ASSERT_EQ(have_z, false); } +TEST(test_layer, test_inner_op_not_inited) { + OpBase op; + std::string kUnknown = "unknown"; + ASSERT_EQ(op.Type(), kUnknown); + ASSERT_THROW(op.Info(), platform::EnforceNotMet); + ASSERT_THROW(op.InnerOp(), platform::EnforceNotMet); + ASSERT_THROW(op.CheckAttrs(), platform::EnforceNotMet); +} + } // namespace imperative } // namespace paddle -- GitLab