diff --git a/paddle/phi/core/infermeta_utils.cc b/paddle/phi/core/infermeta_utils.cc index 70f26102cbad1a3d7e1116a7c9352ca54435ea80..8bdad9d6d2b6e6098a35a74c3858b622f6c30aa5 100644 --- a/paddle/phi/core/infermeta_utils.cc +++ b/paddle/phi/core/infermeta_utils.cc @@ -30,7 +30,7 @@ void InferMetaContext::EmplaceBackOutput(MetaTensor output) { outputs_.emplace_back(std::move(output)); output_range_.emplace_back(std::pair(index, index + 1)); } -void InferMetaContext::EmplaceBackAttr(paddle::any attr) { +void InferMetaContext::EmplaceBackAttr(Attribute attr) { attrs_.emplace_back(std::move(attr)); } @@ -120,6 +120,38 @@ std::vector InferMetaContext::MutableOutputBetween(size_t start, return result; } +template +const AttrType& InferMetaContext::AttrAt(size_t idx) const { + try { + return paddle::get(attrs_.at(idx)); + } catch (paddle::bad_variant_access const& e) { + PADDLE_THROW(phi::errors::InvalidArgument( + "Attribute cast error in InferMeta Context, the expected attribute " + "type is `%s`.", + std::type_index(typeid(AttrType)).name())); + } +} + +template const bool& InferMetaContext::AttrAt(size_t idx) const; +template const int& InferMetaContext::AttrAt(size_t idx) const; +template const int64_t& InferMetaContext::AttrAt(size_t idx) const; +template const float& InferMetaContext::AttrAt(size_t idx) const; +template const double& InferMetaContext::AttrAt(size_t idx) const; +template const std::string& InferMetaContext::AttrAt(size_t idx) const; +template const std::vector& InferMetaContext::AttrAt(size_t idx) const; +template const std::vector& InferMetaContext::AttrAt(size_t idx) const; +template const std::vector& InferMetaContext::AttrAt(size_t idx) const; +template const std::vector& InferMetaContext::AttrAt(size_t idx) const; +template const std::vector& InferMetaContext::AttrAt(size_t idx) const; +template const std::vector& InferMetaContext::AttrAt( + size_t idx) const; +template const Scalar& InferMetaContext::AttrAt(size_t idx) const; +template const std::vector& InferMetaContext::AttrAt(size_t idx) const; +template const IntArray& InferMetaContext::AttrAt(size_t idx) const; +template const DataType& InferMetaContext::AttrAt(size_t idx) const; +template const DataLayout& InferMetaContext::AttrAt(size_t idx) const; +template const Place& InferMetaContext::AttrAt(size_t idx) const; + MetaFnFactory& MetaFnFactory::Instance() { static MetaFnFactory g_meta_fn_map; return g_meta_fn_map; diff --git a/paddle/phi/core/infermeta_utils.h b/paddle/phi/core/infermeta_utils.h index 699c38ebd470236dcbdb641eaeb6873829e13f40..8c726bffa2fc983827088a647e73515af8c9a8e9 100644 --- a/paddle/phi/core/infermeta_utils.h +++ b/paddle/phi/core/infermeta_utils.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/attribute.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/macros.h" #include "paddle/phi/core/meta_tensor.h" @@ -41,7 +42,7 @@ class InferMetaContext { void EmplaceBackInput(MetaTensor input); void EmplaceBackOutput(MetaTensor output); - void EmplaceBackAttr(paddle::any attr); + void EmplaceBackAttr(Attribute attr); void EmplaceBackInputs( paddle::SmallVector inputs); @@ -61,17 +62,7 @@ class InferMetaContext { size_t end); template - AttrType AttrAt(size_t idx) { - try { - return paddle::any_cast(attrs_.at(idx)); - } catch (paddle::bad_any_cast& e) { - PADDLE_THROW(phi::errors::InvalidArgument( - "Attribute cast error in InferMeta Context, the expected attribute " - "type is `%s`, but actual attribute type is `%s`.", - std::type_index(typeid(AttrType)).name(), - std::type_index(attrs_.at(idx).type()).name())); - } - } + const AttrType& AttrAt(size_t idx) const; const std::pair& InputRangeAt(size_t idx) const; const std::pair& OutputRangeAt(size_t idx) const; @@ -81,7 +72,7 @@ class InferMetaContext { protected: MetaConfig config_; - paddle::SmallVector attrs_; + paddle::SmallVector attrs_; paddle::SmallVector, phi::kInputSmallVectorSize> input_range_; @@ -111,6 +102,21 @@ class InferMetaContext { } \ } +#define PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(attr_type) \ + template \ + struct InferMetaFnCallHelper { \ + template \ + static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { \ + static_assert(out_idx == 0, \ + "InferMeta's Attributes should appear before Outputs."); \ + const attr_type& arg = ctx->AttrAt(attr_idx); \ + InferMetaFnCallHelper< \ + Tail...>::template Call(ctx, \ + pargs..., \ + arg); \ + } \ + } + template struct InferMetaTypeTag {}; @@ -201,27 +207,27 @@ struct InferMetaFnImpl { } }; - // TODO(chenweihang): support other attr type later PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(float); - PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::string&); - PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector&); - PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector&); - PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE( - const std::vector&); - PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector&); - PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector&); - PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE( - const std::vector&); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(Backend); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout); - PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const Scalar&); - PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const IntArray&); - - // TODO(chenweihang): support vector input later + PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::string); + PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(Scalar); + PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(IntArray); + PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF( + std::vector); + PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector); + PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF( + std::vector); + PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF( + std::vector); + PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF( + std::vector); + PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF( + std::vector); template struct InferMetaFnCallHelper { diff --git a/paddle/phi/core/type_defs.h b/paddle/phi/core/type_defs.h index e3cbf2cedd077fcc0417bfde6beb3ee867f0a467..0af1c0af230f752b0869675d96be9af92a6dab2a 100644 --- a/paddle/phi/core/type_defs.h +++ b/paddle/phi/core/type_defs.h @@ -18,37 +18,8 @@ #include #include -#include "paddle/phi/common/data_type.h" -#include "paddle/phi/common/int_array.h" -#include "paddle/phi/common/layout.h" -#include "paddle/phi/common/scalar.h" - -#include "paddle/utils/variant.h" - namespace phi { -class Place; - -// NOTE: Add needed type in the future -using Attribute = paddle::variant, - std::vector, - std::vector, - std::vector, - std::vector, - std::vector, - Scalar, - std::vector, - IntArray, - DataType, - DataLayout, - Place>; - class Kernel; class KernelKey; class KernelArgsDef; diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index e3e1211e3ece864a11da5f6280a6335e39a79f80..e5d83a4013d30e6f7a0dd5af4820a7b724d6024e 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -228,13 +228,6 @@ void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out) { out->set_dtype(x.dtype()); } -void CopyToInferMeta(const MetaTensor& x, - Backend backend, - bool blocking, - MetaTensor* out) { - UnchangedInferMeta(x, out); -} - void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) { out->set_dims(x.dims()); out->set_dtype(dtype == DataType::UNDEFINED ? x.dtype() : dtype); @@ -3008,6 +3001,5 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) { } // namespace phi -PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta); PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta); PD_REGISTER_INFER_META_FN(split, phi::SplitInferMeta); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index ac5040388b3341b322fb9d42b4c3f9323a454c42..70b868eeb5d8d0ec4d3e352c50510e69942d3ea3 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -58,11 +58,6 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out); void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out); -void CopyToInferMeta(const MetaTensor& x, - Backend backend, - bool blocking, - MetaTensor* out); - void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out); void CumsumInferMeta(const MetaTensor& x, diff --git a/paddle/phi/tests/core/test_meta_fn_utils.cc b/paddle/phi/tests/core/test_meta_fn_utils.cc index 028b9d23352c7f96740a98851d28bdd087a108fe..07832494d50ec346a8a4a3f44086f3ccf9f19687 100644 --- a/paddle/phi/tests/core/test_meta_fn_utils.cc +++ b/paddle/phi/tests/core/test_meta_fn_utils.cc @@ -60,32 +60,6 @@ TEST(MetaFnFactory, InferMetaFnExists) { EXPECT_EQ(dense_out1.dims()[1], dense_out2.dims()[1]); } -TEST(MetaFnFactory, CopyInferMetaFn) { - phi::DenseTensor dense_x; - dense_x.Resize({3, 4}); - - phi::MetaTensor meta_x(&dense_x); - phi::DenseTensor dense_out1; - phi::MetaTensor meta_out(&dense_out1); - phi::UnchangedInferMeta(meta_x, &meta_out); - - auto shared_meat_x = phi::MetaTensor(&dense_x); - phi::DenseTensor dense_out2; - auto shared_meta_out = phi::MetaTensor(&dense_out2); - - phi::InferMetaContext ctx; - ctx.EmplaceBackInput(shared_meat_x); - ctx.EmplaceBackAttr(Backend::CPU); - ctx.EmplaceBackAttr(false); - ctx.EmplaceBackOutput(shared_meta_out); - ctx.SetMetaConfig({/*is_runtime =*/true, /*is_run_mkldnn_kernel=*/false}); - phi::MetaFnFactory::Instance().Get("copy_to")(&ctx); - - EXPECT_EQ(dense_out1.dims().size(), dense_out2.dims().size()); - EXPECT_EQ(dense_out1.dims()[0], dense_out2.dims()[0]); - EXPECT_EQ(dense_out1.dims()[1], dense_out2.dims()[1]); -} - TEST(MetaFnFactory, SplitInferMetaFn) { phi::DenseTensor dense_x; dense_x.Resize({4, 10});