未验证 提交 9f0958fa 编写于 作者: 王明冬 提交者: GitHub

[infrt] add trt_graph_split_pass for infrt. test=develop (#38494)

上级 dfdc9960
if (NOT WITH_INFRT) if (NOT WITH_INFRT)
return() return()
endif() endif()
# compile flags
set(INFRT_FLAGS -Wno-comment)
foreach(flag ${INFRT_FLAGS})
safe_set_cflag(CMAKE_C_FLAGS ${flag})
safe_set_cxxflag(CMAKE_CXX_FLAGS ${flag})
endforeach()
set(INFRT_SOURCE_DIR "${PADDLE_SOURCE_DIR}/paddle/infrt" ) set(INFRT_SOURCE_DIR "${PADDLE_SOURCE_DIR}/paddle/infrt" )
set(INFRT_BINARY_DIR "${PADDLE_BINARY_DIR}/paddle/infrt" ) set(INFRT_BINARY_DIR "${PADDLE_BINARY_DIR}/paddle/infrt" )
set(INFRT_TEST_TARGETS CACHE INTERNAL "") set(INFRT_TEST_TARGETS CACHE INTERNAL "")
......
...@@ -36,7 +36,9 @@ mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context, ...@@ -36,7 +36,9 @@ mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
const std::string& mlir_source) { const std::string& mlir_source) {
// context->allowUnregisteredDialects(); // context->allowUnregisteredDialects();
RegisterCinnDialects(context->getDialectRegistry()); RegisterCinnDialects(context->getDialectRegistry());
context->getDialectRegistry().insert<mlir::StandardOpsDialect>(); // Currenetly, We only used the CinnDialect and mlir::BuiltinDialect is
// enough。Don't need StandardOpsDialect.
// context->getDialectRegistry().insert<mlir::StandardOpsDialect>();
mlir::ScopedDiagnosticHandler scope_handler( mlir::ScopedDiagnosticHandler scope_handler(
context, [](mlir::Diagnostic& diag) { context, [](mlir::Diagnostic& diag) {
......
func @ops() { func @ops() {
%a = pd.feed() : tensor<?xf32> %a = pd.feed() {name="input0"} : tensor<?xf32>
%b = pd.feed() : tensor<?xf32> %b = pd.feed() {name="input1"}: tensor<?xf32>
%c = "pd.matmul"(%a, %b) {transpose_x=true, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %c = "pd.matmul"(%a, %b) {transpose_x=true, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
......
// CHECK-LABEL: @main // CHECK-LABEL: @main
func @main() -> tensor<?xf32> { func @main() -> tensor<?xf32> {
%a = "pd.feed"() : () -> tensor<?xf32> %a = "pd.feed"() {name="input0"} : () -> tensor<?xf32>
%b = "pd.feed"() : () -> tensor<?xf32> %b = "pd.feed"() {name="input1"} : () -> tensor<?xf32>
%bias = "pd.feed"() : () -> tensor<?xf32> %bias = "pd.feed"() {name="input2"} : () -> tensor<?xf32>
%b1 = "pd.feed"() : () -> tensor<?xf32> %b1 = "pd.feed"() {name="input3"} : () -> tensor<?xf32>
%b2 = "pd.feed"() : () -> tensor<?xf32> %b2 = "pd.feed"() {name="input4"} : () -> tensor<?xf32>
%bias1 = "pd.feed"() : () -> tensor<?xf32> %bias1 = "pd.feed"() {name="input5"} : () -> tensor<?xf32>
%bias2 = "pd.feed"() : () -> tensor<?xf32> %bias2 = "pd.feed"() {name="input6"} : () -> tensor<?xf32>
%c = "pd.matmul"(%a, %b) {transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %c = "pd.matmul"(%a, %b) {transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%d = "pd.elementwise_add"(%c, %bias) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %d = "pd.elementwise_add"(%c, %bias) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
......
// CHECK-LABEL: @main // CHECK-LABEL: @main
func @main() -> tensor<?xf32> { func @main() -> tensor<?xf32> {
%a = "pd.feed"() : () -> tensor<?x3x256x256xf32> %a = "pd.feed"() {name="input0"} : () -> tensor<?x3x256x256xf32>
%filter = "pd.constant"(){value = dense<1.000000e+00> : tensor<3x64x3x3xf32>} : () -> tensor<3x64x3x3xf32> %filter = "pd.constant"(){value = dense<1.000000e+00> : tensor<3x64x3x3xf32>} : () -> tensor<3x64x3x3xf32>
%bias = "pd.constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32> %bias = "pd.constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32>
......
// CHECK-LABEL: @main // CHECK-LABEL: @main
func @main() -> tensor<?xf32> { func @main() -> tensor<?xf32> {
%a = "pd.feed"() : () -> tensor<?xf32> %bias = "pd.feed"() {name="input0"} : () -> tensor<?xf32>
%b = "pd.feed"() : () -> tensor<?xf32> %c = "pd.feed"() {name="input1"} : () -> tensor<?xf32>
%bias = "pd.feed"() : () -> tensor<?xf32> %b1 = "pd.feed"() {name="input2"} : () -> tensor<?xf32>
%c = "pd.feed"() : () -> tensor<?xf32> %b2 = "pd.feed"() {name="input3"} : () -> tensor<?xf32>
%b1 = "pd.feed"() : () -> tensor<?xf32> %bias1 = "pd.feed"() {name="input4"} : () -> tensor<?xf32>
%b2 = "pd.feed"() : () -> tensor<?xf32> %bias2 = "pd.feed"() {name="input5"} : () -> tensor<?xf32>
%bias1 = "pd.feed"() : () -> tensor<?xf32>
%bias2 = "pd.feed"() : () -> tensor<?xf32>
%d = "pd.elementwise_add"(%c, %bias) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %d = "pd.elementwise_add"(%c, %bias) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e = "pd.relu6"(%d) {} : (tensor<?xf32>) -> tensor<?xf32> %e = "pd.relu6"(%d) {} : (tensor<?xf32>) -> tensor<?xf32>
......
...@@ -6,14 +6,14 @@ include "mlir/Interfaces/LoopLikeInterface.td" ...@@ -6,14 +6,14 @@ include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/pd_op_base.td" include "paddle/infrt/dialect/pd_op_base.td"
def PD_FeedOp : PD_Op<"feed", [NoSideEffect]> { def PD_FeedOp : PD_Op<"feed"> {
let summary = "Feed Op"; let summary = "Feed Op";
let description = [{ let description = [{
Feed a tensor into the model. Feed a tensor into the model.
}]; }];
let arguments = (ins); let arguments = (ins StrAttr:$name);
let results = (outs PD_Tensor:$out); let results = (outs PD_Tensor:$out);
let assemblyFormat = [{ let assemblyFormat = [{
......
...@@ -4,5 +4,9 @@ gather_srcs(infrt_src SRCS ...@@ -4,5 +4,9 @@ gather_srcs(infrt_src SRCS
trt_ops.cc trt_ops.cc
trt_op_teller_pass.cc trt_op_teller_pass.cc
trt_graph_fuse_pass.cc trt_graph_fuse_pass.cc
trt_graph_split_pass.cc
) )
mlir_tablegen_on(trt_ops) mlir_tablegen_on(trt_ops)
add_executable(trt-exec trt_exec.cc)
target_link_libraries(trt-exec infrt ${MLIR_IR_LIBS})
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <string>
#include "llvm/Support/CommandLine.h"
#include "mlir/Pass/PassManager.h"
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"
#include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h"
#include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h"
int main(int argc, char** argv) {
static llvm::cl::opt<std::string> input_file(
llvm::cl::Positional,
llvm::cl::desc("Specify input filename"),
llvm::cl::init("-"));
llvm::cl::ParseCommandLineOptions(argc, argv);
mlir::MLIRContext* context = infrt::Global::getMLIRContext();
auto module = infrt::dialect::LoadMlirFile(input_file.c_str(), context);
module->dump();
mlir::PassManager pm(context);
mlir::OpPassManager& trt_pass_manager = pm.nest<mlir::FuncOp>();
trt_pass_manager.addPass(std::make_unique<infrt::trt::trtOpTellerPass>());
trt_pass_manager.addPass(std::make_unique<infrt::trt::trtGraphFusePass>());
trt_pass_manager.addPass(std::make_unique<infrt::trt::trtGraphSplitPass>(10));
if (mlir::failed(pm.run(*module))) {
std::cout << "\npass failed!\n" << std::endl;
return 4;
}
module->dump();
return 0;
}
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SetVector.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "paddle/infrt/dialect/pd_ops.h" #include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h"
...@@ -25,42 +26,31 @@ ...@@ -25,42 +26,31 @@
namespace infrt { namespace infrt {
namespace trt { namespace trt {
namespace { namespace {
// ReverseDfs
// FlexibleDFS // do reverse dfs. calls "func" to search when visit a node.
// do reverse dfs. calls leave(node) after visiting all parents of node. // The elements in 'source' can't be nullptr.
// Reference the function with the same name but defined in: // Reference the function nameed "FlexibleDFS" but defined in:
// paddle/fluid/framework/ir/subgraph_detector.cc. // paddle/fluid/framework/ir/subgraph_detector.cc.
void FlexibleDFS(const std::vector<::mlir::Operation *> &source,
const std::function<bool(const ::mlir::Operation *)> &leave) {
typedef struct {
::mlir::Operation *node;
bool leave;
} FNode;
std::vector<FNode> stack; bool reverseDfs(std::vector<::mlir::Operation *> source,
for (auto &node : source) { const std::function<bool(const ::mlir::Operation *)> &func) {
stack.push_back(FNode{node, false});
}
std::unordered_set<const ::mlir::Operation *> visited; std::unordered_set<const ::mlir::Operation *> visited;
while (!stack.empty()) { while (!source.empty()) {
auto fnode = stack.back(); auto node = source.back();
stack.pop_back(); source.pop_back();
if (visited.count(node)) continue;
if (fnode.leave) { visited.insert(node);
if (leave && !leave(fnode.node)) return; if (func(node)) return true;
} auto values = node->getOperands();
if (visited.count(fnode.node)) continue;
visited.insert(fnode.node);
if (leave) stack.push_back(FNode{fnode.node, true});
auto values = fnode.node->getOperands();
for (auto value : values) { for (auto value : values) {
// if the value is a block argument, the node is nullptr.
::mlir::Operation *node = value.getDefiningOp(); ::mlir::Operation *node = value.getDefiningOp();
if (!visited.count(node)) { if (node != nullptr && !visited.count(node)) {
stack.push_back(FNode{node, false}); source.emplace_back(node);
} }
} }
} }
return false;
} }
// merge the first&second graph op to a new graph op. // merge the first&second graph op to a new graph op.
...@@ -136,6 +126,20 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT ...@@ -136,6 +126,20 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
second.erase(); second.erase();
} }
// Topological sort the function op.
void topoSortBlock(mlir::Block &body) { // NOLINT
llvm::SetVector<Operation *> toSort;
if (body.empty()) return;
for (auto it = body.rbegin(); it != body.rend(); ++it) {
toSort.insert(&*it);
}
llvm::SetVector<Operation *> result =
::mlir::topologicalSort(std::move(toSort));
for (auto *op : result) {
op->moveBefore(body.getTerminator());
}
}
} // namespace } // namespace
// Implementation of the trtGraphFusePass. // Implementation of the trtGraphFusePass.
...@@ -158,21 +162,14 @@ void trtGraphFusePass::runOnFunction() { ...@@ -158,21 +162,14 @@ void trtGraphFusePass::runOnFunction() {
std::vector<::mlir::Operation *> source_nodes; std::vector<::mlir::Operation *> source_nodes;
for (auto operand : user_op->getOperands()) { for (auto operand : user_op->getOperands()) {
auto input = operand.getDefiningOp(); auto input = operand.getDefiningOp();
if (input != &op) { if (input != &op && input != nullptr) {
source_nodes.push_back(input); source_nodes.push_back(input);
} }
} }
// Reverse DFS from the source_nodes. // Reverse DFS from the source_nodes.
bool have_excess_path = false; if (!reverseDfs(source_nodes, [&op](const ::mlir::Operation *n) {
FlexibleDFS(source_nodes, return n == &op;
[&have_excess_path, &op](const ::mlir::Operation *n) { })) {
if (n == &op) {
have_excess_path = true;
return false;
}
return true;
});
if (!have_excess_path) {
mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op); mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op);
changed = true; changed = true;
break; break;
...@@ -181,7 +178,7 @@ void trtGraphFusePass::runOnFunction() { ...@@ -181,7 +178,7 @@ void trtGraphFusePass::runOnFunction() {
if (changed) break; if (changed) break;
} }
} while (changed); } while (changed);
topoSortBlock(body);
} }
} // namespace trt } // namespace trt
} // namespace infrt } // namespace infrt
...@@ -25,7 +25,7 @@ namespace trt { ...@@ -25,7 +25,7 @@ namespace trt {
* source func: * source func:
* *
* func @main() -> tensor<?xf32> { * func @main() -> tensor<?xf32> {
* %a = "pd.feed"() * %a = "pd.feed"()...
* %c = "pd.graph"(%a) { * %c = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* "pd.fetch" %m * "pd.fetch" %m
...@@ -42,7 +42,7 @@ namespace trt { ...@@ -42,7 +42,7 @@ namespace trt {
* *
* destination func: * destination func:
* func @main() -> tensor<?xf32> { * func @main() -> tensor<?xf32> {
* %a = "pd.feed"() * %a = "pd.feed"()...
* %d, %f = "pd.graph"(%a) { * %d, %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)... * %n = "pd.conv3d"(%m)...
...@@ -58,6 +58,5 @@ class trtGraphFusePass ...@@ -58,6 +58,5 @@ class trtGraphFusePass
::llvm::StringRef getName() const override { return "trtGraphFusePass"; } ::llvm::StringRef getName() const override { return "trtGraphFusePass"; }
void runOnFunction() override; void runOnFunction() override;
}; };
} // namespace trt } // namespace trt
} // namespace infrt } // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h"
#include "mlir/IR/Builders.h"
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt {
namespace trt {
// Implementation of the trtGraphSplitPass。
void trtGraphSplitPass::runOnFunction() {
std::vector<::mlir::pd::GraphOp> worklist;
::mlir::Block& block = getFunction().front();
for (auto& op : block) {
::mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(&op);
if (nullptr != graph_op &&
graph_op.getBody()->getOperations().size() <= min_subgraph_size_) {
worklist.push_back(graph_op);
}
}
while (!worklist.empty()) {
::mlir::pd::GraphOp graph_op = worklist.back();
worklist.pop_back();
::mlir::Block* body = graph_op.getBody();
auto fetch_op = body->getTerminator();
graph_op.replaceAllUsesWith(fetch_op->getOperands());
auto copy_range = body->without_terminator();
block.getOperations().splice(::mlir::Block::iterator(graph_op),
body->getOperations(),
copy_range.begin(),
copy_range.end());
graph_op.erase();
}
}
} // namespace trt
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mlir/Pass/Pass.h"
namespace infrt {
namespace trt {
/*
* trtGraphSplitPass.
*
* Splite the graph op when the number of operations is too small.
* The feature is the opposite of 'trtOpTellerPass'.
*
* source func:
*
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()...
* %d, %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
* "pd.fetch" %n, %s
* } ...
* "pd.fetch" %d, %f
* }
*
* destination func:
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()...
* %c = "pd.conv2d"(%a) ...
* %d = "pd.conv3d"(%c) ...
* %f = "pd.conv2d"(%a) ...
* "pd.fetch" %d, %f
* }
*/
class trtGraphSplitPass
: public ::mlir::PassWrapper<trtGraphSplitPass, ::mlir::FunctionPass> {
public:
::llvm::StringRef getName() const override { return "trtGraphSplitPass"; }
void runOnFunction() override;
explicit trtGraphSplitPass(size_t min_subgraph_size = 3)
: min_subgraph_size_(min_subgraph_size) {}
private:
size_t min_subgraph_size_;
};
} // namespace trt
} // namespace infrt
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
namespace infrt { namespace infrt {
namespace trt { namespace trt {
// Implementation of the trtOpTellerPass。 // Implementation of the trtOpTellerPass。
void trtOpTellerPass::runOnFunction() { void trtOpTellerPass::runOnFunction() {
::mlir::Block &body = getFunction().front(); ::mlir::Block &body = getFunction().front();
...@@ -60,6 +59,5 @@ void trtOpTellerPass::runOnFunction() { ...@@ -60,6 +59,5 @@ void trtOpTellerPass::runOnFunction() {
builder.create<mlir::pd::FetchOp>(loc, op->getResults()); builder.create<mlir::pd::FetchOp>(loc, op->getResults());
} }
} }
} // namespace trt } // namespace trt
} // namespace infrt } // namespace infrt
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
namespace infrt { namespace infrt {
namespace trt { namespace trt {
/* /*
* trtOpTellerPass. * trtOpTellerPass.
* *
...@@ -26,7 +25,7 @@ namespace trt { ...@@ -26,7 +25,7 @@ namespace trt {
* source func: * source func:
* *
* func @main() -> tensor<?xf32> { * func @main() -> tensor<?xf32> {
* %a = "pd.feed"() * %a = "pd.feed"()...
* %c = "pd.conv2d"(%a) ... * %c = "pd.conv2d"(%a) ...
* %d = "pd.conv3d"(%c) ... * %d = "pd.conv3d"(%c) ...
* %f = "pd.conv2d"(%a) ... * %f = "pd.conv2d"(%a) ...
...@@ -35,7 +34,7 @@ namespace trt { ...@@ -35,7 +34,7 @@ namespace trt {
* *
* destination func: * destination func:
* func @main() -> tensor<?xf32> { * func @main() -> tensor<?xf32> {
* %a = "pd.feed"() * %a = "pd.feed"()...
* %c = "pd.graph"(%a) { * %c = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* "pd.fetch" %m * "pd.fetch" %m
...@@ -59,6 +58,5 @@ class trtOpTellerPass ...@@ -59,6 +58,5 @@ class trtOpTellerPass
::llvm::StringRef getName() const override { return "trtOpTellerPass"; } ::llvm::StringRef getName() const override { return "trtOpTellerPass"; }
void runOnFunction() override; void runOnFunction() override;
}; };
} // namespace trt } // namespace trt
} // namespace infrt } // namespace infrt
文件模式从 100755 更改为 100644
...@@ -65,12 +65,12 @@ function infrt_gen_and_build() { ...@@ -65,12 +65,12 @@ function infrt_gen_and_build() {
mkdir -p ${PADDLE_ROOT}/build mkdir -p ${PADDLE_ROOT}/build
cd ${PADDLE_ROOT}/build cd ${PADDLE_ROOT}/build
rm -f infrt_summary.txt rm -f infrt_summary.txt
cmake .. -DWITH_MKL=OFF -DWITH_GPU=OFF -DCMAKE_BUILD_TYPE=Release -DWITH_INFRT=ON -DWITH_PYTHON=OFF -DWITH_TESTING==${WITH_TESTING:-ON}; build_error=$? cmake .. -DWITH_MKL=OFF -DWITH_GPU=OFF -DWITH_CRYPTO=OFF -DCMAKE_BUILD_TYPE=Release -DWITH_INFRT=ON -DWITH_PYTHON=OFF -DWITH_TESTING==${WITH_TESTING:-ON}; build_error=$?
if [ "$build_error" != 0 ];then if [ "$build_error" != 0 ];then
exit 7; exit 7;
fi fi
make -j ${parallel_number} infrt infrtopt infrt-exec test_infrt_exec infrt_lib_dist;build_error=$? make -j ${parallel_number} infrt infrtopt infrt-exec test_infrt_exec trt-exec infrt_lib_dist;build_error=$?
if [ "$build_error" != 0 ];then if [ "$build_error" != 0 ];then
exit 7; exit 7;
fi fi
...@@ -115,6 +115,9 @@ function main() { ...@@ -115,6 +115,9 @@ function main() {
build_only) build_only)
infrt_gen_and_build ${parallel_number} infrt_gen_and_build ${parallel_number}
;; ;;
test_only)
test_infrt
;;
*) *)
print_usage print_usage
exit 1 exit 1
...@@ -126,7 +129,7 @@ function main() { ...@@ -126,7 +129,7 @@ function main() {
cat ${PADDLE_ROOT}/build/infrt_summary.txt cat ${PADDLE_ROOT}/build/infrt_summary.txt
echo "========================================================" echo "========================================================"
fi fi
echo "paddle_build script finished as expected" echo "paddle_build script finished as expected!"
} }
main $@ main $@
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册