提交 b008360b 编写于 作者: Q Qiao Longfei 提交者: GitHub

merge InferShapeContext and OperatorContext (#3347)

* merge InferShapeContext and OperatorContext

* OperatorBase& instead of OperatorBase*
上级 0f3a3e98
...@@ -120,10 +120,10 @@ class OperatorBase { ...@@ -120,10 +120,10 @@ class OperatorBase {
std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_; std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_;
}; };
class OperatorContext { class InferShapeContext {
public: public:
OperatorContext(const OperatorBase* op, const Scope& scope) InferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(*op), scope_(scope) {} : op_(op), scope_(scope) {}
size_t InputSize() const { return op_.inputs_.size(); } size_t InputSize() const { return op_.inputs_.size(); }
...@@ -234,12 +234,6 @@ class OperatorContext { ...@@ -234,12 +234,6 @@ class OperatorContext {
const Scope& scope_; const Scope& scope_;
}; };
class InferShapeContext : public OperatorContext {
public:
InferShapeContext(const OperatorBase* op, const Scope& scope)
: OperatorContext(op, scope) {}
};
template <typename T> template <typename T>
struct EigenDeviceConverter; struct EigenDeviceConverter;
...@@ -255,11 +249,11 @@ struct EigenDeviceConverter<platform::GPUPlace> { ...@@ -255,11 +249,11 @@ struct EigenDeviceConverter<platform::GPUPlace> {
}; };
#endif #endif
class ExecutionContext : public OperatorContext { class ExecutionContext : public InferShapeContext {
public: public:
ExecutionContext(const OperatorBase* op, const Scope& scope, ExecutionContext(const OperatorBase& op, const Scope& scope,
const platform::DeviceContext* device_context) const platform::DeviceContext* device_context)
: OperatorContext(op, scope), device_context_(device_context) {} : InferShapeContext(op, scope), device_context_(device_context) {}
template <typename PlaceType, template <typename PlaceType,
typename DeviceType = typename DeviceType =
...@@ -311,13 +305,13 @@ class OperatorWithKernel : public OperatorBase { ...@@ -311,13 +305,13 @@ class OperatorWithKernel : public OperatorBase {
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>; std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
void InferShape(const Scope& scope) const override { void InferShape(const Scope& scope) const override {
InferShape(InferShapeContext(this, scope)); InferShape(InferShapeContext(*this, scope));
} }
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const final { const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(ExecutionContext(this, scope, &dev_ctx)); opKernel->Compute(ExecutionContext(*this, scope, &dev_ctx));
} }
static std::unordered_map<std::string /* op_type */, OpKernelMap>& static std::unordered_map<std::string /* op_type */, OpKernelMap>&
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册