From 07223e3474e794514bbcc7a39317fd3e6756e855 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Tue, 23 May 2023 10:29:29 +0800 Subject: [PATCH] [NewIR] Program desc convert to IRProgram (#53707) * Use copy_if_different to avoid recompilation of generated cutlass kernels. * add program parameter dialect_interface * fix op create bug * add conv2d * draft of paddle converter * fix CI * fix windows CI * fix program destructor * printer draft * fix bug * printer draft finish * fix windows CI * reserve inplace semantics * revert program::destroy since no need to do topology sort * revert * modify by reviews * polish * fix op definition * fix CI * refresh file changes --------- Co-authored-by: umiswing Co-authored-by: zhangbo9674 --- paddle/fluid/CMakeLists.txt | 1 + paddle/fluid/translator/CMakeLists.txt | 10 + paddle/fluid/translator/op_translator.cc | 214 ++++++++++++++++++ paddle/fluid/translator/op_translator.h | 70 ++++++ paddle/fluid/translator/program_translator.cc | 91 ++++++++ paddle/fluid/translator/program_translator.h | 53 +++++ paddle/fluid/translator/translate.cc | 40 ++++ paddle/fluid/translator/translate.h | 31 +++ paddle/fluid/translator/type_translator.cc | 60 +++++ paddle/fluid/translator/type_translator.h | 64 ++++++ test/cpp/ir/CMakeLists.txt | 17 ++ test/cpp/ir/program_translator_test.cc | 63 ++++++ 12 files changed, 714 insertions(+) create mode 100644 paddle/fluid/translator/CMakeLists.txt create mode 100644 paddle/fluid/translator/op_translator.cc create mode 100644 paddle/fluid/translator/op_translator.h create mode 100644 paddle/fluid/translator/program_translator.cc create mode 100644 paddle/fluid/translator/program_translator.h create mode 100644 paddle/fluid/translator/translate.cc create mode 100644 paddle/fluid/translator/translate.h create mode 100644 paddle/fluid/translator/type_translator.cc create mode 100644 paddle/fluid/translator/type_translator.h create mode 100644 test/cpp/ir/program_translator_test.cc diff --git a/paddle/fluid/CMakeLists.txt b/paddle/fluid/CMakeLists.txt index 8248dcd3639..c3deb52cb06 100644 --- a/paddle/fluid/CMakeLists.txt +++ b/paddle/fluid/CMakeLists.txt @@ -10,6 +10,7 @@ add_subdirectory(prim) add_subdirectory(jit) if(WITH_NEWIR) add_subdirectory(dialect) + add_subdirectory(translator) endif() # NOTE: please add subdirectory inference at last. add_subdirectory(inference) diff --git a/paddle/fluid/translator/CMakeLists.txt b/paddle/fluid/translator/CMakeLists.txt new file mode 100644 index 00000000000..a443c066746 --- /dev/null +++ b/paddle/fluid/translator/CMakeLists.txt @@ -0,0 +1,10 @@ +set(PD_PROGRAM_TRANSLATOR_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}") +set(PD_PROGRAM_TRANSLATOR_BINARY_DIR + "${PADDLE_BINARY_DIR}/paddle/fluid/translator") + +file(GLOB PD_PROGRAM_TRANSLATOR_SRCS "*.cc") + +cc_library( + program_translator + SRCS ${PD_PROGRAM_TRANSLATOR_SRCS} + DEPS proto_desc pd_dialect new_ir framework_proto) diff --git a/paddle/fluid/translator/op_translator.cc b/paddle/fluid/translator/op_translator.cc new file mode 100644 index 00000000000..c6ff3f94125 --- /dev/null +++ b/paddle/fluid/translator/op_translator.cc @@ -0,0 +1,214 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/translator/op_translator.h" + +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/translator/program_translator.h" +#include "paddle/fluid/translator/type_translator.h" +#include "paddle/ir/ir_context.h" +#include "paddle/ir/value.h" +#include "paddle/phi/core/enforce.h" + +namespace paddle { +namespace translator { + +namespace { + +using ResultIdx = size_t; +using OpDesc = paddle::framework::OpDesc; +using BlockDesc = paddle::framework::BlockDesc; +using VarDesc = paddle::framework::VarDesc; +using OpOutputTypeList = std::vector; +using OpOutputMapping = std::unordered_map; + +static const char kTargetDialectPrefix[] = "pd."; + +inline bool IsInplace(const OpDesc& op_desc) { + bool inplace = false; + auto input_names = op_desc.InputArgumentNames(); + auto output_names = op_desc.OutputArgumentNames(); + + std::vector name_intersection; + std::set_intersection(input_names.begin(), + input_names.end(), + output_names.begin(), + output_names.end(), + std::back_inserter(name_intersection)); + + if (name_intersection.size() > 0) { + std::string redundant_variables = std::accumulate( + std::next(name_intersection.begin()), + name_intersection.end(), + name_intersection[0], + [](std::string a, std::string b) { return a + "," + b; }); + VLOG(4) << "Following variables occur both in inputs and outputs: " + << redundant_variables; + return true; + } + + return inplace; +} + +inline ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) { + std::string target_op_name = kTargetDialectPrefix + op_desc.Type(); + if (IsInplace(op_desc)) { + target_op_name += "_"; + } + auto op_info = ctx->GetRegisteredOpInfo(target_op_name); + if (!op_info) { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "Op %d should have corresponding OpInfo %d", + op_desc.Type(), + target_op_name)); + } + + return op_info; +} + +inline std::vector GenerateOperationInput( + TranslationContext* param_map, const OpDesc& op_desc) { + std::vector op_inputs = {}; + for (const auto& n : op_desc.Inputs()) { + auto& name = n.first; + VLOG(10) << "[input retriving]" + << "[" << op_desc.Type() << "]" << name; + auto& args = n.second; + for (const auto& arg_name : args) { + PADDLE_ENFORCE_NE( + param_map->count(arg_name), + 0, + platform::errors::PreconditionNotMet( + "arg %s as input should be exists before prasing %d", + arg_name, + op_desc.Type())); + op_inputs.push_back((*param_map)[arg_name]); + } + } + return op_inputs; +} + +inline std::tuple GenerateOperationOutput( + ir::IrContext* ctx, const OpDesc& op_desc) { + OpOutputMapping arg_to_idx; + OpOutputTypeList op_output_types = {}; + + auto& type_translator = TypeTranslator::instance(); + + const BlockDesc* block = op_desc.Block(); + for (const auto& n : op_desc.Outputs()) { + auto& name = n.first; + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "]" << name; + auto& args = n.second; + for (const auto& arg_name : args) { + VarDesc* var = block->FindVarRecursive(arg_name); + VLOG(10) << "[output translating]" + << "[" << op_desc.Type() << "]" << name << " " << arg_name << " " + << var->GetType(); + + ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); + + arg_to_idx[arg_name] = op_output_types.size(); + op_output_types.push_back(translated_var_type); + } + } + return {op_output_types, arg_to_idx}; +} + +inline void RecordOpResultMapping(TranslationContext* param_map, + const OpDesc& op_desc, + ir::Operation* operation, + const OpOutputMapping& arg_to_idx) { + for (const auto& n : op_desc.Outputs()) { + auto& name = n.first; + VLOG(10) << "[output recording]" + << "[" << op_desc.Type() << "]" << name; + auto& args = n.second; + for (const auto& arg_name : args) { + auto idx = arg_to_idx.at(arg_name); + VLOG(10) << "[output recording]" + << "[" << op_desc.Type() << "]" << arg_name << " " << idx; + + (*param_map)[arg_name] = operation->GetResultByIndex(idx); + } + } +} + +ir::Operation* GeneralOpHandler(ir::IrContext* ctx, + TranslationContext* param_map, + ir::Program* program, + const OpDesc& op_desc) { + auto op_inputs = GenerateOperationInput(param_map, op_desc); + + OpOutputMapping arg_to_idx; + OpOutputTypeList op_output_types = {}; + std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); + auto op_info = LoopkUpOpInfo(ctx, op_desc); + ir::Operation* operation = + ir::Operation::create(op_inputs, op_output_types, {}, op_info); + program->InsertOp(operation); + RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); + + return operation; +} + +ir::Operation* FeedOpHandler(ir::IrContext* ctx, + TranslationContext* param_map, + ir::Program* program, + const OpDesc& op_desc) { + std::vector op_inputs = {}; + + OpOutputMapping arg_to_idx; + OpOutputTypeList op_output_types = {}; + std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc); + auto op_info = LoopkUpOpInfo(ctx, op_desc); + ir::Operation* operation = + ir::Operation::create(op_inputs, op_output_types, {}, op_info); + program->InsertOp(operation); + RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); + + return operation; +} + +ir::Operation* FetchOpHandler(ir::IrContext* ctx, + TranslationContext* param_map, + ir::Program* program, + const OpDesc& op_desc) { + auto op_inputs = GenerateOperationInput(param_map, op_desc); + + OpOutputTypeList op_output_types = {}; + auto op_info = LoopkUpOpInfo(ctx, op_desc); + ir::Operation* operation = + ir::Operation::create(op_inputs, op_output_types, {}, op_info); + program->InsertOp(operation); + + return operation; +} +} // namespace + +OpTranslator::OpTranslator() : general_handler(GeneralOpHandler) { + special_handlers["feed"] = FeedOpHandler; + special_handlers["fetch_v2"] = FetchOpHandler; +} + +} // namespace translator +} // namespace paddle diff --git a/paddle/fluid/translator/op_translator.h b/paddle/fluid/translator/op_translator.h new file mode 100644 index 00000000000..c767f639d53 --- /dev/null +++ b/paddle/fluid/translator/op_translator.h @@ -0,0 +1,70 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/var_desc.h" +#include "paddle/ir/ir_context.h" +#include "paddle/ir/operation.h" +#include "paddle/ir/program.h" +#include "paddle/ir/value.h" + +namespace paddle { +namespace translator { + +using TranslationContext = std::unordered_map; + +class OpTranslator { + public: + using ResultIdx = size_t; + using OpDesc = paddle::framework::OpDesc; + using BlockDesc = paddle::framework::BlockDesc; + using VarDesc = paddle::framework::VarDesc; + using OpTranslateFn = std::function; + + private: + OpTranslator(); // Disallow instantiation outside of the class. + std::unordered_map special_handlers; + OpTranslateFn general_handler; + + public: + OpTranslator(const OpTranslator&) = delete; + OpTranslator& operator=(const OpTranslator&) = delete; + OpTranslator(OpTranslator&&) = delete; + OpTranslator& operator=(OpTranslator&&) = delete; + + static auto& instance() { + static OpTranslator OpTranslator; + return OpTranslator; + } + + OpTranslateFn& operator[](const std::string& op_type) { + if (special_handlers.count(op_type) == 0) { + return general_handler; + } else { + return special_handlers[op_type]; + } + } +}; + +using OpTranslateFn = OpTranslator::OpTranslateFn; + +} // namespace translator +} // namespace paddle diff --git a/paddle/fluid/translator/program_translator.cc b/paddle/fluid/translator/program_translator.cc new file mode 100644 index 00000000000..7cdbef58906 --- /dev/null +++ b/paddle/fluid/translator/program_translator.cc @@ -0,0 +1,91 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/translator/program_translator.h" + +#include + +#include "glog/logging.h" + +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/translator/op_translator.h" +#include "paddle/ir/attribute.h" +#include "paddle/ir/builtin_op.h" +#include "paddle/ir/builtin_type.h" +#include "paddle/ir/operation.h" +#include "paddle/phi/core/enforce.h" + +namespace paddle { +namespace translator { + +using ProgramDesc = ::paddle::framework::ProgramDesc; +using BlockDesc = ::paddle::framework::BlockDesc; + +ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program, + ir::Program* program) + : legacy_program(legacy_program), program(program) { + ctx = ir::IrContext::Instance(); +} + +void ProgramTranslator::Translate() { + PADDLE_ENFORCE_EQ( + legacy_program->Size(), + 1u, + platform::errors::PreconditionNotMet( + "Not support multi block ProgramDesc translated, now has %d blocks", + legacy_program->Size())); + + for (size_t block_idx = 0; block_idx < legacy_program->Size(); block_idx++) { + const BlockDesc& block = legacy_program->Block(block_idx); + ExtractParameterFromSingleBlock(block); + } + + for (size_t block_idx = 0; block_idx < legacy_program->Size(); block_idx++) { + const BlockDesc& block = legacy_program->Block(block_idx); + InsertOperationToSingleBlock(block); + } +} + +void ProgramTranslator::ExtractParameterFromSingleBlock( + const BlockDesc& block) { + for (auto& var : block.AllVars()) { + if (!var->Persistable()) continue; + if (param_map.count(var->Name()) != 0) continue; + + std::string get_parameter_op_name(ir::GetParameterOp::name()); + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name); + std::unordered_map op_attribute_map = { + {var->Name(), ir::StrAttribute::get(ctx, var->Name())}, + }; + ir::Operation* operation = ir::Operation::create( + {}, {ir::Float32Type::get(ctx)}, op_attribute_map, op_info); + program->InsertOp(operation); + param_map[var->Name()] = operation->GetResultByIndex(0); + VLOG(10) << "[op translated][get parameter]" << operation; + + program->SetParameter(var->Name(), nullptr); + } +} + +void ProgramTranslator::InsertOperationToSingleBlock(const BlockDesc& block) { + auto& op_translator = OpTranslator::instance(); + for (auto op : block.AllOps()) { + OpTranslateFn& fn = op_translator[op->Type()]; + ir::Operation* operation = fn(ctx, ¶m_map, program, *op); + VLOG(10) << "[op translated][special]" << operation; + } +} + +} // namespace translator +} // namespace paddle diff --git a/paddle/fluid/translator/program_translator.h b/paddle/fluid/translator/program_translator.h new file mode 100644 index 00000000000..569b93b06aa --- /dev/null +++ b/paddle/fluid/translator/program_translator.h @@ -0,0 +1,53 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/ir/ir_context.h" +#include "paddle/ir/program.h" +#include "paddle/ir/value.h" + +namespace paddle { +namespace translator { + +using TranslationContext = std::unordered_map; + +class ProgramTranslator { + using ProgramDesc = ::paddle::framework::ProgramDesc; + using BlockDesc = ::paddle::framework::BlockDesc; + + public: + explicit ProgramTranslator(const ProgramDesc* legacy_program, + ir::Program* program); + + void Translate(); + + private: + const ProgramDesc* legacy_program; + ir::Program* program; + TranslationContext param_map; + ir::IrContext* ctx; + + void ExtractParameterFromSingleBlock(const BlockDesc& block); + void InsertOperationToSingleBlock(const BlockDesc& block); +}; + +} // namespace translator +} // namespace paddle diff --git a/paddle/fluid/translator/translate.cc b/paddle/fluid/translator/translate.cc new file mode 100644 index 00000000000..40af6ce5394 --- /dev/null +++ b/paddle/fluid/translator/translate.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/translator/translate.h" + +#include + +#include "paddle/fluid/dialect/pd_dialect.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/translator/program_translator.h" +#include "paddle/ir/program.h" + +namespace paddle { + +using LegacyProgramDesc = ::paddle::framework::ProgramDesc; +using Program = ::ir::Program; + +std::unique_ptr TranslateLegacyProgramToProgram( + const LegacyProgramDesc& legacy_program) { + auto program = std::make_unique(); + + translator::ProgramTranslator program_translator(&legacy_program, + program.get()); + program_translator.Translate(); + + return program; +} + +} // namespace paddle diff --git a/paddle/fluid/translator/translate.h b/paddle/fluid/translator/translate.h new file mode 100644 index 00000000000..aa2571f74c4 --- /dev/null +++ b/paddle/fluid/translator/translate.h @@ -0,0 +1,31 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/dialect/pd_dialect.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/ir/program.h" + +namespace paddle { + +using LegacyProgramDesc = ::paddle::framework::ProgramDesc; +using Program = ::ir::Program; + +std::unique_ptr TranslateLegacyProgramToProgram( + const LegacyProgramDesc& legacy_program); + +} // namespace paddle diff --git a/paddle/fluid/translator/type_translator.cc b/paddle/fluid/translator/type_translator.cc new file mode 100644 index 00000000000..9792a4b8537 --- /dev/null +++ b/paddle/fluid/translator/type_translator.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/translator/type_translator.h" + +#include "paddle/fluid/dialect/pd_type.h" +#include "paddle/fluid/dialect/pd_type_storage.h" +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/ir/builtin_type.h" + +namespace paddle { +namespace translator { + +using DenseTensorType = paddle::dialect::DenseTensorType; +using DenseTensorTypeStorage = paddle::dialect::DenseTensorTypeStorage; + +TypeTranslator::TypeTranslator() { + handlers = { + {VarType::INT64, + [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { + return ir::Int64Type::get(ctx); + }}, + {VarType::FP32, + [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { + return ir::Float32Type::get(ctx); + }}, + {VarType::FP64, + [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { + return ir::Float64Type::get(ctx); + }}, + {VarType::LOD_TENSOR, + [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { + VLOG(10) << "[vartype translating]" + << "[" << var_desc.Name() << "]" << var_desc.GetDataType(); + + ir::Type dtype = + this->operator[](var_desc.GetDataType())(ctx, var_desc); + DenseTensorTypeStorage::Dim dim = var_desc.GetShape(); + DenseTensorTypeStorage::DataLayout layout = + DenseTensorTypeStorage::DataLayout::UNDEFINED; + DenseTensorTypeStorage::LoD lod = {}; + size_t offset = 0; + return DenseTensorType::get(ctx, dtype, dim, layout, lod, offset); + }}, + }; +} + +} // namespace translator +} // namespace paddle diff --git a/paddle/fluid/translator/type_translator.h b/paddle/fluid/translator/type_translator.h new file mode 100644 index 00000000000..b16c1a222a5 --- /dev/null +++ b/paddle/fluid/translator/type_translator.h @@ -0,0 +1,64 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/var_desc.h" +#include "paddle/fluid/translator/program_translator.h" +#include "paddle/ir/builtin_type.h" +#include "paddle/ir/dialect.h" +#include "paddle/ir/ir_context.h" + +namespace paddle { +namespace translator { + +using OpDesc = paddle::framework::OpDesc; +using BlockDesc = paddle::framework::BlockDesc; +using VarDesc = paddle::framework::VarDesc; +using VarType = paddle::framework::proto::VarType; +using TypeTranslateFn = std::function; + +class TypeTranslator { + private: + TypeTranslator(); // Disallow instantiation outside of the class. + std::unordered_map handlers; + + public: + TypeTranslator(const TypeTranslator&) = delete; + TypeTranslator& operator=(const TypeTranslator&) = delete; + TypeTranslator(TypeTranslator&&) = delete; + TypeTranslator& operator=(TypeTranslator&&) = delete; + + static auto& instance() { + static TypeTranslator TypeTranslator; + return TypeTranslator; + } + + TypeTranslateFn& operator[](VarType::Type type) { + PADDLE_ENFORCE_NE( + handlers.count(type), + 0, + platform::errors::PreconditionNotMet( + "ProtoType %d has no corresponding translator", type)); + + return handlers[type]; + } +}; + +} // namespace translator +} // namespace paddle diff --git a/test/cpp/ir/CMakeLists.txt b/test/cpp/ir/CMakeLists.txt index 0c5385a32fb..9dc8d85f8dc 100644 --- a/test/cpp/ir/CMakeLists.txt +++ b/test/cpp/ir/CMakeLists.txt @@ -12,4 +12,21 @@ if(WITH_NEWIR) pd_dialect phi gtest) + + file( + DOWNLOAD + https://paddle-ci.gz.bcebos.com/ir_translator_test/restnet50_main.prog + ${CMAKE_CURRENT_BINARY_DIR}/restnet50_main.prog + EXPECTED_MD5 b64c0ad3c96d99fc37d12094623ce1ad) + + cc_test_old( + program_translator_test + SRCS + program_translator_test.cc + DEPS + program_translator + gtest + new_ir + pd_dialect) + endif() diff --git a/test/cpp/ir/program_translator_test.cc b/test/cpp/ir/program_translator_test.cc new file mode 100644 index 00000000000..e88bbb2cf39 --- /dev/null +++ b/test/cpp/ir/program_translator_test.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "paddle/fluid/dialect/pd_dialect.h" +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/translator/translate.h" +#include "paddle/ir/builtin_dialect.h" +#include "paddle/ir/dialect.h" +#include "paddle/ir/ir_context.h" +#include "paddle/ir/program.h" + +using PaddleDialect = paddle::dialect::PaddleDialect; +using ProgramDesc = paddle::framework::ProgramDesc; +using BlockDesc = paddle::framework::BlockDesc; +using OpDesc = paddle::framework::OpDesc; +using VarDesc = paddle::framework::VarDesc; +using VarType = paddle::framework::proto::VarType; + +ProgramDesc load_from_file(const std::string &file_name) { + std::ifstream fin(file_name, std::ios::in | std::ios::binary); + fin.seekg(0, std::ios::end); + + std::string buffer(fin.tellg(), ' '); + fin.seekg(0, std::ios::beg); + fin.read(&buffer[0], buffer.size()); + fin.close(); + return ProgramDesc(buffer); +} + +TEST(PaddleDialectTest, Translator) { + auto p = load_from_file("restnet50_main.prog"); + std::cout << p.Size() << std::endl; + + EXPECT_EQ(p.Size(), 1u); + + ir::IrContext *ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + auto program = paddle::TranslateLegacyProgramToProgram(p); + + std::list ops = program->ops(); + EXPECT_EQ(ops.size(), p.Block(0).OpSize() + program->parameters_num()); + VLOG(0) << *program << std::endl; +} -- GitLab