diff --git a/src/io/api_paddle_mobile.cc b/src/io/api_paddle_mobile.cc index 144cf127a44c78279ca1d95815646a4f01fed6bd..c5da3993d18d6c21c46c923e99609b4c290fb668 100644 --- a/src/io/api_paddle_mobile.cc +++ b/src/io/api_paddle_mobile.cc @@ -52,6 +52,10 @@ bool PaddleMobilePredictor::Init(const PaddleMobileConfig &config) { paddle_mobile_->SetThreadNum(config.thread_num); return true; } +template +double PaddleMobilePredictor::CaculatePredictTime() { + return paddle_mobile_->GetPredictTime(); +}; template bool PaddleMobilePredictor::Run( diff --git a/src/io/api_paddle_mobile.h b/src/io/api_paddle_mobile.h index bdeb7e18653843ec9547f027068768532ba04fb2..d8e5f856c6bae870f89d6957aafa97c34bfad5dd 100644 --- a/src/io/api_paddle_mobile.h +++ b/src/io/api_paddle_mobile.h @@ -40,6 +40,8 @@ class PaddleMobilePredictor : public PaddlePredictor { std::vector* output_data, int batch_size = -1) override; + double CaculatePredictTime() override; + ~PaddleMobilePredictor() override; private: diff --git a/src/io/paddle_inference_api.h b/src/io/paddle_inference_api.h index 3c9ffa00c7e749d1c9d77562b2db0b42ee605164..33a166f2c5cb9d668a411db1a03e1a766b3cfe9d 100644 --- a/src/io/paddle_inference_api.h +++ b/src/io/paddle_inference_api.h @@ -98,7 +98,7 @@ class PaddlePredictor { virtual bool Run(const std::vector& inputs, std::vector* output_data, int batch_size = -1) = 0; - + virtual double CaculatePredictTime() = 0; // Destroy the Predictor. virtual ~PaddlePredictor() = default; diff --git a/src/io/paddle_mobile.cpp b/src/io/paddle_mobile.cpp index cfd1a1c87671cfb598aad586b421f046830b10d9..fca870860ec1156aa7d3d8503951cfb8a2e84821 100644 --- a/src/io/paddle_mobile.cpp +++ b/src/io/paddle_mobile.cpp @@ -13,7 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "io/paddle_mobile.h" - +#ifdef PADDLE_MOBILE_CL +#include +#include "framework/cl/cl_tensor.h" +#endif +#include "common/common.h" +#include "operators/math/gemm.h" namespace paddle_mobile { static std::mutex lc; @@ -119,6 +124,40 @@ void PaddleMobile::Clear() { loader_ = nullptr; } +template +double PaddleMobile::GetPredictTime() { + int m = 32; + int n = 224 * 224; + int k = 27; + int lda = k; + int ldb = n; + int ldc = n; + float *a = + static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m * k)); + float *b = + static_cast(paddle_mobile::memory::Alloc(sizeof(float) * k * n)); + float *c = + static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m * n)); + int t1 = 1; + int t2 = 1; + for (int i = 0; i < m * k; ++i) { + a[i] = t1 + rand() % t2; + } + for (int i = 0; i < k * n; ++i) { + b[i] = t1 + rand() % t2; + } + paddle_mobile::operators::math::Gemm gemm; + auto time1 = paddle_mobile::time(); + gemm.Sgemm(m, n, k, static_cast(1), a, lda, b, ldb, + static_cast(0), c, ldc, false, nullptr); + auto time2 = paddle_mobile::time(); + double cost = paddle_mobile::time_diff(time1, time2); + paddle_mobile::memory::Free(a); + paddle_mobile::memory::Free(b); + paddle_mobile::memory::Free(c); + return cost; +} + template PaddleMobile::~PaddleMobile() { executor_ = nullptr; @@ -167,6 +206,208 @@ void PaddleMobile::SetCLPath(std::string path) { framework::CLEngine::Instance()->setClPath(path); } } +template <> +double PaddleMobile::GetPredictTime() { + cl_int status; + cl_uint nPlatform; + clGetPlatformIDs(0, NULL, &nPlatform); + cl_platform_id *listPlatform = + (cl_platform_id *)malloc(nPlatform * sizeof(cl_platform_id)); + clGetPlatformIDs(nPlatform, listPlatform, NULL); + cl_uint nDevice = 0; + clGetDeviceIDs(listPlatform[0], CL_DEVICE_TYPE_GPU, 0, NULL, &nDevice); + cl_device_id *listDevice = + (cl_device_id *)malloc(nDevice * sizeof(cl_device_id)); + clGetDeviceIDs(listPlatform[0], CL_DEVICE_TYPE_GPU, nDevice, listDevice, + NULL); + cl_context context = + clCreateContext(NULL, nDevice, listDevice, NULL, NULL, &status); + cl_command_queue queue = + clCreateCommandQueue(context, listDevice[0], 0, &status); + + int n = 1; + int c = 3; + int h = 224; + int w = 224; + float *input = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * 3 * 224 * 224)); + float *filter = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * 32 * 27)); + int input_w = w * (c + 3) / 4; + int input_h = n * h; + int filter_w = 3 * (3 + 3) / 4; + int filter_h = 32 * 3; + int output_w = 224 * (32 + 3) / 4; + int output_h = 1 * 224; + + framework::DDim input_dims = {1, 3, 224, 224}; + framework::CLTensor input_cl_tensor(context, queue); + input_cl_tensor.Resize(input_dims); + cl_mem inputBuffer = input_cl_tensor.mutable_with_data(input); + + framework::DDim filter_dims = {32, 3, 3, 3}; + framework::CLTensor filter_cl_tensor(context, queue); + input_cl_tensor.Resize(filter_dims); + cl_mem filterBuffer = filter_cl_tensor.mutable_with_data(filter); + + cl_mem cl_filter_image = NULL; + cl_mem cl_input_image = NULL; + cl_mem cl_output_image = NULL; + cl_image_format cf = {.image_channel_order = CL_RGBA, + .image_channel_data_type = CL_HALF_FLOAT}; + cl_input_image = clCreateImage2D(context, CL_MEM_READ_WRITE | 0, &cf, input_w, + input_h, 0, NULL, &status); + cl_filter_image = clCreateImage2D(context, CL_MEM_READ_WRITE | 0, &cf, + filter_w, filter_h, 0, NULL, &status); + cl_output_image = clCreateImage2D(context, CL_MEM_READ_WRITE | 0, &cf, + output_w, output_h, 0, NULL, &status); + char *code; + std::string path = framework::CLEngine::Instance()->GetCLPath() + + "/cl_kernel/feed_kernel.cl"; + size_t length = readText(path.c_str(), &code); + cl_program program = clCreateProgramWithSource( + context, 1, (const char **)&code, &length, NULL); + std::string path1 = "-cl-fast-relaxed-math -I " + + framework::CLEngine::Instance()->GetCLPath() + + "/cl_kernel"; + clBuildProgram(program, 0, 0, path1.c_str(), NULL, NULL); + cl_kernel kernel = clCreateKernel(program, "feed", &status); + + status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputBuffer); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_input_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 2, sizeof(cl_int), &input_w); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 3, sizeof(cl_int), &input_h); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 4, sizeof(cl_int), &c); + CL_CHECK_ERRORS(status); + + size_t global_work_size[2] = {input_w, input_h}; + + // cl_event out_event = param.Out()->GetClEvent(); + + status = clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_work_size, + NULL, 0, NULL, NULL); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &filterBuffer); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_filter_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 2, sizeof(cl_int), &filter_w); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 3, sizeof(cl_int), &filter_h); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 4, sizeof(cl_int), &c); + CL_CHECK_ERRORS(status); + + size_t global_work_size1[2] = {filter_w, filter_h}; + + // cl_event out_event = param.Out()->GetClEvent(); + + status = clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_work_size1, + NULL, 0, NULL, NULL); + CL_CHECK_ERRORS(status); + + clFinish(queue); + queue = clCreateCommandQueue(context, listDevice[0], 0, &status); + + path = framework::CLEngine::Instance()->GetCLPath() + + "/cl_kernel/conv_kernel.cl"; + size_t length1 = readText(path.c_str(), &code); + program = clCreateProgramWithSource(context, 1, (const char **)&code, + &length1, &status); + CL_CHECK_ERRORS(status); + clBuildProgram(program, 0, 0, path1.c_str(), NULL, NULL); + kernel = clCreateKernel(program, "conv_3x3", &status); + CL_CHECK_ERRORS(status); + + int c_block = (32 + 3) / 4; + int nh = n * h; + int stride = 1; + int offset = 0; + int input_c = (c + 3) / 4; + int dilation = 1; + int input_width = 224; + int input_height = 224; + int output_width = 224; + int output_height = 224; + status = clSetKernelArg(kernel, 0, sizeof(int), &c_block); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 1, sizeof(int), &w); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 2, sizeof(int), &nh); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &cl_input_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &cl_filter_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &cl_output_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 6, sizeof(int), &stride); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 7, sizeof(int), &offset); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 8, sizeof(int), &input_c); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 9, sizeof(int), &dilation); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 10, sizeof(int), &input_width); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 11, sizeof(int), &input_height); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 12, sizeof(int), &output_width); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 13, sizeof(int), &output_height); + CL_CHECK_ERRORS(status); + + // cl_event out_event = param.Output()->GetClEvent(); + // cl_event wait_event = param.Input()->GetClEvent(); + size_t global_work_size2[3] = {8, 224, 224}; + auto time1 = paddle_mobile::time(); + status = clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size2, + NULL, 0, NULL, NULL); + CL_CHECK_ERRORS(status); + clFinish(queue); + auto time2 = paddle_mobile::time(); + paddle_mobile::memory::Free(input); + paddle_mobile::memory::Free(filter); + return paddle_mobile::time_diff(time1, time2); +} +template +int PaddleMobile::readText( + const char *kernelPath, + char **pcode) // 读取文本文件放入 pcode,返回字符串长度 +{ + FILE *fp; + int size; + // printf(" File: %s\n", kernelPath); + fp = fopen(kernelPath, "rb"); + if (!fp) { + printf(" Open file failed\n"); + return -1; + } + if (fseek(fp, 0, SEEK_END) != 0) { + printf(" Seek end of file failed\n"); + return -1; + } + if ((size = ftell(fp)) < 0) { + printf(" Get file position failed\n"); + return -1; + } + rewind(fp); + if ((*pcode = (char *)malloc(size + 1)) == NULL) { + printf(" Allocate space failed\n"); + return -1; + } + fread(*pcode, 1, size, fp); + (*pcode)[size] = '\0'; + fclose(fp); + return size + 1; +} + #endif template class PaddleMobile; diff --git a/src/io/paddle_mobile.h b/src/io/paddle_mobile.h index 778b173f3e64f27f6bdf8329a2979ebbdf955633..ab148e7361c160bc658403d4696b806323595c54 100644 --- a/src/io/paddle_mobile.h +++ b/src/io/paddle_mobile.h @@ -65,6 +65,7 @@ class PaddleMobile { void SetThreadNum(int num); void Clear(); + double GetPredictTime(); ~PaddleMobile(); @@ -80,6 +81,8 @@ class PaddleMobile { #ifdef PADDLE_MOBILE_CL public: void SetCLPath(std::string cl_path); + int readText(const char *kernelPath, + char **pcode); // 读取文本文件放入 pcode,返回字符串长度 #endif private: diff --git a/src/operators/kernel/cl/cl_kernel/feed_kernel.cl b/src/operators/kernel/cl/cl_kernel/feed_kernel.cl index 80d741d859af633299120bfec9f4cfeeaeb47194..200a221c9bda49c42f2caff374fc24d6e4df27e5 100644 --- a/src/operators/kernel/cl/cl_kernel/feed_kernel.cl +++ b/src/operators/kernel/cl/cl_kernel/feed_kernel.cl @@ -13,14 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma OPENCL EXTENSION cl_khr_fp16 : enable -__kernel void feed(__global float *in, __write_only image2d_t outputImage,int h,int w) +__kernel void feed(__global float *in, __write_only image2d_t outputImage,int h,int w,int c) { int i = get_global_id(0); int j = get_global_id(1); half4 pixel; pixel.x = convert_half(in[(i * w + j)]); + if(c>=2){ pixel.y = convert_half(in[h * w + (i * w + j)]); + }else{ + pixel.y = 0.0; + } + if(c>=3){ pixel.z = convert_half(in[2 * h * w + (i * w + j)]); + }else{ + pixel.z = 0.0; + } pixel.w = 0.0; int2 coords; coords.x = j; diff --git a/src/operators/kernel/cl/feed_kernel.cpp b/src/operators/kernel/cl/feed_kernel.cpp index ad5fb9cdbcd00dad56579297c010c3912e3dca24..78f04357a23c70595595cc24489fd96e994162fb 100644 --- a/src/operators/kernel/cl/feed_kernel.cpp +++ b/src/operators/kernel/cl/feed_kernel.cpp @@ -34,6 +34,7 @@ void FeedKernel::Compute(const FeedParam ¶m) { const float *input_data = input->data(); int numel = input->numel(); cl_mem cl_image = output->GetCLImage(); + int c = input->dims()[1]; int height = output->dims()[2]; int width = output->dims()[3]; CLTensor input_cl_tensor(this->cl_helper_.CLContext(), @@ -49,6 +50,8 @@ void FeedKernel::Compute(const FeedParam ¶m) { CL_CHECK_ERRORS(status); status = clSetKernelArg(kernel, 3, sizeof(cl_int), &height); CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 4, sizeof(cl_int), &c); + CL_CHECK_ERRORS(status); size_t global_work_size[2] = {width, height}; diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index 605fa17c3c70ec3151cc1a2fb249edab336548a1..d3e6de3134ff91f47c66c927194a5ba688e931b0 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -3230,6 +3230,8 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, int L1 = 64 / max_threads * 1024; KC = k; + zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); + memset(static_cast(zero), 0, sizeof(float) * KC); if (m > n) { // 对 A 分块 MC = L1 / (KC * sizeof(float)); @@ -3255,7 +3257,7 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); - (*this.*procPackB)(KC, NC, NC % NR, B, ldb, packedB); + (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); } else { @@ -3284,12 +3286,10 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); - (*this.*procPackA)(MC, KC, MC % MR, A, lda, packedA); + (*this.*procPackA)(m, KC, m % MR, A, lda, packedA); packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); } - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - memset(static_cast(zero), 0, sizeof(float) * KC); packedC = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads)); @@ -3352,6 +3352,8 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int L1 = 64 / max_threads * 1024; KC = k; + zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); + memset(static_cast(zero), 0, sizeof(float) * KC); if (m > n) { // 对 A 分块 MC = L1 / (KC * sizeof(float)); @@ -3377,7 +3379,7 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); - (*this.*procPackB)(KC, NC, NC % NR, B, ldb, packedB); + (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); } else { @@ -3405,12 +3407,10 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); - (*this.*procPackA)(MC, KC, MC % MR, A, lda, packedA); + (*this.*procPackA)(m, KC, m % MR, A, lda, packedA); packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); } - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - memset(static_cast(zero), 0, sizeof(float) * KC); packedC = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads)); @@ -3480,6 +3480,8 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, int L1 = 8 * 1024; KC = k; + zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); + memset(static_cast(zero), 0, sizeof(float) * KC); if (m > n) { // 对 A 分块 MC = L1 / (KC * sizeof(float)); @@ -3505,7 +3507,7 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); - (*this.*procPackB)(KC, NC, NC % NR, B, ldb, packedB); + (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); } else { @@ -3533,12 +3535,10 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); - (*this.*procPackA)(MC, KC, MC % MR, A, lda, packedA); + (*this.*procPackA)(m, KC, m % MR, A, lda, packedA); packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); } - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - memset(static_cast(zero), 0, sizeof(float) * KC); packedC = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads)); diff --git a/test/net/test_yologpu.cpp b/test/net/test_yologpu.cpp index b00cbef0277f44c65ab951227176721599b0559e..e77861cabad8baca7bfe5bf673ba9b01af97498d 100644 --- a/test/net/test_yologpu.cpp +++ b/test/net/test_yologpu.cpp @@ -13,17 +13,75 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include +#include "../../src/common/types.h" #include "../test_helper.h" #include "../test_include.h" +void t1() { + paddle_mobile::PaddleMobile paddle_mobile_gpu; + paddle_mobile::PaddleMobile paddle_mobile_cpu; + // paddle_mobile.SetThreadNum(4); +#ifdef PADDLE_MOBILE_CL + paddle_mobile_gpu.SetCLPath("/data/local/tmp/bin"); +#endif + printf("cpu time:%f\n", paddle_mobile_cpu.GetPredictTime()); + printf("gpu time:%f\n", paddle_mobile_gpu.GetPredictTime()); + auto time1 = paddle_mobile::time(); + auto isok = paddle_mobile_gpu.Load(std::string(g_yolo_mul) + "/model", + std::string(g_yolo_mul) + "/params", true); -int main() { + // auto isok = paddle_mobile.Load(std::string(g_yolo_mul), true); + if (isok) { + auto time2 = paddle_mobile::time(); + std::cout << "load cost :" << paddle_mobile::time_diff(time1, time2) << "ms" + << std::endl; + + std::vector input; + std::vector dims{1, 3, 416, 416}; + GetInput(g_yolo_img, &input, dims); + + std::vector vec_result; + // = paddle_mobile.Predict(input, dims); + + auto time3 = paddle_mobile::time(); + int max = 10; + for (int i = 0; i < max; ++i) { + vec_result = paddle_mobile_gpu.Predict(input, dims); + } + auto time4 = paddle_mobile::time(); + + // auto time3 = paddle_mobile::time(); + + // for (int i = 0; i < 10; ++i) { + // auto vec_result = paddle_mobile.Predict(input, dims); + // } + + // auto time4 = paddle_mobile::time(); + + std::cout << "predict cost :" + << paddle_mobile::time_diff(time3, time4) / max << "ms" + << std::endl; + std::vector::iterator biggest = + std::max_element(std::begin(vec_result), std::end(vec_result)); + std::cout << " Max element is " << *biggest << " at position " + << std::distance(std::begin(vec_result), biggest) << std::endl; + // for (float i : vec_result) { + // std::cout << i << std::endl; + // } + } +} + +void t2() { paddle_mobile::PaddleMobile paddle_mobile; // paddle_mobile.SetThreadNum(4); +#ifdef PADDLE_MOBILE_CL + paddle_mobile.SetCLPath("/data/local/tmp/bin"); +#endif auto time1 = paddle_mobile::time(); - // auto isok = paddle_mobile.Load(std::string(g_mobilenet_detect) + "/model", - // std::string(g_mobilenet_detect) + "/params", true); + auto isok = paddle_mobile.Load(std::string(g_yolo_mul) + "/model", + std::string(g_yolo_mul) + "/params", true); - auto isok = paddle_mobile.Load(std::string(g_yolo_mul), true); + // auto isok = paddle_mobile.Load(std::string(g_yolo_mul), true); if (isok) { auto time2 = paddle_mobile::time(); std::cout << "load cost :" << paddle_mobile::time_diff(time1, time2) << "ms" @@ -62,5 +120,66 @@ int main() { // std::cout << i << std::endl; // } } +} + +void t3() { + paddle_mobile::PaddleMobile paddle_mobile; + // paddle_mobile.SetThreadNum(4); + //#ifdef PADDLE_MOBILE_CL + // paddle_mobile.SetCLPath("/data/local/tmp/bin"); + //#endif + auto time1 = paddle_mobile::time(); + auto isok = paddle_mobile.Load(std::string(g_yolo_mul) + "/model", + std::string(g_yolo_mul) + "/params", true); + + // auto isok = paddle_mobile.Load(std::string(g_yolo_mul), true); + if (isok) { + auto time2 = paddle_mobile::time(); + std::cout << "load cost :" << paddle_mobile::time_diff(time1, time2) << "ms" + << std::endl; + + std::vector input; + std::vector dims{1, 3, 416, 416}; + GetInput(g_yolo_img, &input, dims); + + std::vector vec_result = paddle_mobile.Predict(input, dims); + + auto time3 = paddle_mobile::time(); + int max = 10; + for (int i = 0; i < max; ++i) { + vec_result = paddle_mobile.Predict(input, dims); + } + auto time4 = paddle_mobile::time(); + + // auto time3 = paddle_mobile::time(); + + // for (int i = 0; i < 10; ++i) { + // auto vec_result = paddle_mobile.Predict(input, dims); + // } + + // auto time4 = paddle_mobile::time(); + + std::cout << "predict cost :" + << paddle_mobile::time_diff(time3, time4) / max << "ms" + << std::endl; + std::vector::iterator biggest = + std::max_element(std::begin(vec_result), std::end(vec_result)); + std::cout << " Max element is " << *biggest << " at position " + << std::distance(std::begin(vec_result), biggest) << std::endl; + // for (float i : vec_result) { + // std::cout << i << std::endl; + // } + } +} + +int main() { + // std::thread th1(t1); + // std::thread th2(t2); + std::thread th3(t3); + // std::thread th1(t1); + // th1.join(); + // th2.join(); + th3.join(); + // th1.join(); return 0; }