未验证 提交 d06b9ba4 编写于 作者: H hong 提交者: GitHub

[IR]add infer_shape interface (#54250)

* add infer_shape interface

* update
上级 455a6735
......@@ -35,6 +35,10 @@ H_FILE_TEMPLATE = """#ifdef GET_OP_LIST
#include "paddle/fluid/dialect/utils.h"
#include "paddle/fluid/dialect/pd_interface.h"
#include "paddle/fluid/interface/infershape.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
{input}
#endif
"""
......@@ -52,6 +56,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{
static OpInfoTuple GetOpInfo();
static void verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes);
{get_inputs_and_outputs}
{exclusive_interface}
}};
"""
op_0_attribute_declare_str = (
......@@ -77,6 +82,14 @@ CC_FILE_TEMPLATE = """#include "{h_file}"
#include "paddle/ir/core/ir_context.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/core/infermeta_utils.h"
{input}
"""
......@@ -217,6 +230,12 @@ ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attrib
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
}}
"""
OP_INFER_SHAPE_TEMPLATE = """
void {op_name}::InferShape( phi::InferMetaContext *infer_meta ) {{
auto fn = PD_INFER_META(phi::{infer_meta_func});
fn(infer_meta);
}}
"""
def to_phi_and_fluid_op_name(op_item):
......@@ -280,6 +299,11 @@ class OpInfoParser:
self.attribute_data_type_list = self.parse_attribute_data_type_list()
self.cross_check(self.attribute_name_list, self.attribute_type_list)
if 'infer_meta' in self.op_yaml_item:
self.infer_shape_func = self.op_yaml_item['infer_meta']["func"]
else:
self.infer_shape_func = None
def cross_check(self, name_list, type_list, optional_list=None):
assert len(name_list) == len(
type_list
......@@ -496,6 +520,13 @@ def OpGenerator(
op_interfaces = ["GetOpInfoInterface"]
op_traits = []
exclusive_interface_str = ""
if op_info.infer_shape_func:
op_interfaces += ["InferShapeInterface"]
exclusive_interface_str += (
" static void InferShape( phi::InferMetaContext *infer_meta );"
)
# If op has inplace info, we will generate inplace op and non-inplace op.
for op_name in op_info.op_phi_name:
op_class_name = to_pascal_case(op_name) + "Op"
......@@ -531,6 +562,7 @@ def OpGenerator(
attribute_declare=op_0_attribute_declare_str,
attribute_num=0,
get_inputs_and_outputs=op_get_inputs_outputs_str,
exclusive_interface=exclusive_interface_str,
)
op_defined_str = ""
else:
......@@ -544,6 +576,7 @@ def OpGenerator(
),
attribute_num=len(op_attribute_name_list),
get_inputs_and_outputs=op_get_inputs_outputs_str,
exclusive_interface=exclusive_interface_str,
)
attribute_names_str = (
'"' + '", "'.join(op_attribute_name_list) + '"'
......@@ -712,11 +745,19 @@ def OpGenerator(
attributes_check=attributes_check_str,
)
op_infer_shape_str = ""
if op_info.infer_shape_func:
op_infer_shape_str = OP_INFER_SHAPE_TEMPLATE.format(
op_name=op_class_name,
infer_meta_func=op_info.infer_shape_func,
)
ops_name_list.append(op_class_name)
ops_declare_list.append(op_declare_str)
ops_defined_list.append(op_defined_str)
ops_defined_list.append(op_info_func_str)
ops_defined_list.append(op_verify_str)
ops_defined_list.append(op_infer_shape_str)
# (4) Generate head file str
op_namespaces_prev = ""
......
# All source files of pd_dialect, except for the source file of op, which is generated in the compilation directory.
file(GLOB PD_INTERFACE_SRCS "*.cc")
cc_library(
pd_interface
SRCS ${PD_INTERFACE_SRCS}
DEPS new_ir framework_proto dense_tensor phi_utils)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/op_base.h"
#include "paddle/phi/core/infermeta_utils.h"
class InferShapeInterface : public ir::OpInterfaceBase<InferShapeInterface> {
public:
struct Concept {
explicit Concept(void (*infer_shape)(ir::Operation *,
phi::InferMetaContext *))
: infer_shape_(infer_shape) {}
void (*infer_shape_)(ir::Operation *, phi::InferMetaContext *);
};
template <class ConcreteOp>
struct Model : public Concept {
static void InferShape(ir::Operation *op,
phi::InferMetaContext *infer_meta) {
ConcreteOp concret_op = op->dyn_cast<ConcreteOp>();
if (concret_op == nullptr) throw("concret_op is nullptr");
concret_op.InferShape(infer_meta);
}
Model() : Concept(InferShape) {}
};
InferShapeInterface(ir::Operation *op, Concept *impl)
: ir::OpInterfaceBase<InferShapeInterface>(op), impl_(impl) {}
void InferShape(phi::InferMetaContext *infer_meta) {
impl_->infer_shape_(operation(), infer_meta);
}
private:
Concept *impl_;
};
......@@ -12,6 +12,16 @@ cc_test_old(
phi
gtest)
cc_test_old(
ir_infershape_test
SRCS
ir_infershape_test.cc
DEPS
new_ir
pd_dialect
phi
gtest)
file(
DOWNLOAD
https://paddle-ci.gz.bcebos.com/ir_translator_test/restnet50_main.prog
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_base.h"
#include "paddle/ir/core/region.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/fluid/interface/infershape.h"
#include "paddle/fluid/platform/init.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/nullary.h"
// Define op
class OperationTest : public ir::Op<OperationTest, InferShapeInterface> {
public:
using Op::Op;
static const char *name() { return "test.operation2"; }
static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num];
static void verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {}
static void InferShape(phi::InferMetaContext *infer_meta) {
auto fn = PD_INFER_META(phi::CreateInferMeta);
fn(infer_meta);
}
};
const char *OperationTest::attributes_name[attributes_num] = {"op2_attr1",
"op2_attr2"};
// Define a dialect, op1 and op2 will be registered by this dialect.
class TestDialect : public ir::Dialect {
public:
explicit TestDialect(ir::IrContext *context)
: ir::Dialect(name(), context, ir::TypeId::get<TestDialect>()) {
initialize();
}
static const char *name() { return "test"; }
private:
void initialize() { RegisterOps<OperationTest>(); }
};
TEST(infershape_test, infershape_test) {
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Dialect *test_dialect = ctx->GetOrRegisterDialect<TestDialect>();
EXPECT_EQ(test_dialect != nullptr, true);
// (2) Get registered operations.
std::string op_name = OperationTest::name();
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name);
std::vector<ir::OpResult> op_inputs = {};
std::vector<ir::Type> op_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op =
ir::Operation::create(op_inputs, {}, op_output_types, op_info);
InferShapeInterface interface = op->dyn_cast<InferShapeInterface>();
phi::InferMetaContext infer_meta_ctx;
infer_meta_ctx.EmplaceBackAttr(phi::IntArray({5, 6}));
infer_meta_ctx.EmplaceBackAttr(phi::DataType::FLOAT32);
phi::DenseTensor tensor;
infer_meta_ctx.EmplaceBackOutput(phi::MetaTensor(&tensor));
interface.InferShape(&infer_meta_ctx);
EXPECT_EQ(tensor.dims().size(), 2);
EXPECT_EQ(tensor.dims()[0], 5);
EXPECT_EQ(tensor.dims()[1], 6);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册