未验证 提交 6eb95caf 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Add attr support for infershape utils (#39513)

* add attr support for infershape

* add unittest for coverage

* add unittest for coverage

* polish unittest detail

* fix windows test failed
上级 ac894ced
...@@ -413,7 +413,7 @@ cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tens ...@@ -413,7 +413,7 @@ cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tens
cc_library(generator SRCS generator.cc DEPS enforce place) cc_library(generator SRCS generator.cc DEPS enforce place)
cc_library(infershape_utils SRCS infershape_utils.cc DEPS lod_tensor selected_rows_utils attribute place pten var_type_traits pten pten_api_utils op_info shape_inference) cc_library(infershape_utils SRCS infershape_utils.cc DEPS lod_tensor selected_rows_utils attribute place pten var_type_traits pten pten_api_utils op_info shape_inference)
cc_test(infershape_utils_test SRCS infershape_utils_test.cc DEPS infershape_utils infermeta_utils meta_tensor)
# Get the current working branch # Get the current working branch
execute_process( execute_process(
......
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include <string>
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/pten_utils.h"
...@@ -303,13 +305,45 @@ pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -303,13 +305,45 @@ pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
auto& attr = attr_reader.GetAttr(attr_name); auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) == std::type_index(typeid(bool))) { if (std::type_index(attr.type()) == std::type_index(typeid(bool))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (std::type_index(attr.type()) == std::type_index(typeid(int))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(int64_t))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr));
} else if (std::type_index(attr.type()) == } else if (std::type_index(attr.type()) ==
std::type_index(typeid(float))) { std::type_index(typeid(float))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(float, attr)); infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::string))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(std::string, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<bool>))) {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<bool>, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int>))) {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int>, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int64_t>, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<float>))) {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<float>, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<double>))) {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<double>, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<std::string>))) {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<std::string>, attr));
} else { } else {
// do nothing, skip useless attrs now PADDLE_THROW(platform::errors::Unimplemented(
// TODO(chenweihang): support other attr type later and throw error "Unsupported attribute type is received when call "
// if attr is cannot parsed "InferShapeFunctor."));
} }
} else { } else {
// do nothing // do nothing
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/pten/core/compat/op_utils.h"
#include "paddle/pten/core/infermeta_utils.h"
namespace paddle {
namespace framework {
void TestInferMeta(bool bool_attr, int int_attr, int64_t int64_attr,
float float_attr, const std::string& str_attr,
const std::vector<bool>& vec_bool_attr,
const std::vector<int>& vec_int_attr,
const std::vector<int64_t>& vec_int64_attr,
const std::vector<float>& vec_float_attr,
const std::vector<double>& vec_double_attr,
const std::vector<std::string>& vec_str_attr) {
ASSERT_EQ(bool_attr, true);
ASSERT_EQ(int_attr, 10);
ASSERT_EQ(int64_attr, 100);
ASSERT_NEAR(float_attr, 3.14, 1e-6);
ASSERT_EQ(str_attr, "test");
ASSERT_EQ(vec_bool_attr.at(0), true);
ASSERT_EQ(vec_bool_attr.at(1), true);
ASSERT_EQ(vec_int_attr.at(0), 10);
ASSERT_EQ(vec_int_attr.at(1), 10);
ASSERT_EQ(vec_int64_attr.at(0), 100L);
ASSERT_EQ(vec_int64_attr.at(1), 100L);
ASSERT_NEAR(vec_float_attr.at(0), 3.14, 1e-6);
ASSERT_NEAR(vec_float_attr.at(1), 3.14, 1e-6);
ASSERT_NEAR(vec_double_attr.at(0), 3.1415, 1e-6);
ASSERT_NEAR(vec_double_attr.at(1), 3.1415, 1e-6);
ASSERT_EQ(vec_str_attr.at(0), "test_vec");
ASSERT_EQ(vec_str_attr.at(1), "test_vec");
}
class InferShapeUtilsTestOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddAttr<bool>("bool", "bool attr of test op");
AddAttr<int>("int", "int attr of test op");
AddAttr<int64_t>("int64", "int64 attr of test op");
AddAttr<float>("float", "float attr of test op");
AddAttr<std::string>("string", "string attr of test op");
AddAttr<std::vector<bool>>("vec_bool", "vec_bool attr of test op");
AddAttr<std::vector<int>>("vec_int", "vec_int attr of test op");
AddAttr<std::vector<int64_t>>("vec_int64", "vec_int attr of test op");
AddAttr<std::vector<float>>("vec_float", "vec_int attr of test op");
AddAttr<std::vector<double>>("vec_double", "vec_int attr of test op");
AddAttr<std::vector<std::string>>("vec_str", "vec_int attr of test op");
AddComment("This is test op");
}
};
class InferShapeUtilsTestOp : public OperatorWithKernel {
public:
using OperatorWithKernel::OperatorWithKernel;
OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override {
return OpKernelType(proto::VarType::FP32, ctx.GetPlace());
}
};
pten::KernelSignature InferShapeUtilsTestOpArgumentMapping(
const pten::ArgumentMappingContext& ctx) {
return pten::KernelSignature(
"infer_shape_utils_test", {},
{"bool", "int", "int64", "float", "string", "vec_bool", "vec_int",
"vec_int64", "vec_float", "vec_double", "vec_str"},
{});
}
} // namespace framework
} // namespace paddle
DELCARE_INFER_SHAPE_FUNCTOR(infer_shape_utils_test,
InferShapeUtilsTestInferShapeFunctor,
PT_INFER_META(paddle::framework::TestInferMeta));
REGISTER_OPERATOR(infer_shape_utils_test,
paddle::framework::InferShapeUtilsTestOp,
paddle::framework::InferShapeUtilsTestOpMaker,
InferShapeUtilsTestInferShapeFunctor);
TEST(InferShapeUtilsTest, ALL) {
paddle::framework::ProgramDesc prog;
paddle::framework::proto::BlockDesc proto_block;
paddle::framework::BlockDesc block_desc(&prog, &proto_block);
auto* op = block_desc.AppendOp();
op->SetType("infer_shape_utils_test");
paddle::framework::Attribute bool_attr(true);
op->SetAttr("bool", bool_attr);
paddle::framework::Attribute int_attr(10);
op->SetAttr("int", int_attr);
int64_t int64_val = 100;
paddle::framework::Attribute int64_attr(int64_val);
op->SetAttr("int64", int64_attr);
float float_value = 3.14;
paddle::framework::Attribute float_attr(float_value);
op->SetAttr("float", float_attr);
std::string str_value("test");
paddle::framework::Attribute str_attr(str_value);
op->SetAttr("string", str_attr);
std::vector<bool> vec_bool(2, true);
paddle::framework::Attribute vec_bool_attr = vec_bool;
op->SetAttr("vec_bool", vec_bool_attr);
std::vector<int> vec_int(2, 10);
paddle::framework::Attribute vec_int_attr = vec_int;
op->SetAttr("vec_int", vec_int_attr);
std::vector<int64_t> vec_int64(2, 100);
paddle::framework::Attribute vec_int64_attr = vec_int64;
op->SetAttr("vec_int64", vec_int64_attr);
std::cout << "after set vec_int64" << std::endl;
std::vector<float> vec_float(2, 3.14);
paddle::framework::Attribute vec_float_attr = vec_float;
op->SetAttr("vec_float", vec_float_attr);
std::vector<double> vec_double(2, 3.1415);
paddle::framework::Attribute vec_double_attr = vec_double;
op->SetAttr("vec_double", vec_double_attr);
std::vector<std::string> vec_str(2, "test_vec");
paddle::framework::Attribute vec_str_attr = vec_str;
op->SetAttr("vec_str", vec_str_attr);
pten::OpUtilsMap::Instance().InsertArgumentMappingFn(
"infer_shape_utils_test",
paddle::framework::InferShapeUtilsTestOpArgumentMapping);
op->InferShape(block_desc);
}
...@@ -15,6 +15,8 @@ limitations under the License. */ ...@@ -15,6 +15,8 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include <typeindex>
#include <typeinfo>
#include <utility> #include <utility>
#include "paddle/pten/common/scalar.h" #include "paddle/pten/common/scalar.h"
...@@ -55,9 +57,12 @@ class InferMetaContext { ...@@ -55,9 +57,12 @@ class InferMetaContext {
AttrType AttrAt(size_t idx) { AttrType AttrAt(size_t idx) {
try { try {
return paddle::any_cast<AttrType>(attrs_.at(idx)); return paddle::any_cast<AttrType>(attrs_.at(idx));
} catch (paddle::bad_any_cast&) { } catch (paddle::bad_any_cast& e) {
PADDLE_THROW(pten::errors::InvalidArgument( PADDLE_THROW(pten::errors::InvalidArgument(
"Attribute cast error in InferMeta Context.")); "Attribute cast error in InferMeta Context, the expected attribute "
"type is `%s`, but actual attribute type is `%s`.",
std::type_index(typeid(AttrType)).name(),
std::type_index(attrs_.at(idx).type()).name()));
} }
} }
...@@ -151,10 +156,15 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> { ...@@ -151,10 +156,15 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(float); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(float);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(double); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::string&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<bool>&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<int>&); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE( PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(
const std::vector<int64_t>&); const std::vector<int64_t>&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<float>&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<double>&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(
const std::vector<std::string>&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(Backend); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(Backend);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册