From 2ade1f924fd211d12a84b41ce43a66c56213a119 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 26 Jul 2023 13:06:02 +0800 Subject: [PATCH] [NewIR]Add ConvertIRType and fix some TODO for IR+CINN (#55691) * [NewIR]Add ConvertIRType and fix some TODO for IR+CINN * modify into GPUPlace --- paddle/cinn/hlir/framework/new_ir_compiler.h | 20 +++------ paddle/cinn/utils/attribute_util.h | 41 ++++++++++++++----- .../cpp/ir/cinn/graph_compiler_new_ir_test.cc | 4 +- 3 files changed, 37 insertions(+), 28 deletions(-) diff --git a/paddle/cinn/hlir/framework/new_ir_compiler.h b/paddle/cinn/hlir/framework/new_ir_compiler.h index fc4944d1ca2..2daaf8923ef 100644 --- a/paddle/cinn/hlir/framework/new_ir_compiler.h +++ b/paddle/cinn/hlir/framework/new_ir_compiler.h @@ -50,7 +50,7 @@ class NewIRCompiler final { const Target& target, const std::shared_ptr& scope) : program_(prog), - m_builder_("NewIR", target), // TODO(dev): need unique name + m_builder_("NewIR", target), target_(target), scope_(scope) {} std::unique_ptr Build() { @@ -103,20 +103,13 @@ class NewIRCompiler final { // TODO(Aurelius84): For now, use addr as name but it's not wise. std::string input_id = CompatibleInfo::kInputPrefix + std::to_string(std::hash<::ir::Value>()(in_value)); - // NOTE(Aurelius84): whether need to support other Type? auto type_info = in_value.type().dyn_cast(); auto in_shape = phi::vectorize(type_info.dims()); - ir::Tensor temp; auto dtype = type_info.dtype(); - // TODO(Aurelius84): support more type - if (dtype.isa<::ir::Float32Type>()) { - temp = lang::Placeholder(input_id, in_shape); - } else if (dtype.isa<::ir::Int32Type>()) { - temp = lang::Placeholder(input_id, in_shape); - } - + ir::Tensor temp = lang::CreatePlaceHolder( + in_shape, utils::ConvertIRType(dtype), input_id); inputs.push_back(temp); cinn_inputs.push_back(common::CINNValue(temp)); } @@ -133,8 +126,7 @@ class NewIRCompiler final { auto out_value = op.result(i); auto type_info = out_value.type().dyn_cast(); - // TODO(Aurelius84): need to support ::ir::Type -> common::Type - out_types.push_back(common::Float(32)); + out_types.push_back(utils::ConvertIRType(type_info.dtype())); auto out_shape = phi::vectorize(type_info.dims()); out_shapes.push_back(std::move(out_shape)); } @@ -294,12 +286,10 @@ std::shared_ptr BuildScope(const Target& target, shape.push_back(Shape::dim_t(type_info.dims()[i])); } tensor->Resize(Shape{shape}); - // TODO(Aurelius84): need convert this. - tensor->set_type(common::Float(32)); + tensor->set_type(utils::ConvertIRType(type_info.dtype())); }; for (auto it = program.block()->begin(); it != program.block()->end(); ++it) { - // visit OpOprands for (auto i = 0; i < (*it)->num_operands(); ++i) { auto in_value = (*it)->operand(i); create_var(CompatibleInfo::kInputPrefix, in_value); diff --git a/paddle/cinn/utils/attribute_util.h b/paddle/cinn/utils/attribute_util.h index c920bd2cfc0..b5caaf667b9 100644 --- a/paddle/cinn/utils/attribute_util.h +++ b/paddle/cinn/utils/attribute_util.h @@ -16,8 +16,10 @@ #include #include +#include "paddle/cinn/common/type.h" #include "paddle/cinn/utils/type_defs.h" #include "paddle/fluid/ir/dialect/pd_attribute.h" +#include "paddle/ir/core/builtin_type.h" #include "paddle/phi/common/data_type.h" namespace cinn { @@ -40,15 +42,12 @@ Attribute ConvertAttribute(const ::ir::Attribute& src_attr) { } else if (src_attr.isa<::ir::DoubleAttribute>()) { dst_attr = src_attr.dyn_cast<::ir::DoubleAttribute>().data(); } else if (src_attr.isa()) { - auto arr = src_attr.dyn_cast().data(); - std::vector val; - for (size_t i = 0; i < arr.size(); ++i) { - val.push_back(arr[i]); - } + auto& arr = src_attr.dyn_cast() + .data() + .GetData(); + std::vector val(arr.begin(), arr.end()); dst_attr = val; } else if (src_attr.isa()) { - // TODO(Aurelius84): Need add convert logic from phi::DataType into cinn - // String. auto dtype = src_attr.dyn_cast().data(); dst_attr = phi::DataTypeToString(dtype); } else { @@ -62,16 +61,36 @@ AttributeMap ConvertAttributes(const NewIR_AttributeMap& src_attrs) { AttributeMap dst_attrs; for (auto& item : src_attrs) { VLOG(4) << "deal with " << item.first; - if (!item.second.isa()) { - dst_attrs[item.first] = std::move(ConvertAttribute(item.second)); + if (item.second.isa()) { + auto is_cpu = + item.second.dyn_cast().data() == + phi::CPUPlace(); + dst_attrs["force_cpu"] = is_cpu; } else { - // TODO(Aurelius84): support place attribute for special Op - dst_attrs["force_cpu"] = false; + dst_attrs[item.first] = std::move(ConvertAttribute(item.second)); } } VLOG(4) << "dst_attrs.size(): " << dst_attrs.size(); return dst_attrs; } +#define CASE_TYPE(src, dst) \ + else if (type.isa<::ir::src>()) return common::dst(); + +common::Type ConvertIRType(::ir::Type type) { + if (type.isa<::ir::BFloat16Type>()) return common::BF16(); + CASE_TYPE(Float16Type, F16) + CASE_TYPE(Float32Type, F32) + CASE_TYPE(Float64Type, F64) + CASE_TYPE(Int8Type, I8) + CASE_TYPE(UInt8Type, UI8) + CASE_TYPE(Int16Type, I16) + CASE_TYPE(Int32Type, I32) + CASE_TYPE(Int64Type, I64) + CASE_TYPE(BoolType, UI1) + + LOG(FATAL) << "unknown ir::Type " << type; +} + } // namespace utils } // namespace cinn diff --git a/test/cpp/ir/cinn/graph_compiler_new_ir_test.cc b/test/cpp/ir/cinn/graph_compiler_new_ir_test.cc index 42ef6fe53d0..05ec98529ea 100644 --- a/test/cpp/ir/cinn/graph_compiler_new_ir_test.cc +++ b/test/cpp/ir/cinn/graph_compiler_new_ir_test.cc @@ -39,13 +39,13 @@ TEST(GraphCompier, TestNewIR) { builder.Build(std::vector{64, 128}, value, phi::DataType::FLOAT32, - phi::CPUPlace()); + phi::GPUPlace()); auto full_op_y = builder.Build(std::vector{128, 64}, value, phi::DataType::FLOAT32, - phi::CPUPlace()); + phi::GPUPlace()); // TODO(Aurelius84): test more op // auto add_z = builder.Build(full_op_x->result(0), // full_op_y->result(0)); -- GitLab