未验证 提交 3a452e4e 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Refine IR builder and throw methods (#54396)

* refine code

* refine code

* refine code

* refine code

* refine code

* refine code

* refine code

* fix bug

* refine code

* refine code

* refine code

* refine code

* refine code

* delete unused code

* delete unused code

* refine code
上级 b62b384b
此差异已折叠。
...@@ -26,5 +26,23 @@ phi::DataLayout DataLayoutAttribute::data() const { ...@@ -26,5 +26,23 @@ phi::DataLayout DataLayoutAttribute::data() const {
return storage()->GetAsKey(); return storage()->GetAsKey();
} }
phi::Scalar ScalarAttribute::data() {
if (isa<ir::FloatAttribute>()) {
return phi::Scalar(dyn_cast<ir::FloatAttribute>().data());
} else if (isa<ir::DoubleAttribute>()) {
return phi::Scalar(dyn_cast<ir::DoubleAttribute>().data());
} else if (isa<ir::Int32_tAttribute>()) {
return phi::Scalar(dyn_cast<ir::Int32_tAttribute>().data());
} else if (isa<ir::Int64_tAttribute>()) {
return phi::Scalar(dyn_cast<ir::Int64_tAttribute>().data());
} else if (isa<ir::BoolAttribute>()) {
return phi::Scalar(dyn_cast<ir::BoolAttribute>().data());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported ir attribute when casting it into "
"phi scalar."));
}
}
} // namespace dialect } // namespace dialect
} // namespace paddle } // namespace paddle
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
#include "paddle/fluid/ir/dialect/pd_attribute_storage.h" #include "paddle/fluid/ir/dialect/pd_attribute_storage.h"
#include "paddle/ir/core/attribute.h" #include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/enforce.h"
namespace paddle { namespace paddle {
namespace dialect { namespace dialect {
...@@ -45,6 +47,8 @@ class ScalarAttribute : public ir::Attribute { ...@@ -45,6 +47,8 @@ class ScalarAttribute : public ir::Attribute {
(val.type_id() == ir::Int32_tAttribute::type_id()) || (val.type_id() == ir::Int32_tAttribute::type_id()) ||
(val.type_id() == ir::Int64_tAttribute::type_id()); (val.type_id() == ir::Int64_tAttribute::type_id());
} }
phi::Scalar data();
}; };
class DataTypeAttribute : public ir::Attribute { class DataTypeAttribute : public ir::Attribute {
......
...@@ -101,14 +101,17 @@ struct OpInputInfo { ...@@ -101,14 +101,17 @@ struct OpInputInfo {
std::string type_name; std::string type_name;
bool optional = false; bool optional = false;
bool no_need_buffer = false; bool no_need_buffer = false;
bool is_mutable_attribute = false;
OpInputInfo(std::string name, OpInputInfo(std::string name,
std::string type_name, std::string type_name,
bool optional, bool optional,
bool no_need_buffer) bool no_need_buffer,
bool is_mutable_attribute)
: name(name), : name(name),
type_name(type_name), type_name(type_name),
optional(optional), optional(optional),
no_need_buffer(no_need_buffer) {} no_need_buffer(no_need_buffer),
is_mutable_attribute(is_mutable_attribute) {}
}; };
struct OpOutputInfo { struct OpOutputInfo {
......
...@@ -56,7 +56,7 @@ class Builder { ...@@ -56,7 +56,7 @@ class Builder {
template <typename OpTy, typename... Args> template <typename OpTy, typename... Args>
OpTy Build(Args &&...args) { OpTy Build(Args &&...args) {
OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name())); OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name()));
OpTy::Build(argument, std::forward<Args>(args)...); OpTy::Build(*this, argument, std::forward<Args>(args)...);
Operation *op = Build(std::move(argument)); Operation *op = Build(std::move(argument));
return op->dyn_cast<OpTy>(); return op->dyn_cast<OpTy>();
} }
......
...@@ -57,20 +57,15 @@ void ModuleOp::Verify(const std::vector<ir::OpResult> &inputs, ...@@ -57,20 +57,15 @@ void ModuleOp::Verify(const std::vector<ir::OpResult> &inputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp."; VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp.";
// Verify inputs type: // Verify inputs type:
if (inputs.size() != 0) { IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0.");
throw("The size of inputs must be equal to 0.");
}
// Verify if attributes contain attribute name in attributes_name: // Verify if attributes contain attribute name in attributes_name:
auto iter = attributes.find("program"); auto iter = attributes.find("program");
if (iter == attributes.end() || !iter->second.isa<PointerAttribute>()) { IR_ENFORCE(iter != attributes.end() && iter->second.isa<PointerAttribute>(),
throw("Type of attribute: program is not right."); "Type of attribute: program is not right.");
}
// Verify outputs type: // Verify outputs type:
if (outputs.size() != 0) { IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0.");
throw("The size of outputs must be equal to 0.");
}
} }
const char *GetParameterOp::attributes_name[attributes_num] = { const char *GetParameterOp::attributes_name[attributes_num] = {
...@@ -81,17 +76,15 @@ void GetParameterOp::Verify(const std::vector<ir::OpResult> &inputs, ...@@ -81,17 +76,15 @@ void GetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp."; VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp.";
// Verify inputs type: // Verify inputs type:
if (inputs.size() != 0) { IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0.");
throw("The size of inputs must be equal to 0.");
}
// Verify outputs type:
if (outputs.size() != 1) {
throw("The size of outputs must be equal to 1.");
}
// Verify if attributes contain attribute name in attributes_name: // Verify if attributes contain attribute name in attributes_name:
if (!attributes.at("parameter_name").isa<StrAttribute>()) { auto iter = attributes.find("parameter_name");
throw("Type of attribute: parameter_name is not right."); IR_ENFORCE(iter != attributes.end() && iter->second.isa<StrAttribute>(),
} "Type of attribute: parameter_name is not right.");
// Verify outputs type:
IR_ENFORCE(outputs.size() == 1, "The size of outputs must be equal to 1.");
} }
const char *SetParameterOp::attributes_name[attributes_num] = { const char *SetParameterOp::attributes_name[attributes_num] = {
...@@ -102,54 +95,45 @@ void SetParameterOp::Verify(const std::vector<ir::OpResult> &inputs, ...@@ -102,54 +95,45 @@ void SetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp."; VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp.";
// Verify inputs type: // Verify inputs type:
if (inputs.size() != 1) { IR_ENFORCE(inputs.size() == 1, "The size of outputs must be equal to 1.");
throw("The size of inputs must be equal to 1.");
}
// Verify outputs type:
if (outputs.size() != 0) {
throw("The size of outputs must be equal to 0.");
}
// Verify if attributes contain attribute name in attributes_name: // Verify if attributes contain attribute name in attributes_name:
if (!attributes.at("parameter_name").isa<StrAttribute>()) { auto iter = attributes.find("parameter_name");
throw("Type of attribute: parameter_name is not right."); IR_ENFORCE(iter != attributes.end() && iter->second.isa<StrAttribute>(),
} "Type of attribute: parameter_name is not right.");
// Verify outputs type:
IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0.");
} }
void CombineOp::Verify(const std::vector<ir::OpResult> &inputs, void CombineOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
// outputs.size() == 1 // outputs.size() == 1
PADDLE_ENFORCE_EQ( IR_ENFORCE(outputs.size() == 1,
outputs.size(), "The size %d of outputs must be equal to 1.",
1, outputs.size());
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", outputs.size()));
// outputs[0].type == Vector<Type> // outputs[0].type == Vector<Type>
PADDLE_ENFORCE(outputs[0].isa<ir::VectorType>(), IR_ENFORCE(outputs[0].isa<ir::VectorType>(),
phi::errors::PreconditionNotMet( "The type %s of outputs[0] must be equal to VectorType.",
"The type %s of outputs[0] must be equal to VectorType.", outputs[0]);
outputs[0]));
ir::VectorType output_type = outputs[0].dyn_cast<ir::VectorType>(); ir::VectorType output_type = outputs[0].dyn_cast<ir::VectorType>();
// inputs.size() == outputs[0].size() // inputs.size() == outputs[0].size()
PADDLE_ENFORCE_EQ( IR_ENFORCE(output_type.size() == inputs.size(),
output_type.size(), "The size %d of outputs[0] must be equal to size %d of inputs.",
inputs.size(), output_type.size(),
phi::errors::PreconditionNotMet( inputs.size());
"The size %d of outputs[0] must be equal to size %d of inputs.",
output_type.size(),
inputs.size()));
// forall i in inputs.size(): inputs[i].type == outputs[0][i].type // forall i in inputs.size(): inputs[i].type == outputs[0][i].type
for (size_t i = 0; i < inputs.size(); i++) { for (size_t i = 0; i < inputs.size(); i++) {
PADDLE_ENFORCE_EQ( IR_ENFORCE(output_type[i] == inputs[i].type(),
output_type[i], "The type %s of outputs[0][%d] must be "
inputs[i].type(), "equal to type %s of inputs[%d].",
phi::errors::PreconditionNotMet("The type %s of outputs[0][%d] must be " output_type[i],
"equal to type %s of inputs[%d].", i,
output_type[i], inputs[i].type(),
i, i);
inputs[i].type(),
i));
} }
} }
...@@ -158,65 +142,50 @@ void SliceOp::Verify(const std::vector<ir::OpResult> &inputs, ...@@ -158,65 +142,50 @@ void SliceOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) { const ir::AttributeMap &attributes) {
// inputs.size() == 1 // inputs.size() == 1
PADDLE_ENFORCE_EQ( IR_ENFORCE(inputs.size() == 1,
inputs.size(), "The size %d of inputs must be equal to 1.",
1, inputs.size());
phi::errors::PreconditionNotMet(
"The size %d of inputs must be equal to 1.", inputs.size()));
// inputs[0].type == Vector<Type> // inputs[0].type == Vector<Type>
PADDLE_ENFORCE(inputs[0].type().isa<ir::VectorType>(), IR_ENFORCE(inputs[0].type().isa<ir::VectorType>(),
phi::errors::PreconditionNotMet( "The type %s of inputs[0] must be equal to VectorType.",
"The type %s of inputs[0] must be equal to VectorType.", inputs[0].type());
inputs[0].type()));
ir::VectorType input_type = inputs[0].type().dyn_cast<ir::VectorType>(); ir::VectorType input_type = inputs[0].type().dyn_cast<ir::VectorType>();
// outputs.size() == 1 // outputs.size() == 1
PADDLE_ENFORCE_EQ( IR_ENFORCE(outputs.size() == 1,
outputs.size(), "The size %d of outputs must be equal to 1.",
1, outputs.size());
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", outputs.size()));
// attributes contains index: Int32 // attributes contains index: Int32
PADDLE_ENFORCE_NE( IR_ENFORCE(attributes.count("index") != 0,
attributes.count("index"), "The attributes must contains index.");
0,
phi::errors::PreconditionNotMet("The attributes must contains index."));
const ir::Attribute &attr = attributes.at("index"); const ir::Attribute &attr = attributes.at("index");
PADDLE_ENFORCE( IR_ENFORCE(attr.isa<ir::Int32_tAttribute>(),
attr.isa<ir::Int32_tAttribute>(), "The attribute index must be INT32.");
phi::errors::PreconditionNotMet("The attribute index must be INT32."));
auto index = attr.dyn_cast<ir::Int32_tAttribute>().data(); auto index = attr.dyn_cast<ir::Int32_tAttribute>().data();
// index >= 0 and < inputs[0].size() // index >= 0 and < inputs[0].size()
PADDLE_ENFORCE_GE( IR_ENFORCE(
index, index >= 0, "The index %d must be greater or equal than 0.", index);
0, IR_ENFORCE(static_cast<size_t>(index) < input_type.size(),
phi::errors::PreconditionNotMet( "The index %d must be less or equal than size %d of inputs[0].",
"The index %d must be greater or equal than 0.", index)); index,
PADDLE_ENFORCE_LT( input_type.size());
index,
input_type.size(),
phi::errors::PreconditionNotMet(
"The index %d must be less or equal than size %d of inputs[0].",
index,
input_type.size()));
// inputs[index].type == outputs[0].type // inputs[index].type == outputs[0].type
PADDLE_ENFORCE_EQ( IR_ENFORCE(
input_type[index] == outputs[0],
"The type %s of inputs[%d] must be equal to type %s of outputs[0].",
input_type[index], input_type[index],
outputs[0], index,
phi::errors::PreconditionNotMet( outputs[0]);
"The type %s of inputs[%d] must be equal to type %s of outputs[0].",
input_type[index],
index,
outputs[0]));
} }
const char *ConstantOp::attributes_name[attributes_num] = {"value"}; const char *ConstantOp::attributes_name[attributes_num] = {"value"};
void ConstantOp::Build(OperationArgument &argument, void ConstantOp::Build(Builder &builder,
OperationArgument &argument,
Attribute value, Attribute value,
Type output_type) { Type output_type) {
argument.AddAttribute("value", value); argument.AddAttribute("value", value);
......
...@@ -86,6 +86,7 @@ class CombineOp : public ir::Op<CombineOp> { ...@@ -86,6 +86,7 @@ class CombineOp : public ir::Op<CombineOp> {
static constexpr uint32_t attributes_num = 0; static constexpr uint32_t attributes_num = 0;
static constexpr const char **attributes_name = nullptr; static constexpr const char **attributes_name = nullptr;
static void Verify(const std::vector<ir::OpResult> &inputs, static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs, const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes); const ir::AttributeMap &attributes);
...@@ -125,7 +126,8 @@ class ConstantOp : public Op<ConstantOp, ConstantLikeTrait> { ...@@ -125,7 +126,8 @@ class ConstantOp : public Op<ConstantOp, ConstantLikeTrait> {
static constexpr uint32_t attributes_num = 1; static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void Build(OperationArgument &argument, // NOLINT static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
Attribute value, Attribute value,
Type output_type); Type output_type);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/ir/core/block.h" #include "paddle/ir/core/block.h"
#include "paddle/ir/core/dialect.h" #include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/op_info.h" #include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h" #include "paddle/ir/core/program.h"
...@@ -85,7 +86,7 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs, ...@@ -85,7 +86,7 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
base_ptr += sizeof(Operation); base_ptr += sizeof(Operation);
// 3.3. Construct OpOperands. // 3.3. Construct OpOperands.
if ((reinterpret_cast<uintptr_t>(base_ptr) & 0x7) != 0) { if ((reinterpret_cast<uintptr_t>(base_ptr) & 0x7) != 0) {
throw("The address of OpOperandImpl must be divisible by 8."); IR_THROW("The address of OpOperandImpl must be divisible by 8.");
} }
for (size_t idx = 0; idx < num_operands; idx++) { for (size_t idx = 0; idx < num_operands; idx++) {
new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op); new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op);
...@@ -147,7 +148,7 @@ void Operation::Destroy() { ...@@ -147,7 +148,7 @@ void Operation::Destroy() {
// 2.2. Deconstruct Operation. // 2.2. Deconstruct Operation.
if (reinterpret_cast<uintptr_t>(base_ptr) != if (reinterpret_cast<uintptr_t>(base_ptr) !=
reinterpret_cast<uintptr_t>(this)) { reinterpret_cast<uintptr_t>(this)) {
throw("Operation address error"); IR_THROW("Operation address error");
} }
reinterpret_cast<Operation *>(base_ptr)->~Operation(); reinterpret_cast<Operation *>(base_ptr)->~Operation();
base_ptr += sizeof(Operation); base_ptr += sizeof(Operation);
...@@ -178,7 +179,7 @@ Operation::Operation(const AttributeMap &attributes, ...@@ -178,7 +179,7 @@ Operation::Operation(const AttributeMap &attributes,
ir::OpResult Operation::GetResultByIndex(uint32_t index) const { ir::OpResult Operation::GetResultByIndex(uint32_t index) const {
if (index >= num_results_) { if (index >= num_results_) {
throw("index exceeds OP output range."); IR_THROW("index exceeds OP output range.");
} }
uint32_t max_inline_idx = detail::OpResultImpl::GetMaxInlineResultIndex(); uint32_t max_inline_idx = detail::OpResultImpl::GetMaxInlineResultIndex();
const char *ptr = const char *ptr =
...@@ -199,7 +200,7 @@ ir::OpResult Operation::GetResultByIndex(uint32_t index) const { ...@@ -199,7 +200,7 @@ ir::OpResult Operation::GetResultByIndex(uint32_t index) const {
ir::OpOperand Operation::GetOperandByIndex(uint32_t index) const { ir::OpOperand Operation::GetOperandByIndex(uint32_t index) const {
if (index >= num_operands_) { if (index >= num_operands_) {
throw("index exceeds OP input range."); IR_THROW("index exceeds OP input range.");
} }
const char *ptr = reinterpret_cast<const char *>(this) + sizeof(Operation) + const char *ptr = reinterpret_cast<const char *>(this) + sizeof(Operation) +
(index) * sizeof(detail::OpOperandImpl); (index) * sizeof(detail::OpOperandImpl);
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include "paddle/ir/core/enforce.h"
namespace ir { namespace ir {
// This is a structure for creating, caching, and looking up Storage of // This is a structure for creating, caching, and looking up Storage of
// parametric types. // parametric types.
...@@ -76,7 +78,7 @@ StorageManager::StorageBase *StorageManager::GetParametricStorageImpl( ...@@ -76,7 +78,7 @@ StorageManager::StorageBase *StorageManager::GetParametricStorageImpl(
<< std::hash<ir::TypeId>()(type_id) << ", param_hash=" << hash_value << std::hash<ir::TypeId>()(type_id) << ", param_hash=" << hash_value
<< "]."; << "].";
if (parametric_instance_.find(type_id) == parametric_instance_.end()) { if (parametric_instance_.find(type_id) == parametric_instance_.end()) {
throw("The input data pointer is null."); IR_THROW("The input data pointer is null.");
} }
ParametricStorageManager &parametric_storage = *parametric_instance_[type_id]; ParametricStorageManager &parametric_storage = *parametric_instance_[type_id];
return parametric_storage.GetOrCreate(hash_value, equal_func, constructor); return parametric_storage.GetOrCreate(hash_value, equal_func, constructor);
...@@ -88,7 +90,7 @@ StorageManager::StorageBase *StorageManager::GetParameterlessStorageImpl( ...@@ -88,7 +90,7 @@ StorageManager::StorageBase *StorageManager::GetParameterlessStorageImpl(
VLOG(4) << "Try to get a parameterless storage of: [TypeId_hash=" VLOG(4) << "Try to get a parameterless storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "]."; << std::hash<ir::TypeId>()(type_id) << "].";
if (parameterless_instance_.find(type_id) == parameterless_instance_.end()) if (parameterless_instance_.find(type_id) == parameterless_instance_.end())
throw("TypeId not found in IrContext."); IR_THROW("TypeId not found in IrContext.");
StorageBase *parameterless_instance = parameterless_instance_[type_id]; StorageBase *parameterless_instance = parameterless_instance_[type_id];
return parameterless_instance; return parameterless_instance;
} }
...@@ -107,7 +109,7 @@ void StorageManager::RegisterParameterlessStorageImpl( ...@@ -107,7 +109,7 @@ void StorageManager::RegisterParameterlessStorageImpl(
VLOG(4) << "Register a parameterless storage of: [TypeId_hash=" VLOG(4) << "Register a parameterless storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "]."; << std::hash<ir::TypeId>()(type_id) << "].";
if (parameterless_instance_.find(type_id) != parameterless_instance_.end()) if (parameterless_instance_.find(type_id) != parameterless_instance_.end())
throw("storage class already registered"); IR_THROW("storage class already registered");
parameterless_instance_.emplace(type_id, constructor()); parameterless_instance_.emplace(type_id, constructor());
} }
......
...@@ -427,6 +427,18 @@ ...@@ -427,6 +427,18 @@
data_type : dtype data_type : dtype
backend : place backend : place
- op : full_int_array
args : (IntArray value, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
output: Tensor(out)
infer_meta :
func : CreateIntArrayInferMeta
param : [value, dtype]
kernel :
func : full_int_array
param : [value, dtype]
data_type : dtype
backend : place
- op : full_like - op : full_like
args : (Tensor x, Scalar value, DataType dtype = DataType::UNDEFINED, Place place = {}) args : (Tensor x, Scalar value, DataType dtype = DataType::UNDEFINED, Place place = {})
output: Tensor(out) output: Tensor(out)
......
...@@ -62,9 +62,9 @@ ...@@ -62,9 +62,9 @@
beta2 : beta2 :
data_type : float data_type : float
tensor_name : Beta2Tensor tensor_name : Beta2Tensor
episilon : epsilon :
data_type : float data_type : float
tensor_name : EpisilonTensor tensor_name : EpsilonTensor
manual_signature : [adam_] manual_signature : [adam_]
- op : adamax_ - op : adamax_
...@@ -85,9 +85,9 @@ ...@@ -85,9 +85,9 @@
beta2 : beta2 :
data_type : float data_type : float
tensor_name : Beta2Tensor tensor_name : Beta2Tensor
episilon : epsilon :
data_type : float data_type : float
tensor_name : EpisilonTensor tensor_name : EpsilonTensor
- op : add (elementwise_add) - op : add (elementwise_add)
backward : add_grad (elementwise_add_grad) backward : add_grad (elementwise_add_grad)
...@@ -1970,7 +1970,7 @@ ...@@ -1970,7 +1970,7 @@
outputs: outputs:
out : Out out : Out
int_array: int_array:
axis : dims :
data_type : int data_type : int
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false]
......
...@@ -41,6 +41,15 @@ void CreateInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out) { ...@@ -41,6 +41,15 @@ void CreateInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out) {
CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out); CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out);
} }
void CreateIntArrayInferMeta(const IntArray& data,
DataType dtype,
MetaTensor* out) {
CreateInferMetaBase({static_cast<int64_t>(data.GetData().size())},
dtype,
DataLayout::NCHW,
out);
}
void CreateInferMetaBase(const std::vector<int64_t>& shape, void CreateInferMetaBase(const std::vector<int64_t>& shape,
DataType dtype, DataType dtype,
DataLayout layout, DataLayout layout,
......
...@@ -35,6 +35,10 @@ void AssignValueInferMeta(const std::vector<int>& shape, ...@@ -35,6 +35,10 @@ void AssignValueInferMeta(const std::vector<int>& shape,
DataType dtype, DataType dtype,
MetaTensor* out); MetaTensor* out);
void CreateIntArrayInferMeta(const IntArray& data,
DataType dtype,
MetaTensor* out);
void CreateInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out); void CreateInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out);
void CreateInferMetaBase(const std::vector<int64_t>& shape, void CreateInferMetaBase(const std::vector<int64_t>& shape,
......
...@@ -80,6 +80,18 @@ void FullLikeKernel(const Context& dev_ctx, ...@@ -80,6 +80,18 @@ void FullLikeKernel(const Context& dev_ctx,
FullValue<T>(dev_ctx, out, value); FullValue<T>(dev_ctx, out, value);
} }
template <typename T, typename Context>
void FullIntArrayKernel(const Context& dev_ctx,
const IntArray& val,
DataType dtype UNUSED,
DenseTensor* out) {
out->Resize(phi::make_ddim({static_cast<int64_t>(val.GetData().size())}));
T* out_data = dev_ctx.template Alloc<T>(out);
for (size_t i = 0; i < val.GetData().size(); ++i) {
out_data[i] = static_cast<T>(val.GetData()[i]);
}
}
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(full, PD_REGISTER_KERNEL(full,
...@@ -115,3 +127,6 @@ PD_REGISTER_KERNEL(full_like, ...@@ -115,3 +127,6 @@ PD_REGISTER_KERNEL(full_like,
phi::dtype::complex<double>) { phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
} }
PD_REGISTER_KERNEL(
full_int_array, CPU, ALL_LAYOUT, phi::FullIntArrayKernel, int, int64_t) {}
...@@ -83,4 +83,10 @@ DenseTensor FullLike(const Context& dev_ctx, ...@@ -83,4 +83,10 @@ DenseTensor FullLike(const Context& dev_ctx,
return dense_out; return dense_out;
} }
template <typename T, typename Context>
void FullIntArrayKernel(const Context& dev_ctx,
const IntArray& val,
DataType dtype,
DenseTensor* out);
} // namespace phi } // namespace phi
...@@ -44,76 +44,61 @@ ...@@ -44,76 +44,61 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "test/cpp/ir/core/phi_kernel_adaptor.h" #include "test/cpp/ir/core/phi_kernel_adaptor.h"
PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(full_int_array, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(uniform, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(uniform, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
bool simple_cmp(float a, float b) { return std::abs((a - b) / a) < 1e-5; } bool simple_cmp(float a, float b) { return std::abs((a - b) / a) < 1e-5; }
TEST(program_test, program) { TEST(program_test, program) {
// Prepare ir env
ir::IrContext* ctx = ir::IrContext::Instance(); ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>(); ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program(ctx); ir::Program program(ctx);
ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program.block());
ir::Block* block = program.block(); ir::Block* block = program.block();
ir::Type fp32_dtype = ir::Float32Type::get(ctx);
paddle::dialect::DenseTensorTypeStorage::Dim dims = {2, 2};
paddle::dialect::DenseTensorTypeStorage::DataLayout data_layout =
paddle::dialect::DenseTensorTypeStorage::DataLayout::NCHW;
paddle::dialect::DenseTensorTypeStorage::LoD lod = {};
size_t offset = 0;
ir::Type dense_tensor_dtype = paddle::dialect::DenseTensorType::get(
ctx, fp32_dtype, dims, data_layout, lod, offset);
// (1) Def a = GetParameterOp("a")
std::string op1_name = std::string(paddle::dialect::UniformOp::name());
ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name);
// ir::Attribute shape_1 = ir::ArrayAttribute::get(ctx, {ten} );
ir::Attribute shape_1 = paddle::dialect::IntArrayAttribute::get(
ctx, std::vector<int64_t>({2, 2}));
ir::Attribute data_type =
paddle::dialect::DataTypeAttribute::get(ctx, phi::DataType::FLOAT32);
ir::Attribute min = ir::FloatAttribute::get(ctx, 0.0);
ir::Attribute max = ir::FloatAttribute::get(ctx, 1.0);
ir::Attribute seed = ir::Int32_tAttribute::get(ctx, 2);
ir::Attribute uni_place = paddle::dialect::PlaceAttribute::get(
ctx, phi::Place(phi::AllocationType::CPU));
std::unordered_map<std::string, ir::Attribute> op1_attribute{
{"shape", shape_1},
{"dtype", data_type},
{"min", min},
{"max", max},
{"seed", seed},
{"place", uni_place}};
ir::Operation* op1 =
ir::Operation::Create({}, op1_attribute, {dense_tensor_dtype}, op1_info);
block->push_back(op1);
// (2) Def b = GetParameterOp("b")
std::string op2_name = std::string(paddle::dialect::UniformOp::name());
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name);
ir::Attribute ten2 = ir::Int32_tAttribute::get(ctx, 3);
std::unordered_map<std::string, ir::Attribute> op2_attribute{{"shape", ten2}};
ir::Operation* op2 =
ir::Operation::Create({}, op1_attribute, {dense_tensor_dtype}, op2_info);
block->push_back(op2);
// (3) Def out = AddOp(a, b)
std::string add_op_name = std::string(paddle::dialect::AddOp::name());
ir::OpInfo add_op_info = ctx->GetRegisteredOpInfo(add_op_name);
ir::Operation* add_op = ir::Operation::Create(
{op1->GetResultByIndex(0), op2->GetResultByIndex(0)},
{},
{dense_tensor_dtype},
add_op_info);
block->push_back(add_op);
// Def: A = paddle::dialect::UniformOp(std::vector<int64_t> shape,
// phi::DataType dtype, float min, float max, int seed, phi::Place place)
paddle::dialect::UniformOp uniform1 =
builder.Build<paddle::dialect::UniformOp>(std::vector<int64_t>{2, 2},
phi::DataType::FLOAT32,
0.0,
1.0,
2,
phi::CPUPlace());
EXPECT_EQ(uniform1->GetResultByIndex(0)
.type()
.isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 4u);
// Def: B = paddle::dialect::UniformOp(...)
paddle::dialect::UniformOp uniform2 =
builder.Build<paddle::dialect::UniformOp>(std::vector<int64_t>{2, 2},
phi::DataType::FLOAT32,
0.0,
1.0,
2,
phi::CPUPlace());
EXPECT_EQ(uniform2->GetResultByIndex(0)
.type()
.isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 8u);
// Def: C = paddle::dialect::AddOp(ir::OpResult x_, ir::OpResult y_)
paddle::dialect::AddOp add = builder.Build<paddle::dialect::AddOp>(
uniform1->GetResultByIndex(0), uniform2->GetResultByIndex(0));
EXPECT_EQ(
add->GetResultByIndex(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 9u);
// Execute program
paddle::framework::Scope scope; paddle::framework::Scope scope;
PhiKernelAdaptor phi_kernel_adaptor(&scope); PhiKernelAdaptor phi_kernel_adaptor(&scope);
phi_kernel_adaptor.run(&program); phi_kernel_adaptor.run(&program);
auto out_tensor = auto out_tensor =
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h" #include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_base.h" #include "paddle/ir/core/op_base.h"
#include "paddle/ir/core/program.h" #include "paddle/ir/core/program.h"
...@@ -240,10 +241,12 @@ TEST(op_test, module_op_death) { ...@@ -240,10 +241,12 @@ TEST(op_test, module_op_death) {
ir::AttributeMap attrs{{"program", ir::Int32_tAttribute::get(ctx, 1)}}; ir::AttributeMap attrs{{"program", ir::Int32_tAttribute::get(ctx, 1)}};
std::vector<ir::Type> output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> output_types = {ir::Float32Type::get(ctx)};
EXPECT_THROW(ir::Operation::Create(inputs, {}, {}, op_info), const char *); EXPECT_THROW(ir::Operation::Create(inputs, {}, {}, op_info),
EXPECT_THROW(ir::Operation::Create({}, attrs, {}, op_info), const char *); ir::IrNotMetException);
EXPECT_THROW(ir::Operation::Create({}, attrs, {}, op_info),
ir::IrNotMetException);
EXPECT_THROW(ir::Operation::Create({}, {}, output_types, op_info), EXPECT_THROW(ir::Operation::Create({}, {}, output_types, op_info),
const char *); ir::IrNotMetException);
ir::Program program(ctx); ir::Program program(ctx);
......
...@@ -98,27 +98,30 @@ void build_context(ir::Operation* op, ...@@ -98,27 +98,30 @@ void build_context(ir::Operation* op,
op->dyn_cast<paddle::dialect::OpYamlInfoInterface>(); op->dyn_cast<paddle::dialect::OpYamlInfoInterface>();
auto op_info_res = op_info_interface.GetOpInfo(); auto op_info_res = op_info_interface.GetOpInfo();
// inputs include input and mutable attributes
auto input_info = std::get<0>(op_info_res); auto input_info = std::get<0>(op_info_res);
std::map<std::string, size_t> input_index_map;
std::set<std::string> input_set; std::map<std::string, std::string> mutable_attr_type_map;
int input_index = 0;
for (auto& t : input_info) { for (auto& t : input_info) {
VLOG(6) << t.name << "\t" << t.type_name; VLOG(6) << t.name << "\t" << t.type_name;
input_index_map[t.name] = input_index++;
input_set.insert(t.name); if (t.is_mutable_attribute) {
mutable_attr_type_map[t.name] = t.type_name;
}
} }
auto attr_map = op->attributes();
std::map<std::string, std::string> attr_type_map;
auto attr_info = std::get<1>(op_info_res); auto attr_info = std::get<1>(op_info_res);
std::map<std::string, std::string> attr_type_map;
for (auto& t : attr_info) { for (auto& t : attr_info) {
VLOG(6) << t.name << "\t" << t.type_name; VLOG(6) << t.name << "\t" << t.type_name;
attr_type_map[t.name] = t.type_name; attr_type_map[t.name] = t.type_name;
} }
auto attr_map = op->attributes();
auto runtime_info = std::get<3>(op_info_res); auto runtime_info = std::get<3>(op_info_res);
int input_index = 0; // int input_index = 0;
std::vector<std::string> vec_param_list; std::vector<std::string> vec_param_list;
if (is_infer_meta) { if (is_infer_meta) {
vec_param_list = runtime_info.infer_meta_param; vec_param_list = runtime_info.infer_meta_param;
...@@ -126,13 +129,31 @@ void build_context(ir::Operation* op, ...@@ -126,13 +129,31 @@ void build_context(ir::Operation* op,
vec_param_list = runtime_info.kernel_param; vec_param_list = runtime_info.kernel_param;
} }
for (auto& t : vec_param_list) { for (auto& t : vec_param_list) {
if (input_set.count(t)) { if (input_index_map.count(t)) {
// get information from input // get information from input
ir::Value ptr = op->GetOperandByIndex(input_index++).source(); ir::Value ptr = op->GetOperandByIndex(input_index_map[t]).source();
auto in_var_name = name_map.at(ptr); auto in_var_name = name_map.at(ptr);
ctx->EmplaceBackInput( if (mutable_attr_type_map.count(t)) {
scope->Var(in_var_name)->GetMutable<phi::DenseTensor>()); VLOG(6) << "ctx->EmplaceBack mutable attr: " << t << "\t"
<< in_var_name;
if (mutable_attr_type_map[t] == "paddle::dialect::IntArrayAttribute") {
ctx->EmplaceBackAttr(phi::IntArray(
*(scope->Var(in_var_name)->GetMutable<phi::DenseTensor>())));
} else if (mutable_attr_type_map[t] ==
"paddle::dialect::ScalarAttribute") {
ctx->EmplaceBackAttr(phi::Scalar(
*(scope->Var(in_var_name)->GetMutable<phi::DenseTensor>())));
} else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
mutable_attr_type_map[t]));
}
} else {
VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name;
ctx->EmplaceBackInput(
scope->Var(in_var_name)->GetMutable<phi::DenseTensor>());
}
} }
if (attr_type_map.count(t)) { if (attr_type_map.count(t)) {
...@@ -149,10 +170,14 @@ void build_context(ir::Operation* op, ...@@ -149,10 +170,14 @@ void build_context(ir::Operation* op,
} else if (type_name == "paddle::dialect::PlaceAttribute") { } else if (type_name == "paddle::dialect::PlaceAttribute") {
ctx->EmplaceBackAttr( ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::PlaceAttribute>().data()); attr_map[t].dyn_cast<paddle::dialect::PlaceAttribute>().data());
} else if (type_name == "paddle::dialect::ScalarAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::ScalarAttribute>().data());
} else { } else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ", PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
type_name)); type_name));
} }
VLOG(6) << "ctx->EmplaceBackAttr: " << t;
} }
} }
...@@ -197,6 +222,9 @@ class PhiKernelAdaptor { ...@@ -197,6 +222,9 @@ class PhiKernelAdaptor {
phi::KernelKey kernel_key(phi::TransToPhiBackend(cpu_place), phi::KernelKey kernel_key(phi::TransToPhiBackend(cpu_place),
phi::DataLayout::ANY, phi::DataLayout::ANY,
phi::DataType::FLOAT32); phi::DataType::FLOAT32);
if (runtime_info.kernel_func[0] == "full_int_array") {
kernel_key.set_dtype(phi::DataType::INT64);
}
auto found_it = phi_kernels.find(kernel_key); auto found_it = phi_kernels.find(kernel_key);
if (found_it == phi_kernels.end()) { if (found_it == phi_kernels.end()) {
std::cerr << "kernel name " << runtime_info.kernel_func[0] << std::endl; std::cerr << "kernel name " << runtime_info.kernel_func[0] << std::endl;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册