未验证 提交 e9fac90e 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Refine some ir logic (#57072)

* fix bug

* refine code

* refine code

* refine code

* refine code

* refine code

* refine code

* fix bug
上级 9358b4bc
...@@ -46,7 +46,6 @@ ...@@ -46,7 +46,6 @@
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_dialect.h" #include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_dialect.h"
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_op.h" #include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_op.h"
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.h" #include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.h"
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/legacy_kernel_op.h"
#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h" #include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h"
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
......
...@@ -330,14 +330,16 @@ def GenBuildOutputs( ...@@ -330,14 +330,16 @@ def GenBuildOutputs(
.dyn_cast<paddle::dialect::IntArrayAttribute>() .dyn_cast<paddle::dialect::IntArrayAttribute>()
.data() .data()
.GetData())); .GetData()));
}} }} else if ({name}_.type().isa<ir::VectorType>()) {{
else {{
PADDLE_ENFORCE(
{name}_.type().isa<ir::VectorType>(),
phi::errors::PreconditionNotMet("section Type should be VectorType."));
size_t {name}_size = {name}_.type().dyn_cast<ir::VectorType>().size(); size_t {name}_size = {name}_.type().dyn_cast<ir::VectorType>().size();
{name} = std::move(phi::IntArray(std::vector<int64_t>({name}_size, -1))); {name} = std::move(phi::IntArray(std::vector<int64_t>({name}_size, -1)));
{name}.SetFromTensor(true); {name}.SetFromTensor(true);
}} else if ({name}_.type().isa<paddle::dialect::DenseTensorType>()) {{
size_t {name}_size = phi::product({name}_.type().dyn_cast<paddle::dialect::DenseTensorType>().dims());
{name} = std::move(phi::IntArray(std::vector<int64_t>({name}_size, -1)));
{name}.SetFromTensor(true);
}} else {{
PADDLE_THROW(phi::errors::Unimplemented("Only support VectorType or DenseTensorType"));
}}\n""" }}\n"""
CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::Scalar {name}; CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::Scalar {name};
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute.h" #include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute.h"
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_op.h" #include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_op.h"
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.h" #include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.h"
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/legacy_kernel_op.h"
#include "paddle/fluid/platform/init_phi.h" #include "paddle/fluid/platform/init_phi.h"
#include "paddle/ir/core/ir_printer.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
...@@ -75,6 +75,45 @@ void PaddleKernelDialect::PrintAttribute(ir::Attribute attr, ...@@ -75,6 +75,45 @@ void PaddleKernelDialect::PrintAttribute(ir::Attribute attr,
<< "|dtype:" << kernel.dtype() << ">"; << "|dtype:" << kernel.dtype() << ">";
} }
void PaddleKernelDialect::PrintOperation(ir::Operation *op,
ir::IrPrinter &printer) const {
if (op->dyn_cast<PhiKernelOp>() || op->dyn_cast<LegacyKernelOp>()) {
auto &os = printer.os;
printer.PrintOpResult(op);
os << " =";
if (auto phi_kernel_op = op->dyn_cast<PhiKernelOp>()) {
std::string kernel_name = phi_kernel_op.kernel_name();
if (op->attributes().count("is_inplace") != 0 &&
op->attributes()
.at("is_inplace")
.dyn_cast<ir::BoolAttribute>()
.data()) {
kernel_name = kernel_name + "_";
}
os << " \"" << kernel_name << "(phi_kernel)\"";
} else {
auto legacy_kernel_op = op->dyn_cast<LegacyKernelOp>();
std::string kernel_name = legacy_kernel_op.kernel_name();
if (op->attributes().count("is_inplace") != 0 &&
op->attributes()
.at("is_inplace")
.dyn_cast<ir::BoolAttribute>()
.data()) {
kernel_name = kernel_name + "_";
}
os << " \"" << kernel_name << "(legacy_kernel)\"";
}
printer.PrintOpOperands(op);
printer.PrintAttributeMap(op);
os << " :";
printer.PrintOperandsType(op);
os << " -> ";
printer.PrintOpReturnType(op);
} else {
printer.PrintGeneralOperation(op);
}
}
} // namespace dialect } // namespace dialect
} // namespace paddle } // namespace paddle
......
...@@ -29,6 +29,9 @@ class PaddleKernelDialect : public ir::Dialect { ...@@ -29,6 +29,9 @@ class PaddleKernelDialect : public ir::Dialect {
void PrintAttribute(ir::Attribute attr, std::ostream& os) const override; void PrintAttribute(ir::Attribute attr, std::ostream& os) const override;
void PrintOperation(ir::Operation* op,
ir::IrPrinter& printer) const override; // NOLINT
private: private:
void initialize(); void initialize();
}; };
......
...@@ -56,7 +56,44 @@ phi::KernelKey PhiKernelOp::kernel_key() { ...@@ -56,7 +56,44 @@ phi::KernelKey PhiKernelOp::kernel_key() {
return attributes().at("kernel_key").dyn_cast<KernelAttribute>().data(); return attributes().at("kernel_key").dyn_cast<KernelAttribute>().data();
} }
const char* LegacyKernelOp::attributes_name[attributes_num] = { // NOLINT
"op_name",
"kernel_name",
"kernel_key"};
void LegacyKernelOp::Verify() {
VLOG(4) << "Verifying inputs, outputs and attributes for: LegacyKernelOp.";
auto& attributes = this->attributes();
PADDLE_ENFORCE(attributes.count("op_name") > 0 &&
attributes.at("op_name").isa<ir::StrAttribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: op_name is not right."));
PADDLE_ENFORCE(attributes.count("kernel_name") > 0 &&
attributes.at("kernel_name").isa<ir::StrAttribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: kernel_name is not right."));
PADDLE_ENFORCE(attributes.count("kernel_key") > 0 &&
attributes.at("kernel_key").isa<KernelAttribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: kernel_key is not right."));
}
std::string LegacyKernelOp::op_name() {
return attributes().at("op_name").dyn_cast<ir::StrAttribute>().AsString();
}
std::string LegacyKernelOp::kernel_name() {
return attributes().at("kernel_name").dyn_cast<ir::StrAttribute>().AsString();
}
phi::KernelKey LegacyKernelOp::kernel_key() {
return attributes().at("kernel_key").dyn_cast<KernelAttribute>().data();
}
} // namespace dialect } // namespace dialect
} // namespace paddle } // namespace paddle
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::PhiKernelOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::PhiKernelOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::LegacyKernelOp)
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
namespace paddle { namespace paddle {
namespace dialect { namespace dialect {
class PhiKernelOp : public ir::Op<PhiKernelOp> { class PhiKernelOp : public ir::Op<PhiKernelOp> {
public: public:
using Op::Op; using Op::Op;
...@@ -33,7 +32,20 @@ class PhiKernelOp : public ir::Op<PhiKernelOp> { ...@@ -33,7 +32,20 @@ class PhiKernelOp : public ir::Op<PhiKernelOp> {
void Verify(); void Verify();
}; };
class LegacyKernelOp : public ir::Op<LegacyKernelOp> {
public:
using Op::Op;
static const char *name() { return "pd_kernel.legacy_kernel"; }
static constexpr uint32_t attributes_num = 3;
static const char *attributes_name[attributes_num];
std::string op_name();
std::string kernel_name();
phi::KernelKey kernel_key();
void Verify();
};
} // namespace dialect } // namespace dialect
} // namespace paddle } // namespace paddle
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::PhiKernelOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::PhiKernelOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::LegacyKernelOp)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/legacy_kernel_op.h"
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace dialect {
const char* LegacyKernelOp::attributes_name[attributes_num] = { // NOLINT
"op_name",
"kernel_name",
"kernel_key"};
void LegacyKernelOp::Verify() {
VLOG(4) << "Verifying inputs, outputs and attributes for: LegacyKernelOp.";
auto& attributes = this->attributes();
PADDLE_ENFORCE(attributes.count("op_name") > 0 &&
attributes.at("op_name").isa<ir::StrAttribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: op_name is not right."));
PADDLE_ENFORCE(attributes.count("kernel_name") > 0 &&
attributes.at("kernel_name").isa<ir::StrAttribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: kernel_name is not right."));
PADDLE_ENFORCE(attributes.count("kernel_key") > 0 &&
attributes.at("kernel_key").isa<KernelAttribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: kernel_key is not right."));
}
std::string LegacyKernelOp::op_name() {
return attributes().at("op_name").dyn_cast<ir::StrAttribute>().AsString();
}
std::string LegacyKernelOp::kernel_name() {
return attributes().at("kernel_name").dyn_cast<ir::StrAttribute>().AsString();
}
phi::KernelKey LegacyKernelOp::kernel_key() {
return attributes().at("kernel_key").dyn_cast<KernelAttribute>().data();
}
} // namespace dialect
} // namespace paddle
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::LegacyKernelOp)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/op_base.h"
#include "paddle/phi/core/kernel_factory.h"
namespace paddle {
namespace dialect {
class LegacyKernelOp : public ir::Op<LegacyKernelOp> {
public:
using Op::Op;
static const char *name() { return "pd_kernel.legacy_kernel"; }
static constexpr uint32_t attributes_num = 3;
static const char *attributes_name[attributes_num];
std::string op_name();
std::string kernel_name();
phi::KernelKey kernel_key();
void Verify();
};
} // namespace dialect
} // namespace paddle
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::LegacyKernelOp)
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_dialect.h" #include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_dialect.h"
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_op.h" #include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_op.h"
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.h" #include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.h"
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/legacy_kernel_op.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/kernel_dispatch.h"
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" #include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h"
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_dialect.h" #include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_dialect.h"
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_op.h" #include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_op.h"
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/legacy_kernel_op.h"
#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h" #include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h"
#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h" #include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/init.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册