未验证 提交 0f611f18 编写于 作者: K kangguangli 提交者: GitHub

[NewIR] Insert `get_parameter` only for paramters (#56325)

* fix inset get_parameter op bug

* fix bug: insert  only for parameters

* fix bug: wrong idx in vector

---------
Co-authored-by: Nzhangbo9674 <zhangbo54@baidu.com>
上级 da72707f
...@@ -367,7 +367,7 @@ std::unique_ptr<::ir::Program> ConstructFowardIrProgram( ...@@ -367,7 +367,7 @@ std::unique_ptr<::ir::Program> ConstructFowardIrProgram(
} }
} }
// add fetch with place op to program // add data op to program
auto *block = local_program.MutableBlock(0); auto *block = local_program.MutableBlock(0);
for (auto &in_t : x) { for (auto &in_t : x) {
auto name = in_t.name(); auto name = in_t.name();
......
...@@ -478,9 +478,10 @@ std::string NewIRInterpreter::DebugValueInfo() { ...@@ -478,9 +478,10 @@ std::string NewIRInterpreter::DebugValueInfo() {
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"var(%s) should exist in var_name_2_id_", kv.second)); "var(%s) should exist in var_name_2_id_", kv.second));
auto* var = InnerScope()->FindVar(kv.second); auto* var = InnerScope()->FindVar(kv.second);
PADDLE_ENFORCE(var != nullptr, PADDLE_ENFORCE(
platform::errors::PreconditionNotMet( var != nullptr,
"var(%s) should exist in var_name_2_id_", kv.second)); platform::errors::PreconditionNotMet(
"var(%s) should exist in scope (%p)", kv.second, InnerScope()));
os << kv.first.impl() << " -> " << kv.second << " -> " os << kv.first.impl() << " -> " << kv.second << " -> "
<< var_name_2_id_.at(kv.second) << " -> " << var << "\n"; << var_name_2_id_.at(kv.second) << " -> " << var << "\n";
} }
......
...@@ -49,7 +49,9 @@ namespace translator { ...@@ -49,7 +49,9 @@ namespace translator {
namespace { namespace {
using ResultIdx = size_t; using IdxInOp = size_t;
using IdxInVector = size_t;
using ResultIdx = std::tuple<IdxInOp, IdxInVector>;
using OpDesc = paddle::framework::OpDesc; using OpDesc = paddle::framework::OpDesc;
using BlockDesc = paddle::framework::BlockDesc; using BlockDesc = paddle::framework::BlockDesc;
using VarDesc = paddle::framework::VarDesc; using VarDesc = paddle::framework::VarDesc;
...@@ -325,18 +327,15 @@ void OpTranscriber::InsertSliceOperationForInput( ...@@ -325,18 +327,15 @@ void OpTranscriber::InsertSliceOperationForInput(
// scan all inputs to see if any of them is generated as a vector<Tensor> // scan all inputs to see if any of them is generated as a vector<Tensor>
// so need an additional `SliceOp` to take it out. // so need an additional `SliceOp` to take it out.
for (const auto& n : op_desc.Inputs()) { for (const auto& n : op_desc.Inputs()) {
auto& name = n.first;
auto& args = n.second; auto& args = n.second;
for (const auto& arg_name : args) { for (const auto& arg_name : args) {
bool check = bool check =
param_map->count(arg_name) != 0 || !yaml_input_set.count(arg_name); param_map->count(arg_name) != 0 && !yaml_input_set.count(arg_name);
IR_ENFORCE(check, if (!check) {
"arg %s.%s as input should be exists before prasing %s", continue;
name, }
arg_name, auto defining_info = param_map->at(arg_name);
op_desc.Type());
auto defining_info = (*param_map)[arg_name];
if (defining_info.generated_by_vector) { if (defining_info.generated_by_vector) {
InsertSliceOperationForTarget( InsertSliceOperationForTarget(
ctx, param_map, program, defining_info, arg_name); ctx, param_map, program, defining_info, arg_name);
...@@ -391,7 +390,8 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput( ...@@ -391,7 +390,8 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
} }
} }
VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " " VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " "
<< legacy_input_name << " " << legacy_input_vars.size(); << legacy_input_name << " " << legacy_input_vars.size() << "["
<< legacy_input_vars << "]";
if (legacy_input_vars.empty() && mutable_attributes != nullptr && if (legacy_input_vars.empty() && mutable_attributes != nullptr &&
mutable_attributes->count(info.name) != 0) { mutable_attributes->count(info.name) != 0) {
...@@ -507,7 +507,7 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx, ...@@ -507,7 +507,7 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx,
ir::Type translated_var_type = ir::Type translated_var_type =
type_translator[var->GetType()](ctx, *var); type_translator[var->GetType()](ctx, *var);
op_output_types.push_back(translated_var_type); op_output_types.push_back(translated_var_type);
arg_to_idx[var->Name()] = cur_output_idx; arg_to_idx[var->Name()] = {cur_output_idx, 0};
continue; continue;
} }
} }
...@@ -535,7 +535,7 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx, ...@@ -535,7 +535,7 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx,
ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var);
arg_to_idx[var_name] = cur_output_idx; arg_to_idx[var_name] = {cur_output_idx, 0};
op_output_types.push_back(translated_var_type); op_output_types.push_back(translated_var_type);
// if src type is Vector<Tesnor> // if src type is Vector<Tesnor>
...@@ -544,10 +544,12 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx, ...@@ -544,10 +544,12 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx,
<< "[" << op_desc.Type() << "]" << info.name << " :" << "[" << op_desc.Type() << "]" << info.name << " :"
<< info.type_name << " var: " << legacy_output_name; << info.type_name << " var: " << legacy_output_name;
std::vector<ir::Type> types; std::vector<ir::Type> types;
for (const auto& var_name : legacy_output_vars) { for (IdxInVector idx_in_vec = 0; idx_in_vec < legacy_output_vars.size();
idx_in_vec++) {
const auto& var_name = legacy_output_vars[idx_in_vec];
if (var_name == kEmptyVarName) { if (var_name == kEmptyVarName) {
types.emplace_back(nullptr); types.emplace_back(nullptr);
arg_to_idx[var_name] = cur_output_idx; arg_to_idx[var_name] = {cur_output_idx, idx_in_vec};
continue; continue;
} }
VarDesc* var = block->FindVarRecursive(var_name); VarDesc* var = block->FindVarRecursive(var_name);
...@@ -557,7 +559,7 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx, ...@@ -557,7 +559,7 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx,
ir::Type translated_var_type = ir::Type translated_var_type =
type_translator[var->GetType()](ctx, *var); type_translator[var->GetType()](ctx, *var);
types.push_back(translated_var_type); types.push_back(translated_var_type);
arg_to_idx[var_name] = cur_output_idx; arg_to_idx[var_name] = {cur_output_idx, idx_in_vec};
} }
ir::Type vec_type = ir::VectorType::get(ctx, types); ir::Type vec_type = ir::VectorType::get(ctx, types);
op_output_types.push_back(vec_type); op_output_types.push_back(vec_type);
...@@ -613,45 +615,16 @@ void OpTranscriber::RecordOpResultMapping(ir::IrContext* ctx, ...@@ -613,45 +615,16 @@ void OpTranscriber::RecordOpResultMapping(ir::IrContext* ctx,
const OpDesc& op_desc, const OpDesc& op_desc,
ir::Operation* operation, ir::Operation* operation,
const OpOutputMapping& arg_to_idx) { const OpOutputMapping& arg_to_idx) {
for (const auto& n : op_desc.Outputs()) { for (const auto& [arg_name, idx] : arg_to_idx) {
auto& name = n.first; const auto& [idx_in_op, idx_in_vec] = idx;
VLOG(10) << "[output recording]" VLOG(10) << "[output recording]"
<< "[" << op_desc.Type() << "]" << name; << "[" << op_desc.Type() << "]" << arg_name << " " << idx_in_op
const auto& args = n.second; << " " << idx_in_vec;
size_t idx_in_vector = 0; ir::OpResult value = operation->result(idx_in_op);
for (const auto& arg_name : args) { bool generated_by_vector = value.type().isa<ir::VectorType>();
if (arg_name == kEmptyVarName) {
idx_in_vector++;
continue;
}
auto idx_iter = arg_to_idx.find(arg_name);
if (idx_iter == arg_to_idx.end()) {
VLOG(4) << "[output recording]"
<< "[" << op_desc.Type() << "][skip]" << arg_name;
continue;
}
auto idx = idx_iter->second;
VLOG(10) << "[output recording]"
<< "[" << op_desc.Type() << "]" << arg_name << " " << idx;
ir::OpResult value = operation->result(idx);
bool generated_by_vector = value.type().isa<ir::VectorType>();
// Specially process TensorArray, this because we cannot distinguish it
// with Vector<DenseTensor> by other conditions but we cannot support it
// like Vector<DenseTensor>
if (args.size() == 1) {
VarDesc* var = op_desc.Block()->FindVarRecursive(args[0]);
if (var->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR_ARRAY) {
generated_by_vector = false;
}
}
(*param_map)[arg_name] = VariableDefiningInfo( (*param_map)[arg_name] = VariableDefiningInfo(
value, generated_by_vector, generated_by_vector ? idx_in_vector : -1); value, generated_by_vector, generated_by_vector ? idx_in_vec : -1);
idx_in_vector++;
}
} }
} }
...@@ -1439,9 +1412,10 @@ struct ElementwiseGradTranscriber : public OpTranscriber { ...@@ -1439,9 +1412,10 @@ struct ElementwiseGradTranscriber : public OpTranscriber {
if (idx_iter == arg_to_idx.end()) { if (idx_iter == arg_to_idx.end()) {
IR_THROW("op[%s] should have got its y_grad", op_desc.Type()); IR_THROW("op[%s] should have got its y_grad", op_desc.Type());
} }
auto idx = idx_iter->second; auto [idx_in_op, idx_in_vec] = idx_iter->second;
VLOG(10) << "[output recording]" VLOG(10) << "[output recording]"
<< "[" << op_desc.Type() << "]" << y_grad_var_name << " " << idx; << "[" << op_desc.Type() << "]" << y_grad_var_name << " "
<< idx_in_op << " " << idx_in_vec;
auto y_names = op_desc.Input("Y", true); auto y_names = op_desc.Input("Y", true);
auto y_name = y_names[0]; auto y_name = y_names[0];
...@@ -1465,7 +1439,7 @@ struct ElementwiseGradTranscriber : public OpTranscriber { ...@@ -1465,7 +1439,7 @@ struct ElementwiseGradTranscriber : public OpTranscriber {
y_type.dyn_cast<dialect::DenseTensorType>(); y_type.dyn_cast<dialect::DenseTensorType>();
std::vector<int64_t> y_shape = phi::vectorize(y_tensor_type.dims()); std::vector<int64_t> y_shape = phi::vectorize(y_tensor_type.dims());
ir::OpResult value = operation->result(idx); ir::OpResult value = operation->result(idx_in_op);
ir::Builder builder(ctx, operation->GetParent()); ir::Builder builder(ctx, operation->GetParent());
auto reshape_op = builder.Build<dialect::ReshapeOp>(value, y_shape); auto reshape_op = builder.Build<dialect::ReshapeOp>(value, y_shape);
(*param_map)[y_grad_var_name] = (*param_map)[y_grad_var_name] =
......
...@@ -143,12 +143,12 @@ void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) { ...@@ -143,12 +143,12 @@ void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) {
var_desc = block.FindVarRecursive(var_name); var_desc = block.FindVarRecursive(var_name);
} }
bool need_get_parameter_op = is_parameter || is_unseen_variable; bool need_get_parameter_op = is_parameter && is_unseen_variable;
if (need_get_parameter_op) { if (need_get_parameter_op) {
ir::Operation* op = InsertGetParamaterOp(ctx_, var_desc); ir::Operation* op = InsertGetParamaterOp(ctx_, var_desc);
program_->block()->push_back(op); program_->block()->push_back(op);
param_map_[var_name] = VariableDefiningInfo(op->result(0)); param_map_[var_name] = VariableDefiningInfo(op->result(0));
VLOG(10) << "[op translated][get parameter]" << op; VLOG(10) << "[op translated][get parameter]" << var_name;
program_->SetParameter(var_name, nullptr); program_->SetParameter(var_name, nullptr);
parameter_visited_.insert(var_name); parameter_visited_.insert(var_name);
...@@ -224,7 +224,7 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) { ...@@ -224,7 +224,7 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) {
insert_pos++; insert_pos++;
block->insert(insert_pos, op); block->insert(insert_pos, op);
VLOG(10) << "[op translated][set parameter]" << op; VLOG(10) << "[op translated][set parameter]" << var_name;
program_->SetParameter(var_name, nullptr); program_->SetParameter(var_name, nullptr);
parameter_visited_.insert(var_name); parameter_visited_.insert(var_name);
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/utils.h"
namespace paddle { namespace paddle {
namespace translator { namespace translator {
...@@ -46,5 +47,15 @@ ir::Operation* InsertSliceOperationForTarget( ...@@ -46,5 +47,15 @@ ir::Operation* InsertSliceOperationForTarget(
return operation; return operation;
} }
std::ostream& operator<<(std::ostream& os,
const std::vector<std::string>& vec_str) {
ir::PrintInterleave(
vec_str.begin(),
vec_str.end(),
[&os](std::string s) { os << s; },
[&os]() { os << ", "; });
return os;
}
} // namespace translator } // namespace translator
} // namespace paddle } // namespace paddle
...@@ -31,5 +31,8 @@ ir::Operation* InsertSliceOperationForTarget( ...@@ -31,5 +31,8 @@ ir::Operation* InsertSliceOperationForTarget(
const VariableDefiningInfo& defining_info, const VariableDefiningInfo& defining_info,
const std::string& arg_name); const std::string& arg_name);
std::ostream& operator<<(std::ostream& os,
const std::vector<std::string>& vec_str);
} // namespace translator } // namespace translator
} // namespace paddle } // namespace paddle
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import paddle import paddle
from paddle.fluid import Variable, core from paddle.fluid import Variable, core
from paddle.fluid.data_feeder import check_type from paddle.fluid.data_feeder import check_type
...@@ -96,7 +98,7 @@ def data(name, shape, dtype=None, lod_level=0): ...@@ -96,7 +98,7 @@ def data(name, shape, dtype=None, lod_level=0):
shape[i] = -1 shape[i] = -1
if dtype: if dtype:
return helper.create_global_variable( out = helper.create_global_variable(
name=name, name=name,
shape=shape, shape=shape,
dtype=dtype, dtype=dtype,
...@@ -107,7 +109,7 @@ def data(name, shape, dtype=None, lod_level=0): ...@@ -107,7 +109,7 @@ def data(name, shape, dtype=None, lod_level=0):
need_check_feed=True, need_check_feed=True,
) )
else: else:
return helper.create_global_variable( out = helper.create_global_variable(
name=name, name=name,
shape=shape, shape=shape,
dtype=paddle.get_default_dtype(), dtype=paddle.get_default_dtype(),
...@@ -118,6 +120,21 @@ def data(name, shape, dtype=None, lod_level=0): ...@@ -118,6 +120,21 @@ def data(name, shape, dtype=None, lod_level=0):
need_check_feed=True, need_check_feed=True,
) )
if os.environ.get("FLAGS_enable_new_ir_in_executor", None):
helper = LayerHelper('data', **locals())
helper.append_op(
type='data',
inputs={},
outputs={'out': out},
attrs={
'index': 0,
'dtype': 0,
'place': 0,
'name': name,
},
)
return out
class InputSpec: class InputSpec:
""" """
......
...@@ -106,7 +106,7 @@ class TestMeanVjp(unittest.TestCase): ...@@ -106,7 +106,7 @@ class TestMeanVjp(unittest.TestCase):
.source() .source()
.get_defining_op() .get_defining_op()
.name(), .name(),
"builtin.get_parameter", "pd.data",
) )
self.assertEqual( self.assertEqual(
grad_outs[0][0] grad_outs[0][0]
......
...@@ -47,7 +47,7 @@ class TestBuildOp(unittest.TestCase): ...@@ -47,7 +47,7 @@ class TestBuildOp(unittest.TestCase):
self.assertEqual( self.assertEqual(
op_name_list, op_name_list,
[ [
'builtin.get_parameter', 'pd.data',
'pd.matmul', 'pd.matmul',
'pd.add', 'pd.add',
'pd.full_int_array', 'pd.full_int_array',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册