未验证 提交 e8cba1cb 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Support inplace execute logic for NewIrInterpreter (#55210)

* add inplace interface

* support inplace

* refine code

* fix bug

* fix bug

* refien code
上级 5f00305d
add_subdirectory(interface)
add_subdirectory(trait)
add_subdirectory(dialect)
add_subdirectory(transforms)
add_subdirectory(phi_kernel_adaptor)
......@@ -52,5 +52,5 @@ file(GLOB PD_DIALECT_SRCS "*.cc")
cc_library(
pd_dialect
SRCS ${PD_DIALECT_SRCS} ${op_source_file}
DEPS framework_proto phi phi_utils pd_interface ir)
DEPS framework_proto phi phi_utils pd_interface pd_trait ir)
target_include_directories(pd_dialect PRIVATE ${PD_DIALECT_BINARY_DIR})
......@@ -43,6 +43,7 @@ H_FILE_TEMPLATE = """#ifdef GET_OP_LIST
#include "paddle/fluid/ir/dialect/op_yaml_info_util.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir/interface/infermeta.h"
#include "paddle/fluid/ir/trait/inplace.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
......@@ -708,6 +709,10 @@ def OpGenerator(
op_interfaces_str = ""
if len(op_interfaces) > 0:
op_interfaces_str = "," + ",".join(op_interfaces)
if op_name[-1] == "_":
op_traits += ["InplaceTrait"]
op_traits_str = ""
if len(op_traits) > 0:
op_traits_str = "," + ",".join(op_traits)
......
......@@ -84,8 +84,30 @@ const OpRunTimeInfo& OpYamlInfoParser::OpRuntimeInfo() const {
return std::get<3>(op_info_tuple_);
}
const std::map<std::string, int>& OpYamlInfoParser::Name2Id() const {
return name2id_;
const std::map<std::string, int>& OpYamlInfoParser::InputName2Id() const {
return input_name2id_;
}
bool OpYamlInfoParser::HasInplace(const std::string& out_name) const {
auto inplace_info = std::get<3>(op_info_tuple_).inplace;
for (size_t i = 0; i < inplace_info.size(); i++) {
if (out_name == inplace_info[i].first) {
return true;
}
}
return false;
}
const std::string& OpYamlInfoParser::InplaceName(
const std::string& out_name) const {
auto inplace_info = std::get<3>(op_info_tuple_).inplace;
for (size_t i = 0; i < inplace_info.size(); i++) {
if (out_name == inplace_info[i].first) {
return inplace_info[i].second;
}
}
PADDLE_THROW(phi::errors::PreconditionNotMet(
"Can not find inplace input of [%s].", out_name));
}
void OpYamlInfoParser::parse() {
......@@ -94,30 +116,30 @@ void OpYamlInfoParser::parse() {
int start_index = 0;
for (size_t i = 0; i < input_info.size(); ++i) {
name2id_[input_info[i].name] = start_index++;
input_name2id_[input_info[i].name] = start_index++;
input_name_list_.push_back(input_info[i].name);
input_info_[input_info[i].name] = input_info[i];
if (!input_info[i].is_mutable_attribute) {
input_tensor_number_++;
}
input_info_[input_info[i].name] = input_info[i];
}
auto attribute_info = std::get<1>(op_info_tuple_);
for (size_t i = 0; i < attribute_info.size(); ++i) {
attribute_name_list_.push_back(attribute_info[i].name);
attr_info_[attribute_info[i].name] = attribute_info[i];
}
auto output_info = std::get<2>(op_info_tuple_);
for (size_t i = 0; i < output_info.size(); ++i) {
output_name_list_.push_back(output_info[i].name);
output_info_[output_info[i].name] = output_info[i];
}
auto runtime_info = std::get<3>(op_info_tuple_);
for (auto& name : runtime_info.infer_meta_param) {
if (name2id_.count(name) && !input_info_[name].is_mutable_attribute) {
if (input_name2id_.count(name) && !input_info_[name].is_mutable_attribute) {
infer_meta_tensor_params_.push_back(name);
} else {
infer_meta_attr_params_.push_back(name);
......@@ -125,7 +147,7 @@ void OpYamlInfoParser::parse() {
}
for (auto& name : runtime_info.kernel_param) {
if (name2id_.count(name) && !input_info_[name].is_mutable_attribute) {
if (input_name2id_.count(name) && !input_info_[name].is_mutable_attribute) {
kernel_fn_tensor_params_.push_back(name);
} else {
kernel_fn_attr_params_.push_back(name);
......
......@@ -34,7 +34,21 @@ class OpYamlInfoParser {
const std::vector<std::string>& TensorParams(bool is_kernel = false) const;
const std::vector<std::string>& AttrParams(bool is_kernel = false) const;
const OpRunTimeInfo& OpRuntimeInfo() const;
const std::map<std::string, int>& Name2Id() const;
const std::map<std::string, int>& InputName2Id() const;
const std::vector<std::string>& InputNames() const {
return input_name_list_;
}
const std::vector<std::string>& AttributeNames() const {
return attribute_name_list_;
}
const std::vector<std::string>& OutputNames() const {
return output_name_list_;
}
bool HasInplace(const std::string& out_name) const;
const std::string& InplaceName(const std::string& out_name) const;
private:
void parse();
......@@ -44,18 +58,25 @@ class OpYamlInfoParser {
OpInfoTuple op_info_tuple_;
std::map<std::string, int> name2id_;
// input info
std::map<std::string, int> input_name2id_;
std::vector<std::string> input_name_list_;
std::map<std::string, OpInputInfo> input_info_;
int input_tensor_number_{0};
// attribute info
std::vector<std::string> attribute_name_list_;
std::map<std::string, OpAttributeInfo> attr_info_;
// output info
std::vector<std::string> output_name_list_;
std::map<std::string, OpOutputInfo> output_info_;
// runtime info
std::vector<std::string> infer_meta_tensor_params_;
std::vector<std::string> infer_meta_attr_params_;
std::vector<std::string> kernel_fn_tensor_params_;
std::vector<std::string> kernel_fn_attr_params_;
int input_tensor_number_{0};
};
} // namespace dialect
......
......@@ -64,6 +64,66 @@ paddle::framework::Variable* CreateVar(ir::Value value,
}
}
void CheckInputVars(
ir::Operation* op,
const std::string& op_name,
const std::unordered_map<ir::Value, std::string>& name_map) {
size_t input_num = op->num_operands();
if (input_num > 0) {
for (size_t i = 0; i < input_num; ++i) {
auto value = op->operand(i);
if (value) {
PADDLE_ENFORCE_NE(
name_map.find(value),
name_map.end(),
phi::errors::PreconditionNotMet(
"input should in name map, [%d] 'th input of [%s] op",
i,
op_name));
}
}
}
}
void BuildValue(ir::Value value,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map,
int& count) { // NOLINT
auto inner_local_scope = local_scope != nullptr ? local_scope : scope;
std::string name;
if (name_map->find(value) != name_map->end()) {
name = name_map->at(value);
} else {
name = "inner_var_" + std::to_string(count++);
name_map->emplace(value, name);
}
auto var = CreateVar(value, name, scope, inner_local_scope);
// Only support DenseTensor or Vector<DenseTensor>
if (!value.type()) {
var->GetMutable<phi::DenseTensor>();
} else if (value.type().isa<paddle::dialect::AllocatedDenseTensorType>()) {
var->GetMutable<phi::DenseTensor>();
} else if (value.type().isa<ir::VectorType>()) {
auto tensor_array = var->GetMutable<paddle::framework::TensorRefArray>();
for (size_t i = 0; i < value.type().dyn_cast<ir::VectorType>().size();
i++) {
PADDLE_ENFORCE(value.type()
.dyn_cast<ir::VectorType>()[i]
.isa<paddle::dialect::AllocatedDenseTensorType>(),
paddle::platform::errors::Fatal(
"Element of VectorType output only support "
"DenseTensorType"));
std::string name_i = "inner_var_" + std::to_string(count++);
auto var_i = CreateVar(value, name_i, scope, inner_local_scope);
tensor_array->emplace_back(var_i->GetMutable<phi::DenseTensor>());
}
} else {
PADDLE_THROW(phi::errors::PreconditionNotMet(
"Output only support DenseTensorType or VectorType"));
}
}
void HandleForSpecialOp(ir::Operation* op,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
......@@ -91,10 +151,10 @@ void HandleForSpecialOp(ir::Operation* op,
if (op_name == "pd.feed") {
VLOG(6) << "Handle for pd.feed:";
auto ptr = op->result(0);
auto value = op->result(0);
std::string name = "inner_var_" + std::to_string(count++);
name_map->emplace(ptr, name);
auto var = CreateVar(ptr, name, scope, local_scope);
name_map->emplace(value, name);
auto var = CreateVar(value, name, scope, local_scope);
// TODO(phlrain): need to update here, support StringTensor
auto out_tensor = var->GetMutable<phi::DenseTensor>();
......@@ -122,14 +182,14 @@ void HandleForSpecialOp(ir::Operation* op,
auto tensor_array = var->GetMutable<paddle::framework::TensorRefArray>();
for (size_t i = 0; i < input_num; ++i) {
auto ptr = op->operand(i);
auto value = op->operand(i);
PADDLE_ENFORCE_EQ(
name_map->count(ptr),
name_map->count(value),
true,
phi::errors::PreconditionNotMet("can not found input of combine op"));
tensor_array->emplace_back(
&(CreateVar(ptr, name_map->at(ptr), scope, local_scope)
&(CreateVar(value, name_map->at(value), scope, local_scope)
->Get<phi::DenseTensor>()));
}
}
......@@ -160,6 +220,41 @@ void HandleForSpecialOp(ir::Operation* op,
}
}
void HandleForInplaceOp(ir::Operation* op,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map,
int& count) { // NOLINT
if (op->num_results() < 1) return;
ir::IrContext* ctx = ir::IrContext::Instance();
std::string op_name = op->name();
if (op->attributes().count("op_name")) {
op_name =
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data();
}
VLOG(4) << "HandleForInplaceOp: " << op_name;
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name);
paddle::dialect::OpYamlInfoParser yaml_parser(
op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>()
->get_op_info_());
for (size_t i = 0; i < op->num_results(); ++i) {
ir::Value value = op->result(i);
std::string value_name = yaml_parser.OutputNames()[i];
if (yaml_parser.HasInplace(value_name)) {
std::string inplace_name = yaml_parser.InplaceName(value_name);
ir::Value inplace_value =
op->operand(yaml_parser.InputName2Id().at(inplace_name));
std::string var_name = name_map->at(inplace_value);
VLOG(4) << "inplace: " << value_name << " -> " << inplace_name
<< " (var: " << var_name << ")";
name_map->emplace(value, var_name);
} else {
BuildValue(value, scope, local_scope, name_map, count);
}
}
}
void BuildScope(const ir::Block& block,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
......@@ -178,77 +273,39 @@ void BuildScope(const ir::Block& block,
for (auto it = block.begin(); it != block.end(); ++it) {
ir::Operation* op = *it;
auto attr_map = op->attributes();
std::string op_name = op->name();
if (attr_map.count("op_name")) {
op_name = attr_map.at("op_name").dyn_cast<ir::StrAttribute>().data();
if (op->attributes().count("op_name")) {
op_name =
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data();
}
VLOG(4) << "BuildScope for :" << op_name;
if (op_name == "pd.feed" || op_name == "pd.fetch" ||
op_name == "builtin.combine" || op_name == "builtin.set_parameter" ||
op_name == "builtin.get_parameter") {
VLOG(6) << "HandleForSpecialOp: " << op_name;
VLOG(4) << "HandleForSpecialOp: " << op_name;
HandleForSpecialOp(op, scope, inner_local_scope, name_map, count);
continue;
}
size_t input_num = op->num_operands();
if (input_num > 0) {
for (size_t i = 0; i < input_num; ++i) {
auto ptr = op->operand(i);
if (ptr) {
PADDLE_ENFORCE_NE(
name_map->find(ptr),
name_map->end(),
phi::errors::PreconditionNotMet(
"input should in name map, [%d] 'th input of [%s] op",
i,
op_name));
}
}
}
CheckInputVars(op, op_name, *name_map);
int out_num = op->num_results();
if (out_num > 0) {
for (int i = 0; i < out_num; ++i) {
ir::Value ptr = op->result(i);
std::string name;
if (name_map->find(ptr) != name_map->end()) {
name = name_map->at(ptr);
} else {
name = "inner_var_" + std::to_string(count++);
name_map->emplace(ptr, name);
}
auto var = CreateVar(ptr, name, scope, inner_local_scope);
// Only support DenseTensor or Vector<DenseTensor>
if (!ptr.type()) {
var->GetMutable<phi::DenseTensor>();
} else if (ptr.type()
.isa<paddle::dialect::AllocatedDenseTensorType>()) {
var->GetMutable<phi::DenseTensor>();
} else if (ptr.type().isa<ir::VectorType>()) {
auto tensor_array =
var->GetMutable<paddle::framework::TensorRefArray>();
for (size_t i = 0; i < ptr.type().dyn_cast<ir::VectorType>().size();
i++) {
PADDLE_ENFORCE(
ptr.type()
.dyn_cast<ir::VectorType>()[i]
.isa<paddle::dialect::AllocatedDenseTensorType>(),
paddle::platform::errors::Fatal(
"Element of VectorType output only support "
"DenseTensorType"));
std::string name_i = "inner_var_" + std::to_string(count++);
auto var_i = CreateVar(ptr, name_i, scope, inner_local_scope);
tensor_array->emplace_back(var_i->GetMutable<phi::DenseTensor>());
}
} else {
PADDLE_THROW(phi::errors::PreconditionNotMet(
"Output only support DenseTensorType or VectorType"));
}
if (op->num_results() < 1) continue;
if (op->attributes().count("is_inplace") != 0 &&
op->attributes()
.at("is_inplace")
.dyn_cast<ir::BoolAttribute>()
.data()) {
HandleForInplaceOp(op, scope, inner_local_scope, name_map, count);
continue;
} else {
for (size_t i = 0; i < op->num_results(); ++i) {
BuildValue(op->result(i), scope, local_scope, name_map, count);
}
}
}
VLOG(4) << "***** [after build] scope: ******\n"
<< paddle::framework::GenScopeTreeDebugInfo(
const_cast<paddle::framework::Scope*>(scope->root()));
......
......@@ -45,12 +45,27 @@ paddle::framework::Variable* CreateVar(ir::Value value,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope);
void BuildValue(ir::Value value,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map,
int& count); // NOLINT
void HandleForSpecialOp(ir::Operation* op,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map,
int& count); // NOLINT
void HandleForInplaceOp(ir::Operation* op,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map,
int& count); // NOLINT
void CheckInputVars(ir::Operation* op,
const std::unordered_map<ir::Value, std::string>& name_map);
void BuildScope(const ir::Block& block,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
......@@ -80,13 +95,13 @@ void BuildPhiContext(
auto& vec_kernel_fn_tensor_params = op_yaml_info.TensorParams(is_kernel);
auto& name2id = op_yaml_info.Name2Id();
auto& name2id = op_yaml_info.InputName2Id();
for (auto& t : vec_kernel_fn_tensor_params) {
PADDLE_ENFORCE_EQ(
name2id.count(t),
true,
phi::errors::NotFound("param [%s] MUST in name2id map", t));
auto index = op_yaml_info.Name2Id().at(t);
auto index = op_yaml_info.InputName2Id().at(t);
ir::Value ptr = op->operand(index);
if (!ptr) {
phi::DenseTensor* ptr = nullptr;
......@@ -97,7 +112,7 @@ void BuildPhiContext(
auto in_var_name = name_map.at(ptr);
VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name;
PADDLE_ENFORCE_NOT_NULL(inner_scope->FindLocalVar(in_var_name),
PADDLE_ENFORCE_NOT_NULL(inner_scope->FindVar(in_var_name),
phi::errors::PreconditionNotMet(
"can not find var[%s] in scope", in_var_name));
auto var = inner_scope->FindVar(in_var_name);
......
file(GLOB PD_INTERFACE_SRCS "*.cc")
cc_library(
pd_trait
SRCS ${PD_INTERFACE_SRCS}
DEPS ir)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/op_base.h"
namespace paddle {
namespace dialect {
class InplaceTrait : public ir::OpTraitBase<InplaceTrait> {
public:
explicit InplaceTrait(ir::Operation *op)
: ir::OpTraitBase<InplaceTrait>(op) {}
};
} // namespace dialect
} // namespace paddle
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::InplaceTrait)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/ir/trait/inplace.h"
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::InplaceTrait)
cc_library(
transform_general_functions
SRCS transform_general_functions.cc
DEPS ir phi pd_dialect)
DEPS phi pd_dialect ir)
cc_library(
pd_op_to_kernel_pass
SRCS pd_op_to_kernel_pass.cc
DEPS ir phi_utils pd_interface)
DEPS phi_utils pd_interface pd_trait ir)
......@@ -26,11 +26,13 @@
#include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir/interface/op_yaml_info_parser.h"
#include "paddle/fluid/ir/trait/inplace.h"
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/kernel_factory.h"
namespace paddle {
namespace dialect {
......@@ -63,7 +65,7 @@ phi::KernelKey GetKernelKey(
if (data_type_info.size() > 0 && data_type_info[0] != "") {
// only support single input and attribute
auto slot_name = data_type_info[0];
auto& input_map = op_info_parser->Name2Id();
auto& input_map = op_info_parser->InputName2Id();
if (input_map.count(slot_name)) {
// parse from input
......@@ -340,6 +342,10 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) {
op_attribute.emplace(it1->first, it1->second);
}
if ((*it)->HasTrait<paddle::dialect::InplaceTrait>()) {
op_attribute.emplace("is_inplace", ir::BoolAttribute::get(ctx, true));
}
ir::Operation* op = ir::Operation::Create(
vec_inputs, op_attribute, op_output_types, op_info);
......
......@@ -6,6 +6,7 @@ cc_test_old(
pd_op_to_kernel_pass
pd_dialect
phi_kernel_adaptor
pd_trait
ir
phi
gtest)
......@@ -39,6 +39,7 @@ PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(full_int_array, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(uniform, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sqrt, CPU, ALL_LAYOUT);
bool simple_cmp(float a, float b) { return std::abs((a - b) / a) < 1e-5; }
......@@ -246,5 +247,44 @@ TEST(StandaloneExecutor, data_transfer) {
}
#endif
TEST(StandaloneExecutor, run_inplace_sqrt) {
ir::IrContext* ctx = ir::IrContext::Instance();
ir::Program program((ctx));
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Builder builder = ir::Builder(ctx, program.block());
paddle::dialect::FullOp full = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2, 2}, 4.0, phi::DataType::FLOAT32, phi::CPUPlace());
builder.Build<paddle::dialect::Sqrt_Op>(full->result(0));
auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);
kernel_program->Print(std::cout);
auto place = platform::CPUPlace();
Scope scope;
InterpreterCore test_core(place, std::move(kernel_program), &scope);
test_core.Run({});
auto out_tensor = test_core.local_scope() == nullptr
? scope.FindVar("inner_var_0")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar("inner_var_0")
->Get<phi::DenseTensor>();
bool res0 = simple_cmp(out_tensor.data<float>()[0], 2.0);
bool res1 = simple_cmp(out_tensor.data<float>()[1], 2.0);
bool res2 = simple_cmp(out_tensor.data<float>()[2], 2.0);
bool res3 = simple_cmp(out_tensor.data<float>()[3], 2.0);
EXPECT_EQ(scope.kids().size(), 1u);
EXPECT_EQ(scope.kids().front()->Size(), 1u);
EXPECT_EQ(res0, true);
EXPECT_EQ(res1, true);
EXPECT_EQ(res2, true);
EXPECT_EQ(res3, true);
}
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册