提交 728665d7 编写于 作者: Q Qiao Longfei 提交者: GitHub

Add Init to OperatorBase (#2838)

上级 90cf44d7
...@@ -119,6 +119,7 @@ class OpRegistry { ...@@ -119,6 +119,7 @@ class OpRegistry {
op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr); op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
} }
op_checkers().at(op_type).Check(op->attrs_); op_checkers().at(op_type).Check(op->attrs_);
op->Init();
return op; return op;
} }
......
...@@ -49,6 +49,10 @@ class OperatorBase { ...@@ -49,6 +49,10 @@ class OperatorBase {
std::string DebugString() const; std::string DebugString() const;
/// Init will be called after CreateOperator, you can put some initialization
/// logic here.
virtual void Init() {}
/// InferShape infer the size of Variables used by this Operator with /// InferShape infer the size of Variables used by this Operator with
/// information inside scope /// information inside scope
virtual void InferShape(const std::shared_ptr<Scope>& scope) const = 0; virtual void InferShape(const std::shared_ptr<Scope>& scope) const = 0;
......
...@@ -21,14 +21,19 @@ namespace framework { ...@@ -21,14 +21,19 @@ namespace framework {
class OperatorTest : public OperatorBase { class OperatorTest : public OperatorBase {
public: public:
void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {} void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope, void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
float scale = GetAttr<float>("scale"); float scale = GetAttr<float>("scale");
ASSERT_NEAR(scale, 3.14, 1e-5); ASSERT_NEAR(scale, 3.14, 1e-5);
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
ASSERT_EQ(x, 1);
ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr); ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr);
} }
public:
float x = 0;
}; };
class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册