未验证 提交 54d3b5a1 编写于 作者: Z Zeng Jinle 提交者: GitHub

enhance err_msg, test=develop (#23714)

上级 a63bcf9a
...@@ -42,13 +42,25 @@ class OpBase { ...@@ -42,13 +42,25 @@ class OpBase {
~OpBase() { VLOG(3) << "Destruct Op: " << Type(); } ~OpBase() { VLOG(3) << "Destruct Op: " << Type(); }
const std::string& Type() const { return op_->Type(); } const std::string& Type() const {
return op_ ? op_->Type() : UnknownOpType();
}
const framework::AttributeMap& Attrs() const { return attrs_; } const framework::AttributeMap& Attrs() const { return attrs_; }
const framework::OpInfo& Info() const { return op_->Info(); } const framework::OpInfo& Info() const {
PADDLE_ENFORCE_NOT_NULL(op_, platform::errors::PreconditionNotMet(
"OpBase::Info() should be called after "
"OpBase::SetType() is called"));
return op_->Info();
}
const framework::OperatorBase& InnerOp() const { return *op_; } const framework::OperatorBase& InnerOp() const {
PADDLE_ENFORCE_NOT_NULL(op_, platform::errors::PreconditionNotMet(
"OpBase::InnerOp() should be called after "
"OpBase::SetType() is called"));
return *op_;
}
void ClearBackwardTrace(); void ClearBackwardTrace();
...@@ -63,7 +75,7 @@ class OpBase { ...@@ -63,7 +75,7 @@ class OpBase {
void SetType(const std::string& type); void SetType(const std::string& type);
void CheckAttrs() { void CheckAttrs() {
auto& info = op_->Info(); auto& info = Info();
if (info.Checker() != nullptr) { if (info.Checker() != nullptr) {
info.Checker()->Check(&attrs_, true); info.Checker()->Check(&attrs_, true);
} }
...@@ -150,6 +162,12 @@ class OpBase { ...@@ -150,6 +162,12 @@ class OpBase {
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const platform::Place& place); const platform::Place& place);
private:
static const std::string& UnknownOpType() {
static std::string kUnknownOpType{"unknown"};
return kUnknownOpType;
}
private: private:
NameVarMap<VariableWrapper> ins_; NameVarMap<VariableWrapper> ins_;
NameVarMap<VariableWrapper> outs_; NameVarMap<VariableWrapper> outs_;
......
...@@ -271,6 +271,15 @@ TEST(test_layer, test_dygraph_infershape_context) { ...@@ -271,6 +271,15 @@ TEST(test_layer, test_dygraph_infershape_context) {
ASSERT_EQ(have_z, false); ASSERT_EQ(have_z, false);
} }
TEST(test_layer, test_inner_op_not_inited) {
OpBase op;
std::string kUnknown = "unknown";
ASSERT_EQ(op.Type(), kUnknown);
ASSERT_THROW(op.Info(), platform::EnforceNotMet);
ASSERT_THROW(op.InnerOp(), platform::EnforceNotMet);
ASSERT_THROW(op.CheckAttrs(), platform::EnforceNotMet);
}
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册