diff --git a/src/common/common.h b/src/common/common.h index 12157b5e946490d041f0cc0d235142a13a3a2527..c59e6b7932e73bd19b56d4dd081adff8689d5cf3 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -15,6 +15,8 @@ limitations under the License. */ #pragma once #include +namespace paddle_mobile { + using Time = decltype(std::chrono::high_resolution_clock::now()); inline Time time() { return std::chrono::high_resolution_clock::now(); } @@ -25,3 +27,5 @@ inline double time_diff(Time t1, Time t2) { ms counter = std::chrono::duration_cast(diff); return counter.count() / 1000.0; } + +} diff --git a/src/framework/cl/cl_engine.h b/src/framework/cl/cl_engine.h index dc5e8aa60eea33d63df5b024ff0383d2414ce1dc..ee671a1ff276b6597535a0f0bf20b02c46bf5eac 100644 --- a/src/framework/cl/cl_engine.h +++ b/src/framework/cl/cl_engine.h @@ -18,8 +18,8 @@ limitations under the License. */ #include #include "CL/cl.h" -#include "common/log.h" #include "common/enforce.h" +#include "common/log.h" #include "framework/cl/cl_deleter.h" #include "framework/cl/cl_tool.h" diff --git a/src/framework/cl/cl_tool.h b/src/framework/cl/cl_tool.h index 74a20f48185af34c2d509c6e8de23ecab42601cc..25d5bfc584b59e4fe9d22a922b601f8c32892fd1 100644 --- a/src/framework/cl/cl_tool.h +++ b/src/framework/cl/cl_tool.h @@ -21,12 +21,13 @@ namespace framework { const char* opencl_error_to_str(cl_int error); -#define CL_CHECK_ERRORS(ERR) \ - if (ERR != CL_SUCCESS) { \ - printf( \ - "OpenCL error with code %s happened in file %s at line %d. " \ - "Exiting.\n", \ - opencl_error_to_str(ERR), __FILE__, __LINE__); \ +#define CL_CHECK_ERRORS(ERR) \ + if (ERR != CL_SUCCESS) { \ + printf( \ + "OpenCL error with code %s happened in file %s at line %d. " \ + "Exiting.\n", \ + paddle_mobile::framework::opencl_error_to_str(ERR), __FILE__, \ + __LINE__); \ } } // namespace framework diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index ef76623bfd72b91935a5c3ffc39e8ebf37906b1c..274a6a1d82bc1a8b939226a01d74f7b1172c81cd 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -908,10 +908,14 @@ void Executor::InitMemory() { for (const auto &var_desc : block->Vars()) { auto var = program_.scope->Var(var_desc->Name()); if (var_desc->Persistable()) { - auto cl_image = var->template GetMutable(); + CLImage *cl_image = nullptr; if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { + var->template GetMutable(); continue; + } else { + cl_image = var->template GetMutable(); } + char *origin_data = Get_binary_data(program_.model_path + "/" + var_desc->Name()); char *data = origin_data; @@ -928,7 +932,8 @@ void Executor::InitMemory() { framework::DDim ddim = framework::make_ddim(desc.Dims()); - cl_image->Init(context, tensorInput, ddim); + // has not init + cl_image->SetTensorData(tensorInput, ddim); delete origin_data; // paddle_mobile::memory::Free(tensorInput); @@ -941,7 +946,7 @@ void Executor::InitMemory() { // framework::DDim ddim = framework::make_ddim(desc.Dims()); framework::DDim ddim = cl_image->dims(); DLOG << var_desc->Name(); - cl_image->Init(context, ddim); + cl_image->InitEmptyImage(context, ddim); } } } @@ -965,9 +970,12 @@ void Executor::InitCombineMemory() { for (const auto &var_desc : block->Vars()) { auto var = program_.scope->Var(var_desc->Name()); if (var_desc->Persistable()) { - auto cl_image = var->template GetMutable(); + CLImage *cl_image = nullptr; if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { + var->template GetMutable(); continue; + } else { + cl_image = var->template GetMutable(); } cl_context context = program_.scope->GetCLScpoe()->Context(); @@ -982,7 +990,10 @@ void Executor::InitCombineMemory() { float *tensorInput = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * numel)); LoadMemory(*var_desc, tensorInput, &origin_data); - cl_image->Init(context, tensorInput, ddim); + + // has not init + cl_image->SetTensorData(tensorInput, ddim); + paddle_mobile::memory::Free(tensorInput); } else { auto cl_image = var->template GetMutable(); @@ -991,8 +1002,7 @@ void Executor::InitCombineMemory() { const framework::TensorDesc &desc = var_desc->Tensor_desc(); framework::DDim ddim = cl_image->dims(); // framework::DDim ddim = framework::make_ddim(desc.Dims()); - - cl_image->Init(context, ddim); + cl_image->InitEmptyImage(context, ddim); } } } diff --git a/src/operators/kernel/cl/batchnorm_kernel.cpp b/src/operators/kernel/cl/batchnorm_kernel.cpp index a096fae81d0e3d2b03ee582e85f49c1b84627ae2..8770ce70191197790c4e0b1dfbd4523ef83e5d4c 100644 --- a/src/operators/kernel/cl/batchnorm_kernel.cpp +++ b/src/operators/kernel/cl/batchnorm_kernel.cpp @@ -21,12 +21,67 @@ namespace operators { template <> bool BatchNormKernel::Init(BatchNormParam *param) { + this->cl_helper_.AddKernel("batchnorm", "batchnorm_kernel.cl"); + const framework::CLImage *mean = param->InputMean(); + const framework::CLImage *variance = param->InputVariance(); + const framework::CLImage *scale = param->InputScale(); + const framework::CLImage *bias = param->InputBias(); + const float epsilon = param->Epsilon(); + + auto mean_ptr = mean->data(); + auto variance_ptr = variance->data(); + auto scale_ptr = scale->data(); + auto bias_ptr = bias->data(); + + const int C = mean->numel(); + float inv_std_ptr[C]; + for (int i = 0; i < C; i++) { + inv_std_ptr[i] = + 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); + } + float *new_scale_ptr = new float[C]; + float *new_bias_ptr = new float[C]; + + for (int i = 0; i < C; i++) { + new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; + new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; + } + + delete[](new_scale_ptr); + delete[](new_bias_ptr); + + framework::CLImage *new_scale = new framework::CLImage(); + framework::CLImage *new_bias = new framework::CLImage(); + + param->SetNewScale(new_scale); + param->SetNewBias(new_bias); + return true; } template <> void BatchNormKernel::Compute( - const BatchNormParam ¶m) {} + const BatchNormParam ¶m) { + auto kernel = this->cl_helper_.KernelAt(0); + auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.OutputY()); + + auto input = param.InputX()->GetCLImage(); + auto out = param.OutputY()->GetCLImage(); + auto new_scale = param.NewScale()->GetCLImage(); + auto new_bias = param.NewBias()->GetCLImage(); + const int out_height = param.OutputY()->HeightOfOneBlock(); + const int out_width = param.OutputY()->WidthOfOneBlock(); + + clSetKernelArg(kernel, 0, sizeof(int), &out_height); + clSetKernelArg(kernel, 1, sizeof(int), &out_width); + clSetKernelArg(kernel, 2, sizeof(cl_mem), &input); + clSetKernelArg(kernel, 3, sizeof(cl_mem), &new_scale); + clSetKernelArg(kernel, 4, sizeof(cl_mem), &new_bias); + clSetKernelArg(kernel, 5, sizeof(cl_mem), &out); + + clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, + default_work_size.data(), NULL, 0, NULL, NULL); +} template class BatchNormKernel; diff --git a/src/operators/kernel/cl/cl_kernel/batchnorm_kernel.cl b/src/operators/kernel/cl/cl_kernel/batchnorm_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..d2cc2151422255f48f81550f7424ec2dccb3be41 --- /dev/null +++ b/src/operators/kernel/cl/cl_kernel/batchnorm_kernel.cl @@ -0,0 +1,24 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void batchnorm(__private const int out_height, + __private const int out_width, + __read_only image2d_t input, + __read_only image2d_t new_scale, + __read_only image2d_t new_bias, + __write_only image2d_t output) { + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + half4 new_scale = read_imageh(bn_scale, sampler, (int2)(out_c, 0)); + half4 new_bias = read_imageh(bn_bias, sampler, (int2)(out_c, 0)); + + int pos_x = mad24(out_c, out_width, out_w); + half4 in = read_imageh(input, sampler, (int2)(pos_x, out_nh)); + half4 out = mad(in, new_scale, new_bias); + + write_imageh(output, (int2)(pos_x, nh), out); +} diff --git a/src/operators/kernel/cl/cl_kernel/fetch_kernel.cl b/src/operators/kernel/cl/cl_kernel/fetch_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..10f39f9cf9549a6c1a5abe2af905f94f7355220e --- /dev/null +++ b/src/operators/kernel/cl/cl_kernel/fetch_kernel.cl @@ -0,0 +1,27 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void fetch(__private const int in_height, + __private const int in_width, + __private const int size_ch, + __private const int size_block, + __private const int size_batch, + __read_only image2d_t input, + __global float* out) { + const int in_c = get_global_id(0); + const int in_w = get_global_id(1); + const int in_nh = get_global_id(2); + const int in_n = in_nh / in_height; + const int in_h = in_nh % in_height; + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + const int pos_x = mad24(in_c, in_width, in_w); + half4 in = read_imageh(input, sampler, (int2)(pos_x, in_nh)); + + const int index = in_n * size_batch + in_c * size_block + in_h * in_width + in_w; + out[index] = convert_float(in.x); + out[index + size_ch] = convert_float(in.y); + out[index + size_ch * 2] = convert_float(in.z); + out[index + size_ch * 3] = convert_float(in.w); +} diff --git a/src/operators/kernel/cl/cl_kernel/pool_kernel.cl b/src/operators/kernel/cl/cl_kernel/pool_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..18246fddcfb803adeae5cc9e2efeba1a4362aa2e --- /dev/null +++ b/src/operators/kernel/cl/cl_kernel/pool_kernel.cl @@ -0,0 +1,75 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#define MIN_VALUE -FLT_MAX + +__kernel void pool_max( + __private const int in_height, __private const int in_width, + __private const int out_height, __private const int out_width, + __private const int pad_top, __private const int pad_left, + __private const int stride_h, __private const int stride_w, + __private const int ksize_h, __private const int ksize_w, + __read_only image2d_t input, __write_only image2d_t output) { + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + const int out_n = out_nh / out_height; + const int out_h = out_nh % out_height; + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + int start_h = max(out_h * stride_h - pad_top, 0); + int end_h = min(start_h + ksize_h, in_height); + + int start_w = max(out_w * stride_w - pad_left, 0); + int end_w = min(start_w + ksize_w, in_width); + + const int pos_in_x = out_c * in_width; + const int pos_in_y = out_n * in_height; + half4 max_value = (half4)(MIN_VALUE); + for (int y = start_h; y < end_h; ++y) { + for (int x = start_w; x < end_w; ++x) { + half4 tmp = read_imageh(input, sampler, (int2)(pos_in_x + x, pos_in_y + y)); + max_value = max(max_value, tmp); + } + } + + const int pos_out_x = mad24(out_c, out_width, out_w); + write_imageh(output, (int2)(pos_out_x, out_nh), max_value); +} + +__kernel void pool_avg( + __private const int in_height, __private const int in_width, + __private const int out_height, __private const int out_width, + __private const int pad_top, __private const int pad_left, + __private const int stride_h, __private const int stride_w, + __private const int ksize_h, __private const int ksize_w, + __read_only image2d_t input, __write_only image2d_t output) { + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + const int out_n = out_nh / out_height; + const int out_h = out_nh % out_height; + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + int start_h = max(out_h * stride_h - pad_top, 0); + int end_h = min(start_h + ksize_h, in_height); + + int start_w = max(out_w * stride_w - pad_left, 0); + int end_w = min(start_w + ksize_w, in_width); + + const int pos_in_x = out_c * in_width; + const int pos_in_y = out_n * in_height; + half4 sum = (half4)(0.0f); + int num = 0; + for (int y = start_h; y < end_h; ++y) { + for (int x = start_w; x < end_w; ++x) { + sum += read_imageh(input, sampler, (int2)(pos_in_x + x, pos_in_y + y)); + num++; + } + } + half4 avg = sum / num; + const int pos_out_x = mad24(out_c, out_width, out_w); + write_imageh(output, (int2)(pos_out_x, out_nh), avg); +} \ No newline at end of file diff --git a/src/operators/kernel/cl/cl_kernel/relu.cl b/src/operators/kernel/cl/cl_kernel/relu.cl new file mode 100644 index 0000000000000000000000000000000000000000..e773d1c2577461abb35fabfa752ffc272970492b --- /dev/null +++ b/src/operators/kernel/cl/cl_kernel/relu.cl @@ -0,0 +1,25 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +__kernel void relu(__read_only image2d_t input, + __write_only image2d_t output) + const int x = get_global_id(0); + const int y = get_global_id(1); + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + half4 r = read_imageh(input, sampler, int2(x, y)); + r = max(half4(0, 0, 0, 0), r); + write_imageh(output, int2(x, y), r); +} \ No newline at end of file diff --git a/src/operators/kernel/cl/cl_kernel/reshape.cl b/src/operators/kernel/cl/cl_kernel/reshape.cl new file mode 100644 index 0000000000000000000000000000000000000000..4055445d1576b2ca54919ed03ad187d08cff14c2 --- /dev/null +++ b/src/operators/kernel/cl/cl_kernel/reshape.cl @@ -0,0 +1,49 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +__kernel void reshape(__read_only image2d_t input, + __write_only image2d_t output, + __private const int d0, + __private const int d1, + __private const int d2, + __private const int d3, + __private const int x0, + __private const int x1, + __private const int x2, + __private const int x3) { + const int x = get_global_id(0); + const int y = get_global_id(1); + int obx = x / x3; + int oby = y / x2; + int ox = x % x3; + int oy = y % x2; + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + half4 r; + for (int i = 0; i < 4; i++) { + int t = obx * 4 + i; + if (t > x1) break; + int oindex = oby * x1 * x2 * x3 + t * x2 * x3 + ox * x3 + oy; + int i0, i1, i2, i3; + int i3 = oindex % d3; oindex /= d3; + int i2 = oindex % d2; oindex /= d2; + int i1 = oindex % d1; oindex /= d1; + int i0 = oindex; + int ix = (i1 / 4) * d3 + i3; + int iy = i0 * d2 + i2; + r[i] = read_imageh(input, sampler, int2(ix, iy))[i1%4]; + } + write_imageh(output, int2(x, y), r); +} \ No newline at end of file diff --git a/src/operators/kernel/cl/cl_kernel/softmax.cl b/src/operators/kernel/cl/cl_kernel/softmax.cl new file mode 100644 index 0000000000000000000000000000000000000000..60f0cf409596632b67817cd236f9621010522571 --- /dev/null +++ b/src/operators/kernel/cl/cl_kernel/softmax.cl @@ -0,0 +1,41 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +__kernel void softmax(__read_only image2d_t input, + __write_only image2d_t output, + __private const int d0, + __private const int d1, + __private const int d2, + __private const int d3) { + const int z = get_global_id(0); + const int x = get_global_id(1); + const int y = get_global_id(2); + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + half4 maxv = read_imageh(input, sampler, int2(z * d3, y)); + half4 buf[d3] = {piece}; + for (int i = 1; i < d3; i++) { + buf[i] = read_imageh(input, sampler, int2(z * d3 + i, y)); + maxv = max(maxv, buf[i]); + } + float4 sum = 0; + for (int i = 0; i < d3; i++) { + buf[i] = exp(buf[i] - maxv); + sum += buf[i]; + } + half4 r = buf[x] / sum; + + write_imageh(output, int2(z * d3 + x, y), r); +} diff --git a/src/operators/kernel/cl/conv_add_bn_relu_kernel.cpp b/src/operators/kernel/cl/conv_add_bn_relu_kernel.cpp index fd846be8024bd2742f6825f08993f17dfcd3509a..29b13b6abc3943ff23ce3f5a98962ecb4d9c2d7a 100644 --- a/src/operators/kernel/cl/conv_add_bn_relu_kernel.cpp +++ b/src/operators/kernel/cl/conv_add_bn_relu_kernel.cpp @@ -16,6 +16,7 @@ limitations under the License. */ #include "operators/kernel/conv_add_bn_relu_kernel.h" #include "framework/cl/cl_image.h" +#include "framework/cl/cl_tool.h" namespace paddle_mobile { namespace operators { @@ -56,15 +57,15 @@ bool ConvAddBNReluKernel::Init( framework::CLImage *new_scale = new framework::CLImage(); - new_scale->Init(this->cl_helper_.CLContext(), new_scale_ptr, - variance->dims()); + new_scale->SetTensorData(new_scale_ptr, variance->dims()); + new_scale->InitCLImage(this->cl_helper_.CLContext()); framework::CLImage *new_bias = new framework::CLImage(); - new_bias->Init(this->cl_helper_.CLContext(), new_bias_ptr, variance->dims()); + new_bias->SetTensorData(new_bias_ptr, variance->dims()); + new_bias->InitCLImage(this->cl_helper_.CLContext()); param->SetNewScale(new_scale); - param->SetNewBias(new_bias); PADDLE_MOBILE_ENFORCE( @@ -115,26 +116,32 @@ void ConvAddBNReluKernel::Compute( int output_width = param.Output()->WidthOfOneBlock(); int output_height = param.Output()->HeightOfOneBlock(); - clSetKernelArg(kernel, 0, sizeof(int), &c_block); - clSetKernelArg(kernel, 1, sizeof(int), &w); - clSetKernelArg(kernel, 2, sizeof(int), &nh); - clSetKernelArg(kernel, 3, sizeof(cl_mem), &input); - clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter); - clSetKernelArg(kernel, 5, sizeof(cl_mem), &biase); - clSetKernelArg(kernel, 6, sizeof(cl_mem), &new_scale); - clSetKernelArg(kernel, 7, sizeof(cl_mem), &new_bias); - clSetKernelArg(kernel, 8, sizeof(cl_mem), &output); - clSetKernelArg(kernel, 9, sizeof(int), &stride); - clSetKernelArg(kernel, 10, sizeof(int), &offset); - clSetKernelArg(kernel, 11, sizeof(int), &input_c); - clSetKernelArg(kernel, 12, sizeof(int), &dilation); - clSetKernelArg(kernel, 13, sizeof(int), &input_width); - clSetKernelArg(kernel, 14, sizeof(int), &input_height); - clSetKernelArg(kernel, 15, sizeof(int), &output_width); - clSetKernelArg(kernel, 16, sizeof(int), &output_height); - - clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, - default_work_size.data(), NULL, 0, NULL, NULL); + cl_int status; + + status = clSetKernelArg(kernel, 0, sizeof(int), &c_block); + status = clSetKernelArg(kernel, 1, sizeof(int), &w); + status = clSetKernelArg(kernel, 2, sizeof(int), &nh); + status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &input); + status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter); + status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &biase); + status = clSetKernelArg(kernel, 6, sizeof(cl_mem), &new_scale); + status = clSetKernelArg(kernel, 7, sizeof(cl_mem), &new_bias); + status = clSetKernelArg(kernel, 8, sizeof(cl_mem), &output); + status = clSetKernelArg(kernel, 9, sizeof(int), &stride); + status = clSetKernelArg(kernel, 10, sizeof(int), &offset); + status = clSetKernelArg(kernel, 11, sizeof(int), &input_c); + status = clSetKernelArg(kernel, 12, sizeof(int), &dilation); + status = clSetKernelArg(kernel, 13, sizeof(int), &input_width); + status = clSetKernelArg(kernel, 14, sizeof(int), &input_height); + status = clSetKernelArg(kernel, 15, sizeof(int), &output_width); + status = clSetKernelArg(kernel, 16, sizeof(int), &output_height); + + CL_CHECK_ERRORS(status); + + status = + clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, + default_work_size.data(), NULL, 0, NULL, NULL); + CL_CHECK_ERRORS(status); } template class ConvAddBNReluKernel; diff --git a/src/operators/kernel/cl/conv_add_kernel.cpp b/src/operators/kernel/cl/conv_add_kernel.cpp index 696ae01bcc24f180ff26d10c13ecc81da51bb10e..b5fd82c47a9af4fc383cfe276d08dfb365b6bff3 100644 --- a/src/operators/kernel/cl/conv_add_kernel.cpp +++ b/src/operators/kernel/cl/conv_add_kernel.cpp @@ -65,24 +65,31 @@ void ConvAddKernel::Compute( int output_width = param.Output()->WidthOfOneBlock(); int output_height = param.Output()->HeightOfOneBlock(); - clSetKernelArg(kernel, 0, sizeof(int), &c_block); - clSetKernelArg(kernel, 1, sizeof(int), &w); - clSetKernelArg(kernel, 2, sizeof(int), &nh); - clSetKernelArg(kernel, 3, sizeof(cl_mem), &input); - clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter); - clSetKernelArg(kernel, 5, sizeof(cl_mem), &biase); - clSetKernelArg(kernel, 6, sizeof(cl_mem), &output); - clSetKernelArg(kernel, 7, sizeof(int), &stride); - clSetKernelArg(kernel, 8, sizeof(int), &offset); - clSetKernelArg(kernel, 9, sizeof(int), &input_c); - clSetKernelArg(kernel, 10, sizeof(int), &dilation); - clSetKernelArg(kernel, 11, sizeof(int), &input_width); - clSetKernelArg(kernel, 12, sizeof(int), &input_height); - clSetKernelArg(kernel, 13, sizeof(int), &output_width); - clSetKernelArg(kernel, 14, sizeof(int), &output_height); - - clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, - default_work_size.data(), NULL, 0, NULL, NULL); + cl_int status; + + status = clSetKernelArg(kernel, 0, sizeof(int), &c_block); + status = clSetKernelArg(kernel, 1, sizeof(int), &w); + status = clSetKernelArg(kernel, 2, sizeof(int), &nh); + status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &input); + status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter); + status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &biase); + status = clSetKernelArg(kernel, 6, sizeof(cl_mem), &output); + status = clSetKernelArg(kernel, 7, sizeof(int), &stride); + status = clSetKernelArg(kernel, 8, sizeof(int), &offset); + status = clSetKernelArg(kernel, 9, sizeof(int), &input_c); + status = clSetKernelArg(kernel, 10, sizeof(int), &dilation); + status = clSetKernelArg(kernel, 11, sizeof(int), &input_width); + status = clSetKernelArg(kernel, 12, sizeof(int), &input_height); + status = clSetKernelArg(kernel, 13, sizeof(int), &output_width); + status = clSetKernelArg(kernel, 14, sizeof(int), &output_height); + + CL_CHECK_ERRORS(status); + + status = + clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, + default_work_size.data(), NULL, 0, NULL, NULL); + + CL_CHECK_ERRORS(status); } template class ConvAddKernel; diff --git a/src/operators/kernel/cl/conv_kernel.cpp b/src/operators/kernel/cl/conv_kernel.cpp index ee7b56629a58eb616453c1e9a4065b2cb8cf3d8f..d31553b60ef2827e9e818443e49a4be148305cf4 100644 --- a/src/operators/kernel/cl/conv_kernel.cpp +++ b/src/operators/kernel/cl/conv_kernel.cpp @@ -21,63 +21,69 @@ namespace operators { template <> bool ConvKernel::Init(ConvParam *param) { - // PADDLE_MOBILE_ENFORCE( - // param->Filter()->dims()[2] == param->Filter()->dims()[3] && - // param->Paddings()[0] == param->Paddings()[1], - // "need equal"); - // int offset = static_cast(param->Filter()->dims()[2]) / 2 - - // static_cast(param->Paddings()[1]); - // param->SetOffset(offset); - // - // if (param->Filter()->WidthOfOneBlock() == 1 && - // param->Filter()->HeightOfOneBlock() == 1) { - // this->cl_helper_.AddKernel("conv_1x1", "conv_add_bn_relu_kernel.cl"); - // } else if (param->Filter()->dims()[1] == 1) { - // this->cl_helper_.AddKernel("depth_conv_3x3", - // "conv_add_bn_relu_kernel.cl"); - // } else if (param->Filter()->WidthOfOneBlock() == 3 && - // param->Filter()->HeightOfOneBlock() == 3) { - // this->cl_helper_.AddKernel("conv_3x3", "conv_add_bn_relu_kernel.cl"); - // } else { - // PADDLE_MOBILE_THROW_EXCEPTION(" not support "); - // } + PADDLE_MOBILE_ENFORCE( + param->Filter()->dims()[2] == param->Filter()->dims()[3] && + param->Paddings()[0] == param->Paddings()[1], + "need equal"); + + int offset = static_cast(param->Filter()->dims()[2]) / 2 - + static_cast(param->Paddings()[1]); + param->SetOffset(offset); + + if (param->Filter()->WidthOfOneBlock() == 1 && + param->Filter()->HeightOfOneBlock() == 1) { + this->cl_helper_.AddKernel("conv_1x1", "conv_add_bn_relu_kernel.cl"); + } else if (param->Filter()->dims()[1] == 1) { + this->cl_helper_.AddKernel("depth_conv_3x3", "conv_add_bn_relu_kernel.cl"); + } else if (param->Filter()->WidthOfOneBlock() == 3 && + param->Filter()->HeightOfOneBlock() == 3) { + this->cl_helper_.AddKernel("conv_3x3", "conv_add_bn_relu_kernel.cl"); + } else { + PADDLE_MOBILE_THROW_EXCEPTION(" not support "); + } return true; } template <> void ConvKernel::Compute(const ConvParam ¶m) { - // auto kernel = this->cl_helper_.KernelAt(0); - // auto default_work_size = - // this->cl_helper_.DefaultWorkSize(*param.Output()); int c_block = - // default_work_size[0]; int w = default_work_size[1]; int nh = - // default_work_size[2]; auto input = param.Input()->GetCLImage(); auto - // filter = param.Filter()->GetCLImage(); auto output = param.Output(); int - // stride = param.Strides()[0]; int offset = param.Offset(); int input_c = - // param.Input()->CBlock(); int dilation = param.Dilations()[0]; int - // input_width = param.Input()->WidthOfOneBlock(); int input_height = - // param.Input()->HeightOfOneBlock(); - // - // clSetKernelArg(kernel, 0, sizeof(int), &c_block); - // clSetKernelArg(kernel, 1, sizeof(int), &w); - // clSetKernelArg(kernel, 2, sizeof(int), &nh); - // clSetKernelArg(kernel, 3, sizeof(cl_mem), &input); - // clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter); - // clSetKernelArg(kernel, 5, sizeof(cl_mem), &output); - // clSetKernelArg(kernel, 6, sizeof(int), &stride); - // clSetKernelArg(kernel, 7, sizeof(int), &offset); - // clSetKernelArg(kernel, 8, sizeof(int), &input_c); - // clSetKernelArg(kernel, 9, sizeof(int), &dilation); - // clSetKernelArg(kernel, 10, sizeof(int), &input_width); - // clSetKernelArg(kernel, 11, sizeof(int), &input_height); - // - // clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, - // default_work_size.data(), NULL, 0, NULL, NULL); - - // auto kernel = this->cl_helper_.KernelAt(0); - // size_t global_work_size[3] = {1, 2, 3}; - // clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, - // global_work_size, NULL, 0, NULL, NULL); + auto kernel = this->cl_helper_.KernelAt(0); + auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Output()); + int c_block = default_work_size[0]; + int w = default_work_size[1]; + int nh = default_work_size[2]; + auto input = param.Input()->GetCLImage(); + auto filter = param.Filter()->GetCLImage(); + auto output = param.Output(); + int stride = param.Strides()[0]; + int offset = param.Offset(); + int input_c = param.Input()->CBlock(); + int dilation = param.Dilations()[0]; + int input_width = param.Input()->WidthOfOneBlock(); + int input_height = param.Input()->HeightOfOneBlock(); + + cl_int status; + + status = clSetKernelArg(kernel, 0, sizeof(int), &c_block); + status = clSetKernelArg(kernel, 1, sizeof(int), &w); + status = clSetKernelArg(kernel, 2, sizeof(int), &nh); + status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &input); + status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter); + status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &output); + status = clSetKernelArg(kernel, 6, sizeof(int), &stride); + status = clSetKernelArg(kernel, 7, sizeof(int), &offset); + status = clSetKernelArg(kernel, 8, sizeof(int), &input_c); + status = clSetKernelArg(kernel, 9, sizeof(int), &dilation); + status = clSetKernelArg(kernel, 10, sizeof(int), &input_width); + status = clSetKernelArg(kernel, 11, sizeof(int), &input_height); + + CL_CHECK_ERRORS(status); + + status = + clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, + default_work_size.data(), NULL, 0, NULL, NULL); + + CL_CHECK_ERRORS(status); } template class ConvKernel; diff --git a/src/operators/kernel/cl/depthwise_conv_kernel.cpp b/src/operators/kernel/cl/depthwise_conv_kernel.cpp index f292dff7a7dea2795108d74663750a51b6ba877b..99b5a714d6e2d5d6951b800f8f2cdc56c9241e79 100644 --- a/src/operators/kernel/cl/depthwise_conv_kernel.cpp +++ b/src/operators/kernel/cl/depthwise_conv_kernel.cpp @@ -55,23 +55,30 @@ void DepthwiseConvKernel::Compute( int output_width = param.Output()->WidthOfOneBlock(); int output_height = param.Output()->HeightOfOneBlock(); - clSetKernelArg(kernel, 0, sizeof(int), &c_block); - clSetKernelArg(kernel, 1, sizeof(int), &w); - clSetKernelArg(kernel, 2, sizeof(int), &nh); - clSetKernelArg(kernel, 3, sizeof(cl_mem), &input); - clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter); - clSetKernelArg(kernel, 5, sizeof(cl_mem), &output); - clSetKernelArg(kernel, 6, sizeof(int), &stride); - clSetKernelArg(kernel, 7, sizeof(int), &offset); - clSetKernelArg(kernel, 8, sizeof(int), &input_c); - clSetKernelArg(kernel, 9, sizeof(int), &dilation); - clSetKernelArg(kernel, 10, sizeof(int), &input_width); - clSetKernelArg(kernel, 11, sizeof(int), &input_height); - clSetKernelArg(kernel, 12, sizeof(int), &output_width); - clSetKernelArg(kernel, 13, sizeof(int), &output_height); - - clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, - default_work_size.data(), NULL, 0, NULL, NULL); + cl_int status; + + status = clSetKernelArg(kernel, 0, sizeof(int), &c_block); + status = clSetKernelArg(kernel, 1, sizeof(int), &w); + status = clSetKernelArg(kernel, 2, sizeof(int), &nh); + status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &input); + status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter); + status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &output); + status = clSetKernelArg(kernel, 6, sizeof(int), &stride); + status = clSetKernelArg(kernel, 7, sizeof(int), &offset); + status = clSetKernelArg(kernel, 8, sizeof(int), &input_c); + status = clSetKernelArg(kernel, 9, sizeof(int), &dilation); + status = clSetKernelArg(kernel, 10, sizeof(int), &input_width); + status = clSetKernelArg(kernel, 11, sizeof(int), &input_height); + status = clSetKernelArg(kernel, 12, sizeof(int), &output_width); + status = clSetKernelArg(kernel, 13, sizeof(int), &output_height); + + CL_CHECK_ERRORS(status); + + status = + clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, + default_work_size.data(), NULL, 0, NULL, NULL); + + CL_CHECK_ERRORS(status); } template class DepthwiseConvKernel; diff --git a/src/operators/kernel/cl/feed_kernel.cpp b/src/operators/kernel/cl/feed_kernel.cpp index 4cbdd6e5d1e872b2d98b0653f3172800deb8c9cb..7eac3e8fbee5462e1158d0c942e5300d597e780f 100644 --- a/src/operators/kernel/cl/feed_kernel.cpp +++ b/src/operators/kernel/cl/feed_kernel.cpp @@ -27,6 +27,7 @@ bool FeedKernel::Init(FeedParam *param) { template <> void FeedKernel::Compute(const FeedParam ¶m) { +<<<<<<< HEAD auto kernel = this->cl_helper_.KernelAt(0); cl_int status; auto output = param.Out(); @@ -38,6 +39,19 @@ void FeedKernel::Compute(const FeedParam ¶m) { int height = output->dims()[2]; int width = output->dims()[3]; DLOG << output->dims(); +======= + DLOG << "feed_kernel"; + auto kernel = this->cl_helper_.KernelAt(0); + cl_int status; + auto output = param.Out(); + auto input = param.InputX(); + DLOG << " input: " << input; + + const float *input_data = input->data(); + cl_mem cl_image = output->GetCLImage(); + int height = output->dims()[2]; + int width = output->dims()[3]; +>>>>>>> df230944d11f0f09aea4c2c6bc0489d8667fa8ca status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &input_data); status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_image); status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &width); diff --git a/src/operators/kernel/cl/fetch_kernel.cpp b/src/operators/kernel/cl/fetch_kernel.cpp index d10bfe7a4bd64c8eb0aaa6ae85f531d3d3dce169..995713ce5afaf0a93bc6b8ddd9928d7cee1c55ff 100644 --- a/src/operators/kernel/cl/fetch_kernel.cpp +++ b/src/operators/kernel/cl/fetch_kernel.cpp @@ -19,11 +19,45 @@ namespace operators { template <> bool FetchKernel::Init(FetchParam *param) { + this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl"); return true; } template <> -void FetchKernel::Compute(const FetchParam ¶m) {} +void FetchKernel::Compute(const FetchParam ¶m) { + auto kernel = this->cl_helper_.KernelAt(0); + auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.InputX()); + + auto input = param.InputX()->GetCLImage(); + auto *out = param.Out(); + + const auto &dims = param.InputX()->dims(); + const int N = dims[0]; + const int C = dims[1]; + const int in_height = dims[2]; + const int in_width = dims[3]; + + int size_ch = in_height * in_width; + int size_block = size_ch * 4; + int size_batch = size_ch * C; + + // need create outputBuffer + cl_image_format imageFormat; + imageFormat.image_channel_order = CL_RGBA; + imageFormat.image_channel_data_type = CL_FLOAT; + cl_mem outputBuffer; + + clSetKernelArg(kernel, 0, sizeof(int), &in_height); + clSetKernelArg(kernel, 1, sizeof(int), &in_width); + clSetKernelArg(kernel, 2, sizeof(int), &size_ch); + clSetKernelArg(kernel, 3, sizeof(int), &size_block); + clSetKernelArg(kernel, 4, sizeof(int), &size_batch); + clSetKernelArg(kernel, 5, sizeof(cl_mem), &input); + clSetKernelArg(kernel, 6, sizeof(cl_mem), &outputBuffer); + + clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, + default_work_size.data(), NULL, 0, NULL, NULL); +} template class FetchKernel; diff --git a/src/operators/kernel/cl/pool_kernel.cpp b/src/operators/kernel/cl/pool_kernel.cpp index c24a1babf106afe07e3b3dd30727ed1419af5bf8..802de26e6147aa0bf5d9467c6c6cab0f0148fe59 100644 --- a/src/operators/kernel/cl/pool_kernel.cpp +++ b/src/operators/kernel/cl/pool_kernel.cpp @@ -21,11 +21,51 @@ namespace operators { template <> bool PoolKernel::Init(PoolParam *param) { + std::string pooling_type = param->PoolingType(); + this->cl_helper_.AddKernel("pool_" + pooling_type, "pool_kernel.cl"); return true; } template <> -void PoolKernel::Compute(const PoolParam ¶m) {} +void PoolKernel::Compute(const PoolParam ¶m) { + auto kernel = this->cl_helper_.KernelAt(0); + auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Output()); + + auto input = param.Input()->GetCLImage(); + auto out = param.Output()->GetCLImage(); + + const int in_height = param.Input()->HeightOfOneBlock(); + const int in_width = param.Input()->WidthOfOneBlock(); + const int out_height = param.Output()->HeightOfOneBlock(); + const int out_width = param.Output()->WidthOfOneBlock(); + + std::string pooling_type = param.PoolingType(); + std::vector ksize = param.Ksize(); + std::vector strides = param.Strides(); + std::vector paddings = param.Paddings(); + const int pad_top = paddings[0]; + const int pad_left = paddings[1]; + const int stride_h = strides[0]; + const int stride_w = strides[1]; + const int ksize_h = ksize[0]; + const int ksize_w = ksize[1]; + + clSetKernelArg(kernel, 0, sizeof(cl_int), &in_height); + clSetKernelArg(kernel, 1, sizeof(cl_int), &in_width); + clSetKernelArg(kernel, 2, sizeof(cl_int), &out_height); + clSetKernelArg(kernel, 3, sizeof(cl_int), &out_width); + clSetKernelArg(kernel, 4, sizeof(cl_int), &pad_top); + clSetKernelArg(kernel, 5, sizeof(cl_int), &pad_left); + clSetKernelArg(kernel, 6, sizeof(cl_int), &stride_h); + clSetKernelArg(kernel, 7, sizeof(cl_int), &stride_w); + clSetKernelArg(kernel, 8, sizeof(cl_int), &ksize_h); + clSetKernelArg(kernel, 9, sizeof(cl_int), &ksize_w); + clSetKernelArg(kernel, 10, sizeof(cl_mem), &input); + clSetKernelArg(kernel, 11, sizeof(cl_mem), &out); + + clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, + default_work_size.data(), NULL, 0, NULL, NULL); +} template class PoolKernel; diff --git a/src/operators/kernel/cl/relu_kernel.cpp b/src/operators/kernel/cl/relu_kernel.cpp index f38c29f1827cd61b18a0dd59773e63169a4445a7..71304b9c307f36f7a3db754a7a41958e206f77cd 100644 --- a/src/operators/kernel/cl/relu_kernel.cpp +++ b/src/operators/kernel/cl/relu_kernel.cpp @@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifdef RELU_OP #include "operators/kernel/relu_kernel.h" @@ -18,14 +19,28 @@ namespace paddle_mobile { namespace operators { template <> -bool ReluKernel::Init(ReluParam *param) { +bool ReluKernel::Init(ReluParam* param) { + this->cl_helper_.AddKernel("relu", "relu.cl"); return true; } template <> -void ReluKernel::Compute(const ReluParam ¶m) {} +void ReluKernel::Compute(const ReluParam& param) { + auto kernel = this->cl_helper_.KernelAt(0); + const auto* input = param.InputX(); + auto* output = param.Out(); + auto default_work_size = this->cl_helper_.DefaultWorkSize(*output); + auto inputImage = input->GetCLImage(); + auto outputImage = output->GetCLImage(); + clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputImage); + clSetKernelArg(kernel, 1, sizeof(cl_mem), &outputImage); + const size_t work_size[2] = {input->ImageWidth(), input->ImageHeight()}; + clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, + work_size, NULL, 0, NULL, NULL); +} template class ReluKernel; } // namespace operators } // namespace paddle_mobile +#endif diff --git a/src/operators/kernel/cl/reshape_kernel.cpp b/src/operators/kernel/cl/reshape_kernel.cpp index bc6eb2834b3cff1720ddb7ffb8b4272cf8abbbeb..877a325636a223913ebd10eb5724509719ac5717 100644 --- a/src/operators/kernel/cl/reshape_kernel.cpp +++ b/src/operators/kernel/cl/reshape_kernel.cpp @@ -19,11 +19,36 @@ namespace operators { template <> bool ReshapeKernel::Init(ReshapeParam *param) { + this->cl_helper_.AddKernel("reshape", "reshape.cl"); return true; } template <> -void ReshapeKernel::Compute(const ReshapeParam ¶m) {} +void ReshapeKernel::Compute(const ReshapeParam ¶m) { + auto kernel = this->cl_helper_.KernelAt(0); + const auto *input = param.InputX(); + auto *output = param.Out(); + auto inputImage = input->GetCLImage(); + auto outputImage = output->GetCLImage(); + clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputImage); + clSetKernelArg(kernel, 1, sizeof(cl_mem), &outputImage); + const auto &inputDim = input->dims(); + const auto &outputDim = output->dims(); + int dims[4] = {inputDim[0], inputDim[1], inputDim[2], inputDim[3]}; + int odims[4] = {outputDim[0], outputDim[1], outputDim[2], outputDim[3]}; + clSetKernelArg(kernel, 2, sizeof(int), dims); + clSetKernelArg(kernel, 3, sizeof(int), dims + 1); + clSetKernelArg(kernel, 4, sizeof(int), dims + 2); + clSetKernelArg(kernel, 5, sizeof(int), dims + 3); + clSetKernelArg(kernel, 6, sizeof(int), odims); + clSetKernelArg(kernel, 7, sizeof(int), odims + 1); + clSetKernelArg(kernel, 8, sizeof(int), odims + 2); + clSetKernelArg(kernel, 9, sizeof(int), odims + 3); + const size_t work_size[2] = {output->ImageWidth(), output->ImageHeight()}; + + clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2, NULL, + work_size, NULL, 0, NULL, NULL); +} template class ReshapeKernel; diff --git a/src/operators/kernel/cl/softmax_kernel.cpp b/src/operators/kernel/cl/softmax_kernel.cpp index d0a97cf076c5fe22c7b2612629616053c63dec6c..1404ea40c703c8da2db09551fc6da440771f7366 100644 --- a/src/operators/kernel/cl/softmax_kernel.cpp +++ b/src/operators/kernel/cl/softmax_kernel.cpp @@ -21,11 +21,30 @@ namespace operators { template <> bool SoftmaxKernel::Init(SoftmaxParam *param) { + this->cl_helper_.AddKernel("softmax", "softmax.cl"); return true; } template <> -void SoftmaxKernel::Compute(const SoftmaxParam ¶m) {} +void SoftmaxKernel::Compute(const SoftmaxParam ¶m) { + auto kernel = this->cl_helper_.KernelAt(0); + auto default_work_size = this->cl_helper_.DefaultWorkSize(*(param.Out())); + const auto *input = param.InputX(); + auto *output = param.Out(); + auto inputImage = input->GetCLImage(); + auto outputImage = output->GetCLImage(); + clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputImage); + clSetKernelArg(kernel, 1, sizeof(cl_mem), &outputImage); + const auto &inputDim = input->dims(); + int dims[4] = {inputDim[0], inputDim[1], inputDim[2], inputDim[3]}; + clSetKernelArg(kernel, 2, sizeof(int), dims); + clSetKernelArg(kernel, 3, sizeof(int), dims + 1); + clSetKernelArg(kernel, 4, sizeof(int), dims + 2); + clSetKernelArg(kernel, 5, sizeof(int), dims + 3); + + clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, + default_work_size.data(), NULL, 0, NULL, NULL); +} template class SoftmaxKernel; diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 689eec0da950fbc8a1e7892c0740bff1790fd1ab..85587e45e95f8e9e9f5e0dc819ee67438a8e4b4a 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -614,6 +614,14 @@ class BatchNormParam : OpParam { const string &DataFormat() const { return data_format_; } + void SetNewScale(RType *new_scale) { new_scale_ = new_scale; } + + void SetNewBias(RType *new_bias) { new_bias_ = new_bias; } + + const RType *NewScale() const { return new_scale_; } + + const RType *NewBias() const { return new_bias_; } + private: RType *input_x_; RType *output_y_; @@ -625,6 +633,8 @@ class BatchNormParam : OpParam { float momentum_; bool is_test_; string data_format_; + RType *new_bias_; + RType *new_scale_; }; #endif @@ -936,10 +946,21 @@ class FetchParam : public OpParam { FetchParam(const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, const Scope &scope) { input_x_ = InputXFrom(inputs, scope); +<<<<<<< HEAD out_ = OutFrom(outputs, scope); } const RType *InputX() const { return input_x_; } Tensor *Out() const { return out_; } +======= + out_ = OutFrom(outputs, scope); + } + const RType *InputX() const { return input_x_; } + Tensor *Out() const { return out_; } + + static Tensor *OutFrom(const VariableNameMap &outputs, const Scope &scope) { + return GetVarValue("Out", outputs, scope); + } +>>>>>>> df230944d11f0f09aea4c2c6bc0489d8667fa8ca private: RType *input_x_; diff --git a/test/net/test_googlenet.cpp b/test/net/test_googlenet.cpp index a2f030eeac5c2584b33fad2b082b9d5513707260..9e826d3a747c5207f81baa5973d7da6aabc2103f 100644 --- a/test/net/test_googlenet.cpp +++ b/test/net/test_googlenet.cpp @@ -29,8 +29,8 @@ int main() { bool optimize = true; auto time1 = time(); if (paddle_mobile.Load(g_googlenet, optimize)) { - auto time2 = time(); - std::cout << "load cost :" << time_diff(time1, time2) << "ms" << std::endl; + auto time2 = paddle_mobile::time(); + std::cout << "load cost :" << paddle_mobile::time_diff(time1, time2) << "ms" << std::endl; std::vector input; std::vector dims{1, 3, 224, 224}; GetInput(g_test_image_1x3x224x224, &input, dims); diff --git a/test/net/test_mobilenet_GPU.cpp b/test/net/test_mobilenet_GPU.cpp index f0994855faed337bf2e2e557c10108e053ea7e71..f65e1890f362e4c8e4ebeba9c5a59d79da1b2791 100644 --- a/test/net/test_mobilenet_GPU.cpp +++ b/test/net/test_mobilenet_GPU.cpp @@ -19,14 +19,14 @@ limitations under the License. */ int main() { paddle_mobile::PaddleMobile paddle_mobile; // paddle_mobile.SetThreadNum(4); - auto time1 = time(); + auto time1 = paddle_mobile::time(); // auto isok = paddle_mobile.Load(std::string(g_mobilenet_detect) + "/model", // std::string(g_mobilenet_detect) + "/params", true); auto isok = paddle_mobile.Load(g_mobilenet, false); if (isok) { - auto time2 = time(); - std::cout << "load cost :" << time_diff(time1, time1) << "ms" << std::endl; + auto time2 = paddle_mobile::time(); + std::cout << "load cost :" << paddle_mobile::time_diff(time1, time1) << "ms" << std::endl; std::vector input; std::vector dims{1, 3, 224, 224}; @@ -42,13 +42,13 @@ int main() { for (int i = 0; i < 10; ++i) { auto vec_result = paddle_mobile.Predict(input, dims); } - auto time3 = time(); + auto time3 = paddle_mobile::time(); for (int i = 0; i < 10; ++i) { auto vec_result = paddle_mobile.Predict(input, dims); } DLOG << vec_result; - auto time4 = time(); - std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms" + auto time4 = paddle_mobile::time(); + std::cout << "predict cost :" << paddle_mobile::time_diff(time3, time4) / 10 << "ms" << std::endl; } diff --git a/tools/web-exporter/CMakeLists.txt b/tools/web-exporter/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..e9cddecd6794fd047a2d0e79719373adbe7f5959 --- /dev/null +++ b/tools/web-exporter/CMakeLists.txt @@ -0,0 +1,20 @@ +cmake_minimum_required(VERSION 3.6) + +project(web-exporter) + +set(CMAKE_CXX_STANDARD 11) + +file(GLOB PADDLE_MOBILE_CPP_FILES + "../../src/common/*.c" + "../../src/common/*.cpp" + "../../src/memory/*.cpp" + "../../src/framework/*.c" + "../../src/framework/*.cpp" + "../../src/framework/program/*.cpp" + "../../src/framework/program/program-optimize/*.cpp" +) +file(GLOB EXPORT_CPP_FILES "*.cpp") + +add_executable(web-exporter ${PADDLE_MOBILE_CPP_FILES} ${EXPORT_CPP_FILES}) +target_include_directories(web-exporter PRIVATE "../../src") +target_link_libraries(web-exporter) \ No newline at end of file diff --git a/tools/web-exporter/export-nodejs.cpp b/tools/web-exporter/export-nodejs.cpp new file mode 100644 index 0000000000000000000000000000000000000000..023d9e5874e5871cdeb4e2b568c63c69436dee6e --- /dev/null +++ b/tools/web-exporter/export-nodejs.cpp @@ -0,0 +1,49 @@ +#include "export.h" + +inline std::string indent(int i) { + return std::string(i, ' '); +} +void export_nodejs(ProgramPtr program, ScopePtr scope, std::ostream & os) { + os << "module.exports.program = {\n"; + os << indent(2) << var2str("blocks") << ": [\n"; + for (const auto& block: program->Blocks()) { + os << indent(4) << "{\n"; + os << indent(6) << var2str("vars") << ": {\n"; + for (const auto& var: block->Vars()) { + const auto& dim = var->Tensor_desc().Dims(); + os << indent(8) << var2str(var->Name()) << ": {\n"; + os << indent(10) << var2str("dim") << ": " << var2str(dim) << ",\n"; + os << indent(10) << var2str("persistable") << ": " << var2str(var->Persistable()) << "\n"; + os << indent(8) << "},\n"; + } + os << indent(6) << "},\n"; + os << indent(6) << var2str("ops") << ": [\n"; + for (const auto& op: block->Ops()) { + os << indent(8) << "{\n"; + os << indent(10) << var2str("type") << ": " << var2str(op->Type()) << ",\n"; + os << indent(10) << var2str("inputs") << ": {\n"; + for (const auto& kv: op->GetInputs()) { + os << indent(12) << var2str(kv.first) << ": " << var2str(kv.second) << ",\n"; + } + os << indent(10) << "},\n"; + + os << indent(10) << var2str("outputs") << ": {\n"; + for (const auto& kv: op->GetInputs()) { + os << indent(12) << var2str(kv.first) << ": " << var2str(kv.second) << ",\n"; + } + os << indent(10) << "},\n"; + + os << indent(10) << var2str("attrs") << ": {\n"; + for (const auto& kv: op->GetAttrMap()) { + os << indent(12) << var2str(kv.first) << ": "; + os << decltype(kv.second)::ApplyVistor(VarVisitor(), kv.second) << ",\n"; + } + os << indent(10) << "},\n"; + os << indent(8) << "},\n"; + } + os << indent(6) << "],\n"; + os << indent(4) << "},\n"; + } + os << indent(2) << "]\n"; + os << "}\n"; +} diff --git a/tools/web-exporter/export-scope.cpp b/tools/web-exporter/export-scope.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d5c2492ac74129fce25eb56dc9fc870e66f2dccc --- /dev/null +++ b/tools/web-exporter/export-scope.cpp @@ -0,0 +1,34 @@ +#include +#include "export.h" + +void export_scope(ProgramPtr program, ScopePtr scope, const std::string & dirname) { + for (const auto& block: program->Blocks()) { + for (const auto& var: block->Vars()) { + if (var->Name() == "feed" || var->Name() == "fetch") { + continue; + } + if (var->Persistable()) { + auto* v = scope->FindVar(var->Name()); + assert(v != nullptr); + int count = 1; + for (auto n: var->Tensor_desc().Dims()) { + count *= n; + } + + auto* tensor = v->GetMutable(); + const float * p = tensor->mutable_data(); + + std::string para_file_name = dirname + '/' + var->Name(); + FILE *para_file = fopen(para_file_name.c_str(), "w"); + assert(p != nullptr); + fwrite(p, sizeof(float), count, para_file); + fclose(para_file); + // std::cout << "==> " << var->Name() << " " << count << "\n"; + // for (int i = 0; i < count; i++) { + // std::cout << p[i] << ", "; + // } + // std::cout << "\n"; + } + } + } +} diff --git a/tools/web-exporter/export.cpp b/tools/web-exporter/export.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1f7c678b69e6af9b4f6694489304b85853bf6215 --- /dev/null +++ b/tools/web-exporter/export.cpp @@ -0,0 +1,52 @@ +#include "export.h" +#include +#include + +class FakeExecutor : public paddle_mobile::framework::Executor { +public: + FakeExecutor(const paddle_mobile::framework::Program p) { + program_ = p; + batch_size_ = 1; + use_optimize_ = true; + loddable_ = false; + if (use_optimize_) { + to_predict_program_ = program_.optimizeProgram; + } else { + to_predict_program_ = program_.originProgram; + } + auto *variable_ptr = program_.scope->Var("batch_size"); + variable_ptr[0].SetValue(1); + if (program_.combined) { + InitCombineMemory(); + } else { + InitMemory(); + } + } +}; + +int main(int argc, char** argv) { + if (argc != 3) { + std::cout << "Usage: " << argv[0] << " \n"; + return -1; + } + std::string model_dir = argv[1]; + std::string model_path = model_dir + "/model"; + std::string para_path = model_dir + "/params"; + + std::string out_dir = argv[2]; + std::string out_model_js = out_dir + "/model.js"; + std::string out_para_dir = out_dir + "/paras"; + mkdir(out_dir.c_str(), S_IRWXU|S_IRWXG|S_IRWXO); + mkdir(out_para_dir.c_str(), S_IRWXU|S_IRWXG|S_IRWXO); + + std::cout << "loading " << model_path << " & " << para_path << "\n"; + paddle_mobile::framework::Loader<> loader; + auto program = loader.Load(model_path, para_path, true); + FakeExecutor executor(program); + auto optimizedProgram = program.optimizeProgram; + export_scope(optimizedProgram, program.scope, out_para_dir); + std::ofstream fs(out_model_js.c_str()); + export_nodejs(optimizedProgram, program.scope, fs); + fs.close(); + return 0; +} diff --git a/tools/web-exporter/export.h b/tools/web-exporter/export.h new file mode 100644 index 0000000000000000000000000000000000000000..d9db3b31dfa490b4404baccd6336df456cc84755 --- /dev/null +++ b/tools/web-exporter/export.h @@ -0,0 +1,61 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "framework/loader.h" +#include "framework/executor.h" +#include "framework/scope.h" +#include "framework/program/program_desc.h" + +// using paddle_mobile::framework::ProgramDesc; +// using paddle_mobile::framework::Scope; + +using ProgramPtr = std::shared_ptr; +using ScopePtr = std::shared_ptr; + +void export_nodejs(ProgramPtr program, ScopePtr scope, std::ostream & os = std::cout); +void export_scope(ProgramPtr program, ScopePtr scope, const std::string & dirname = "."); + + +template +inline std::string var2str(const T & v) { + return std::to_string(v); +} + +template <> +inline std::string var2str(const std::string & v) { + return "\"" + v + "\""; +} + +inline std::string var2str(const char* v) { + return var2str(v); +} + +inline std::string var2str(const bool v) { + return v ? "true" : "false"; +} + +template +std::string var2str(const std::vector & v) { + std::string r = "["; + auto s = v.size(); + for (int i = 0; i < s; i++) { + if (i) r += ", "; + r += var2str(v[i]); + } + return r + "]"; +} + +struct VarVisitor { + using type_t = decltype(var2str(0)); + + template + type_t operator()(const T & v) { + return var2str(v); + } +}; \ No newline at end of file