未验证 提交 bb3fb69c 编写于 作者: K kangguangli 提交者: GitHub

[NewIR] add method to check_unregistered_ops in `paddle dialect` (#56587)

* check_unregistered_ops

* fix
上级 4e0fc706
......@@ -246,66 +246,6 @@ inline ir::OpResult GetAttributeAsInput(ir::IrContext* ctx,
} // namespace
/// @brief This class is used to translate a OpDesc, it's a functor class and
/// should have no non-static data member, since we expected it's stateless.
struct OpTranscriber {
public:
virtual ~OpTranscriber() = default;
public:
virtual ir::Operation* operator()(ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
ir::Program* program);
public:
virtual ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc);
virtual std::vector<ir::OpResult> GenerateOperationInput(
ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
const std::string& normalized_op_name,
const OpInputInfoList& input_infos,
ir::Program* program);
virtual std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput(
ir::IrContext* ctx,
const OpDesc& op_desc,
const OpOutputInfoList& output_infos);
virtual void HandleNonexistentAttribute(ir::IrContext*,
ir::AttributeMap* attribute_map,
const OpAttributeInfo& info) {
auto& attribute_translator = AttributeTranslator::instance();
(*attribute_map)[info.name] =
attribute_translator(info.type_name, paddle::framework::Attribute());
}
virtual ir::AttributeMap TranslateOpAttribute(
ir::IrContext* ctx,
const std::string& normalized_op_name,
const OpAttributeInfoList& op_attr_infos,
const OpDesc& op_desc);
virtual void RecordOpResultMapping(ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
ir::Operation* operation,
const OpOutputMapping& arg_to_idx);
public:
virtual InputHandlerFn GetSpecialInputHandlers(
const std::string& input_name) {
return nullptr;
}
virtual AttributeHandlerFn GetSpecialAttributeHandlers(
const std::string& input_name) {
return nullptr;
}
virtual void InsertSliceOperationForInput(ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
const OpInputInfoList& input_infos,
ir::Program* program);
};
ir::OpInfo OpTranscriber::LoopkUpOpInfo(ir::IrContext* ctx,
const OpDesc& op_desc) {
std::string target_op_name =
......@@ -625,6 +565,14 @@ ir::AttributeMap OpTranscriber::TranslateOpAttribute(
return attribute_map;
}
void OpTranscriber::HandleNonexistentAttribute(ir::IrContext*,
ir::AttributeMap* attribute_map,
const OpAttributeInfo& info) {
auto& attribute_translator = AttributeTranslator::instance();
(*attribute_map)[info.name] =
attribute_translator(info.type_name, paddle::framework::Attribute());
}
void OpTranscriber::RecordOpResultMapping(ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
......
......@@ -20,6 +20,7 @@
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h"
#include "paddle/fluid/ir_adaptor/translator/program_translator.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
......@@ -29,6 +30,84 @@
namespace paddle {
namespace translator {
/// @brief This class is used to translate a OpDesc, it's a functor class and
/// should have no non-static data member, since we expected it's stateless.
struct OpTranscriber {
public:
virtual ~OpTranscriber() = default;
public:
using IdxInOp = size_t;
using IdxInVector = size_t;
using ResultIdx = std::tuple<IdxInOp, IdxInVector>;
using OpDesc = paddle::framework::OpDesc;
using OpOutputTypeList = std::vector<ir::Type>;
using OpOutputMapping = std::unordered_map<std::string, ResultIdx>;
using OpInputInfo = dialect::OpInputInfo;
using OpInputInfoList = std::vector<dialect::OpInputInfo>;
using OpAttributeInfo = dialect::OpAttributeInfo;
using OpAttributeInfoList = std::vector<dialect::OpAttributeInfo>;
using OpOutputInfo = dialect::OpOutputInfo;
using OpOutputInfoList = std::vector<dialect::OpOutputInfo>;
using InputHandlerFn = std::function<ir::OpResult(ir::IrContext*,
TranslationContext*,
const OpDesc&,
const std::string&,
const OpInputInfo&,
ir::Program*)>;
using AttributeHandlerFn = std::function<ir::Attribute(
ir::IrContext*, const OpDesc&, const OpAttributeInfo&)>;
public:
virtual ir::Operation* operator()(ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
ir::Program* program);
public:
virtual ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc);
virtual std::vector<ir::OpResult> GenerateOperationInput(
ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
const std::string& normalized_op_name,
const OpInputInfoList& input_infos,
ir::Program* program);
virtual std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput(
ir::IrContext* ctx,
const OpDesc& op_desc,
const OpOutputInfoList& output_infos);
virtual void HandleNonexistentAttribute(ir::IrContext*,
ir::AttributeMap* attribute_map,
const OpAttributeInfo& info);
virtual ir::AttributeMap TranslateOpAttribute(
ir::IrContext* ctx,
const std::string& normalized_op_name,
const OpAttributeInfoList& op_attr_infos,
const OpDesc& op_desc);
virtual void RecordOpResultMapping(ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
ir::Operation* operation,
const OpOutputMapping& arg_to_idx);
public:
virtual InputHandlerFn GetSpecialInputHandlers(
const std::string& input_name) {
return nullptr;
}
virtual AttributeHandlerFn GetSpecialAttributeHandlers(
const std::string& input_name) {
return nullptr;
}
virtual void InsertSliceOperationForInput(ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
const OpInputInfoList& input_infos,
ir::Program* program);
};
class OpTranslator {
public:
using ResultIdx = size_t;
......@@ -61,6 +140,10 @@ class OpTranslator {
return special_handlers[op_type];
}
}
bool HasSpecialHandler(const std::string& op_type) {
return special_handlers.count(op_type) != 0;
}
};
using OpTranslateFn = OpTranslator::OpTranslateFn;
......
......@@ -16,8 +16,11 @@
#include <unordered_map>
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h"
#include "paddle/fluid/ir_adaptor/translator/op_translator.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/utils.h"
namespace paddle {
......@@ -57,5 +60,37 @@ std::ostream& operator<<(std::ostream& os,
return os;
}
std::vector<std::string> CheckUnregisteredOperationInBlock(
ir::IrContext* ctx, const framework::BlockDesc& block) {
auto& op_translator = OpTranslator::instance();
std::vector<std::string> unregistered_ops;
for (auto op : block.AllOps()) {
if (op_translator.HasSpecialHandler(op->Type())) {
continue;
}
OpTranscriber general_handler;
try {
general_handler.LoopkUpOpInfo(ctx, *op);
} catch (ir::IrNotMetException& e) {
unregistered_ops.push_back(op->Type());
}
}
return unregistered_ops;
}
std::vector<std::string> CheckUnregisteredOperation(
ir::IrContext* ctx, const framework::ProgramDesc& legacy_program) {
ctx->GetOrRegisterDialect<dialect::PaddleDialect>();
std::vector<std::string> unregistered_ops;
for (size_t block_idx = 0; block_idx < legacy_program.Size(); block_idx++) {
const framework::BlockDesc& block = legacy_program.Block(block_idx);
auto ops = CheckUnregisteredOperationInBlock(ctx, block);
unregistered_ops.insert(unregistered_ops.end(), ops.begin(), ops.end());
}
return unregistered_ops;
}
} // namespace translator
} // namespace paddle
......@@ -15,7 +15,9 @@
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/ir_adaptor/translator/program_translator.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
......@@ -34,5 +36,8 @@ ir::Operation* InsertSliceOperationForTarget(
std::ostream& operator<<(std::ostream& os,
const std::vector<std::string>& vec_str);
std::vector<std::string> CheckUnregisteredOperation(
ir::IrContext* ctx, const framework::ProgramDesc& legacy_program);
} // namespace translator
} // namespace paddle
......@@ -31,6 +31,7 @@
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h"
#include "paddle/fluid/ir_adaptor/translator/translate.h"
#include "paddle/fluid/ir_adaptor/translator/utils.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/program.h"
......@@ -444,6 +445,21 @@ void BindUtils(pybind11::module *m) {
print(newir_program)
)DOC");
m->def(
"check_unregistered_ops",
[](const framework::ProgramDesc &legacy_program) {
ir::IrContext *ctx = ir::IrContext::Instance();
return paddle::translator::CheckUnregisteredOperation(ctx,
legacy_program);
},
R"DOC(
Check unregistered operators in paddle dialect.
Args:
legacy_program (ProgramDesc): The Fluid Program that need checked.
Returns:
list[str] : List of unregistered operators in paddle dialect, the name is expressed by origin op name.
)DOC");
}
void BindNewIR(pybind11::module *module) {
......
......@@ -27,6 +27,7 @@ from paddle.fluid.libpaddle.ir import (
set_insertion_point,
reset_insertion_point_to_start,
reset_insertion_point_to_end,
check_unregistered_ops,
) # noqa: F401
from . import core
......
......@@ -16,14 +16,18 @@
#include <chrono>
#include <iostream>
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h"
#include "paddle/fluid/ir_adaptor/translator/translate.h"
#include "paddle/fluid/ir_adaptor/translator/utils.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_context.h"
......@@ -85,3 +89,21 @@ TEST(PaddleDialectTest, StartupProgram) {
program->Print(ss);
EXPECT_GT(ss.str().size(), 0u);
}
TEST(RegisterInfoTest, MainProgram) {
auto p = load_from_file("resnet50_startup.prog");
ir::IrContext *ctx = ir::IrContext::Instance();
auto unregistered_ops =
paddle::translator::CheckUnregisteredOperation(ctx, p);
EXPECT_EQ(unregistered_ops.size(), 0u);
auto new_op = std::unique_ptr<OpDesc>(
new OpDesc("something must not be registered", {}, {}, {}));
auto *block = p.MutableBlock(0);
block->AppendAllocatedOp(std::move(new_op));
unregistered_ops = paddle::translator::CheckUnregisteredOperation(ctx, p);
EXPECT_EQ(unregistered_ops.size(), 1u);
EXPECT_EQ(unregistered_ops[0], "something must not be registered");
}
......@@ -325,5 +325,19 @@ class TestShadowOutputSlice(unittest.TestCase):
l = ir.translate_to_new_ir(main_program.desc)
class TestCheckUnregisteredOp(unittest.TestCase):
def test_program(self):
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program):
x = paddle.randn((4, 16))
prev_h = paddle.randn((4, 32))
cell = paddle.nn.SimpleRNNCell(16, 32)
y, h = cell(x, prev_h)
ops = ir.check_unregistered_ops(main_program.desc)
assert len(ops) == 0
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册