未验证 提交 2ade1f92 编写于 作者: A Aurelius84 提交者: GitHub

[NewIR]Add ConvertIRType and fix some TODO for IR+CINN (#55691)

* [NewIR]Add ConvertIRType and fix some TODO for IR+CINN

* modify into GPUPlace
上级 9f3b5f15
......@@ -50,7 +50,7 @@ class NewIRCompiler final {
const Target& target,
const std::shared_ptr<Scope>& scope)
: program_(prog),
m_builder_("NewIR", target), // TODO(dev): need unique name
m_builder_("NewIR", target),
target_(target),
scope_(scope) {}
std::unique_ptr<Program> Build() {
......@@ -103,20 +103,13 @@ class NewIRCompiler final {
// TODO(Aurelius84): For now, use addr as name but it's not wise.
std::string input_id = CompatibleInfo::kInputPrefix +
std::to_string(std::hash<::ir::Value>()(in_value));
// NOTE(Aurelius84): whether need to support other Type?
auto type_info =
in_value.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto in_shape = phi::vectorize<int>(type_info.dims());
ir::Tensor temp;
auto dtype = type_info.dtype();
// TODO(Aurelius84): support more type
if (dtype.isa<::ir::Float32Type>()) {
temp = lang::Placeholder<float>(input_id, in_shape);
} else if (dtype.isa<::ir::Int32Type>()) {
temp = lang::Placeholder<int>(input_id, in_shape);
}
ir::Tensor temp = lang::CreatePlaceHolder(
in_shape, utils::ConvertIRType(dtype), input_id);
inputs.push_back(temp);
cinn_inputs.push_back(common::CINNValue(temp));
}
......@@ -133,8 +126,7 @@ class NewIRCompiler final {
auto out_value = op.result(i);
auto type_info =
out_value.type().dyn_cast<paddle::dialect::DenseTensorType>();
// TODO(Aurelius84): need to support ::ir::Type -> common::Type
out_types.push_back(common::Float(32));
out_types.push_back(utils::ConvertIRType(type_info.dtype()));
auto out_shape = phi::vectorize<int>(type_info.dims());
out_shapes.push_back(std::move(out_shape));
}
......@@ -294,12 +286,10 @@ std::shared_ptr<Scope> BuildScope(const Target& target,
shape.push_back(Shape::dim_t(type_info.dims()[i]));
}
tensor->Resize(Shape{shape});
// TODO(Aurelius84): need convert this.
tensor->set_type(common::Float(32));
tensor->set_type(utils::ConvertIRType(type_info.dtype()));
};
for (auto it = program.block()->begin(); it != program.block()->end(); ++it) {
// visit OpOprands
for (auto i = 0; i < (*it)->num_operands(); ++i) {
auto in_value = (*it)->operand(i);
create_var(CompatibleInfo::kInputPrefix, in_value);
......
......@@ -16,8 +16,10 @@
#include <string>
#include <unordered_map>
#include "paddle/cinn/common/type.h"
#include "paddle/cinn/utils/type_defs.h"
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/phi/common/data_type.h"
namespace cinn {
......@@ -40,15 +42,12 @@ Attribute ConvertAttribute(const ::ir::Attribute& src_attr) {
} else if (src_attr.isa<::ir::DoubleAttribute>()) {
dst_attr = src_attr.dyn_cast<::ir::DoubleAttribute>().data();
} else if (src_attr.isa<paddle::dialect::IntArrayAttribute>()) {
auto arr = src_attr.dyn_cast<paddle::dialect::IntArrayAttribute>().data();
std::vector<int> val;
for (size_t i = 0; i < arr.size(); ++i) {
val.push_back(arr[i]);
}
auto& arr = src_attr.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data()
.GetData();
std::vector<int> val(arr.begin(), arr.end());
dst_attr = val;
} else if (src_attr.isa<paddle::dialect::DataTypeAttribute>()) {
// TODO(Aurelius84): Need add convert logic from phi::DataType into cinn
// String.
auto dtype = src_attr.dyn_cast<paddle::dialect::DataTypeAttribute>().data();
dst_attr = phi::DataTypeToString(dtype);
} else {
......@@ -62,16 +61,36 @@ AttributeMap ConvertAttributes(const NewIR_AttributeMap& src_attrs) {
AttributeMap dst_attrs;
for (auto& item : src_attrs) {
VLOG(4) << "deal with " << item.first;
if (!item.second.isa<paddle::dialect::PlaceAttribute>()) {
dst_attrs[item.first] = std::move(ConvertAttribute(item.second));
if (item.second.isa<paddle::dialect::PlaceAttribute>()) {
auto is_cpu =
item.second.dyn_cast<paddle::dialect::PlaceAttribute>().data() ==
phi::CPUPlace();
dst_attrs["force_cpu"] = is_cpu;
} else {
// TODO(Aurelius84): support place attribute for special Op
dst_attrs["force_cpu"] = false;
dst_attrs[item.first] = std::move(ConvertAttribute(item.second));
}
}
VLOG(4) << "dst_attrs.size(): " << dst_attrs.size();
return dst_attrs;
}
#define CASE_TYPE(src, dst) \
else if (type.isa<::ir::src>()) return common::dst();
common::Type ConvertIRType(::ir::Type type) {
if (type.isa<::ir::BFloat16Type>()) return common::BF16();
CASE_TYPE(Float16Type, F16)
CASE_TYPE(Float32Type, F32)
CASE_TYPE(Float64Type, F64)
CASE_TYPE(Int8Type, I8)
CASE_TYPE(UInt8Type, UI8)
CASE_TYPE(Int16Type, I16)
CASE_TYPE(Int32Type, I32)
CASE_TYPE(Int64Type, I64)
CASE_TYPE(BoolType, UI1)
LOG(FATAL) << "unknown ir::Type " << type;
}
} // namespace utils
} // namespace cinn
......@@ -39,13 +39,13 @@ TEST(GraphCompier, TestNewIR) {
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64, 128},
value,
phi::DataType::FLOAT32,
phi::CPUPlace());
phi::GPUPlace());
auto full_op_y =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{128, 64},
value,
phi::DataType::FLOAT32,
phi::CPUPlace());
phi::GPUPlace());
// TODO(Aurelius84): test more op
// auto add_z = builder.Build<paddle::dialect::MatmulOp>(full_op_x->result(0),
// full_op_y->result(0));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册