diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index ef98558820f5696077c83424d5946caeec00b6fb..4d204aefde4562f67d552e5b10ee9ba38b6d5163 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -149,8 +149,29 @@ class CompileTimeInferShapeContext : public InferShapeContext { bool IsRuntime() const override; + std::vector GetInputsVarType( + const std::string &name) const override { + return GetVarTypes(Inputs(name)); + } + + std::vector GetOutputsVarType( + const std::string &name) const override { + return GetVarTypes(Outputs(name)); + } + protected: - proto::VarType::Type GetVarType(const std::string &name) const override; + std::vector GetVarTypes( + const std::vector &names) const { + std::vector retv; + retv.resize(names.size()); + std::transform( + names.begin(), names.end(), retv.begin(), + std::bind(std::mem_fn(&CompileTimeInferShapeContext::GetVarType), this, + std::placeholders::_1)); + return retv; + } + + proto::VarType::Type GetVarType(const std::string &name) const; DDim GetDim(const std::string &name) const { auto var = block_.FindVarRecursive(name); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 2bfe055b4ca0addeb872dcdcd7ea92172d7e5040..eb172ca88f3c168a8ddd6e12c2bc06092a582127 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -722,6 +722,16 @@ class RuntimeInferShapeContext : public InferShapeContext { return GetDims(vars); } + std::vector GetInputsVarType( + const std::string& name) const override { + return GetVarTypes(InputVars(name)); + } + + std::vector GetOutputsVarType( + const std::string& name) const override { + return GetVarTypes(OutputVars(name)); + } + protected: DDim GetDim(Variable* var) const { PADDLE_ENFORCE_NOT_NULL(var); @@ -766,8 +776,17 @@ class RuntimeInferShapeContext : public InferShapeContext { PADDLE_THROW("Only compile time support this method"); } - proto::VarType::Type GetVarType(const std::string& name) const override { - auto* var = scope_.FindVar(name); + std::vector GetVarTypes( + const std::vector& vars) const { + std::vector retv; + retv.resize(vars.size()); + std::transform(vars.begin(), vars.end(), retv.begin(), + std::bind(std::mem_fn(&RuntimeInferShapeContext::GetVarType), + this, std::placeholders::_1)); + return retv; + } + + proto::VarType::Type GetVarType(Variable* var) const { return ToVarType(var->Type()); } diff --git a/paddle/fluid/framework/shape_inference.cc b/paddle/fluid/framework/shape_inference.cc index f274a1b73f47bcd99ffc7a70fc7026fc97427200..4e67855b5caeb7b000cba2eaa1922b368b5df7c9 100644 --- a/paddle/fluid/framework/shape_inference.cc +++ b/paddle/fluid/framework/shape_inference.cc @@ -68,25 +68,5 @@ void InferShapeContext::SetDims(const std::vector &names, } } -std::vector InferShapeContext::GetInputsVarType( - const std::string &name) const { - return GetVarTypes(Inputs(name)); -} - -std::vector InferShapeContext::GetOutputsVarType( - const std::string &name) const { - return GetVarTypes(Outputs(name)); -} - -std::vector InferShapeContext::GetVarTypes( - const std::vector &names) const { - std::vector retv; - retv.resize(names.size()); - std::transform(names.begin(), names.end(), retv.begin(), - std::bind(std::mem_fn(&InferShapeContext::GetVarType), this, - std::placeholders::_1)); - return retv; -} - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 6cf9cf3f38608cc22425b1ea8ba71c6fe90580d4..415339a01dd579ea5ef68e945335660cb46024fe 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -34,9 +34,9 @@ class InferShapeContext { virtual bool HasOutput(const std::string &name) const = 0; virtual std::vector GetInputsVarType( - const std::string &name) const; + const std::string &name) const = 0; virtual std::vector GetOutputsVarType( - const std::string &name) const; + const std::string &name) const = 0; virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0; @@ -82,11 +82,6 @@ class InferShapeContext { virtual std::vector GetRepeatedDims(const std::string &name) const = 0; virtual void SetRepeatedDims(const std::string &name, const std::vector &dims) = 0; - - std::vector GetVarTypes( - const std::vector &names) const; - - virtual proto::VarType::Type GetVarType(const std::string &name) const = 0; }; } // namespace framework