From 728665d709811162ac1e2e136e44f88d6e68cb7f Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Thu, 13 Jul 2017 15:57:22 +0800 Subject: [PATCH] Add Init to OperatorBase (#2838) --- paddle/framework/op_registry.h | 1 + paddle/framework/operator.h | 4 ++++ paddle/framework/operator_test.cc | 5 +++++ 3 files changed, 10 insertions(+) diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 248c7a1a3b..e46da822c6 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -119,6 +119,7 @@ class OpRegistry { op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr); } op_checkers().at(op_type).Check(op->attrs_); + op->Init(); return op; } diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 0ce422e007..4336115670 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -49,6 +49,10 @@ class OperatorBase { std::string DebugString() const; + /// Init will be called after CreateOperator, you can put some initialization + /// logic here. + virtual void Init() {} + /// InferShape infer the size of Variables used by this Operator with /// information inside scope virtual void InferShape(const std::shared_ptr& scope) const = 0; diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index be8c4be2d4..01b87bb50e 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -21,14 +21,19 @@ namespace framework { class OperatorTest : public OperatorBase { public: + void Init() override { x = 1; } void InferShape(const std::shared_ptr& scope) const override {} void Run(const std::shared_ptr& scope, const platform::DeviceContext& dev_ctx) const override { float scale = GetAttr("scale"); ASSERT_NEAR(scale, 3.14, 1e-5); ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); + ASSERT_EQ(x, 1); ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr); } + + public: + float x = 0; }; class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { -- GitLab