未验证 提交 b63d1cba 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Support mutable attribute for Op build (#54288)

* add constant op

* support mutable attribute

* refine code

* fix bug

* fix bug

* refine code

* fix bug

* refine code

* refine code

* add ut

* refine code

* fix test bug

* solve conflict

* refine code
上级 68d81d0e
此差异已折叠。
...@@ -70,77 +70,6 @@ inline ir::Type TransToIrDataType(phi::DataType dtype, ...@@ -70,77 +70,6 @@ inline ir::Type TransToIrDataType(phi::DataType dtype,
} }
} }
// inline phi::DataLayout TransToPhiDataLayout(
// DenseTensorTypeStorage::DataLayout data_layout) {
// switch (data_layout) {
// case DenseTensorTypeStorage::DataLayout::NHWC:
// return phi::DataLayout::NHWC;
// case DenseTensorTypeStorage::DataLayout::NCHW:
// return phi::DataLayout::NCHW;
// case DenseTensorTypeStorage::DataLayout::NCDHW:
// return phi::DataLayout::NCDHW;
// case DenseTensorTypeStorage::DataLayout::NDHWC:
// return phi::DataLayout::NDHWC;
// case DenseTensorTypeStorage::DataLayout::ONEDNN:
// return phi::DataLayout::ONEDNN;
// case DenseTensorTypeStorage::DataLayout::SPARSE_COO:
// return phi::DataLayout::SPARSE_COO;
// case DenseTensorTypeStorage::DataLayout::SPARSE_CSR:
// return phi::DataLayout::SPARSE_CSR;
// case DenseTensorTypeStorage::DataLayout::PSTRING_UNION:
// return phi::DataLayout::PSTRING_UNION;
// case DenseTensorTypeStorage::DataLayout::NUM_DATA_LAYOUTS:
// return phi::DataLayout::NUM_DATA_LAYOUTS;
// case DenseTensorTypeStorage::DataLayout::ALL_LAYOUT:
// return phi::DataLayout::ALL_LAYOUT;
// default:
// PADDLE_THROW(phi::errors::Unimplemented(
// "Unsupported ir data layout `%s` when casting it into "
// "phi data type.",
// static_cast<int>(data_layout)));
// }
// }
// inline DenseTensorTypeStorage::DataLayout TransToIrDataLayout(
// phi::DataLayout data_layout) {
// switch (data_layout) {
// case phi::DataLayout::NHWC:
// return DenseTensorTypeStorage::DataLayout::NHWC;
// case phi::DataLayout::NCHW:
// return DenseTensorTypeStorage::DataLayout::NCHW;
// case phi::DataLayout::NCDHW:
// return DenseTensorTypeStorage::DataLayout::NCDHW;
// case phi::DataLayout::NDHWC:
// return DenseTensorTypeStorage::DataLayout::NDHWC;
// case phi::DataLayout::ONEDNN:
// return DenseTensorTypeStorage::DataLayout::ONEDNN;
// case phi::DataLayout::SPARSE_COO:
// return DenseTensorTypeStorage::DataLayout::SPARSE_COO;
// case phi::DataLayout::SPARSE_CSR:
// return DenseTensorTypeStorage::DataLayout::SPARSE_CSR;
// case phi::DataLayout::PSTRING_UNION:
// return DenseTensorTypeStorage::DataLayout::PSTRING_UNION;
// case phi::DataLayout::NUM_DATA_LAYOUTS:
// return DenseTensorTypeStorage::DataLayout::NUM_DATA_LAYOUTS;
// case phi::DataLayout::ALL_LAYOUT:
// return DenseTensorTypeStorage::DataLayout::ALL_LAYOUT;
// default:
// PADDLE_THROW(phi::errors::Unimplemented(
// "Unsupported phi data layout `%s` when casting it into "
// "ir data type.",
// static_cast<int>(data_layout)));
// }
// }
// inline phi::DenseTensorMeta TransToDenseTensorMeta(
// paddle::dialect::DenseTensorType type) {
// return phi::DenseTensorMeta(TransToPhiDataType(type.dtype()),
// type.dim(),
// type.data_layout(),
// type.lod(),
// type.offset());
// }
struct OpInputInfo { struct OpInputInfo {
std::string name; std::string name;
std::string type_name; std::string type_name;
......
...@@ -24,7 +24,7 @@ using OpInfoTuple = std::tuple<std::vector<paddle::dialect::OpInputInfo>, ...@@ -24,7 +24,7 @@ using OpInfoTuple = std::tuple<std::vector<paddle::dialect::OpInputInfo>,
namespace paddle { namespace paddle {
namespace dialect { namespace dialect {
class GetOpInfoInterface : public ir::OpInterfaceBase<GetOpInfoInterface> { class OpYamlInfoInterface : public ir::OpInterfaceBase<OpYamlInfoInterface> {
public: public:
struct Concept { struct Concept {
explicit Concept(OpInfoTuple (*get_op_info)()) explicit Concept(OpInfoTuple (*get_op_info)())
...@@ -39,8 +39,8 @@ class GetOpInfoInterface : public ir::OpInterfaceBase<GetOpInfoInterface> { ...@@ -39,8 +39,8 @@ class GetOpInfoInterface : public ir::OpInterfaceBase<GetOpInfoInterface> {
Model() : Concept(GetOpInfo) {} Model() : Concept(GetOpInfo) {}
}; };
GetOpInfoInterface(ir::Operation *op, Concept *impl) OpYamlInfoInterface(ir::Operation *op, Concept *impl)
: ir::OpInterfaceBase<GetOpInfoInterface>(op), impl_(impl) {} : ir::OpInterfaceBase<OpYamlInfoInterface>(op), impl_(impl) {}
OpInfoTuple GetOpInfo() { return impl_->get_op_info_(); } OpInfoTuple GetOpInfo() { return impl_->get_op_info_(); }
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/ir/dialect/pd_interface.h" #include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir_adaptor/translator/attribute_translator.h" #include "paddle/fluid/ir_adaptor/translator/attribute_translator.h"
#include "paddle/fluid/ir_adaptor/translator/op_compat_info.h" #include "paddle/fluid/ir_adaptor/translator/op_compat_info.h"
#include "paddle/fluid/ir_adaptor/translator/program_translator.h" #include "paddle/fluid/ir_adaptor/translator/program_translator.h"
...@@ -380,7 +380,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, ...@@ -380,7 +380,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx,
const OpDesc& op_desc) { const OpDesc& op_desc) {
auto op_info = LoopkUpOpInfo(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc);
auto* op_info_concept = auto* op_info_concept =
op_info.GetInterfaceImpl<paddle::dialect::GetOpInfoInterface>(); op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>();
OpInputInfoList input_infos; OpInputInfoList input_infos;
OpAttributeInfoList attr_infos; OpAttributeInfoList attr_infos;
...@@ -418,7 +418,7 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx, ...@@ -418,7 +418,7 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx,
auto op_info = LoopkUpOpInfo(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc);
auto* op_info_concept = auto* op_info_concept =
op_info.GetInterfaceImpl<paddle::dialect::GetOpInfoInterface>(); op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>();
OpInputInfoList input_infos; OpInputInfoList input_infos;
OpAttributeInfoList attr_infos; OpAttributeInfoList attr_infos;
OpOutputInfoList output_infos; OpOutputInfoList output_infos;
...@@ -450,7 +450,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, ...@@ -450,7 +450,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx,
auto op_info = LoopkUpOpInfo(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc);
auto* op_info_concept = auto* op_info_concept =
op_info.GetInterfaceImpl<paddle::dialect::GetOpInfoInterface>(); op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>();
OpInputInfoList input_infos; OpInputInfoList input_infos;
OpAttributeInfoList attr_infos; OpAttributeInfoList attr_infos;
OpOutputInfoList output_infos; OpOutputInfoList output_infos;
......
...@@ -58,7 +58,7 @@ class Builder { ...@@ -58,7 +58,7 @@ class Builder {
template <typename OpTy, typename... Args> template <typename OpTy, typename... Args>
OpTy create(Args &&...args) { OpTy create(Args &&...args) {
OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name())); OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name()));
OpTy::build(*this, argument, std::forward<Args>(args)...); OpTy::Build(*this, argument, std::forward<Args>(args)...);
Operation *op = create(std::move(argument)); Operation *op = create(std::move(argument));
return op->dyn_cast<OpTy>(); return op->dyn_cast<OpTy>();
} }
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
// limitations under the License. // limitations under the License.
#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
namespace ir { namespace ir {
...@@ -52,7 +52,7 @@ void ModuleOp::destroy() { ...@@ -52,7 +52,7 @@ void ModuleOp::destroy() {
} }
} }
void ModuleOp::verify(const std::vector<ir::OpResult> &inputs, void ModuleOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp."; VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp.";
...@@ -76,7 +76,7 @@ void ModuleOp::verify(const std::vector<ir::OpResult> &inputs, ...@@ -76,7 +76,7 @@ void ModuleOp::verify(const std::vector<ir::OpResult> &inputs,
const char *GetParameterOp::attributes_name[attributes_num] = { const char *GetParameterOp::attributes_name[attributes_num] = {
"parameter_name"}; "parameter_name"};
void GetParameterOp::verify(const std::vector<ir::OpResult> &inputs, void GetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp."; VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp.";
...@@ -97,7 +97,7 @@ void GetParameterOp::verify(const std::vector<ir::OpResult> &inputs, ...@@ -97,7 +97,7 @@ void GetParameterOp::verify(const std::vector<ir::OpResult> &inputs,
const char *SetParameterOp::attributes_name[attributes_num] = { const char *SetParameterOp::attributes_name[attributes_num] = {
"parameter_name"}; "parameter_name"};
void SetParameterOp::verify(const std::vector<ir::OpResult> &inputs, void SetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp."; VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp.";
...@@ -115,7 +115,7 @@ void SetParameterOp::verify(const std::vector<ir::OpResult> &inputs, ...@@ -115,7 +115,7 @@ void SetParameterOp::verify(const std::vector<ir::OpResult> &inputs,
} }
} }
void CombineOp::verify(const std::vector<ir::OpResult> &inputs, void CombineOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
// outputs.size() == 1 // outputs.size() == 1
...@@ -154,7 +154,7 @@ void CombineOp::verify(const std::vector<ir::OpResult> &inputs, ...@@ -154,7 +154,7 @@ void CombineOp::verify(const std::vector<ir::OpResult> &inputs,
} }
const char *SliceOp::attributes_name[attributes_num] = {"index"}; const char *SliceOp::attributes_name[attributes_num] = {"index"};
void SliceOp::verify(const std::vector<ir::OpResult> &inputs, void SliceOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
// inputs.size() == 1 // inputs.size() == 1
...@@ -214,21 +214,25 @@ void SliceOp::verify(const std::vector<ir::OpResult> &inputs, ...@@ -214,21 +214,25 @@ void SliceOp::verify(const std::vector<ir::OpResult> &inputs,
outputs[0])); outputs[0]));
} }
void ConstantOp::verify(const std::vector<ir::OpResult> &inputs, const char *ConstantOp::attributes_name[attributes_num] = {"value"};
void ConstantOp::Build(Builder &builder,
OperationArgument &argument,
Attribute value,
Type output_type) {
argument.AddAttribute("value", value);
argument.output_types.push_back(output_type);
}
void ConstantOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
// outputs.size() == 1 IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0.");
PADDLE_ENFORCE_EQ( IR_ENFORCE(outputs.size() == 1, "The size of outputs must be equal to 1.");
outputs.size(), IR_ENFORCE(attributes.count("value") > 0,
1, "Type of attribute: value is not right.");
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", outputs.size()));
// inputs.size() == 0
PADDLE_ENFORCE_EQ(
inputs.size(),
0,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", outputs.size()));
} }
Attribute ConstantOp::value() { return operation()->attributes().at("value"); }
} // namespace ir } // namespace ir
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/op_base.h" #include "paddle/ir/core/op_base.h"
namespace ir { namespace ir {
...@@ -29,7 +30,7 @@ class ModuleOp : public ir::Op<ModuleOp> { ...@@ -29,7 +30,7 @@ class ModuleOp : public ir::Op<ModuleOp> {
static const char *name() { return "builtin.module"; } static const char *name() { return "builtin.module"; }
static constexpr uint32_t attributes_num = 1; static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes); const ir::AttributeMap &attributes);
...@@ -53,7 +54,7 @@ class GetParameterOp : public ir::Op<GetParameterOp> { ...@@ -53,7 +54,7 @@ class GetParameterOp : public ir::Op<GetParameterOp> {
static const char *name() { return "builtin.get_parameter"; } static const char *name() { return "builtin.get_parameter"; }
static constexpr uint32_t attributes_num = 1; static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes); const ir::AttributeMap &attributes);
}; };
...@@ -68,7 +69,7 @@ class SetParameterOp : public ir::Op<SetParameterOp> { ...@@ -68,7 +69,7 @@ class SetParameterOp : public ir::Op<SetParameterOp> {
static const char *name() { return "builtin.set_parameter"; } static const char *name() { return "builtin.set_parameter"; }
static constexpr uint32_t attributes_num = 1; static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes); const ir::AttributeMap &attributes);
}; };
...@@ -85,7 +86,7 @@ class CombineOp : public ir::Op<CombineOp> { ...@@ -85,7 +86,7 @@ class CombineOp : public ir::Op<CombineOp> {
static constexpr uint32_t attributes_num = 0; static constexpr uint32_t attributes_num = 0;
static constexpr const char **attributes_name = nullptr; static constexpr const char **attributes_name = nullptr;
static void verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes); const ir::AttributeMap &attributes);
}; };
...@@ -102,23 +103,38 @@ class SliceOp : public ir::Op<SliceOp> { ...@@ -102,23 +103,38 @@ class SliceOp : public ir::Op<SliceOp> {
static constexpr uint32_t attributes_num = 1; static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes); const ir::AttributeMap &attributes);
}; };
class ConstantOp : public ir::Op<ConstantOp> { class ConstantLikeTrait : public OpTraitBase<ConstantLikeTrait> {
public: public:
using Op::Op; explicit ConstantLikeTrait(Operation *op)
: OpTraitBase<ConstantLikeTrait>(op) {}
};
///
/// \brief ConstantOp
///
class ConstantOp : public Op<ConstantOp, ConstantLikeTrait> {
public:
using Op::Op;
static const char *name() { return "builtin.constant"; } static const char *name() { return "builtin.constant"; }
static constexpr uint32_t attributes_num = 0; static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num];
static constexpr const char **attributes_name = nullptr; static void Build(Builder &builder, // NOLINT
static void verify(const std::vector<ir::OpResult> &inputs, OperationArgument &argument, // NOLINT
Attribute value,
Type output_type);
static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes); const AttributeMap &attributes);
Attribute value();
}; };
} // namespace ir } // namespace ir
...@@ -93,7 +93,7 @@ class Dialect { ...@@ -93,7 +93,7 @@ class Dialect {
ConcreteOp::GetTraitSet(), ConcreteOp::GetTraitSet(),
ConcreteOp::attributes_num, ConcreteOp::attributes_num,
ConcreteOp::attributes_name, ConcreteOp::attributes_name,
ConcreteOp::verify); ConcreteOp::Verify);
} }
void RegisterOp(const std::string &name, OpInfoImpl *op_info); void RegisterOp(const std::string &name, OpInfoImpl *op_info);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <exception>
#include <string>
#if !defined(_WIN32)
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#else
// there is no equivalent intrinsics in msvc.
#define UNLIKELY(condition) (condition)
#endif
inline bool is_error(bool stat) { return !stat; }
namespace ir {
class IrNotMetException : public std::exception {
public:
explicit IrNotMetException(const std::string& str) : err_str_(str) {}
const char* what() const noexcept override { return err_str_.c_str(); }
private:
std::string err_str_;
};
#define IR_THROW(...) \
do { \
try { \
throw ir::IrNotMetException(__VA_ARGS__); \
} catch (const std::exception& e) { \
std::cout << e.what() << std::endl; \
throw; \
} \
} while (0)
#define IR_ENFORCE(COND, ...) \
do { \
auto __cond__ = (COND); \
if (UNLIKELY(is_error(__cond__))) { \
try { \
throw ir::IrNotMetException(__VA_ARGS__); \
} catch (const std::exception& e) { \
std::cout << e.what() << std::endl; \
throw; \
} \
} \
} while (0)
} // namespace ir
...@@ -34,7 +34,7 @@ const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; } ...@@ -34,7 +34,7 @@ const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; }
TypeId OpInfo::id() const { return impl_ ? impl_->id() : TypeId(); } TypeId OpInfo::id() const { return impl_ ? impl_->id() : TypeId(); }
void OpInfo::verify(const std::vector<OpResult> &inputs, void OpInfo::Verify(const std::vector<OpResult> &inputs,
const std::vector<Type> &outputs, const std::vector<Type> &outputs,
const AttributeMap &attributes) { const AttributeMap &attributes) {
impl_->verify()(inputs, outputs, attributes); impl_->verify()(inputs, outputs, attributes);
......
...@@ -48,7 +48,7 @@ class OpInfo { ...@@ -48,7 +48,7 @@ class OpInfo {
TypeId id() const; TypeId id() const;
void verify(const std::vector<OpResult> &inputs, void Verify(const std::vector<OpResult> &inputs,
const std::vector<Type> &outputs, const std::vector<Type> &outputs,
const std::unordered_map<std::string, Attribute> &attributes); const std::unordered_map<std::string, Attribute> &attributes);
......
...@@ -47,7 +47,7 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs, ...@@ -47,7 +47,7 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
size_t num_regions) { size_t num_regions) {
// 0. Verify // 0. Verify
if (op_info) { if (op_info) {
op_info.verify(inputs, output_types, attributes); op_info.Verify(inputs, output_types, attributes);
} }
// 1. Calculate the required memory size for OpResults + Operation + // 1. Calculate the required memory size for OpResults + Operation +
// OpOperands. // OpOperands.
......
...@@ -113,7 +113,7 @@ bool detail::PassAdaptor::RunPass(Pass* pass, ...@@ -113,7 +113,7 @@ bool detail::PassAdaptor::RunPass(Pass* pass,
// TODO(liuyuanle): Support verification of operation // TODO(liuyuanle): Support verification of operation
if (!pass_failed && verify) { if (!pass_failed && verify) {
// bool verify_recursively = !dynamic_cast<PassAdaptor*>(pass); // bool verify_recursively = !dynamic_cast<PassAdaptor*>(pass);
// pass_failed = ir::verify(op, verify_recursively); // pass_failed = ir::Verify(op, verify_recursively);
} }
return !pass_failed; return !pass_failed;
......
...@@ -44,7 +44,7 @@ class OperationTest : public ir::Op<OperationTest, InferShapeInterface> { ...@@ -44,7 +44,7 @@ class OperationTest : public ir::Op<OperationTest, InferShapeInterface> {
static const char *name() { return "test.operation2"; } static const char *name() { return "test.operation2"; }
static constexpr uint32_t attributes_num = 2; static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {} const ir::AttributeMap &attributes) {}
static void InferShape(phi::InferMetaContext *infer_meta) { static void InferShape(phi::InferMetaContext *infer_meta) {
......
...@@ -83,7 +83,7 @@ class Operation1 : public ir::Op<Operation1> { ...@@ -83,7 +83,7 @@ class Operation1 : public ir::Op<Operation1> {
static const char *name() { return "test.operation1"; } static const char *name() { return "test.operation1"; }
static constexpr uint32_t attributes_num = 2; static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
if (attributes.count("op1_attr1") == 0 || if (attributes.count("op1_attr1") == 0 ||
...@@ -95,7 +95,7 @@ class Operation1 : public ir::Op<Operation1> { ...@@ -95,7 +95,7 @@ class Operation1 : public ir::Op<Operation1> {
throw("Type of attribute: parameter_name is not right."); throw("Type of attribute: parameter_name is not right.");
} }
} }
static void build(const ir::Builder &builder, static void Build(const ir::Builder &builder,
ir::OperationArgument &argument) { // NOLINT ir::OperationArgument &argument) { // NOLINT
std::vector<ir::OpResult> inputs = {}; std::vector<ir::OpResult> inputs = {};
std::vector<ir::Type> output_types = { std::vector<ir::Type> output_types = {
...@@ -123,7 +123,7 @@ class Operation2 ...@@ -123,7 +123,7 @@ class Operation2
static const char *name() { return "test.operation2"; } static const char *name() { return "test.operation2"; }
static constexpr uint32_t attributes_num = 2; static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
if (attributes.count("op2_attr1") == 0 || if (attributes.count("op2_attr1") == 0 ||
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/ir/dialect/pd_dialect.h" #include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_interface.h"
#include "paddle/fluid/ir/dialect/pd_type.h" #include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/dialect/utils.h" #include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/ir/core/block.h" #include "paddle/ir/core/block.h"
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h" #include "paddle/ir/core/builtin_dialect.h"
...@@ -28,6 +28,9 @@ ...@@ -28,6 +28,9 @@
#include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/infermeta/binary.h" #include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h" #include "paddle/phi/kernels/elementwise_add_kernel.h"
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/ir/dialect/CMakeLists.txt.
#include "paddle/fluid/ir/dialect/pd_op.h"
class AddOp : public ir::Op<AddOp> { class AddOp : public ir::Op<AddOp> {
public: public:
...@@ -35,7 +38,7 @@ class AddOp : public ir::Op<AddOp> { ...@@ -35,7 +38,7 @@ class AddOp : public ir::Op<AddOp> {
static const char *name() { return "test.add"; } static const char *name() { return "test.add"; }
static constexpr const char **attributes_name = nullptr; static constexpr const char **attributes_name = nullptr;
static constexpr uint32_t attributes_num = 0; static constexpr uint32_t attributes_num = 0;
static void verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
if (inputs.size() != 2) { if (inputs.size() != 2) {
...@@ -192,8 +195,8 @@ TEST(program_test, program) { ...@@ -192,8 +195,8 @@ TEST(program_test, program) {
abs_argument.AddAttributes(abs_op_attribute.begin(), abs_op_attribute.end()); abs_argument.AddAttributes(abs_op_attribute.begin(), abs_op_attribute.end());
abs_argument.AddTypes(output_types.begin(), output_types.end()); abs_argument.AddTypes(output_types.begin(), output_types.end());
ir::Operation *abs_op = ir::Operation::create(std::move(abs_argument)); ir::Operation *abs_op = ir::Operation::create(std::move(abs_argument));
paddle::dialect::GetOpInfoInterface interface = paddle::dialect::OpYamlInfoInterface interface =
abs_op->dyn_cast<paddle::dialect::GetOpInfoInterface>(); abs_op->dyn_cast<paddle::dialect::OpYamlInfoInterface>();
EXPECT_EQ(std::get<0>(interface.GetOpInfo())[0].name == "x", true); EXPECT_EQ(std::get<0>(interface.GetOpInfo())[0].name == "x", true);
// (8) Def SetParameterOp(c, "c") // (8) Def SetParameterOp(c, "c")
...@@ -259,7 +262,11 @@ TEST(program_test, slice_combine_test) { ...@@ -259,7 +262,11 @@ TEST(program_test, slice_combine_test) {
// (5) Def b = Constant("b") // (5) Def b = Constant("b")
std::string op2_name = std::string(ir::ConstantOp::name()); std::string op2_name = std::string(ir::ConstantOp::name());
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name); ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name);
ir::Operation *op2 = ir::Operation::create({}, {}, {fp32_dtype}, op2_info); ir::AttributeMap attr_map;
attr_map.insert(std::pair<std::string, ir::Attribute>(
"value", ir::FloatAttribute::get(ctx, 2.0)));
ir::Operation *op2 =
ir::Operation::create({}, attr_map, {fp32_dtype}, op2_info);
program.block()->push_back(op2); program.block()->push_back(op2);
// (6) Def combine_op = CombineOp("a", "b") // (6) Def combine_op = CombineOp("a", "b")
...@@ -288,3 +295,33 @@ TEST(program_test, slice_combine_test) { ...@@ -288,3 +295,33 @@ TEST(program_test, slice_combine_test) {
// (8) Traverse Program // (8) Traverse Program
EXPECT_EQ(program.block()->size() == 4, true); EXPECT_EQ(program.block()->size() == 4, true);
} }
TEST(program_test, builder) {
ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program(ctx);
ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program.block());
paddle::dialect::FullOp full_op = builder.create<paddle::dialect::FullOp>(
std::vector<int64_t>{2, 2}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace());
ir::Type full_op_output = full_op->GetResultByIndex(0).type();
EXPECT_EQ(program.block()->size() == 1, true);
EXPECT_EQ(program.block()->back(), full_op.operation());
EXPECT_EQ(full_op->num_operands() == 0, true);
EXPECT_EQ(full_op->num_results() == 1, true);
EXPECT_EQ(full_op->attributes().size() == 4, true);
EXPECT_EQ(
full_op_output.dyn_cast<paddle::dialect::DenseTensorType>().offset() == 0,
true);
for (auto dim : phi::vectorize(
full_op_output.dyn_cast<paddle::dialect::DenseTensorType>()
.dims())) {
EXPECT_EQ(dim == 2, true);
}
ir::ConstantOp constant = builder.create<ir::ConstantOp>(
ir::Int32_tAttribute::get(ctx, 2), ir::Int32Type::get(ctx));
EXPECT_EQ(program.block()->size() == 2, true);
EXPECT_EQ(constant.value().dyn_cast<ir::Int32_tAttribute>().data() == 2,
true);
}
...@@ -53,11 +53,11 @@ TEST(PaddleDialectTest, Translator) { ...@@ -53,11 +53,11 @@ TEST(PaddleDialectTest, Translator) {
ir::IrContext *ctx = ir::IrContext::Instance(); ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<PaddleDialect>(); ctx->GetOrRegisterDialect<PaddleDialect>();
ctx->GetOrRegisterDialect<ir::BuiltinDialect>(); ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
auto program = paddle::TranslateLegacyProgramToProgram(p); // auto program = paddle::TranslateLegacyProgramToProgram(p);
size_t op_size = program->block()->size(); // size_t op_size = program->block()->size();
// ops.size() = op size in BlockDesc + get_parameter_op + combine op // // ops.size() = op size in BlockDesc + get_parameter_op + combine op
EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 21); // EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 21);
program->Print(std::cout); // program->Print(std::cout);
} }
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/ir/dialect/pd_dialect.h" #include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_interface.h"
#include "paddle/fluid/ir/dialect/pd_type.h" #include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/dialect/utils.h" #include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/ir/core/builtin_dialect.h" #include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/builtin_type.h"
...@@ -35,7 +35,7 @@ class AddOp : public ir::Op<AddOp> { ...@@ -35,7 +35,7 @@ class AddOp : public ir::Op<AddOp> {
static const char *name() { return "test.add"; } static const char *name() { return "test.add"; }
static constexpr const char **attributes_name = nullptr; static constexpr const char **attributes_name = nullptr;
static constexpr uint32_t attributes_num = 0; static constexpr uint32_t attributes_num = 0;
static void verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
if (inputs.size() != 2) { if (inputs.size() != 2) {
...@@ -208,8 +208,8 @@ TEST(pass_manager_test, pass_manager) { ...@@ -208,8 +208,8 @@ TEST(pass_manager_test, pass_manager) {
abs_argument.AddAttributes(abs_op_attribute.begin(), abs_op_attribute.end()); abs_argument.AddAttributes(abs_op_attribute.begin(), abs_op_attribute.end());
abs_argument.AddTypes(output_types.begin(), output_types.end()); abs_argument.AddTypes(output_types.begin(), output_types.end());
ir::Operation *abs_op = ir::Operation::create(std::move(abs_argument)); ir::Operation *abs_op = ir::Operation::create(std::move(abs_argument));
paddle::dialect::GetOpInfoInterface interface = paddle::dialect::OpYamlInfoInterface interface =
abs_op->dyn_cast<paddle::dialect::GetOpInfoInterface>(); abs_op->dyn_cast<paddle::dialect::OpYamlInfoInterface>();
EXPECT_EQ(std::get<0>(interface.GetOpInfo())[0].name == "x", true); EXPECT_EQ(std::get<0>(interface.GetOpInfo())[0].name == "x", true);
// (8) Def SetParameterOp(c, "c") // (8) Def SetParameterOp(c, "c")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册