diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index aa650d60ac20af390a2d1ce7200c9dfd5b027527..7552ee29669b80b27a1ccae419fb5cac81c96ef3 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -31,6 +31,7 @@ #include "paddle/fluid/ir_adaptor/translator/op_compat_info.h" #include "paddle/fluid/ir_adaptor/translator/program_translator.h" #include "paddle/fluid/ir_adaptor/translator/type_translator.h" +#include "paddle/fluid/ir_adaptor/translator/utils.h" #include "paddle/ir/core/builder.h" #include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_type.h" @@ -124,30 +125,6 @@ inline std::string OpNameCompatibleMapping(std::string op_name) { return op_normalizer[op_name]; } -inline ir::Operation* InsertSliceOperationForTarget( - ir::IrContext* ctx, - TranslationContext* param_map, - ir::Program* program, - const VariableDefiningInfo& defining_info, - const std::string& arg_name) { - std::string slice_op_name(ir::SliceOp::name()); - ir::OpInfo op_info = ctx->GetRegisteredOpInfo(slice_op_name); - std::unordered_map op_attribute_map = { - {"index", ir::Int32Attribute::get(ctx, defining_info.idx_in_vector)}, - }; - ir::VectorType src_vec_type = - defining_info.value.type().dyn_cast(); - ir::Operation* operation = - ir::Operation::Create({defining_info.value}, - op_attribute_map, - {src_vec_type[defining_info.idx_in_vector]}, - op_info); - program->block()->push_back(operation); - ir::OpResult target_op_result = operation->result(0); - (*param_map)[arg_name] = VariableDefiningInfo(target_op_result); - return operation; -} - inline ir::Operation* InsertCombineOperationForTarget( ir::IrContext* ctx, TranslationContext* param_map, @@ -307,6 +284,11 @@ struct OpTranscriber { const std::string& input_name) { return nullptr; } + virtual void InsertSliceOperationForInput(ir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + const OpInputInfoList& input_infos, + ir::Program* program); }; ir::OpInfo OpTranscriber::LoopkUpOpInfo(ir::IrContext* ctx, @@ -328,26 +310,20 @@ ir::OpInfo OpTranscriber::LoopkUpOpInfo(ir::IrContext* ctx, return op_info; } -std::vector OpTranscriber::GenerateOperationInput( +void OpTranscriber::InsertSliceOperationForInput( ir::IrContext* ctx, TranslationContext* param_map, const OpDesc& op_desc, - const std::string& normalized_op_name, const OpInputInfoList& input_infos, ir::Program* program) { - VLOG(10) << "[op:" << op_desc.Type() << "][input] entrance"; - auto& op_normalizer = OpNameNormalizer::instance(); - const auto* mutable_attributes = - op_normalizer.GetMutableAttributes(op_desc.Type()); - std::set yaml_input_set; for (const auto& info : input_infos) { std::string legacy_input_name = op_normalizer.GetLegacyArgName(op_desc.Type(), info.name); - yaml_input_set.insert(legacy_input_name); } + // scan all inputs to see if any of them is generated as a vector // so need an additional `SliceOp` to take it out. for (const auto& n : op_desc.Inputs()) { @@ -366,9 +342,25 @@ std::vector OpTranscriber::GenerateOperationInput( if (defining_info.generated_by_vector) { InsertSliceOperationForTarget( ctx, param_map, program, defining_info, arg_name); + VLOG(8) << "[op:" << op_desc.Type() + << "] insert slice for var: " << arg_name; } } } +} + +std::vector OpTranscriber::GenerateOperationInput( + ir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + const std::string& normalized_op_name, + const OpInputInfoList& input_infos, + ir::Program* program) { + VLOG(10) << "[op:" << op_desc.Type() << "][input] entrance"; + + auto& op_normalizer = OpNameNormalizer::instance(); + const auto* mutable_attributes = + op_normalizer.GetMutableAttributes(op_desc.Type()); VLOG(10) << "[op:" << op_desc.Type() << "][input] start"; @@ -540,8 +532,8 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx, op_desc.Type(), var_name); VLOG(10) << "[output translating]" - << "[" << op_desc.Type() << "]" << info.name << " " << var_name - << " " << var->GetType(); + << "[" << op_desc.Type() << "]" << info.name + << " var: " << var_name << " type: " << var->GetType(); ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); @@ -552,7 +544,7 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx, } else { VLOG(10) << "[output translating]" << "[" << op_desc.Type() << "]" << info.name << " :" - << info.type_name << " " << legacy_output_name; + << info.type_name << " var: " << legacy_output_name; std::vector types; for (const auto& var_name : legacy_output_vars) { if (var_name == kEmptyVarName) { @@ -562,8 +554,8 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx, } VarDesc* var = block->FindVarRecursive(var_name); VLOG(10) << "[output translating]" - << "[" << op_desc.Type() << "]" << info.name << " " << var_name - << " " << var->GetType(); + << "[" << op_desc.Type() << "]" << info.name + << " var: " << var_name << " type: " << var->GetType(); ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); types.push_back(translated_var_type); @@ -631,6 +623,7 @@ void OpTranscriber::RecordOpResultMapping(ir::IrContext* ctx, size_t idx_in_vector = 0; for (const auto& arg_name : args) { if (arg_name == kEmptyVarName) { + idx_in_vector++; continue; } auto idx_iter = arg_to_idx.find(arg_name); @@ -678,6 +671,9 @@ ir::Operation* OpTranscriber::operator()(ir::IrContext* ctx, std::tie(input_infos, attr_infos, output_infos, std::ignore) = op_info_concept->get_op_info_(); + this->InsertSliceOperationForInput( + ctx, param_map, op_desc, input_infos, program); + auto op_inputs = this->GenerateOperationInput( ctx, param_map, op_desc, op_info.name(), input_infos, program); @@ -1137,6 +1133,9 @@ struct FetchOpTranscriber : public OpTranscriber { std::tie(input_infos, attr_infos, output_infos, std::ignore) = op_info_concept->get_op_info_(); + this->InsertSliceOperationForInput( + ctx, param_map, op_desc, input_infos, program); + auto op_inputs = this->GenerateOperationInput( ctx, param_map, op_desc, op_info.name(), input_infos, program); diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index 3c16eebc34914a8b1ebd06d9952532dcdc33f52f..69c5c8b0a08621f17a40e7fdc5e6c8645dd86a1f 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -22,6 +22,7 @@ #include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/ir_adaptor/translator/op_translator.h" #include "paddle/fluid/ir_adaptor/translator/type_translator.h" +#include "paddle/fluid/ir_adaptor/translator/utils.h" #include "paddle/ir/core/attribute.h" #include "paddle/ir/core/block.h" #include "paddle/ir/core/builtin_attribute.h" @@ -189,6 +190,12 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) { continue; } + if (param_map_[var_name].generated_by_vector) { + InsertSliceOperationForTarget( + ctx_, ¶m_map_, program_, param_map_[var_name], var_name); + defining_op_result = param_map_.at(var_name).value; + } + ir::Operation* op = InsertSetParamaterOp( ctx_, defining_op_result, parameter_name_mappings_[var_name]); @@ -218,6 +225,7 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue( // Currently we set stop gradient for operation that generated a value // connected with VarDesc for (const auto& [var_name, value_info] : param_map_) { + if (no_cast_var_names.count(var_name) != 0) continue; VLOG(10) << "[op translated][stop gradient]" << var_name; VarDesc* var = block.FindVarRecursive(var_name); if (var == nullptr) { diff --git a/paddle/fluid/ir_adaptor/translator/utils.cc b/paddle/fluid/ir_adaptor/translator/utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..75bcebd16d7343c86eb8fcc37f9c28c8c789822d --- /dev/null +++ b/paddle/fluid/ir_adaptor/translator/utils.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/ir_adaptor/translator/utils.h" + +#include + +#include "paddle/ir/core/builtin_attribute.h" +#include "paddle/ir/core/builtin_type.h" + +namespace paddle { +namespace translator { + +ir::Operation* InsertSliceOperationForTarget( + ir::IrContext* ctx, + TranslationContext* param_map, + ir::Program* program, + const VariableDefiningInfo& defining_info, + const std::string& arg_name) { + std::string slice_op_name(ir::SliceOp::name()); + ir::OpInfo op_info = ctx->GetRegisteredOpInfo(slice_op_name); + std::unordered_map op_attribute_map = { + {"index", ir::Int32Attribute::get(ctx, defining_info.idx_in_vector)}, + }; + ir::VectorType src_vec_type = + defining_info.value.type().dyn_cast(); + ir::Operation* operation = + ir::Operation::Create({defining_info.value}, + op_attribute_map, + {src_vec_type[defining_info.idx_in_vector]}, + op_info); + program->block()->push_back(operation); + ir::OpResult target_op_result = operation->result(0); + (*param_map)[arg_name] = VariableDefiningInfo(target_op_result); + return operation; +} + +} // namespace translator +} // namespace paddle diff --git a/paddle/fluid/ir_adaptor/translator/utils.h b/paddle/fluid/ir_adaptor/translator/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..a6e5a8fd20969ddc4eeccdd5469f957276fc034c --- /dev/null +++ b/paddle/fluid/ir_adaptor/translator/utils.h @@ -0,0 +1,35 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/ir_adaptor/translator/program_translator.h" +#include "paddle/ir/core/ir_context.h" +#include "paddle/ir/core/operation.h" +#include "paddle/ir/core/program.h" + +namespace paddle { +namespace translator { + +ir::Operation* InsertSliceOperationForTarget( + ir::IrContext* ctx, + TranslationContext* param_map, + ir::Program* program, + const VariableDefiningInfo& defining_info, + const std::string& arg_name); + +} // namespace translator +} // namespace paddle