未验证 提交 d51a0718 编写于 作者: X xiebaiyuan 提交者: GitHub

Merge pull request #1081 from xiebaiyuan/develop

trans gemm to class && add multi instance support && to unit test
此差异已折叠。
......@@ -35,146 +35,166 @@ namespace paddle_mobile {
namespace operators {
namespace math {
/*
class Gemm {
public:
/*
// 将 A 矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
float *buffer);
// 将 B 矩阵分块复制到连续内存(ColMajor)
void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
float *buffer);
*/
// 将 A 矩阵分块复制到连续内存(RowMajor)
void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
// 将 B 矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
// 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
float beta, float *c, float *C, int ldc, bool relu);
void InnerKernelWithBias(int mc, int nc, float alpha, const float *a,
typedef void (Gemm::*FnPack)(int, int, int, const float *, int, float *);
typedef void (Gemm::*FnAddDot)(int, const float *, const float *, float *,
int);
FnPack procPackA;
FnPack procPackB;
FnAddDot procAddDot;
// 将 A 矩阵分块复制到连续内存(RowMajor)
void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
// 将 B 矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
// 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
float beta, float *c, float *C, int ldc, bool relu);
void InnerKernelWithBias(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C,
int ldc, bool relu, float *bias);
void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C,
int ldc, bool relu, float *bias);
void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C, int ldc,
bool relu, float *new_scale, float *new_bias);
void InnerKernelWithBnAdd(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C,
int ldc, bool relu, float *new_scale, float *new_bias,
int ldc, bool relu, float *new_scale, float *new_bias);
void InnerKernelWithBnAdd(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C,
int ldc, bool relu, float *new_scale,
float *new_bias, float *bias);
void InnerKernelWithPRelu(int mc, int nc, const float *a, const float *b,
float *c, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1);
/*
// 向量矩阵乘法 (M = 1)
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu);
void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta, float
*C, int ldc, bool relu, float *new_scale, float *new_bias);
*/
// 计算一个更小的 C 矩阵分块
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc);
void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc);
void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc);
void AddDot8x12(int k, const float *a, const float *b, float *c, int ldc);
void AddDot6x16(int k, const float *a, const float *b, float *c, int ldc);
// 分块矩阵乘法结果回写
// C = A * B
void WriteBasic(int mc, int nc, float *c, float *C, int ldc);
// C = alpha * A * B + beta * C
void WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + C
void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + bias
void WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, float *bias);
// C = A * B + C, relu(C)
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + C,prelu(C)
void WriteWithAddPRelu(int mc, int nc, float *c, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1);
// C = A * B + bias ,relu(C)
void WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc,
float *bias);
void InnerKernelWithPRelu(int mc, int nc, const float *a, const float *b,
float *c, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1);
/*
// 向量矩阵乘法 (M = 1)
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu);
void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta, float *C,
int ldc, bool relu, float *new_scale, float *new_bias);
*/
// C = A * B, batchnorm(C)
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias);
// C = A * B, batchnorm(C), relu(C)
void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias);
void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias, float *bias1);
/*
// 向量矩阵乘法结果回写
// C = A * B
void VecWriteBasic(int n, float *c, float *C, int ldc);
// C = alpha * A * B + beta * C
void VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc);
// C = A * B + C
void VecWriteWithAdd(int n, float *c, float *C, int ldc);
// C = A * B + C, relu(C)
void VecWriteWithAddRelu(int n, float *c, float *C, int ldc);
// C = A * B, batchnorm(C)
void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale,
float *new_bias);
// C = A * B, batchnorm(C), relu(C)
void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale,
float *new_bias);
*/
// 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, bool relu,
float *bias);
// 32位 float 矩阵乘法, 并对结果进行 batchnrom
void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias, float *bias);
void SgemmWithPRelu(int m, int n, int k, const float *A, int lda,
const float *B, int ldb, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1);
// 32位 float 矩阵乘法(openmp 多线程版本)
void Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *bias);
// 计算一个更小的 C 矩阵分块
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc);
void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc);
void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc);
void AddDot8x12(int k, const float *a, const float *b, float *c, int ldc);
void AddDot6x16(int k, const float *a, const float *b, float *c, int ldc);
// 分块矩阵乘法结果回写
// C = A * B
void WriteBasic(int mc, int nc, float *c, float *C, int ldc);
// C = alpha * A * B + beta * C
void WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + C
void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + bias
void WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, float *bias);
// C = A * B + C, relu(C)
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + C,prelu(C)
void WriteWithAddPRelu(int mc, int nc, float *c, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1);
// C = A * B + bias ,relu(C)
void WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc,
float *bias);
// C = A * B, batchnorm(C)
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale,
float *new_bias);
// C = A * B, batchnorm(C), relu(C)
void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias);
void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias, float *bias1);
/*
// 向量矩阵乘法结果回写
// C = A * B
void VecWriteBasic(int n, float *c, float *C, int ldc);
// C = alpha * A * B + beta * C
void VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc);
// C = A * B + C
void VecWriteWithAdd(int n, float *c, float *C, int ldc);
// C = A * B + C, relu(C)
void VecWriteWithAddRelu(int n, float *c, float *C, int ldc);
// C = A * B, batchnorm(C)
void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale,
float *new_bias);
// C = A * B, batchnorm(C), relu(C)
void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale,
float *new_bias);
*/
// 32位 float 矩阵乘法, 并对结果进行 batchnrom(openmp 多线程版本)
void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta, float *C,
int ldc, bool relu, float *new_scale, float *new_bias,
float *bias);
void SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
const float *B, int ldb, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1);
// 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, bool relu,
float *bias);
private:
int MC = 0;
int KC = 0;
int NC = 0;
// 32位 float 矩阵乘法, 并对结果进行 batchnrom
void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias, float *bias);
void SgemmWithPRelu(int m, int n, int k, const float *A, int lda,
const float *B, int ldb, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1);
// 32位 float 矩阵乘法(openmp 多线程版本)
void Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *bias);
// 32位 float 矩阵乘法, 并对结果进行 batchnrom(openmp 多线程版本)
void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias, float *bias);
void SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
const float *B, int ldb, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1);
float *packedA;
float *packedB;
float *packedC;
float *zero;
};
} // namespace math
} // namespace operators
......
......@@ -28,19 +28,22 @@ struct GRUUnitFunctor<CPU, T> {
static void compute(GRUMetaValue<T> value, int frame_size, int batch_size,
const ActivationType active_node,
const ActivationType active_gate) {
Gemm gemm;
if (value.prev_out_value) {
Sgemm(batch_size, frame_size * 2, frame_size, 1, value.prev_out_value,
frame_size, value.gate_weight, frame_size * 2, 1, value.gate_value,
frame_size * 3, false, nullptr);
gemm.Sgemm(batch_size, frame_size * 2, frame_size, 1,
value.prev_out_value, frame_size, value.gate_weight,
frame_size * 2, 1, value.gate_value, frame_size * 3, false,
nullptr);
}
forward_reset_output(forward::gru_resetOutput<T>(), value, frame_size,
batch_size, active_gate);
if (value.prev_out_value) {
Sgemm(batch_size, frame_size, frame_size, 1, value.reset_output_value,
frame_size, value.state_weight, frame_size, 1,
value.gate_value + frame_size * 2, frame_size * 3, false, nullptr);
gemm.Sgemm(batch_size, frame_size, frame_size, 1,
value.reset_output_value, frame_size, value.state_weight,
frame_size, 1, value.gate_value + frame_size * 2,
frame_size * 3, false, nullptr);
}
forward_final_output(forward::gru_finalOutput<T>(), value, frame_size,
......
......@@ -36,6 +36,7 @@ void matmul<float>(const framework::Tensor &matrix_a, bool trans_a,
int M = dim_out[0];
int N = dim_out[1];
int K = (!trans_a) ? dim_a[1] : dim_a[0];
Gemm gemm;
if (trans_a) {
int numel = matrix_a.numel();
......@@ -50,20 +51,24 @@ void matmul<float>(const framework::Tensor &matrix_a, bool trans_a,
a[index++] = tmp[i * n + j];
}
}
#ifdef _OPENMP
Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<float>(), N, beta,
matrix_out->data<float>(), N, relu, bias);
gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<float>(), N, beta,
matrix_out->data<float>(), N, relu, bias);
#else
Sgemm(M, N, K, alpha, a, K, matrix_b.data<float>(), N, beta,
matrix_out->data<float>(), N, relu, bias);
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<float>(), N, beta,
matrix_out->data<float>(), N, relu, bias);
#endif
} else {
#ifdef _OPENMP
Sgemm_omp(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(),
N, beta, matrix_out->data<float>(), N, relu, bias);
gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(),
N, relu, bias);
#else
Sgemm(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), N,
beta, matrix_out->data<float>(), N, relu, bias);
gemm.Sgemm(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(), N,
relu, bias);
#endif
}
}
......@@ -74,6 +79,7 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a,
float alpha, framework::Tensor *matrix_out, float beta,
bool relu, framework::Tensor *new_scale,
framework::Tensor *new_bias, int group, float *bias) {
Gemm gemm;
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
......@@ -86,21 +92,22 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a,
int K = (!trans_a) ? dim_a[1] : dim_a[0];
#ifdef _OPENMP
SgemmWithBn_omp(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(), N,
relu, new_scale->data<float>() + group,
new_bias->data<float>() + group, bias);
gemm.SgemmWithBn_omp(
M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), N,
beta, matrix_out->data<float>(), N, relu,
new_scale->data<float>() + group, new_bias->data<float>() + group, bias);
#else
SgemmWithBn(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(),
N, beta, matrix_out->data<float>(), N, relu,
new_scale->data<float>() + group, new_bias->data<float>() + group,
bias);
gemm.SgemmWithBn(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(),
N, relu, new_scale->data<float>() + group,
new_bias->data<float>() + group, bias);
#endif
}
void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
framework::Tensor *matrix_out, float *p, std::string mode,
float *bias, float *bias1) {
Gemm gemm;
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
......@@ -113,11 +120,13 @@ void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a,
int K = (!trans_a) ? dim_a[1] : dim_a[0];
#ifdef _OPENMP
SgemmWithPRelu_omp(M, N, K, matrix_a.data<float>(), K, matrix_b.data<float>(),
N, matrix_out->data<float>(), N, p, mode, bias, bias1);
gemm.SgemmWithPRelu_omp(M, N, K, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, matrix_out->data<float>(),
N, p, mode, bias, bias1);
#else
SgemmWithPRelu(M, N, K, matrix_a.data<float>(), K, matrix_b.data<float>(), N,
matrix_out->data<float>(), N, p, mode, bias, bias1);
gemm.SgemmWithPRelu(M, N, K, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, matrix_out->data<float>(), N,
p, mode, bias, bias1);
#endif
}
......
......@@ -35,8 +35,8 @@ if (CON GREATER -1)
ADD_EXECUTABLE(test-yolo net/test_yolo.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-yolo paddle-mobile)
# gen test
ADD_EXECUTABLE(test_yolo_combined net/test_yolo_combined.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test_yolo_combined paddle-mobile)
ADD_EXECUTABLE(test-yolo-combined net/test_yolo_combined.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-yolo-combined paddle-mobile)
set(FOUND_MATCH ON)
endif ()
......@@ -323,5 +323,10 @@ if (NOT FOUND_MATCH)
target_link_libraries(test-fssd paddle-mobile)
# gen test
ADD_EXECUTABLE(test-multi-process net/test_multi_inference_predict.cpp test_helper.h test_include.h)
target_link_libraries(test-multi-process paddle-mobile)
#add_library(test-lib-size SHARED common/test_lib_size.h common/test_lib_size.cpp)
endif ()
......@@ -83,8 +83,9 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) {
}
}
paddle_mobile::operators::math::SgemmWithBn(
m, n, k, 0.9, a, lda, b, ldb, 0.3, c, ldc, relu, scale, bias, nullptr);
paddle_mobile::operators::math::Gemm gemm;
gemm.SgemmWithBn(m, n, k, 0.9, a, lda, b, ldb, 0.3, c, ldc, relu, scale, bias,
nullptr);
int eq = 0;
int neq = 0;
for (int i = 0; i < m * n; ++i) {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <thread> // NOLINT
#include "../test_helper.h"
#include "../test_include.h"
void fun_yolo();
int fun_mobilenet();
int main() {
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile2;
// fun_yolo();
// fun_mobilenet();
std::thread t1(fun_yolo);
std::thread t2(fun_mobilenet);
t1.join();
t2.join();
return 0;
}
void fun_yolo() {
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
paddle_mobile.SetThreadNum(4);
// ../../../test/models/googlenet
// ../../../test/models/mobilenet
auto time1 = time();
if (paddle_mobile.Load(g_yolo, true)) {
auto time2 = time();
std::cout << "load cost :" << time_diff(time1, time1) << "ms" << std::endl;
vector<int64_t> dims{1, 3, 227, 227};
Tensor input_tensor;
SetupTensor<float>(&input_tensor, {1, 3, 227, 227}, static_cast<float>(0),
static_cast<float>(1));
vector<float> input(input_tensor.data<float>(),
input_tensor.data<float>() + input_tensor.numel());
auto time3 = time();
for (int i = 0; i < 10; ++i) {
paddle_mobile.Predict(input, dims);
}
auto time4 = time();
std::cout << "thread 1: predict cost :" << time_diff(time3, time4) / 10
<< "ms" << std::endl;
}
}
int fun_mobilenet() {
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
paddle_mobile.SetThreadNum(4);
auto time1 = time();
// auto isok = paddle_mobile.Load(std::string(g_mobilenet_detect) + "/model",
// std::string(g_mobilenet_detect) + "/params", true);
auto isok = paddle_mobile.Load(g_mobilenet, true);
if (isok) {
auto time2 = time();
std::cout << "load cost :" << time_diff(time1, time1) << "ms" << std::endl;
vector<float> input;
vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224_banana, &input, dims);
auto vec_result = paddle_mobile.Predict(input, dims);
auto biggest = max_element(begin(vec_result), end(vec_result));
std::cout << " Max element is " << *biggest << " at position "
<< distance(begin(vec_result), biggest) << std::endl;
// 预热十次
for (int i = 0; i < 10; ++i) {
auto vec_result = paddle_mobile.Predict(input, dims);
}
auto time3 = time();
for (int i = 0; i < 10; ++i) {
auto vec_result = paddle_mobile.Predict(input, dims);
}
DLOG << vec_result;
auto time4 = time();
std::cout << "thread 2: predict cost :" << time_diff(time3, time4) / 10
<< "ms" << std::endl;
}
std::cout << "如果结果Nan请查看: test/images/g_test_image_1x3x224x224_banana "
"是否存在?"
<< std::endl;
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册