未验证 提交 c4694c15 编写于 作者: K kangguangli 提交者: GitHub

[NewIR ]fix bug: program translator not set value index correctly (#55789)

* fix bug: program translator not set value index correctly

* fix slice for setparameter
上级 8ddf51ff
......@@ -31,6 +31,7 @@
#include "paddle/fluid/ir_adaptor/translator/op_compat_info.h"
#include "paddle/fluid/ir_adaptor/translator/program_translator.h"
#include "paddle/fluid/ir_adaptor/translator/type_translator.h"
#include "paddle/fluid/ir_adaptor/translator/utils.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h"
......@@ -124,30 +125,6 @@ inline std::string OpNameCompatibleMapping(std::string op_name) {
return op_normalizer[op_name];
}
inline ir::Operation* InsertSliceOperationForTarget(
ir::IrContext* ctx,
TranslationContext* param_map,
ir::Program* program,
const VariableDefiningInfo& defining_info,
const std::string& arg_name) {
std::string slice_op_name(ir::SliceOp::name());
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(slice_op_name);
std::unordered_map<std::string, ir::Attribute> op_attribute_map = {
{"index", ir::Int32Attribute::get(ctx, defining_info.idx_in_vector)},
};
ir::VectorType src_vec_type =
defining_info.value.type().dyn_cast<ir::VectorType>();
ir::Operation* operation =
ir::Operation::Create({defining_info.value},
op_attribute_map,
{src_vec_type[defining_info.idx_in_vector]},
op_info);
program->block()->push_back(operation);
ir::OpResult target_op_result = operation->result(0);
(*param_map)[arg_name] = VariableDefiningInfo(target_op_result);
return operation;
}
inline ir::Operation* InsertCombineOperationForTarget(
ir::IrContext* ctx,
TranslationContext* param_map,
......@@ -307,6 +284,11 @@ struct OpTranscriber {
const std::string& input_name) {
return nullptr;
}
virtual void InsertSliceOperationForInput(ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
const OpInputInfoList& input_infos,
ir::Program* program);
};
ir::OpInfo OpTranscriber::LoopkUpOpInfo(ir::IrContext* ctx,
......@@ -328,26 +310,20 @@ ir::OpInfo OpTranscriber::LoopkUpOpInfo(ir::IrContext* ctx,
return op_info;
}
std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
void OpTranscriber::InsertSliceOperationForInput(
ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
const std::string& normalized_op_name,
const OpInputInfoList& input_infos,
ir::Program* program) {
VLOG(10) << "[op:" << op_desc.Type() << "][input] entrance";
auto& op_normalizer = OpNameNormalizer::instance();
const auto* mutable_attributes =
op_normalizer.GetMutableAttributes(op_desc.Type());
std::set<std::string> yaml_input_set;
for (const auto& info : input_infos) {
std::string legacy_input_name =
op_normalizer.GetLegacyArgName(op_desc.Type(), info.name);
yaml_input_set.insert(legacy_input_name);
}
// scan all inputs to see if any of them is generated as a vector<Tensor>
// so need an additional `SliceOp` to take it out.
for (const auto& n : op_desc.Inputs()) {
......@@ -366,9 +342,25 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
if (defining_info.generated_by_vector) {
InsertSliceOperationForTarget(
ctx, param_map, program, defining_info, arg_name);
VLOG(8) << "[op:" << op_desc.Type()
<< "] insert slice for var: " << arg_name;
}
}
}
}
std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
const std::string& normalized_op_name,
const OpInputInfoList& input_infos,
ir::Program* program) {
VLOG(10) << "[op:" << op_desc.Type() << "][input] entrance";
auto& op_normalizer = OpNameNormalizer::instance();
const auto* mutable_attributes =
op_normalizer.GetMutableAttributes(op_desc.Type());
VLOG(10) << "[op:" << op_desc.Type() << "][input] start";
......@@ -540,8 +532,8 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx,
op_desc.Type(),
var_name);
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "]" << info.name << " " << var_name
<< " " << var->GetType();
<< "[" << op_desc.Type() << "]" << info.name
<< " var: " << var_name << " type: " << var->GetType();
ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var);
......@@ -552,7 +544,7 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx,
} else {
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "]" << info.name << " :"
<< info.type_name << " " << legacy_output_name;
<< info.type_name << " var: " << legacy_output_name;
std::vector<ir::Type> types;
for (const auto& var_name : legacy_output_vars) {
if (var_name == kEmptyVarName) {
......@@ -562,8 +554,8 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx,
}
VarDesc* var = block->FindVarRecursive(var_name);
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "]" << info.name << " " << var_name
<< " " << var->GetType();
<< "[" << op_desc.Type() << "]" << info.name
<< " var: " << var_name << " type: " << var->GetType();
ir::Type translated_var_type =
type_translator[var->GetType()](ctx, *var);
types.push_back(translated_var_type);
......@@ -631,6 +623,7 @@ void OpTranscriber::RecordOpResultMapping(ir::IrContext* ctx,
size_t idx_in_vector = 0;
for (const auto& arg_name : args) {
if (arg_name == kEmptyVarName) {
idx_in_vector++;
continue;
}
auto idx_iter = arg_to_idx.find(arg_name);
......@@ -678,6 +671,9 @@ ir::Operation* OpTranscriber::operator()(ir::IrContext* ctx,
std::tie(input_infos, attr_infos, output_infos, std::ignore) =
op_info_concept->get_op_info_();
this->InsertSliceOperationForInput(
ctx, param_map, op_desc, input_infos, program);
auto op_inputs = this->GenerateOperationInput(
ctx, param_map, op_desc, op_info.name(), input_infos, program);
......@@ -1137,6 +1133,9 @@ struct FetchOpTranscriber : public OpTranscriber {
std::tie(input_infos, attr_infos, output_infos, std::ignore) =
op_info_concept->get_op_info_();
this->InsertSliceOperationForInput(
ctx, param_map, op_desc, input_infos, program);
auto op_inputs = this->GenerateOperationInput(
ctx, param_map, op_desc, op_info.name(), input_infos, program);
......
......@@ -22,6 +22,7 @@
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/ir_adaptor/translator/op_translator.h"
#include "paddle/fluid/ir_adaptor/translator/type_translator.h"
#include "paddle/fluid/ir_adaptor/translator/utils.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builtin_attribute.h"
......@@ -189,6 +190,12 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) {
continue;
}
if (param_map_[var_name].generated_by_vector) {
InsertSliceOperationForTarget(
ctx_, &param_map_, program_, param_map_[var_name], var_name);
defining_op_result = param_map_.at(var_name).value;
}
ir::Operation* op = InsertSetParamaterOp(
ctx_, defining_op_result, parameter_name_mappings_[var_name]);
......@@ -218,6 +225,7 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue(
// Currently we set stop gradient for operation that generated a value
// connected with VarDesc
for (const auto& [var_name, value_info] : param_map_) {
if (no_cast_var_names.count(var_name) != 0) continue;
VLOG(10) << "[op translated][stop gradient]" << var_name;
VarDesc* var = block.FindVarRecursive(var_name);
if (var == nullptr) {
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/ir_adaptor/translator/utils.h"
#include <unordered_map>
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h"
namespace paddle {
namespace translator {
ir::Operation* InsertSliceOperationForTarget(
ir::IrContext* ctx,
TranslationContext* param_map,
ir::Program* program,
const VariableDefiningInfo& defining_info,
const std::string& arg_name) {
std::string slice_op_name(ir::SliceOp::name());
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(slice_op_name);
std::unordered_map<std::string, ir::Attribute> op_attribute_map = {
{"index", ir::Int32Attribute::get(ctx, defining_info.idx_in_vector)},
};
ir::VectorType src_vec_type =
defining_info.value.type().dyn_cast<ir::VectorType>();
ir::Operation* operation =
ir::Operation::Create({defining_info.value},
op_attribute_map,
{src_vec_type[defining_info.idx_in_vector]},
op_info);
program->block()->push_back(operation);
ir::OpResult target_op_result = operation->result(0);
(*param_map)[arg_name] = VariableDefiningInfo(target_op_result);
return operation;
}
} // namespace translator
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/ir_adaptor/translator/program_translator.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h"
namespace paddle {
namespace translator {
ir::Operation* InsertSliceOperationForTarget(
ir::IrContext* ctx,
TranslationContext* param_map,
ir::Program* program,
const VariableDefiningInfo& defining_info,
const std::string& arg_name);
} // namespace translator
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册