diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 8b3842adbb8aa924bb392b9e0b7db985586b3406..78f5bb077aaf189ff0d21aba853d62aebe46f53e 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -413,7 +413,7 @@ cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tens cc_library(generator SRCS generator.cc DEPS enforce place) cc_library(infershape_utils SRCS infershape_utils.cc DEPS lod_tensor selected_rows_utils attribute place pten var_type_traits pten pten_api_utils op_info shape_inference) - +cc_test(infershape_utils_test SRCS infershape_utils_test.cc DEPS infershape_utils infermeta_utils meta_tensor) # Get the current working branch execute_process( diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 9e1958973d2d97a351ef5ced57339fb698b70281..bc0344d405cf795bc96fd3fb2d5376bbde89bd2b 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -14,6 +14,8 @@ limitations under the License. */ #include "paddle/fluid/framework/infershape_utils.h" +#include + #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/pten_utils.h" @@ -303,13 +305,45 @@ pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, auto& attr = attr_reader.GetAttr(attr_name); if (std::type_index(attr.type()) == std::type_index(typeid(bool))) { infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); + } else if (std::type_index(attr.type()) == std::type_index(typeid(int))) { + infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int, attr)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(int64_t))) { + infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr)); } else if (std::type_index(attr.type()) == std::type_index(typeid(float))) { infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(float, attr)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::string))) { + infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(std::string, attr)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + infer_meta_context.EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + infer_meta_context.EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + infer_meta_context.EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + infer_meta_context.EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + infer_meta_context.EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + infer_meta_context.EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); } else { - // do nothing, skip useless attrs now - // TODO(chenweihang): support other attr type later and throw error - // if attr is cannot parsed + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported attribute type is received when call " + "InferShapeFunctor.")); } } else { // do nothing diff --git a/paddle/fluid/framework/infershape_utils_test.cc b/paddle/fluid/framework/infershape_utils_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..755ca3f5ce90b7bcc85e904089262fd7f7e401cb --- /dev/null +++ b/paddle/fluid/framework/infershape_utils_test.cc @@ -0,0 +1,163 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "gtest/gtest.h" + +#include "paddle/fluid/framework/attribute.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_info.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/pten/core/compat/op_utils.h" +#include "paddle/pten/core/infermeta_utils.h" + +namespace paddle { +namespace framework { + +void TestInferMeta(bool bool_attr, int int_attr, int64_t int64_attr, + float float_attr, const std::string& str_attr, + const std::vector& vec_bool_attr, + const std::vector& vec_int_attr, + const std::vector& vec_int64_attr, + const std::vector& vec_float_attr, + const std::vector& vec_double_attr, + const std::vector& vec_str_attr) { + ASSERT_EQ(bool_attr, true); + ASSERT_EQ(int_attr, 10); + ASSERT_EQ(int64_attr, 100); + ASSERT_NEAR(float_attr, 3.14, 1e-6); + ASSERT_EQ(str_attr, "test"); + ASSERT_EQ(vec_bool_attr.at(0), true); + ASSERT_EQ(vec_bool_attr.at(1), true); + ASSERT_EQ(vec_int_attr.at(0), 10); + ASSERT_EQ(vec_int_attr.at(1), 10); + ASSERT_EQ(vec_int64_attr.at(0), 100L); + ASSERT_EQ(vec_int64_attr.at(1), 100L); + ASSERT_NEAR(vec_float_attr.at(0), 3.14, 1e-6); + ASSERT_NEAR(vec_float_attr.at(1), 3.14, 1e-6); + ASSERT_NEAR(vec_double_attr.at(0), 3.1415, 1e-6); + ASSERT_NEAR(vec_double_attr.at(1), 3.1415, 1e-6); + ASSERT_EQ(vec_str_attr.at(0), "test_vec"); + ASSERT_EQ(vec_str_attr.at(1), "test_vec"); +} + +class InferShapeUtilsTestOpMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddAttr("bool", "bool attr of test op"); + AddAttr("int", "int attr of test op"); + AddAttr("int64", "int64 attr of test op"); + AddAttr("float", "float attr of test op"); + AddAttr("string", "string attr of test op"); + AddAttr>("vec_bool", "vec_bool attr of test op"); + AddAttr>("vec_int", "vec_int attr of test op"); + AddAttr>("vec_int64", "vec_int attr of test op"); + AddAttr>("vec_float", "vec_int attr of test op"); + AddAttr>("vec_double", "vec_int attr of test op"); + AddAttr>("vec_str", "vec_int attr of test op"); + AddComment("This is test op"); + } +}; + +class InferShapeUtilsTestOp : public OperatorWithKernel { + public: + using OperatorWithKernel::OperatorWithKernel; + + OpKernelType GetExpectedKernelType( + const ExecutionContext& ctx) const override { + return OpKernelType(proto::VarType::FP32, ctx.GetPlace()); + } +}; + +pten::KernelSignature InferShapeUtilsTestOpArgumentMapping( + const pten::ArgumentMappingContext& ctx) { + return pten::KernelSignature( + "infer_shape_utils_test", {}, + {"bool", "int", "int64", "float", "string", "vec_bool", "vec_int", + "vec_int64", "vec_float", "vec_double", "vec_str"}, + {}); +} + +} // namespace framework +} // namespace paddle + +DELCARE_INFER_SHAPE_FUNCTOR(infer_shape_utils_test, + InferShapeUtilsTestInferShapeFunctor, + PT_INFER_META(paddle::framework::TestInferMeta)); +REGISTER_OPERATOR(infer_shape_utils_test, + paddle::framework::InferShapeUtilsTestOp, + paddle::framework::InferShapeUtilsTestOpMaker, + InferShapeUtilsTestInferShapeFunctor); + +TEST(InferShapeUtilsTest, ALL) { + paddle::framework::ProgramDesc prog; + paddle::framework::proto::BlockDesc proto_block; + paddle::framework::BlockDesc block_desc(&prog, &proto_block); + + auto* op = block_desc.AppendOp(); + op->SetType("infer_shape_utils_test"); + + paddle::framework::Attribute bool_attr(true); + op->SetAttr("bool", bool_attr); + + paddle::framework::Attribute int_attr(10); + op->SetAttr("int", int_attr); + + int64_t int64_val = 100; + paddle::framework::Attribute int64_attr(int64_val); + op->SetAttr("int64", int64_attr); + + float float_value = 3.14; + paddle::framework::Attribute float_attr(float_value); + op->SetAttr("float", float_attr); + + std::string str_value("test"); + paddle::framework::Attribute str_attr(str_value); + op->SetAttr("string", str_attr); + + std::vector vec_bool(2, true); + paddle::framework::Attribute vec_bool_attr = vec_bool; + op->SetAttr("vec_bool", vec_bool_attr); + + std::vector vec_int(2, 10); + paddle::framework::Attribute vec_int_attr = vec_int; + op->SetAttr("vec_int", vec_int_attr); + + std::vector vec_int64(2, 100); + paddle::framework::Attribute vec_int64_attr = vec_int64; + op->SetAttr("vec_int64", vec_int64_attr); + std::cout << "after set vec_int64" << std::endl; + + std::vector vec_float(2, 3.14); + paddle::framework::Attribute vec_float_attr = vec_float; + op->SetAttr("vec_float", vec_float_attr); + + std::vector vec_double(2, 3.1415); + paddle::framework::Attribute vec_double_attr = vec_double; + op->SetAttr("vec_double", vec_double_attr); + + std::vector vec_str(2, "test_vec"); + paddle::framework::Attribute vec_str_attr = vec_str; + op->SetAttr("vec_str", vec_str_attr); + + pten::OpUtilsMap::Instance().InsertArgumentMappingFn( + "infer_shape_utils_test", + paddle::framework::InferShapeUtilsTestOpArgumentMapping); + + op->InferShape(block_desc); +} diff --git a/paddle/pten/core/infermeta_utils.h b/paddle/pten/core/infermeta_utils.h index 6de91db9382e22537e577ce3188764034c7235e3..59d2a4ed3c089d2480bfcbe526d2706371e322bc 100644 --- a/paddle/pten/core/infermeta_utils.h +++ b/paddle/pten/core/infermeta_utils.h @@ -15,6 +15,8 @@ limitations under the License. */ #pragma once #include +#include +#include #include #include "paddle/pten/common/scalar.h" @@ -55,9 +57,12 @@ class InferMetaContext { AttrType AttrAt(size_t idx) { try { return paddle::any_cast(attrs_.at(idx)); - } catch (paddle::bad_any_cast&) { + } catch (paddle::bad_any_cast& e) { PADDLE_THROW(pten::errors::InvalidArgument( - "Attribute cast error in InferMeta Context.")); + "Attribute cast error in InferMeta Context, the expected attribute " + "type is `%s`, but actual attribute type is `%s`.", + std::type_index(typeid(AttrType)).name(), + std::type_index(attrs_.at(idx).type()).name())); } } @@ -151,10 +156,15 @@ struct InferMetaFnImpl { PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(float); - PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(double); + PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::string&); + PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector&); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector&); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE( const std::vector&); + PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector&); + PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector&); + PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE( + const std::vector&); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(Backend); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout);