未验证 提交 64661927 编写于 作者: K kangguangli 提交者: GitHub

[IR] add op name normalizer (#54143)

* add op name normalizer

* disable unittest
上级 04d6afc9
...@@ -96,3 +96,4 @@ paddle/phi/api/profiler/__init__.py ...@@ -96,3 +96,4 @@ paddle/phi/api/profiler/__init__.py
python/paddle/incubate/fleet/parameter_server/pslib/ps_pb2.py python/paddle/incubate/fleet/parameter_server/pslib/ps_pb2.py
paddle/phi/kernels/fusion/cutlass/conv2d/generated/* paddle/phi/kernels/fusion/cutlass/conv2d/generated/*
python/paddle/fluid/incubate/fleet/parameter_server/pslib/ps_pb2.py python/paddle/fluid/incubate/fleet/parameter_server/pslib/ps_pb2.py
paddle/fluid/translator/op_compat_info.cc
...@@ -79,6 +79,19 @@ REIGSTER_EMPTY_OP(batch_norm_grad, ...@@ -79,6 +79,19 @@ REIGSTER_EMPTY_OP(batch_norm_grad,
REIGSTER_EMPTY_OP(conv2d_grad, Conv2DGradOp); // To be customized: conv2d_grad REIGSTER_EMPTY_OP(conv2d_grad, Conv2DGradOp); // To be customized: conv2d_grad
REIGSTER_EMPTY_OP(sum, SumOp); // To be customized: sum(reduce_sum) REIGSTER_EMPTY_OP(sum, SumOp); // To be customized: sum(reduce_sum)
REIGSTER_EMPTY_OP(fetch_v2, FetchV2Op); // To be customized: fetch_v2 REIGSTER_EMPTY_OP(fetch_v2, FetchV2Op); // To be customized: fetch_v2
REIGSTER_EMPTY_OP(add, AddOp);
REIGSTER_EMPTY_OP(add_grad, AddGradOp);
REIGSTER_EMPTY_OP(matmul, MatMulOp);
REIGSTER_EMPTY_OP(matmul_grad, MatMulGradOp);
REIGSTER_EMPTY_OP(reshape, ReshapeOp);
REIGSTER_EMPTY_OP(reshape_grad, ReshapeGradOp);
REIGSTER_EMPTY_OP(mean, MeanOp);
REIGSTER_EMPTY_OP(cross_entropy_with_softmax, CrossEntropyOp);
REIGSTER_EMPTY_OP(cross_entropy_with_softmax_grad, CrossEntropyGradOp);
REIGSTER_EMPTY_OP(topk, TopKOp);
REIGSTER_EMPTY_OP(topk_grad, TopKGradOp);
REIGSTER_EMPTY_OP(full, FullOp);
REIGSTER_EMPTY_OP(add_n, AddNOp);
} // namespace dialect } // namespace dialect
} // namespace paddle } // namespace paddle
...@@ -133,7 +133,20 @@ void PaddleDialect::initialize() { ...@@ -133,7 +133,20 @@ void PaddleDialect::initialize() {
BatchNormGradOp, BatchNormGradOp,
Conv2DGradOp, Conv2DGradOp,
SumOp, SumOp,
FetchV2Op>(); FetchV2Op,
AddOp,
MatMulOp,
ReshapeOp,
CrossEntropyOp,
TopKOp,
FullOp,
MeanOp,
AddNOp,
AddGradOp,
MatMulGradOp,
ReshapeGradOp,
CrossEntropyGradOp,
TopKGradOp>();
} }
void PaddleDialect::PrintType(ir::Type type, std::ostream &os) { void PaddleDialect::PrintType(ir::Type type, std::ostream &os) {
......
...@@ -2,9 +2,20 @@ set(PD_PROGRAM_TRANSLATOR_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}") ...@@ -2,9 +2,20 @@ set(PD_PROGRAM_TRANSLATOR_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}")
set(PD_PROGRAM_TRANSLATOR_BINARY_DIR set(PD_PROGRAM_TRANSLATOR_BINARY_DIR
"${PADDLE_BINARY_DIR}/paddle/fluid/translator") "${PADDLE_BINARY_DIR}/paddle/fluid/translator")
set(op_gen_file ${PD_PROGRAM_TRANSLATOR_SOURCE_DIR}/op_compat_gen.py)
set(op_compat_yaml_file ${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml)
set(op_compat_source_file ${PD_PROGRAM_TRANSLATOR_SOURCE_DIR}/op_compat_info.cc)
add_custom_command(
OUTPUT ${op_compat_source_file}
COMMAND ${PYTHON_EXECUTABLE} ${op_gen_file} --op_compat_yaml_file
${op_compat_yaml_file} --output_source_file ${op_compat_source_file}
DEPENDS ${op_gen_file} ${op_compat_yaml_file}
VERBATIM)
file(GLOB PD_PROGRAM_TRANSLATOR_SRCS "*.cc") file(GLOB PD_PROGRAM_TRANSLATOR_SRCS "*.cc")
cc_library( cc_library(
program_translator program_translator
SRCS ${PD_PROGRAM_TRANSLATOR_SRCS} SRCS ${PD_PROGRAM_TRANSLATOR_SRCS} ${op_compat_source_file}
DEPS proto_desc pd_dialect new_ir framework_proto) DEPS proto_desc pd_dialect new_ir framework_proto)
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from pathlib import Path
import yaml
from jinja2 import Environment, FileSystemLoader, StrictUndefined
file_loader = FileSystemLoader(Path(__file__).parent)
env = Environment(
loader=file_loader,
keep_trailing_newline=True,
trim_blocks=True,
lstrip_blocks=True,
undefined=StrictUndefined,
extensions=['jinja2.ext.do'],
)
def OpNameNormalizerInitialization(
op_compat_yaml_file: str = "", output_source_file: str = ""
) -> None:
def to_phi_and_fluid_op_name(op_item):
# Templat: - op : phi_name (fluid_name)
names = op_item.split('(')
if len(names) == 1:
phi_fluid_name = names[0].strip()
return phi_fluid_name, phi_fluid_name
else:
phi_name = names[0].strip()
fluid_name = names[1].split(')')[0].strip()
return phi_name, fluid_name
with open(op_compat_yaml_file, "r") as f:
op_compat_infos = yaml.safe_load(f)
op_name_mappings = {}
for op_compat_item in op_compat_infos:
def insert_new_mappings(op_name_str):
normalized_name, legacy_name = to_phi_and_fluid_op_name(op_name_str)
if normalized_name == legacy_name:
return
op_name_mappings[legacy_name] = normalized_name
insert_new_mappings(op_compat_item["op"])
if "backward" in op_compat_item:
insert_new_mappings(op_compat_item["backward"])
op_name_normailzer_template = env.get_template("op_compat_info.cc.j2")
with open(output_source_file, 'wt') as f:
op_compat_definition = op_name_normailzer_template.render(
op_name_paris=op_name_mappings
)
f.write(op_compat_definition)
# =====================================
# Script parameter parsing
# =====================================
def ParseArguments():
parser = argparse.ArgumentParser(
description='Generate OP Compatiable info Files By Yaml'
)
parser.add_argument('--op_compat_yaml_file', type=str)
parser.add_argument('--output_source_file', type=str)
return parser.parse_args()
# =====================================
# Main
# =====================================
if __name__ == "__main__":
# parse arguments
args = ParseArguments()
OpNameNormalizerInitialization(**vars(args))
#include "paddle/fluid/translator/op_compat_info.h"
namespace paddle {
namespace translator {
OpNameNormalizer::OpNameNormalizer() {
op_name_mappings = {
{% for legacy_name, normalized_name in op_name_paris.items() %}
{ "{{legacy_name}}", "{{normalized_name}}" },
{% endfor %}
};
}
} // namespace translator
}// namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <unordered_map>
#include "glog/logging.h"
#pragma once
namespace paddle {
namespace translator {
class OpNameNormalizer {
private:
OpNameNormalizer(); // Disallow instantiation outside of the class.
std::unordered_map<std::string, std::string> op_name_mappings;
public:
OpNameNormalizer(const OpNameNormalizer&) = delete;
OpNameNormalizer& operator=(const OpNameNormalizer&) = delete;
OpNameNormalizer(OpNameNormalizer&&) = delete;
OpNameNormalizer& operator=(OpNameNormalizer&&) = delete;
static auto& instance() {
static OpNameNormalizer OpNameNormalizer;
return OpNameNormalizer;
}
std::string operator[](const std::string& op_type) {
if (op_name_mappings.find(op_type) == op_name_mappings.end()) {
return op_type;
}
return op_name_mappings.at(op_type);
}
};
} // namespace translator
} // namespace paddle
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/translator/op_compat_info.h"
#include "paddle/fluid/translator/program_translator.h" #include "paddle/fluid/translator/program_translator.h"
#include "paddle/fluid/translator/type_translator.h" #include "paddle/fluid/translator/type_translator.h"
#include "paddle/ir/builtin_op.h" #include "paddle/ir/builtin_op.h"
...@@ -70,11 +71,19 @@ inline bool IsInplace(const OpDesc& op_desc) { ...@@ -70,11 +71,19 @@ inline bool IsInplace(const OpDesc& op_desc) {
return inplace; return inplace;
} }
inline std::string OpNamecompatibleMapping(std::string op_name) {
auto& op_normalizer = OpNameNormalizer::instance();
return op_normalizer[op_name];
}
inline ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) { inline ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) {
std::string target_op_name = kTargetDialectPrefix + op_desc.Type(); std::string target_op_name =
kTargetDialectPrefix + OpNamecompatibleMapping(op_desc.Type());
if (IsInplace(op_desc)) { if (IsInplace(op_desc)) {
target_op_name += "_"; target_op_name += "_";
} }
VLOG(6) << "[op name normalizing: " << op_desc.Type() << " to "
<< target_op_name;
auto op_info = ctx->GetRegisteredOpInfo(target_op_name); auto op_info = ctx->GetRegisteredOpInfo(target_op_name);
if (!op_info) { if (!op_info) {
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
......
...@@ -49,8 +49,6 @@ ProgramDesc load_from_file(const std::string &file_name) { ...@@ -49,8 +49,6 @@ ProgramDesc load_from_file(const std::string &file_name) {
TEST(PaddleDialectTest, Translator) { TEST(PaddleDialectTest, Translator) {
LOG(WARNING) << "TODO"; LOG(WARNING) << "TODO";
// auto p = load_from_file("restnet50_main.prog"); // auto p = load_from_file("restnet50_main.prog");
// std::cout << p.Size() << std::endl;
// EXPECT_EQ(p.Size(), 1u); // EXPECT_EQ(p.Size(), 1u);
// ir::IrContext *ctx = ir::IrContext::Instance(); // ir::IrContext *ctx = ir::IrContext::Instance();
...@@ -58,8 +56,8 @@ TEST(PaddleDialectTest, Translator) { ...@@ -58,8 +56,8 @@ TEST(PaddleDialectTest, Translator) {
// ctx->GetOrRegisterDialect<ir::BuiltinDialect>(); // ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
// auto program = paddle::TranslateLegacyProgramToProgram(p); // auto program = paddle::TranslateLegacyProgramToProgram(p);
// std::list<ir::Operation *> ops = program->ops(); // size_t op_size = program->block()->size();
// ops.size() = op size in BlockDesc + get_parameter_op + combine op // // ops.size() = op size in BlockDesc + get_parameter_op + combine op
// EXPECT_EQ(ops.size(), p.Block(0).OpSize() + program->parameters_num() + // EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 20);
// 20); std::cout << *program << std::endl; // VLOG(0) << *program;
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册