diff --git a/.gitignore b/.gitignore index a9db6cfcf10274a5a8ac7e3d0da63b815bcce0fd..6195e2473081ff5d46eb8876946f6fb11defaef8 100644 --- a/.gitignore +++ b/.gitignore @@ -96,3 +96,4 @@ paddle/phi/api/profiler/__init__.py python/paddle/incubate/fleet/parameter_server/pslib/ps_pb2.py paddle/phi/kernels/fusion/cutlass/conv2d/generated/* python/paddle/fluid/incubate/fleet/parameter_server/pslib/ps_pb2.py +paddle/fluid/translator/op_compat_info.cc diff --git a/paddle/fluid/dialect/legacy_pd_op.h b/paddle/fluid/dialect/legacy_pd_op.h index 21be24720dc3fae58f99abdb0cf73c4356d5a144..3d06cd4d64c1fa590a4d655cd0799bab5bdd8b89 100644 --- a/paddle/fluid/dialect/legacy_pd_op.h +++ b/paddle/fluid/dialect/legacy_pd_op.h @@ -79,6 +79,19 @@ REIGSTER_EMPTY_OP(batch_norm_grad, REIGSTER_EMPTY_OP(conv2d_grad, Conv2DGradOp); // To be customized: conv2d_grad REIGSTER_EMPTY_OP(sum, SumOp); // To be customized: sum(reduce_sum) REIGSTER_EMPTY_OP(fetch_v2, FetchV2Op); // To be customized: fetch_v2 +REIGSTER_EMPTY_OP(add, AddOp); +REIGSTER_EMPTY_OP(add_grad, AddGradOp); +REIGSTER_EMPTY_OP(matmul, MatMulOp); +REIGSTER_EMPTY_OP(matmul_grad, MatMulGradOp); +REIGSTER_EMPTY_OP(reshape, ReshapeOp); +REIGSTER_EMPTY_OP(reshape_grad, ReshapeGradOp); +REIGSTER_EMPTY_OP(mean, MeanOp); +REIGSTER_EMPTY_OP(cross_entropy_with_softmax, CrossEntropyOp); +REIGSTER_EMPTY_OP(cross_entropy_with_softmax_grad, CrossEntropyGradOp); +REIGSTER_EMPTY_OP(topk, TopKOp); +REIGSTER_EMPTY_OP(topk_grad, TopKGradOp); +REIGSTER_EMPTY_OP(full, FullOp); +REIGSTER_EMPTY_OP(add_n, AddNOp); } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/dialect/pd_dialect.cc b/paddle/fluid/dialect/pd_dialect.cc index 9baeb4c1f9b1d0bca41bb877c752d851bbf768ec..a11dd5fb6da7747861c2a18780dad66144e824bc 100644 --- a/paddle/fluid/dialect/pd_dialect.cc +++ b/paddle/fluid/dialect/pd_dialect.cc @@ -133,7 +133,20 @@ void PaddleDialect::initialize() { BatchNormGradOp, Conv2DGradOp, SumOp, - FetchV2Op>(); + FetchV2Op, + AddOp, + MatMulOp, + ReshapeOp, + CrossEntropyOp, + TopKOp, + FullOp, + MeanOp, + AddNOp, + AddGradOp, + MatMulGradOp, + ReshapeGradOp, + CrossEntropyGradOp, + TopKGradOp>(); } void PaddleDialect::PrintType(ir::Type type, std::ostream &os) { diff --git a/paddle/fluid/translator/CMakeLists.txt b/paddle/fluid/translator/CMakeLists.txt index a443c06674648e8ae21f3e2a751a1ff2e5a5709f..2ffd12be5c80b046cc76448c8b942b3b8d6ca918 100644 --- a/paddle/fluid/translator/CMakeLists.txt +++ b/paddle/fluid/translator/CMakeLists.txt @@ -2,9 +2,20 @@ set(PD_PROGRAM_TRANSLATOR_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}") set(PD_PROGRAM_TRANSLATOR_BINARY_DIR "${PADDLE_BINARY_DIR}/paddle/fluid/translator") +set(op_gen_file ${PD_PROGRAM_TRANSLATOR_SOURCE_DIR}/op_compat_gen.py) +set(op_compat_yaml_file ${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml) +set(op_compat_source_file ${PD_PROGRAM_TRANSLATOR_SOURCE_DIR}/op_compat_info.cc) + +add_custom_command( + OUTPUT ${op_compat_source_file} + COMMAND ${PYTHON_EXECUTABLE} ${op_gen_file} --op_compat_yaml_file + ${op_compat_yaml_file} --output_source_file ${op_compat_source_file} + DEPENDS ${op_gen_file} ${op_compat_yaml_file} + VERBATIM) + file(GLOB PD_PROGRAM_TRANSLATOR_SRCS "*.cc") cc_library( program_translator - SRCS ${PD_PROGRAM_TRANSLATOR_SRCS} + SRCS ${PD_PROGRAM_TRANSLATOR_SRCS} ${op_compat_source_file} DEPS proto_desc pd_dialect new_ir framework_proto) diff --git a/paddle/fluid/translator/op_compat_gen.py b/paddle/fluid/translator/op_compat_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..d6aeeeaf8ee4bf9f206b075ada81296d8c181e64 --- /dev/null +++ b/paddle/fluid/translator/op_compat_gen.py @@ -0,0 +1,86 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from pathlib import Path + +import yaml +from jinja2 import Environment, FileSystemLoader, StrictUndefined + +file_loader = FileSystemLoader(Path(__file__).parent) +env = Environment( + loader=file_loader, + keep_trailing_newline=True, + trim_blocks=True, + lstrip_blocks=True, + undefined=StrictUndefined, + extensions=['jinja2.ext.do'], +) + + +def OpNameNormalizerInitialization( + op_compat_yaml_file: str = "", output_source_file: str = "" +) -> None: + def to_phi_and_fluid_op_name(op_item): + # Templat: - op : phi_name (fluid_name) + names = op_item.split('(') + if len(names) == 1: + phi_fluid_name = names[0].strip() + return phi_fluid_name, phi_fluid_name + else: + phi_name = names[0].strip() + fluid_name = names[1].split(')')[0].strip() + return phi_name, fluid_name + + with open(op_compat_yaml_file, "r") as f: + op_compat_infos = yaml.safe_load(f) + op_name_mappings = {} + for op_compat_item in op_compat_infos: + + def insert_new_mappings(op_name_str): + normalized_name, legacy_name = to_phi_and_fluid_op_name(op_name_str) + if normalized_name == legacy_name: + return + op_name_mappings[legacy_name] = normalized_name + + insert_new_mappings(op_compat_item["op"]) + if "backward" in op_compat_item: + insert_new_mappings(op_compat_item["backward"]) + op_name_normailzer_template = env.get_template("op_compat_info.cc.j2") + with open(output_source_file, 'wt') as f: + op_compat_definition = op_name_normailzer_template.render( + op_name_paris=op_name_mappings + ) + f.write(op_compat_definition) + + +# ===================================== +# Script parameter parsing +# ===================================== +def ParseArguments(): + parser = argparse.ArgumentParser( + description='Generate OP Compatiable info Files By Yaml' + ) + parser.add_argument('--op_compat_yaml_file', type=str) + parser.add_argument('--output_source_file', type=str) + return parser.parse_args() + + +# ===================================== +# Main +# ===================================== +if __name__ == "__main__": + # parse arguments + args = ParseArguments() + OpNameNormalizerInitialization(**vars(args)) diff --git a/paddle/fluid/translator/op_compat_info.cc.j2 b/paddle/fluid/translator/op_compat_info.cc.j2 new file mode 100644 index 0000000000000000000000000000000000000000..af42cf9b8abdc391dcdeebb27af5bb42d1cb9c4e --- /dev/null +++ b/paddle/fluid/translator/op_compat_info.cc.j2 @@ -0,0 +1,15 @@ +#include "paddle/fluid/translator/op_compat_info.h" + +namespace paddle { +namespace translator { + +OpNameNormalizer::OpNameNormalizer() { + op_name_mappings = { + {% for legacy_name, normalized_name in op_name_paris.items() %} + { "{{legacy_name}}", "{{normalized_name}}" }, + {% endfor %} + }; +} + +} // namespace translator +}// namespace paddle diff --git a/paddle/fluid/translator/op_compat_info.h b/paddle/fluid/translator/op_compat_info.h new file mode 100644 index 0000000000000000000000000000000000000000..86acafe7a0f1a5cb480b8881ba99cc27c41bdc11 --- /dev/null +++ b/paddle/fluid/translator/op_compat_info.h @@ -0,0 +1,50 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "glog/logging.h" + +#pragma once + +namespace paddle { +namespace translator { + +class OpNameNormalizer { + private: + OpNameNormalizer(); // Disallow instantiation outside of the class. + std::unordered_map op_name_mappings; + + public: + OpNameNormalizer(const OpNameNormalizer&) = delete; + OpNameNormalizer& operator=(const OpNameNormalizer&) = delete; + OpNameNormalizer(OpNameNormalizer&&) = delete; + OpNameNormalizer& operator=(OpNameNormalizer&&) = delete; + + static auto& instance() { + static OpNameNormalizer OpNameNormalizer; + return OpNameNormalizer; + } + + std::string operator[](const std::string& op_type) { + if (op_name_mappings.find(op_type) == op_name_mappings.end()) { + return op_type; + } + return op_name_mappings.at(op_type); + } +}; + +} // namespace translator +} // namespace paddle diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc index 65102a4eb0b1928dd22d045d9d76dd8cc1fc7db1..9654f1f09ec558751f1afd3c54151b25d8221749 100644 --- a/paddle/fluid/translator/op_translator.cc +++ b/paddle/fluid/translator/op_translator.cc @@ -22,6 +22,7 @@ #include #include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/translator/op_compat_info.h" #include "paddle/fluid/translator/program_translator.h" #include "paddle/fluid/translator/type_translator.h" #include "paddle/ir/builtin_op.h" @@ -70,11 +71,19 @@ inline bool IsInplace(const OpDesc& op_desc) { return inplace; } +inline std::string OpNamecompatibleMapping(std::string op_name) { + auto& op_normalizer = OpNameNormalizer::instance(); + return op_normalizer[op_name]; +} + inline ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) { - std::string target_op_name = kTargetDialectPrefix + op_desc.Type(); + std::string target_op_name = + kTargetDialectPrefix + OpNamecompatibleMapping(op_desc.Type()); if (IsInplace(op_desc)) { target_op_name += "_"; } + VLOG(6) << "[op name normalizing: " << op_desc.Type() << " to " + << target_op_name; auto op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { PADDLE_THROW(platform::errors::PreconditionNotMet( diff --git a/test/cpp/ir/program_translator_test.cc b/test/cpp/ir/program_translator_test.cc index 5ae7b0b9b31e7ef8446956145ecfb39c5606a7f0..0035f860c5861b0527316c3a8b4a32da9343694d 100644 --- a/test/cpp/ir/program_translator_test.cc +++ b/test/cpp/ir/program_translator_test.cc @@ -49,8 +49,6 @@ ProgramDesc load_from_file(const std::string &file_name) { TEST(PaddleDialectTest, Translator) { LOG(WARNING) << "TODO"; // auto p = load_from_file("restnet50_main.prog"); - // std::cout << p.Size() << std::endl; - // EXPECT_EQ(p.Size(), 1u); // ir::IrContext *ctx = ir::IrContext::Instance(); @@ -58,8 +56,8 @@ TEST(PaddleDialectTest, Translator) { // ctx->GetOrRegisterDialect(); // auto program = paddle::TranslateLegacyProgramToProgram(p); - // std::list ops = program->ops(); - // ops.size() = op size in BlockDesc + get_parameter_op + combine op - // EXPECT_EQ(ops.size(), p.Block(0).OpSize() + program->parameters_num() + - // 20); std::cout << *program << std::endl; + // size_t op_size = program->block()->size(); + // // ops.size() = op size in BlockDesc + get_parameter_op + combine op + // EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 20); + // VLOG(0) << *program; }