From 94b31f90a41d653f9e587204ee71bdd49475f4d6 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Thu, 24 Feb 2022 11:24:54 +0800 Subject: [PATCH] [pten] add optional type for infermeta (#39848) * modify infershape by args_def * add optional type for infermate * add optional type for infermate * add optional type for infermate * support scalar type * change OptionalInputAt function to none template * support phi::DataType --- paddle/fluid/framework/infershape_utils.cc | 85 ++++++++++++++++++---- paddle/phi/core/infermeta_utils.cc | 8 ++ paddle/phi/core/infermeta_utils.h | 21 ++++++ 3 files changed, 99 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index aae36cf455..4bec1baeaa 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar_array.h" #include "paddle/phi/core/compat/arg_map_context.h" #include "paddle/phi/core/compat/convert_utils.h" @@ -376,47 +377,101 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, attr_name)); } } + } else if (attr_defs[i].type_index == + std::type_index(typeid(phi::Scalar))) { + if (ctx->HasAttr(attr_name)) { + // TODO(chentianyu03): support other attrs later + auto& attr = attr_reader.GetAttr(attr_name); + if (std::type_index(attr.type()) == std::type_index(typeid(float))) { + infer_meta_context.EmplaceBackAttr( + phi::Scalar(BOOST_GET_CONST(float, attr))); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::string))) { + infer_meta_context.EmplaceBackAttr( + phi::Scalar(BOOST_GET_CONST(std::string, attr))); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(int))) { + infer_meta_context.EmplaceBackAttr( + phi::Scalar(BOOST_GET_CONST(int, attr))); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to Scalar when construct " + "InferMetaContext.", + attr_name)); + } + } else if (ctx->HasInput(attr_name)) { + const auto& infershape_input = ctx->GetInputVarPtrs(attr_name); + if (infershape_input.size() == 1) { + if (ctx->IsRuntime()) { + Variable* var = BOOST_GET_CONST(Variable*, infershape_input[0]); + infer_meta_context.EmplaceBackAttr( + std::move(experimental::MakePtenScalarFromVar(*var))); + } else { + phi::Scalar tensor_scalar(-1); + tensor_scalar.SetFromTensor(true); + infer_meta_context.EmplaceBackAttr(std::move(tensor_scalar)); + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Invalid input.size() when cast op attribute `%s` to Scalar, " + "expected 1, but actually is %d .", + attr_name, infershape_input.size())); + } + } } else if (ctx->HasAttr(attr_name)) { // Emplace Back Attr according to the type of attr. auto& attr = attr_reader.GetAttr(attr_name); - if (std::type_index(attr.type()) == std::type_index(typeid(bool))) { + if (attr_defs[i].type_index == std::type_index(typeid(bool))) { infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); - } else if (std::type_index(attr.type()) == std::type_index(typeid(int))) { + } else if (attr_defs[i].type_index == std::type_index(typeid(int))) { infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int, attr)); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(int64_t))) { + } else if (attr_defs[i].type_index == std::type_index(typeid(int64_t))) { infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr)); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(float))) { + } else if (attr_defs[i].type_index == std::type_index(typeid(float))) { infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(float, attr)); - } else if (std::type_index(attr.type()) == + } else if (attr_defs[i].type_index == std::type_index(typeid(std::string))) { infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(std::string, attr)); - } else if (std::type_index(attr.type()) == + } else if (attr_defs[i].type_index == std::type_index(typeid(std::vector))) { infer_meta_context.EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr)); - } else if (std::type_index(attr.type()) == + } else if (attr_defs[i].type_index == std::type_index(typeid(std::vector))) { infer_meta_context.EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr)); - } else if (std::type_index(attr.type()) == + } else if (attr_defs[i].type_index == std::type_index(typeid(std::vector))) { - infer_meta_context.EmplaceBackAttr( - BOOST_GET_CONST(std::vector, attr)); - } else if (std::type_index(attr.type()) == + if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + // Emplace Back Attr according to the type of Phi_Kernel args. + const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); + const std::vector vector_int64_attr(vector_int_attr.begin(), + vector_int_attr.end()); + infer_meta_context.EmplaceBackAttr(vector_int64_attr); + } else { + infer_meta_context.EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); + } + } else if (attr_defs[i].type_index == std::type_index(typeid(std::vector))) { infer_meta_context.EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr)); - } else if (std::type_index(attr.type()) == + } else if (attr_defs[i].type_index == std::type_index(typeid(std::vector))) { infer_meta_context.EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr)); - } else if (std::type_index(attr.type()) == + } else if (attr_defs[i].type_index == std::type_index(typeid(std::vector))) { infer_meta_context.EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr)); + } else if (attr_defs[i].type_index == + std::type_index(typeid(phi::DataType))) { + auto data_type = paddle::framework::TransToPtenDataType( + static_cast( + BOOST_GET_CONST(int, attr))); + infer_meta_context.EmplaceBackAttr(data_type); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported attribute type is received when call " diff --git a/paddle/phi/core/infermeta_utils.cc b/paddle/phi/core/infermeta_utils.cc index d21232ed82..f3dd056911 100644 --- a/paddle/phi/core/infermeta_utils.cc +++ b/paddle/phi/core/infermeta_utils.cc @@ -67,6 +67,14 @@ const MetaTensor& InferMetaContext::InputAt(size_t idx) const { return *inputs_.at(idx); } +paddle::optional InferMetaContext::OptionalInputAt( + size_t idx) const { + const auto& input = inputs_.at(idx); + return input ? paddle::optional{static_cast< + const phi::MetaTensor&>(*input)} + : paddle::optional{paddle::none}; +} + std::vector InferMetaContext::InputsBetween(size_t start, size_t end) const { std::vector result; diff --git a/paddle/phi/core/infermeta_utils.h b/paddle/phi/core/infermeta_utils.h index 7cf92e4d93..203dbb2698 100644 --- a/paddle/phi/core/infermeta_utils.h +++ b/paddle/phi/core/infermeta_utils.h @@ -51,6 +51,9 @@ class InferMetaContext { const MetaConfig& GetMetaConfig() const; const MetaTensor& InputAt(size_t idx) const; + + paddle::optional OptionalInputAt(size_t idx) const; + std::vector InputsBetween(size_t start, size_t end) const; MetaTensor* MutableOutputAt(size_t idx); std::vector MutableOutputBetween(size_t start, size_t end); @@ -135,6 +138,24 @@ struct InferMetaFnImpl { } }; + template + struct InferMetaFnCallHelper, Tail...> { + template + static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { + static_assert(attr_idx == 0, + "InferMeta's Input should appear before Attributes."); + static_assert(out_idx == 0, + "InferMeta's Input should appear before Outputs."); + const std::pair range = ctx->InputRangeAt(in_idx); + auto arg = ctx->OptionalInputAt(range.first); + + InferMetaFnCallHelper< + Tail...>::template Call(ctx, + pargs..., + arg); + } + }; + template struct InferMetaFnCallHelper&, Tail...> { template -- GitLab