From 31bdb3f3cc98a38a3ec4951f26dd06260cdc4695 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 2 Oct 2017 14:03:38 -0700 Subject: [PATCH] tmp --- paddle/framework/operator.h | 98 +++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 73e53a4176d..f807909650f 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -317,6 +317,104 @@ class ExecutionContext : public InferShapeContext { const platform::DeviceContext& device_context_; }; +class CompileTimeInferShapeContext : public InferShapeContextBase { + public: + CompileTimeInferShapeContext(const OperatorBase& op, const Scope& scope) + : op_(op), scope_(scope) {} + + bool HasInput(const std::string& name) const { + auto ipt = op_.Input(name); + auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); + return var != nullptr; + } + + bool HasOutput(const std::string& name) const { + auto ipt = op_.Output(name); + auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); + return var != nullptr; + } + + bool HasInputs(const std::string& name) const { + auto inputs = op_.Inputs(name); + if (inputs.size() == 0UL) { + return false; + } + for (auto& input : inputs) { + if (scope_.FindVar(input) == nullptr) { + return false; + } + } + return true; + } + + bool HasOutputs(const std::string& name) const { + auto outputs = op_.Outputs(name); + if (outputs.size() == 0UL) { + return false; + } + for (auto& output : outputs) { + if (scope_.FindVar(output) == nullptr) { + return false; + } + } + return true; + } + + DDim GetInputDim(const std::string& name) const { + return GetDim(op_.Input(name)); + } + + void SetInputDim(const std::string& name, const DDim& dim) { + SetDim(op_.Input(name), dim); + } + + DDim GetOutputDim(const std::string& name) const { + return GetDim(op_.Output(name)); + } + + void SetOutputDim(const std::string& name, const DDim& dim) { + SetDim(op_.Output(name), dim); + } + + AttrReader Attrs() const { return AttrReader(op_.Attrs()); } + + const std::vector& Inputs(const std::string& name) const { + return op_.Inputs(name); + } + + const std::vector& Outputs(const std::string& name) const { + return op_.Outputs(name); + } + + private: + template + Tensor* GetTensor(const std::string& name) const { + Tensor* t = nullptr; + auto* var = scope_.FindVar(name); + if (!var->IsType() && !var->IsType()) { + if (Allocate) { + t = var->GetMutable(); + } else { + PADDLE_THROW("Variable(%s) should be tensor", name); + } + } else { + t = GetTensorFromVar(scope_.FindVar(name)); + } + return t; + } + + DDim GetDim(const std::string& name) const { + return GetTensor(name)->dims(); + } + + void SetDim(const std::string& name, const DDim& dim) { + GetTensor(name)->Resize(dim); + } + + const OperatorBase& op_; + const Scope& scope_; +}; + class RuntimeInferShapeContext : public InferShapeContextBase { public: RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope) -- GitLab