未验证 提交 058e1d85 编写于 作者: 石晓伟 提交者: GitHub

infrt runtime supports phi, test=develop (#39836)

* runtime supports pten kernels, test=develop

* fixes a bug, test=develop
上级 ca11a0e5
...@@ -23,6 +23,8 @@ class ContextTypeOf<string place, list<Trait> traits=[]>: ...@@ -23,6 +23,8 @@ class ContextTypeOf<string place, list<Trait> traits=[]>:
let summary = !strconcat("!phi.context_", place, " type"); let summary = !strconcat("!phi.context_", place, " type");
} }
def PhiOpTrait : NativeOpTrait<"PhiOpTrait">;
def CPU_Allocator : AllocatorTypeOf<"CPU">; def CPU_Allocator : AllocatorTypeOf<"CPU">;
def GPU_Allocator : AllocatorTypeOf<"GPU">; def GPU_Allocator : AllocatorTypeOf<"GPU">;
......
#ifndef PHI_KERNEL #ifndef PHI_KERNEL
#define PHI_KERNEL #define PHI_KERNEL
include "paddle/infrt/dialect/phi/infrt_phi_tensor.td" include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/infrt_base.td"
include "paddle/infrt/dialect/phi/infrt_phi_base.td"
def PHI_KernelDialect : Dialect { def PHI_KernelDialect : Dialect {
let name = "phi_kernel"; let name = "phi_kernel";
...@@ -14,12 +17,7 @@ def PHI_KernelDialect : Dialect { ...@@ -14,12 +17,7 @@ def PHI_KernelDialect : Dialect {
} }
// PHI Kernel related ops. // PHI Kernel related ops.
class PDT_Kernel<string mnemonic, list<OpTrait> traits = []> : Op<PHI_KernelDialect, mnemonic, !listconcat(traits, [IsolatedFromAbove])> { class PDT_Kernel<string mnemonic, list<OpTrait> traits = []> : Op<PHI_KernelDialect, mnemonic, !listconcat(traits, [PhiOpTrait, IsolatedFromAbove])> {
}
def FakeKernelOp : PDT_Kernel<"phi.matmul.host.fp32"> {
let arguments = (ins CPU_Context:$dev_ctx, DenseTensor:$x, DenseTensor:$y, BoolAttr:$transpose_x, BoolAttr:$transpose_y);
let results = (outs DenseTensor:$output);
} }
def PDCK_AbsOp : PDT_Kernel<"phi.abs.host.fp32"> { def PDCK_AbsOp : PDT_Kernel<"phi.abs.host.fp32"> {
......
...@@ -18,7 +18,7 @@ def PHI_DenseTensorDialect : Dialect { ...@@ -18,7 +18,7 @@ def PHI_DenseTensorDialect : Dialect {
} }
// PHI DenseTensor related Op. // PHI DenseTensor related Op.
class PDT_Op<string mnemonic, list<OpTrait> traits = []> : Op<PHI_DenseTensorDialect, mnemonic, !listconcat(traits, [IsolatedFromAbove])> { class PDT_Op<string mnemonic, list<OpTrait> traits = []> : Op<PHI_DenseTensorDialect, mnemonic, !listconcat(traits, [PhiOpTrait, IsolatedFromAbove])> {
} }
class CreateDenseTensorOp<string place, string dtype, string layout> class CreateDenseTensorOp<string place, string dtype, string layout>
...@@ -53,4 +53,9 @@ def PDT_FillDenseTensorOp_f32 : FillDenseTensorOp<F32ArrayAttr, "f32">; ...@@ -53,4 +53,9 @@ def PDT_FillDenseTensorOp_f32 : FillDenseTensorOp<F32ArrayAttr, "f32">;
def PDT_CreateAllocatorOp_cpu : CreateCPUAllocatorOp; def PDT_CreateAllocatorOp_cpu : CreateCPUAllocatorOp;
def PDT_CreateContextOp_cpu : CreateCPUContextOp; def PDT_CreateContextOp_cpu : CreateCPUContextOp;
def FakeKernelOp : PDT_Op<"fake_phi_kernel"> {
let arguments = (ins CPU_Context:$dev_ctx, DenseTensor:$x, DenseTensor:$y, BoolAttr:$transpose_x, BoolAttr:$transpose_y);
let results = (outs DenseTensor:$output);
}
#endif #endif
...@@ -25,6 +25,20 @@ ...@@ -25,6 +25,20 @@
#define GET_TYPEDEF_CLASSES #define GET_TYPEDEF_CLASSES
#include "paddle/infrt/dialect/phi/infrt_phi_baseTypes.h.inc" #include "paddle/infrt/dialect/phi/infrt_phi_baseTypes.h.inc"
namespace mlir {
namespace OpTrait {
template <typename ConcreteType>
class PhiOpTrait : public OpTrait::TraitBase<ConcreteType, PhiOpTrait> {
public:
static LogicalResult verifyTrait(Operation *op) {
return LogicalResult::success();
}
};
} // namespace OpTrait
} // namespace mlir
namespace infrt { namespace infrt {
namespace phi {} // namespace phi namespace phi {} // namespace phi
} // namespace infrt } // namespace infrt
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/infrt/host_context/kernel_frame.h" #include "paddle/infrt/host_context/kernel_frame.h"
#include <memory> #include <memory>
#include <sstream>
namespace infrt { namespace infrt {
namespace host_context { namespace host_context {
...@@ -25,5 +26,36 @@ std::ostream& operator<<(std::ostream& os, const KernelFrame& frame) { ...@@ -25,5 +26,36 @@ std::ostream& operator<<(std::ostream& os, const KernelFrame& frame) {
return os; return os;
} }
#ifndef NDEBUG
std::string KernelFrame::DumpArgTypes() const {
std::stringstream ss;
for (auto* value : GetValues(0, GetNumElements())) {
if (value->is_type<bool>()) {
ss << "bool (" << &value->get<bool>() << "), ";
} else if (value->is_type<tensor::DenseHostTensor>()) {
ss << "DenseHostTensor(" << &value->get<tensor::DenseHostTensor>()
<< "), ";
} else if (value->is_type<float>()) {
ss << "float(" << &value->get<float>() << "), ";
} else if (value->is_type<int>()) {
ss << "int(" << &value->get<int>() << "), ";
} else if (value->is_type<phi::DenseTensor>()) {
ss << "phi::DenseTensor(" << &value->get<phi::DenseTensor>() << "), ";
} else if (value->is_type<phi::MetaTensor>()) {
ss << "phi::MetaTensor(" << &value->get<phi::MetaTensor>() << "), ";
} else if (value->is_type<::phi::CPUContext>()) {
ss << "phi::CPUContext(" << &value->get<::phi::CPUContext>() << "), ";
} else if (value->is_type<host_context::None>()) {
ss << "none(" << &value->get<host_context::None>() << "), ";
} else if (value->is_type<backends::CpuPhiContext>()) {
ss << "CpuPhiContext(" << &value->get<backends::CpuPhiContext>() << "), ";
} else {
ss << "typeid: " << value->index() << ", ";
}
}
return ss.str();
}
#endif
} // namespace host_context } // namespace host_context
} // namespace infrt } // namespace infrt
...@@ -31,20 +31,24 @@ namespace host_context { ...@@ -31,20 +31,24 @@ namespace host_context {
class KernelFrame { class KernelFrame {
public: public:
int GetNumArgs() const { return num_arguments_; } int GetNumArgs() const { return num_arguments_; }
int GetNumResults() const { return num_results_ == -1 ? 0 : num_results_; } int GetNumResults() const {
int GetNumAttributes() const { return value_or_attrs_.size() - num_arguments_ - GetNumAttributes();
return value_or_attrs_.size() - num_arguments_ -
(num_results_ == -1 ? 0 : num_results_);
} }
int GetNumAttributes() const { return num_attrs_ == -1 ? 0 : num_attrs_; }
//! Get something at a specific position \p index. The element might be an //! Get something at a specific position \p index. The element might be an
//! argument, an attribute or a result. //! argument, an attribute or a result.
template <typename T> template <typename T>
T& GetElementAt(int index) { T& GetElementAt(int index) {
CHECK_LT(index, GetNumArgs() + GetNumAttributes() + GetNumResults()); CHECK_LT(static_cast<size_t>(index), GetNumElements());
return value_or_attrs_[index]->template get_or_default<T>(); return value_or_attrs_[index]->template get_or_default<T>();
} }
Value* GetElementAt(int index) {
CHECK_LT(static_cast<size_t>(index), GetNumElements());
return value_or_attrs_[index];
}
// Get number of elements, either input, attributes or results. // Get number of elements, either input, attributes or results.
size_t GetNumElements() const { return value_or_attrs_.size(); } size_t GetNumElements() const { return value_or_attrs_.size(); }
...@@ -70,18 +74,21 @@ class KernelFrame { ...@@ -70,18 +74,21 @@ class KernelFrame {
} }
Value* GetAttributeAt(int idx) { Value* GetAttributeAt(int idx) {
CHECK_NE(num_results_, -1) // CHECK_NE(num_results_, -1)
<< "Must call SetNumResults before GetAttributeAt"; //<< "Must call SetNumResults before GetAttributeAt";
CHECK_LT(idx, CHECK_LT(idx, GetNumAttributes());
static_cast<int>(value_or_attrs_.size() - num_arguments_ - return value_or_attrs_[num_arguments_ + idx];
num_results_));
return value_or_attrs_[num_arguments_ + num_results_ + idx];
} }
void AddAttribute(Value* v) { void AddAttribute(Value* v) {
CHECK_NE(num_results_, -1) CHECK_LE(num_results_, 0)
<< "Must call SetNumResults before calling AddAttribute"; << "Must call SetNumResults after calling AddAttribute";
value_or_attrs_.emplace_back(v); value_or_attrs_.emplace_back(v);
if (num_attrs_ == -1) num_attrs_ = 0;
num_attrs_++;
CHECK_EQ(value_or_attrs_.size(),
static_cast<size_t>(num_arguments_ + num_attrs_));
} }
template <typename T, typename... Args> template <typename T, typename... Args>
...@@ -96,35 +103,43 @@ class KernelFrame { ...@@ -96,35 +103,43 @@ class KernelFrame {
template <typename T> template <typename T>
void SetResultAt(int index, T&& value) { void SetResultAt(int index, T&& value) {
CHECK_LT(index, num_results_) << "Invalid result index"; CHECK_LT(index, GetNumResults()) << "Invalid result index";
CHECK(value_or_attrs_[num_arguments_ + index]); CHECK(value_or_attrs_[num_arguments_ + GetNumAttributes() + index]);
value_or_attrs_[num_arguments_ + index]->set(std::move(value)); value_or_attrs_[num_arguments_ + GetNumAttributes() + index]->set(
std::move(value));
} }
llvm::ArrayRef<Value*> GetResults() const { llvm::ArrayRef<Value*> GetResults() const {
return GetValues(num_arguments_, num_results_); CHECK_GE(num_results_, 0) << "Invalid results num";
return GetValues(num_arguments_ + GetNumAttributes(), num_results_);
} }
llvm::MutableArrayRef<Value*> GetResults() { llvm::MutableArrayRef<Value*> GetResults() {
return GetMutableValues(num_arguments_, num_results_); CHECK_GE(num_results_, 0) << "Invalid results num";
return GetMutableValues(num_arguments_ + GetNumAttributes(), num_results_);
} }
llvm::ArrayRef<Value*> GetValues(size_t from, size_t length) const { llvm::ArrayRef<Value*> GetValues(size_t from, size_t length) const {
CHECK_LE(static_cast<int>(from + length), num_arguments_ + num_results_); CHECK_LE(from + length, GetNumElements());
if (length == 0) return {}; if (length == 0) return {};
return llvm::makeArrayRef(&value_or_attrs_[from], length); return llvm::makeArrayRef(&value_or_attrs_[from], length);
} }
llvm::MutableArrayRef<Value*> GetMutableValues(size_t from, size_t length) { llvm::MutableArrayRef<Value*> GetMutableValues(size_t from, size_t length) {
CHECK_LE(static_cast<int>(from + length), num_arguments_ + num_results_); CHECK_LE(from + length, GetNumElements());
if (length == 0) return {}; if (length == 0) return {};
return llvm::makeMutableArrayRef(&value_or_attrs_[from], length); return llvm::makeMutableArrayRef(&value_or_attrs_[from], length);
} }
#ifndef NDEBUG
std::string DumpArgTypes() const;
#endif
bool IsEmpty() const { return value_or_attrs_.empty(); } bool IsEmpty() const { return value_or_attrs_.empty(); }
protected: protected:
int num_arguments_{}; int num_arguments_{};
int num_attrs_{-1};
int num_results_{-1}; int num_results_{-1};
llvm::SmallVector<Value*, 8> value_or_attrs_; llvm::SmallVector<Value*, 8> value_or_attrs_;
...@@ -136,15 +151,15 @@ class KernelFrameBuilder : public KernelFrame { ...@@ -136,15 +151,15 @@ class KernelFrameBuilder : public KernelFrame {
public: public:
void AddArgument(Value* value) { void AddArgument(Value* value) {
CHECK(value); CHECK(value);
CHECK_EQ(num_results_, -1) CHECK_EQ(num_attrs_, -1)
<< "Should call AddArgument before calling SetNumResults"; << "Should call AddArgument before calling SetAttributes";
value_or_attrs_.push_back(value); value_or_attrs_.push_back(value);
++num_arguments_; ++num_arguments_;
} }
void SetResults(llvm::ArrayRef<Value*> values) { void SetResults(llvm::ArrayRef<Value*> values) {
CHECK_EQ(num_arguments_, static_cast<int>(value_or_attrs_.size())); CHECK_EQ(num_arguments_ + GetNumAttributes(),
CHECK_EQ(num_results_, -1); static_cast<int>(value_or_attrs_.size()));
for (Value* x : values) { for (Value* x : values) {
value_or_attrs_.push_back(x); value_or_attrs_.push_back(x);
} }
...@@ -152,28 +167,30 @@ class KernelFrameBuilder : public KernelFrame { ...@@ -152,28 +167,30 @@ class KernelFrameBuilder : public KernelFrame {
} }
void SetNumResults(size_t n) { void SetNumResults(size_t n) {
CHECK_EQ(num_arguments_, static_cast<int>(value_or_attrs_.size())); CHECK_EQ(num_arguments_ + GetNumAttributes(),
CHECK_EQ(num_results_, -1); static_cast<int>(value_or_attrs_.size()));
num_results_ = n;
for (size_t i = 0; i < n; i++) { for (size_t i = 0; i < n; i++) {
value_or_attrs_.emplace_back(new Value); value_or_attrs_.emplace_back(new Value);
} }
num_results_ = n;
} }
void SetResultAt(int result_id, Value* value) { void SetResultAt(int result_id, Value* value) {
CHECK_EQ(static_cast<int>(value_or_attrs_.size()), CHECK_EQ(static_cast<int>(value_or_attrs_.size()),
num_arguments_ + num_results_) num_arguments_ + GetNumAttributes() + num_results_)
<< "Call SetNumResults first"; << "Call SetNumResults first";
CHECK_LT(result_id + num_arguments_, CHECK_LT(result_id + num_arguments_ + GetNumAttributes(),
static_cast<int>(value_or_attrs_.size())); static_cast<int>(value_or_attrs_.size()));
CHECK(value); CHECK(value);
value_or_attrs_[num_arguments_ + result_id]->set(value); value_or_attrs_[num_arguments_ + GetNumAttributes() + result_id]->set(
value);
} }
void Reset() { void Reset() {
value_or_attrs_.clear(); value_or_attrs_.clear();
num_arguments_ = 0; num_arguments_ = 0;
num_results_ = -1; num_results_ = -1;
num_attrs_ = -1;
} }
}; };
......
...@@ -209,9 +209,11 @@ struct KernelImpl<Return (*)(Args...), impl_fn> { ...@@ -209,9 +209,11 @@ struct KernelImpl<Return (*)(Args...), impl_fn> {
static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) { static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) {
static_assert(out_idx != -1, static_assert(out_idx != -1,
"Do not place Results after RemainingResults"); "Do not place Results after RemainingResults");
static_assert(const_idx == 0, // static_assert(const_idx == 0,
"Arguments and results should appear before attributes"); // "Arguments and results should appear before attributes");
Result<Head> arg(&frame->GetResults()[out_idx]);
// Result<Head> arg(&frame->GetResults()[out_idx]);
Result<Head> arg(new ValueRef());
KernelCallHelper< KernelCallHelper<
Tail...>::template Invoke<in_idx, out_idx + 1, const_idx>(frame, Tail...>::template Invoke<in_idx, out_idx + 1, const_idx>(frame,
pargs..., pargs...,
...@@ -224,8 +226,8 @@ struct KernelImpl<Return (*)(Args...), impl_fn> { ...@@ -224,8 +226,8 @@ struct KernelImpl<Return (*)(Args...), impl_fn> {
struct KernelCallHelper<Attribute<Head>, Tail...> { struct KernelCallHelper<Attribute<Head>, Tail...> {
template <int in_idx, int out_idx, int const_idx, typename... PreviousArgs> template <int in_idx, int out_idx, int const_idx, typename... PreviousArgs>
static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) { static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) {
static_assert(const_idx != -1, // static_assert(const_idx != -1,
"Do not place Attributes after RemainingAttributes"); // "Do not place Attributes after RemainingAttributes");
Attribute<Head> arg(frame->GetAttributeAt(const_idx)); Attribute<Head> arg(frame->GetAttributeAt(const_idx));
KernelCallHelper< KernelCallHelper<
Tail...>::template Invoke<in_idx, out_idx, const_idx + 1>(frame, Tail...>::template Invoke<in_idx, out_idx, const_idx + 1>(frame,
...@@ -242,8 +244,8 @@ struct KernelImpl<Return (*)(Args...), impl_fn> { ...@@ -242,8 +244,8 @@ struct KernelImpl<Return (*)(Args...), impl_fn> {
static_assert(in_idx != -1, static_assert(in_idx != -1,
"Do not place Arguments after RemainingArguments"); "Do not place Arguments after RemainingArguments");
static_assert(out_idx == 0, "Arguments should appear before results"); static_assert(out_idx == 0, "Arguments should appear before results");
static_assert(const_idx == 0, // static_assert(const_idx == 0,
"Arguments and results should appear before attributes."); // "Arguments and results should appear before attributes.");
auto* arg = &frame->template GetElementAt<Head>(in_idx); auto* arg = &frame->template GetElementAt<Head>(in_idx);
KernelCallHelper< KernelCallHelper<
Tail...>::template Invoke<in_idx + 1, out_idx, const_idx>(frame, Tail...>::template Invoke<in_idx + 1, out_idx, const_idx>(frame,
...@@ -265,7 +267,7 @@ struct KernelImpl<Return (*)(Args...), impl_fn> { ...@@ -265,7 +267,7 @@ struct KernelImpl<Return (*)(Args...), impl_fn> {
static_assert(const_idx == 0, static_assert(const_idx == 0,
"Arguments and results should appear before attributes."); "Arguments and results should appear before attributes.");
auto* value = frame->GetArgAt(in_idx); auto* value = frame->GetElementAt(in_idx);
auto&& arg = value->get<ArgT>(); auto&& arg = value->get<ArgT>();
KernelCallHelper< KernelCallHelper<
......
...@@ -67,5 +67,45 @@ TEST(KernelImpl, pair) { ...@@ -67,5 +67,45 @@ TEST(KernelImpl, pair) {
ASSERT_EQ(results[1]->get<float>(), 3.f); ASSERT_EQ(results[1]->get<float>(), 3.f);
} }
void TestFunc(const std::string& arg_0,
const std::string& arg_1,
const std::string& arg_2,
Attribute<std::string> attr_0,
Result<std::string> res_0,
Result<std::string> res_1) {
CHECK_EQ(arg_0, "arg_0");
CHECK_EQ(arg_1, "arg_1");
CHECK_EQ(arg_2, "arg_2");
CHECK_EQ(attr_0.get(), "attr_0");
// res_0.Set(Argument<std::string>(ValueRef(new Value())));
// res_1.Set(Argument<std::string>(ValueRef(new Value())));
}
TEST(KernelRegistry, basic) {
KernelFrameBuilder kernel_frame;
Value arg_0(std::string{"arg_0"});
Value arg_1(std::string{"arg_1"});
Value arg_2(std::string{"arg_2"});
Value attr_0(std::string{"attr_0"});
kernel_frame.AddArgument(&arg_0);
kernel_frame.AddArgument(&arg_1);
kernel_frame.AddArgument(&arg_2);
kernel_frame.AddAttribute(&attr_0);
kernel_frame.SetNumResults(2);
CHECK_EQ(kernel_frame.GetNumArgs(), 3);
CHECK_EQ(kernel_frame.GetNumResults(), 2);
CHECK_EQ(kernel_frame.GetNumAttributes(), 1);
CHECK_EQ(kernel_frame.GetNumElements(), 6UL);
CHECK_EQ(kernel_frame.GetArgAt<std::string>(2), "arg_2");
CHECK_EQ(kernel_frame.GetAttributeAt(0)->get<std::string>(), "attr_0");
KernelImpl<decltype(&TestFunc), TestFunc>::Invoke(&kernel_frame);
}
} // namespace host_context } // namespace host_context
} // namespace infrt } // namespace infrt
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "boost/optional.hpp" #include "boost/optional.hpp"
#include "paddle/infrt/common/string.h" #include "paddle/infrt/common/string.h"
#include "paddle/infrt/dialect/dense_tensor.h"
#include "paddle/infrt/dialect/mlir_loader.h" #include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/dialect/tensor_shape.h" #include "paddle/infrt/dialect/tensor_shape.h"
#include "paddle/infrt/host_context/core_runtime.h" #include "paddle/infrt/host_context/core_runtime.h"
...@@ -150,6 +151,17 @@ boost::optional<float> MlirToRuntimeTranslator::EmitAttribute( ...@@ -150,6 +151,17 @@ boost::optional<float> MlirToRuntimeTranslator::EmitAttribute(
return boost::none; return boost::none;
} }
template <>
boost::optional<bool> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute& attr) {
if (!attr.isa<mlir::BoolAttr>()) return boost::none;
if (attr.isa<mlir::BoolAttr>()) {
auto val = attr.cast<mlir::BoolAttr>();
return val.getValue();
}
return boost::none;
}
template <> template <>
boost::optional<double> MlirToRuntimeTranslator::EmitAttribute( boost::optional<double> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute& attr) { const mlir::Attribute& attr) {
...@@ -187,6 +199,7 @@ boost::optional<std::string> MlirToRuntimeTranslator::EmitAttribute( ...@@ -187,6 +199,7 @@ boost::optional<std::string> MlirToRuntimeTranslator::EmitAttribute(
return res; \ return res; \
} }
PROCESS_ARRAY_INT(bool, 1);
PROCESS_ARRAY_INT(int16_t, 16); PROCESS_ARRAY_INT(int16_t, 16);
PROCESS_ARRAY_INT(int32_t, 32); PROCESS_ARRAY_INT(int32_t, 32);
PROCESS_ARRAY_INT(int64_t, 64); PROCESS_ARRAY_INT(int64_t, 64);
...@@ -262,25 +275,6 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) { ...@@ -262,25 +275,6 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
<< GetValue(operand) << " vs " << arg_value; << GetValue(operand) << " vs " << arg_value;
} }
// process results
llvm::SmallVector<Value*, 4> res_values;
for (int i = 0, e = op->getNumResults(); i < e; i++) {
auto res = op->getResult(i);
res_values.push_back(AddValue(res));
VLOG(3) << "* op mlir res: " << DumpToString(res) << " " << GetValue(res);
}
impl_->cur_op->SetResults(res_values);
#ifdef INFRT_DEBUG
{
VLOG(3) << "check result";
for (int i = 0; i < impl_->cur_op->frame().GetNumResults(); i++) {
VLOG(3) << "+ res value: " << impl_->cur_op->frame().GetResults()[i];
}
}
#endif
// process attributes // process attributes
auto attrs = op->getAttrs(); auto attrs = op->getAttrs();
...@@ -296,6 +290,8 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) { ...@@ -296,6 +290,8 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
impl_->cur_op->AppendAttribute(new Value(*v)); impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<std::string>(attr.getValue())) { } else if (auto v = EmitAttribute<std::string>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v))); impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<bool>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<std::vector<int16_t>>(attr.getValue())) { } else if (auto v = EmitAttribute<std::vector<int16_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v))); impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<int32_t>>(attr.getValue())) { } else if (auto v = EmitAttribute<std::vector<int32_t>>(attr.getValue())) {
...@@ -311,6 +307,33 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) { ...@@ -311,6 +307,33 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
} }
} }
// process results
llvm::SmallVector<Value*, 4> res_values;
for (int i = 0, e = op->getNumResults(); i < e; i++) {
auto res = op->getResult(i);
if (res.getType().isa<::infrt::DenseTensorType>()) {
auto r = impl_->value_map.try_emplace(
res, ValueRef(new Value{::phi::DenseTensor()}));
CHECK(r.second) << "Duplicate add mlir value [" << DumpToString(res)
<< "]";
res_values.push_back(r.first->second.get());
} else {
res_values.push_back(AddValue(res));
}
VLOG(3) << "* op mlir res: " << DumpToString(res) << " " << GetValue(res);
}
impl_->cur_op->SetResults(res_values);
#ifdef INFRT_DEBUG
{
VLOG(3) << "check result";
for (int i = 0; i < impl_->cur_op->frame().GetNumResults(); i++) {
VLOG(3) << "+ res value: " << impl_->cur_op->frame().GetResults()[i];
}
}
#endif
// process regions, we treat regions as attribute. // process regions, we treat regions as attribute.
auto num_regions = op->getNumRegions(); auto num_regions = op->getNumRegions();
if (num_regions > 0) { if (num_regions > 0) {
...@@ -440,14 +463,6 @@ bool MlirToRuntimeTranslator::EmitCallOp(mlir::Operation* op, ...@@ -440,14 +463,6 @@ bool MlirToRuntimeTranslator::EmitCallOp(mlir::Operation* op,
impl_->cur_op->AppendArgument(arg_value); impl_->cur_op->AppendArgument(arg_value);
} }
// process results
llvm::SmallVector<Value*, 4> res_values;
for (int i = 0, e = op->getNumResults(); i < e; i++) {
auto res = op->getResult(i);
res_values.push_back(AddValue(res));
}
impl_->cur_op->SetResults(res_values);
// process attribute // process attribute
auto& table = function_table ? *function_table : impl_->func_defs; auto& table = function_table ? *function_table : impl_->func_defs;
{ {
...@@ -460,6 +475,14 @@ bool MlirToRuntimeTranslator::EmitCallOp(mlir::Operation* op, ...@@ -460,6 +475,14 @@ bool MlirToRuntimeTranslator::EmitCallOp(mlir::Operation* op,
impl_->cur_op->AppendAttribute(new Value(function)); impl_->cur_op->AppendAttribute(new Value(function));
} }
// process results
llvm::SmallVector<Value*, 4> res_values;
for (int i = 0, e = op->getNumResults(); i < e; i++) {
auto res = op->getResult(i);
res_values.push_back(AddValue(res));
}
impl_->cur_op->SetResults(res_values);
VLOG(3) << "Emit call " << callee_name.getValue().str() << " " VLOG(3) << "Emit call " << callee_name.getValue().str() << " "
<< impl_->cur_op->frame(); << impl_->cur_op->frame();
return true; return true;
......
...@@ -133,7 +133,8 @@ void OpExecutable::Execute() { ...@@ -133,7 +133,8 @@ void OpExecutable::Execute() {
VLOG(3) << "execute " << name() VLOG(3) << "execute " << name()
<< " --- frame args: " << impl_->frame.GetNumArgs() << " results " << " --- frame args: " << impl_->frame.GetNumArgs() << " results "
<< impl_->frame.GetNumResults() << " attributes " << impl_->frame.GetNumResults() << " attributes "
<< impl_->frame.GetNumAttributes(); << impl_->frame.GetNumAttributes() << "\n"
<< frame().DumpArgTypes();
for (int i = 0; i < impl_->frame.GetNumArgs(); i++) { for (int i = 0; i < impl_->frame.GetNumArgs(); i++) {
VLOG(3) << "function arg: " << impl_->frame.GetArgAt(i); VLOG(3) << "function arg: " << impl_->frame.GetArgAt(i);
} }
......
...@@ -45,10 +45,13 @@ ...@@ -45,10 +45,13 @@
namespace infrt { namespace infrt {
namespace host_context { namespace host_context {
struct None {};
struct MlirFunctionExecutable; struct MlirFunctionExecutable;
using ValueVariantType = using ValueVariantType =
Variant<int16_t, Variant<None,
int16_t,
int32_t, int32_t,
int64_t, int64_t,
float, float,
...@@ -118,13 +121,15 @@ class Value : public common::Object { ...@@ -118,13 +121,15 @@ class Value : public common::Object {
template <typename T> template <typename T>
const T& get() const { const T& get() const {
CHECK(data.template is<T>()); CHECK(data.template is<T>()) << "typeid: " << data.index()
<< " != " << ValueVariantType::IndexOf<T>;
return data.get<T>(); return data.get<T>();
} }
template <typename T> template <typename T>
T& get() { T& get() {
CHECK(data.template is<T>()); CHECK(data.template is<T>()) << "typeid: " << data.index()
<< " != " << ValueVariantType::IndexOf<T>;
return data.get<T>(); return data.get<T>();
} }
...@@ -153,6 +158,8 @@ class Value : public common::Object { ...@@ -153,6 +158,8 @@ class Value : public common::Object {
const char* type_info() const override; const char* type_info() const override;
ValueVariantType::IndexT index() const { return data.index(); }
friend void CopyTo(const Value& from, Value* to); friend void CopyTo(const Value& from, Value* to);
private: private:
......
...@@ -18,7 +18,7 @@ namespace infrt { ...@@ -18,7 +18,7 @@ namespace infrt {
namespace kernel { namespace kernel {
namespace phi { namespace phi {
backends::CpuPhiContext CreateCpuContext() { return {}; } ::phi::CPUContext CreateCpuContext() { return {}; }
} // namespace phi } // namespace phi
} // namespace kernel } // namespace kernel
......
...@@ -21,7 +21,7 @@ namespace infrt { ...@@ -21,7 +21,7 @@ namespace infrt {
namespace kernel { namespace kernel {
namespace phi { namespace phi {
backends::CpuPhiContext CreateCpuContext(); ::phi::CPUContext CreateCpuContext();
} // namespace phi } // namespace phi
} // namespace kernel } // namespace kernel
......
...@@ -26,9 +26,6 @@ void InferShapedKernelLauncher::CreateKernelFrameForInferShape( ...@@ -26,9 +26,6 @@ void InferShapedKernelLauncher::CreateKernelFrameForInferShape(
if (value->is_type<::phi::DenseTensor>()) { if (value->is_type<::phi::DenseTensor>()) {
values.emplace_back(::phi::MetaTensor{&value->get<::phi::DenseTensor>()}); values.emplace_back(::phi::MetaTensor{&value->get<::phi::DenseTensor>()});
infershape_kernel_frame_builder.AddArgument(values.back().get()); infershape_kernel_frame_builder.AddArgument(values.back().get());
} else if (value->is_type<phi::DenseTensor>()) {
values.emplace_back(phi::MetaTensor{&value->get<phi::DenseTensor>()});
infershape_kernel_frame_builder.AddArgument(values.back().get());
} else { } else {
infershape_kernel_frame_builder.AddArgument(value); infershape_kernel_frame_builder.AddArgument(value);
} }
......
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
#pragma once #pragma once
#include <llvm/ADT/SmallVector.h> #include <llvm/ADT/SmallVector.h>
#include <iostream>
#include "paddle/infrt/backends/host/phi_context.h"
#include "paddle/infrt/host_context/kernel_utils.h" #include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launcher.h" #include "paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/kernel/phi/infershaped/infershaped_utils.h" #include "paddle/infrt/kernel/phi/infershaped/infershaped_utils.h"
...@@ -22,6 +24,26 @@ ...@@ -22,6 +24,26 @@
namespace infrt { namespace infrt {
namespace kernel { namespace kernel {
static void FakePhiInferShape(const ::phi::MetaTensor& a,
const ::phi::MetaTensor& b,
bool arg_0,
bool arg_1,
::phi::MetaTensor* c) {
LOG(INFO) << "the ptr of c: " << c;
LOG(INFO) << "c->numel(): " << c->numel();
}
static void FakePhiKernel(const ::phi::CPUContext& /*Context*/,
const ::phi::DenseTensor& a,
const ::phi::DenseTensor& b,
bool arg_0,
bool arg_1,
::phi::DenseTensor* c) {
std::cout << "@FakePhiKernel@" << std::endl;
LOG(INFO) << "the ptr of c: " << c;
LOG(INFO) << "c->numel(): " << c->numel();
}
template <typename KernelFunc, template <typename KernelFunc,
KernelFunc kernel, KernelFunc kernel,
typename InferShapedFunc, typename InferShapedFunc,
...@@ -31,10 +53,17 @@ class KernelLauncher : public InferShapedKernelLauncher { ...@@ -31,10 +53,17 @@ class KernelLauncher : public InferShapedKernelLauncher {
static const uint16_t num_input_tensors{InferShapeHelper<KernelFunc>::count}; static const uint16_t num_input_tensors{InferShapeHelper<KernelFunc>::count};
static const bool turn_on_infer_shape_cache{true}; static const bool turn_on_infer_shape_cache{true};
void Invoke(host_context::KernelFrame* frame) override { void Invoke(host_context::KernelFrame* frame) override {
#ifndef NDEBUG
LOG(INFO) << "Kernel.frame: " << frame->DumpArgTypes();
#endif
// Build the infershape KernelFrame if needed. // Build the infershape KernelFrame if needed.
// TODO(Superjomn) add unlikely here. // TODO(Superjomn) add unlikely here.
if (infershape_kernel_frame_builder.IsEmpty()) { if (infershape_kernel_frame_builder.IsEmpty()) {
CreateKernelFrameForInferShape(frame); CreateKernelFrameForInferShape(frame);
#ifndef NDEBUG
LOG(INFO) << "infershape.frame: "
<< infershape_kernel_frame_builder.DumpArgTypes();
#endif
} }
if (turn_on_infer_shape_cache) { if (turn_on_infer_shape_cache) {
if (!turn_on_infer_shape_cache || IsShapeChanged(num_input_tensors)) { if (!turn_on_infer_shape_cache || IsShapeChanged(num_input_tensors)) {
......
...@@ -43,17 +43,15 @@ void RegisterPhiKernels(host_context::KernelRegistry* registry) { ...@@ -43,17 +43,15 @@ void RegisterPhiKernels(host_context::KernelRegistry* registry) {
registry->AddKernel("phi_dt.fill_dense_tensor.f32", registry->AddKernel("phi_dt.fill_dense_tensor.f32",
INFRT_KERNEL(infrt::kernel::phi::FillDenseTensorF32)); INFRT_KERNEL(infrt::kernel::phi::FillDenseTensorF32));
registry->AddKernel( registry->AddKernel(
"phi.matmul.host.fp32", "phi_dt.fake_phi_kernel",
std::bind(&kernel::KernelLauncherFunc< std::bind(&KernelLauncherFunc<decltype(&FakePhiKernel),
decltype(&::phi::MatmulKernel<float, ::phi::CPUContext>), &FakePhiKernel,
&::phi::MatmulKernel<float, ::phi::CPUContext>, decltype(&FakePhiInferShape),
decltype(&::phi::MatmulInferMeta), &FakePhiInferShape>,
&::phi::MatmulInferMeta>, KernelLauncher<decltype(&FakePhiKernel),
kernel::KernelLauncher< &FakePhiKernel,
decltype(&::phi::MatmulKernel<float, ::phi::CPUContext>), decltype(&FakePhiInferShape),
&::phi::MatmulKernel<float, ::phi::CPUContext>, &FakePhiInferShape>(),
decltype(&::phi::MatmulInferMeta),
&::phi::MatmulInferMeta>(),
std::placeholders::_1)); std::placeholders::_1));
} }
......
...@@ -45,7 +45,7 @@ void PrintTensor(const DenseHostTensor &tensor) { ...@@ -45,7 +45,7 @@ void PrintTensor(const DenseHostTensor &tensor) {
} }
template <typename T> template <typename T>
void FillTensorWithConstant(DenseHostTensor *tensor, Attribute<T> v) { void FillTensorWithConstant(Attribute<T> v, DenseHostTensor *tensor) {
MutableDTArrayView<T>(tensor).Fill(v.get()); MutableDTArrayView<T>(tensor).Fill(v.get());
} }
...@@ -53,13 +53,11 @@ TensorMap LoadParams(const std::string &path) { ...@@ -53,13 +53,11 @@ TensorMap LoadParams(const std::string &path) {
return *(infrt::tensor::LoadParams(path)); return *(infrt::tensor::LoadParams(path));
} }
void TensorMapGetTensor(TensorMap map, DenseHostTensor TensorMapGetTensor(TensorMap map, Attribute<std::string> name) {
DenseHostTensor *out,
Attribute<std::string> name) {
auto it = map.find(name.get()); auto it = map.find(name.get());
CHECK(it != map.end()) << "No tensor called " << name.get() CHECK(it != map.end()) << "No tensor called " << name.get()
<< " in the TensorMap"; << " in the TensorMap";
*out = *it->second; return *it->second;
} }
int32_t TensorMapGetSize(TensorMap map) { return map.size(); } int32_t TensorMapGetSize(TensorMap map) { return map.size(); }
......
...@@ -136,12 +136,12 @@ class Variant { ...@@ -136,12 +136,12 @@ class Variant {
return nullptr; return nullptr;
} }
IndexT index() { return index_; } IndexT index() const { return index_; }
private:
template <typename T> template <typename T>
static constexpr size_t IndexOf = TupleIndexOf<T, Types>::value; static constexpr size_t IndexOf = TupleIndexOf<T, Types>::value;
private:
static constexpr size_t kStorageSize = std::max({sizeof(Ts)...}); static constexpr size_t kStorageSize = std::max({sizeof(Ts)...});
static constexpr size_t kAlignment = std::max({alignof(Ts)...}); static constexpr size_t kAlignment = std::max({alignof(Ts)...});
......
// RUN: infrtopt %s | FileCheck %s // RUN: infrtexec -i %s | FileCheck %s
// CHECK-LABEL: @basic_tensor // CHECK-LABEL: @fake_phi_kernel_execute
func @basic_tensor() { func @fake_phi_kernel_execute() {
%a = "phi_dt.create_allocator.cpu" (): () -> !phi.CPU_allocator %allocator = "phi_dt.create_allocator.cpu" (): () -> !phi.CPU_allocator
%b = "phi_dt.create_context.cpu" (): () -> !phi.CPU_context %ctx = "phi_dt.create_context.cpu" (): () -> !phi.CPU_context
%c = "phi_dt.create_dense_tensor.cpu.f32.nchw" (%a) {dims=[1:i64], lod=[1:i64]}: (!phi.CPU_allocator) -> (!infrt.dense_tensor<CPU, FP32, NCHW>) %t = "phi_dt.create_dense_tensor.cpu.f32.nchw" (%allocator) {dims=[1:i64], lod=[1:i64]}: (!phi.CPU_allocator) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
// "phi_dt.fill_dense_tensor.f32" (%c) {value=[1.0:f32]} : (!Infrt.tensor<CPU, FP32, NCHW>) -> ()
// CHECK: @FakePhiKernel@
%d = "phi_dt.fake_phi_kernel" (%ctx, %t, %t) {transpose_x=false, transpose_y=false} : (!phi.CPU_context, !infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
Infrt.return Infrt.return
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册