未验证 提交 5e6645d7 编写于 作者: H hong 提交者: GitHub

[NewIR]Fix new ir concat split bug (#55419)

* fix new ir concat op bug

* fix bug

* using add_n_with_kernel instead of add_n impl

* fix pd_op yaml bug

* fix bug
上级 5e61b04c
...@@ -116,7 +116,42 @@ ...@@ -116,7 +116,42 @@
- {typename: Tensor, name: out, optional: false, intermediate: false} - {typename: Tensor, name: out, optional: false, intermediate: false}
no_need_buffer: null no_need_buffer: null
data_transform: null data_transform: null
invoke: {func: add_n_impl, args: inputs} infer_meta:
func: AddNInferMeta
param: [inputs]
kernel:
func: [add_n]
param: [inputs]
backend: null
layout: null
data_type: null
dispatch: {fetch: null}
force_backend: null
backward: add_n_grad
- name: add_n_with_kernel
inputs:
- typename: Tensor[]
name: inputs
optional: false
no_need_buffer: false
data_transform: {}
attrs: []
outputs:
- {typename: Tensor, name: out, optional: false, intermediate: false}
no_need_buffer: null
data_transform: null
infer_meta:
func: AddNInferMeta
param: [inputs]
kernel:
func: [add_n]
param: [inputs]
backend: null
layout: null
data_type: null
dispatch: {fetch: null}
force_backend: null
backward: add_n_grad backward: add_n_grad
- name: write_to_array - name: write_to_array
......
...@@ -350,7 +350,6 @@ void BuildScope(const ir::Block& block, ...@@ -350,7 +350,6 @@ void BuildScope(const ir::Block& block,
<< paddle::framework::GenScopeTreeDebugInfo( << paddle::framework::GenScopeTreeDebugInfo(
const_cast<paddle::framework::Scope*>(inner_scope->root())); const_cast<paddle::framework::Scope*>(inner_scope->root()));
// int count = value_2_var_name->size();
for (auto it = block.begin(); it != block.end(); ++it) { for (auto it = block.begin(); it != block.end(); ++it) {
ir::Operation* op = *it; ir::Operation* op = *it;
......
...@@ -65,7 +65,6 @@ phi::KernelKey GetKernelKey( ...@@ -65,7 +65,6 @@ phi::KernelKey GetKernelKey(
phi::Backend kernel_backend = phi::Backend::UNDEFINED; phi::Backend kernel_backend = phi::Backend::UNDEFINED;
phi::DataLayout kernel_layout = phi::DataLayout::UNDEFINED; phi::DataLayout kernel_layout = phi::DataLayout::UNDEFINED;
phi::DataType kernel_data_type = phi::DataType::UNDEFINED; phi::DataType kernel_data_type = phi::DataType::UNDEFINED;
if (op_info_parser != nullptr) { if (op_info_parser != nullptr) {
// only suppurt non vector input for now // only suppurt non vector input for now
int tensor_input_number = op_info_parser->InputTensorNumber(); int tensor_input_number = op_info_parser->InputTensorNumber();
...@@ -84,12 +83,36 @@ phi::KernelKey GetKernelKey( ...@@ -84,12 +83,36 @@ phi::KernelKey GetKernelKey(
} else if (input_map.count(slot_name)) { } else if (input_map.count(slot_name)) {
// parse from input // parse from input
int in_index = input_map.at(slot_name); int in_index = input_map.at(slot_name);
auto type = map_value_pair.at(op->operand(in_index)).type();
if (type.isa<paddle::dialect::AllocatedDenseTensorType>()) {
kernel_data_type = TransToPhiDataType(
type.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.dtype());
} else if (type.isa<ir::VectorType>()) {
auto vec_data = type.dyn_cast<ir::VectorType>().data();
if (vec_data.size() == 0) {
kernel_data_type = phi::DataType::UNDEFINED;
} else {
if (vec_data[0].isa<paddle::dialect::AllocatedDenseTensorType>()) {
kernel_data_type = TransToPhiDataType(
vec_data[0]
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.dtype());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support DenseTensorType in vector"));
}
}
} else if (type.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
kernel_data_type = TransToPhiDataType(
type.dyn_cast<paddle::dialect::AllocatedSelectedRowsType>()
.dtype());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support DenseTensorType, SelectedRows, VectorType"));
}
dialect::DenseTensorType type =
op->operand(in_index)
.type()
.dyn_cast<paddle::dialect::DenseTensorType>();
kernel_data_type = TransToPhiDataType(type.dtype());
} else { } else {
PADDLE_ENFORCE_EQ(attr_map.count(slot_name), PADDLE_ENFORCE_EQ(attr_map.count(slot_name),
true, true,
...@@ -146,7 +169,6 @@ phi::KernelKey GetKernelKey( ...@@ -146,7 +169,6 @@ phi::KernelKey GetKernelKey(
if (op_info_parser != nullptr && op_info_parser->IsTensorAttribute(i)) { if (op_info_parser != nullptr && op_info_parser->IsTensorAttribute(i)) {
continue; continue;
} }
auto input_tmp = op->operand(i); auto input_tmp = op->operand(i);
// NOTE: if not input_tmp, it's an optional input // NOTE: if not input_tmp, it's an optional input
if (!input_tmp) { if (!input_tmp) {
......
...@@ -977,7 +977,8 @@ struct SplitOpTranscriber : public OpTranscriber { ...@@ -977,7 +977,8 @@ struct SplitOpTranscriber : public OpTranscriber {
// process sections // process sections
int num = paddle::get<int>(op_desc.GetAttr("num")); int num = paddle::get<int>(op_desc.GetAttr("num"));
if (num <= 0) { if (num <= 0) {
if (op_desc.HasInput("SectionsTensorList")) { if (op_desc.HasInput("SectionsTensorList") &&
op_desc.Input("SectionsTensorList").size() > 0) {
// get SectionsTensorList from input // get SectionsTensorList from input
auto sec_tensor_list = op_desc.Input("SectionsTensorList"); auto sec_tensor_list = op_desc.Input("SectionsTensorList");
...@@ -989,7 +990,7 @@ struct SplitOpTranscriber : public OpTranscriber { ...@@ -989,7 +990,7 @@ struct SplitOpTranscriber : public OpTranscriber {
ir::Attribute new_attr = attribute_translator( ir::Attribute new_attr = attribute_translator(
"paddle::dialect::IntArrayAttribute", op_desc.GetAttr("sections")); "paddle::dialect::IntArrayAttribute", op_desc.GetAttr("sections"));
auto sec_defin_op = auto sec_defin_op =
InsertFullOperationForAttributeInput(ctx, program, new_attr); InsertFullArrayOperationForAttributeInput(ctx, program, new_attr);
op_inputs.push_back(sec_defin_op->result(0)); op_inputs.push_back(sec_defin_op->result(0));
} }
} }
...@@ -1087,6 +1088,26 @@ struct FetchOpTranscriber : public OpTranscriber { ...@@ -1087,6 +1088,26 @@ struct FetchOpTranscriber : public OpTranscriber {
} }
}; };
// NOTE, add_n op in legacy ops don't have a kernel, so we use a new op for now
struct AddNOpTranscriber : public OpTranscriber {
ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override {
std::string target_op_name =
kTargetDialectPrefix + OpNameCompatibleMapping(op_desc.Type());
if (IsInplace(op_desc)) {
target_op_name += "_";
} else {
target_op_name += "_with_kernel";
}
const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
if (!op_info) {
IR_THROW(
"Op assign_value should have corresponding OpInfo pd.assign_value_");
}
return op_info;
}
};
OpTranslator::OpTranslator() { OpTranslator::OpTranslator() {
general_handler = OpTranscriber(); general_handler = OpTranscriber();
special_handlers["feed"] = FeedOpTranscriber(); special_handlers["feed"] = FeedOpTranscriber();
...@@ -1098,6 +1119,7 @@ OpTranslator::OpTranslator() { ...@@ -1098,6 +1119,7 @@ OpTranslator::OpTranslator() {
special_handlers["assign_value"] = AssignValueOpTranscriber(); special_handlers["assign_value"] = AssignValueOpTranscriber();
special_handlers["increment"] = IncrementOpTranscriber(); special_handlers["increment"] = IncrementOpTranscriber();
special_handlers["rnn"] = RnnOpTranscriber(); special_handlers["rnn"] = RnnOpTranscriber();
special_handlers["add_n"] = AddNOpTranscriber();
} }
} // namespace translator } // namespace translator
......
...@@ -2534,6 +2534,15 @@ ...@@ -2534,6 +2534,15 @@
out : Out out : Out
extra : extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"] attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"]
int_array :
starts :
data_type : int
tensor_name : StartsTensor
tensors_name : StartsTensorList
ends :
data_type : int
tensor_name : EndsTensor
tensors_name : EndsTensorList
- op : slogdet(slogdeterminant) - op : slogdet(slogdeterminant)
backward : slogdet_grad(slogdeterminant_grad) backward : slogdet_grad(slogdeterminant_grad)
......
...@@ -24,6 +24,7 @@ void FetchKernel(const Context& dev_ctx, ...@@ -24,6 +24,7 @@ void FetchKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
DenseTensor* out) { DenseTensor* out) {
phi::Copy(dev_ctx, x, phi::CPUPlace(), true, out); phi::Copy(dev_ctx, x, phi::CPUPlace(), true, out);
out->set_lod(x.lod());
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(fetch, PD_REGISTER_KERNEL(fetch,
......
...@@ -855,11 +855,18 @@ class _ExecutorCache: ...@@ -855,11 +855,18 @@ class _ExecutorCache:
if build_strategy is None or build_strategy.enable_inplace if build_strategy is None or build_strategy.enable_inplace
else False else False
) )
enable_addto = ( enable_addto = (
True True
if build_strategy is not None and build_strategy.enable_addto if build_strategy is not None and build_strategy.enable_addto
else False else False
) )
if os.getenv("FLAGS_enable_new_ir_in_executor"):
# todo(phlrain), skip inplace add addto pass in new IR
enable_inplace = False
enable_addto = False
if enable_inplace or enable_addto: if enable_inplace or enable_addto:
# inplace should skip feed and fetch var # inplace should skip feed and fetch var
skip_var_names = eval(_get_program_cache_key(feed, fetch_list)) skip_var_names = eval(_get_program_cache_key(feed, fetch_list))
......
...@@ -2309,6 +2309,8 @@ class OpTest(unittest.TestCase): ...@@ -2309,6 +2309,8 @@ class OpTest(unittest.TestCase):
self.op_type self.op_type
not in compile_vs_runtime_white_list.COMPILE_RUN_OP_WHITE_LIST not in compile_vs_runtime_white_list.COMPILE_RUN_OP_WHITE_LIST
): ):
if os.getenv("FLAGS_enable_new_ir_in_executor"):
return
self.check_compile_vs_runtime(fetch_list, outs) self.check_compile_vs_runtime(fetch_list, outs)
def check_output_customized(self, checker, custom_place=None): def check_output_customized(self, checker, custom_place=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册