未验证 提交 bdef57cd 编写于 作者: W Wilber 提交者: GitHub

add weight unfold pass and handle trt fc op (#41088)

* add weight unfold pass and handle trt fc op

* update

* add kernel

* update

* update
上级 4b9e748a
...@@ -48,6 +48,10 @@ ...@@ -48,6 +48,10 @@
#include "paddle/infrt/kernel/test_kernels.h" #include "paddle/infrt/kernel/test_kernels.h"
#include "paddle/infrt/tensor/tensor_map.h" #include "paddle/infrt/tensor/tensor_map.h"
#if defined(INFRT_WITH_GPU) && defined(INFRT_WITH_TRT)
#include "paddle/infrt/kernel/tensorrt/registry.h"
#endif
using namespace infrt::host_context; // NOLINT using namespace infrt::host_context; // NOLINT
using namespace infrt::tensor; // NOLINT using namespace infrt::tensor; // NOLINT
using namespace infrt::tensor; // NOLINT using namespace infrt::tensor; // NOLINT
......
...@@ -2,6 +2,7 @@ core_gather_headers() ...@@ -2,6 +2,7 @@ core_gather_headers()
gather_srcs(infrt_src SRCS gather_srcs(infrt_src SRCS
infrt_op_fuse_pass.cc infrt_op_fuse_pass.cc
infrt_weights_unfold_pass.cc
) )
mlir_add_rewriter(infrt_op_fuse) mlir_add_rewriter(infrt_op_fuse)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/infrt/pass/infrt_weights_unfold_pass.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/UseDefLists.h"
#include "mlir/IR/Value.h"
#include "paddle/infrt/dialect/infrt/common/types.h"
#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h"
#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h"
#include "paddle/infrt/dialect/phi/ir/phi_base.h"
#include "paddle/infrt/paddle/model_parser.h"
#include "paddle/infrt/tensor/phi/tensor_map.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"
namespace infrt {
namespace kernel {
namespace phi {
::infrt::phi::DenseTensorMap LoadCombinedParameters(
const std::string& model_path, const std::string& params_path);
} // namespace phi
} // namespace kernel
} // namespace infrt
namespace {
class InfrtWeightsFoldPass
: public mlir::PassWrapper<InfrtWeightsFoldPass, mlir::FunctionPass> {
public:
::llvm::StringRef getName() const override { return "InfrtWeightsFoldPass"; }
void runOnFunction() override;
};
void InfrtWeightsFoldPass::runOnFunction() {
mlir::Block& block = getFunction().body().front();
mlir::OpBuilder builder(&block, block.begin());
::llvm::StringRef model_path, params_path;
std::vector<mlir::Operation*> delete_op_list;
// Insert cpu context. If the pass failed, the context op will be removed by
// CanonicalizerPass.
auto context_op = builder.create<infrt::phi::CreateCPUContextOp>(
block.front().getLoc(),
infrt::phi::ContextType::get(builder.getContext(),
infrt::TargetType::CPU));
for (auto& org_op : block) {
if (auto op = llvm::dyn_cast<::infrt::phi::LoadCombinedParamsOp>(org_op)) {
model_path = op.model_path();
params_path = op.params_path();
// Load params.
auto map = ::infrt::kernel::phi::LoadCombinedParameters(
model_path.str(), params_path.str());
bool delete_load_combined_op{false};
// Find all use of map.
for (auto map_arg : op.getODSResults(0)) {
for (mlir::Operation* user_op : map_arg.getUsers()) {
if (auto tensor_map_get_op =
llvm::dyn_cast<::infrt::phi::TensorMapGetTensorOp>(user_op)) {
::llvm::StringRef arg_name = tensor_map_get_op.name();
::phi::DenseTensor* tensor = map.GetDenseTensor(arg_name.str());
if (tensor->dtype() != ::phi::DataType::FLOAT32) {
CHECK(false)
<< "the weight tensor type now only support float32.";
}
builder.setInsertionPoint(tensor_map_get_op);
auto inited_weight_op =
builder.create<::infrt::phi::CreateHostInitedDenseTensorOp>(
tensor_map_get_op.getLoc(),
tensor_map_get_op.output().getType(),
context_op.output(),
builder.getI64ArrayAttr(
{tensor->dims().Get(),
tensor->dims().Get() + tensor->dims().size()}),
::infrt::LayoutAttr::get(builder.getContext(),
::infrt::LayoutType::NCHW),
builder.getI64ArrayAttr({0}),
builder.getF32ArrayAttr(
{tensor->data<float>(),
static_cast<size_t>(tensor->numel())}));
tensor_map_get_op.replaceAllUsesWith(inited_weight_op.output());
delete_load_combined_op = true;
delete_op_list.push_back(tensor_map_get_op);
}
}
}
if (delete_load_combined_op) {
delete_op_list.push_back(op);
}
}
}
// remove all map releation op.
for (size_t i = 0; i < delete_op_list.size(); ++i) {
delete_op_list[i]->erase();
}
}
} // namespace
std::unique_ptr<mlir::Pass> infrt::CreateInfrtWeightsUnfoldPass() {
return std::make_unique<InfrtWeightsFoldPass>();
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mlir/Pass/Pass.h>
namespace infrt {
/*
* InfrtWeightsFoldPass.
*/
std::unique_ptr<mlir::Pass> CreateInfrtWeightsUnfoldPass();
} // namespace infrt
...@@ -28,6 +28,17 @@ class CreateDenseTensorOp<string target> ...@@ -28,6 +28,17 @@ class CreateDenseTensorOp<string target>
let results = (outs DenseTensor:$output); let results = (outs DenseTensor:$output);
} }
def CreateHostInitedDenseTensorOp : PDT_Op<"create_host_inited_dense_tensor.f32", [NoSideEffect]> {
let arguments = (ins
Context:$context,
I64ArrayAttr:$dims,
LayoutAttr:$layout,
I64ArrayAttr:$lod,
F32ArrayAttr:$values
);
let results = (outs DenseTensor:$output);
}
def CreateInitedCpuFLOAT32DenseTensorOp def CreateInitedCpuFLOAT32DenseTensorOp
: PDT_Op<"create_inited_dense_tensor.cpu.f32", [NoSideEffect]> { : PDT_Op<"create_inited_dense_tensor.cpu.f32", [NoSideEffect]> {
let arguments = (ins Context:$context, I64ArrayAttr:$dims, let arguments = (ins Context:$context, I64ArrayAttr:$dims,
......
...@@ -13,14 +13,19 @@ ...@@ -13,14 +13,19 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <glog/logging.h>
#include <mlir/IR/Builders.h> #include <mlir/IR/Builders.h>
#include <mlir/Transforms/DialectConversion.h> #include <mlir/Transforms/DialectConversion.h>
#include "paddle/infrt/dialect/infrt/common/types.h"
#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h" #include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h"
#include "paddle/infrt/dialect/pd/ir/pd_ops.h"
#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt { namespace infrt {
namespace trt { namespace trt {
static mlir::Value createTRTConv2dOp(mlir::PatternRewriter &rewriter, static mlir::Value createTRTConv2dOp(mlir::PatternRewriter &rewriter, // NOLINT
mlir::Operation *op) { mlir::Operation *op) {
auto conv_op = ::llvm::dyn_cast<infrt::pd::Conv2dOp>(op); auto conv_op = ::llvm::dyn_cast<infrt::pd::Conv2dOp>(op);
::mlir::SmallVector<::mlir::Value, 4> operands; ::mlir::SmallVector<::mlir::Value, 4> operands;
...@@ -75,11 +80,93 @@ static mlir::Value createTRTConv2dOp(mlir::PatternRewriter &rewriter, ...@@ -75,11 +80,93 @@ static mlir::Value createTRTConv2dOp(mlir::PatternRewriter &rewriter,
op->getLoc(), resultTypes, operands, attributes); op->getLoc(), resultTypes, operands, attributes);
} }
static mlir::Value createTRTShuffledOp(mlir::PatternRewriter &rewriter, static inline mlir::ArrayAttr TransposeWeight(
mlir::Operation *op, mlir::PatternRewriter &builder, // NOLINT
const mlir::Value &input, const mlir::ArrayAttr &weight,
const mlir::Attribute &start, const mlir::ArrayAttr &dims) {
const mlir::Attribute &stop) { CHECK_EQ(dims.size(), 2U);
CHECK(!dims.empty());
CHECK(dims[0].getType().isInteger(64));
CHECK(!weight.empty());
CHECK(weight[0].getType().isF32());
int row = dims[0].cast<mlir::IntegerAttr>().getInt();
int col = dims[1].cast<mlir::IntegerAttr>().getInt();
std::vector<float> trans_weight(weight.size());
for (int i = 0; i < row; ++i) {
for (int j = 0; j < col; ++j) {
trans_weight[j * row + i] =
weight[i * col + j].cast<mlir::FloatAttr>().getValueAsDouble();
}
}
return builder.getF32ArrayAttr(trans_weight);
}
// matmul_y and elt_y is weights.
inline ::llvm::SmallVector<::mlir::Value, 4> createTrtFcOp(
mlir::PatternRewriter &builder, // NOLINT
mlir::Value matmul_x,
mlir::Value matmul_y,
mlir::Value elt_y,
mlir::Value elt_out) {
::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;
auto *y_producer = matmul_y.getDefiningOp();
auto create_inited_tensor_op =
llvm::dyn_cast<::infrt::phi::CreateHostInitedDenseTensorOp>(y_producer);
CHECK_NOTNULL(create_inited_tensor_op);
mlir::ArrayAttr dims = create_inited_tensor_op.dims();
CHECK_EQ(dims.size(), 2U);
std::vector<int64_t> new_dims(dims.size());
CHECK(!dims.empty());
CHECK(dims[0].getType().isIntOrIndex());
for (size_t i = 0; i < new_dims.size(); ++i) {
new_dims[i] = dims[dims.size() - 1 - i].cast<mlir::IntegerAttr>().getInt();
}
auto insert_point = builder.saveInsertionPoint();
builder.setInsertionPoint(create_inited_tensor_op);
auto new_inited_op =
builder.create<::infrt::phi::CreateHostInitedDenseTensorOp>(
create_inited_tensor_op->getLoc(),
create_inited_tensor_op.output().getType(),
create_inited_tensor_op.context(),
builder.getI64ArrayAttr(new_dims),
::infrt::LayoutAttr::get(builder.getContext(),
::infrt::LayoutType::NCHW),
create_inited_tensor_op.lod(),
TransposeWeight(builder, create_inited_tensor_op.values(), dims));
builder.replaceOp(create_inited_tensor_op, new_inited_op->getResults());
builder.restoreInsertionPoint(insert_point);
auto ods_loc = builder.getFusedLoc({y_producer->getLoc()});
::infrt::trt::FullyConnectedOp fc_op;
{
::mlir::SmallVector<::mlir::Type, 4> tblgen_types;
fc_op = builder.create<::infrt::trt::FullyConnectedOp>(
ods_loc,
elt_out.getType(),
matmul_x,
new_inited_op.output(),
elt_y,
builder.getSI32IntegerAttr(new_dims[0]));
}
::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;
for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{fc_op.getODSResults(0)}) {
tblgen_repl_values.push_back(v);
}
return tblgen_repl_values;
}
static mlir::Value createTRTShuffledOp(
mlir::PatternRewriter &rewriter, // NOLINT
mlir::Operation *op,
const mlir::Value &input,
const mlir::Attribute &start,
const mlir::Attribute &stop) {
auto flatten_op = ::llvm::dyn_cast<infrt::pd::Flatten_contiguous_rangeOp>(op); auto flatten_op = ::llvm::dyn_cast<infrt::pd::Flatten_contiguous_rangeOp>(op);
::mlir::SmallVector<::mlir::Value, 4> operands; ::mlir::SmallVector<::mlir::Value, 4> operands;
operands.push_back(input); operands.push_back(input);
......
...@@ -43,6 +43,12 @@ def PD2TRT_SoftMax_Lower : Pat< ...@@ -43,6 +43,12 @@ def PD2TRT_SoftMax_Lower : Pat<
(PD_SoftmaxOp $Input, $axis, $_), (PD_SoftmaxOp $Input, $axis, $_),
(TRT_SoftMaxOp $Input, $axis)>; (TRT_SoftMaxOp $Input, $axis)>;
// pd.matmul_v2 + pd.elementwise_add -> trt.fc
def createTrtFcOp : NativeCodeCall<"::infrt::trt::createTrtFcOp($_builder, $0, $1, $2, $3)">;
def PD2TRT_Fc_Lower : Pat<
(PD_Elementwise_addOp:$elt_out (PD_Matmul_v2Op $X, $Y, $trans_x, $trans_y), $elt_y, $axis),
(createTrtFcOp $X, $Y, $elt_y, $elt_out)>;
def createTRTShuffledOp : NativeCodeCall<"createTRTShuffledOp($_builder, $0.getDefiningOp(), $1, $2, $3)">; def createTRTShuffledOp : NativeCodeCall<"createTRTShuffledOp($_builder, $0.getDefiningOp(), $1, $2, $3)">;
def PD2TRT_Flatten_contiguous_range_Lower : Pat< def PD2TRT_Flatten_contiguous_range_Lower : Pat<
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <iostream> #include <iostream>
#include <string> #include <string>
#include "paddle/infrt/common/global.h" #include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/infrt/pass/infrt_weights_unfold_pass.h"
#include "paddle/infrt/dialect/mlir_loader.h" #include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h" #include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"
#include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h" #include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h"
...@@ -42,6 +43,8 @@ ...@@ -42,6 +43,8 @@
#include "paddle/infrt/kernel/phi/registry.h" #include "paddle/infrt/kernel/phi/registry.h"
#endif #endif
#include <mlir/Transforms/Passes.h>
int main(int argc, char** argv) { int main(int argc, char** argv) {
static llvm::cl::opt<std::string> input_file( static llvm::cl::opt<std::string> input_file(
llvm::cl::Positional, llvm::cl::Positional,
...@@ -73,11 +76,13 @@ int main(int argc, char** argv) { ...@@ -73,11 +76,13 @@ int main(int argc, char** argv) {
mlir::PassManager pm(context); mlir::PassManager pm(context);
mlir::OpPassManager& trt_pass_manager = pm.nest<mlir::FuncOp>(); mlir::OpPassManager& trt_pass_manager = pm.nest<mlir::FuncOp>();
trt_pass_manager.addPass(::infrt::CreateInfrtWeightsUnfoldPass());
trt_pass_manager.addPass(std::make_unique<infrt::trt::TRTOpTellerPass>()); trt_pass_manager.addPass(std::make_unique<infrt::trt::TRTOpTellerPass>());
trt_pass_manager.addPass(std::make_unique<infrt::trt::TRTGraphFusePass>()); trt_pass_manager.addPass(std::make_unique<infrt::trt::TRTGraphFusePass>());
trt_pass_manager.addPass(std::make_unique<infrt::trt::TRTGraphSplitPass>(1)); trt_pass_manager.addPass(std::make_unique<infrt::trt::TRTGraphSplitPass>(1));
trt_pass_manager.addPass(std::make_unique<infrt::trt::TRTOpConverterPass>()); trt_pass_manager.addPass(std::make_unique<infrt::trt::TRTOpConverterPass>());
trt_pass_manager.addPass(infrt::trt::createTrtTypeConvertPass()); trt_pass_manager.addPass(infrt::trt::createTrtTypeConvertPass());
trt_pass_manager.addPass(::mlir::createCanonicalizerPass());
if (mlir::failed(pm.run(*module))) { if (mlir::failed(pm.run(*module))) {
std::cout << "\npass failed!\n" << std::endl; std::cout << "\npass failed!\n" << std::endl;
return 4; return 4;
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h" #include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"
#include <glog/logging.h>
#include <llvm/ADT/SetVector.h> #include <llvm/ADT/SetVector.h>
#include <mlir/Analysis/SliceAnalysis.h> #include <mlir/Analysis/SliceAnalysis.h>
#include <mlir/IR/Builders.h> #include <mlir/IR/Builders.h>
...@@ -133,8 +134,7 @@ void topoSortBlock(mlir::Block &body) { // NOLINT ...@@ -133,8 +134,7 @@ void topoSortBlock(mlir::Block &body) { // NOLINT
for (auto it = body.rbegin(); it != body.rend(); ++it) { for (auto it = body.rbegin(); it != body.rend(); ++it) {
toSort.insert(&*it); toSort.insert(&*it);
} }
llvm::SetVector<mlir::Operation *> result = llvm::SetVector<mlir::Operation *> result = mlir::topologicalSort(toSort);
mlir::topologicalSort(std::move(toSort));
for (auto *op : result) { for (auto *op : result) {
op->moveBefore(body.getTerminator()); op->moveBefore(body.getTerminator());
} }
...@@ -177,7 +177,9 @@ void TRTGraphFusePass::runOnFunction() { ...@@ -177,7 +177,9 @@ void TRTGraphFusePass::runOnFunction() {
if (changed) break; if (changed) break;
} }
} while (changed); } while (changed);
topoSortBlock(body);
// TODO(wilber): Implement a toposort for efficiency.
// topoSortBlock(body);
} }
} // namespace trt } // namespace trt
} // namespace infrt } // namespace infrt
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/infrt/dialect/tensorrt/trt_type_convert_pass.h" #include "paddle/infrt/dialect/tensorrt/trt_type_convert_pass.h"
#include <glog/logging.h> #include <glog/logging.h>
#include <set>
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
...@@ -57,6 +58,10 @@ void TrtTypeConvertPass::runOnFunction() { ...@@ -57,6 +58,10 @@ void TrtTypeConvertPass::runOnFunction() {
::infrt::LayoutType layout = ::infrt::LayoutType::NCHW; ::infrt::LayoutType layout = ::infrt::LayoutType::NCHW;
::infrt::TargetType target = ::infrt::TargetType::GPU; ::infrt::TargetType target = ::infrt::TargetType::GPU;
const std::set<std::string> inited_op_repr{
"phi_dt.tensor_map_get_tensor",
"phi_dt.create_inited_dense_tensor.cpu.f32",
"phi_dt.create_host_inited_dense_tensor.f32"};
for (auto& op : worklist) { for (auto& op : worklist) {
if (auto tensor_map_get_op = if (auto tensor_map_get_op =
llvm::dyn_cast<::infrt::phi::TensorMapGetTensorOp>(op)) { llvm::dyn_cast<::infrt::phi::TensorMapGetTensorOp>(op)) {
...@@ -66,16 +71,25 @@ void TrtTypeConvertPass::runOnFunction() { ...@@ -66,16 +71,25 @@ void TrtTypeConvertPass::runOnFunction() {
mlir_ctx, t.getTarget(), t.getPrecision(), layout); mlir_ctx, t.getTarget(), t.getPrecision(), layout);
res.setType(replace_type); res.setType(replace_type);
} }
} } else if (auto create_inited_tensor_op =
if (auto create_engine = llvm::dyn_cast<::infrt::trt::CreateEngineOp>(op)) { llvm::dyn_cast<::infrt::phi::CreateHostInitedDenseTensorOp>(
op)) {
auto res = create_inited_tensor_op.output();
if (auto t = res.getType().dyn_cast<::infrt::DenseTensorType>()) {
auto replace_type = ::infrt::DenseTensorType::get(
mlir_ctx, t.getTarget(), t.getPrecision(), layout);
res.setType(replace_type);
}
} else if (auto create_engine =
llvm::dyn_cast<::infrt::trt::CreateEngineOp>(op)) {
// Insert `infrt.gpu.memcpy` op. // Insert `infrt.gpu.memcpy` op.
for (auto arg : create_engine.getOperands()) { for (auto arg : create_engine.getOperands()) {
if (mlir::Operation* producer = arg.getDefiningOp()) { if (mlir::Operation* producer = arg.getDefiningOp()) {
if (arg.getType().isa<::infrt::DenseTensorType>()) { if (arg.getType().isa<::infrt::DenseTensorType>()) {
builder.setInsertionPointAfter(producer); builder.setInsertionPointAfter(producer);
auto t = arg.getType().dyn_cast<::infrt::DenseTensorType>(); auto t = arg.getType().dyn_cast<::infrt::DenseTensorType>();
if (producer->getName().getStringRef() != if (!inited_op_repr.count(
"phi_dt.tensor_map_get_tensor" && producer->getName().getStringRef().str()) &&
t.getTarget() != ::infrt::TargetType::GPU) { t.getTarget() != ::infrt::TargetType::GPU) {
auto replace_type = ::infrt::DenseTensorType::get( auto replace_type = ::infrt::DenseTensorType::get(
mlir_ctx, target, t.getPrecision(), layout); mlir_ctx, target, t.getPrecision(), layout);
...@@ -86,7 +100,7 @@ void TrtTypeConvertPass::runOnFunction() { ...@@ -86,7 +100,7 @@ void TrtTypeConvertPass::runOnFunction() {
arg, arg,
llvm::dyn_cast<::infrt::phi::CreateGPUContextOp>(ctx_op) llvm::dyn_cast<::infrt::phi::CreateGPUContextOp>(ctx_op)
.output(), .output(),
mlir::BoolAttr::get(mlir_ctx, /*d2h*/ false)); builder.getBoolAttr(false));
arg.replaceAllUsesExcept(mem_cpy_op.output(), mem_cpy_op); arg.replaceAllUsesExcept(mem_cpy_op.output(), mem_cpy_op);
} }
} }
...@@ -104,7 +118,7 @@ void TrtTypeConvertPass::runOnFunction() { ...@@ -104,7 +118,7 @@ void TrtTypeConvertPass::runOnFunction() {
blockArg, blockArg,
llvm::dyn_cast<::infrt::phi::CreateGPUContextOp>(ctx_op) llvm::dyn_cast<::infrt::phi::CreateGPUContextOp>(ctx_op)
.output(), .output(),
mlir::BoolAttr::get(mlir_ctx, /*d2h*/ false)); builder.getBoolAttr(false));
arg.replaceAllUsesExcept(mem_cpy_op.output(), mem_cpy_op); arg.replaceAllUsesExcept(mem_cpy_op.output(), mem_cpy_op);
} }
} }
...@@ -147,7 +161,7 @@ void TrtTypeConvertPass::runOnFunction() { ...@@ -147,7 +161,7 @@ void TrtTypeConvertPass::runOnFunction() {
arg, arg,
llvm::dyn_cast<::infrt::phi::CreateGPUContextOp>(ctx_op) llvm::dyn_cast<::infrt::phi::CreateGPUContextOp>(ctx_op)
.output(), .output(),
mlir::BoolAttr::get(mlir_ctx, /*d2h*/ true)); builder.getBoolAttr(true));
arg.replaceAllUsesExcept(mem_cpy_op.output(), mem_cpy_op); arg.replaceAllUsesExcept(mem_cpy_op.output(), mem_cpy_op);
} }
} }
......
...@@ -77,6 +77,27 @@ namespace phi { ...@@ -77,6 +77,27 @@ namespace phi {
return dense_tensor; return dense_tensor;
} }
::phi::DenseTensor CreateHostInitedDenseTensorF32(
const ::phi::CPUContext& context,
host_context::Attribute<std::vector<int64_t>> dims,
host_context::Attribute<std::vector<int64_t>> lod,
host_context::Attribute<::infrt::LayoutType> layout,
host_context::Attribute<std::vector<float>> values) {
::phi::DenseTensor dense_tensor(
const_cast<::phi::Allocator*>(&context.GetAllocator()),
::phi::DenseTensorMeta(
ConvertPrecisionToPhi(::infrt::PrecisionType::FLOAT32),
::phi::make_ddim(dims.get()),
ConvertLayoutToPhi(layout.get()),
{}));
CHECK_EQ(dense_tensor.numel(), static_cast<int64_t>(values.get().size()));
float* data = dense_tensor.mutable_data<float>(::phi::CPUPlace());
for (int64_t i = 0; i < dense_tensor.numel(); ++i) {
data[i] = values.get()[i];
}
return dense_tensor;
}
::phi::DenseTensor CreateGPUDenseTensor( ::phi::DenseTensor CreateGPUDenseTensor(
const ::phi::GPUContext& context, const ::phi::GPUContext& context,
host_context::Attribute<std::vector<int64_t>> dims, host_context::Attribute<std::vector<int64_t>> dims,
......
...@@ -39,6 +39,13 @@ namespace phi { ...@@ -39,6 +39,13 @@ namespace phi {
host_context::Attribute<::infrt::LayoutType> layout, host_context::Attribute<::infrt::LayoutType> layout,
host_context::Attribute<float> value); host_context::Attribute<float> value);
::phi::DenseTensor CreateHostInitedDenseTensorF32(
const ::phi::CPUContext& context,
host_context::Attribute<std::vector<int64_t>> dims,
host_context::Attribute<std::vector<int64_t>> lod,
host_context::Attribute<::infrt::LayoutType> layout,
host_context::Attribute<std::vector<float>> values);
::phi::DenseTensor CreateGPUDenseTensor( ::phi::DenseTensor CreateGPUDenseTensor(
const ::phi::GPUContext& context, const ::phi::GPUContext& context,
host_context::Attribute<std::vector<int64_t>> dims, host_context::Attribute<std::vector<int64_t>> dims,
......
...@@ -43,9 +43,15 @@ void RegisterPhiKernels(host_context::KernelRegistry* registry) { ...@@ -43,9 +43,15 @@ void RegisterPhiKernels(host_context::KernelRegistry* registry) {
INFRT_KERNEL(infrt::kernel::phi::CreateInitedDenseTensorF32), INFRT_KERNEL(infrt::kernel::phi::CreateInitedDenseTensorF32),
{"dims", "lod", "layout", "value"}); {"dims", "lod", "layout", "value"});
registry->AddKernel(
"phi_dt.create_host_inited_dense_tensor.f32",
INFRT_KERNEL(infrt::kernel::phi::CreateHostInitedDenseTensorF32),
{"dims", "lod", "layout", "values"});
registry->AddKernel("phi_dt.fill_dense_tensor.f32", registry->AddKernel("phi_dt.fill_dense_tensor.f32",
INFRT_KERNEL(infrt::kernel::phi::FillDenseTensorF32), INFRT_KERNEL(infrt::kernel::phi::FillDenseTensorF32),
{"value"}); {"value"});
registry->AddKernel("phi_dt.print_tensor", registry->AddKernel("phi_dt.print_tensor",
INFRT_KERNEL(infrt::kernel::phi::PrintDenseTensor)); INFRT_KERNEL(infrt::kernel::phi::PrintDenseTensor));
......
...@@ -23,7 +23,8 @@ namespace kernel { ...@@ -23,7 +23,8 @@ namespace kernel {
void RegisterTrtKernels(host_context::KernelRegistry* registry) { void RegisterTrtKernels(host_context::KernelRegistry* registry) {
registry->AddKernel("trt.create_engine", registry->AddKernel("trt.create_engine",
INFRT_KERNEL(tensorrt::CreateTrtEngine)); INFRT_KERNEL(tensorrt::CreateTrtEngine),
{"run_once"});
registry->AddKernel("trt.inspect_engine", registry->AddKernel("trt.inspect_engine",
INFRT_KERNEL(tensorrt::PrintTrtLayer)); INFRT_KERNEL(tensorrt::PrintTrtLayer));
registry->AddKernel("trt.compute", INFRT_KERNEL(tensorrt::TrtEngineCompute)); registry->AddKernel("trt.compute", INFRT_KERNEL(tensorrt::TrtEngineCompute));
......
...@@ -108,7 +108,10 @@ namespace tensorrt { ...@@ -108,7 +108,10 @@ namespace tensorrt {
} else { } else {
// TODO(wilber): Replace with the op name that generates the weights. // TODO(wilber): Replace with the op name that generates the weights.
std::unordered_set<std::string> weight_flags{ std::unordered_set<std::string> weight_flags{
"phi_dt.tensor_map_get_tensor", "phi_dt.create_dense_tensor.cpu"}; "phi_dt.tensor_map_get_tensor",
"phi_dt.create_dense_tensor.cpu",
"phi_dt.create_inited_dense_tensor.cpu.f32",
"phi_dt.create_host_inited_dense_tensor.f32"};
if (!weight_flags.count( if (!weight_flags.count(
operand.getDefiningOp()->getName().getStringRef().str())) { operand.getDefiningOp()->getName().getStringRef().str())) {
trt_bind_inputs[input_name] = t; trt_bind_inputs[input_name] = t;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册