diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index c324fa6702de1eabab3f75cbf4e6568c99b60470..ceef9f028b05a0c99986d520c9e28b09a68137ce 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -120,10 +120,10 @@ class OperatorBase { std::shared_ptr> in_out_idxs_; }; -class OperatorContext { +class InferShapeContext { public: - OperatorContext(const OperatorBase* op, const Scope& scope) - : op_(*op), scope_(scope) {} + InferShapeContext(const OperatorBase& op, const Scope& scope) + : op_(op), scope_(scope) {} size_t InputSize() const { return op_.inputs_.size(); } @@ -234,12 +234,6 @@ class OperatorContext { const Scope& scope_; }; -class InferShapeContext : public OperatorContext { - public: - InferShapeContext(const OperatorBase* op, const Scope& scope) - : OperatorContext(op, scope) {} -}; - template struct EigenDeviceConverter; @@ -255,11 +249,11 @@ struct EigenDeviceConverter { }; #endif -class ExecutionContext : public OperatorContext { +class ExecutionContext : public InferShapeContext { public: - ExecutionContext(const OperatorBase* op, const Scope& scope, + ExecutionContext(const OperatorBase& op, const Scope& scope, const platform::DeviceContext* device_context) - : OperatorContext(op, scope), device_context_(device_context) {} + : InferShapeContext(op, scope), device_context_(device_context) {} template , OpKernelHash>; void InferShape(const Scope& scope) const override { - InferShape(InferShapeContext(this, scope)); + InferShape(InferShapeContext(*this, scope)); } void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const final { auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); - opKernel->Compute(ExecutionContext(this, scope, &dev_ctx)); + opKernel->Compute(ExecutionContext(*this, scope, &dev_ctx)); } static std::unordered_map&