提交 d89d0cfd 编写于 作者: Y yangfei

add cpu and gpu predict function,optimize feed op kernel

上级 fc9e5bbc
......@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "io/paddle_mobile.h"
#include <CL/cl.h>
#include "common/common.h"
#include "framework/cl/cl_tensor.h"
#include "operators/math/gemm.h"
namespace paddle_mobile {
static std::mutex lc;
......@@ -119,6 +122,40 @@ void PaddleMobile<Dtype, P>::Clear() {
loader_ = nullptr;
}
template <typename Dtype, Precision P>
double PaddleMobile<Dtype, P>::GetCPUPredictTime() {
int m = 32;
int n = 224 * 224;
int k = 27;
int lda = k;
int ldb = n;
int ldc = n;
float *a =
static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m * k));
float *b =
static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * k * n));
float *c =
static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m * n));
int t1 = 1;
int t2 = 1;
for (int i = 0; i < m * k; ++i) {
a[i] = t1 + rand() % t2;
}
for (int i = 0; i < k * n; ++i) {
b[i] = t1 + rand() % t2;
}
paddle_mobile::operators::math::Gemm gemm;
auto time1 = paddle_mobile::time();
gemm.Sgemm(m, n, k, static_cast<float>(1), a, lda, b, ldb,
static_cast<float>(0), c, ldc, false, nullptr);
auto time2 = paddle_mobile::time();
double cost = paddle_mobile::time_diff(time1, time2);
paddle_mobile::memory::Free(a);
paddle_mobile::memory::Free(b);
paddle_mobile::memory::Free(c);
return cost;
}
template <typename Dtype, Precision P>
PaddleMobile<Dtype, P>::~PaddleMobile() {
executor_ = nullptr;
......@@ -167,6 +204,208 @@ void PaddleMobile<Dtype, P>::SetCLPath(std::string path) {
framework::CLEngine::Instance()->setClPath(path);
}
}
template <typename Dtype, Precision P>
double PaddleMobile<Dtype, P>::GetGPUPredictTime() {
cl_int status;
cl_uint nPlatform;
clGetPlatformIDs(0, NULL, &nPlatform);
cl_platform_id *listPlatform =
(cl_platform_id *)malloc(nPlatform * sizeof(cl_platform_id));
clGetPlatformIDs(nPlatform, listPlatform, NULL);
cl_uint nDevice = 0;
clGetDeviceIDs(listPlatform[0], CL_DEVICE_TYPE_GPU, 0, NULL, &nDevice);
cl_device_id *listDevice =
(cl_device_id *)malloc(nDevice * sizeof(cl_device_id));
clGetDeviceIDs(listPlatform[0], CL_DEVICE_TYPE_GPU, nDevice, listDevice,
NULL);
cl_context context =
clCreateContext(NULL, nDevice, listDevice, NULL, NULL, &status);
cl_command_queue queue =
clCreateCommandQueue(context, listDevice[0], 0, &status);
int n = 1;
int c = 3;
int h = 224;
int w = 224;
float *input = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * 3 * 224 * 224));
float *filter = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * 32 * 27));
int input_w = w * (c + 3) / 4;
int input_h = n * h;
int filter_w = 3 * (3 + 3) / 4;
int filter_h = 32 * 3;
int output_w = 224 * (32 + 3) / 4;
int output_h = 1 * 224;
framework::DDim input_dims = {1, 3, 224, 224};
framework::CLTensor input_cl_tensor(context, queue);
input_cl_tensor.Resize(input_dims);
cl_mem inputBuffer = input_cl_tensor.mutable_with_data<float>(input);
framework::DDim filter_dims = {32, 3, 3, 3};
framework::CLTensor filter_cl_tensor(context, queue);
input_cl_tensor.Resize(filter_dims);
cl_mem filterBuffer = filter_cl_tensor.mutable_with_data<float>(filter);
cl_mem cl_filter_image = NULL;
cl_mem cl_input_image = NULL;
cl_mem cl_output_image = NULL;
cl_image_format cf = {.image_channel_order = CL_RGBA,
.image_channel_data_type = CL_HALF_FLOAT};
cl_input_image = clCreateImage2D(context, CL_MEM_READ_WRITE | 0, &cf, input_w,
input_h, 0, NULL, &status);
cl_filter_image = clCreateImage2D(context, CL_MEM_READ_WRITE | 0, &cf,
filter_w, filter_h, 0, NULL, &status);
cl_output_image = clCreateImage2D(context, CL_MEM_READ_WRITE | 0, &cf,
output_w, output_h, 0, NULL, &status);
char *code;
std::string path = framework::CLEngine::Instance()->GetCLPath() +
"/cl_kernel/feed_kernel.cl";
size_t length = readText(path.c_str(), &code);
cl_program program = clCreateProgramWithSource(
context, 1, (const char **)&code, &length, NULL);
std::string path1 = "-cl-fast-relaxed-math -I " +
framework::CLEngine::Instance()->GetCLPath() +
"/cl_kernel";
clBuildProgram(program, 0, 0, path1.c_str(), NULL, NULL);
cl_kernel kernel = clCreateKernel(program, "feed", &status);
status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputBuffer);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_input_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(cl_int), &input_w);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(cl_int), &input_h);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(cl_int), &c);
CL_CHECK_ERRORS(status);
size_t global_work_size[2] = {input_w, input_h};
// cl_event out_event = param.Out()->GetClEvent();
status = clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_work_size,
NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &filterBuffer);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_filter_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(cl_int), &filter_w);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(cl_int), &filter_h);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(cl_int), &c);
CL_CHECK_ERRORS(status);
size_t global_work_size1[2] = {filter_w, filter_h};
// cl_event out_event = param.Out()->GetClEvent();
status = clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_work_size1,
NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
clFinish(queue);
queue = clCreateCommandQueue(context, listDevice[0], 0, &status);
path = framework::CLEngine::Instance()->GetCLPath() +
"/cl_kernel/conv_kernel.cl";
size_t length1 = readText(path.c_str(), &code);
program = clCreateProgramWithSource(context, 1, (const char **)&code,
&length1, &status);
CL_CHECK_ERRORS(status);
clBuildProgram(program, 0, 0, path1.c_str(), NULL, NULL);
kernel = clCreateKernel(program, "conv_3x3", &status);
CL_CHECK_ERRORS(status);
int c_block = (32 + 3) / 4;
int nh = n * h;
int stride = 1;
int offset = 0;
int input_c = (c + 3) / 4;
int dilation = 1;
int input_width = 224;
int input_height = 224;
int output_width = 224;
int output_height = 224;
status = clSetKernelArg(kernel, 0, sizeof(int), &c_block);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(int), &w);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &nh);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &cl_input_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &cl_filter_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &cl_output_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 6, sizeof(int), &stride);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 7, sizeof(int), &offset);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 8, sizeof(int), &input_c);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 9, sizeof(int), &dilation);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 10, sizeof(int), &input_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 11, sizeof(int), &input_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 12, sizeof(int), &output_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 13, sizeof(int), &output_height);
CL_CHECK_ERRORS(status);
// cl_event out_event = param.Output()->GetClEvent();
// cl_event wait_event = param.Input()->GetClEvent();
size_t global_work_size2[3] = {8, 224, 224};
auto time1 = paddle_mobile::time();
status = clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size2,
NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
clFinish(queue);
auto time2 = paddle_mobile::time();
paddle_mobile::memory::Free(input);
paddle_mobile::memory::Free(filter);
return paddle_mobile::time_diff(time1, time2);
}
template <typename Dtype, Precision P>
int PaddleMobile<Dtype, P>::readText(
const char *kernelPath,
char **pcode) // 读取文本文件放入 pcode,返回字符串长度
{
FILE *fp;
int size;
// printf("<readText> File: %s\n", kernelPath);
fp = fopen(kernelPath, "rb");
if (!fp) {
printf("<readText> Open file failed\n");
return -1;
}
if (fseek(fp, 0, SEEK_END) != 0) {
printf("<readText> Seek end of file failed\n");
return -1;
}
if ((size = ftell(fp)) < 0) {
printf("<readText> Get file position failed\n");
return -1;
}
rewind(fp);
if ((*pcode = (char *)malloc(size + 1)) == NULL) {
printf("<readText> Allocate space failed\n");
return -1;
}
fread(*pcode, 1, size, fp);
(*pcode)[size] = '\0';
fclose(fp);
return size + 1;
}
#endif
template class PaddleMobile<CPU, Precision::FP32>;
......
......@@ -65,6 +65,7 @@ class PaddleMobile {
void SetThreadNum(int num);
void Clear();
double GetCPUPredictTime();
~PaddleMobile();
......@@ -80,6 +81,9 @@ class PaddleMobile {
#ifdef PADDLE_MOBILE_CL
public:
void SetCLPath(std::string cl_path);
double GetGPUPredictTime();
int readText(const char *kernelPath,
char **pcode); // 读取文本文件放入 pcode,返回字符串长度
#endif
private:
......
......@@ -13,14 +13,22 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void feed(__global float *in, __write_only image2d_t outputImage,int h,int w)
__kernel void feed(__global float *in, __write_only image2d_t outputImage,int h,int w,int c)
{
int i = get_global_id(0);
int j = get_global_id(1);
half4 pixel;
pixel.x = convert_half(in[(i * w + j)]);
if(c>=2){
pixel.y = convert_half(in[h * w + (i * w + j)]);
}else{
pixel.y = 0.0;
}
if(c>=3){
pixel.z = convert_half(in[2 * h * w + (i * w + j)]);
}else{
pixel.z = 0.0;
}
pixel.w = 0.0;
int2 coords;
coords.x = j;
......
......@@ -34,6 +34,7 @@ void FeedKernel<GPU_CL, float>::Compute(const FeedParam<GPU_CL> &param) {
const float *input_data = input->data<float>();
int numel = input->numel();
cl_mem cl_image = output->GetCLImage();
int c = input->dims()[1];
int height = output->dims()[2];
int width = output->dims()[3];
CLTensor input_cl_tensor(this->cl_helper_.CLContext(),
......@@ -49,6 +50,8 @@ void FeedKernel<GPU_CL, float>::Compute(const FeedParam<GPU_CL> &param) {
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(cl_int), &height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(cl_int), &c);
CL_CHECK_ERRORS(status);
size_t global_work_size[2] = {width, height};
......
......@@ -13,17 +13,74 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <thread>
#include "../../src/common/types.h"
#include "../test_helper.h"
#include "../test_include.h"
void t1() {
paddle_mobile::PaddleMobile<paddle_mobile::GPU_CL> paddle_mobile;
// paddle_mobile.SetThreadNum(4);
#ifdef PADDLE_MOBILE_CL
paddle_mobile.SetCLPath("/data/local/tmp/bin");
#endif
printf("cpu time:%f\n", paddle_mobile.GetCPUPredictTime());
printf("gpu time:%f\n", paddle_mobile.GetGPUPredictTime());
auto time1 = paddle_mobile::time();
auto isok = paddle_mobile.Load(std::string(g_yolo_mul) + "/model",
std::string(g_yolo_mul) + "/params", true);
int main() {
// auto isok = paddle_mobile.Load(std::string(g_yolo_mul), true);
if (isok) {
auto time2 = paddle_mobile::time();
std::cout << "load cost :" << paddle_mobile::time_diff(time1, time2) << "ms"
<< std::endl;
std::vector<float> input;
std::vector<int64_t> dims{1, 3, 416, 416};
GetInput<float>(g_yolo_img, &input, dims);
std::vector<float> vec_result;
// = paddle_mobile.Predict(input, dims);
auto time3 = paddle_mobile::time();
int max = 10;
for (int i = 0; i < max; ++i) {
vec_result = paddle_mobile.Predict(input, dims);
}
auto time4 = paddle_mobile::time();
// auto time3 = paddle_mobile::time();
// for (int i = 0; i < 10; ++i) {
// auto vec_result = paddle_mobile.Predict(input, dims);
// }
// auto time4 = paddle_mobile::time();
std::cout << "predict cost :"
<< paddle_mobile::time_diff(time3, time4) / max << "ms"
<< std::endl;
std::vector<float>::iterator biggest =
std::max_element(std::begin(vec_result), std::end(vec_result));
std::cout << " Max element is " << *biggest << " at position "
<< std::distance(std::begin(vec_result), biggest) << std::endl;
// for (float i : vec_result) {
// std::cout << i << std::endl;
// }
}
}
void t2() {
paddle_mobile::PaddleMobile<paddle_mobile::GPU_CL> paddle_mobile;
// paddle_mobile.SetThreadNum(4);
#ifdef PADDLE_MOBILE_CL
paddle_mobile.SetCLPath("/data/local/tmp/bin");
#endif
auto time1 = paddle_mobile::time();
// auto isok = paddle_mobile.Load(std::string(g_mobilenet_detect) + "/model",
// std::string(g_mobilenet_detect) + "/params", true);
auto isok = paddle_mobile.Load(std::string(g_yolo_mul) + "/model",
std::string(g_yolo_mul) + "/params", true);
auto isok = paddle_mobile.Load(std::string(g_yolo_mul), true);
// auto isok = paddle_mobile.Load(std::string(g_yolo_mul), true);
if (isok) {
auto time2 = paddle_mobile::time();
std::cout << "load cost :" << paddle_mobile::time_diff(time1, time2) << "ms"
......@@ -62,5 +119,64 @@ int main() {
// std::cout << i << std::endl;
// }
}
}
void t3() {
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
// paddle_mobile.SetThreadNum(4);
//#ifdef PADDLE_MOBILE_CL
// paddle_mobile.SetCLPath("/data/local/tmp/bin");
//#endif
auto time1 = paddle_mobile::time();
auto isok = paddle_mobile.Load(std::string(g_yolo_mul) + "/model",
std::string(g_yolo_mul) + "/params", true);
// auto isok = paddle_mobile.Load(std::string(g_yolo_mul), true);
if (isok) {
auto time2 = paddle_mobile::time();
std::cout << "load cost :" << paddle_mobile::time_diff(time1, time2) << "ms"
<< std::endl;
std::vector<float> input;
std::vector<int64_t> dims{1, 3, 416, 416};
GetInput<float>(g_yolo_img, &input, dims);
std::vector<float> vec_result = paddle_mobile.Predict(input, dims);
auto time3 = paddle_mobile::time();
int max = 10;
for (int i = 0; i < max; ++i) {
vec_result = paddle_mobile.Predict(input, dims);
}
auto time4 = paddle_mobile::time();
// auto time3 = paddle_mobile::time();
// for (int i = 0; i < 10; ++i) {
// auto vec_result = paddle_mobile.Predict(input, dims);
// }
// auto time4 = paddle_mobile::time();
std::cout << "predict cost :"
<< paddle_mobile::time_diff(time3, time4) / max << "ms"
<< std::endl;
std::vector<float>::iterator biggest =
std::max_element(std::begin(vec_result), std::end(vec_result));
std::cout << " Max element is " << *biggest << " at position "
<< std::distance(std::begin(vec_result), biggest) << std::endl;
// for (float i : vec_result) {
// std::cout << i << std::endl;
// }
}
}
int main() {
// std::thread th1(t1);
// std::thread th2(t2);
std::thread th1(t1);
// th1.join();
// th2.join();
th1.join();
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册