未验证 提交 ada16f94 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Refine the Build interface of split op (#56924)

* fix bug

* fix bug
上级 6d0ef342
......@@ -16,8 +16,8 @@
#include "paddle/fluid/framework/new_executor/feed_fetch_utils.h"
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/framework/new_executor/program_interpreter.h"
#include "paddle/fluid/platform/flags.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/phi/core/flags.h"
#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
......@@ -29,10 +29,7 @@
PHI_DECLARE_bool(enable_new_ir_in_executor);
PHI_DECLARE_bool(enable_new_ir_api);
PADDLE_DEFINE_EXPORTED_bool(new_ir_apply_inplace_pass,
true,
"new ir kernel program apply inplace pass.");
PHI_DECLARE_bool(new_ir_apply_inplace_pass);
namespace paddle {
namespace framework {
......
......@@ -13,6 +13,10 @@
# limitations under the License.
# generator build function
_INFERMETA_NEED_META_CONFIG = {'SplitInferMeta'}
_PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE = {'SplitOp'}
OP_BUILD_TEMPLATE = """
void {op_name}::Build({build_args}) {{
{get_attributes}
......@@ -273,6 +277,7 @@ def GenBuildAttributes(
def GenBuildOutputs(
op_class_name,
op_input_name_list,
op_input_type_list,
op_mutable_attribute_name_list,
......@@ -316,6 +321,40 @@ def GenBuildOutputs(
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector<int64_t> {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullIntArrayOp>().attributes().at("value").dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData(); (void){name};\n"""
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullOp>().attributes().at("value").dyn_cast<paddle::dialect::ScalarAttribute>().data().to<{dtype}>(); (void){name};\n"""
CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::IntArray {name};
if ({name}_.owner()->info().id() == ir::TypeId::get<paddle::dialect::FullIntArrayOp>()) {{
{name} = std::move(phi::IntArray({name}_.owner()
->dyn_cast<paddle::dialect::FullIntArrayOp>()
.attributes()
.at("value")
.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data()
.GetData()));
}}
else {{
PADDLE_ENFORCE(
{name}_.type().isa<ir::VectorType>(),
phi::errors::PreconditionNotMet("section Type should be VectorType."));
size_t {name}_size = {name}_.type().dyn_cast<ir::VectorType>().size();
{name} = std::move(phi::IntArray(std::vector<int64_t>({name}_size, -1)));
{name}.SetFromTensor(true);
}}\n"""
CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::Scalar {name};
if ({name}_.owner()->info().id() == ir::TypeId::get<paddle::dialect::FullOp>()) {{
{name} = std::move(phi::Scalar({name}_.owner()
->dyn_cast<paddle::dialect::FullOp>()
.attributes()
.at("value")
.dyn_cast<paddle::dialect::ScalarAttribute>()
.data()
.to<int>()));
}}
else {{
{name} = std::move(phi::Scalar(-1));
{name}.SetFromTensor(true);
}}\n"""
CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name};
phi::MetaTensor meta_{name}(&dense_{name});
"""
......@@ -348,19 +387,30 @@ def GenBuildOutputs(
attr_dtype = op_mutable_attribute_type_list[idx]
# int_array
if attr_dtype[0] == "paddle::dialect::IntArrayAttribute":
build_output_str += (
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE.format(
if op_class_name in _PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE:
build_output_str += CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx]
)
)
else:
build_output_str += (
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx]
)
)
# scalar
elif attr_dtype[0] == "paddle::dialect::ScalarAttribute":
build_output_str += (
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE.format(
if op_class_name in _PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE:
build_output_str += CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx],
dtype=attr_dtype[1],
)
)
else:
build_output_str += (
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx],
dtype=attr_dtype[1],
)
)
# string
elif attr_dtype[0] == "ir::StrAttribute":
build_output_str += ""
......@@ -421,9 +471,19 @@ def GenBuildOutputs(
CREATE_INFER_META_FUNC_TEMPLATE = """
phi::{func}({args});
"""
build_output_str += CREATE_INFER_META_FUNC_TEMPLATE.format(
func=op_infer_meta_map['func'], args=", ".join(infer_meta_args)
)
CREATE_INFER_META_FUNC_WITH_METACINFIG_TEMPLATE = """
phi::{func}({args}, phi::MetaConfig(false, false));
"""
if op_infer_meta_map['func'] in _INFERMETA_NEED_META_CONFIG:
build_output_str += (
CREATE_INFER_META_FUNC_WITH_METACINFIG_TEMPLATE.format(
func=op_infer_meta_map['func'], args=", ".join(infer_meta_args)
)
)
else:
build_output_str += CREATE_INFER_META_FUNC_TEMPLATE.format(
func=op_infer_meta_map['func'], args=", ".join(infer_meta_args)
)
# use dense_{name} or vec_dense_{name} to create Outputs type
build_output_str += "\n std::vector<ir::Type> argument_outputs;"
......@@ -528,6 +588,7 @@ def gen_build_func_str(
op_non_mutable_attribute_type_list,
)
build_outputs_str = GenBuildOutputs(
op_class_name,
op_input_name_list,
op_input_type_list,
op_mutable_attribute_name_list,
......
......@@ -25,10 +25,11 @@
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_registry.h"
namespace details {
// NOTE(zhangbo): Which kind of value can be deleted?
// (1) Value's type needs to be AllocatedDenseTensorType or
// AllocatedSelectedRowsType; (2) Value's is not persisable.
bool CanBeDeleted(ir::Value value) {
static bool CanBeDeleted(ir::Value value) {
if (!value.type()) {
return false;
}
......@@ -47,9 +48,9 @@ bool CanBeDeleted(ir::Value value) {
return true;
}
bool CanDoInplace(const std::unordered_set<ir::Value>& eager_dels,
ir::Value input,
ir::Value output) {
static bool CanDoInplace(const std::unordered_set<ir::Value>& eager_dels,
ir::Value input,
ir::Value output) {
if (input.type() != output.type()) {
VLOG(9) << " -- input's type != output's type, can't do inplace";
return false;
......@@ -61,7 +62,7 @@ bool CanDoInplace(const std::unordered_set<ir::Value>& eager_dels,
return true;
}
bool IsNoNeedBuffer(ir::Operation* op, ir::Value value) {
static bool IsNoNeedBuffer(ir::Operation* op, ir::Value value) {
if (op->dialect()->name().compare(
paddle::dialect::PaddleKernelDialect::name()) != 0) {
VLOG(8) << op->name()
......@@ -90,7 +91,7 @@ bool IsNoNeedBuffer(ir::Operation* op, ir::Value value) {
// NOTE(zhangbo): pd.feed's output and pd.fetch's input can not be eager
// deleted.
std::unordered_set<ir::Value> GetSkipDeletionValues(ir::Block* block) {
static std::unordered_set<ir::Value> GetSkipDeletionValues(ir::Block* block) {
std::unordered_set<ir::Value> skip_dels;
for (auto& op : *block) {
if (op->dialect()->name().compare(
......@@ -119,7 +120,7 @@ std::unordered_set<ir::Value> GetSkipDeletionValues(ir::Block* block) {
// NOTE(zhangbo): For inplace Pass, currently only the kernel_dialect operator
// is supported. Therefore, this function only returns the values in the
// kernel_dialect operator that can be eager deleted.
std::unordered_map<ir::Operation*, std::unordered_set<ir::Value>>
static std::unordered_map<ir::Operation*, std::unordered_set<ir::Value>>
GetEagerDeletionValues(ir::Block* block) {
std::unordered_set<ir::Value> skip_dels = GetSkipDeletionValues(block);
......@@ -167,7 +168,7 @@ GetEagerDeletionValues(ir::Block* block) {
return eager_dels;
}
std::unordered_map<ir::Operation*, std::string> GetInplaceOps(
static std::unordered_map<ir::Operation*, std::string> GetInplaceOps(
ir::Block* block) {
const auto eager_dels = GetEagerDeletionValues(block);
......@@ -282,6 +283,7 @@ std::unordered_map<ir::Operation*, std::string> GetInplaceOps(
}
return inplace_ops;
}
} // namespace details
class InplacePass : public ir::Pass {
public:
......@@ -292,7 +294,7 @@ class InplacePass : public ir::Pass {
IR_ENFORCE(module_op, "DcePass should run on module op.");
auto* block = module_op.block();
auto inplace_ops = GetInplaceOps(block);
auto inplace_ops = details::GetInplaceOps(block);
for (auto kv : inplace_ops) {
VLOG(6) << "Do inplace for: "
......
......@@ -118,7 +118,7 @@ class IrContextImpl {
<< ", OpInfo: ptr=" << iter->second.AsOpaquePointer() << "].";
return iter->second;
}
LOG(WARNING) << "No cache found operation of: [Name=" << name << "].";
VLOG(8) << "No cache found operation of: [Name=" << name << "].";
return OpInfo();
}
const OpInfoMap &registered_op_info_map() { return registed_op_infos_; }
......
......@@ -1289,7 +1289,7 @@ PHI_DEFINE_EXPORTED_bool(enable_new_ir_api,
"Enable new IR API in Python");
/**
* Using new IR in executor FLAG
* Using new IR in executor FLAG
* Name: enable_new_ir_in_executor_trace_run
* Since Version: 2.6.0
* Value Range: bool, default=false
......@@ -1301,6 +1301,19 @@ PHI_DEFINE_EXPORTED_bool(enable_new_ir_in_executor_trace_run,
false,
"Enable new IR in executor");
/**
* Apply inplace pass to new IR FLAG
* Name: new_ir_apply_inplace_pass
* Since Version: 2.6.0
* Value Range: bool, default=true
* Example:
* Note: If Ture, will apply inplace pass to new IR.
*/
PHI_DEFINE_EXPORTED_bool(new_ir_apply_inplace_pass,
true,
"Whether to apply inplace pass on lowering "
"::ir::Program to Kernel Dialect");
PHI_DEFINE_EXPORTED_bool(enable_record_memory, false, "Enable memory recorder");
PHI_DEFINE_EXPORTED_bool(
......
......@@ -759,7 +759,7 @@ def relu(x, name=None):
if in_dynamic_mode():
return _C_ops.relu(x)
else:
if paddle.ir.core._use_new_ir_api():
if paddle.framework.in_dynamic_or_new_ir_mode():
# Below code will be removed after we can generate IR api automatically
return paddle._ir_ops.relu(x)
......
......@@ -17,14 +17,13 @@ import unittest
import numpy as np
import paddle
from paddle.fluid import core
paddle.enable_static()
class TestPdInplacePass(unittest.TestCase):
def test_pd_inplace_pass(self):
place = core.Place()
place = paddle.framework.core.Place()
place.set_place(paddle.CPUPlace())
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册