未验证 提交 3082ed46 编写于 作者: W Wilber 提交者: GitHub

Trt engine. (#40532)

* infrt add trt engine

* fix register

* file generate

* fix ci error

* fix conflict

* add copyright

* update

* update

* update

* update engine name

* refactor trt code

* update

* update

* update

* update

* fix conflict

* update

* fix compile with cuda
上级 46abe798
......@@ -3,12 +3,22 @@ if (NOT WITH_INFRT)
endif()
option(INFRT_WITH_PHI "Compile INFRT with PHI" ON)
option(INFRT_WITH_GPU "Compile INFRT with GPU" OFF)
option(INFRT_WITH_TRT "Compile INFRT with TensorRT" OFF)
#TODO(xiaowei) remove fluid
include_directories(${PADDLE_SOURCE_DIR}/paddle/fluid/platform)
if (INFRT_WITH_PHI)
add_definitions("-DINFRT_WITH_PHI")
add_definitions("-DINFRT_WITH_PHI")
# TODO(wilber): Now Infrt gpu/trt depends on phi's components, Modify compile dependency options later.
if (INFRT_WITH_GPU)
add_definitions("-DINFRT_WITH_GPU")
if (INFRT_WITH_TRT)
add_definitions("-DINFRT_WITH_TRT")
endif()
endif()
endif()
# compile flags
......@@ -106,6 +116,9 @@ if (INFRT_WITH_PHI)
endif()
cc_library(infrt SHARED SRCS ${infrt_src} DEPS glog boost ${mlir_libs} ${phi_libs} paddle_framework_proto infrt_naive)
if (INFRT_WITH_TRT)
target_link_libraries(infrt infrt_trt)
endif()
cc_library(infrt_static SRCS ${infrt_src} DEPS glog boost ${mlir_libs} ${phi_libs} paddle_framework_proto)
add_dependencies(infrt ${infrt_mlir_incs} mlir-headers)
......
......@@ -13,6 +13,10 @@ limitations under the License. */
#include "paddle/phi/core/allocator.h"
#ifdef INFRT_WITH_GPU
#include <cuda_runtime.h>
#endif
namespace infrt {
namespace backends {
......@@ -29,5 +33,22 @@ class CpuPhiAllocator : public phi::Allocator {
}
};
#ifdef INFRT_WITH_GPU
// TODO(wilber): Just for demo test. we need a more efficient gpu allocator.
class GpuPhiAllocator : public phi::Allocator {
public:
static void deleter(phi::Allocation* ptr) { cudaFree(ptr->ptr()); }
AllocationPtr Allocate(size_t bytes_size) {
void* ptr;
cudaMalloc(&ptr, bytes_size);
return AllocationPtr(
new phi::Allocation(
ptr, bytes_size, phi::Place(phi::AllocationType::GPU)),
deleter);
}
};
#endif
} // namespace backends
} // namespace infrt
......@@ -13,6 +13,7 @@ limitations under the License. */
#include "paddle/infrt/backends/host/phi_allocator.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace infrt {
namespace backends {
......@@ -31,5 +32,16 @@ class CpuPhiContext : public phi::CPUContext {
std::unique_ptr<phi::Allocator> alloc_{std::make_unique<CpuPhiAllocator>()};
};
class GpuPhiContext : public phi::GPUContext {
public:
using Base = phi::GPUContext;
using phi::GPUContext::SetStream;
using phi::GPUContext::SetEigenDevice;
using phi::GPUContext::SetBlasHandle;
using phi::GPUContext::SetDnnHandle;
using phi::GPUContext::SetSolverHandle;
using phi::GPUContext::SetSparseHandle;
};
} // namespace backends
} // namespace infrt
......@@ -37,9 +37,9 @@ namespace infrt {
namespace backends {
namespace tensorrt {
const char* model_input = "model_input";
const char* model_output = "model_output1";
const char* model_output2 = "model_output2";
const char* model_input = "input_0";
const char* model_output = "output_0";
const char* model_output2 = "output_1";
TrtUniquePtr<nvinfer1::INetworkDefinition> ConstructNetwork(
nvinfer1::IBuilder* builder, nvinfer1::Dims dims, bool is_static_shape) {
......@@ -122,27 +122,26 @@ TEST(trt, run_static) {
std::unordered_map<std::string, phi::DenseTensor*> inputs;
inputs.emplace(std::make_pair(model_input, &input));
phi::DenseTensor output, output2;
std::unordered_map<std::string, phi::DenseTensor*> outputs;
outputs.emplace(std::make_pair(model_output, &output));
outputs.emplace(std::make_pair(model_output2, &output2));
static_trt_engine.SetUpInference(inference_options, inputs, &outputs);
static_trt_engine.PrepareOutputHandle("output_0");
static_trt_engine.PrepareOutputHandle("output_1");
static_trt_engine.SetUpInference(inference_options, inputs);
static_trt_engine.GetEngineInfo();
static_trt_engine.Run(context);
phi::DenseTensor* output0 = static_trt_engine.GetOutput("output_0");
phi::DenseTensor* output1 = static_trt_engine.GetOutput("output_1");
std::vector<float> output_data1(inference_options.batch * 1 * 28 * 28, 0);
std::vector<float> output_data2(inference_options.batch * 2 * 28 * 28, 0);
paddle::memory::Copy(phi::CPUPlace(),
output_data1.data(),
place,
output.data<float>(),
output0->data<float>(),
sizeof(float) * output_data1.size(),
context.stream());
paddle::memory::Copy(phi::CPUPlace(),
output_data2.data(),
place,
output2.data<float>(),
output1->data<float>(),
sizeof(float) * output_data2.size(),
context.stream());
cudaStreamSynchronize(context.stream());
......@@ -208,27 +207,27 @@ TEST(trt, run_dynamic) {
context.stream());
std::unordered_map<std::string, phi::DenseTensor*> inputs;
std::unordered_map<std::string, phi::DenseTensor*> outputs;
inputs.emplace(std::make_pair(model_input, &input));
outputs.emplace(std::make_pair(model_output, &output));
outputs.emplace(std::make_pair(model_output2, &output2));
engine.SetUpInference(inference_options, inputs, &outputs);
engine.PrepareOutputHandle("output_0");
engine.PrepareOutputHandle("output_1");
engine.SetUpInference(inference_options, inputs);
engine.GetEngineInfo();
engine.Run(context);
phi::DenseTensor* output0 = engine.GetOutput("output_0");
phi::DenseTensor* output1 = engine.GetOutput("output_1");
std::vector<float> output_data1(inference_options.batch * 1 * 16 * 16, 0);
std::vector<float> output_data2(inference_options.batch * 2 * 16 * 16, 0);
paddle::memory::Copy(phi::CPUPlace(),
output_data1.data(),
place,
output.data<float>(),
output0->data<float>(),
sizeof(float) * output_data1.size(),
context.stream());
paddle::memory::Copy(phi::CPUPlace(),
output_data2.data(),
place,
output2.data<float>(),
output1->data<float>(),
sizeof(float) * output_data2.size(),
context.stream());
cudaStreamSynchronize(context.stream());
......
......@@ -21,6 +21,7 @@
#include "paddle/phi/backends/dynload/tensorrt.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
namespace infrt {
namespace backends {
......@@ -235,10 +236,20 @@ bool TrtEngine::SetupNetworkAndConfig(const BuildOptions& build,
return true;
}
void TrtEngine::PrepareOutputHandle(const std::string& out_name) {
phi::DenseTensor t;
outputs_.emplace(out_name, t);
}
phi::DenseTensor* TrtEngine::GetOutput(const std::string& name) {
return &outputs_[name];
}
size_t TrtEngine::GetOutputNum() const { return outputs_.size(); }
bool TrtEngine::SetUpInference(
const InferenceOptions& inference,
const std::unordered_map<std::string, phi::DenseTensor*>& inputs,
std::unordered_map<std::string, phi::DenseTensor*>* outputs) {
const std::unordered_map<std::string, phi::DenseTensor*>& inputs) {
// TODO(wilber): now only create one exec_context
FreshDeviceId();
CHECK(engine_ != nullptr);
......@@ -252,10 +263,10 @@ bool TrtEngine::SetUpInference(
bindings_.front()->AddBinding(
bind_index, it.first, true, it.second, nvinfer1::DataType::kFLOAT);
}
for (auto& it : *outputs) {
for (auto& it : outputs_) {
const int bind_index = engine_->getBindingIndex(it.first.c_str());
bindings_.front()->AddBinding(
bind_index, it.first, false, it.second, nvinfer1::DataType::kFLOAT);
bind_index, it.first, false, &it.second, nvinfer1::DataType::kFLOAT);
}
return true;
......@@ -290,11 +301,13 @@ void TrtEngine::StaticRun(const phi::GPUContext& ctx) {
const int bind_index = engine_->getBindingIndex(bind.name.c_str());
std::vector<int32_t> ddim;
auto dims = engine_->getBindingDimensions(bind_index);
CHECK_NE(runtime_batch, -1) << "runtime_batch should not be -1.";
ddim.push_back(runtime_batch);
for (int i = 0; i < dims.nbDims; ++i) {
ddim.push_back(dims.d[i]);
}
bind.buffer->Resize(phi::make_ddim(ddim));
// TODO(wilber): now only support float output.
ctx.Alloc<float>(bind.buffer, sizeof(float) * bind.buffer->numel());
buffers[bind_index] = static_cast<void*>(bind.buffer->data<float>());
}
......
......@@ -81,11 +81,17 @@ class TrtEngine {
// TODO(wilber): How to support multiple execution contexts?
bool SetUpInference(
const InferenceOptions& inference,
const std::unordered_map<std::string, phi::DenseTensor*>& inputs,
std::unordered_map<std::string, phi::DenseTensor*>* outputs);
const std::unordered_map<std::string, phi::DenseTensor*>& inputs);
void GetEngineInfo();
void PrepareOutputHandle(const std::string& out_name);
// TODO(wilber): The output tensor names are: output_0, output_1, ...
phi::DenseTensor* GetOutput(const std::string&);
size_t GetOutputNum() const;
private:
void FreshDeviceId();
......@@ -112,6 +118,7 @@ class TrtEngine {
std::vector<std::unique_ptr<Bindings>> bindings_;
int device_id_{0};
bool is_dynamic_shape_{false};
std::unordered_map<std::string, phi::DenseTensor> outputs_;
};
} // namespace tensorrt
......
......@@ -130,7 +130,7 @@ def TensorMapGetTensorOp : DT_Op<"tensor_map_get_tensor", [NoSideEffect]> {
}
def TensorMapGetSizeOp : DT_Op<"tensor_map_get_size", [NoSideEffect]> {
let summary = "ddt.tensor_map_get_size operation";
let summary = "dt.tensor_map_get_size operation";
let description = [{
An operation that get the size of a TensorMap.
......@@ -141,6 +141,32 @@ def TensorMapGetSizeOp : DT_Op<"tensor_map_get_size", [NoSideEffect]> {
let assemblyFormat = "`(` $map `)` attr-dict `->` type($size)";
}
def Infrt_TensorListGetTensorOp : DT_Op<"tensor_list_get_tensor", [NoSideEffect]> {
let summary = "dt.tensor_list_get_tensor operation";
let description = [{
An operation that can get a tensor from a TensorList.
}];
let arguments = (ins
DenseTensorList:$l,
I32Attr:$id
);
let results = (outs DenseTensor:$output);
let verifier = ?;
}
def TensorListGetSizeOp : DT_Op<"tensor_list_get_size", [NoSideEffect]> {
let summary = "dt.tensor_list_get_size operation";
let description = [{
An operation that get the size of a TensorList.
}];
let arguments = (ins DenseTensorList:$map);
let results = (outs I32:$size);
}
def GetTensorShapeOp : DT_Op<"get_tensor_shape", [NoSideEffect]> {
let summary = "dt.get_tensor_shape operation";
......
......@@ -89,6 +89,13 @@ def DenseTensorMap : Infrt_Type<"DenseTensorMap"> {
let parameters = (ins);
}
// TODO(wilber): Add !infrt.vec type.
def DenseTensorList : Infrt_Type<"DenseTensorList"> {
let summary = "infrt dense tensor map";
let description = [{dense_tensor map}];
let parameters = (ins);
}
// Type Constrait for concrete DenseTensor type.
class DenseTensor<string target, string precision, string layout> :
Type<CPred<"$_self == ::infrt::DenseTensorType::get($_self.getContext(), ::infrt::TargetType::"#target#",::infrt::PrecisionType::"#precision#",::infrt::LayoutType::"#layout#")">,
......
......@@ -138,6 +138,10 @@ mlir::Type InfrtDialect::parseType(::mlir::DialectAsmParser &parser) const {
parser.getContext(), *targetType, *precisionType, *layoutType);
}
if (keyword == "tensor_list") {
return infrt::DenseTensorListType::get(parser.getContext());
}
if (keyword == "dense_tensor_map") {
return DenseTensorMapType::get(parser.getContext());
}
......@@ -175,6 +179,9 @@ void InfrtDialect::printType(::mlir::Type type,
return;
}
if (type.isa<infrt::DenseTensorListType>()) {
os << "tensor_list";
}
// print DenseTensorType, for example: !infrt.dense_tensor<CPU, FP32, NCHW>
if (type.isa<DenseTensorMapType>()) {
os << "dense_tensor_map";
......
......@@ -26,6 +26,7 @@
#include "paddle/infrt/dialect/phi/ir/phi_kernels.h"
#include "paddle/infrt/dialect/tensor_shape.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt {
void registerCinnDialects(mlir::DialectRegistry &registry) { // NOLINT
......@@ -37,7 +38,8 @@ void registerCinnDialects(mlir::DialectRegistry &registry) { // NOLINT
phi::PHIDenseTensorDialect,
phi::PHICPUKernelDialect,
phi::PHIGPUKernelDialect,
phi::PHIDialect
phi::PHIDialect,
infrt::trt::TensorRTDialect
#endif
>();
}
......
......@@ -21,8 +21,8 @@ def PHI_DenseTensorDialect : Dialect {
class PDT_Op<string mnemonic, list<OpTrait> traits = []> : Op<PHI_DenseTensorDialect,
mnemonic, !listconcat(traits, [PhiOpTrait, IsolatedFromAbove])> {}
class CreateDenseTensorOp
: PDT_Op<"create_dense_tensor", [NoSideEffect]> {
class CreateDenseTensorOp<string target>
: PDT_Op<"create_dense_tensor." # target, [NoSideEffect]> {
let arguments = (ins Context:$context, I64ArrayAttr:$dims,
LayoutAttr:$layout, I64ArrayAttr:$lod, PrecisionAttr:$precision);
let results = (outs DenseTensor:$output);
......@@ -51,9 +51,11 @@ class CreateContextOp<string target>
let results = (outs Context:$output);
}
def PDT_CreateDenseTensorOp : CreateDenseTensorOp;
def PDT_CreateCPUDenseTensorOp : CreateDenseTensorOp<"cpu">;
def PDT_CreateGPUDenseTensorOp : CreateDenseTensorOp<"gpu">;
def PDT_FillDenseTensorOp_f32 : FillDenseTensorOp<F32ArrayAttr, "f32">;
def PDT_CreateCPUContextOp : CreateContextOp<"cpu">;
def PDT_CreateGPUContextOp : CreateContextOp<"gpu">;
def PDT_PrintDenseTensor : PrintDenseTensorOp;
def FakeKernelOp : PDT_Op<"fake_phi_kernel"> {
......
......@@ -21,6 +21,10 @@
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/tensorrt/trt_dialect_types.h"
#include "paddle/infrt/dialect/dense_tensor.h"
#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h"
#include "paddle/infrt/dialect/phi/ir/phi_base.h"
namespace infrt {
namespace trt {
......
......@@ -7,6 +7,8 @@ include "mlir/Interfaces/CallInterfaces.td"
include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/tensorrt/trt_op_base.td"
include "paddle/infrt/dialect/infrt/ir/infrt_base.td"
include "paddle/infrt/dialect/phi/ir/infrt_phi_base.td"
def TRT_CreateEngineOp : TRT_Op<"create_engine", [SingleBlockImplicitTerminator<"::infrt::ReturnOp">]> {
let summary = "trt CreateEngine Op";
......@@ -14,8 +16,8 @@ def TRT_CreateEngineOp : TRT_Op<"create_engine", [SingleBlockImplicitTerminator<
Describe a tensorrt subgraph.
}];
let regions = (region SizedRegion<1>:$body);
let arguments = (ins Variadic<TRT_Tensor>:$inputs, DefaultValuedAttr<BoolAttr, "true">:$run_once);
let results = (outs TRT_EngineType:$output);
let arguments = (ins Variadic<DenseTensor>:$inputs, DefaultValuedAttr<BoolAttr, "true">:$run_once);
let results = (outs TRT_EngineType:$engine);
}
def TRT_ExecuteOp : TRT_Op<"execute", [NoSideEffect]> {
......@@ -23,8 +25,25 @@ def TRT_ExecuteOp : TRT_Op<"execute", [NoSideEffect]> {
let description = [{
Describe a tensorrt runtime.
}];
let arguments = (ins TRT_EngineType:$engine, Variadic<TRT_Tensor>:$inputs);
let results = (outs Variadic<TRT_Tensor>:$output);
let arguments = (ins TRT_EngineType:$engine, Variadic<DenseTensor>:$inputs);
let results = (outs Variadic<DenseTensor>:$output);
}
def TRT_EngineComputeOp : TRT_Op<"compute", [NoSideEffect]> {
let summary = "trt compute engine";
let description = [{
execute engine
}];
let arguments = (ins TRT_EngineType:$engine, Context:$context);
let results = (outs DenseTensorList:$outputs);
}
def TRT_InspectEngineOp : TRT_Op<"inspect_engine", [NoSideEffect]> {
let summary = "trt inspect engine";
let description = [{
Show engine
}];
let arguments = (ins TRT_EngineType:$engine);
}
def TRT_ActivationOp : TRT_Op<"Activation", [NoSideEffect]> {
......@@ -34,11 +53,11 @@ def TRT_ActivationOp : TRT_Op<"Activation", [NoSideEffect]> {
TensorRT IActivationLayer.
}];
let arguments = (ins TRT_Tensor:$input, SI32Attr:$activation_type,
let arguments = (ins DenseTensor:$input, SI32Attr:$activation_type,
DefaultValuedAttr<F32Attr, "0.0">:$alpha,
DefaultValuedAttr<F32Attr, "0.0">:$beta);
let results = (outs TRT_Tensor:$output);
let results = (outs DenseTensor:$output);
}
def TRT_ElementWiseOp : TRT_Op<"ElementWise", [NoSideEffect]> {
......@@ -48,9 +67,9 @@ def TRT_ElementWiseOp : TRT_Op<"ElementWise", [NoSideEffect]> {
TensorRT IElementWiseLayer.
}];
let arguments = (ins TRT_Tensor:$input1, TRT_Tensor:$input2, SI32Attr:$elementwise_operation);
let arguments = (ins DenseTensor:$input1, DenseTensor:$input2, SI32Attr:$elementwise_operation);
let results = (outs TRT_Tensor:$output);
let results = (outs DenseTensor:$output);
}
def TRT_MatrixMultiplyOp : TRT_Op<"MatrixMultiply", [NoSideEffect]> {
......@@ -60,10 +79,10 @@ def TRT_MatrixMultiplyOp : TRT_Op<"MatrixMultiply", [NoSideEffect]> {
TensorRT IMatrixMultiplyLayer.
}];
let arguments = (ins TRT_Tensor:$input1, BoolAttr:$transpose1,
TRT_Tensor:$input2, BoolAttr:$transpose2);
let arguments = (ins DenseTensor:$input1, BoolAttr:$transpose1,
DenseTensor:$input2, BoolAttr:$transpose2);
let results = (outs TRT_Tensor:$output);
let results = (outs DenseTensor:$output);
}
#endif // TRT_OPS
......@@ -33,7 +33,10 @@
#include "paddle/infrt/dialect/phi/pass/phi_op_convert_pass.h"
#include "paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launchers.h"
#include "paddle/infrt/kernel/phi/registry.h"
#endif
#if defined(INFRT_WITH_GPU) && defined(INFRT_WITH_TRT)
#include "paddle/infrt/kernel/tensorrt/registry.h"
#endif // INFRT_WITH_GPU && INFRT_WITH_TRT
#endif // INFRT_WITH_PHI
static llvm::cl::list<std::string> cl_shared_libs( // NOLINT
"shared_libs",
......@@ -62,6 +65,9 @@ int main(int argc, char** argv) {
#ifdef INFRT_WITH_PHI
kernel::RegisterPhiKernels(&registry);
kernel::RegisterInferShapeLaunchers(&registry);
#if defined(INFRT_WITH_GPU) && defined(INFRT_WITH_TRT)
kernel::RegisterTrtKernels(&registry);
#endif // INFRT_WITH_GPU && INFRT_WITH_TRT
#endif
// load extra shared library
......
......@@ -16,12 +16,14 @@
#include <llvm/Support/SourceMgr.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Diagnostics.h>
#include <mlir/IR/OperationSupport.h>
#include <mlir/Parser.h>
#include <glog/logging.h>
#include <iostream>
#include <memory>
#include <string>
......@@ -42,6 +44,13 @@
#include "paddle/infrt/host_context/value.h"
#include "paddle/infrt/tensor/tensor_shape.h"
#ifdef INFRT_WITH_PHI
#ifdef INFRT_WITH_TRT
#include "paddle/infrt/kernel/tensorrt/trt_kernels.h"
#endif
#include "paddle/phi/core/dense_tensor.h"
#endif
namespace infrt {
namespace host_context {
......@@ -277,33 +286,58 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(
impl_->runtime->NewOpExecutable(op->getName().getStringRef().str());
VLOG(3) << "processing general op : " << op->getName().getStringRef().str();
// TODO(wilber): Find a more appropriate way to handle special cases.
if (op->getName().getStringRef() == "trt.create_engine") {
#ifdef INFRT_WITH_TRT
auto* symbols = impl_->runtime->symbol_table();
::infrt::kernel::tensorrt::MlirOperationWithInfrtSymbol mlir_operation;
mlir_operation.operation = op;
mlir_operation.symbol_table = symbols;
impl_->cur_op->AppendArgument(new Value(mlir_operation));
// TODO(wilber): how to pass DenseTensor to create_engine op? temporialiy
// add a naive implement.
for (int i = 0, e = op->getNumOperands(); i < e; ++i) {
auto operand = op->getOperand(i);
if (operand.isa<mlir::BlockArgument>()) {
mlir::BlockArgument arg = operand.dyn_cast<mlir::BlockArgument>();
Value* arg_value = GetValue(arg);
if (arg_value->is_type<phi::DenseTensor>()) {
impl_->runtime->FeedInArgs(
std::make_pair(std::to_string(i), ValueRef(arg_value)));
}
}
}
#else
CHECK(false) << "should not reach here";
#endif
} else {
// process operands
for (int i = 0, e = op->getNumOperands(); i < e; i++) {
// function argument as value
auto operand = op->getOperand(i);
/// if (operand.getKind() == mlir::Value::Kind::BlockArgument) {
if (operand.isa<mlir::BlockArgument>()) {
mlir::BlockArgument arg = operand.dyn_cast<mlir::BlockArgument>();
Value* arg_value = GetValue(arg);
impl_->cur_op->AppendArgument(arg_value);
VLOG(3) << "* op mlir operand: " << DumpToString(arg) << " "
<< GetValue(arg);
continue;
}
// process operands
for (int i = 0, e = op->getNumOperands(); i < e; i++) {
// function argument as value
auto operand = op->getOperand(i);
/// if (operand.getKind() == mlir::Value::Kind::BlockArgument) {
if (operand.isa<mlir::BlockArgument>()) {
mlir::BlockArgument arg = operand.dyn_cast<mlir::BlockArgument>();
Value* arg_value = GetValue(arg);
// normal value
Value* arg_value = GetValue(operand);
if (!arg_value) {
auto upstream_op = operand.getDefiningOp();
arg_value = GetOpResult(upstream_op);
}
CHECK(arg_value) << "No-exist argument value found: "
<< DumpToString(operand);
impl_->cur_op->AppendArgument(arg_value);
VLOG(3) << "* op mlir operand: " << DumpToString(arg) << " "
<< GetValue(arg);
continue;
}
// normal value
Value* arg_value = GetValue(operand);
if (!arg_value) {
auto upstream_op = operand.getDefiningOp();
arg_value = GetOpResult(upstream_op);
VLOG(3) << "* op mlir operand: " << DumpToString(operand) << " "
<< GetValue(operand) << " vs " << arg_value;
}
CHECK(arg_value) << "No-exist argument value found: "
<< DumpToString(operand);
impl_->cur_op->AppendArgument(arg_value);
VLOG(3) << "* op mlir operand: " << DumpToString(operand) << " "
<< GetValue(operand) << " vs " << arg_value;
}
// process attributes
......@@ -383,33 +417,6 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(
impl_->cur_op->AppendAttribute(tmp[i]);
}
// process results
llvm::SmallVector<Value*, 4> res_values;
for (int i = 0, e = op->getNumResults(); i < e; i++) {
auto res = op->getResult(i);
if (res.getType().isa<::infrt::DenseTensorType>()) {
auto r = impl_->value_map.try_emplace(
res, ValueRef(new Value{::phi::DenseTensor()}));
CHECK(r.second) << "Duplicate add mlir value [" << DumpToString(res)
<< "]";
res_values.push_back(r.first->second.get());
} else {
res_values.push_back(AddValue(res));
}
VLOG(3) << "* op mlir res: " << DumpToString(res) << " " << GetValue(res);
}
impl_->cur_op->SetResults(res_values);
#ifdef INFRT_DEBUG
{
VLOG(3) << "check result";
for (int i = 0; i < impl_->cur_op->frame().GetNumResults(); i++) {
VLOG(3) << "+ res value: " << impl_->cur_op->frame().GetResults()[i];
}
}
#endif
// process regions, we treat regions as attribute.
auto num_regions = op->getNumRegions();
if (num_regions > 0) {
......@@ -438,6 +445,33 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(
impl_->cur_op->AppendAttribute(new Value(function));
}
// process results
llvm::SmallVector<Value*, 4> res_values;
for (int i = 0, e = op->getNumResults(); i < e; i++) {
auto res = op->getResult(i);
if (res.getType().isa<::infrt::DenseTensorType>()) {
auto r = impl_->value_map.try_emplace(
res, ValueRef(new Value{::phi::DenseTensor()}));
CHECK(r.second) << "Duplicate add mlir value [" << DumpToString(res)
<< "]";
res_values.push_back(r.first->second.get());
} else {
res_values.push_back(AddValue(res));
}
VLOG(3) << "* op mlir res: " << DumpToString(res) << " " << GetValue(res);
}
impl_->cur_op->SetResults(res_values);
#ifdef INFRT_DEBUG
{
VLOG(3) << "check result";
for (int i = 0; i < impl_->cur_op->frame().GetNumResults(); i++) {
VLOG(3) << "+ res value: " << impl_->cur_op->frame().GetResults()[i];
}
}
#endif
return true;
}
......
......@@ -24,6 +24,7 @@
#include "paddle/infrt/common/shared.h"
#include "paddle/infrt/dialect/infrt/common/types.h"
#include "paddle/infrt/host_context/function.h"
#include "paddle/infrt/host_context/symbol_table.h"
#include "paddle/infrt/support/variant.h"
#include "paddle/infrt/tensor/dense_host_tensor.h"
#include "paddle/infrt/tensor/dense_tensor_view.h"
......@@ -41,7 +42,15 @@
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/meta_tensor.h"
#endif
#ifdef INFRT_WITH_GPU
#include "paddle/phi/backends/gpu/gpu_context.h"
#endif // INFRT_WITH_GPU
#ifdef INFRT_WITH_TRT
#include "paddle/infrt/backends/tensorrt/trt_engine.h"
#include "paddle/infrt/kernel/tensorrt/trt_kernels.h"
#endif // INFRT_WITH_TRT
#endif // INFRT_WITH_PHI
namespace infrt {
namespace host_context {
......@@ -72,8 +81,13 @@ using ValueVariantType =
::phi::MetaTensor,
::phi::DenseTensor,
backends::CpuPhiContext,
#ifdef INFRT_WITH_GPU
backends::GpuPhiContext,
::phi::GPUContext,
#endif
::phi::CPUContext,
std::vector<const phi::DenseTensor*>,
std::vector<phi::DenseTensor*>,
paddle::experimental::ScalarBase<phi::DenseTensor>,
paddle::experimental::ScalarArrayBase<phi::DenseTensor>,
std::vector<phi::MetaTensor*>,
......@@ -81,6 +95,10 @@ using ValueVariantType =
paddle::experimental::Backend,
paddle::experimental::DataLayout,
paddle::experimental::DataType,
#ifdef INFRT_WITH_TRT
::infrt::backends::tensorrt::TrtEngine,
::infrt::kernel::tensorrt::MlirOperationWithInfrtSymbol,
#endif // INFRT_WITH_TRT
#endif
std::vector<int16_t>,
std::vector<int32_t>,
......@@ -120,8 +138,18 @@ class Value : public common::Object {
#ifdef INFRT_WITH_PHI
explicit Value(::phi::CPUContext&& x) : data(std::move(x)) {}
explicit Value(backends::CpuPhiContext&& x) : data(std::move(x)) {}
#ifdef INFRT_WITH_GPU
explicit Value(::phi::GPUContext&& x) : data(std::move(x)) {}
explicit Value(backends::GpuPhiContext&& x) : data(std::move(x)) {}
#endif
explicit Value(::phi::DenseTensor&& x) : data(std::move(x)) {}
explicit Value(::phi::MetaTensor&& x) : data(std::move(x)) {}
#ifdef INFRT_WITH_TRT
explicit Value(::infrt::backends::tensorrt::TrtEngine&& x)
: data(std::move(x)) {}
explicit Value(::infrt::kernel::tensorrt::MlirOperationWithInfrtSymbol x)
: data(x) {}
#endif // INFRT_WITH_TRT
#endif
template <typename T>
......
add_subdirectory(phi)
add_subdirectory(tensorrt)
core_gather_headers()
......
......@@ -25,6 +25,16 @@ namespace phi {
return ctx;
}
#ifdef INFRT_WITH_GPU
::phi::GPUContext CreateGPUContext() {
::phi::GPUContext context;
context.PartialInitWithoutAllocator();
context.SetAllocator(new ::infrt::backends::GpuPhiAllocator{});
context.PartialInitWithAllocator();
return context;
}
#endif
} // namespace phi
} // namespace kernel
} // namespace infrt
......@@ -25,6 +25,10 @@ namespace phi {
::phi::CPUContext CreateCPUContext();
#ifdef INFRT_WITH_GPU
::phi::GPUContext CreateGPUContext();
#endif
} // namespace phi
} // namespace kernel
} // namespace infrt
......@@ -15,6 +15,12 @@
#include "paddle/infrt/kernel/phi/dense_tensor_kernels.h"
#include "paddle/infrt/dialect/phi/data_type.h"
#include "paddle/infrt/kernel/phi/context_kernels.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/common/place.h"
#ifdef INFRT_WITH_GPU
#include <cuda_runtime.h>
#endif
namespace infrt {
namespace kernel {
......@@ -34,26 +40,83 @@ namespace phi {
{}));
}
::phi::DenseTensor CreateGPUDenseTensor(
const ::phi::GPUContext& context,
host_context::Attribute<std::vector<int64_t>> dims,
host_context::Attribute<std::vector<int64_t>> lod,
host_context::Attribute<::infrt::LayoutType> layout,
host_context::Attribute<::infrt::PrecisionType> precision) {
return ::phi::DenseTensor(
const_cast<::phi::Allocator*>(&context.GetAllocator()),
::phi::DenseTensorMeta(ConvertPrecisionToPhi(precision.get()),
::phi::make_ddim(dims.get()),
ConvertLayoutToPhi(layout.get()),
{}));
}
void FillDenseTensorF32(::phi::DenseTensor* dense_tensor,
host_context::Attribute<std::vector<float>> value) {
auto place = ::phi::CPUPlace();
auto place = dense_tensor->place();
float* a_data = dense_tensor->mutable_data<float>(place);
for (int64_t i = 0; i < dense_tensor->numel(); ++i) {
a_data[i] = (value.get())[i];
if (place.GetType() == ::phi::AllocationType::CPU) {
for (int64_t i = 0; i < dense_tensor->numel(); ++i) {
a_data[i] = (value.get())[i];
}
} else if (place.GetType() == ::phi::AllocationType::GPU) {
#ifdef INFRT_WITH_GPU
// TODO(wilber): how to set the stream parameter to copy with stream.
cudaMemcpy(a_data,
value.get().data(),
sizeof(float) * value.get().size(),
cudaMemcpyHostToDevice);
#endif
} else {
llvm_unreachable("temporarily not support other target.");
}
}
void PrintDenseTensor(::phi::DenseTensor* dense_tensor) {
#define PRINT_META_DATA(PHI_DATATYPE, DTYPE) \
case ::phi::DataType::PHI_DATATYPE: { \
DTYPE* data = dense_tensor->data<DTYPE>(); \
if (dense_tensor->numel() == 0) break; \
std::cout << data[0]; \
for (int64_t i = 1; i < dense_tensor->numel(); i++) { \
std::cout << "," << data[i]; \
} \
break; \
#ifndef INFRT_WITH_GPU
#define PRINT_META_DATA(PHI_DATATYPE, DTYPE) \
case ::phi::DataType::PHI_DATATYPE: { \
auto place = dense_tensor->place(); \
if (place.GetType() == ::phi::AllocationType::CPU) { \
DTYPE* data = dense_tensor->data<DTYPE>(); \
if (dense_tensor->numel() == 0) break; \
std::cout << data[0]; \
for (int64_t i = 1; i < dense_tensor->numel(); i++) { \
std::cout << "," << data[i]; \
} \
} \
break; \
}
#else
#define PRINT_META_DATA(PHI_DATATYPE, DTYPE) \
case ::phi::DataType::PHI_DATATYPE: { \
auto place = dense_tensor->place(); \
DTYPE* data = dense_tensor->data<DTYPE>(); \
if (dense_tensor->numel() == 0) break; \
if (place.GetType() == ::phi::AllocationType::CPU) { \
std::cout << data[0]; \
for (int64_t i = 1; i < dense_tensor->numel(); i++) { \
std::cout << "," << data[i]; \
} \
} else if (place.GetType() == ::phi::AllocationType::GPU) { \
std::vector<DTYPE> host_data(dense_tensor->numel(), 0); \
cudaMemcpy(host_data.data(), \
data, \
sizeof(DTYPE) * dense_tensor->numel(), \
cudaMemcpyDeviceToHost); \
std::cout << host_data[0]; \
for (int64_t i = 1; i < dense_tensor->numel(); i++) { \
std::cout << "," << host_data[i]; \
} \
} else { \
llvm_unreachable("temporarily not support other target."); \
} \
break; \
}
#endif
::phi::DDim dims = dense_tensor->dims();
std::cout << "dense_tensor: shape=shape" << dims.to_str() << ","
......
......@@ -30,6 +30,13 @@ namespace phi {
host_context::Attribute<::infrt::LayoutType> layout,
host_context::Attribute<::infrt::PrecisionType> precision);
::phi::DenseTensor CreateGPUDenseTensor(
const ::phi::GPUContext& context,
host_context::Attribute<std::vector<int64_t>> dims,
host_context::Attribute<std::vector<int64_t>> lod,
host_context::Attribute<::infrt::LayoutType> layout,
host_context::Attribute<::infrt::PrecisionType> precision);
void FillDenseTensorF32(::phi::DenseTensor* dense_tensor,
host_context::Attribute<std::vector<float>> values);
void PrintDenseTensor(::phi::DenseTensor* dense_tensor);
......
......@@ -35,7 +35,7 @@ void RegisterPhiKernels(host_context::KernelRegistry* registry) {
registry->AddKernel("phi_dt.create_context.cpu",
INFRT_KERNEL(infrt::kernel::phi::CreateCPUContext));
registry->AddKernelWithAttrs(
"phi_dt.create_dense_tensor",
"phi_dt.create_dense_tensor.cpu",
INFRT_KERNEL(infrt::kernel::phi::CreateDenseTensor),
{"dims", "lod", "layout", "precision"});
registry->AddKernelWithAttrs(
......@@ -44,6 +44,15 @@ void RegisterPhiKernels(host_context::KernelRegistry* registry) {
{"value"});
registry->AddKernel("phi_dt.print_tensor",
INFRT_KERNEL(infrt::kernel::phi::PrintDenseTensor));
#ifdef INFRT_WITH_GPU
registry->AddKernel("phi_dt.create_context.gpu",
INFRT_KERNEL(infrt::kernel::phi::CreateGPUContext));
registry->AddKernelWithAttrs(
"phi_dt.create_dense_tensor.gpu",
INFRT_KERNEL(infrt::kernel::phi::CreateGPUDenseTensor),
{"dims", "lod", "layout", "precision"});
#endif
}
} // namespace kernel
......
......@@ -25,6 +25,10 @@
#include "paddle/infrt/tensor/tensor_map.h"
#include "paddle/infrt/tensor/tensor_shape.h"
#ifdef INFRT_WITH_PHI
#include "paddle/phi/core/dense_tensor.h"
#endif
namespace infrt {
namespace kernel {
using namespace host_context; // NOLINT
......@@ -62,6 +66,20 @@ DenseHostTensor TensorMapGetTensor(TensorMap map, Attribute<std::string> name) {
int32_t TensorMapGetSize(TensorMap map) { return map.size(); }
// TODO(wilber): Maybe we should place TensorList type in dt dialect.
#ifdef INFRT_WITH_PHI
phi::DenseTensor TensorListGetTensor(std::vector<phi::DenseTensor *> list,
Attribute<int32_t> idx) {
CHECK_LT(idx.get(), static_cast<int>(list.size()))
<< "idx should less than list size";
return *list[idx.get()];
}
int32_t TensorListGetSize(const std::vector<phi::DenseTensor *> &list) {
return list.size();
}
#endif
DenseHostTensor ShallowCopyTensor(DenseHostTensor v) { return v; }
template <typename T>
......@@ -126,6 +144,14 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) {
INFRT_KERNEL(TensorMapGetTensor));
registry->AddKernel("dt.tensor_map_get_size", INFRT_KERNEL(TensorMapGetSize));
// TensorList related methods.
#ifdef INFRT_WITH_PHI
registry->AddKernel("dt.tensor_list_get_tensor",
INFRT_KERNEL(TensorListGetTensor));
registry->AddKernel("dt.tensor_list_get_size",
INFRT_KERNEL(TensorListGetSize));
#endif
registry->AddKernel("dt.shallow_copy_tensor",
INFRT_KERNEL(ShallowCopyTensor));
......
if (NOT (INFRT_WITH_PHI AND INFRT_WITH_GPU AND INFRT_WITH_TRT))
return()
endif()
core_gather_headers()
gather_srcs(infrt_src SRCS
registry.cc
trt_kernels.cc
)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/kernel/tensorrt/registry.h"
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/kernel/tensorrt/trt_kernels.h"
namespace infrt {
namespace kernel {
void RegisterTrtKernels(host_context::KernelRegistry* registry) {
registry->AddKernel("trt.create_engine",
INFRT_KERNEL(tensorrt::CreateTrtEngine));
registry->AddKernel("trt.inspect_engine",
INFRT_KERNEL(tensorrt::PrintTrtLayer));
registry->AddKernel("trt.compute", INFRT_KERNEL(tensorrt::TrtEngineCompute));
}
} // namespace kernel
} // namespace infrt
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
namespace infrt {
namespace host_context {
struct KernelRegistry;
} // namespace host_context
} // namespace infrt
namespace infrt {
namespace kernel {
/**
* Register all the trt kernels to registry.
*/
void RegisterTrtKernels(host_context::KernelRegistry* registry);
} // namespace kernel
} // namespace infrt
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/kernel/tensorrt/trt_kernels.h"
#include <string>
#include "NvInfer.h"
#include "NvInferRuntime.h"
#include "NvInferRuntimeCommon.h"
#include "glog/logging.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "paddle/infrt/backends/tensorrt/trt_engine.h"
#include "paddle/infrt/backends/tensorrt/trt_options.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
#include "paddle/infrt/host_context/symbol_table.h"
#include "paddle/phi/core/dense_tensor.h"
namespace infrt {
namespace kernel {
namespace tensorrt {
::infrt::backends::tensorrt::TrtEngine CreateTrtEngine(
MlirOperationWithInfrtSymbol
create_engine_op /*, input_tensors, output_tensors, weights*/) {
// TODO(wilber): The device_id needs to get from mlir.
int device_id = 0;
backends::tensorrt::TrtEngine engine(device_id);
auto* builder = engine.GetTrtBuilder();
// TODO(wilber): How to process weights?
backends::tensorrt::TrtUniquePtr<nvinfer1::INetworkDefinition> network;
// TODO(wilber): static_shape or dynamic_shape network? The code is just
// static_shape test.
network.reset(builder->createNetworkV2(0));
// TODO(wilber): The build option shoule be fiiled from mlir info.
backends::tensorrt::BuildOptions options;
options.max_batch = 4;
// Parse mlir Region which only has one block.
mlir::Operation& operation = *create_engine_op.operation;
auto* symbol_table = create_engine_op.symbol_table;
CHECK_NOTNULL(symbol_table);
unsigned int num_regions = operation.getNumRegions();
CHECK_EQ(num_regions, 1U) << "only support one region case.";
auto& region = operation.getRegion(0);
auto& block = region.getBlocks().front();
llvm::DenseMap<mlir::Value, nvinfer1::ITensor*> map_info;
std::unordered_map<std::string, phi::DenseTensor*> trt_bind_inputs;
for (auto index_operand : llvm::enumerate(operation.getOperands())) {
mlir::Value operand = index_operand.value();
size_t idx = index_operand.index();
const std::string input_name = "input_" + std::to_string(idx);
auto* v = symbol_table->GetValue(std::to_string(idx));
CHECK_NOTNULL(v);
auto* t = &v->get<phi::DenseTensor>();
trt_bind_inputs[input_name] = t;
// TODO(wilber): get input info from mlir.
// TODO(wilber): input dims, now only support static_shape, and just remove
// the first dimension.
// TODO(wilber): now only suppot float input.
nvinfer1::Dims dims;
dims.nbDims = t->dims().size() - 1;
for (int i = 0; i < dims.nbDims; ++i) {
dims.d[i] = t->dims()[i + 1];
}
auto* in =
network->addInput(input_name.c_str(), nvinfer1::DataType::kFLOAT, dims);
map_info[operand] = in;
}
// TODO(wilber): Find a way to add layer.
for (auto& inner_op : block.without_terminator()) {
if (inner_op.getName().getStringRef() == "trt.Activation") {
trt::ActivationOp act_op = llvm::dyn_cast<trt::ActivationOp>(inner_op);
auto in_arg = act_op.getOperand();
if (!map_info.count(in_arg)) {
CHECK(false) << "map_info not has in_arg.";
}
nvinfer1::ActivationType act_type =
static_cast<nvinfer1::ActivationType>(act_op.activation_type());
auto* act_layer = network->addActivation(*map_info[in_arg], act_type);
act_layer->setAlpha(act_op.alpha().convertToFloat());
act_layer->setBeta(act_op.beta().convertToFloat());
for (size_t i = 0; i < act_op->getNumResults(); ++i) {
nvinfer1::ITensor* act_out_tensor = act_layer->getOutput(i);
mlir::Value act_out = act_op->getResult(i);
map_info[act_out] = act_out_tensor;
}
}
// if (inner_op.getName().getStringRef() == "trt.Constant") {
// trt::ConstantOp op = llvm::dyn_cast<trt::ConstantOp>(inner_op);
// mlir::Value op_out = op.getResult();
// std::vector<float> weight_data{1};
// auto* layer = network->addConstant(nvinfer1::Dims2(1, 1),
// nvinfer1::Weights{nvinfer1::DataType::kFLOAT, weight_data.data(), 1});
// auto* op_out_tenor = layer->getOutput(0);
// map_info[op_out] = op_out_tenor;
// }
}
for (auto& inner_op : block.without_terminator()) {
for (mlir::Value v : inner_op.getResults()) {
for (mlir::Operation* user : v.getUsers()) {
if (user->getName().getStringRef() == "infrt.return") {
if (!map_info.count(v)) {
CHECK(false) << "map_info not has value";
}
network->markOutput(*map_info[v]);
}
}
}
}
// std::unordered_map<std::string, phi::DenseTensor*> trt_bind_outputs;
mlir::Operation* ret = block.getTerminator();
for (unsigned int i = 0; i < ret->getNumOperands(); ++i) {
mlir::Value arg = ret->getOperand(i);
CHECK(map_info.count(arg));
map_info[arg]->setName(("output_" + std::to_string(i)).c_str());
}
for (int i = 0; i < network->getNbOutputs(); ++i) {
engine.PrepareOutputHandle(network->getOutput(i)->getName());
}
VLOG(3) << "trt engine build start.";
engine.Build(std::move(network), options);
VLOG(3) << "trt engine build done.";
// TODO(wilber): get inference options from mlir.
backends::tensorrt::InferenceOptions inference_options;
inference_options.batch = 1;
// TODO(wilber): bind trt input/output tensors.
engine.SetUpInference(inference_options, trt_bind_inputs);
return engine;
}
void PrintTrtLayer(backends::tensorrt::TrtEngine* engine) {
engine->GetEngineInfo();
}
std::vector<phi::DenseTensor*> TrtEngineCompute(
backends::tensorrt::TrtEngine* engine, const phi::GPUContext& context) {
engine->Run(context);
std::vector<phi::DenseTensor*> res;
for (size_t i = 0; i < engine->GetOutputNum(); ++i) {
res.push_back(engine->GetOutput("output_" + std::to_string(i)));
}
return res;
}
} // namespace tensorrt
} // namespace kernel
} // namespace infrt
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <tuple>
#include <utility>
#include "mlir/IR/Operation.h"
#include "paddle/infrt/backends/tensorrt/trt_engine.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace infrt {
namespace host_context {
class SymbolTable;
} // namespace host_context
namespace kernel {
namespace tensorrt {
struct MlirOperationWithInfrtSymbol {
mlir::Operation* operation;
::infrt::host_context::SymbolTable* symbol_table;
};
::infrt::backends::tensorrt::TrtEngine CreateTrtEngine(
MlirOperationWithInfrtSymbol engine_op);
void PrintTrtLayer(backends::tensorrt::TrtEngine* engine);
std::vector<phi::DenseTensor*> TrtEngineCompute(
backends::tensorrt::TrtEngine* engine, const phi::GPUContext& context);
} // namespace tensorrt
} // namespace kernel
} // namespace infrt
// RUN: infrtexec -i %s | FileCheck %s
// CHECK-LABEL: @run_trt
func @run_trt(%0 : !infrt.dense_tensor<GPU, FP32, NCHW>, %ctx : !phi.context<GPU>) {
%a = "trt.create_engine"(%0) ({
%1 = "trt.Activation"(%0) {activation_type = 1 : si32, alpha = 1.0 : f32, beta = 6.0 : f32} : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
"infrt.return"(%1) : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
}) : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> !trt.engine
"trt.inspect_engine"(%a) {} : (!trt.engine) -> ()
%res = "trt.compute"(%a, %ctx) {} : (!trt.engine, !phi.context<GPU>) -> (!infrt.tensor_list)
%size = "dt.tensor_list_get_size"(%res) {} : (!infrt.tensor_list) -> (i32)
"infrt.print.i32"(%size) {} : (i32) -> ()
%ts0 = "dt.tensor_list_get_tensor"(%res) {id = 0 : i32} : (!infrt.tensor_list) -> (!infrt.dense_tensor<GPU, FP32, NCHW>)
"phi_dt.print_tensor" (%ts0) : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
infrt.return
}
// CHECK-LABEL: @main
func @main() {
%ctx = "phi_dt.create_context.gpu" (): () -> !phi.context<GPU>
%t = "phi_dt.create_dense_tensor.gpu" (%ctx) {
precision=#infrt.precision<FP32>,
layout=#infrt.layout<NCHW>,
dims=[1:i64, 3:i64, 1:i64, 1:i64], lod=[1:i64]}: (!phi.context<GPU>) -> (!infrt.dense_tensor<GPU, FP32, NCHW>)
"phi_dt.fill_dense_tensor.f32"(%t) {value=[3.8:f32, 2.4:f32, 1.3:f32]} : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
"phi_dt.print_tensor" (%t) : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> ()
//%res =
infrt.call @run_trt(%t, %ctx) : (!infrt.dense_tensor<GPU, FP32, NCHW>, !phi.context<GPU>) -> ()
//-> (!infrt.dense_tensor<GPU, FP32, NCHW>)
infrt.return
}
......@@ -3,7 +3,7 @@
// CHECK-LABEL: @sign_any_float32_execute
func @sign_any_float32_execute() {
%ctx = "phi_dt.create_context.cpu" (): () -> !phi.context<CPU>
%t = "phi_dt.create_dense_tensor" (%ctx) {
%t = "phi_dt.create_dense_tensor.cpu" (%ctx) {
precision=#infrt.precision<FP32>,
layout=#infrt.layout<NCHW>, lod=[1:i64], dims=[1:i64]}: (!phi.context<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
"phi_dt.fill_dense_tensor.f32"(%t) {value=[3.8:f32]} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
......
......@@ -6,7 +6,7 @@ module {
}
func @main() {
%ctx = "phi_dt.create_context.cpu" (): () -> !phi.context<CPU>
%t = "phi_dt.create_dense_tensor" (%ctx) {precision=#infrt.precision<FP32>, layout=#infrt.layout<NCHW>, lod=[1:i64], dims=[1:i64]}: (!phi.context<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
%t = "phi_dt.create_dense_tensor.cpu" (%ctx) {precision=#infrt.precision<FP32>, layout=#infrt.layout<NCHW>, lod=[1:i64], dims=[1:i64]}: (!phi.context<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
"phi_dt.fill_dense_tensor.f32"(%t) {value=[3.8:f32]} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
%2 = infrt.call@predict(%t) : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
phi_dt.print_tensor(%2 : !infrt.dense_tensor<CPU, FP32, NCHW>)
......
// RUN: trt-exec %s
// CHECK-LABEL: @main
func @main(%bias:tensor<?xf32>, %c:tensor<?xf32>, %b1:tensor<?xf32>, %b2:tensor<?xf32>, %bias1:tensor<?xf32>, %bias2:tensor<?xf32>) -> tensor<?xf32> {
%d = "pd.elementwise_add"(%c, %bias) {axis=-1:si32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e = "pd.relu6"(%d) {} : (tensor<?xf32>) -> tensor<?xf32>
func @main(%bias:!infrt.dense_tensor<GPU, FP32, NCHW>, %c:!infrt.dense_tensor<GPU, FP32, NCHW>, %b1:!infrt.dense_tensor<GPU, FP32, NCHW>, %b2:!infrt.dense_tensor<GPU, FP32, NCHW>, %bias1:!infrt.dense_tensor<GPU, FP32, NCHW>, %bias2:!infrt.dense_tensor<GPU, FP32, NCHW>) -> !infrt.dense_tensor<GPU, FP32, NCHW> {
%d = "pd.elementwise_add"(%c, %bias) {axis=-1:si32} : (!infrt.dense_tensor<GPU, FP32, NCHW>, !infrt.dense_tensor<GPU, FP32, NCHW>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
%e = "pd.relu6"(%d) {} : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
%c1 = "pd.matmul"(%e, %b1) {transpose_x=false, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%d1 = "pd.elementwise_add"(%c1, %bias1) {axis=-1:si32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e1 = "pd.relu"(%d1) {} : (tensor<?xf32>) -> tensor<?xf32>
%c1 = "pd.matmul"(%e, %b1) {transpose_x=false, transpose_y=false} : (!infrt.dense_tensor<GPU, FP32, NCHW>, !infrt.dense_tensor<GPU, FP32, NCHW>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
%d1 = "pd.elementwise_add"(%c1, %bias1) {axis=-1:si32} : (!infrt.dense_tensor<GPU, FP32, NCHW>, !infrt.dense_tensor<GPU, FP32, NCHW>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
%e1 = "pd.relu"(%d1) {} : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
%c2 = "pd.matmul"(%e1, %b2) {transpose_x=true, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%d2 = "pd.elementwise_add"(%c2, %bias2) {axis=-1:si32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e2 = "pd.relu"(%d2) {} : (tensor<?xf32>) -> tensor<?xf32>
%c2 = "pd.matmul"(%e1, %b2) {transpose_x=true, transpose_y=false} : (!infrt.dense_tensor<GPU, FP32, NCHW>, !infrt.dense_tensor<GPU, FP32, NCHW>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
%d2 = "pd.elementwise_add"(%c2, %bias2) {axis=-1:si32} : (!infrt.dense_tensor<GPU, FP32, NCHW>, !infrt.dense_tensor<GPU, FP32, NCHW>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
%e2 = "pd.relu"(%d2) {} : (!infrt.dense_tensor<GPU, FP32, NCHW>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
infrt.return %e2 : tensor<?xf32>
infrt.return %e2 : !infrt.dense_tensor<GPU, FP32, NCHW>
}
......@@ -741,6 +741,10 @@ struct GPUContext::Impl {
GPUContext::GPUContext() : DeviceContext(), impl_(std::make_unique<Impl>()) {}
GPUContext::GPUContext(GPUContext&&) = default;
GPUContext& GPUContext::operator=(GPUContext&&) = default;
GPUContext::GPUContext(const GPUPlace& place)
: DeviceContext(), impl_(std::make_unique<Impl>(place)) {}
......
......@@ -77,6 +77,8 @@ class DnnWorkspaceHandle {
class GPUContext : public DeviceContext {
public:
GPUContext();
GPUContext(GPUContext&&);
GPUContext& operator=(GPUContext&&);
explicit GPUContext(const GPUPlace& place);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册