提交 10650d3f 编写于 作者: Z ZhenWang

Merge branch 'develop' of https://github.com/PaddlePaddle/paddle-mobile into add_pooling_int8

......@@ -14,6 +14,7 @@ limitations under the License. */
#include "fpga/V1/api.h"
#include "fpga/V1/bias_scale.h"
#include "fpga/V1/deconv_filter.h"
#include "fpga/V1/filter.h"
#include "fpga/V1/image.h"
......@@ -124,6 +125,32 @@ void format_fc_filter(framework::Tensor *filter_tensor, float max_value) {
max_value);
filter_tensor->reset_data_ptr(new_data);
}
void format_deconv_filter(framework::Tensor *filter_tensor, float max_value,
int group_num, int stride) {
filter_tensor->scale[0] = float(max_value / 127.0); // NOLINT
filter_tensor->scale[1] = float(127.0 / max_value); // NOLINT
auto dims = filter_tensor->dims();
auto num = dims[0], channel = dims[1], height = dims[2], width = dims[3];
auto data_ptr = filter_tensor->data<float>();
size_t memory_size = num * channel * height * width * sizeof(float);
auto new_data = (float *)fpga_malloc(memory_size); // NOLINT
memcpy(new_data, data_ptr, memory_size);
int hw = height * width;
deconv_filter::deconv_NC_convert(&new_data, num, channel, hw);
num = dims[1];
channel = dims[0];
deconv_filter::deconv_format_filter(
&new_data, (int)num, (int)channel, // NOLINT
(int)height, // NOLINT
(int)width, group_num, max_value, stride); // NOLINT
framework::DDim dims_new =
framework::make_ddim({num, channel, height, width});
filter_tensor->Resize(dims_new);
filter_tensor->reset_data_ptr(new_data);
}
void format_bias_scale_array(float **bias_scale_array,
int element_num_per_division, int num) {
......@@ -240,6 +267,100 @@ void fill_split_arg(struct SplitConvArgs *arg, framework::Tensor *input,
filter->reset_data_ptr(nullptr);
fpga_free(bs_ptr);
}
void fill_deconv_arg(struct DeconvArgs *arg, framework::Tensor *input,
framework::Tensor *out, framework::Tensor *filter,
bool relu_enabled, int group_num, int stride_h,
int stride_w, int padding_h, int padding_w,
float *bs_ptr) {
auto input_ptr = input->data<float>();
auto filter_ptr = filter->data<float>();
auto out_ptr = out->data<float>();
arg->group_num = (uint32_t)group_num;
arg->sub_conv_num = stride_h;
arg->filter_num = (uint32_t)filter->dims()[0];
int sub_conv_num = arg->sub_conv_num;
int sub_stride = 1;
int sub_pad = deconv_filter::deconv_calc_sub_pad(filter->dims()[3], padding_w,
stride_w);
int sub_filter_width =
deconv_filter::deconv_get_sub_filter_axis(filter->dims()[3], stride_w);
int sub_output_width = deconv_filter::deconv_get_sub_out_axis(
input->dims()[3], sub_pad, sub_filter_width);
int sub_output_height = deconv_filter::deconv_get_sub_out_axis(
input->dims()[2], sub_pad, sub_filter_width);
arg->sub_output_width = sub_output_width;
arg->sub_output_height = sub_output_height;
arg->omit_size =
deconv_filter::deconv_get_omit(stride_w, filter->dims()[3], padding_w);
arg->conv_args = (ConvArgs *)fpga_malloc(sub_conv_num * sizeof(ConvArgs));
int sub_channels = (int32_t)input->dims()[1];
int omit_size = arg->omit_size;
int real_out_width = sub_output_width * sub_conv_num - 2 * omit_size;
int real_out_height = sub_output_height * sub_conv_num - 2 * omit_size;
int sub_filter_num = sub_conv_num * (arg->filter_num);
int conv_output_size =
(align_to_x(sub_output_width * sub_filter_num, IMAGE_ALIGNMENT)) *
sub_output_height;
int ouput_size = conv_output_size * sub_conv_num;
int align_sub_filter_num = align_to_x(sub_filter_num, FILTER_NUM_ALIGNMENT);
int align_sub_filter_count =
align_to_x(sub_filter_width * sub_filter_width * sub_channels,
FILTER_ELEMENT_ALIGNMENT);
int align_conv_sub_filter_count =
align_sub_filter_count * align_sub_filter_num;
for (int i = 0; i < sub_conv_num; ++i) {
arg->conv_args[i].filter_num = (arg->sub_conv_num) * (arg->filter_num);
arg->conv_args[i].group_num = group_num;
arg->conv_args[i].filter_scale_address = filter->scale;
arg->conv_args[i].relu_enabled = relu_enabled;
arg->conv_args[i].kernel.width = sub_filter_width;
arg->conv_args[i].kernel.height = sub_filter_width;
arg->conv_args[i].kernel.stride_w = 1;
arg->conv_args[i].kernel.stride_h = 1;
// DeconvParam.conv_args[i].image.address = (void*)ptr_image;
arg->conv_args[i].image.scale_address = input->scale;
arg->conv_args[i].image.channels = sub_channels;
arg->conv_args[i].image.width = (uint32_t)input->dims()[3];
arg->conv_args[i].image.height = (uint32_t)input->dims()[2];
arg->conv_args[i].image.pad_width = sub_pad;
arg->conv_args[i].image.pad_height = sub_pad;
arg->conv_args[i].image.address = input_ptr;
arg->conv_args[i].sb_address = (void *)bs_ptr;
char *filter_sub_space =
(char *)fpga_malloc(align_conv_sub_filter_count * sizeof(char));
fpga_copy(filter_sub_space,
(char *)filter_ptr + i * align_conv_sub_filter_count,
align_conv_sub_filter_count);
arg->conv_args[i].filter_address = (void *)(filter_sub_space);
fpga_flush(filter_sub_space, align_conv_sub_filter_count);
if (sub_conv_num == 1) {
arg->conv_args[i].output.address = out_ptr;
arg->conv_args[i].output.scale_address = out->scale;
} else {
half *ptr_output = (half *)fpga_malloc(conv_output_size * sizeof(half));
arg->conv_args[i].output.address = (void *)((half *)ptr_output);
float *ptr_output_scale = (float *)fpga_malloc(2 * sizeof(float));
arg->conv_args[i].output.scale_address = ptr_output_scale;
}
}
arg->output.address = out_ptr;
arg->output.scale_address = out->scale;
// fpga_free(filter_ptr);
}
} // namespace fpga
} // namespace paddle_mobile
......@@ -43,6 +43,25 @@ void fill_split_arg(struct SplitConvArgs* arg, framework::Tensor* input,
framework::Tensor* out, framework::Tensor* filter,
bool relu_enabled, int group_num, int stride_h,
int stride_w, int padding_h, int padding_w, float* bs_ptr);
void fill_deconv_arg(struct DeconvArgs* arg, framework::Tensor* input,
framework::Tensor* out, framework::Tensor* filter,
bool relu_enabled, int group_num, int stride_h,
int stride_w, int padding_h, int padding_w, float* bs_ptr);
void format_deconv_filter(framework::Tensor* filter_tensor, float max_value,
int group_num, int stride);
template <typename Dtype>
void savefile(std::string filename, void* buffer, int dataSize, Dtype tmp) {
float data;
std::ofstream out(filename.c_str());
for (int i = 0; i < dataSize; ++i) {
data = (((Dtype*)buffer)[i]);
out << data << std::endl;
}
out.close();
return;
}
} // namespace fpga
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "fpga/V1/deconv_bias_scale.h"
// #include "deconv_bias_scale.h"
#include "fpga/V1/bias_scale.h"
// #include "bias_scale.h"
#include <memory.h>
#include "fpga/V1/api.h"
// #include "fpga_api.h"
namespace paddle_mobile {
namespace fpga {
namespace deconv_bias_scale {
void deconv_bias_scale_expand(float** bias_scale_array, int num,
int sub_conv_n) {
int sub_num = num * sub_conv_n;
float* ptr_tmp = *bias_scale_array;
float* ptr_bias_scale_expand =
(float*)fpga_malloc(sizeof(float) * sub_num * 2);
int scale_base_offset = sub_num;
for (int i = 0; i < sub_conv_n; ++i) {
int offset = num * i;
// copy bias
fpga_copy(ptr_bias_scale_expand + offset, ptr_tmp, num * sizeof(float));
// copy scale
fpga_copy(ptr_bias_scale_expand + scale_base_offset + offset, ptr_tmp + num,
num * sizeof(float));
}
*bias_scale_array = ptr_bias_scale_expand;
fpga_free(ptr_tmp);
}
} // namespace deconv_bias_scale
} // namespace fpga
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#define BS_NUM_ALIGNMENT 8
namespace paddle_mobile {
namespace fpga {
namespace deconv_bias_scale {
void deconv_bias_scale_expand(float** bias_scale_array, int num,
int sub_conv_n);
} // namespace deconv_bias_scale
} // namespace fpga
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "fpga/V1/deconv_filter.h"
#include <memory.h>
#include <algorithm>
// #include "deconv_filter.h"
#include "fpga/V1/filter.h"
// #include "filter.h"
#include "fpga/V1/api.h"
// #include "fpga_api.h"
// just for test
//#include <string>
//#include "deconv.h"
//#include "deconv_api.h"
// using namespace std;
// using namespace paddle_mobile::fpga;
// using namespace baidu::fpga::deconv::api;
// namespace api = baidu::fpga::deconv::api;
namespace paddle_mobile {
namespace fpga {
namespace deconv_filter {
/*
inverse kernel weights of each channel for every filter
*/
void deconv_inverse_filter(float** data_in, int num, int channel, int width,
int height) {
float* tmp = *data_in;
// float fix_range = 127;// float scale = fix_range / max;
int data_size = num * channel * width * height;
int hw_len = height * width;
float* tmp_data = (float*)fpga_malloc(data_size * sizeof(float));
for (int i = 0; i < num; ++i) {
for (int j = 0; j < channel; ++j) {
for (int k = 0; k < hw_len; ++k) {
tmp_data[i * channel * hw_len + j * hw_len + k] =
(*data_in)[i * channel * hw_len + j * hw_len + hw_len - k - 1];
}
}
}
*data_in = (float*)tmp_data; //
fpga_free(tmp);
}
/*
calculate sub padding number
*/
int deconv_calc_sub_pad(int filter_axis, int pad, int stride) {
if (stride == 0 || ((filter_axis - pad - 1) < 0)) {
// error
return 0;
}
return (filter_axis - pad - 1) / stride;
}
int deconv_get_sub_filter_axis(int filter_axis, int stride) {
return (filter_axis / stride);
}
int deconv_get_sub_out_axis(int image_axis, int sub_pad, int sub_filter_axis) {
return ((image_axis + 2 * sub_pad - sub_filter_axis) + 1);
}
/*
(filter_width-pad,filter_width-pad) is the first pixel of sub-pixel image
position. so the omit rows or columns is (stride - )
*/
int deconv_get_omit(int stride, int filter_width, int pad) {
if (((filter_width - pad) <= 0)) { // ((filter_width-pad) > stride) ||
// error
return 0;
}
int idx = 1;
bool flag = false;
for (idx = 1; idx <= stride; ++idx) {
int j = idx;
for (; j <= filter_width;) {
if (j == filter_width - pad) {
flag = true;
break;
}
j = j + stride;
}
if (flag) {
break;
}
}
return (stride - idx);
}
int deconv_get_sub_filter_num(int filter_num, int stride) {
return filter_num * stride;
}
void deconv_get_sub_filter(char** data_in, int height, int width,
int sub_conv_n, int kernel_num, int channel) {
char* ptr_tmp = *data_in;
int sub_num = kernel_num * sub_conv_n;
int sub_h = height / sub_conv_n;
int sub_w = width / sub_conv_n;
int sub_filter_size =
kernel_num * sub_h * sub_w * channel * sub_conv_n * sub_conv_n;
char* ptr_sub_filter = (char*)fpga_malloc(sub_filter_size * sizeof(char));
for (int idx = 0; idx < sub_conv_n; ++idx) {
for (int nn = 0; nn < sub_num; ++nn) {
int ni = nn % kernel_num;
int woff = sub_conv_n - 1 - (nn / kernel_num); //
for (int hh = 0; hh < sub_h; ++hh) {
int hi = hh * sub_conv_n + idx % sub_conv_n;
for (int ww = 0; ww < sub_w; ++ww) {
int wi = ww * sub_conv_n + woff; // 1 0
int sidx = ((nn * sub_h + hh) * sub_w + ww) * channel; //
int kidx = ((ni * height + hi) * width + wi) * channel; //
fpga_copy(
ptr_sub_filter + idx * sub_h * sub_w * channel * sub_num + sidx,
(*data_in) + kidx, channel * sizeof(char));
// for (int cc =0; cc < channel; ++cc) {
// ptr_sub_filter[idx*sub_h*sub_w*channel*sub_num + sidx + cc] =
// (*data_in)[kidx + cc];
// }
}
}
}
}
*data_in = ptr_sub_filter;
fpga_free(ptr_tmp);
}
void deconv_NC_convert(float** filter_in, int kernel_num, int channels,
int hw) {
float* tmp = *filter_in;
float* ptr_filter = (float*)(paddle_mobile::fpga::fpga_malloc(
hw * kernel_num * channels * sizeof(float)));
for (int c = 0; c < channels; ++c) {
for (int n = 0; n < kernel_num; ++n) {
paddle_mobile::fpga::fpga_copy(ptr_filter + n * hw + kernel_num * hw * c,
tmp + n * channels * hw + c * hw,
hw * sizeof(float));
}
}
*filter_in = ptr_filter;
paddle_mobile::fpga::fpga_free(tmp);
}
void deconv_format_filter(float** data_in, int num, int channel, int height,
int width, int group_num, float max, int stride) {
int data_size = channel * height * width * num;
/*{
float result2 = (float)0;
string filename = "origin_filter_data";
api::savefile<float>(filename, (void *)*data_in, data_size, result2);
}*/
deconv_inverse_filter(data_in, num, channel, width, height);
/* {
float result2 = (float)0;
string filename = "inverse_filter_data";
api::savefile<float>(filename, (void *)*data_in, data_size, result2);
}*/
filter::quantize(data_in, data_size, max);
/* {
char result2 = (char)0;
string filename = "quantize_filter_data";
api::savefile<char>(filename, (void *)*data_in, data_size, result2);
}*/
char** quantize_data = (char**)data_in; // NOLINT
filter::convert_to_hwc(quantize_data, num, channel, height, width);
/*{
char result2 = (char)0;
string filename = "convert_to_hwc_filter_data";
api::savefile<char>(filename, (void *)*quantize_data, data_size,
result2);
}*/
deconv_get_sub_filter(quantize_data, height, width, stride, num, channel);
/*{
char result2 = (char)0;
string filename = "sub_filter_filter_data";
api::savefile<char>(filename, (void *)*quantize_data, data_size, result2);
}*/
int sub_conv_n = stride;
int sub_h = height / sub_conv_n;
int sub_w = width / sub_conv_n;
int sub_chw = sub_h * sub_w * channel;
int sub_num = sub_conv_n * num;
int division_capacity = filter::calc_division_capacity(sub_chw);
int num_per_div_before_alignment =
filter::calc_num_per_div(sub_num, group_num, division_capacity);
int num_per_div_after_alignment =
align_to_x(num_per_div_before_alignment, FILTER_NUM_ALIGNMENT);
int div_num = (sub_num + num_per_div_before_alignment - 1) /
num_per_div_before_alignment;
int residual = (sub_num) % num_per_div_before_alignment;
int num_after_alignment = num_per_div_after_alignment *
((residual == 0) ? div_num : (div_num - 1)) +
align_to_x(residual, FILTER_NUM_ALIGNMENT);
char** ptr_ptr_data = (char**)fpga_malloc(sub_conv_n * sizeof(char*));
int origin_offset = sub_chw * sub_num;
for (int i = 0; i < sub_conv_n; ++i) {
(ptr_ptr_data)[i] = (char*)fpga_malloc(origin_offset * sizeof(char));
fpga_copy((ptr_ptr_data)[i], (*quantize_data) + origin_offset * i,
origin_offset * sizeof(char));
/* char result2 = (char)0;
string filename = "ptr_ptr_data" + to_string(i);
api::savefile<char>(filename, (void *)(ptr_ptr_data[i]), origin_offset,
result2);
*/
}
// char result2 = (char)0;
// string filename = "interleave";
// api::savefile<char>(filename, (void *)*ptr_ptr_data, origin_offset,
// result2);
fpga_free(*quantize_data);
int align_offset =
align_to_x(sub_chw, FILTER_ELEMENT_ALIGNMENT) * num_after_alignment;
char* ptr_space = (char*)fpga_malloc(sub_conv_n * align_offset *
sizeof(char)); // continuous space
for (int i = 0; i < sub_conv_n; ++i) {
int offset = i * origin_offset;
char* ptr_tmp = (ptr_ptr_data)[i];
filter::align_element(&ptr_tmp, sub_num, sub_chw);
filter::align_num(&ptr_tmp, num_per_div_before_alignment, sub_num, sub_chw);
filter::reorder(&ptr_tmp, num_after_alignment, sub_chw);
filter::interleave(&ptr_tmp, num_after_alignment, sub_chw);
/* char result2 = (char)0;
string filename = "interleave" + to_string(i);
api::savefile<char>(filename, (void *)ptr_tmp, align_offset, result2);
*/
fpga_copy(ptr_space + i * align_offset, ptr_tmp, align_offset);
fpga_free(ptr_tmp);
}
*data_in = (float*)ptr_space;
/* {
char result2 = (char)0;
string filename = "ptr_space";
api::savefile<char>(filename, (void *)ptr_space, sub_conv_n *
align_offset, result2);
}*/
fpga_flush(ptr_space, sub_conv_n * align_offset * sizeof(char));
}
} // namespace deconv_filter
} // namespace fpga
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
namespace paddle_mobile {
namespace fpga {
namespace deconv_filter {
void deconv_inverse_filter(float** data_in, int num, int channel, int width,
int height);
int deconv_calc_sub_pad(int filter_axis, int pad, int stride);
int deconv_get_sub_filter_num(int filter_num, int stride);
int deconv_get_sub_filter_axis(int filter_axis, int stride);
int deconv_get_sub_out_axis(int image_axis, int sub_pad, int sub_filter_axis);
int deconv_get_omit(int stride, int filter_width, int pad);
void deconv_get_sub_filter(char** data_in, int height, int width,
int sub_conv_n, int kernel_num, int channel);
void deconv_format_filter(float** data_in, int num, int channel, int height,
int width, int group_num, float max, int stride);
void deconv_NC_convert(float** filter_in, int kernel_num, int channels, int hw);
} // namespace deconv_filter
} // namespace fpga
} // namespace paddle_mobile
......@@ -146,12 +146,16 @@ void align_num(char **data_in, int num_per_div_before_alignment, int num,
memset(data_tmp, 0, num_element * sizeof(char));
for (i = 0; i < div_num; i++) {
for (i = 0; i < div_num - 1; i++) {
memcpy(data_tmp + num_per_div_after_alignment * align_chw * i,
*data_in + num_per_div_before_alignment * align_chw * i,
num_per_div_before_alignment * align_chw);
}
memcpy(data_tmp + num_per_div_after_alignment * align_chw * i,
*data_in + num_per_div_before_alignment * align_chw * i,
(num - (div_num - 1) * num_per_div_before_alignment) * align_chw);
*data_in = data_tmp;
fpga_free(tmp);
}
......
......@@ -29,11 +29,11 @@ void convert_to_hwc(char** data_in, int num, int channel, int height,
int width);
float find_max(float* data_in, int data_size);
void quantize(float** data_in, int data_size, float max);
void align_element(float** data_in, int num, int chw);
void align_element(char** data_in, int num, int chw);
void align_num(char** data_in, int num_per_div_before_alignment, int num,
int chw);
void reorder(float** data_in, int num_after_alignment, int chw);
void interleave(float** data_in, int num_after_alignment, int chw);
void reorder(char** data_in, int num_after_alignment, int chw);
void interleave(char** data_in, int num_after_alignment, int chw);
void format_filter(float** data_in, int num, int channel, int height, int width,
int group_num, float max);
......
......@@ -56,6 +56,7 @@ class CLImage {
tensor_dims_ = dim;
}
bool isInit() { return initialized_; }
/*
* need call SetTensorData first
*
......
......@@ -55,6 +55,9 @@ REGISTER_FUSION_MATCHER(fusion_conv_bn_add_relu,
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_conv_bn_add_relu, ops::FusionConvBNAddReluOp);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(fusion_conv_bn_add_relu, ops::FusionConvBNAddReluOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(fusion_conv_bn_add_relu, ops::FusionConvBNAddReluOp);
#endif
......
......@@ -77,15 +77,25 @@ void BatchNormKernel<GPU_CL, float>::Compute(
auto new_scale = param.NewScale()->GetCLImage();
auto new_bias = param.NewBias()->GetCLImage();
const int out_width = default_work_size[1];
clSetKernelArg(kernel, 1, sizeof(int), &out_width);
clSetKernelArg(kernel, 2, sizeof(cl_mem), &input);
clSetKernelArg(kernel, 3, sizeof(cl_mem), &new_scale);
clSetKernelArg(kernel, 4, sizeof(cl_mem), &new_bias);
clSetKernelArg(kernel, 5, sizeof(cl_mem), &out);
// cl_event out_event = param.OutputY()->GetClEvent();
// cl_event wait_event = param.InputX()->GetClEvent();
DLOG << *param.InputX();
DLOG << *param.NewBias();
DLOG << *param.NewScale();
DLOG << default_work_size[0];
DLOG << default_work_size[1];
DLOG << default_work_size[2];
DLOG << out_width;
DLOG << *param.OutputY();
cl_int status;
clSetKernelArg(kernel, 0, sizeof(cl_int), &out_width);
CL_CHECK_ERRORS(status);
clSetKernelArg(kernel, 1, sizeof(cl_mem), &input);
CL_CHECK_ERRORS(status);
clSetKernelArg(kernel, 2, sizeof(cl_mem), &new_scale);
CL_CHECK_ERRORS(status);
clSetKernelArg(kernel, 3, sizeof(cl_mem), &new_bias);
CL_CHECK_ERRORS(status);
clSetKernelArg(kernel, 4, sizeof(cl_mem), &out);
CL_CHECK_ERRORS(status);
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define BATCH_NORM
#define BIASE
#define RELU
#include "conv_kernel.inc.cl"
......@@ -924,6 +924,387 @@ __kernel void conv_5x5(__private const int global_size_dim0,
write_imageh(output_image, (int2)(out_c * global_size_dim1 + out_w, out_nh), output);
}
__kernel void convBNAdd_3x3(__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter,
#ifdef BIASE
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int offset,
__private const int input_c,
__private const int dilation,
__private const int input_width,/* of one block */
__private const int input_height,/* of one block */
__private const int output_width,
__private const int output_height) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
if (out_c >= global_size_dim0 ||
out_w >= global_size_dim1 ||
out_nh >= global_size_dim2) {
return;
}
int2 stride_xy;
stride_xy.x = stride;
stride_xy.y = stride;
int2 ouput_pos_in_one_block;
ouput_pos_in_one_block.x = out_w;
ouput_pos_in_one_block.y = out_nh;
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
int2 in_pos_in_one_block;
in_pos_in_one_block.x = ouput_pos_in_one_block.x * stride + offset;
in_pos_in_one_block.y = ouput_pos_in_one_block.y * stride + offset;
half4 output = (half4)0.0f;
half4 input[9];
for (int i = 0; i < input_c; ++i) {
int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y);
input[0] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x - dilation, pos_in.y - dilation)),
(half4)(0.0f),
(ushort4)((in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y - dilation >= input_height) << 15));
input[1] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x, pos_in.y - dilation)),
(half4)(0.0f),
(ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y - dilation >= input_height) << 15));
input[2] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x + dilation, pos_in.y - dilation)),
(half4)(0.0f),
(ushort4)((in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y - dilation >= input_height) << 15));
input[3] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x - dilation, pos_in.y)),
(half4)(0.0f),
(ushort4)((in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y >= input_height) << 15));
input[4] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x, pos_in.y)),
(half4)(0.0f),
(ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y >= input_height) << 15));
input[5] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x + dilation, pos_in.y)),
(half4)(0.0f),
(ushort4)((in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y >= input_height) << 15));
input[6] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x - dilation, pos_in.y + dilation)),
(half4)(0.0f),
(ushort4)((in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y + dilation >= input_height) << 15));
input[7] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x, pos_in.y + dilation)),
(half4)(0.0f),
(ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y + dilation >= input_height) << 15));
input[8] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x + dilation, pos_in.y + dilation)),
(half4)(0.0f),
(ushort4)((in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y + dilation >= input_height) << 15));
/*
for (int j = 0; j < 9; ++j) {
int2 pos_of_weight;
pos_of_weight.x = i * 3 + j % 3;
pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3;
half4 weight_x = read_imageh(filter, sampler, pos_of_weight);
output.x += dot(input[j], weight_x);
pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3;
half4 weight_y = read_imageh(filter, sampler, pos_of_weight);
output.y += dot(input[j], weight_y);
pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3;
half4 weight_z = read_imageh(filter, sampler, pos_of_weight);
output.z += dot(input[j], weight_z);
pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3;
half4 weight_w = read_imageh(filter, sampler, pos_of_weight);
output.w += dot(input[j], weight_w);
}
*/
int j = 0;
int2 pos_of_weight;
pos_of_weight.x = i * 3 + j % 3;
pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3;
half4 weight_x = read_imageh(filter, sampler, pos_of_weight);
output.x += dot(input[j], weight_x);
pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3;
half4 weight_y = read_imageh(filter, sampler, pos_of_weight);
output.y += dot(input[j], weight_y);
pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3;
half4 weight_z = read_imageh(filter, sampler, pos_of_weight);
output.z += dot(input[j], weight_z);
pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3;
half4 weight_w = read_imageh(filter, sampler, pos_of_weight);
output.w += dot(input[j], weight_w);
j = 1;
pos_of_weight.x = i * 3 + j % 3;
pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3;
weight_x = read_imageh(filter, sampler, pos_of_weight);
output.x += dot(input[j], weight_x);
pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3;
weight_y = read_imageh(filter, sampler, pos_of_weight);
output.y += dot(input[j], weight_y);
pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3;
weight_z = read_imageh(filter, sampler, pos_of_weight);
output.z += dot(input[j], weight_z);
pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3;
weight_w = read_imageh(filter, sampler, pos_of_weight);
output.w += dot(input[j], weight_w);
j = 2;
pos_of_weight.x = i * 3 + j % 3;
pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3;
weight_x = read_imageh(filter, sampler, pos_of_weight);
output.x += dot(input[j], weight_x);
pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3;
weight_y = read_imageh(filter, sampler, pos_of_weight);
output.y += dot(input[j], weight_y);
pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3;
weight_z = read_imageh(filter, sampler, pos_of_weight);
output.z += dot(input[j], weight_z);
pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3;
weight_w = read_imageh(filter, sampler, pos_of_weight);
output.w += dot(input[j], weight_w);
j = 3;
pos_of_weight.x = i * 3 + j % 3;
pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3;
weight_x = read_imageh(filter, sampler, pos_of_weight);
output.x += dot(input[j], weight_x);
pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3;
weight_y = read_imageh(filter, sampler, pos_of_weight);
output.y += dot(input[j], weight_y);
pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3;
weight_z = read_imageh(filter, sampler, pos_of_weight);
output.z += dot(input[j], weight_z);
pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3;
weight_w = read_imageh(filter, sampler, pos_of_weight);
output.w += dot(input[j], weight_w);
j = 4;
pos_of_weight.x = i * 3 + j % 3;
pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3;
weight_x = read_imageh(filter, sampler, pos_of_weight);
output.x += dot(input[j], weight_x);
pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3;
weight_y = read_imageh(filter, sampler, pos_of_weight);
output.y += dot(input[j], weight_y);
pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3;
weight_z = read_imageh(filter, sampler, pos_of_weight);
output.z += dot(input[j], weight_z);
pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3;
weight_w = read_imageh(filter, sampler, pos_of_weight);
output.w += dot(input[j], weight_w);
j = 5;
pos_of_weight.x = i * 3 + j % 3;
pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3;
weight_x = read_imageh(filter, sampler, pos_of_weight);
output.x += dot(input[j], weight_x);
pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3;
weight_y = read_imageh(filter, sampler, pos_of_weight);
output.y += dot(input[j], weight_y);
pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3;
weight_z = read_imageh(filter, sampler, pos_of_weight);
output.z += dot(input[j], weight_z);
pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3;
weight_w = read_imageh(filter, sampler, pos_of_weight);
output.w += dot(input[j], weight_w);
j = 6;
pos_of_weight.x = i * 3 + j % 3;
pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3;
weight_x = read_imageh(filter, sampler, pos_of_weight);
output.x += dot(input[j], weight_x);
pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3;
weight_y = read_imageh(filter, sampler, pos_of_weight);
output.y += dot(input[j], weight_y);
pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3;
weight_z = read_imageh(filter, sampler, pos_of_weight);
output.z += dot(input[j], weight_z);
pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3;
weight_w = read_imageh(filter, sampler, pos_of_weight);
output.w += dot(input[j], weight_w);
j = 7;
pos_of_weight.x = i * 3 + j % 3;
pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3;
weight_x = read_imageh(filter, sampler, pos_of_weight);
output.x += dot(input[j], weight_x);
pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3;
weight_y = read_imageh(filter, sampler, pos_of_weight);
output.y += dot(input[j], weight_y);
pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3;
weight_z = read_imageh(filter, sampler, pos_of_weight);
output.z += dot(input[j], weight_z);
pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3;
weight_w = read_imageh(filter, sampler, pos_of_weight);
output.w += dot(input[j], weight_w);
j = 8;
pos_of_weight.x = i * 3 + j % 3;
pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3;
weight_x = read_imageh(filter, sampler, pos_of_weight);
output.x += dot(input[j], weight_x);
pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3;
weight_y = read_imageh(filter, sampler, pos_of_weight);
output.y += dot(input[j], weight_y);
pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3;
weight_z = read_imageh(filter, sampler, pos_of_weight);
output.z += dot(input[j], weight_z);
pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3;
weight_w = read_imageh(filter, sampler, pos_of_weight);
output.w += dot(input[j], weight_w);
}
#ifdef BATCH_NORM
output = output * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + read_imageh(new_biase, sampler, (int2)(out_c, 0));
#endif
#ifdef BIASE
output += read_imageh(bias, sampler, (int2)(out_c * global_size_dim1 + out_w, out_nh));
#endif
#ifdef RELU
output = activation(output);
#endif
write_imageh(output_image, (int2)(out_c * global_size_dim1 + out_w, out_nh), output);
}
__kernel void convBNAdd_1x1(__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter,
#ifdef BIASE
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int offset,
__private const int input_c,
__private const int dilation,
__private const int input_width,/* of one block */
__private const int input_height,/* of one block */
__private const int output_width,
__private const int output_height) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
const uint kernelHXW = 1;
int2 stride_xy = (int2)(stride, stride);
int2 ouput_pos_in_one_block = (int2)(out_w, out_nh);
int2 in_pos_in_one_block = ouput_pos_in_one_block * stride_xy + (int2)(offset, offset);
half4 output = 0.0f;
for (int i = 0; i < input_c; ++i) {
int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y);
half4 input = read_imageh(input_image, sampler, pos_in);
half4 weight0 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 0));
half4 weight1 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 1));
half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 2));
half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 3));
/*
output.x = dot(input, weight0);
output.y = dot(input, weight1);
output.z = dot(input, weight2);
output.w = dot(input, weight3);
*/
output = mad(input.x, weight0, output);
output = mad(input.y, weight1, output);
output = mad(input.z, weight2, output);
output = mad(input.w, weight3, output);
}
#ifdef BATCH_NORM
output = output * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + read_imageh(new_biase, sampler, (int2)(out_c, 0));
#endif
#ifdef BIASE
output += read_imageh(bias, sampler, (int2)(out_c * global_size_dim1 + out_w, out_nh));
#endif
#ifdef RELU
output = activation(output);
#endif
int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh);
write_imageh(output_image, output_pos, output);
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVBNADDRELU_OP
#include "operators/kernel/conv_bn_add_relu_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ConvBNAddReluKernel<GPU_CL, float>::Init(
FusionConvBNAddReluParam<GPU_CL> *param) {
PADDLE_MOBILE_ENFORCE(
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1],
"need equal");
const framework::CLImage *mean = param->InputMean();
const framework::CLImage *variance = param->InputVariance();
const framework::CLImage *scale = param->InputScale();
const framework::CLImage *bias = param->InputBias();
const float epsilon = param->Epsilon();
const int C = mean->numel();
auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>();
auto scale_ptr = scale->data<float>();
auto bias_ptr = bias->data<float>();
float inv_std_ptr[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
float *new_scale_ptr = new float[C];
float *new_bias_ptr = new float[C];
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
}
framework::CLImage *new_scale = new framework::CLImage();
// for (int j = 0; j < C; ++j) {
// DLOG << " new scale - " << j << new_scale_ptr[j];
// }
//
// for (int j = 0; j < C; ++j) {
// DLOG << " new bias - " << j << new_bias_ptr[j];
// }
new_scale->SetTensorData(new_scale_ptr, variance->dims());
new_scale->InitCLImage(this->cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
// DLOG << " climage - y bias: " << *(param->Bias());
//
// DLOG << " climage - new scale: " << *new_scale;
framework::CLImage *new_bias = new framework::CLImage();
new_bias->SetTensorData(new_bias_ptr, variance->dims());
new_bias->InitCLImage(this->cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
// DLOG << " climage - new bias: " << *new_bias;
//
// DLOG << " climage - filter: " << *(param->Filter());
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
delete[](new_scale_ptr);
delete[](new_bias_ptr);
PADDLE_MOBILE_ENFORCE(
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1],
"need equal");
int offset = static_cast<int>(param->Filter()->dims()[2]) / 2 -
static_cast<int>(param->Paddings()[1]);
param->SetOffset(offset);
if (param->Filter()->dims()[2] == 1 && param->Filter()->dims()[3] == 1) {
param->Filter()->InitNImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("convBNAdd_1x1", "conv_bn_add_relu_kernel.cl");
DLOG << " conv bn add relu conv 1x1";
} else if (param->Filter()->dims()[1] == 1 &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == 3) {
param->Filter()->InitDWImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("depth_convBNAdd_3x3",
"conv_bn_add_relu_kernel.cl");
DLOG << " conv bn add relu depth_conv_3x3";
} else if (param->Filter()->dims()[2] == 3 &&
param->Filter()->dims()[3] == 3) {
param->Filter()->InitCLImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("convBNAdd_3x3", "conv_bn_add_relu_kernel.cl");
DLOG << " conv bn add relu conv_3x3";
} else {
PADDLE_MOBILE_THROW_EXCEPTION(" not support ");
}
return true;
}
template <>
void ConvBNAddReluKernel<GPU_CL, float>::Compute(
const FusionConvBNAddReluParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Output());
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
auto input = param.Input()->GetCLImage();
auto filter = param.Filter()->GetCLImage();
auto biase = param.Bias()->GetCLImage();
auto new_scale = param.NewScale()->GetCLImage();
auto new_bias = param.NewBias()->GetCLImage();
auto output = param.Output()->GetCLImage();
int stride = param.Strides()[0];
int offset = param.Offset();
int input_c = reinterpret_cast<framework::CLImageConverterFolder *>(
param.Input()->Converter())
->GetCBlock();
int dilation = param.Dilations()[0];
int input_width = param.Input()->dims()[3];
int input_height = param.Input()->dims()[2];
int output_width = param.Output()->dims()[3];
int output_height = param.Output()->dims()[2];
// DLOG << " c block " << c_block;
// DLOG << " w " << w;
// DLOG << " nh " << nh;
// DLOG << " stride " << stride;
// DLOG << " offset " << offset;
// DLOG << " input_c " << input_c;
// DLOG << " dilation " << dilation;
// DLOG << " input width " << input_width;
// DLOG << " input height " << input_height;
// DLOG << " output width " << output_width;
// DLOG << " output height " << output_height;
// DLOG << " input dim " << *param.Input();
// DLOG << " output dim " <<* param.Output();
// DLOG << " filter dim " << *param.Filter();
// DLOG<<*param.Bias();
cl_int status;
status = clSetKernelArg(kernel, 0, sizeof(int), &c_block);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(int), &w);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &nh);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &input);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &biase);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 6, sizeof(cl_mem), &new_scale);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 7, sizeof(cl_mem), &new_bias);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 8, sizeof(cl_mem), &output);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 9, sizeof(int), &stride);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 10, sizeof(int), &offset);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 11, sizeof(int), &input_c);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 12, sizeof(int), &dilation);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 13, sizeof(int), &input_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 14, sizeof(int), &input_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 15, sizeof(int), &output_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 16, sizeof(int), &output_height);
CL_CHECK_ERRORS(status);
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
}
template class ConvBNAddReluKernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef DEPTHWISECONV_OP
#include "operators/kernel/depthwise_conv_kernel.h"
#include "operators/kernel/central-arm-func/depthwise_conv_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool DepthwiseConvKernel<GPU_CL, float>::Init(ConvParam<GPU_CL> *param) {
DLOG << " depthwise conv kernel init begin ";
PADDLE_MOBILE_ENFORCE(
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1],
"need equal");
param->Filter()->InitCLImage(cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
int offset = static_cast<int>(param->Filter()->dims()[2]) / 2 -
static_cast<int>(param->Paddings()[1]);
param->SetOffset(offset);
this->cl_helper_.AddKernel("depth_conv_3x3", "conv_add_bn_relu_kernel.cl");
DLOG << " depthwise conv kernel init end ";
return true;
}
template <>
void DepthwiseConvKernel<GPU_CL, float>::Compute(
const ConvParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Output());
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
auto input = param.Input()->GetCLImage();
auto filter = param.Filter()->GetCLImage();
auto output = param.Output()->GetCLImage();
int stride = param.Strides()[0];
int offset = param.Offset();
int input_c = reinterpret_cast<framework::CLImageConverterFolder *>(
param.Input()->Converter())
->GetCBlock();
int dilation = param.Dilations()[0];
int input_width = param.Input()->dims()[3];
int input_height = param.Input()->dims()[2];
int output_width = param.Output()->dims()[3];
int output_height = param.Output()->dims()[2];
cl_int status;
status = clSetKernelArg(kernel, 0, sizeof(int), &c_block);
status = clSetKernelArg(kernel, 1, sizeof(int), &w);
status = clSetKernelArg(kernel, 2, sizeof(int), &nh);
status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &input);
status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter);
status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &output);
status = clSetKernelArg(kernel, 6, sizeof(int), &stride);
status = clSetKernelArg(kernel, 7, sizeof(int), &offset);
status = clSetKernelArg(kernel, 8, sizeof(int), &input_c);
status = clSetKernelArg(kernel, 9, sizeof(int), &dilation);
status = clSetKernelArg(kernel, 10, sizeof(int), &input_width);
status = clSetKernelArg(kernel, 11, sizeof(int), &input_height);
status = clSetKernelArg(kernel, 12, sizeof(int), &output_width);
status = clSetKernelArg(kernel, 13, sizeof(int), &output_height);
CL_CHECK_ERRORS(status);
// cl_event out_event = param.Output()->GetClEvent();
// cl_event wait_event = param.Input()->GetClEvent();
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
}
template class DepthwiseConvKernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
///* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License. */
//
//#ifdef DEQUANT_OP
//
//#include "operators/kernel/dequantize_kernel.h"
//
// namespace paddle_mobile {
// namespace operators {
//
// template <>
// bool DequantizeKernel<GPU_CL, float>::Init(DequantizeParam<GPU_CL> *param) {
// DLOG << " depthwise conv kernel init begin ";
// PADDLE_MOBILE_ENFORCE(
// param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
// param->Paddings()[0] == param->Paddings()[1],
// "need equal");
// param->Filter()->InitCLImage(cl_helper_.CLContext(),
// this->cl_helper_.CLCommandQueue());
// int offset = static_cast<int>(param->Filter()->dims()[2]) / 2 -
// static_cast<int>(param->Paddings()[1]);
// param->SetOffset(offset);
// this->cl_helper_.AddKernel("depth_conv_3x3", "conv_add_bn_relu_kernel.cl");
// DLOG << " depthwise conv kernel init end ";
// return true;
//}
//
// template <>
// void DequantizeKernel<GPU_CL, float>::Compute(
// const DequantizeParam<GPU_CL> &param) {
// auto kernel = this->cl_helper_.KernelAt(0);
// auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Output());
// int c_block = default_work_size[0];
// int w = default_work_size[1];
// int nh = default_work_size[2];
// auto input = param.Input()->GetCLImage();
// auto filter = param.Filter()->GetCLImage();
// auto output = param.Output()->GetCLImage();
// int stride = param.Strides()[0];
// int offset = param.Offset();
// int input_c = reinterpret_cast<framework::CLImageConverterFolder *>(
// param.Input()->Converter())
// ->GetCBlock();
// int dilation = param.Dilations()[0];
//
// int input_width = param.Input()->dims()[3];
// int input_height = param.Input()->dims()[2];
// int output_width = param.Output()->dims()[3];
// int output_height = param.Output()->dims()[2];
//
// cl_int status;
//
// status = clSetKernelArg(kernel, 0, sizeof(int), &c_block);
// status = clSetKernelArg(kernel, 1, sizeof(int), &w);
// status = clSetKernelArg(kernel, 2, sizeof(int), &nh);
// status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &input);
// status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter);
// status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &output);
// status = clSetKernelArg(kernel, 6, sizeof(int), &stride);
// status = clSetKernelArg(kernel, 7, sizeof(int), &offset);
// status = clSetKernelArg(kernel, 8, sizeof(int), &input_c);
// status = clSetKernelArg(kernel, 9, sizeof(int), &dilation);
// status = clSetKernelArg(kernel, 10, sizeof(int), &input_width);
// status = clSetKernelArg(kernel, 11, sizeof(int), &input_height);
// status = clSetKernelArg(kernel, 12, sizeof(int), &output_width);
// status = clSetKernelArg(kernel, 13, sizeof(int), &output_height);
//
// CL_CHECK_ERRORS(status);
//
// // cl_event out_event = param.Output()->GetClEvent();
// // cl_event wait_event = param.Input()->GetClEvent();
//
// status = clEnqueueNDRangeKernel(
// this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(),
// NULL, default_work_size.data(), NULL, 0, NULL, NULL);
//
// CL_CHECK_ERRORS(status);
//}
//
// template class DepthwiseConvKernel<GPU_CL, float>;
//
//} // namespace operators
//} // namespace paddle_mobile
//
//#endif
......@@ -24,7 +24,11 @@ bool ElementwiseAddKernel<GPU_CL, float>::Init(
ElementwiseAddParam<GPU_CL> *param) {
DLOG << "-----init add-----";
CLImage *bias = (CLImage *)(param->InputY());
bias->InitCLImage(cl_helper_.CLContext(), this->cl_helper_.CLCommandQueue());
if (!bias->isInit()) {
bias->InitCLImage(cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
}
DLOG << " bias: " << *bias;
if (bias->dims().size() == 4) {
this->cl_helper_.AddKernel("elementwise_add", "elementwise_add_kernel.cl");
......
......@@ -23,12 +23,61 @@ namespace operators {
template <>
bool DeconvAddKernel<FPGA, float>::Init(FusionDeconvAddParam<FPGA> *param) {
bool relu_enabled = false;
auto input = const_cast<Tensor *>(param->Input());
const Tensor *bias = param->Bias();
auto bias_ptr = bias->data<float>();
auto filter = const_cast<Tensor *>(param->Filter());
auto out = param->Output();
PADDLE_MOBILE_ENFORCE(out->dims()[1] == bias->dims()[0],
"Output channel should be equal to bias number");
int channel = out->dims()[1];
int sub_conv_n = param->Strides()[0];
auto bs_ptr = (float *)fpga::fpga_malloc(2 * channel * sub_conv_n *
sizeof(float)); // NOLINT
for (int i = 0; i < channel * sub_conv_n; i++) {
bs_ptr[i + sub_conv_n * channel] = 1;
bs_ptr[i] = bias_ptr[i % (channel)];
}
PADDLE_MOBILE_ENFORCE(param->Strides()[1] == param->Strides()[0],
"stride_width should be equal to stride_height ");
PADDLE_MOBILE_ENFORCE(filter->dims()[2] == filter->dims()[3],
"filter width should be equal to filter height ");
PADDLE_MOBILE_ENFORCE(((filter->dims()[2] % param->Strides()[0]) == 0),
"filter axis should be the multiple of stride axis ");
float max_value = fpga::filter_find_max(filter);
fpga::format_deconv_filter(filter, max_value, param->Groups(),
param->Strides()[0]);
// int element_num_per_div =
// fpga::get_filter_num_per_div(filter, param->Groups());
// deconv only support group=1 && no spilt
fpga::format_bias_scale_array(&bs_ptr, channel * sub_conv_n,
channel * sub_conv_n);
fpga::format_fp16_ofm(out);
fpga::DeconvArgs deconv_arg = {0};
fpga::fill_deconv_arg(&deconv_arg, input, out, filter, relu_enabled,
param->Groups(), param->Strides()[0],
param->Strides()[1], param->Paddings()[0],
param->Paddings()[1], bs_ptr);
param->SetFpgaArgs(deconv_arg);
return true;
}
template <>
void DeconvAddKernel<FPGA, float>::Compute(
const FusionDeconvAddParam<FPGA> &param) {}
const FusionDeconvAddParam<FPGA> &param) {
fpga::ComputeFpgaDeconv(param.FpgaArgs());
}
} // namespace operators
} // namespace paddle_mobile
......
......@@ -24,12 +24,60 @@ namespace operators {
template <>
bool DeconvAddReluKernel<FPGA, float>::Init(
FusionDeconvAddReluParam<FPGA> *param) {
bool relu_enabled = true;
auto input = const_cast<Tensor *>(param->Input());
const Tensor *bias = param->Bias();
auto bias_ptr = bias->data<float>();
auto filter = const_cast<Tensor *>(param->Filter());
auto out = param->Output();
PADDLE_MOBILE_ENFORCE(out->dims()[1] == bias->dims()[0],
"Output channel should be equal to bias number");
int channel = out->dims()[1];
int sub_conv_n = param->Strides()[0];
auto bs_ptr = (float *)fpga::fpga_malloc(2 * channel * sub_conv_n *
sizeof(float)); // NOLINT
for (int i = 0; i < channel * sub_conv_n; i++) {
bs_ptr[i + sub_conv_n * channel] = 1;
bs_ptr[i] = bias_ptr[i % (channel)];
}
PADDLE_MOBILE_ENFORCE(param->Strides()[1] == param->Strides()[0],
"stride_width should be equal to stride_height ");
PADDLE_MOBILE_ENFORCE(filter->dims()[2] == filter->dims()[3],
"filter width should be equal to filter height ");
PADDLE_MOBILE_ENFORCE(((filter->dims()[2] % param->Strides()[0]) == 0),
"filter axis should be the multiple of stride axis ");
float max_value = fpga::filter_find_max(filter);
fpga::format_deconv_filter(filter, max_value, param->Groups(),
param->Strides()[0]);
// int element_num_per_div =
// fpga::get_filter_num_per_div(filter, param->Groups());
// deconv only support group=1 && no spilt
fpga::format_bias_scale_array(&bs_ptr, channel * sub_conv_n,
channel * sub_conv_n);
fpga::format_fp16_ofm(out);
fpga::DeconvArgs deconv_arg = {0};
fpga::fill_deconv_arg(&deconv_arg, input, out, filter, relu_enabled,
param->Groups(), param->Strides()[0],
param->Strides()[1], param->Paddings()[0],
param->Paddings()[1], bs_ptr);
param->SetFpgaArgs(deconv_arg);
return true;
}
template <>
void DeconvAddReluKernel<FPGA, float>::Compute(
const FusionDeconvAddReluParam<FPGA> &param) {}
const FusionDeconvAddReluParam<FPGA> &param) {
fpga::ComputeFpgaDeconv(param.FpgaArgs());
}
} // namespace operators
} // namespace paddle_mobile
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册