From 9ef8a76873983c61eb91fab99f3306a5be8ef0c0 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 19 Dec 2018 16:13:31 +0800 Subject: [PATCH] convert more test=develop --- paddle/fluid/framework/op_desc.cc | 23 ++++++++++++++++++++++- paddle/fluid/framework/operator.cc | 23 +++++++++++++++++++++-- paddle/fluid/framework/shape_inference.cc | 20 -------------------- paddle/fluid/framework/shape_inference.h | 9 ++------- 4 files changed, 45 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index ef9855882..4d204aefd 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -149,8 +149,29 @@ class CompileTimeInferShapeContext : public InferShapeContext { bool IsRuntime() const override; + std::vector GetInputsVarType( + const std::string &name) const override { + return GetVarTypes(Inputs(name)); + } + + std::vector GetOutputsVarType( + const std::string &name) const override { + return GetVarTypes(Outputs(name)); + } + protected: - proto::VarType::Type GetVarType(const std::string &name) const override; + std::vector GetVarTypes( + const std::vector &names) const { + std::vector retv; + retv.resize(names.size()); + std::transform( + names.begin(), names.end(), retv.begin(), + std::bind(std::mem_fn(&CompileTimeInferShapeContext::GetVarType), this, + std::placeholders::_1)); + return retv; + } + + proto::VarType::Type GetVarType(const std::string &name) const; DDim GetDim(const std::string &name) const { auto var = block_.FindVarRecursive(name); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 2bfe055b4..eb172ca88 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -722,6 +722,16 @@ class RuntimeInferShapeContext : public InferShapeContext { return GetDims(vars); } + std::vector GetInputsVarType( + const std::string& name) const override { + return GetVarTypes(InputVars(name)); + } + + std::vector GetOutputsVarType( + const std::string& name) const override { + return GetVarTypes(OutputVars(name)); + } + protected: DDim GetDim(Variable* var) const { PADDLE_ENFORCE_NOT_NULL(var); @@ -766,8 +776,17 @@ class RuntimeInferShapeContext : public InferShapeContext { PADDLE_THROW("Only compile time support this method"); } - proto::VarType::Type GetVarType(const std::string& name) const override { - auto* var = scope_.FindVar(name); + std::vector GetVarTypes( + const std::vector& vars) const { + std::vector retv; + retv.resize(vars.size()); + std::transform(vars.begin(), vars.end(), retv.begin(), + std::bind(std::mem_fn(&RuntimeInferShapeContext::GetVarType), + this, std::placeholders::_1)); + return retv; + } + + proto::VarType::Type GetVarType(Variable* var) const { return ToVarType(var->Type()); } diff --git a/paddle/fluid/framework/shape_inference.cc b/paddle/fluid/framework/shape_inference.cc index f274a1b73..4e67855b5 100644 --- a/paddle/fluid/framework/shape_inference.cc +++ b/paddle/fluid/framework/shape_inference.cc @@ -68,25 +68,5 @@ void InferShapeContext::SetDims(const std::vector &names, } } -std::vector InferShapeContext::GetInputsVarType( - const std::string &name) const { - return GetVarTypes(Inputs(name)); -} - -std::vector InferShapeContext::GetOutputsVarType( - const std::string &name) const { - return GetVarTypes(Outputs(name)); -} - -std::vector InferShapeContext::GetVarTypes( - const std::vector &names) const { - std::vector retv; - retv.resize(names.size()); - std::transform(names.begin(), names.end(), retv.begin(), - std::bind(std::mem_fn(&InferShapeContext::GetVarType), this, - std::placeholders::_1)); - return retv; -} - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 6cf9cf3f3..415339a01 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -34,9 +34,9 @@ class InferShapeContext { virtual bool HasOutput(const std::string &name) const = 0; virtual std::vector GetInputsVarType( - const std::string &name) const; + const std::string &name) const = 0; virtual std::vector GetOutputsVarType( - const std::string &name) const; + const std::string &name) const = 0; virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0; @@ -82,11 +82,6 @@ class InferShapeContext { virtual std::vector GetRepeatedDims(const std::string &name) const = 0; virtual void SetRepeatedDims(const std::string &name, const std::vector &dims) = 0; - - std::vector GetVarTypes( - const std::vector &names) const; - - virtual proto::VarType::Type GetVarType(const std::string &name) const = 0; }; } // namespace framework -- GitLab