未验证 提交 67c84c45 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Support inplace pass (#56672)

* add code

* add code

* refine code

* add code

* fix bug

* fix bug

* fix bug

* add code

* add ut

* polish code

* fix bug

* refine code

* fix bug

* refine code

* fix bug

* refine code

* fix bug

* refine code

* fix bug

* refine code

* add code

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* refine code
上级 11a526ab
......@@ -16,6 +16,7 @@ set(STANDALONE_EXECUTOR_DEPS
phi_kernel_adaptor
program_translator
instruction_base
pd_inplace_pass
ir)
cc_library(
......
......@@ -16,16 +16,24 @@
#include "paddle/fluid/framework/new_executor/feed_fetch_utils.h"
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/framework/new_executor/program_interpreter.h"
#include "paddle/fluid/platform/flags.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/fluid/ir/transforms/inplace_pass.h"
#include "paddle/fluid/ir_adaptor/translator/translate.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_manager.h"
PHI_DECLARE_bool(enable_new_ir_in_executor);
PHI_DECLARE_bool(enable_new_ir_api);
PADDLE_DEFINE_EXPORTED_bool(new_ir_apply_inplace_pass,
true,
"new ir kernel program apply inplace pass.");
namespace paddle {
namespace framework {
StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
......@@ -101,6 +109,13 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
}
auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(base_program.get(), place);
if (FLAGS_new_ir_apply_inplace_pass) {
ir::PassManager pm(ir::IrContext::Instance(), 3);
pm.AddPass(ir::CreateInplacePass());
pm.Run(kernel_program.get());
}
interpretercores_.emplace_back(
std::make_shared<InterpreterCore>(place_,
fetch_var_names_,
......
......@@ -856,7 +856,7 @@ def OpGenerator(
op_infer_meta_map,
muta_attr_is_input=False,
)
if len(op_attribute_name_list) > 1:
if len(op_attribute_name_list) > 0:
(
build_args_with_attr_is_map_for_declare,
build_func_with_attr_is_map,
......
......@@ -84,15 +84,15 @@ const OpRunTimeInfo& OpYamlInfoParser::OpRuntimeInfo() const {
return std::get<3>(op_info_tuple_);
}
const std::map<std::string, int>& OpYamlInfoParser::InputName2Id() const {
const std::map<std::string, uint32_t>& OpYamlInfoParser::InputName2Id() const {
return input_name2id_;
}
const std::map<std::string, int>& OpYamlInfoParser::OutputName2Id() const {
const std::map<std::string, uint32_t>& OpYamlInfoParser::OutputName2Id() const {
return output_name2id_;
}
const std::vector<int>& OpYamlInfoParser::NoNeedBufferIds() const {
const std::vector<uint32_t>& OpYamlInfoParser::NoNeedBufferIds() const {
return no_need_buffer_ids_;
}
......@@ -118,6 +118,17 @@ const std::string& OpYamlInfoParser::InplaceName(
"Can not find inplace input of [%s].", out_name));
}
std::unordered_map<uint32_t, uint32_t> OpYamlInfoParser::GetInplaceIdMap()
const {
std::unordered_map<uint32_t, uint32_t> inplace_id_map;
auto& inplace_info = std::get<3>(op_info_tuple_).inplace;
for (const auto& info : inplace_info) {
inplace_id_map[OutputName2Id().at(info.first)] =
InputName2Id().at(info.second);
}
return inplace_id_map;
}
bool OpYamlInfoParser::HasView(const std::string& out_name) const {
auto& view_info = std::get<3>(op_info_tuple_).view;
for (size_t i = 0; i < view_info.size(); i++) {
......
......@@ -34,10 +34,10 @@ class OpYamlInfoParser {
const std::vector<std::string>& TensorParams(bool is_kernel = false) const;
const std::vector<std::string>& AttrParams(bool is_kernel = false) const;
const OpRunTimeInfo& OpRuntimeInfo() const;
const std::map<std::string, int>& InputName2Id() const;
const std::map<std::string, int>& OutputName2Id() const;
const std::map<std::string, uint32_t>& InputName2Id() const;
const std::map<std::string, uint32_t>& OutputName2Id() const;
const std::vector<int>& NoNeedBufferIds() const;
const std::vector<uint32_t>& NoNeedBufferIds() const;
const std::vector<std::string>& InputNames() const {
return input_name_list_;
......@@ -53,6 +53,8 @@ class OpYamlInfoParser {
const std::string& InplaceName(const std::string& out_name) const;
std::unordered_map<uint32_t, uint32_t> GetInplaceIdMap() const;
bool HasView(const std::string& out_name) const;
const std::string& ViewName(const std::string& out_name) const;
......@@ -68,20 +70,20 @@ class OpYamlInfoParser {
OpInfoTuple op_info_tuple_;
// input info
std::map<std::string, int> input_name2id_;
std::map<std::string, uint32_t> input_name2id_;
std::vector<std::string> input_name_list_;
std::map<std::string, OpInputInfo> input_info_;
int input_tensor_number_{0};
uint32_t input_tensor_number_{0};
// no_need_buffer_ids
std::vector<int> no_need_buffer_ids_;
std::vector<uint32_t> no_need_buffer_ids_;
// attribute info
std::vector<std::string> attribute_name_list_;
std::map<std::string, OpAttributeInfo> attr_info_;
// output info
std::map<std::string, int> output_name2id_;
std::map<std::string, uint32_t> output_name2id_;
std::vector<std::string> output_name_list_;
std::map<std::string, OpOutputInfo> output_info_;
......
......@@ -12,3 +12,8 @@ cc_library(
_constant_folding_pass
SRCS constant_folding_pass.cc
DEPS standalone_executor pd_op_to_kernel_pass transform_general_functions)
cc_library(
pd_inplace_pass
SRCS inplace_pass.cc
DEPS pd_dialect_core op_yaml_info_parser)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/ir/transforms/inplace_pass.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/trait/inplace.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_parser.h"
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_dialect.h"
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_registry.h"
// NOTE(zhangbo): Which kind of value can be deleted?
// (1) Value's type needs to be AllocatedDenseTensorType or
// AllocatedSelectedRowsType; (2) Value's is not persisable.
bool CanBeDeleted(ir::Value value) {
if (!value.type()) {
return false;
}
if (!value.type().isa<paddle::dialect::AllocatedDenseTensorType>() &&
!value.type().isa<paddle::dialect::AllocatedSelectedRowsType>()) {
return false;
}
if (value.GetDefiningOp()->HasAttribute(kAttrIsPersisable)) {
return !(value.GetDefiningOp()
->attribute(kAttrIsPersisable)
.dyn_cast<::ir::ArrayAttribute>()
.AsVector()[value.dyn_cast<::ir::OpResult>().GetResultIndex()]
.dyn_cast<::ir::BoolAttribute>()
.data());
}
return true;
}
bool CanDoInplace(const std::unordered_set<ir::Value>& eager_dels,
ir::Value input,
ir::Value output) {
if (input.type() != output.type()) {
VLOG(9) << " -- input's type != output's type, can't do inplace";
return false;
}
if (eager_dels.count(input) == 0) {
VLOG(9) << " -- input not in eager_deletion_valus, can't do inplace";
return false;
}
return true;
}
bool IsNoNeedBuffer(ir::Operation* op, ir::Value value) {
if (op->dialect()->name().compare(
paddle::dialect::PaddleKernelDialect::name()) != 0) {
VLOG(8) << op->name()
<< "is not a kernel_dialect op, no need buffer is false";
return false;
}
auto op_name =
op->attributes().at("op_name").dyn_cast<::ir::StrAttribute>().AsString();
ir::OpInfo op_info = ir::IrContext::Instance()->GetRegisteredOpInfo(op_name);
if (op_info) {
auto info_interface =
op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>();
if (info_interface) {
paddle::dialect::OpYamlInfoParser info_parser(
info_interface->get_op_info_());
auto& no_need_buffer_ids = info_parser.NoNeedBufferIds();
for (size_t id = 0; id < no_need_buffer_ids.size(); id++) {
if (value == op->operand_source(no_need_buffer_ids[id])) {
return true;
}
}
}
}
return false;
}
// NOTE(zhangbo): pd.feed's output and pd.fetch's input can not be eager
// deleted.
std::unordered_set<ir::Value> GetSkipDeletionValues(ir::Block* block) {
std::unordered_set<ir::Value> skip_dels;
for (auto& op : *block) {
if (op->dialect()->name().compare(
paddle::dialect::PaddleKernelDialect::name()) != 0) {
continue;
}
IR_ENFORCE(op->attributes().count("op_name") > 0,
"kernel_dialect op should own an 'op_name' attribute.");
auto upper_op_name = op->attributes()
.at("op_name")
.dyn_cast<::ir::StrAttribute>()
.AsString();
if (upper_op_name == "pd.feed" || upper_op_name == "pd.data") {
skip_dels.insert(op->result(0));
continue;
}
if (upper_op_name == "pd.fetch" || upper_op_name == "pd.shadow_output") {
skip_dels.insert(op->operand_source(0));
continue;
}
}
return skip_dels;
}
// NOTE(zhangbo): For inplace Pass, currently only the kernel_dialect operator
// is supported. Therefore, this function only returns the values in the
// kernel_dialect operator that can be eager deleted.
std::unordered_map<ir::Operation*, std::unordered_set<ir::Value>>
GetEagerDeletionValues(ir::Block* block) {
std::unordered_set<ir::Value> skip_dels = GetSkipDeletionValues(block);
std::unordered_map<ir::Value, ir::Operation*> del_value_2_op;
for (auto& op : *block) {
std::string upper_op_name = op->name();
if (op->dialect()->name().compare(
paddle::dialect::PaddleKernelDialect::name()) == 0) {
IR_ENFORCE(op->attributes().count("op_name") > 0,
"kernel_dialect op should own an 'op_name' attribute.");
upper_op_name = op->attributes()
.at("op_name")
.dyn_cast<::ir::StrAttribute>()
.AsString();
}
for (size_t i = 0; i < op->num_operands(); ++i) {
auto input = op->operand_source(i);
if (skip_dels.count(input) > 0 || !input || !CanBeDeleted(input) ||
IsNoNeedBuffer(op, input)) {
VLOG(6) << "The " << i << "-th input value of the Operation("
<< upper_op_name << ") can not be deleted.";
VLOG(8) << " -- skip dels: " << skip_dels.count(input);
VLOG(8) << " -- value is null: " << !input;
VLOG(8) << " -- can be deleted: " << !CanBeDeleted(input);
VLOG(8) << " -- is no_need_buffer: " << IsNoNeedBuffer(op, input);
continue;
}
del_value_2_op[input] = op;
}
for (size_t i = 0; i < op->num_results(); ++i) {
ir::Value output = op->result(i);
if (output && CanBeDeleted(output)) {
del_value_2_op[output] = op;
}
}
}
std::unordered_map<ir::Operation*, std::unordered_set<ir::Value>> eager_dels;
for (auto& kv : del_value_2_op) {
eager_dels[kv.second].insert(kv.first);
}
return eager_dels;
}
std::unordered_map<ir::Operation*, std::string> GetInplaceOps(
ir::Block* block) {
const auto eager_dels = GetEagerDeletionValues(block);
std::unordered_map<ir::Operation*, std::string> inplace_ops;
std::unordered_set<ir::Value> visited_values;
std::unordered_set<ir::Value> reused_input_values;
std::unordered_set<ir::Value> reused_output_values;
for (auto& op : *block) {
for (size_t i = 0; i < op->num_operands(); ++i) {
visited_values.insert(op->operand_source(i));
}
if (op->dialect()->name().compare(
paddle::dialect::PaddleKernelDialect::name()) != 0) {
VLOG(6) << op->name()
<< "is not a kernel_dialect op, inplace only support "
"kernel_dialect operators";
for (size_t i = 0; i < op->num_results(); ++i) {
visited_values.insert(op->result(i));
}
continue;
}
auto upper_op_attrs = op->attributes();
auto upper_op_name =
upper_op_attrs.at("op_name").dyn_cast<::ir::StrAttribute>().AsString();
VLOG(6) << "analyse op: " << upper_op_name;
if (upper_op_attrs.count("is_inplace") != 0 &&
upper_op_attrs.at("is_inplace").dyn_cast<ir::BoolAttribute>().data()) {
VLOG(6) << upper_op_name << " is already an inplace op.";
for (size_t i = 0; i < op->num_operands(); ++i) {
reused_input_values.insert(op->operand_source(i));
}
for (size_t i = 0; i < op->num_results(); ++i) {
reused_output_values.insert(op->result(i));
visited_values.insert(op->result(i));
}
continue;
}
ir::OpInfo upper_inplace_op_info =
ir::IrContext::Instance()->GetRegisteredOpInfo(upper_op_name + "_");
if (eager_dels.count(op) == 0 || (!upper_inplace_op_info)) {
VLOG(6) << upper_op_name
<< "'s value can't delete or doesn't have inplace op, so that "
"can't do inplace.";
for (size_t i = 0; i < op->num_results(); ++i) {
visited_values.insert(op->result(i));
}
continue;
}
auto upper_inplace_op_interface =
upper_inplace_op_info
.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>();
PADDLE_ENFORCE_NOT_NULL(
upper_inplace_op_interface,
phi::errors::PreconditionNotMet(
"can not find OpYamlInfoInterface from [%s]", upper_op_name + "_"));
paddle::dialect::OpYamlInfoParser upper_inplace_op_info_parser(
upper_inplace_op_interface->get_op_info_());
std::unordered_map<uint32_t, uint32_t> inplace_out_2_in =
upper_inplace_op_info_parser.GetInplaceIdMap();
bool can_do_inplace = true;
for (auto& kv : inplace_out_2_in) {
uint32_t out_slot = kv.first;
uint32_t in_slot = kv.second;
if ((in_slot >= op->num_operands()) || (out_slot >= op->num_results()) ||
(!CanDoInplace(eager_dels.at(op),
op->operand_source(in_slot),
op->result(out_slot))) ||
(visited_values.count(op->result(out_slot)) > 0) ||
(!CanBeDeleted(op->result(out_slot))) ||
(reused_input_values.count(op->operand_source(in_slot)) > 0) ||
(reused_output_values.count(op->result(out_slot)) > 0)) {
can_do_inplace = false;
VLOG(6) << upper_op_name
<< "'s value has been visited or reused by other inplace op, "
"so that can't do inplace.";
VLOG(8) << " -- operand " << in_slot << " and result " << out_slot
<< " can do inplace: "
<< CanDoInplace(eager_dels.at(op),
op->operand_source(in_slot),
op->result(out_slot));
VLOG(8) << " -- result " << out_slot << " visited: "
<< (visited_values.count(op->result(out_slot)) > 0);
VLOG(8) << " -- operand " << in_slot << " has been reused: "
<< (reused_input_values.count(op->operand_source(in_slot)) > 0);
VLOG(8) << " -- result " << out_slot << " has been reused: "
<< (reused_output_values.count(op->result(out_slot)) > 0);
break;
}
}
if (can_do_inplace) {
inplace_ops[op] = upper_op_name + "_";
for (auto& kv : inplace_out_2_in) {
reused_input_values.insert(op->operand_source(kv.second));
reused_output_values.insert(op->result(kv.first));
}
VLOG(6) << upper_op_name
<< " will change to inplace version op: " << upper_op_name + "_";
}
for (size_t i = 0; i < op->num_results(); ++i) {
visited_values.insert(op->result(i));
}
}
return inplace_ops;
}
class InplacePass : public ir::Pass {
public:
InplacePass() : ir::Pass("InplacePass", 3) {}
void Run(ir::Operation* op) override {
auto module_op = op->dyn_cast<ir::ModuleOp>();
IR_ENFORCE(module_op, "DcePass should run on module op.");
auto* block = module_op.block();
auto inplace_ops = GetInplaceOps(block);
for (auto kv : inplace_ops) {
VLOG(6) << "Do inplace for: "
<< kv.first->attributes()
.at("op_name")
.dyn_cast<::ir::StrAttribute>()
.AsString();
ir::Block::iterator insert_pos =
std::find(block->begin(), block->end(), kv.first);
IR_ENFORCE(insert_pos != block->end(),
"Operator %s not found in block.",
kv.first->name());
kv.first->set_attribute(
"op_name",
ir::StrAttribute::get(ir::IrContext::Instance(), kv.second));
kv.first->set_attribute(
"is_inplace",
ir::BoolAttribute::get(ir::IrContext::Instance(), true));
}
}
bool CanApplyOn(ir::Operation* op) const override {
return op->name() == "builtin.module" && op->num_regions() > 0;
}
};
namespace ir {
std::unique_ptr<ir::Pass> CreateInplacePass() {
return std::make_unique<InplacePass>();
}
} // namespace ir
REGISTER_PASS(inplace, InplacePass);
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/ir/core/dll_decl.h"
namespace ir {
class Pass;
std::unique_ptr<Pass> CreateInplacePass();
} // namespace ir
......@@ -41,6 +41,7 @@ set(PYBIND_DEPS
phi_kernel_adaptor
pd_dialect
program_translator
pd_inplace_pass
ir
new_profiler
jit_layer
......
......@@ -30,6 +30,7 @@
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h"
#include "paddle/fluid/ir/transforms/inplace_pass.h"
#include "paddle/fluid/ir_adaptor/translator/translate.h"
#include "paddle/fluid/ir_adaptor/translator/utils.h"
#include "paddle/ir/core/block.h"
......@@ -59,6 +60,7 @@ using paddle::dialect::DenseTensorType;
using pybind11::return_value_policy;
USE_PASS(dead_code_elimination);
USE_PASS(inplace);
namespace paddle {
namespace pybind {
......
......@@ -759,6 +759,10 @@ def relu(x, name=None):
if in_dynamic_mode():
return _C_ops.relu(x)
else:
if paddle.ir.core._use_new_ir_api():
# Below code will be removed after we can generate IR api automatically
return paddle._ir_ops.relu(x)
check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'relu'
)
......
......@@ -4,7 +4,7 @@ file(
"test_*.py")
string(REPLACE ".py" "" TEST_INTERP_CASES "${TEST_INTERP_CASES}")
set(TEST_IR_SYSTEM_CASES test_build_model)
set(TEST_IR_SYSTEM_CASES test_build_model test_pd_inplace_pass)
list(REMOVE_ITEM TEST_INTERP_CASES ${TEST_IR_SYSTEM_CASES})
foreach(target ${TEST_INTERP_CASES})
......@@ -16,3 +16,5 @@ foreach(target ${TEST_IR_SYSTEM_CASES})
py_test_modules(${target} MODULES ${target} ENVS GLOG_v=1
FLAGS_enable_new_ir_api=true)
endforeach()
set_tests_properties(test_pd_inplace_pass PROPERTIES TIMEOUT 60)
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.fluid import core
paddle.enable_static()
class TestPdInplacePass(unittest.TestCase):
def test_pd_inplace_pass(self):
place = core.Place()
place.set_place(paddle.CPUPlace())
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
x = paddle.static.data('x', [2, 2], dtype='float32')
y = paddle.ones([2, 2], dtype='float32')
z = paddle.divide(x, y)
out = paddle.nn.functional.relu(z)
exe = paddle.static.Executor()
x_feed = np.ones([2, 2], dtype=np.float32) * 10
(sum_value,) = exe.run(feed={'x': x_feed}, fetch_list=[out])
self.assertEqual(
(sum_value == np.ones([2, 2], dtype="float32") * 10).all(),
True,
)
if __name__ == "__main__":
unittest.main()
......@@ -1370,9 +1370,19 @@ class OpTest(unittest.TestCase):
return
if self._check_cinn:
return
stored_flag = get_flags('FLAGS_enable_new_ir_in_executor')
stored_flag = get_flags(
[
'FLAGS_enable_new_ir_in_executor',
"FLAGS_new_ir_apply_inplace_pass",
]
)
try:
set_flags({"FLAGS_enable_new_ir_in_executor": True})
set_flags(
{
"FLAGS_enable_new_ir_in_executor": True,
"FLAGS_new_ir_apply_inplace_pass": 0,
}
)
new_scope = paddle.static.Scope()
executor = Executor(place)
new_program = None
......@@ -3215,9 +3225,19 @@ class OpTest(unittest.TestCase):
if self._check_cinn:
return
stored_flag = get_flags('FLAGS_enable_new_ir_in_executor')
stored_flag = get_flags(
[
'FLAGS_enable_new_ir_in_executor',
"FLAGS_new_ir_apply_inplace_pass",
]
)
try:
set_flags({"FLAGS_enable_new_ir_in_executor": True})
set_flags(
{
"FLAGS_enable_new_ir_in_executor": True,
"FLAGS_new_ir_apply_inplace_pass": 0,
}
)
executor = Executor(place)
new_gradients = list(
map(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册