From bdef57cd621267ec99340bf9313dfdfef80f42a9 Mon Sep 17 00:00:00 2001 From: Wilber Date: Thu, 31 Mar 2022 10:05:10 +0800 Subject: [PATCH] add weight unfold pass and handle trt fc op (#41088) * add weight unfold pass and handle trt fc op * update * add kernel * update * update --- paddle/infrt/api/infrt_api.cc | 4 + .../infrt/dialect/infrt/pass/CMakeLists.txt | 1 + .../infrt/pass/infrt_weights_unfold_pass.cc | 126 ++++++++++++++++++ .../infrt/pass/infrt_weights_unfold_pass.h | 25 ++++ .../infrt/dialect/phi/ir/infrt_phi_tensor.td | 11 ++ paddle/infrt/dialect/tensorrt/convert.h | 99 +++++++++++++- .../infrt/dialect/tensorrt/pd_lower_to_trt.td | 6 + paddle/infrt/dialect/tensorrt/trt_exec.cc | 5 + .../dialect/tensorrt/trt_graph_fuse_pass.cc | 8 +- .../dialect/tensorrt/trt_type_convert_pass.cc | 28 +++- .../infrt/kernel/phi/dense_tensor_kernels.cc | 21 +++ .../infrt/kernel/phi/dense_tensor_kernels.h | 7 + paddle/infrt/kernel/phi/registry.cc | 6 + paddle/infrt/kernel/tensorrt/registry.cc | 3 +- paddle/infrt/kernel/tensorrt/trt_kernels.cc | 5 +- 15 files changed, 337 insertions(+), 18 deletions(-) create mode 100644 paddle/infrt/dialect/infrt/pass/infrt_weights_unfold_pass.cc create mode 100644 paddle/infrt/dialect/infrt/pass/infrt_weights_unfold_pass.h diff --git a/paddle/infrt/api/infrt_api.cc b/paddle/infrt/api/infrt_api.cc index ca6b6f81e38..2e8b64f768f 100644 --- a/paddle/infrt/api/infrt_api.cc +++ b/paddle/infrt/api/infrt_api.cc @@ -48,6 +48,10 @@ #include "paddle/infrt/kernel/test_kernels.h" #include "paddle/infrt/tensor/tensor_map.h" +#if defined(INFRT_WITH_GPU) && defined(INFRT_WITH_TRT) +#include "paddle/infrt/kernel/tensorrt/registry.h" +#endif + using namespace infrt::host_context; // NOLINT using namespace infrt::tensor; // NOLINT using namespace infrt::tensor; // NOLINT diff --git a/paddle/infrt/dialect/infrt/pass/CMakeLists.txt b/paddle/infrt/dialect/infrt/pass/CMakeLists.txt index 19c12251a2e..ab06c00d143 100644 --- a/paddle/infrt/dialect/infrt/pass/CMakeLists.txt +++ b/paddle/infrt/dialect/infrt/pass/CMakeLists.txt @@ -2,6 +2,7 @@ core_gather_headers() gather_srcs(infrt_src SRCS infrt_op_fuse_pass.cc + infrt_weights_unfold_pass.cc ) mlir_add_rewriter(infrt_op_fuse) diff --git a/paddle/infrt/dialect/infrt/pass/infrt_weights_unfold_pass.cc b/paddle/infrt/dialect/infrt/pass/infrt_weights_unfold_pass.cc new file mode 100644 index 00000000000..6a9f828dc95 --- /dev/null +++ b/paddle/infrt/dialect/infrt/pass/infrt_weights_unfold_pass.cc @@ -0,0 +1,126 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/dialect/infrt/pass/infrt_weights_unfold_pass.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/UseDefLists.h" +#include "mlir/IR/Value.h" +#include "paddle/infrt/dialect/infrt/common/types.h" +#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h" +#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h" +#include "paddle/infrt/dialect/phi/ir/phi_base.h" +#include "paddle/infrt/paddle/model_parser.h" +#include "paddle/infrt/tensor/phi/tensor_map.h" +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace infrt { +namespace kernel { +namespace phi { +::infrt::phi::DenseTensorMap LoadCombinedParameters( + const std::string& model_path, const std::string& params_path); +} // namespace phi +} // namespace kernel +} // namespace infrt + +namespace { + +class InfrtWeightsFoldPass + : public mlir::PassWrapper { + public: + ::llvm::StringRef getName() const override { return "InfrtWeightsFoldPass"; } + + void runOnFunction() override; +}; + +void InfrtWeightsFoldPass::runOnFunction() { + mlir::Block& block = getFunction().body().front(); + mlir::OpBuilder builder(&block, block.begin()); + + ::llvm::StringRef model_path, params_path; + std::vector delete_op_list; + // Insert cpu context. If the pass failed, the context op will be removed by + // CanonicalizerPass. + auto context_op = builder.create( + block.front().getLoc(), + infrt::phi::ContextType::get(builder.getContext(), + infrt::TargetType::CPU)); + + for (auto& org_op : block) { + if (auto op = llvm::dyn_cast<::infrt::phi::LoadCombinedParamsOp>(org_op)) { + model_path = op.model_path(); + params_path = op.params_path(); + + // Load params. + auto map = ::infrt::kernel::phi::LoadCombinedParameters( + model_path.str(), params_path.str()); + bool delete_load_combined_op{false}; + // Find all use of map. + for (auto map_arg : op.getODSResults(0)) { + for (mlir::Operation* user_op : map_arg.getUsers()) { + if (auto tensor_map_get_op = + llvm::dyn_cast<::infrt::phi::TensorMapGetTensorOp>(user_op)) { + ::llvm::StringRef arg_name = tensor_map_get_op.name(); + ::phi::DenseTensor* tensor = map.GetDenseTensor(arg_name.str()); + if (tensor->dtype() != ::phi::DataType::FLOAT32) { + CHECK(false) + << "the weight tensor type now only support float32."; + } + + builder.setInsertionPoint(tensor_map_get_op); + auto inited_weight_op = + builder.create<::infrt::phi::CreateHostInitedDenseTensorOp>( + tensor_map_get_op.getLoc(), + tensor_map_get_op.output().getType(), + context_op.output(), + builder.getI64ArrayAttr( + {tensor->dims().Get(), + tensor->dims().Get() + tensor->dims().size()}), + ::infrt::LayoutAttr::get(builder.getContext(), + ::infrt::LayoutType::NCHW), + builder.getI64ArrayAttr({0}), + builder.getF32ArrayAttr( + {tensor->data(), + static_cast(tensor->numel())})); + tensor_map_get_op.replaceAllUsesWith(inited_weight_op.output()); + delete_load_combined_op = true; + delete_op_list.push_back(tensor_map_get_op); + } + } + } + if (delete_load_combined_op) { + delete_op_list.push_back(op); + } + } + } + + // remove all map releation op. + for (size_t i = 0; i < delete_op_list.size(); ++i) { + delete_op_list[i]->erase(); + } +} + +} // namespace + +std::unique_ptr infrt::CreateInfrtWeightsUnfoldPass() { + return std::make_unique(); +} diff --git a/paddle/infrt/dialect/infrt/pass/infrt_weights_unfold_pass.h b/paddle/infrt/dialect/infrt/pass/infrt_weights_unfold_pass.h new file mode 100644 index 00000000000..09effe54e69 --- /dev/null +++ b/paddle/infrt/dialect/infrt/pass/infrt_weights_unfold_pass.h @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace infrt { +/* + * InfrtWeightsFoldPass. + */ +std::unique_ptr CreateInfrtWeightsUnfoldPass(); + +} // namespace infrt diff --git a/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td b/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td index dbf3b853307..c4707c367bc 100644 --- a/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td +++ b/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td @@ -28,6 +28,17 @@ class CreateDenseTensorOp let results = (outs DenseTensor:$output); } +def CreateHostInitedDenseTensorOp : PDT_Op<"create_host_inited_dense_tensor.f32", [NoSideEffect]> { + let arguments = (ins + Context:$context, + I64ArrayAttr:$dims, + LayoutAttr:$layout, + I64ArrayAttr:$lod, + F32ArrayAttr:$values + ); + let results = (outs DenseTensor:$output); +} + def CreateInitedCpuFLOAT32DenseTensorOp : PDT_Op<"create_inited_dense_tensor.cpu.f32", [NoSideEffect]> { let arguments = (ins Context:$context, I64ArrayAttr:$dims, diff --git a/paddle/infrt/dialect/tensorrt/convert.h b/paddle/infrt/dialect/tensorrt/convert.h index 1890c839eff..fc607aa1127 100644 --- a/paddle/infrt/dialect/tensorrt/convert.h +++ b/paddle/infrt/dialect/tensorrt/convert.h @@ -13,14 +13,19 @@ // limitations under the License. #pragma once +#include #include #include + +#include "paddle/infrt/dialect/infrt/common/types.h" #include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h" +#include "paddle/infrt/dialect/pd/ir/pd_ops.h" +#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h" namespace infrt { namespace trt { -static mlir::Value createTRTConv2dOp(mlir::PatternRewriter &rewriter, +static mlir::Value createTRTConv2dOp(mlir::PatternRewriter &rewriter, // NOLINT mlir::Operation *op) { auto conv_op = ::llvm::dyn_cast(op); ::mlir::SmallVector<::mlir::Value, 4> operands; @@ -75,11 +80,93 @@ static mlir::Value createTRTConv2dOp(mlir::PatternRewriter &rewriter, op->getLoc(), resultTypes, operands, attributes); } -static mlir::Value createTRTShuffledOp(mlir::PatternRewriter &rewriter, - mlir::Operation *op, - const mlir::Value &input, - const mlir::Attribute &start, - const mlir::Attribute &stop) { +static inline mlir::ArrayAttr TransposeWeight( + mlir::PatternRewriter &builder, // NOLINT + const mlir::ArrayAttr &weight, + const mlir::ArrayAttr &dims) { + CHECK_EQ(dims.size(), 2U); + CHECK(!dims.empty()); + CHECK(dims[0].getType().isInteger(64)); + CHECK(!weight.empty()); + CHECK(weight[0].getType().isF32()); + + int row = dims[0].cast().getInt(); + int col = dims[1].cast().getInt(); + std::vector trans_weight(weight.size()); + for (int i = 0; i < row; ++i) { + for (int j = 0; j < col; ++j) { + trans_weight[j * row + i] = + weight[i * col + j].cast().getValueAsDouble(); + } + } + return builder.getF32ArrayAttr(trans_weight); +} + +// matmul_y and elt_y is weights. +inline ::llvm::SmallVector<::mlir::Value, 4> createTrtFcOp( + mlir::PatternRewriter &builder, // NOLINT + mlir::Value matmul_x, + mlir::Value matmul_y, + mlir::Value elt_y, + mlir::Value elt_out) { + ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops; + + auto *y_producer = matmul_y.getDefiningOp(); + auto create_inited_tensor_op = + llvm::dyn_cast<::infrt::phi::CreateHostInitedDenseTensorOp>(y_producer); + CHECK_NOTNULL(create_inited_tensor_op); + + mlir::ArrayAttr dims = create_inited_tensor_op.dims(); + CHECK_EQ(dims.size(), 2U); + + std::vector new_dims(dims.size()); + CHECK(!dims.empty()); + CHECK(dims[0].getType().isIntOrIndex()); + for (size_t i = 0; i < new_dims.size(); ++i) { + new_dims[i] = dims[dims.size() - 1 - i].cast().getInt(); + } + auto insert_point = builder.saveInsertionPoint(); + builder.setInsertionPoint(create_inited_tensor_op); + auto new_inited_op = + builder.create<::infrt::phi::CreateHostInitedDenseTensorOp>( + create_inited_tensor_op->getLoc(), + create_inited_tensor_op.output().getType(), + create_inited_tensor_op.context(), + builder.getI64ArrayAttr(new_dims), + ::infrt::LayoutAttr::get(builder.getContext(), + ::infrt::LayoutType::NCHW), + create_inited_tensor_op.lod(), + TransposeWeight(builder, create_inited_tensor_op.values(), dims)); + builder.replaceOp(create_inited_tensor_op, new_inited_op->getResults()); + builder.restoreInsertionPoint(insert_point); + + auto ods_loc = builder.getFusedLoc({y_producer->getLoc()}); + ::infrt::trt::FullyConnectedOp fc_op; + { + ::mlir::SmallVector<::mlir::Type, 4> tblgen_types; + + fc_op = builder.create<::infrt::trt::FullyConnectedOp>( + ods_loc, + elt_out.getType(), + matmul_x, + new_inited_op.output(), + elt_y, + builder.getSI32IntegerAttr(new_dims[0])); + } + + ::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; + for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{fc_op.getODSResults(0)}) { + tblgen_repl_values.push_back(v); + } + return tblgen_repl_values; +} + +static mlir::Value createTRTShuffledOp( + mlir::PatternRewriter &rewriter, // NOLINT + mlir::Operation *op, + const mlir::Value &input, + const mlir::Attribute &start, + const mlir::Attribute &stop) { auto flatten_op = ::llvm::dyn_cast(op); ::mlir::SmallVector<::mlir::Value, 4> operands; operands.push_back(input); diff --git a/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td b/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td index b153e84b53f..ad60906ecec 100644 --- a/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td +++ b/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td @@ -43,6 +43,12 @@ def PD2TRT_SoftMax_Lower : Pat< (PD_SoftmaxOp $Input, $axis, $_), (TRT_SoftMaxOp $Input, $axis)>; +// pd.matmul_v2 + pd.elementwise_add -> trt.fc +def createTrtFcOp : NativeCodeCall<"::infrt::trt::createTrtFcOp($_builder, $0, $1, $2, $3)">; +def PD2TRT_Fc_Lower : Pat< + (PD_Elementwise_addOp:$elt_out (PD_Matmul_v2Op $X, $Y, $trans_x, $trans_y), $elt_y, $axis), + (createTrtFcOp $X, $Y, $elt_y, $elt_out)>; + def createTRTShuffledOp : NativeCodeCall<"createTRTShuffledOp($_builder, $0.getDefiningOp(), $1, $2, $3)">; def PD2TRT_Flatten_contiguous_range_Lower : Pat< diff --git a/paddle/infrt/dialect/tensorrt/trt_exec.cc b/paddle/infrt/dialect/tensorrt/trt_exec.cc index be239255ffb..b37186ada6d 100644 --- a/paddle/infrt/dialect/tensorrt/trt_exec.cc +++ b/paddle/infrt/dialect/tensorrt/trt_exec.cc @@ -16,6 +16,7 @@ #include #include #include "paddle/infrt/common/global.h" +#include "paddle/infrt/dialect/infrt/pass/infrt_weights_unfold_pass.h" #include "paddle/infrt/dialect/mlir_loader.h" #include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h" #include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h" @@ -42,6 +43,8 @@ #include "paddle/infrt/kernel/phi/registry.h" #endif +#include + int main(int argc, char** argv) { static llvm::cl::opt input_file( llvm::cl::Positional, @@ -73,11 +76,13 @@ int main(int argc, char** argv) { mlir::PassManager pm(context); mlir::OpPassManager& trt_pass_manager = pm.nest(); + trt_pass_manager.addPass(::infrt::CreateInfrtWeightsUnfoldPass()); trt_pass_manager.addPass(std::make_unique()); trt_pass_manager.addPass(std::make_unique()); trt_pass_manager.addPass(std::make_unique(1)); trt_pass_manager.addPass(std::make_unique()); trt_pass_manager.addPass(infrt::trt::createTrtTypeConvertPass()); + trt_pass_manager.addPass(::mlir::createCanonicalizerPass()); if (mlir::failed(pm.run(*module))) { std::cout << "\npass failed!\n" << std::endl; return 4; diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc index c575d05949a..55964b77e21 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc @@ -14,6 +14,7 @@ #include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h" +#include #include #include #include @@ -133,8 +134,7 @@ void topoSortBlock(mlir::Block &body) { // NOLINT for (auto it = body.rbegin(); it != body.rend(); ++it) { toSort.insert(&*it); } - llvm::SetVector result = - mlir::topologicalSort(std::move(toSort)); + llvm::SetVector result = mlir::topologicalSort(toSort); for (auto *op : result) { op->moveBefore(body.getTerminator()); } @@ -177,7 +177,9 @@ void TRTGraphFusePass::runOnFunction() { if (changed) break; } } while (changed); - topoSortBlock(body); + + // TODO(wilber): Implement a toposort for efficiency. + // topoSortBlock(body); } } // namespace trt } // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_type_convert_pass.cc b/paddle/infrt/dialect/tensorrt/trt_type_convert_pass.cc index cd55fef696a..0ed79c79db6 100644 --- a/paddle/infrt/dialect/tensorrt/trt_type_convert_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_type_convert_pass.cc @@ -15,6 +15,7 @@ #include "paddle/infrt/dialect/tensorrt/trt_type_convert_pass.h" #include +#include #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" @@ -57,6 +58,10 @@ void TrtTypeConvertPass::runOnFunction() { ::infrt::LayoutType layout = ::infrt::LayoutType::NCHW; ::infrt::TargetType target = ::infrt::TargetType::GPU; + const std::set inited_op_repr{ + "phi_dt.tensor_map_get_tensor", + "phi_dt.create_inited_dense_tensor.cpu.f32", + "phi_dt.create_host_inited_dense_tensor.f32"}; for (auto& op : worklist) { if (auto tensor_map_get_op = llvm::dyn_cast<::infrt::phi::TensorMapGetTensorOp>(op)) { @@ -66,16 +71,25 @@ void TrtTypeConvertPass::runOnFunction() { mlir_ctx, t.getTarget(), t.getPrecision(), layout); res.setType(replace_type); } - } - if (auto create_engine = llvm::dyn_cast<::infrt::trt::CreateEngineOp>(op)) { + } else if (auto create_inited_tensor_op = + llvm::dyn_cast<::infrt::phi::CreateHostInitedDenseTensorOp>( + op)) { + auto res = create_inited_tensor_op.output(); + if (auto t = res.getType().dyn_cast<::infrt::DenseTensorType>()) { + auto replace_type = ::infrt::DenseTensorType::get( + mlir_ctx, t.getTarget(), t.getPrecision(), layout); + res.setType(replace_type); + } + } else if (auto create_engine = + llvm::dyn_cast<::infrt::trt::CreateEngineOp>(op)) { // Insert `infrt.gpu.memcpy` op. for (auto arg : create_engine.getOperands()) { if (mlir::Operation* producer = arg.getDefiningOp()) { if (arg.getType().isa<::infrt::DenseTensorType>()) { builder.setInsertionPointAfter(producer); auto t = arg.getType().dyn_cast<::infrt::DenseTensorType>(); - if (producer->getName().getStringRef() != - "phi_dt.tensor_map_get_tensor" && + if (!inited_op_repr.count( + producer->getName().getStringRef().str()) && t.getTarget() != ::infrt::TargetType::GPU) { auto replace_type = ::infrt::DenseTensorType::get( mlir_ctx, target, t.getPrecision(), layout); @@ -86,7 +100,7 @@ void TrtTypeConvertPass::runOnFunction() { arg, llvm::dyn_cast<::infrt::phi::CreateGPUContextOp>(ctx_op) .output(), - mlir::BoolAttr::get(mlir_ctx, /*d2h*/ false)); + builder.getBoolAttr(false)); arg.replaceAllUsesExcept(mem_cpy_op.output(), mem_cpy_op); } } @@ -104,7 +118,7 @@ void TrtTypeConvertPass::runOnFunction() { blockArg, llvm::dyn_cast<::infrt::phi::CreateGPUContextOp>(ctx_op) .output(), - mlir::BoolAttr::get(mlir_ctx, /*d2h*/ false)); + builder.getBoolAttr(false)); arg.replaceAllUsesExcept(mem_cpy_op.output(), mem_cpy_op); } } @@ -147,7 +161,7 @@ void TrtTypeConvertPass::runOnFunction() { arg, llvm::dyn_cast<::infrt::phi::CreateGPUContextOp>(ctx_op) .output(), - mlir::BoolAttr::get(mlir_ctx, /*d2h*/ true)); + builder.getBoolAttr(true)); arg.replaceAllUsesExcept(mem_cpy_op.output(), mem_cpy_op); } } diff --git a/paddle/infrt/kernel/phi/dense_tensor_kernels.cc b/paddle/infrt/kernel/phi/dense_tensor_kernels.cc index 844db8aecb2..fe1cda0e100 100644 --- a/paddle/infrt/kernel/phi/dense_tensor_kernels.cc +++ b/paddle/infrt/kernel/phi/dense_tensor_kernels.cc @@ -77,6 +77,27 @@ namespace phi { return dense_tensor; } +::phi::DenseTensor CreateHostInitedDenseTensorF32( + const ::phi::CPUContext& context, + host_context::Attribute> dims, + host_context::Attribute> lod, + host_context::Attribute<::infrt::LayoutType> layout, + host_context::Attribute> values) { + ::phi::DenseTensor dense_tensor( + const_cast<::phi::Allocator*>(&context.GetAllocator()), + ::phi::DenseTensorMeta( + ConvertPrecisionToPhi(::infrt::PrecisionType::FLOAT32), + ::phi::make_ddim(dims.get()), + ConvertLayoutToPhi(layout.get()), + {})); + CHECK_EQ(dense_tensor.numel(), static_cast(values.get().size())); + float* data = dense_tensor.mutable_data(::phi::CPUPlace()); + for (int64_t i = 0; i < dense_tensor.numel(); ++i) { + data[i] = values.get()[i]; + } + return dense_tensor; +} + ::phi::DenseTensor CreateGPUDenseTensor( const ::phi::GPUContext& context, host_context::Attribute> dims, diff --git a/paddle/infrt/kernel/phi/dense_tensor_kernels.h b/paddle/infrt/kernel/phi/dense_tensor_kernels.h index 60cc63a928f..b1075444731 100644 --- a/paddle/infrt/kernel/phi/dense_tensor_kernels.h +++ b/paddle/infrt/kernel/phi/dense_tensor_kernels.h @@ -39,6 +39,13 @@ namespace phi { host_context::Attribute<::infrt::LayoutType> layout, host_context::Attribute value); +::phi::DenseTensor CreateHostInitedDenseTensorF32( + const ::phi::CPUContext& context, + host_context::Attribute> dims, + host_context::Attribute> lod, + host_context::Attribute<::infrt::LayoutType> layout, + host_context::Attribute> values); + ::phi::DenseTensor CreateGPUDenseTensor( const ::phi::GPUContext& context, host_context::Attribute> dims, diff --git a/paddle/infrt/kernel/phi/registry.cc b/paddle/infrt/kernel/phi/registry.cc index 04778811250..928209ab182 100644 --- a/paddle/infrt/kernel/phi/registry.cc +++ b/paddle/infrt/kernel/phi/registry.cc @@ -43,9 +43,15 @@ void RegisterPhiKernels(host_context::KernelRegistry* registry) { INFRT_KERNEL(infrt::kernel::phi::CreateInitedDenseTensorF32), {"dims", "lod", "layout", "value"}); + registry->AddKernel( + "phi_dt.create_host_inited_dense_tensor.f32", + INFRT_KERNEL(infrt::kernel::phi::CreateHostInitedDenseTensorF32), + {"dims", "lod", "layout", "values"}); + registry->AddKernel("phi_dt.fill_dense_tensor.f32", INFRT_KERNEL(infrt::kernel::phi::FillDenseTensorF32), {"value"}); + registry->AddKernel("phi_dt.print_tensor", INFRT_KERNEL(infrt::kernel::phi::PrintDenseTensor)); diff --git a/paddle/infrt/kernel/tensorrt/registry.cc b/paddle/infrt/kernel/tensorrt/registry.cc index a37e3c0f7f2..197eb1ecb8a 100644 --- a/paddle/infrt/kernel/tensorrt/registry.cc +++ b/paddle/infrt/kernel/tensorrt/registry.cc @@ -23,7 +23,8 @@ namespace kernel { void RegisterTrtKernels(host_context::KernelRegistry* registry) { registry->AddKernel("trt.create_engine", - INFRT_KERNEL(tensorrt::CreateTrtEngine)); + INFRT_KERNEL(tensorrt::CreateTrtEngine), + {"run_once"}); registry->AddKernel("trt.inspect_engine", INFRT_KERNEL(tensorrt::PrintTrtLayer)); registry->AddKernel("trt.compute", INFRT_KERNEL(tensorrt::TrtEngineCompute)); diff --git a/paddle/infrt/kernel/tensorrt/trt_kernels.cc b/paddle/infrt/kernel/tensorrt/trt_kernels.cc index 2f73c6b13f4..a6d740f0184 100644 --- a/paddle/infrt/kernel/tensorrt/trt_kernels.cc +++ b/paddle/infrt/kernel/tensorrt/trt_kernels.cc @@ -108,7 +108,10 @@ namespace tensorrt { } else { // TODO(wilber): Replace with the op name that generates the weights. std::unordered_set weight_flags{ - "phi_dt.tensor_map_get_tensor", "phi_dt.create_dense_tensor.cpu"}; + "phi_dt.tensor_map_get_tensor", + "phi_dt.create_dense_tensor.cpu", + "phi_dt.create_inited_dense_tensor.cpu.f32", + "phi_dt.create_host_inited_dense_tensor.f32"}; if (!weight_flags.count( operand.getDefiningOp()->getName().getStringRef().str())) { trt_bind_inputs[input_name] = t; -- GitLab