提交 a0767228 编写于 作者: Q qiaolongfei

merge InferShapeContext and ExecutionContext

上级 c3b46d16
...@@ -205,13 +205,13 @@ void OperatorBase::GenerateTemporaryNames() { ...@@ -205,13 +205,13 @@ void OperatorBase::GenerateTemporaryNames() {
} }
template <> template <>
const Tensor* InferShapeContext::Input<Tensor>(const std::string& name) const { const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
auto* var = InputVar(name); auto* var = InputVar(name);
return var == nullptr ? nullptr : GetTensorFromVar(var); return var == nullptr ? nullptr : GetTensorFromVar(var);
} }
template <> template <>
const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>( const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
const std::string& name) const { const std::string& name) const {
auto names = op().Inputs(name); auto names = op().Inputs(name);
std::vector<const Tensor*> res; std::vector<const Tensor*> res;
...@@ -225,13 +225,13 @@ const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>( ...@@ -225,13 +225,13 @@ const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
} }
template <> template <>
Tensor* InferShapeContext::Output<Tensor>(const std::string& name) const { Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
auto var = OutputVar(name); auto var = OutputVar(name);
return var == nullptr ? nullptr : var->GetMutable<LoDTensor>(); return var == nullptr ? nullptr : var->GetMutable<LoDTensor>();
} }
template <> template <>
std::vector<Tensor*> InferShapeContext::MultiOutput<Tensor>( std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
const std::string& name) const { const std::string& name) const {
auto names = op().Outputs(name); auto names = op().Outputs(name);
std::vector<Tensor*> res; std::vector<Tensor*> res;
......
...@@ -57,7 +57,6 @@ inline std::string GradVarName(const std::string& var_name) { ...@@ -57,7 +57,6 @@ inline std::string GradVarName(const std::string& var_name) {
} }
class OperatorBase; class OperatorBase;
class InferShapeContext;
class ExecutionContext; class ExecutionContext;
extern const Tensor* GetTensorFromVar(const Variable* var); extern const Tensor* GetTensorFromVar(const Variable* var);
...@@ -169,10 +168,11 @@ class NOP : public OperatorBase { ...@@ -169,10 +168,11 @@ class NOP : public OperatorBase {
} }
}; };
class InferShapeContext { class ExecutionContext {
public: public:
InferShapeContext(const OperatorBase& op, const Scope& scope) ExecutionContext(const OperatorBase& op, const Scope& scope,
: op_(op), scope_(scope) {} const platform::DeviceContext& device_context)
: op_(op), scope_(scope), device_context_(device_context) {}
const OperatorBase& op() const { return op_; } const OperatorBase& op() const { return op_; }
...@@ -278,31 +278,6 @@ class InferShapeContext { ...@@ -278,31 +278,6 @@ class InferShapeContext {
out_tensor->set_lod(in_tensor.lod()); out_tensor->set_lod(in_tensor.lod());
} }
private:
const OperatorBase& op_;
const Scope& scope_;
};
template <>
const Tensor* InferShapeContext::Input<Tensor>(const std::string& name) const;
template <>
const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
const std::string& name) const;
template <>
Tensor* InferShapeContext::Output<Tensor>(const std::string& name) const;
template <>
std::vector<Tensor*> InferShapeContext::MultiOutput<Tensor>(
const std::string& name) const;
class ExecutionContext : public InferShapeContext {
public:
ExecutionContext(const OperatorBase& op, const Scope& scope,
const platform::DeviceContext& device_context)
: InferShapeContext(op, scope), device_context_(device_context) {}
template <typename PlaceType, template <typename PlaceType,
typename DeviceType = typename platform::EigenDeviceConverter< typename DeviceType = typename platform::EigenDeviceConverter<
PlaceType>::EigenDeviceType> PlaceType>::EigenDeviceType>
...@@ -315,9 +290,25 @@ class ExecutionContext : public InferShapeContext { ...@@ -315,9 +290,25 @@ class ExecutionContext : public InferShapeContext {
} }
private: private:
const OperatorBase& op_;
const Scope& scope_;
const platform::DeviceContext& device_context_; const platform::DeviceContext& device_context_;
}; };
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const;
template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
const std::string& name) const;
template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;
template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
const std::string& name) const;
class CompileTimeInferShapeContext : public InferShapeContextBase { class CompileTimeInferShapeContext : public InferShapeContextBase {
public: public:
CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block) CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册