未验证 提交 ada16f94 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Refine the Build interface of split op (#56924)

* fix bug

* fix bug
上级 6d0ef342
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
#include "paddle/fluid/framework/new_executor/feed_fetch_utils.h" #include "paddle/fluid/framework/new_executor/feed_fetch_utils.h"
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/framework/new_executor/program_interpreter.h" #include "paddle/fluid/framework/new_executor/program_interpreter.h"
#include "paddle/fluid/platform/flags.h"
#include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/phi/core/flags.h"
#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h" #include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
...@@ -29,10 +29,7 @@ ...@@ -29,10 +29,7 @@
PHI_DECLARE_bool(enable_new_ir_in_executor); PHI_DECLARE_bool(enable_new_ir_in_executor);
PHI_DECLARE_bool(enable_new_ir_api); PHI_DECLARE_bool(enable_new_ir_api);
PHI_DECLARE_bool(new_ir_apply_inplace_pass);
PADDLE_DEFINE_EXPORTED_bool(new_ir_apply_inplace_pass,
true,
"new ir kernel program apply inplace pass.");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -13,6 +13,10 @@ ...@@ -13,6 +13,10 @@
# limitations under the License. # limitations under the License.
# generator build function # generator build function
_INFERMETA_NEED_META_CONFIG = {'SplitInferMeta'}
_PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE = {'SplitOp'}
OP_BUILD_TEMPLATE = """ OP_BUILD_TEMPLATE = """
void {op_name}::Build({build_args}) {{ void {op_name}::Build({build_args}) {{
{get_attributes} {get_attributes}
...@@ -273,6 +277,7 @@ def GenBuildAttributes( ...@@ -273,6 +277,7 @@ def GenBuildAttributes(
def GenBuildOutputs( def GenBuildOutputs(
op_class_name,
op_input_name_list, op_input_name_list,
op_input_type_list, op_input_type_list,
op_mutable_attribute_name_list, op_mutable_attribute_name_list,
...@@ -316,6 +321,40 @@ def GenBuildOutputs( ...@@ -316,6 +321,40 @@ def GenBuildOutputs(
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector<int64_t> {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullIntArrayOp>().attributes().at("value").dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData(); (void){name};\n""" CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector<int64_t> {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullIntArrayOp>().attributes().at("value").dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData(); (void){name};\n"""
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullOp>().attributes().at("value").dyn_cast<paddle::dialect::ScalarAttribute>().data().to<{dtype}>(); (void){name};\n""" CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullOp>().attributes().at("value").dyn_cast<paddle::dialect::ScalarAttribute>().data().to<{dtype}>(); (void){name};\n"""
CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::IntArray {name};
if ({name}_.owner()->info().id() == ir::TypeId::get<paddle::dialect::FullIntArrayOp>()) {{
{name} = std::move(phi::IntArray({name}_.owner()
->dyn_cast<paddle::dialect::FullIntArrayOp>()
.attributes()
.at("value")
.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data()
.GetData()));
}}
else {{
PADDLE_ENFORCE(
{name}_.type().isa<ir::VectorType>(),
phi::errors::PreconditionNotMet("section Type should be VectorType."));
size_t {name}_size = {name}_.type().dyn_cast<ir::VectorType>().size();
{name} = std::move(phi::IntArray(std::vector<int64_t>({name}_size, -1)));
{name}.SetFromTensor(true);
}}\n"""
CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::Scalar {name};
if ({name}_.owner()->info().id() == ir::TypeId::get<paddle::dialect::FullOp>()) {{
{name} = std::move(phi::Scalar({name}_.owner()
->dyn_cast<paddle::dialect::FullOp>()
.attributes()
.at("value")
.dyn_cast<paddle::dialect::ScalarAttribute>()
.data()
.to<int>()));
}}
else {{
{name} = std::move(phi::Scalar(-1));
{name}.SetFromTensor(true);
}}\n"""
CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name}; CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name};
phi::MetaTensor meta_{name}(&dense_{name}); phi::MetaTensor meta_{name}(&dense_{name});
""" """
...@@ -348,19 +387,30 @@ def GenBuildOutputs( ...@@ -348,19 +387,30 @@ def GenBuildOutputs(
attr_dtype = op_mutable_attribute_type_list[idx] attr_dtype = op_mutable_attribute_type_list[idx]
# int_array # int_array
if attr_dtype[0] == "paddle::dialect::IntArrayAttribute": if attr_dtype[0] == "paddle::dialect::IntArrayAttribute":
build_output_str += ( if op_class_name in _PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE:
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE.format( build_output_str += CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx] name=op_mutable_attribute_name_list[idx]
) )
) else:
build_output_str += (
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx]
)
)
# scalar # scalar
elif attr_dtype[0] == "paddle::dialect::ScalarAttribute": elif attr_dtype[0] == "paddle::dialect::ScalarAttribute":
build_output_str += ( if op_class_name in _PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE:
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE.format( build_output_str += CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx], name=op_mutable_attribute_name_list[idx],
dtype=attr_dtype[1], dtype=attr_dtype[1],
) )
) else:
build_output_str += (
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx],
dtype=attr_dtype[1],
)
)
# string # string
elif attr_dtype[0] == "ir::StrAttribute": elif attr_dtype[0] == "ir::StrAttribute":
build_output_str += "" build_output_str += ""
...@@ -421,9 +471,19 @@ def GenBuildOutputs( ...@@ -421,9 +471,19 @@ def GenBuildOutputs(
CREATE_INFER_META_FUNC_TEMPLATE = """ CREATE_INFER_META_FUNC_TEMPLATE = """
phi::{func}({args}); phi::{func}({args});
""" """
build_output_str += CREATE_INFER_META_FUNC_TEMPLATE.format( CREATE_INFER_META_FUNC_WITH_METACINFIG_TEMPLATE = """
func=op_infer_meta_map['func'], args=", ".join(infer_meta_args) phi::{func}({args}, phi::MetaConfig(false, false));
) """
if op_infer_meta_map['func'] in _INFERMETA_NEED_META_CONFIG:
build_output_str += (
CREATE_INFER_META_FUNC_WITH_METACINFIG_TEMPLATE.format(
func=op_infer_meta_map['func'], args=", ".join(infer_meta_args)
)
)
else:
build_output_str += CREATE_INFER_META_FUNC_TEMPLATE.format(
func=op_infer_meta_map['func'], args=", ".join(infer_meta_args)
)
# use dense_{name} or vec_dense_{name} to create Outputs type # use dense_{name} or vec_dense_{name} to create Outputs type
build_output_str += "\n std::vector<ir::Type> argument_outputs;" build_output_str += "\n std::vector<ir::Type> argument_outputs;"
...@@ -528,6 +588,7 @@ def gen_build_func_str( ...@@ -528,6 +588,7 @@ def gen_build_func_str(
op_non_mutable_attribute_type_list, op_non_mutable_attribute_type_list,
) )
build_outputs_str = GenBuildOutputs( build_outputs_str = GenBuildOutputs(
op_class_name,
op_input_name_list, op_input_name_list,
op_input_type_list, op_input_type_list,
op_mutable_attribute_name_list, op_mutable_attribute_name_list,
......
...@@ -25,10 +25,11 @@ ...@@ -25,10 +25,11 @@
#include "paddle/ir/pass/pass.h" #include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_registry.h" #include "paddle/ir/pass/pass_registry.h"
namespace details {
// NOTE(zhangbo): Which kind of value can be deleted? // NOTE(zhangbo): Which kind of value can be deleted?
// (1) Value's type needs to be AllocatedDenseTensorType or // (1) Value's type needs to be AllocatedDenseTensorType or
// AllocatedSelectedRowsType; (2) Value's is not persisable. // AllocatedSelectedRowsType; (2) Value's is not persisable.
bool CanBeDeleted(ir::Value value) { static bool CanBeDeleted(ir::Value value) {
if (!value.type()) { if (!value.type()) {
return false; return false;
} }
...@@ -47,9 +48,9 @@ bool CanBeDeleted(ir::Value value) { ...@@ -47,9 +48,9 @@ bool CanBeDeleted(ir::Value value) {
return true; return true;
} }
bool CanDoInplace(const std::unordered_set<ir::Value>& eager_dels, static bool CanDoInplace(const std::unordered_set<ir::Value>& eager_dels,
ir::Value input, ir::Value input,
ir::Value output) { ir::Value output) {
if (input.type() != output.type()) { if (input.type() != output.type()) {
VLOG(9) << " -- input's type != output's type, can't do inplace"; VLOG(9) << " -- input's type != output's type, can't do inplace";
return false; return false;
...@@ -61,7 +62,7 @@ bool CanDoInplace(const std::unordered_set<ir::Value>& eager_dels, ...@@ -61,7 +62,7 @@ bool CanDoInplace(const std::unordered_set<ir::Value>& eager_dels,
return true; return true;
} }
bool IsNoNeedBuffer(ir::Operation* op, ir::Value value) { static bool IsNoNeedBuffer(ir::Operation* op, ir::Value value) {
if (op->dialect()->name().compare( if (op->dialect()->name().compare(
paddle::dialect::PaddleKernelDialect::name()) != 0) { paddle::dialect::PaddleKernelDialect::name()) != 0) {
VLOG(8) << op->name() VLOG(8) << op->name()
...@@ -90,7 +91,7 @@ bool IsNoNeedBuffer(ir::Operation* op, ir::Value value) { ...@@ -90,7 +91,7 @@ bool IsNoNeedBuffer(ir::Operation* op, ir::Value value) {
// NOTE(zhangbo): pd.feed's output and pd.fetch's input can not be eager // NOTE(zhangbo): pd.feed's output and pd.fetch's input can not be eager
// deleted. // deleted.
std::unordered_set<ir::Value> GetSkipDeletionValues(ir::Block* block) { static std::unordered_set<ir::Value> GetSkipDeletionValues(ir::Block* block) {
std::unordered_set<ir::Value> skip_dels; std::unordered_set<ir::Value> skip_dels;
for (auto& op : *block) { for (auto& op : *block) {
if (op->dialect()->name().compare( if (op->dialect()->name().compare(
...@@ -119,7 +120,7 @@ std::unordered_set<ir::Value> GetSkipDeletionValues(ir::Block* block) { ...@@ -119,7 +120,7 @@ std::unordered_set<ir::Value> GetSkipDeletionValues(ir::Block* block) {
// NOTE(zhangbo): For inplace Pass, currently only the kernel_dialect operator // NOTE(zhangbo): For inplace Pass, currently only the kernel_dialect operator
// is supported. Therefore, this function only returns the values in the // is supported. Therefore, this function only returns the values in the
// kernel_dialect operator that can be eager deleted. // kernel_dialect operator that can be eager deleted.
std::unordered_map<ir::Operation*, std::unordered_set<ir::Value>> static std::unordered_map<ir::Operation*, std::unordered_set<ir::Value>>
GetEagerDeletionValues(ir::Block* block) { GetEagerDeletionValues(ir::Block* block) {
std::unordered_set<ir::Value> skip_dels = GetSkipDeletionValues(block); std::unordered_set<ir::Value> skip_dels = GetSkipDeletionValues(block);
...@@ -167,7 +168,7 @@ GetEagerDeletionValues(ir::Block* block) { ...@@ -167,7 +168,7 @@ GetEagerDeletionValues(ir::Block* block) {
return eager_dels; return eager_dels;
} }
std::unordered_map<ir::Operation*, std::string> GetInplaceOps( static std::unordered_map<ir::Operation*, std::string> GetInplaceOps(
ir::Block* block) { ir::Block* block) {
const auto eager_dels = GetEagerDeletionValues(block); const auto eager_dels = GetEagerDeletionValues(block);
...@@ -282,6 +283,7 @@ std::unordered_map<ir::Operation*, std::string> GetInplaceOps( ...@@ -282,6 +283,7 @@ std::unordered_map<ir::Operation*, std::string> GetInplaceOps(
} }
return inplace_ops; return inplace_ops;
} }
} // namespace details
class InplacePass : public ir::Pass { class InplacePass : public ir::Pass {
public: public:
...@@ -292,7 +294,7 @@ class InplacePass : public ir::Pass { ...@@ -292,7 +294,7 @@ class InplacePass : public ir::Pass {
IR_ENFORCE(module_op, "DcePass should run on module op."); IR_ENFORCE(module_op, "DcePass should run on module op.");
auto* block = module_op.block(); auto* block = module_op.block();
auto inplace_ops = GetInplaceOps(block); auto inplace_ops = details::GetInplaceOps(block);
for (auto kv : inplace_ops) { for (auto kv : inplace_ops) {
VLOG(6) << "Do inplace for: " VLOG(6) << "Do inplace for: "
......
...@@ -118,7 +118,7 @@ class IrContextImpl { ...@@ -118,7 +118,7 @@ class IrContextImpl {
<< ", OpInfo: ptr=" << iter->second.AsOpaquePointer() << "]."; << ", OpInfo: ptr=" << iter->second.AsOpaquePointer() << "].";
return iter->second; return iter->second;
} }
LOG(WARNING) << "No cache found operation of: [Name=" << name << "]."; VLOG(8) << "No cache found operation of: [Name=" << name << "].";
return OpInfo(); return OpInfo();
} }
const OpInfoMap &registered_op_info_map() { return registed_op_infos_; } const OpInfoMap &registered_op_info_map() { return registed_op_infos_; }
......
...@@ -1289,7 +1289,7 @@ PHI_DEFINE_EXPORTED_bool(enable_new_ir_api, ...@@ -1289,7 +1289,7 @@ PHI_DEFINE_EXPORTED_bool(enable_new_ir_api,
"Enable new IR API in Python"); "Enable new IR API in Python");
/** /**
* Using new IR in executor FLAG * Using new IR in executor FLAG
* Name: enable_new_ir_in_executor_trace_run * Name: enable_new_ir_in_executor_trace_run
* Since Version: 2.6.0 * Since Version: 2.6.0
* Value Range: bool, default=false * Value Range: bool, default=false
...@@ -1301,6 +1301,19 @@ PHI_DEFINE_EXPORTED_bool(enable_new_ir_in_executor_trace_run, ...@@ -1301,6 +1301,19 @@ PHI_DEFINE_EXPORTED_bool(enable_new_ir_in_executor_trace_run,
false, false,
"Enable new IR in executor"); "Enable new IR in executor");
/**
* Apply inplace pass to new IR FLAG
* Name: new_ir_apply_inplace_pass
* Since Version: 2.6.0
* Value Range: bool, default=true
* Example:
* Note: If Ture, will apply inplace pass to new IR.
*/
PHI_DEFINE_EXPORTED_bool(new_ir_apply_inplace_pass,
true,
"Whether to apply inplace pass on lowering "
"::ir::Program to Kernel Dialect");
PHI_DEFINE_EXPORTED_bool(enable_record_memory, false, "Enable memory recorder"); PHI_DEFINE_EXPORTED_bool(enable_record_memory, false, "Enable memory recorder");
PHI_DEFINE_EXPORTED_bool( PHI_DEFINE_EXPORTED_bool(
......
...@@ -759,7 +759,7 @@ def relu(x, name=None): ...@@ -759,7 +759,7 @@ def relu(x, name=None):
if in_dynamic_mode(): if in_dynamic_mode():
return _C_ops.relu(x) return _C_ops.relu(x)
else: else:
if paddle.ir.core._use_new_ir_api(): if paddle.framework.in_dynamic_or_new_ir_mode():
# Below code will be removed after we can generate IR api automatically # Below code will be removed after we can generate IR api automatically
return paddle._ir_ops.relu(x) return paddle._ir_ops.relu(x)
......
...@@ -17,14 +17,13 @@ import unittest ...@@ -17,14 +17,13 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
from paddle.fluid import core
paddle.enable_static() paddle.enable_static()
class TestPdInplacePass(unittest.TestCase): class TestPdInplacePass(unittest.TestCase):
def test_pd_inplace_pass(self): def test_pd_inplace_pass(self):
place = core.Place() place = paddle.framework.core.Place()
place.set_place(paddle.CPUPlace()) place.set_place(paddle.CPUPlace())
new_scope = paddle.static.Scope() new_scope = paddle.static.Scope()
main_program = paddle.static.Program() main_program = paddle.static.Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册