未验证 提交 058e1d85 编写于 作者: 石晓伟 提交者: GitHub

infrt runtime supports phi, test=develop (#39836)

* runtime supports pten kernels, test=develop

* fixes a bug, test=develop
上级 ca11a0e5
......@@ -23,6 +23,8 @@ class ContextTypeOf<string place, list<Trait> traits=[]>:
let summary = !strconcat("!phi.context_", place, " type");
}
def PhiOpTrait : NativeOpTrait<"PhiOpTrait">;
def CPU_Allocator : AllocatorTypeOf<"CPU">;
def GPU_Allocator : AllocatorTypeOf<"GPU">;
......
#ifndef PHI_KERNEL
#define PHI_KERNEL
include "paddle/infrt/dialect/phi/infrt_phi_tensor.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/infrt_base.td"
include "paddle/infrt/dialect/phi/infrt_phi_base.td"
def PHI_KernelDialect : Dialect {
let name = "phi_kernel";
......@@ -14,12 +17,7 @@ def PHI_KernelDialect : Dialect {
}
// PHI Kernel related ops.
class PDT_Kernel<string mnemonic, list<OpTrait> traits = []> : Op<PHI_KernelDialect, mnemonic, !listconcat(traits, [IsolatedFromAbove])> {
}
def FakeKernelOp : PDT_Kernel<"phi.matmul.host.fp32"> {
let arguments = (ins CPU_Context:$dev_ctx, DenseTensor:$x, DenseTensor:$y, BoolAttr:$transpose_x, BoolAttr:$transpose_y);
let results = (outs DenseTensor:$output);
class PDT_Kernel<string mnemonic, list<OpTrait> traits = []> : Op<PHI_KernelDialect, mnemonic, !listconcat(traits, [PhiOpTrait, IsolatedFromAbove])> {
}
def PDCK_AbsOp : PDT_Kernel<"phi.abs.host.fp32"> {
......
......@@ -18,7 +18,7 @@ def PHI_DenseTensorDialect : Dialect {
}
// PHI DenseTensor related Op.
class PDT_Op<string mnemonic, list<OpTrait> traits = []> : Op<PHI_DenseTensorDialect, mnemonic, !listconcat(traits, [IsolatedFromAbove])> {
class PDT_Op<string mnemonic, list<OpTrait> traits = []> : Op<PHI_DenseTensorDialect, mnemonic, !listconcat(traits, [PhiOpTrait, IsolatedFromAbove])> {
}
class CreateDenseTensorOp<string place, string dtype, string layout>
......@@ -53,4 +53,9 @@ def PDT_FillDenseTensorOp_f32 : FillDenseTensorOp<F32ArrayAttr, "f32">;
def PDT_CreateAllocatorOp_cpu : CreateCPUAllocatorOp;
def PDT_CreateContextOp_cpu : CreateCPUContextOp;
def FakeKernelOp : PDT_Op<"fake_phi_kernel"> {
let arguments = (ins CPU_Context:$dev_ctx, DenseTensor:$x, DenseTensor:$y, BoolAttr:$transpose_x, BoolAttr:$transpose_y);
let results = (outs DenseTensor:$output);
}
#endif
......@@ -25,6 +25,20 @@
#define GET_TYPEDEF_CLASSES
#include "paddle/infrt/dialect/phi/infrt_phi_baseTypes.h.inc"
namespace mlir {
namespace OpTrait {
template <typename ConcreteType>
class PhiOpTrait : public OpTrait::TraitBase<ConcreteType, PhiOpTrait> {
public:
static LogicalResult verifyTrait(Operation *op) {
return LogicalResult::success();
}
};
} // namespace OpTrait
} // namespace mlir
namespace infrt {
namespace phi {} // namespace phi
} // namespace infrt
......@@ -15,6 +15,7 @@
#include "paddle/infrt/host_context/kernel_frame.h"
#include <memory>
#include <sstream>
namespace infrt {
namespace host_context {
......@@ -25,5 +26,36 @@ std::ostream& operator<<(std::ostream& os, const KernelFrame& frame) {
return os;
}
#ifndef NDEBUG
std::string KernelFrame::DumpArgTypes() const {
std::stringstream ss;
for (auto* value : GetValues(0, GetNumElements())) {
if (value->is_type<bool>()) {
ss << "bool (" << &value->get<bool>() << "), ";
} else if (value->is_type<tensor::DenseHostTensor>()) {
ss << "DenseHostTensor(" << &value->get<tensor::DenseHostTensor>()
<< "), ";
} else if (value->is_type<float>()) {
ss << "float(" << &value->get<float>() << "), ";
} else if (value->is_type<int>()) {
ss << "int(" << &value->get<int>() << "), ";
} else if (value->is_type<phi::DenseTensor>()) {
ss << "phi::DenseTensor(" << &value->get<phi::DenseTensor>() << "), ";
} else if (value->is_type<phi::MetaTensor>()) {
ss << "phi::MetaTensor(" << &value->get<phi::MetaTensor>() << "), ";
} else if (value->is_type<::phi::CPUContext>()) {
ss << "phi::CPUContext(" << &value->get<::phi::CPUContext>() << "), ";
} else if (value->is_type<host_context::None>()) {
ss << "none(" << &value->get<host_context::None>() << "), ";
} else if (value->is_type<backends::CpuPhiContext>()) {
ss << "CpuPhiContext(" << &value->get<backends::CpuPhiContext>() << "), ";
} else {
ss << "typeid: " << value->index() << ", ";
}
}
return ss.str();
}
#endif
} // namespace host_context
} // namespace infrt
......@@ -31,20 +31,24 @@ namespace host_context {
class KernelFrame {
public:
int GetNumArgs() const { return num_arguments_; }
int GetNumResults() const { return num_results_ == -1 ? 0 : num_results_; }
int GetNumAttributes() const {
return value_or_attrs_.size() - num_arguments_ -
(num_results_ == -1 ? 0 : num_results_);
int GetNumResults() const {
return value_or_attrs_.size() - num_arguments_ - GetNumAttributes();
}
int GetNumAttributes() const { return num_attrs_ == -1 ? 0 : num_attrs_; }
//! Get something at a specific position \p index. The element might be an
//! argument, an attribute or a result.
template <typename T>
T& GetElementAt(int index) {
CHECK_LT(index, GetNumArgs() + GetNumAttributes() + GetNumResults());
CHECK_LT(static_cast<size_t>(index), GetNumElements());
return value_or_attrs_[index]->template get_or_default<T>();
}
Value* GetElementAt(int index) {
CHECK_LT(static_cast<size_t>(index), GetNumElements());
return value_or_attrs_[index];
}
// Get number of elements, either input, attributes or results.
size_t GetNumElements() const { return value_or_attrs_.size(); }
......@@ -70,18 +74,21 @@ class KernelFrame {
}
Value* GetAttributeAt(int idx) {
CHECK_NE(num_results_, -1)
<< "Must call SetNumResults before GetAttributeAt";
CHECK_LT(idx,
static_cast<int>(value_or_attrs_.size() - num_arguments_ -
num_results_));
return value_or_attrs_[num_arguments_ + num_results_ + idx];
// CHECK_NE(num_results_, -1)
//<< "Must call SetNumResults before GetAttributeAt";
CHECK_LT(idx, GetNumAttributes());
return value_or_attrs_[num_arguments_ + idx];
}
void AddAttribute(Value* v) {
CHECK_NE(num_results_, -1)
<< "Must call SetNumResults before calling AddAttribute";
CHECK_LE(num_results_, 0)
<< "Must call SetNumResults after calling AddAttribute";
value_or_attrs_.emplace_back(v);
if (num_attrs_ == -1) num_attrs_ = 0;
num_attrs_++;
CHECK_EQ(value_or_attrs_.size(),
static_cast<size_t>(num_arguments_ + num_attrs_));
}
template <typename T, typename... Args>
......@@ -96,35 +103,43 @@ class KernelFrame {
template <typename T>
void SetResultAt(int index, T&& value) {
CHECK_LT(index, num_results_) << "Invalid result index";
CHECK(value_or_attrs_[num_arguments_ + index]);
value_or_attrs_[num_arguments_ + index]->set(std::move(value));
CHECK_LT(index, GetNumResults()) << "Invalid result index";
CHECK(value_or_attrs_[num_arguments_ + GetNumAttributes() + index]);
value_or_attrs_[num_arguments_ + GetNumAttributes() + index]->set(
std::move(value));
}
llvm::ArrayRef<Value*> GetResults() const {
return GetValues(num_arguments_, num_results_);
CHECK_GE(num_results_, 0) << "Invalid results num";
return GetValues(num_arguments_ + GetNumAttributes(), num_results_);
}
llvm::MutableArrayRef<Value*> GetResults() {
return GetMutableValues(num_arguments_, num_results_);
CHECK_GE(num_results_, 0) << "Invalid results num";
return GetMutableValues(num_arguments_ + GetNumAttributes(), num_results_);
}
llvm::ArrayRef<Value*> GetValues(size_t from, size_t length) const {
CHECK_LE(static_cast<int>(from + length), num_arguments_ + num_results_);
CHECK_LE(from + length, GetNumElements());
if (length == 0) return {};
return llvm::makeArrayRef(&value_or_attrs_[from], length);
}
llvm::MutableArrayRef<Value*> GetMutableValues(size_t from, size_t length) {
CHECK_LE(static_cast<int>(from + length), num_arguments_ + num_results_);
CHECK_LE(from + length, GetNumElements());
if (length == 0) return {};
return llvm::makeMutableArrayRef(&value_or_attrs_[from], length);
}
#ifndef NDEBUG
std::string DumpArgTypes() const;
#endif
bool IsEmpty() const { return value_or_attrs_.empty(); }
protected:
int num_arguments_{};
int num_attrs_{-1};
int num_results_{-1};
llvm::SmallVector<Value*, 8> value_or_attrs_;
......@@ -136,15 +151,15 @@ class KernelFrameBuilder : public KernelFrame {
public:
void AddArgument(Value* value) {
CHECK(value);
CHECK_EQ(num_results_, -1)
<< "Should call AddArgument before calling SetNumResults";
CHECK_EQ(num_attrs_, -1)
<< "Should call AddArgument before calling SetAttributes";
value_or_attrs_.push_back(value);
++num_arguments_;
}
void SetResults(llvm::ArrayRef<Value*> values) {
CHECK_EQ(num_arguments_, static_cast<int>(value_or_attrs_.size()));
CHECK_EQ(num_results_, -1);
CHECK_EQ(num_arguments_ + GetNumAttributes(),
static_cast<int>(value_or_attrs_.size()));
for (Value* x : values) {
value_or_attrs_.push_back(x);
}
......@@ -152,28 +167,30 @@ class KernelFrameBuilder : public KernelFrame {
}
void SetNumResults(size_t n) {
CHECK_EQ(num_arguments_, static_cast<int>(value_or_attrs_.size()));
CHECK_EQ(num_results_, -1);
num_results_ = n;
CHECK_EQ(num_arguments_ + GetNumAttributes(),
static_cast<int>(value_or_attrs_.size()));
for (size_t i = 0; i < n; i++) {
value_or_attrs_.emplace_back(new Value);
}
num_results_ = n;
}
void SetResultAt(int result_id, Value* value) {
CHECK_EQ(static_cast<int>(value_or_attrs_.size()),
num_arguments_ + num_results_)
num_arguments_ + GetNumAttributes() + num_results_)
<< "Call SetNumResults first";
CHECK_LT(result_id + num_arguments_,
CHECK_LT(result_id + num_arguments_ + GetNumAttributes(),
static_cast<int>(value_or_attrs_.size()));
CHECK(value);
value_or_attrs_[num_arguments_ + result_id]->set(value);
value_or_attrs_[num_arguments_ + GetNumAttributes() + result_id]->set(
value);
}
void Reset() {
value_or_attrs_.clear();
num_arguments_ = 0;
num_results_ = -1;
num_attrs_ = -1;
}
};
......
......@@ -209,9 +209,11 @@ struct KernelImpl<Return (*)(Args...), impl_fn> {
static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) {
static_assert(out_idx != -1,
"Do not place Results after RemainingResults");
static_assert(const_idx == 0,
"Arguments and results should appear before attributes");
Result<Head> arg(&frame->GetResults()[out_idx]);
// static_assert(const_idx == 0,
// "Arguments and results should appear before attributes");
// Result<Head> arg(&frame->GetResults()[out_idx]);
Result<Head> arg(new ValueRef());
KernelCallHelper<
Tail...>::template Invoke<in_idx, out_idx + 1, const_idx>(frame,
pargs...,
......@@ -224,8 +226,8 @@ struct KernelImpl<Return (*)(Args...), impl_fn> {
struct KernelCallHelper<Attribute<Head>, Tail...> {
template <int in_idx, int out_idx, int const_idx, typename... PreviousArgs>
static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) {
static_assert(const_idx != -1,
"Do not place Attributes after RemainingAttributes");
// static_assert(const_idx != -1,
// "Do not place Attributes after RemainingAttributes");
Attribute<Head> arg(frame->GetAttributeAt(const_idx));
KernelCallHelper<
Tail...>::template Invoke<in_idx, out_idx, const_idx + 1>(frame,
......@@ -242,8 +244,8 @@ struct KernelImpl<Return (*)(Args...), impl_fn> {
static_assert(in_idx != -1,
"Do not place Arguments after RemainingArguments");
static_assert(out_idx == 0, "Arguments should appear before results");
static_assert(const_idx == 0,
"Arguments and results should appear before attributes.");
// static_assert(const_idx == 0,
// "Arguments and results should appear before attributes.");
auto* arg = &frame->template GetElementAt<Head>(in_idx);
KernelCallHelper<
Tail...>::template Invoke<in_idx + 1, out_idx, const_idx>(frame,
......@@ -265,7 +267,7 @@ struct KernelImpl<Return (*)(Args...), impl_fn> {
static_assert(const_idx == 0,
"Arguments and results should appear before attributes.");
auto* value = frame->GetArgAt(in_idx);
auto* value = frame->GetElementAt(in_idx);
auto&& arg = value->get<ArgT>();
KernelCallHelper<
......
......@@ -67,5 +67,45 @@ TEST(KernelImpl, pair) {
ASSERT_EQ(results[1]->get<float>(), 3.f);
}
void TestFunc(const std::string& arg_0,
const std::string& arg_1,
const std::string& arg_2,
Attribute<std::string> attr_0,
Result<std::string> res_0,
Result<std::string> res_1) {
CHECK_EQ(arg_0, "arg_0");
CHECK_EQ(arg_1, "arg_1");
CHECK_EQ(arg_2, "arg_2");
CHECK_EQ(attr_0.get(), "attr_0");
// res_0.Set(Argument<std::string>(ValueRef(new Value())));
// res_1.Set(Argument<std::string>(ValueRef(new Value())));
}
TEST(KernelRegistry, basic) {
KernelFrameBuilder kernel_frame;
Value arg_0(std::string{"arg_0"});
Value arg_1(std::string{"arg_1"});
Value arg_2(std::string{"arg_2"});
Value attr_0(std::string{"attr_0"});
kernel_frame.AddArgument(&arg_0);
kernel_frame.AddArgument(&arg_1);
kernel_frame.AddArgument(&arg_2);
kernel_frame.AddAttribute(&attr_0);
kernel_frame.SetNumResults(2);
CHECK_EQ(kernel_frame.GetNumArgs(), 3);
CHECK_EQ(kernel_frame.GetNumResults(), 2);
CHECK_EQ(kernel_frame.GetNumAttributes(), 1);
CHECK_EQ(kernel_frame.GetNumElements(), 6UL);
CHECK_EQ(kernel_frame.GetArgAt<std::string>(2), "arg_2");
CHECK_EQ(kernel_frame.GetAttributeAt(0)->get<std::string>(), "attr_0");
KernelImpl<decltype(&TestFunc), TestFunc>::Invoke(&kernel_frame);
}
} // namespace host_context
} // namespace infrt
......@@ -31,6 +31,7 @@
#include "boost/optional.hpp"
#include "paddle/infrt/common/string.h"
#include "paddle/infrt/dialect/dense_tensor.h"
#include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/dialect/tensor_shape.h"
#include "paddle/infrt/host_context/core_runtime.h"
......@@ -150,6 +151,17 @@ boost::optional<float> MlirToRuntimeTranslator::EmitAttribute(
return boost::none;
}
template <>
boost::optional<bool> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute& attr) {
if (!attr.isa<mlir::BoolAttr>()) return boost::none;
if (attr.isa<mlir::BoolAttr>()) {
auto val = attr.cast<mlir::BoolAttr>();
return val.getValue();
}
return boost::none;
}
template <>
boost::optional<double> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute& attr) {
......@@ -187,6 +199,7 @@ boost::optional<std::string> MlirToRuntimeTranslator::EmitAttribute(
return res; \
}
PROCESS_ARRAY_INT(bool, 1);
PROCESS_ARRAY_INT(int16_t, 16);
PROCESS_ARRAY_INT(int32_t, 32);
PROCESS_ARRAY_INT(int64_t, 64);
......@@ -262,25 +275,6 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
<< GetValue(operand) << " vs " << arg_value;
}
// process results
llvm::SmallVector<Value*, 4> res_values;
for (int i = 0, e = op->getNumResults(); i < e; i++) {
auto res = op->getResult(i);
res_values.push_back(AddValue(res));
VLOG(3) << "* op mlir res: " << DumpToString(res) << " " << GetValue(res);
}
impl_->cur_op->SetResults(res_values);
#ifdef INFRT_DEBUG
{
VLOG(3) << "check result";
for (int i = 0; i < impl_->cur_op->frame().GetNumResults(); i++) {
VLOG(3) << "+ res value: " << impl_->cur_op->frame().GetResults()[i];
}
}
#endif
// process attributes
auto attrs = op->getAttrs();
......@@ -296,6 +290,8 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<std::string>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<bool>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<std::vector<int16_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<int32_t>>(attr.getValue())) {
......@@ -311,6 +307,33 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
}
}
// process results
llvm::SmallVector<Value*, 4> res_values;
for (int i = 0, e = op->getNumResults(); i < e; i++) {
auto res = op->getResult(i);
if (res.getType().isa<::infrt::DenseTensorType>()) {
auto r = impl_->value_map.try_emplace(
res, ValueRef(new Value{::phi::DenseTensor()}));
CHECK(r.second) << "Duplicate add mlir value [" << DumpToString(res)
<< "]";
res_values.push_back(r.first->second.get());
} else {
res_values.push_back(AddValue(res));
}
VLOG(3) << "* op mlir res: " << DumpToString(res) << " " << GetValue(res);
}
impl_->cur_op->SetResults(res_values);
#ifdef INFRT_DEBUG
{
VLOG(3) << "check result";
for (int i = 0; i < impl_->cur_op->frame().GetNumResults(); i++) {
VLOG(3) << "+ res value: " << impl_->cur_op->frame().GetResults()[i];
}
}
#endif
// process regions, we treat regions as attribute.
auto num_regions = op->getNumRegions();
if (num_regions > 0) {
......@@ -440,14 +463,6 @@ bool MlirToRuntimeTranslator::EmitCallOp(mlir::Operation* op,
impl_->cur_op->AppendArgument(arg_value);
}
// process results
llvm::SmallVector<Value*, 4> res_values;
for (int i = 0, e = op->getNumResults(); i < e; i++) {
auto res = op->getResult(i);
res_values.push_back(AddValue(res));
}
impl_->cur_op->SetResults(res_values);
// process attribute
auto& table = function_table ? *function_table : impl_->func_defs;
{
......@@ -460,6 +475,14 @@ bool MlirToRuntimeTranslator::EmitCallOp(mlir::Operation* op,
impl_->cur_op->AppendAttribute(new Value(function));
}
// process results
llvm::SmallVector<Value*, 4> res_values;
for (int i = 0, e = op->getNumResults(); i < e; i++) {
auto res = op->getResult(i);
res_values.push_back(AddValue(res));
}
impl_->cur_op->SetResults(res_values);
VLOG(3) << "Emit call " << callee_name.getValue().str() << " "
<< impl_->cur_op->frame();
return true;
......
......@@ -133,7 +133,8 @@ void OpExecutable::Execute() {
VLOG(3) << "execute " << name()
<< " --- frame args: " << impl_->frame.GetNumArgs() << " results "
<< impl_->frame.GetNumResults() << " attributes "
<< impl_->frame.GetNumAttributes();
<< impl_->frame.GetNumAttributes() << "\n"
<< frame().DumpArgTypes();
for (int i = 0; i < impl_->frame.GetNumArgs(); i++) {
VLOG(3) << "function arg: " << impl_->frame.GetArgAt(i);
}
......
......@@ -45,10 +45,13 @@
namespace infrt {
namespace host_context {
struct None {};
struct MlirFunctionExecutable;
using ValueVariantType =
Variant<int16_t,
Variant<None,
int16_t,
int32_t,
int64_t,
float,
......@@ -118,13 +121,15 @@ class Value : public common::Object {
template <typename T>
const T& get() const {
CHECK(data.template is<T>());
CHECK(data.template is<T>()) << "typeid: " << data.index()
<< " != " << ValueVariantType::IndexOf<T>;
return data.get<T>();
}
template <typename T>
T& get() {
CHECK(data.template is<T>());
CHECK(data.template is<T>()) << "typeid: " << data.index()
<< " != " << ValueVariantType::IndexOf<T>;
return data.get<T>();
}
......@@ -153,6 +158,8 @@ class Value : public common::Object {
const char* type_info() const override;
ValueVariantType::IndexT index() const { return data.index(); }
friend void CopyTo(const Value& from, Value* to);
private:
......
......@@ -18,7 +18,7 @@ namespace infrt {
namespace kernel {
namespace phi {
backends::CpuPhiContext CreateCpuContext() { return {}; }
::phi::CPUContext CreateCpuContext() { return {}; }
} // namespace phi
} // namespace kernel
......
......@@ -21,7 +21,7 @@ namespace infrt {
namespace kernel {
namespace phi {
backends::CpuPhiContext CreateCpuContext();
::phi::CPUContext CreateCpuContext();
} // namespace phi
} // namespace kernel
......
......@@ -26,9 +26,6 @@ void InferShapedKernelLauncher::CreateKernelFrameForInferShape(
if (value->is_type<::phi::DenseTensor>()) {
values.emplace_back(::phi::MetaTensor{&value->get<::phi::DenseTensor>()});
infershape_kernel_frame_builder.AddArgument(values.back().get());
} else if (value->is_type<phi::DenseTensor>()) {
values.emplace_back(phi::MetaTensor{&value->get<phi::DenseTensor>()});
infershape_kernel_frame_builder.AddArgument(values.back().get());
} else {
infershape_kernel_frame_builder.AddArgument(value);
}
......
......@@ -14,7 +14,9 @@
#pragma once
#include <llvm/ADT/SmallVector.h>
#include <iostream>
#include "paddle/infrt/backends/host/phi_context.h"
#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/kernel/phi/infershaped/infershaped_utils.h"
......@@ -22,6 +24,26 @@
namespace infrt {
namespace kernel {
static void FakePhiInferShape(const ::phi::MetaTensor& a,
const ::phi::MetaTensor& b,
bool arg_0,
bool arg_1,
::phi::MetaTensor* c) {
LOG(INFO) << "the ptr of c: " << c;
LOG(INFO) << "c->numel(): " << c->numel();
}
static void FakePhiKernel(const ::phi::CPUContext& /*Context*/,
const ::phi::DenseTensor& a,
const ::phi::DenseTensor& b,
bool arg_0,
bool arg_1,
::phi::DenseTensor* c) {
std::cout << "@FakePhiKernel@" << std::endl;
LOG(INFO) << "the ptr of c: " << c;
LOG(INFO) << "c->numel(): " << c->numel();
}
template <typename KernelFunc,
KernelFunc kernel,
typename InferShapedFunc,
......@@ -31,10 +53,17 @@ class KernelLauncher : public InferShapedKernelLauncher {
static const uint16_t num_input_tensors{InferShapeHelper<KernelFunc>::count};
static const bool turn_on_infer_shape_cache{true};
void Invoke(host_context::KernelFrame* frame) override {
#ifndef NDEBUG
LOG(INFO) << "Kernel.frame: " << frame->DumpArgTypes();
#endif
// Build the infershape KernelFrame if needed.
// TODO(Superjomn) add unlikely here.
if (infershape_kernel_frame_builder.IsEmpty()) {
CreateKernelFrameForInferShape(frame);
#ifndef NDEBUG
LOG(INFO) << "infershape.frame: "
<< infershape_kernel_frame_builder.DumpArgTypes();
#endif
}
if (turn_on_infer_shape_cache) {
if (!turn_on_infer_shape_cache || IsShapeChanged(num_input_tensors)) {
......
......@@ -43,17 +43,15 @@ void RegisterPhiKernels(host_context::KernelRegistry* registry) {
registry->AddKernel("phi_dt.fill_dense_tensor.f32",
INFRT_KERNEL(infrt::kernel::phi::FillDenseTensorF32));
registry->AddKernel(
"phi.matmul.host.fp32",
std::bind(&kernel::KernelLauncherFunc<
decltype(&::phi::MatmulKernel<float, ::phi::CPUContext>),
&::phi::MatmulKernel<float, ::phi::CPUContext>,
decltype(&::phi::MatmulInferMeta),
&::phi::MatmulInferMeta>,
kernel::KernelLauncher<
decltype(&::phi::MatmulKernel<float, ::phi::CPUContext>),
&::phi::MatmulKernel<float, ::phi::CPUContext>,
decltype(&::phi::MatmulInferMeta),
&::phi::MatmulInferMeta>(),
"phi_dt.fake_phi_kernel",
std::bind(&KernelLauncherFunc<decltype(&FakePhiKernel),
&FakePhiKernel,
decltype(&FakePhiInferShape),
&FakePhiInferShape>,
KernelLauncher<decltype(&FakePhiKernel),
&FakePhiKernel,
decltype(&FakePhiInferShape),
&FakePhiInferShape>(),
std::placeholders::_1));
}
......
......@@ -45,7 +45,7 @@ void PrintTensor(const DenseHostTensor &tensor) {
}
template <typename T>
void FillTensorWithConstant(DenseHostTensor *tensor, Attribute<T> v) {
void FillTensorWithConstant(Attribute<T> v, DenseHostTensor *tensor) {
MutableDTArrayView<T>(tensor).Fill(v.get());
}
......@@ -53,13 +53,11 @@ TensorMap LoadParams(const std::string &path) {
return *(infrt::tensor::LoadParams(path));
}
void TensorMapGetTensor(TensorMap map,
DenseHostTensor *out,
Attribute<std::string> name) {
DenseHostTensor TensorMapGetTensor(TensorMap map, Attribute<std::string> name) {
auto it = map.find(name.get());
CHECK(it != map.end()) << "No tensor called " << name.get()
<< " in the TensorMap";
*out = *it->second;
return *it->second;
}
int32_t TensorMapGetSize(TensorMap map) { return map.size(); }
......
......@@ -136,12 +136,12 @@ class Variant {
return nullptr;
}
IndexT index() { return index_; }
IndexT index() const { return index_; }
private:
template <typename T>
static constexpr size_t IndexOf = TupleIndexOf<T, Types>::value;
private:
static constexpr size_t kStorageSize = std::max({sizeof(Ts)...});
static constexpr size_t kAlignment = std::max({alignof(Ts)...});
......
// RUN: infrtopt %s | FileCheck %s
// RUN: infrtexec -i %s | FileCheck %s
// CHECK-LABEL: @basic_tensor
func @basic_tensor() {
%a = "phi_dt.create_allocator.cpu" (): () -> !phi.CPU_allocator
%b = "phi_dt.create_context.cpu" (): () -> !phi.CPU_context
%c = "phi_dt.create_dense_tensor.cpu.f32.nchw" (%a) {dims=[1:i64], lod=[1:i64]}: (!phi.CPU_allocator) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
// "phi_dt.fill_dense_tensor.f32" (%c) {value=[1.0:f32]} : (!Infrt.tensor<CPU, FP32, NCHW>) -> ()
// CHECK-LABEL: @fake_phi_kernel_execute
func @fake_phi_kernel_execute() {
%allocator = "phi_dt.create_allocator.cpu" (): () -> !phi.CPU_allocator
%ctx = "phi_dt.create_context.cpu" (): () -> !phi.CPU_context
%t = "phi_dt.create_dense_tensor.cpu.f32.nchw" (%allocator) {dims=[1:i64], lod=[1:i64]}: (!phi.CPU_allocator) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
// CHECK: @FakePhiKernel@
%d = "phi_dt.fake_phi_kernel" (%ctx, %t, %t) {transpose_x=false, transpose_y=false} : (!phi.CPU_context, !infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
Infrt.return
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册