From 9f0958fa9a868f943f7997416b4ae077682cb5d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Tue, 4 Jan 2022 13:03:12 +0800 Subject: [PATCH] [infrt] add trt_graph_split_pass for infrt. test=develop (#38494) --- paddle/infrt/CMakeLists.txt | 8 ++ paddle/infrt/dialect/mlir_loader.cc | 4 +- .../infrt/dialect/mlir_tests/paddle_ops.mlir | 4 +- paddle/infrt/dialect/mlir_tests/rewrite.mlir | 14 ++-- .../dialect/mlir_tests/rewrite_conv_bn.mlir | 2 +- paddle/infrt/dialect/mlir_tests/trt_ops.mlir | 14 ++-- paddle/infrt/dialect/pd_ops.td | 4 +- paddle/infrt/dialect/tensorrt/CMakeLists.txt | 4 + paddle/infrt/dialect/tensorrt/trt_exec.cc | 48 ++++++++++++ .../dialect/tensorrt/trt_graph_fuse_pass.cc | 77 +++++++++---------- .../dialect/tensorrt/trt_graph_fuse_pass.h | 5 +- .../dialect/tensorrt/trt_graph_split_pass.cc | 50 ++++++++++++ .../dialect/tensorrt/trt_graph_split_pass.h | 60 +++++++++++++++ .../dialect/tensorrt/trt_op_teller_pass.cc | 2 - .../dialect/tensorrt/trt_op_teller_pass.h | 6 +- paddle/infrt/dialect/tensorrt/trt_ops.h | 0 paddle/scripts/infrt_build.sh | 9 ++- 17 files changed, 238 insertions(+), 73 deletions(-) create mode 100644 paddle/infrt/dialect/tensorrt/trt_exec.cc create mode 100644 paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc create mode 100644 paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h mode change 100755 => 100644 paddle/infrt/dialect/tensorrt/trt_ops.h diff --git a/paddle/infrt/CMakeLists.txt b/paddle/infrt/CMakeLists.txt index 3bcc9f59b2..8f05d286bf 100644 --- a/paddle/infrt/CMakeLists.txt +++ b/paddle/infrt/CMakeLists.txt @@ -1,6 +1,14 @@ if (NOT WITH_INFRT) return() endif() + +# compile flags +set(INFRT_FLAGS -Wno-comment) +foreach(flag ${INFRT_FLAGS}) + safe_set_cflag(CMAKE_C_FLAGS ${flag}) + safe_set_cxxflag(CMAKE_CXX_FLAGS ${flag}) +endforeach() + set(INFRT_SOURCE_DIR "${PADDLE_SOURCE_DIR}/paddle/infrt" ) set(INFRT_BINARY_DIR "${PADDLE_BINARY_DIR}/paddle/infrt" ) set(INFRT_TEST_TARGETS CACHE INTERNAL "") diff --git a/paddle/infrt/dialect/mlir_loader.cc b/paddle/infrt/dialect/mlir_loader.cc index 5a6654b6c9..b318a6a763 100644 --- a/paddle/infrt/dialect/mlir_loader.cc +++ b/paddle/infrt/dialect/mlir_loader.cc @@ -36,7 +36,9 @@ mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context, const std::string& mlir_source) { // context->allowUnregisteredDialects(); RegisterCinnDialects(context->getDialectRegistry()); - context->getDialectRegistry().insert(); + // Currenetly, We only used the CinnDialect and mlir::BuiltinDialect is + // enough。Don't need StandardOpsDialect. + // context->getDialectRegistry().insert(); mlir::ScopedDiagnosticHandler scope_handler( context, [](mlir::Diagnostic& diag) { diff --git a/paddle/infrt/dialect/mlir_tests/paddle_ops.mlir b/paddle/infrt/dialect/mlir_tests/paddle_ops.mlir index ee9fb1740a..6618fe66bd 100644 --- a/paddle/infrt/dialect/mlir_tests/paddle_ops.mlir +++ b/paddle/infrt/dialect/mlir_tests/paddle_ops.mlir @@ -1,6 +1,6 @@ func @ops() { - %a = pd.feed() : tensor - %b = pd.feed() : tensor + %a = pd.feed() {name="input0"} : tensor + %b = pd.feed() {name="input1"}: tensor %c = "pd.matmul"(%a, %b) {transpose_x=true, transpose_y=false} : (tensor, tensor) -> tensor diff --git a/paddle/infrt/dialect/mlir_tests/rewrite.mlir b/paddle/infrt/dialect/mlir_tests/rewrite.mlir index 39f5e8f595..bfad9d1f69 100644 --- a/paddle/infrt/dialect/mlir_tests/rewrite.mlir +++ b/paddle/infrt/dialect/mlir_tests/rewrite.mlir @@ -1,13 +1,13 @@ // CHECK-LABEL: @main func @main() -> tensor { - %a = "pd.feed"() : () -> tensor - %b = "pd.feed"() : () -> tensor - %bias = "pd.feed"() : () -> tensor + %a = "pd.feed"() {name="input0"} : () -> tensor + %b = "pd.feed"() {name="input1"} : () -> tensor + %bias = "pd.feed"() {name="input2"} : () -> tensor - %b1 = "pd.feed"() : () -> tensor - %b2 = "pd.feed"() : () -> tensor - %bias1 = "pd.feed"() : () -> tensor - %bias2 = "pd.feed"() : () -> tensor + %b1 = "pd.feed"() {name="input3"} : () -> tensor + %b2 = "pd.feed"() {name="input4"} : () -> tensor + %bias1 = "pd.feed"() {name="input5"} : () -> tensor + %bias2 = "pd.feed"() {name="input6"} : () -> tensor %c = "pd.matmul"(%a, %b) {transpose_y=false} : (tensor, tensor) -> tensor %d = "pd.elementwise_add"(%c, %bias) {axis=1:i32} : (tensor, tensor) -> tensor diff --git a/paddle/infrt/dialect/mlir_tests/rewrite_conv_bn.mlir b/paddle/infrt/dialect/mlir_tests/rewrite_conv_bn.mlir index 1360efe17b..9ea1ec0ebc 100644 --- a/paddle/infrt/dialect/mlir_tests/rewrite_conv_bn.mlir +++ b/paddle/infrt/dialect/mlir_tests/rewrite_conv_bn.mlir @@ -1,6 +1,6 @@ // CHECK-LABEL: @main func @main() -> tensor { - %a = "pd.feed"() : () -> tensor + %a = "pd.feed"() {name="input0"} : () -> tensor %filter = "pd.constant"(){value = dense<1.000000e+00> : tensor<3x64x3x3xf32>} : () -> tensor<3x64x3x3xf32> %bias = "pd.constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32> diff --git a/paddle/infrt/dialect/mlir_tests/trt_ops.mlir b/paddle/infrt/dialect/mlir_tests/trt_ops.mlir index 539ad875f7..009b6d1c19 100644 --- a/paddle/infrt/dialect/mlir_tests/trt_ops.mlir +++ b/paddle/infrt/dialect/mlir_tests/trt_ops.mlir @@ -1,13 +1,11 @@ // CHECK-LABEL: @main func @main() -> tensor { - %a = "pd.feed"() : () -> tensor - %b = "pd.feed"() : () -> tensor - %bias = "pd.feed"() : () -> tensor - %c = "pd.feed"() : () -> tensor - %b1 = "pd.feed"() : () -> tensor - %b2 = "pd.feed"() : () -> tensor - %bias1 = "pd.feed"() : () -> tensor - %bias2 = "pd.feed"() : () -> tensor + %bias = "pd.feed"() {name="input0"} : () -> tensor + %c = "pd.feed"() {name="input1"} : () -> tensor + %b1 = "pd.feed"() {name="input2"} : () -> tensor + %b2 = "pd.feed"() {name="input3"} : () -> tensor + %bias1 = "pd.feed"() {name="input4"} : () -> tensor + %bias2 = "pd.feed"() {name="input5"} : () -> tensor %d = "pd.elementwise_add"(%c, %bias) {axis=1:i32} : (tensor, tensor) -> tensor %e = "pd.relu6"(%d) {} : (tensor) -> tensor diff --git a/paddle/infrt/dialect/pd_ops.td b/paddle/infrt/dialect/pd_ops.td index ff049689ed..b020b7ad5d 100644 --- a/paddle/infrt/dialect/pd_ops.td +++ b/paddle/infrt/dialect/pd_ops.td @@ -6,14 +6,14 @@ include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/IR/OpBase.td" include "paddle/infrt/dialect/pd_op_base.td" -def PD_FeedOp : PD_Op<"feed", [NoSideEffect]> { +def PD_FeedOp : PD_Op<"feed"> { let summary = "Feed Op"; let description = [{ Feed a tensor into the model. }]; - let arguments = (ins); + let arguments = (ins StrAttr:$name); let results = (outs PD_Tensor:$out); let assemblyFormat = [{ diff --git a/paddle/infrt/dialect/tensorrt/CMakeLists.txt b/paddle/infrt/dialect/tensorrt/CMakeLists.txt index e1e3f40075..794266513e 100755 --- a/paddle/infrt/dialect/tensorrt/CMakeLists.txt +++ b/paddle/infrt/dialect/tensorrt/CMakeLists.txt @@ -4,5 +4,9 @@ gather_srcs(infrt_src SRCS trt_ops.cc trt_op_teller_pass.cc trt_graph_fuse_pass.cc + trt_graph_split_pass.cc ) mlir_tablegen_on(trt_ops) + +add_executable(trt-exec trt_exec.cc) +target_link_libraries(trt-exec infrt ${MLIR_IR_LIBS}) diff --git a/paddle/infrt/dialect/tensorrt/trt_exec.cc b/paddle/infrt/dialect/tensorrt/trt_exec.cc new file mode 100644 index 0000000000..dc0f2acb2b --- /dev/null +++ b/paddle/infrt/dialect/tensorrt/trt_exec.cc @@ -0,0 +1,48 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include "llvm/Support/CommandLine.h" +#include "mlir/Pass/PassManager.h" +#include "paddle/infrt/common/global.h" +#include "paddle/infrt/dialect/mlir_loader.h" +#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h" +#include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h" +#include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h" + +int main(int argc, char** argv) { + static llvm::cl::opt input_file( + llvm::cl::Positional, + llvm::cl::desc("Specify input filename"), + llvm::cl::init("-")); + + llvm::cl::ParseCommandLineOptions(argc, argv); + + mlir::MLIRContext* context = infrt::Global::getMLIRContext(); + auto module = infrt::dialect::LoadMlirFile(input_file.c_str(), context); + + module->dump(); + mlir::PassManager pm(context); + + mlir::OpPassManager& trt_pass_manager = pm.nest(); + trt_pass_manager.addPass(std::make_unique()); + trt_pass_manager.addPass(std::make_unique()); + trt_pass_manager.addPass(std::make_unique(10)); + if (mlir::failed(pm.run(*module))) { + std::cout << "\npass failed!\n" << std::endl; + return 4; + } + module->dump(); + return 0; +} diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc index a44fcece43..181f462962 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc @@ -18,6 +18,7 @@ #include #include #include "llvm/ADT/SetVector.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/IR/Builders.h" #include "paddle/infrt/dialect/pd_ops.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h" @@ -25,42 +26,31 @@ namespace infrt { namespace trt { namespace { - -// FlexibleDFS -// do reverse dfs. calls leave(node) after visiting all parents of node. -// Reference the function with the same name but defined in: +// ReverseDfs +// do reverse dfs. calls "func" to search when visit a node. +// The elements in 'source' can't be nullptr. +// Reference the function nameed "FlexibleDFS" but defined in: // paddle/fluid/framework/ir/subgraph_detector.cc. -void FlexibleDFS(const std::vector<::mlir::Operation *> &source, - const std::function &leave) { - typedef struct { - ::mlir::Operation *node; - bool leave; - } FNode; - std::vector stack; - for (auto &node : source) { - stack.push_back(FNode{node, false}); - } +bool reverseDfs(std::vector<::mlir::Operation *> source, + const std::function &func) { std::unordered_set visited; - while (!stack.empty()) { - auto fnode = stack.back(); - stack.pop_back(); - - if (fnode.leave) { - if (leave && !leave(fnode.node)) return; - } - if (visited.count(fnode.node)) continue; - visited.insert(fnode.node); - - if (leave) stack.push_back(FNode{fnode.node, true}); - auto values = fnode.node->getOperands(); + while (!source.empty()) { + auto node = source.back(); + source.pop_back(); + if (visited.count(node)) continue; + visited.insert(node); + if (func(node)) return true; + auto values = node->getOperands(); for (auto value : values) { + // if the value is a block argument, the node is nullptr. ::mlir::Operation *node = value.getDefiningOp(); - if (!visited.count(node)) { - stack.push_back(FNode{node, false}); + if (node != nullptr && !visited.count(node)) { + source.emplace_back(node); } } } + return false; } // merge the first&second graph op to a new graph op. @@ -136,6 +126,20 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT second.erase(); } +// Topological sort the function op. +void topoSortBlock(mlir::Block &body) { // NOLINT + llvm::SetVector toSort; + if (body.empty()) return; + for (auto it = body.rbegin(); it != body.rend(); ++it) { + toSort.insert(&*it); + } + llvm::SetVector result = + ::mlir::topologicalSort(std::move(toSort)); + for (auto *op : result) { + op->moveBefore(body.getTerminator()); + } +} + } // namespace // Implementation of the trtGraphFusePass. @@ -158,21 +162,14 @@ void trtGraphFusePass::runOnFunction() { std::vector<::mlir::Operation *> source_nodes; for (auto operand : user_op->getOperands()) { auto input = operand.getDefiningOp(); - if (input != &op) { + if (input != &op && input != nullptr) { source_nodes.push_back(input); } } // Reverse DFS from the source_nodes. - bool have_excess_path = false; - FlexibleDFS(source_nodes, - [&have_excess_path, &op](const ::mlir::Operation *n) { - if (n == &op) { - have_excess_path = true; - return false; - } - return true; - }); - if (!have_excess_path) { + if (!reverseDfs(source_nodes, [&op](const ::mlir::Operation *n) { + return n == &op; + })) { mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op); changed = true; break; @@ -181,7 +178,7 @@ void trtGraphFusePass::runOnFunction() { if (changed) break; } } while (changed); + topoSortBlock(body); } - } // namespace trt } // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h index d5a019eff8..e7134e88f3 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h @@ -25,7 +25,7 @@ namespace trt { * source func: * * func @main() -> tensor { - * %a = "pd.feed"() + * %a = "pd.feed"()... * %c = "pd.graph"(%a) { * %m = "pd.conv2d"(%a)... * "pd.fetch" %m @@ -42,7 +42,7 @@ namespace trt { * * destination func: * func @main() -> tensor { - * %a = "pd.feed"() + * %a = "pd.feed"()... * %d, %f = "pd.graph"(%a) { * %m = "pd.conv2d"(%a)... * %n = "pd.conv3d"(%m)... @@ -58,6 +58,5 @@ class trtGraphFusePass ::llvm::StringRef getName() const override { return "trtGraphFusePass"; } void runOnFunction() override; }; - } // namespace trt } // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc new file mode 100644 index 0000000000..2b45364de2 --- /dev/null +++ b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h" + +#include "mlir/IR/Builders.h" +#include "paddle/infrt/dialect/pd_ops.h" +#include "paddle/infrt/dialect/tensorrt/trt_ops.h" + +namespace infrt { +namespace trt { +// Implementation of the trtGraphSplitPass。 +void trtGraphSplitPass::runOnFunction() { + std::vector<::mlir::pd::GraphOp> worklist; + ::mlir::Block& block = getFunction().front(); + for (auto& op : block) { + ::mlir::pd::GraphOp graph_op = + ::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(&op); + if (nullptr != graph_op && + graph_op.getBody()->getOperations().size() <= min_subgraph_size_) { + worklist.push_back(graph_op); + } + } + while (!worklist.empty()) { + ::mlir::pd::GraphOp graph_op = worklist.back(); + worklist.pop_back(); + ::mlir::Block* body = graph_op.getBody(); + auto fetch_op = body->getTerminator(); + graph_op.replaceAllUsesWith(fetch_op->getOperands()); + auto copy_range = body->without_terminator(); + block.getOperations().splice(::mlir::Block::iterator(graph_op), + body->getOperations(), + copy_range.begin(), + copy_range.end()); + graph_op.erase(); + } +} +} // namespace trt +} // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h new file mode 100644 index 0000000000..092df0cf83 --- /dev/null +++ b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h @@ -0,0 +1,60 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "mlir/Pass/Pass.h" + +namespace infrt { +namespace trt { +/* + * trtGraphSplitPass. + * + * Splite the graph op when the number of operations is too small. + * The feature is the opposite of 'trtOpTellerPass'. + * + * source func: + * + * func @main() -> tensor { + * %a = "pd.feed"()... + * %d, %f = "pd.graph"(%a) { + * %m = "pd.conv2d"(%a)... + * %n = "pd.conv3d"(%m)... + * %s = "pd.conv2d"(%a)... + * "pd.fetch" %n, %s + * } ... + * "pd.fetch" %d, %f + * } + * + * destination func: + * func @main() -> tensor { + * %a = "pd.feed"()... + * %c = "pd.conv2d"(%a) ... + * %d = "pd.conv3d"(%c) ... + * %f = "pd.conv2d"(%a) ... + * "pd.fetch" %d, %f + * } + */ +class trtGraphSplitPass + : public ::mlir::PassWrapper { + public: + ::llvm::StringRef getName() const override { return "trtGraphSplitPass"; } + void runOnFunction() override; + explicit trtGraphSplitPass(size_t min_subgraph_size = 3) + : min_subgraph_size_(min_subgraph_size) {} + + private: + size_t min_subgraph_size_; +}; +} // namespace trt +} // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc index 14cffdbf98..7b7fbb05c1 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc @@ -20,7 +20,6 @@ namespace infrt { namespace trt { - // Implementation of the trtOpTellerPass。 void trtOpTellerPass::runOnFunction() { ::mlir::Block &body = getFunction().front(); @@ -60,6 +59,5 @@ void trtOpTellerPass::runOnFunction() { builder.create(loc, op->getResults()); } } - } // namespace trt } // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h index e9cebaf36d..b03945b345 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h @@ -17,7 +17,6 @@ namespace infrt { namespace trt { - /* * trtOpTellerPass. * @@ -26,7 +25,7 @@ namespace trt { * source func: * * func @main() -> tensor { - * %a = "pd.feed"() + * %a = "pd.feed"()... * %c = "pd.conv2d"(%a) ... * %d = "pd.conv3d"(%c) ... * %f = "pd.conv2d"(%a) ... @@ -35,7 +34,7 @@ namespace trt { * * destination func: * func @main() -> tensor { - * %a = "pd.feed"() + * %a = "pd.feed"()... * %c = "pd.graph"(%a) { * %m = "pd.conv2d"(%a)... * "pd.fetch" %m @@ -59,6 +58,5 @@ class trtOpTellerPass ::llvm::StringRef getName() const override { return "trtOpTellerPass"; } void runOnFunction() override; }; - } // namespace trt } // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.h b/paddle/infrt/dialect/tensorrt/trt_ops.h old mode 100755 new mode 100644 diff --git a/paddle/scripts/infrt_build.sh b/paddle/scripts/infrt_build.sh index cc948e7a8d..74f690da76 100644 --- a/paddle/scripts/infrt_build.sh +++ b/paddle/scripts/infrt_build.sh @@ -65,12 +65,12 @@ function infrt_gen_and_build() { mkdir -p ${PADDLE_ROOT}/build cd ${PADDLE_ROOT}/build rm -f infrt_summary.txt - cmake .. -DWITH_MKL=OFF -DWITH_GPU=OFF -DCMAKE_BUILD_TYPE=Release -DWITH_INFRT=ON -DWITH_PYTHON=OFF -DWITH_TESTING==${WITH_TESTING:-ON}; build_error=$? + cmake .. -DWITH_MKL=OFF -DWITH_GPU=OFF -DWITH_CRYPTO=OFF -DCMAKE_BUILD_TYPE=Release -DWITH_INFRT=ON -DWITH_PYTHON=OFF -DWITH_TESTING==${WITH_TESTING:-ON}; build_error=$? if [ "$build_error" != 0 ];then exit 7; fi - make -j ${parallel_number} infrt infrtopt infrt-exec test_infrt_exec infrt_lib_dist;build_error=$? + make -j ${parallel_number} infrt infrtopt infrt-exec test_infrt_exec trt-exec infrt_lib_dist;build_error=$? if [ "$build_error" != 0 ];then exit 7; fi @@ -115,6 +115,9 @@ function main() { build_only) infrt_gen_and_build ${parallel_number} ;; + test_only) + test_infrt + ;; *) print_usage exit 1 @@ -126,7 +129,7 @@ function main() { cat ${PADDLE_ROOT}/build/infrt_summary.txt echo "========================================================" fi - echo "paddle_build script finished as expected" + echo "paddle_build script finished as expected!" } main $@ -- GitLab