提交 8884755c 编写于 作者: H hjchen2

Merge branch 'dev-latest' of https://github.com/hjchen2/paddle-mobile into dev-latest

......@@ -196,19 +196,35 @@ void fill_split_arg(struct SplitConvArgs *arg, framework::Tensor *input,
arg->conv_arg[i].image.pad_height = (uint32_t)padding_h;
arg->conv_arg[i].image.pad_width = (uint32_t)padding_w;
arg->conv_arg[i].filter_scale_address = filter->scale;
arg->conv_arg[i].filter_address = &(
(int8_t *)filter_ptr)[i * element_num * filter_num_per_div]; // NOLINT
arg->conv_arg[i].sb_address = &bs_ptr[i * filter_num_per_div * 2];
// arg->conv_arg[i].filter_address = &(
// (int8_t *)filter_ptr)[i * element_num * filter_num_per_div]; //
// NOLINT
// arg->conv_arg[i].sb_address = &bs_ptr[i * filter_num_per_div * 2];
arg->conv_arg[i].filter_num = (uint32_t)(
i == n - 1 ? channel - (n - 1) * filter_num_per_div // NOLINT
: filter_num_per_div);
size_t filter_size =
element_num * arg->conv_arg[i].filter_num * sizeof(int8_t);
auto filter_head =
&((int8_t *)filter_ptr)[i * element_num * filter_num_per_div];
arg->conv_arg[i].filter_address = fpga_malloc(filter_size);
memcpy(arg->conv_arg[i].filter_address, filter_head, filter_size);
fpga_flush(arg->conv_arg[i].filter_address, filter_size);
size_t bs_size = 2 * arg->conv_arg[i].filter_num * sizeof(float);
auto bs_head = &bs_ptr[i * filter_num_per_div * 2];
arg->conv_arg[i].sb_address = fpga_malloc(bs_size);
memcpy(arg->conv_arg[i].sb_address, bs_head, bs_size);
fpga_flush(arg->conv_arg[i].sb_address, bs_size);
if (n > 1) {
arg->conv_arg[i].output.scale_address =
(float *)fpga_malloc(2 * sizeof(float)); // NOLINT
arg->conv_arg[i].output.address =
fpga_malloc(input->dims()[2] *
align_to_x(input->dims()[3] * arg->conv_arg[i].filter_num,
fpga_malloc(out->dims()[2] *
align_to_x(out->dims()[3] * arg->conv_arg[i].filter_num,
IMAGE_ALIGNMENT) *
sizeof(half));
} else {
......@@ -221,6 +237,8 @@ void fill_split_arg(struct SplitConvArgs *arg, framework::Tensor *input,
arg->concat_arg.scales_in[i] = arg->conv_arg[i].output.scale_address;
arg->concat_arg.channel_num[i] = arg->conv_arg[i].filter_num;
}
filter->reset_data_ptr(nullptr);
fpga_free(bs_ptr);
}
} // namespace fpga
......
......@@ -137,24 +137,23 @@ void align_num(char **data_in, int num_per_div_before_alignment, int num,
int align_chw = align_to_x(chw, FILTER_ELEMENT_ALIGNMENT);
int num_per_div_after_alignment =
align_to_x(num_per_div_before_alignment, FILTER_NUM_ALIGNMENT);
if (num_per_div_after_alignment != num_per_div_before_alignment) {
char *tmp = *data_in;
int div_num =
(num + num_per_div_before_alignment - 1) / num_per_div_before_alignment;
int num_element = div_num * num_per_div_after_alignment * align_chw;
char *data_tmp = (char *)fpga_malloc(num_element * sizeof(char)); // NOLINT
memset(data_tmp, 0, num_element * sizeof(char));
char *tmp = *data_in;
int div_num =
(num + num_per_div_before_alignment - 1) / num_per_div_before_alignment;
int num_element = div_num * num_per_div_after_alignment * align_chw;
char *data_tmp = (char *)fpga_malloc(num_element * sizeof(char)); // NOLINT
for (i = 0; i < div_num; i++) {
memcpy(data_tmp + num_per_div_after_alignment * align_chw * i,
*data_in + num_per_div_before_alignment * align_chw * i,
num_per_div_before_alignment * align_chw);
}
memset(data_tmp, 0, num_element * sizeof(char));
*data_in = data_tmp;
fpga_free(tmp);
for (i = 0; i < div_num; i++) {
memcpy(data_tmp + num_per_div_after_alignment * align_chw * i,
*data_in + num_per_div_before_alignment * align_chw * i,
num_per_div_before_alignment * align_chw);
}
*data_in = data_tmp;
fpga_free(tmp);
}
void reorder(char **data_in, int num_after_alignment, int chw) {
......@@ -223,7 +222,10 @@ void format_filter(float **data_in, int num, int channel, int height, int width,
char **quantize_data = (char **)data_in; // NOLINT
convert_to_hwc(quantize_data, num, channel, height, width);
align_element(quantize_data, num, chw);
align_num(quantize_data, num_per_div_before_alignment, num, chw);
if (num_after_alignment != num) {
align_num(quantize_data, num_per_div_before_alignment, num, chw);
}
reorder(quantize_data, num_after_alignment, chw);
interleave(quantize_data, num_after_alignment, chw);
fpga_flush(*quantize_data, align_to_x(chw, FILTER_ELEMENT_ALIGNMENT) *
......@@ -254,15 +256,18 @@ void format_fc_filter(float **data_in, int num, int channel, int height,
align_to_x(num_per_div_before_alignment, FILTER_NUM_ALIGNMENT);
int div_num =
(num + num_per_div_before_alignment - 1) / num_per_div_before_alignment;
int num_after_alignment = num_per_div_after_alignment * div_num;
int residual = num % num_per_div_before_alignment;
int num_after_alignment = num_per_div_after_alignment *
((residual == 0) ? div_num : (div_num - 1)) +
align_to_x(residual, FILTER_NUM_ALIGNMENT);
quantize(data_in, data_size, max);
char **quantize_data = (char **)data_in; // NOLINT
convert_fc_filter(quantize_data, num, chw);
align_element(quantize_data, num, chw);
align_num(quantize_data, num_per_div_before_alignment, num, chw);
if (num_after_alignment != num) {
align_num(quantize_data, num_per_div_before_alignment, num, chw);
}
reorder(quantize_data, num_after_alignment, chw);
interleave(quantize_data, num_after_alignment, chw);
fpga_flush(*quantize_data, align_to_x(chw, FILTER_ELEMENT_ALIGNMENT) *
......
此差异已折叠。
......@@ -132,11 +132,11 @@ void format_concat_output(framework::Tensor *out, int height, int width,
}
int format_conv_data(framework::Tensor *filter_tensor,
framework::Tensor *ofm_tensor, float *bs_ptr, int group) {
framework::Tensor *ofm_tensor, float **bs_ptr, int group) {
float max_value = fpga::filter_find_max(filter_tensor);
fpga::format_filter(filter_tensor, max_value, group);
int aligned_num = get_aligned_filter_num(filter_tensor);
fpga::format_bias_scale_array(&bs_ptr,
fpga::format_bias_scale_array(bs_ptr,
(int)filter_tensor->dims()[0], // NOLINT
aligned_num);
int aligned_channel = fpga::get_conv_output_channel(filter_tensor);
......
......@@ -39,7 +39,7 @@ void format_bias_scale_array(float** bias_scale_array, int filter_num,
void format_concat_output(framework::Tensor* out, int height, int width,
uint32_t out_channel);
int format_conv_data(framework::Tensor* filter_tensor,
framework::Tensor* ofm_tensor, float* bs_ptr, int group);
framework::Tensor* ofm_tensor, float** bs_ptr, int group);
int format_fc_data(framework::Tensor* filter_tensor,
framework::Tensor* ofm_tensor, float* bs_ptr);
void fill_split_arg(struct SplitConvArgs* arg, framework::Tensor* input,
......
......@@ -137,11 +137,13 @@ int fpga_regpoll(uint64_t reg, uint64_t val, int time) {
for (i = 0; i < timeout; i++) {
if (val == reg_readq(reg)) {
std::cout << "fpga_regpoll:" << i << "val:" << val << "reg:" << reg
<< std::endl;
break;
}
}
if (i <= timeout) {
if (i < timeout) {
return 0;
} else {
return -1;
......@@ -153,6 +155,12 @@ int memory_request(struct fpga_memory *memory, size_t size, uint64_t *addr) {
uint64_t _nr = DIV_ROUND_UP(size, FPGA_PAGE_SIZE);
unsigned int nr = (unsigned int)_nr;
int ret = 0;
DLOG << size;
DLOG << _nr;
DLOG << nr;
uint64_t a_size = FPGA_PAGE_SIZE * nr;
DLOG << a_size;
pthread_mutex_lock(&memory->mutex);
......@@ -166,6 +174,7 @@ int memory_request(struct fpga_memory *memory, size_t size, uint64_t *addr) {
*addr = address_ofset;
} else {
DLOG << "memory request failed!";
ret = -ENOMEM;
}
......@@ -282,7 +291,7 @@ uint64_t vaddr_to_paddr(void *address) {
if (iter != g_fpgainfo.fpga_vaddr2paddr_map.end()) {
paddr = iter->second;
} else {
DLOG << "Invalid pointer";
DLOG << "Invalid pointer: " << address;
}
return paddr;
......@@ -348,6 +357,11 @@ void fpga_free_driver(void *ptr) {
fpga_bitmap::bitmap_clear(g_fpgainfo.memory_info->bitmap, pos,
g_fpgainfo.memory_info->nr[pos]);
pthread_mutex_unlock(&g_fpgainfo.memory_info->mutex);
auto iter = g_fpgainfo.fpga_vaddr2paddr_map.find(ptr);
if (iter != g_fpgainfo.fpga_vaddr2paddr_map.end()) {
g_fpgainfo.fpga_vaddr2paddr_map.erase(iter);
}
} else {
DLOG << "Invalid pointer";
}
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <ctype.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <cstring>
#include <map>
......@@ -44,7 +45,7 @@ const int PE_IDX_POOLING = 1;
const int PE_IDX_EW = 2;
const int PE_IDX_BYPASS = 3;
enum pe_status { IDLE = 0, BUSY = 1 };
enum pe_status { IDLE = 0, BUSY = 1, ERROR = 2 };
struct MemoryCacheArgs {
void *offset;
......@@ -58,7 +59,7 @@ struct MemoryCacheArgs {
struct fpga_pe {
char type_name[MAX_TYPE_NAME_LENTH + 1];
struct pe_data_s *outer;
pe_status status; // 0=idle 1=busy -1=fail
pe_status status;
uint64_t interrupt_cnt;
};
......@@ -106,6 +107,8 @@ inline uint64_t reg_readq(uint32_t offset) {
uint64_t value =
*(volatile uint64_t *)((uint8_t *)g_fpgainfo.FpgaRegVirAddr + // NOLINT
offset); // NOLINT
// DLOG << "read end";
usleep(10);
return value;
}
......@@ -114,6 +117,8 @@ inline void reg_writeq(uint64_t value, uint32_t offset) {
// DLOG << "offset : " << offset << ", value : " << value;
*(volatile uint64_t *)((uint8_t *)g_fpgainfo.FpgaRegVirAddr + // NOLINT
offset) = value;
// DLOG << "write end";
usleep(10);
}
int open_device_driver();
......
......@@ -74,12 +74,21 @@ struct ConcatArgs {
void* image_out;
float* scale_out;
uint32_t* channel_num;
// uint32_t* aligned_channel_num;
// uint32_t out_channel;
uint32_t* aligned_channel_num;
uint32_t out_channel;
uint32_t height;
uint32_t width;
};
struct SplitConvArgs {
uint32_t split_num;
uint32_t group_num;
uint32_t filter_num;
struct ImageOutputArgs output;
struct ConvArgs* conv_arg;
struct ConcatArgs concat_arg;
};
struct SplitArgs {
uint32_t image_num;
int16_t* image_in;
......@@ -91,15 +100,6 @@ struct SplitArgs {
uint32_t width;
};
struct SplitConvArgs {
uint32_t split_num;
uint32_t group_num;
uint32_t filter_num;
struct ImageOutputArgs output;
struct ConvArgs* conv_arg;
struct ConcatArgs concat_arg;
};
struct PoolingArgs {
int16_t mode; // mode: 0:max, 1:avg
int16_t kernel_reciprocal;
......@@ -127,7 +127,14 @@ struct BypassArgs {
};
struct DeconvArgs {
struct ConvArgs conv_arg;
uint32_t sub_conv_num;
uint32_t group_num;
uint32_t filter_num;
uint32_t omit_size;
uint32_t sub_output_width;
uint32_t sub_output_height;
struct ImageOutputArgs output;
struct ConvArgs* conv_args;
};
static inline int align_to_x(int num, int x) { return (num + x - 1) / x * x; }
......
......@@ -68,6 +68,13 @@ class CLImage {
InitCLImage(context, command_queue, folder_converter);
}
void InitNormalCLImage(cl_context context, cl_command_queue command_queue) {
PADDLE_MOBILE_ENFORCE(tensor_data_ != nullptr,
" need call SetTensorData first");
CLImageConverterNormal *normal_converter = new CLImageConverterNormal();
InitCLImage(context, command_queue, normal_converter);
}
void InitCLImage(cl_context context, cl_command_queue command_queue,
CLImageConverterBase *converter) {
if (image_converter_ != nullptr) {
......
......@@ -22,7 +22,6 @@ void FeedOp<DeviceType, T>::InferShape() const {
auto out_dims = this->param_.Out()->dims();
out_dims[0] = this->param_.BatchSize();
auto input_dims = this->param_.InputX()->dims();
DLOG << input_dims.size();
if (input_dims.size() == 4) {
this->param_.Out()->Resize(input_dims);
} else {
......
......@@ -60,6 +60,9 @@ REGISTER_FUSION_MATCHER(fusion_fc, ops::FusionFcMatcher);
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_fc, ops::FusionFcOp);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(fusion_fc, ops::FusionFcOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
REGISTER_OPERATOR_MALI_GPU(fusion_fc, ops::FusionFcOp);
#endif
......
/* Copyright (c) 201f8 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -167,8 +167,9 @@ void DequantBNQuantCompute(const FusionDequantAddBNQuantParam<CPU> *param) {
// if (param->is_static_) {
if (true) {
max_abs = param->static_scale_;
max_abs = param->offline_scale_->data<float>()[0];
float quant_scale = 127.f / max_abs;
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < batch_size; ++batch) {
for (int c = 0; c < channels; ++c) {
......@@ -287,8 +288,9 @@ void DequantBNReluQuantCompute(
// if (param->is_static_) {
if (true) {
max_abs = param->static_scale_;
max_abs = param->offline_scale_->data<float>()[0];
float quant_scale = 127.f / max_abs;
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < batch_size; ++batch) {
for (int c = 0; c < channels; ++c) {
......
......@@ -107,15 +107,9 @@ inline void GemmConv(const ConvParam<CPU> &param) {
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
if (param.Input()->type() == typeid(int8_t)) {
math::matmul_int8(filter_slice, false, col_matrix, false,
math::matmul<Itype>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0));
} else {
math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0));
}
}
}
}
......
......@@ -73,8 +73,8 @@ void MulCompute(const MulParam<CPU> &param) {
}
if (param.InputX()->type() == typeid(int8_t)) {
out->mutable_data<int32_t>();
math::matmul_int8(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(0));
math::matmul<int8_t>(x_matrix, false, y_matrix, false,
static_cast<int8_t>(1), out, static_cast<int8_t>(0));
} else {
out->mutable_data<float>();
......
......@@ -13,7 +13,27 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
/*
__kernel void concatByC0(__read_only image2d_t input_image,
__write_only image2d_t output_image,
__private const int out_W) {
const int in_c = get_global_id(0);
const int in_w = get_global_id(1);
const int in_nh = get_global_id(2);
int2 input_pos ;
input_pos.x = in_c * out_W + in_w;
input_pos.y = in_nh;
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
half4 input;
input = read_imageh(input_image, sampler,input_pos);
write_imageh(output_image, input_pos, input);
}
__kernel void concatByC(__read_only image2d_t input_image1,
__read_only image2d_t input_image2,
......@@ -24,13 +44,13 @@ __kernel void concatByC(__read_only image2d_t input_image1,
__private const int out_C_Start,
__private const int in_W,
__private const int in_H,
__private const int int_C1,
__private const int int_C2) {
__private const int in_C1,
__private const int in_C2) {
const int in_c = get_global_id(0);
const int in_w = get_global_id(1);
const int in_nh = get_global_id(2);
int out_c1 = (out_C_Start)/4 + in_c;
int out_c1 = (out_C_Start + 3)/4 -1 + in_c;
int out_c2 = out_c1 + 1;
......@@ -45,7 +65,7 @@ __kernel void concatByC(__read_only image2d_t input_image1,
int2 input_pos1;
if(in_c==0){
input_pos1.x = ((in_C1-1)/4) * in_W + in_w;
input_pos1.x = ((in_C1 + 3)/4-1) * in_W + in_w;
}else{
input_pos1.x = (in_c - 1) * in_W + in_w;
}
......@@ -103,26 +123,6 @@ __kernel void concatByC(__read_only image2d_t input_image1,
write_imageh(output_image, output_pos2, output2);
}
__kernel void concatByW0(__read_only image2d_t input_image,
__write_only image2d_t output_image,
__private const int out_W) {
const int in_c = get_global_id(0);
const int in_w = get_global_id(1);
const int in_nh = get_global_id(2);
int2 input_pos = in_c * out_W + in_w;
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
half4 input;
input = read_imageh(input_image, sampler,input_pos);
write_imageh(output_image, input_pos, input);
}
*/
__kernel void concatByH(__read_only image2d_t input_image,
__write_only image2d_t output_image,
......
......@@ -692,6 +692,238 @@ __kernel void conv_1x1_4(__private const int global_size_dim0,
*/
__kernel void conv_7x7(__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#ifdef BIASE
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int offset,
__private const int input_c,
__private const int dilation,
__private const int input_width,/* of one block */
__private const int input_height,/* of one block */
__private const int output_width,
__private const int output_height) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
if (out_c >= global_size_dim0 ||
out_w >= global_size_dim1 ||
out_nh >= global_size_dim2) {
return;
}
const filter_n0 = 4 * out_c + 0;
const filter_n1 = 4 * out_c + 1;
const filter_n2 = 4 * out_c + 2;
const filter_n3 = 4 * out_c + 3;
int2 stride_xy;
stride_xy.x = stride;
stride_xy.y = stride;
int2 ouput_pos_in_one_block;
ouput_pos_in_one_block.x = out_w;
ouput_pos_in_one_block.y = out_nh;
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
int2 in_pos_in_one_block;
in_pos_in_one_block.x = ouput_pos_in_one_block.x * stride + offset;
in_pos_in_one_block.y = ouput_pos_in_one_block.y * stride + offset;
#ifdef BIASE
half4 output = read_imageh(bias, sampler, (int2)(out_c, 0));
#else
half4 output = 0.0f;
#endif
half4 input;
half4 filter[4];
int2 filter_pos0;
int2 filter_pos1;
int2 filter_pos2;
int2 filter_pos3;
for (int i = 0; i < input_c; ++i) {
int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y);
for(int j = 0; j < 7; j++){
for(int k = 0; k < 7; k++){
input = select(read_imageh(input_image, sampler,
(int2)(pos_in.x + (j - 3) * dilation, pos_in.y + (k - 3) * dilation)),
(half4)(0.0f),
(ushort4)((in_pos_in_one_block.x + (j - 3) * dilation < 0 || in_pos_in_one_block.y + (k - 3) * dilation < 0 || in_pos_in_one_block.x + (j - 3) * dilation >= input_width || in_pos_in_one_block.y + (k - 3) * dilation >= input_height) << 15));
int filter_h = k;
int filter_w = j;
int filter_c = i;
filter_pos0.x = filter_c * 7 + filter_w;
filter_pos0.y = filter_n0 * 7 + filter_h;
filter_pos1.x = filter_c * 7 + filter_w;
filter_pos1.y = filter_n1 * 7 + filter_h;
filter_pos2.x = filter_c * 7 + filter_w;
filter_pos2.y = filter_n2 * 7 + filter_h;
filter_pos3.x = filter_c * 7 + filter_w;
filter_pos3.y = filter_n3 * 7 + filter_h;
filter[0] = read_imageh(filter_image, sampler, filter_pos0);
filter[1] = read_imageh(filter_image, sampler, filter_pos1);
filter[2] = read_imageh(filter_image, sampler, filter_pos2);
filter[3] = read_imageh(filter_image, sampler, filter_pos3);
output.x += dot(input, filter[0]);
output.y += dot(input, filter[1]);
output.z += dot(input, filter[2]);
output.w += dot(input, filter[3]);
}
}
}
#ifdef BATCH_NORM
output = output * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + read_imageh(new_biase, sampler, (int2)(out_c, 0));
#endif
#ifdef RELU
output = activation(output);
#endif
write_imageh(output_image, (int2)(out_c * global_size_dim1 + out_w, out_nh), output);
}
__kernel void conv_5x5(__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#ifdef BIASE
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int offset,
__private const int input_c,
__private const int dilation,
__private const int input_width,/* of one block */
__private const int input_height,/* of one block */
__private const int output_width,
__private const int output_height) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
if (out_c >= global_size_dim0 ||
out_w >= global_size_dim1 ||
out_nh >= global_size_dim2) {
return;
}
const filter_n0 = 4 * out_c + 0;
const filter_n1 = 4 * out_c + 1;
const filter_n2 = 4 * out_c + 2;
const filter_n3 = 4 * out_c + 3;
int2 stride_xy;
stride_xy.x = stride;
stride_xy.y = stride;
int2 ouput_pos_in_one_block;
ouput_pos_in_one_block.x = out_w;
ouput_pos_in_one_block.y = out_nh;
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
int2 in_pos_in_one_block;
in_pos_in_one_block.x = ouput_pos_in_one_block.x * stride + offset;
in_pos_in_one_block.y = ouput_pos_in_one_block.y * stride + offset;
#ifdef BIASE
half4 output = read_imageh(bias, sampler, (int2)(out_c, 0));
#else
half4 output = 0.0f;
#endif
half4 input;
half4 filter[4];
int2 filter_pos0;
int2 filter_pos1;
int2 filter_pos2;
int2 filter_pos3;
for (int i = 0; i < input_c; ++i) {
int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y);
for(int j = 0; j < 5; j++){
for(int k = 0; k < 5; k++){
input = select(read_imageh(input_image, sampler,
(int2)(pos_in.x + (j - 2) * dilation, pos_in.y + (k - 2) * dilation)),
(half4)(0.0f),
(ushort4)((in_pos_in_one_block.x + (j - 2) * dilation < 0 || in_pos_in_one_block.y + (k - 2) * dilation < 0 || in_pos_in_one_block.x + (j - 2) * dilation >= input_width || in_pos_in_one_block.y + (k - 2) * dilation >= input_height) << 15));
int filter_h = k;
int filter_w = j;
int filter_c = i;
filter_pos0.x = filter_c * 5 + filter_w;
filter_pos0.y = filter_n0 * 5 + filter_h;
filter_pos1.x = filter_c * 5 + filter_w;
filter_pos1.y = filter_n1 * 5 + filter_h;
filter_pos2.x = filter_c * 5 + filter_w;
filter_pos2.y = filter_n2 * 5 + filter_h;
filter_pos3.x = filter_c * 5 + filter_w;
filter_pos3.y = filter_n3 * 5 + filter_h;
filter[0] = read_imageh(filter_image, sampler, filter_pos0);
filter[1] = read_imageh(filter_image, sampler, filter_pos1);
filter[2] = read_imageh(filter_image, sampler, filter_pos2);
filter[3] = read_imageh(filter_image, sampler, filter_pos3);
output.x += dot(input, filter[0]);
output.y += dot(input, filter[1]);
output.z += dot(input, filter[2]);
output.w += dot(input, filter[3]);
}
}
}
#ifdef BATCH_NORM
output = output * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + read_imageh(new_biase, sampler, (int2)(out_c, 0));
#endif
#ifdef RELU
output = activation(output);
#endif
write_imageh(output_image, (int2)(out_c * global_size_dim1 + out_w, out_nh), output);
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void lrn(__read_only image2d_t input_image,
__write_only image2d_t output_image,
__private const int out_C,
__private const int out_W,
__private const int n,
__private const float k,
__private const float alpha,
__private const float beta){
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
const int out_c0 = out_c * 4;
const int out_c1 = out_c * 4 + 1;
const int out_c2 = out_c * 4+ 2;
const int out_c3 = out_c * 4+ 3;
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
const int start = -(n-1)/2;
const end = start + n;
float sqr_sum0 = 0.0f;
float sqr_sum1 = 0.0f;
float sqr_sum2 = 0.0f;
float sqr_sum3 = 0.0f;
int input_c0,input_c1,input_c2,input_c3;
int2 input_pos0,input_pos1,input_pos2,input_pos3;
float4 input0,input1,input2,input3;
for(int i = start; i < end ;i++){
if(out_c0 + i>=0&&out_c0 + i<out_C){
input_c0 = (out_c0 + i)/4;
input_pos0.x = input_c0 * out_W + out_w;
input_pos0.y = out_nh;
input0 = convert_float4(read_imageh(input_image, sampler,input_pos0));
if((out_c0 + i)%4 == 0){
sqr_sum0 += input0.x * input0.x;
}else if((out_c0 + i)%4 == 1){
sqr_sum0 += input0.y * input0.y;
}else if((out_c0 + i)%4 == 2){
sqr_sum0 += input0.z * input0.z;
}else{
sqr_sum0 += input0.w * input0.w;
}
}
if(out_c1 + i>=0&&out_c1 + i<out_C){
input_c1 = (out_c1 + i)/4;
input_pos1.x = input_c1 * out_W + out_w;
input_pos1.y = out_nh;
input1 = convert_float4(read_imageh(input_image, sampler,input_pos1));
if((out_c1 + i)%4 == 0){
sqr_sum1 += input1.x * input1.x;
}else if((out_c1 + i)%4 == 1){
sqr_sum1 += input1.y * input1.y;
}else if((out_c1 + i)%4 == 2){
sqr_sum1 += input1.z * input1.z;
}else{
sqr_sum1 += input1.w * input1.w;
}
}
if(out_c2 + i>=0&&out_c2 + i<out_C){
input_c2 = (out_c2 + i)/4;
input_pos2.x = input_c2 * out_W + out_w;
input_pos2.y = out_nh;
input2 = convert_float4(read_imageh(input_image, sampler,input_pos2));
if((out_c2 + i)%4 == 0){
sqr_sum2 += input2.x * input2.x;
}else if((out_c2 + i)%4 == 1){
sqr_sum2 += input2.y * input2.y;
}else if((out_c2 + i)%4 == 2){
sqr_sum2 += input2.z * input2.z;
}else{
sqr_sum2 += input2.w * input2.w;
}
}
if(out_c3 + i>=0&&out_c3 + i<out_C){
input_c3 = (out_c3 + i)/4;
input_pos3.x = input_c3 * out_W + out_w;
input_pos3.y = out_nh;
input3 = convert_float4(read_imageh(input_image, sampler,input_pos3));
if((out_c3 + i)%4 == 0){
sqr_sum3 += input3.x * input3.x;
}else if((out_c3 + i)%4 == 1){
sqr_sum3 += input3.y * input3.y;
}else if((out_c3 + i)%4 == 2){
sqr_sum3 += input3.z * input3.z;
}else{
sqr_sum3 += input3.w * input3.w;
}
}
}
float4 output = (float4)0.0f;
float4 input;
int2 output_pos;
output_pos.x = out_c * out_W + out_w;
output_pos.y = out_nh;
input = convert_float4(read_imageh(input_image, sampler,output_pos));
output.x = input.x / (pow(k + alpha * (sqr_sum0),beta));
if(out_C - 4 * out_c>=2){
output.y = input.y / (pow(k + alpha * (sqr_sum1),beta));
}
if(out_C - 4 * out_c>=3){
output.z = input.z / (pow(k + alpha * (sqr_sum2),beta));
}
if(out_C - 4 * out_c>=4){
output.w = input.w / (pow(k + alpha * (sqr_sum3),beta));
}
half4 tmp = convert_half4(output);
write_imageh(output_image, output_pos, tmp);
}
\ No newline at end of file
......@@ -31,11 +31,13 @@ __kernel void pool_max(
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int start_h = max(out_h * stride_h - pad_top, 0);
int start_h = out_h * stride_h - pad_top;
int end_h = min(start_h + ksize_h, in_height);
start_h = max(start_h,0);
int start_w = max(out_w * stride_w - pad_left, 0);
int start_w = out_w * stride_w - pad_left;
int end_w = min(start_w + ksize_w, in_width);
start_w = max(start_w,0);
const int pos_in_x = out_c * in_width;
const int pos_in_y = out_n * in_height;
......
......@@ -23,12 +23,17 @@ template <>
bool ConcatKernel<GPU_CL, float>::Init(ConcatParam<GPU_CL> *param) {
if (param->Out()->dims().size() < 4) {
this->cl_helper_.AddKernel("concatByH", "concat_kernel.cl");
} else if (param->Out()->dims().size() == 4) {
this->cl_helper_.AddKernel("concatByC0", "concat_kernel.cl");
this->cl_helper_.AddKernel("concatByC", "concat_kernel.cl");
}
return true;
}
template <>
void ConcatKernel<GPU_CL, float>::Compute(const ConcatParam<GPU_CL> &param) {
DLOG << "yangfei50";
DLOG << param.Out()->dims();
if (param.Out()->dims().size() < 4) {
auto kernel = this->cl_helper_.KernelAt(0);
auto inputs = param.Inputs();
......@@ -62,6 +67,76 @@ void ConcatKernel<GPU_CL, float>::Compute(const ConcatParam<GPU_CL> &param) {
out_H_Start += inputs[i]->dims()[0];
}
}
} else {
auto kernel0 = this->cl_helper_.KernelAt(0);
auto kernel1 = this->cl_helper_.KernelAt(1);
auto inputs = param.Inputs();
auto *output_image = param.Out()->GetCLImage();
int out_C_Start = 0;
auto input_image = inputs[0]->GetCLImage();
auto default_work_size = this->cl_helper_.DefaultWorkSize(*inputs[0]);
int out_W = param.Out()->dims()[3];
cl_int status;
status = clSetKernelArg(kernel0, 0, sizeof(cl_mem), &input_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel0, 1, sizeof(cl_mem), &output_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel0, 2, sizeof(int), &out_W);
CL_CHECK_ERRORS(status);
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel0, default_work_size.size(),
NULL, default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
out_C_Start += inputs[0]->dims()[1];
for (int i = 1; i < inputs.size(); i++) {
auto input_image1 = inputs[i - 1]->GetCLImage();
auto input_image2 = inputs[i]->GetCLImage();
default_work_size = this->cl_helper_.DefaultWorkSize(*inputs[i]);
int out_C = param.Out()->dims()[1];
int out_H = param.Out()->dims()[2];
int in_W = inputs[i]->dims()[3];
int in_H = inputs[i]->dims()[2];
int in_C1 = inputs[i - 1]->dims()[1];
int in_C2 = inputs[i]->dims()[1];
DLOG << "第" << i << "个";
DLOG << "out_C=" << out_C;
DLOG << "out_H=" << out_H;
DLOG << "in_W=" << in_W;
DLOG << "in_H=" << in_H;
DLOG << "in_C1=" << in_C1;
DLOG << "in_C2=" << in_C2;
DLOG << "out_C_Start = " << out_C_Start;
status = clSetKernelArg(kernel1, 0, sizeof(cl_mem), &input_image1);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel1, 1, sizeof(cl_mem), &input_image2);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel1, 2, sizeof(cl_mem), &output_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel1, 3, sizeof(int), &out_C);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel1, 4, sizeof(int), &out_H);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel1, 5, sizeof(int), &out_W);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel1, 6, sizeof(int), &out_C_Start);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel1, 7, sizeof(int), &in_W);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel1, 8, sizeof(int), &in_H);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel1, 9, sizeof(int), &in_C1);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel1, 10, sizeof(int), &in_C2);
CL_CHECK_ERRORS(status);
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel1, default_work_size.size(),
NULL, default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
out_C_Start += inputs[i]->dims()[1];
}
}
}
......
......@@ -51,8 +51,16 @@ bool ConvAddKernel<GPU_CL, float>::Init(FusionConvAddParam<GPU_CL> *param) {
this->cl_helper_.AddKernel("conv_3x3", "conv_add_kernel.cl");
} else {
PADDLE_MOBILE_THROW_EXCEPTION(" not support ");
} else if (param->Filter()->dims()[2] == 7 &&
param->Filter()->dims()[3] == 7) {
param->Filter()->InitCLImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("conv_7x7", "conv_add_kernel.cl");
} else if (param->Filter()->dims()[2] == 5 &&
param->Filter()->dims()[3] == 5) {
param->Filter()->InitCLImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("conv_5x5", "conv_add_kernel.cl");
}
return true;
......
......@@ -52,6 +52,16 @@ bool ConvAddReluKernel<GPU_CL, float>::Init(
this->cl_helper_.AddKernel("conv_3x3", "conv_add_relu_kernel.cl");
} else if (param->Filter()->dims()[2] == 7 &&
param->Filter()->dims()[3] == 7) {
param->Filter()->InitCLImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("conv_7x7", "conv_add_relu_kernel.cl");
} else if (param->Filter()->dims()[2] == 5 &&
param->Filter()->dims()[3] == 5) {
param->Filter()->InitCLImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("conv_5x5", "conv_add_relu_kernel.cl");
} else {
PADDLE_MOBILE_THROW_EXCEPTION(" not support ");
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_FC_OP
#include "operators/kernel/fusion_fc_kernel.h"
#include "operators/math/math_function.h"
namespace paddle_mobile {
namespace operators {
template <>
bool FusionFcKernel<GPU_CL, float>::Init(FusionFcParam<GPU_CL> *param) {
param->InputY()->InitNormalCLImage(cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
param->InputZ()->InitNormalCLImage(cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl");
this->cl_helper_.AddKernel("feed", "feed_kernel.cl");
return true;
}
template <typename P>
void FusionFcCompute(const FusionFcParam<GPU_CL> &param, cl_context context,
cl_command_queue commandQueue, cl_kernel kernel0,
cl_kernel kernel1) {
auto *input_x_image = param.InputX();
auto *input_y_image = param.InputY();
auto *input_z_image = param.InputZ();
int axis = param.Axis();
auto *out_image = param.Out();
Tensor *input_x = new Tensor();
input_x->Resize(input_x_image->dims());
input_x->mutable_data<float>();
framework::CLImageToTensor(input_x_image, input_x, context, commandQueue,
kernel0);
Tensor *input_y = new Tensor();
input_y->Resize(input_y_image->dims());
input_y->mutable_data<float>();
framework::CLImageToTensor(input_y_image, input_y, context, commandQueue,
kernel0);
Tensor *input_z = new Tensor();
input_z->Resize(input_z_image->dims());
input_z->mutable_data<float>();
framework::CLImageToTensor(input_z_image, input_z, context, commandQueue,
kernel0);
auto *input_z_data = input_z->data<float>();
DLOG << *input_x;
DLOG << *input_y;
DLOG << *input_z;
Tensor *out = new Tensor();
out->Resize(out_image->dims());
out->mutable_data<float>();
auto *out_data = out->mutable_data<float>();
const Tensor x_matrix =
input_x->dims().size() > 2
? framework::ReshapeToMatrix(*input_x, param.XNumColDims())
: *input_x;
const Tensor y_matrix =
input_y->dims().size() > 2
? framework::ReshapeToMatrix(*input_y, param.YNumColDims())
: *input_y;
auto out_dim = out->dims();
if (out_dim.size() != 2) {
out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
}
PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2.");
PADDLE_MOBILE_ENFORCE(input_z->dims().size() == 1, "inpu_z size must be 1");
PADDLE_MOBILE_ENFORCE(out_dim[1] == input_z->dims()[0],
" out_dim.size must be 2.");
axis = (axis == -1 ? out_dim.size() - input_z->dims().size() : axis);
PADDLE_MOBILE_ENFORCE(axis == 1, " to fit broadcast, axis = 1. ");
int64_t classes = input_z->numel();
for (int i = 0; i < out_dim[0]; i++) {
memory::Copy(out_data + i * classes, input_z_data, sizeof(float) * classes);
}
// for (int i = 0; i < out->numel(); i++) {
// DLOG << out_data[i];
// }
// bias_data的维度和out的维度一致
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(1), false);
out_image->InitEmptyImage(context, commandQueue, out->dims());
framework::TensorToCLImage(out, out_image, context, commandQueue, kernel1);
DLOG << *out;
delete (input_x);
delete (input_y);
delete (input_z);
delete (out);
PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2.");
// if (out_dim.size() != 2) {
// out->Resize(out_dim);
// }
}
template <>
void FusionFcKernel<GPU_CL, float>::Compute(
const FusionFcParam<GPU_CL> &param) {
auto kernel0 = this->cl_helper_.KernelAt(0);
auto kernel1 = this->cl_helper_.KernelAt(1);
FusionFcCompute<float>(param, this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue(), kernel0, kernel1);
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef LRN_OP
#include "operators/kernel/lrn_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool LrnKernel<GPU_CL, float>::Init(LrnParam<GPU_CL> *param) {
this->cl_helper_.AddKernel("lrn", "lrn_kernel.cl");
return true;
}
template <>
void LrnKernel<GPU_CL, float>::Compute(const LrnParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Out());
auto input_image = param.InputX()->GetCLImage();
auto x_dims = param.InputX()->dims();
auto output_image = param.Out()->GetCLImage();
const int N = x_dims[0];
const int C = x_dims[1];
const int H = x_dims[2];
const int W = x_dims[3];
const int n = param.N();
const float alpha = param.Alpha();
const float beta = param.Beta();
const float k = param.K();
DLOG << "n=" << n;
DLOG << "alpha=" << alpha;
DLOG << "beta=" << beta;
DLOG << "k=" << k;
DLOG << default_work_size;
DLOG << C;
DLOG << W;
cl_int status;
status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &input_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &output_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &C);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(int), &W);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(int), &n);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 5, sizeof(float), &k);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 6, sizeof(float), &alpha);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 7, sizeof(float), &beta);
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
}
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -58,7 +58,7 @@ bool ConvAddBNKernel<FPGA, float>::Init(FusionConvAddBNParam<FPGA> *param) {
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
fpga::format_conv_data(filter, out, bs_ptr, param->Groups());
fpga::format_conv_data(filter, out, &bs_ptr, param->Groups());
fpga::SplitConvArgs conv_arg = {0};
fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled,
......
......@@ -56,7 +56,7 @@ bool ConvAddBNReluKernel<FPGA, float>::Init(
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
fpga::format_conv_data(filter, out, bs_ptr, param->Groups());
fpga::format_conv_data(filter, out, &bs_ptr, param->Groups());
fpga::SplitConvArgs conv_arg = {0};
fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled,
......
......@@ -38,7 +38,7 @@ bool ConvAddKernel<FPGA, float>::Init(FusionConvAddParam<FPGA> *param) {
bs_ptr[i] = bias_ptr[i];
}
fpga::format_conv_data(filter, out, bs_ptr, param->Groups());
fpga::format_conv_data(filter, out, &bs_ptr, param->Groups());
fpga::SplitConvArgs conv_arg = {0};
fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled,
......
......@@ -38,7 +38,7 @@ bool ConvAddReluKernel<FPGA, float>::Init(FusionConvAddReluParam<FPGA> *param) {
bs_ptr[i] = bias_ptr[i];
}
fpga::format_conv_data(filter, out, bs_ptr, param->Groups());
fpga::format_conv_data(filter, out, &bs_ptr, param->Groups());
fpga::SplitConvArgs conv_arg = {0};
fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled,
......
......@@ -50,7 +50,7 @@ bool ConvBNKernel<FPGA, float>::Init(FusionConvBNParam<FPGA> *param) {
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
fpga::format_conv_data(filter, out, bs_ptr, param->Groups());
fpga::format_conv_data(filter, out, &bs_ptr, param->Groups());
fpga::SplitConvArgs conv_arg = {0};
fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled,
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef FUSION_CONVBNRELU_OP
#include "operators/kernel/conv_bn_relu_kernel.h"
#include "fpga/V2/filter.h"
namespace paddle_mobile {
namespace operators {
......@@ -50,7 +51,7 @@ bool ConvBNReluKernel<FPGA, float>::Init(FusionConvBNReluParam<FPGA> *param) {
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
fpga::format_conv_data(filter, out, bs_ptr, param->Groups());
fpga::format_conv_data(filter, out, &bs_ptr, param->Groups());
fpga::SplitConvArgs conv_arg = {0};
fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled,
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#ifdef LRN_OP
#include "lrn_op.h"
#include "operators/lrn_op.h"
namespace paddle_mobile {
namespace operators {
......@@ -32,6 +32,9 @@ namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(lrn, ops::LrnOp);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(lrn, ops::LrnOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
REGISTER_OPERATOR_MALI_GPU(lrn, ops::LrnOp);
#endif
......
......@@ -23,12 +23,10 @@ limitations under the License. */
#if __aarch64__
#define MR_INT8 4
#define NR_INT8 2
#define MR 6
#define NR 16
#else
#define MR_INT8 4
#define NR_INT8 2
#define MR 6
#define NR 8
#endif
......@@ -195,58 +193,52 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
// 8 bits int small block inner product
void AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
void AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
// 8 bits int inner product
void InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a,
const int8_t *b, float beta, int32_t *c, int32_t *C,
int32_t ldc, bool relu);
void InnerKernelWithBias(int32_t mc, int32_t nc, float alpha, const int8_t *a,
const int8_t *b, float beta, int32_t *c, int8_t *C,
int32_t ldc, bool relu, int32_t *bias);
void InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha,
const int8_t *a, const int8_t *b, int8_t beta,
int32_t *c, int32_t *C, int32_t ldc, bool relu,
int8_t *bias);
// 8 bits int pack function
void PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixB_2c_16(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer);
void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer);
void PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer);
void PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail,
const int8_t *A, int32_t lda, int8_t *buffer);
void PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer);
// 8 bits int matrix product
void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, float beta, int32_t *C,
int32_t ldc, bool relu, int32_t *bias);
void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, float beta, int8_t *C,
int32_t ldc, bool relu, int32_t *bias);
void Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, float beta,
int32_t *C, int32_t ldc, bool relu, int32_t *bias);
void Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, int32_t *C,
int32_t ldc, bool relu, int8_t *bias);
void Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, int8_t beta,
int32_t *C, int32_t ldc, bool relu, int8_t *bias);
// 8 bits int write back
// C = alpha * A * B + beta * C
void WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc);
// C = A * B
void WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc);
// C = A * B + bias, scale * relu(C)
void WriteWithAddReluScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C,
int32_t ldc, int32_t *bias, float scale);
// C = A * B + bias, scale * C
void WriteWithAddScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C,
int32_t ldc, int32_t *bias, float scale);
// C = A * B + C
void WriteWithAdd(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc);
// C = A * B + bias
void WriteWithAddV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc, int8_t *bias);
// C = A * B + C, relu(C)
void WriteWithAddRelu(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc);
// C = A * B + bias, relu(C)
void WriteWithAddReluV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc, int8_t *bias);
private:
int MC = 0;
......@@ -262,7 +254,7 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
// 8 bits int
int8_t *packedA_int8;
int8_t *packedB_int8;
int32_t *packedC_int32;
int32_t *packedC_int8;
int8_t *zero_int8;
};
......
此差异已折叠。
......@@ -28,10 +28,10 @@ namespace operators {
namespace math {
// 8 bits int matrix product (m*k x k*n)
void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha,
const int8_t *A, int32_t lda, const int8_t *B, int32_t ldb,
float beta, int32_t *C, int32_t ldc, bool relu,
int32_t *bias) {
int8_t beta, int32_t *C, int32_t ldc, bool relu,
int8_t *bias) {
#ifdef _OPENMP
int32_t max_threads = omp_get_max_threads();
#else
......@@ -39,11 +39,10 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
#endif
int32_t L1 = 64 / max_threads * 1024;
const int32_t k_complete = (k + 15) - ((k + 15) & 15);
KC = k_complete;
KC = k;
zero_int8 =
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * k));
memset(static_cast<void *>(zero_int8), 0, sizeof(int8_t) * k);
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * KC));
memset(static_cast<void *>(zero_int8), 0, sizeof(int8_t) * KC);
if (m > n) {
// 对 A 分块
MC = L1 / (KC * sizeof(int8_t));
......@@ -55,14 +54,14 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8;
}
// 补齐 B
NC = (n + NR_INT8 - 1) / NR_INT8 * NR_INT8;
NC = (n + NR - 1) / NR * NR;
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC));
#if __aarch64__
// TODO(wzzju)
#else
PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8);
PackMatrixB_omp_8c(KC, n, n % NR, B, ldb, packedB_int8);
#endif
packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC * max_threads));
......@@ -70,11 +69,11 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
// 对 B 分块
NC = L1 / (KC * sizeof(int8_t));
if (NC == 0) {
NC = NR_INT8;
NC = NR;
} else {
int32_t nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8;
NC = (NC + NR - 1) / NR * NR;
}
// 补齐 A
MC = (m + MR_INT8 - 1) / MR_INT8 * MR_INT8;
......@@ -84,12 +83,12 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
#if __aarch64__
// TODO(wzzju)
#else
PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8);
PackMatrixA_omp_4r(m, KC, m % MR_INT8, A, lda, packedA_int8);
#endif
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC * max_threads));
}
packedC_int32 = static_cast<int32_t *>(
packedC_int8 = static_cast<int32_t *>(
paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC * max_threads));
if (m > n) {
......@@ -104,19 +103,14 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
int32_t mc;
mc = s_min(m - i, MC);
int8_t *local_A = packedA_int8 + MC * KC * local_threads;
int32_t *local_C = packedC_int32 + MC * NC * local_threads;
int32_t *local_C = packedC_int8 + MC * NC * local_threads;
#if __aarch64__
// TODO(wzzju)
#else
PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A);
PackMatrixA_4r(mc, KC, mc % MR_INT8, &A(i, 0), lda, local_A);
#endif
// InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta,
// local_C,
// &C(i, 0), ldc, relu, bias + i);
if (bias == nullptr) {
InnerKernel(mc, n, alpha, local_A, packedB_int8, beta, local_C,
&C(i, 0), ldc, relu);
}
InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta, local_C,
&C(i, 0), ldc, relu, bias + i);
}
} else {
#pragma omp parallel for
......@@ -129,25 +123,20 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
int32_t nc;
nc = s_min(n - j, NC);
int8_t *local_B = packedB_int8 + KC * NC * local_threads;
int32_t *local_C = packedC_int32 + MC * NC * local_threads;
int32_t *local_C = packedC_int8 + MC * NC * local_threads;
#if __aarch64__
// TODO(wzzju)
#else
PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B);
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, local_B);
#endif
// InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta,
// local_C,
// &C(0, j), ldc, relu, bias);
if (bias == nullptr) {
InnerKernel(m, nc, alpha, packedA_int8, local_B, beta, local_C,
&C(0, j), ldc, relu);
}
InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta, local_C,
&C(0, j), ldc, relu, bias);
}
}
paddle_mobile::memory::Free(packedA_int8);
paddle_mobile::memory::Free(packedB_int8);
paddle_mobile::memory::Free(packedC_int32);
paddle_mobile::memory::Free(packedC_int8);
paddle_mobile::memory::Free(zero_int8);
}
......@@ -155,7 +144,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer) {
const int32_t j_length = n - n_tail;
#pragma omp parallel for
for (int32_t j = 0; j < j_length; j += 8) {
for (int32_t j = 0; j < j_length; j += NR) {
int8_t *local_buffer = buffer + j * k;
for (int32_t i = 0; i < k; ++i) {
const int8_t *b0 = &B(i, j);
......@@ -190,7 +179,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
for (int32_t j = j_length; j < n; ++j) {
*local_buffer++ = *b0++;
}
for (int32_t j = n; j < j_length + 8; ++j) {
for (int32_t j = n; j < j_length + NR; ++j) {
*local_buffer++ = 0;
}
}
......@@ -199,9 +188,9 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail,
const int8_t *A, int32_t lda, int8_t *buffer) {
const int32_t i_length = m - m_tail;
const int i_length = m - m_tail;
#pragma omp parallel for
for (int32_t i = 0; i < i_length; i += 4) {
for (int32_t i = 0; i < i_length; i += MR_INT8) {
const int8_t *a0 = A + i * lda;
const int8_t *a1 = A + (i + 1) * lda;
const int8_t *a2 = A + (i + 2) * lda;
......@@ -232,7 +221,7 @@ void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail,
default:
break;
}
for (int32_t j = 0; j < k; ++j) {
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
......@@ -241,232 +230,6 @@ void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail,
}
}
// 8 bits int PackMatrixA_4r
void Gemm::PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail,
const int8_t *A, int32_t lda, int8_t *buffer) {
const int32_t i_length = m - m_tail;
const int32_t k_count = k >> 4;
const int32_t k_tail = k & 15;
#pragma omp parallel for
for (int32_t i = 0; i < i_length; i += 4) {
const int8_t *a0 = A + i * lda;
const int8_t *a1 = A + (i + 1) * lda;
const int8_t *a2 = A + (i + 2) * lda;
const int8_t *a3 = A + (i + 3) * lda;
int8_t *local_buffer = buffer + i * KC;
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
"vld1.s8 {d2, d3}, [%[a1]]! \n\t"
"vld1.s8 {d4, d5}, [%[a2]]! \n\t"
"vld1.s8 {d6, d7}, [%[a3]]! \n\t"
"vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t"
"vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t"
"vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t"
"vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t"
: [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1),
[a2] "+r"(a2), [a3] "+r"(a3)
:
: "memory", "q0", "q1", "q2", "q3");
#endif // __aarch64__
#else
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a0++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a1++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a2++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a3++;
}
#endif // __ARM_NEON
}
if (k_tail != 0) {
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a0++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a1++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a2++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a3++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
if (m_tail != 0) {
const int8_t *a0 = &A(i_length, 0);
const int8_t *a1 = a0 + lda;
const int8_t *a2 = a0 + 2 * lda;
const int8_t *a3 = a0 + 3 * lda;
int8_t *local_buffer = buffer + i_length * KC;
switch (m_tail) {
case 1:
a1 = zero_int8;
case 2:
a2 = zero_int8;
case 3:
a3 = zero_int8;
break;
default:
break;
}
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
"vld1.s8 {d2, d3}, [%[a1]]! \n\t"
"vld1.s8 {d4, d5}, [%[a2]]! \n\t"
"vld1.s8 {d6, d7}, [%[a3]]! \n\t"
"vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t"
"vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t"
"vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t"
"vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t"
: [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1),
[a2] "+r"(a2), [a3] "+r"(a3)
:
: "memory", "q0", "q1", "q2", "q3");
#endif // __aarch64__
#else
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a0++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a1++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a2++;
}
for (int32_t l = 0; l < 16; ++l) {
*local_buffer++ = *a3++;
}
#endif // __ARM_NEON
}
if (k_tail != 0) {
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a0++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a1++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a2++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *a3++;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
}
// 8 bits int PackMatrixB
void Gemm::PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail,
const int8_t *B, int32_t ldb, int8_t *buffer) {
const int32_t j_length = n - n_tail;
const int32_t k_count = k >> 4;
const int32_t k_tail = k & 15;
#pragma omp parallel for
for (int32_t j = 0; j < j_length; j += 2) {
int8_t *local_buffer = buffer + j * KC;
for (int32_t i = 0; i < k_count; ++i) {
const int8_t *b0 = &B((i << 4), j);
const int8_t *b1 = &B((i << 4), j + 1);
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b1;
b1 += ldb;
}
}
if (k_tail != 0) {
const int8_t *b0 = &B((k_count << 4), j);
const int8_t *b1 = &B((k_count << 4), j + 1);
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b1;
b1 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
if (n_tail != 0) {
int8_t *local_buffer = buffer + j_length * KC;
for (int32_t i = 0; i < k_count; ++i) {
const int8_t *b0 = &B((i << 4), j_length);
for (int m = 0; m < 16; ++m) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int m = 0; m < 16; ++m) {
*local_buffer++ = 0;
}
}
if (k_tail != 0) {
const int8_t *b0 = &B((k_count << 4), j_length);
for (int32_t j = k_count << 4; j < k; ++j) {
*local_buffer++ = *b0;
b0 += ldb;
}
for (int32_t j = k; j < KC; ++j) {
*local_buffer++ = 0;
}
for (int32_t j = k_count << 4; j < KC; ++j) {
*local_buffer++ = 0;
}
}
}
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -28,12 +28,7 @@ template <typename T>
void matmul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu = false,
float *bias = nullptr);
void matmul_int8(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu = false,
int32_t *bias = nullptr);
T *bias = nullptr);
template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
......
......@@ -20,10 +20,11 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
namespace math {
void matmul_int8(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu,
int32_t *bias) {
template <>
void matmul<int8_t>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
int8_t alpha, framework::Tensor *matrix_out, int8_t beta,
bool relu, int8_t *bias) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
......@@ -51,45 +52,21 @@ void matmul_int8(const framework::Tensor &matrix_a, bool trans_a,
}
#ifdef _OPENMP
if (bias != nullptr) {
// TODO(wzzju): gemm.Sgemm_omp_with_bias, now use single thread instead.
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int8_t>(), N, relu, bias);
} else {
gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
}
gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
#else
if (bias != nullptr) {
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int8_t>(), N, relu, bias);
} else {
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
}
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
#endif
} else {
#ifdef _OPENMP
if (bias != nullptr) {
// TODO(wzzju): gemm.Sgemm_omp_with_bias, now use single thread instead.
gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int8_t>(),
N, relu, bias);
} else {
gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
}
gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
#else
if (bias != nullptr) {
gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int8_t>(),
N, relu, bias);
} else {
gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int32_t>(),
N, relu, bias);
}
gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int32_t>(), N,
relu, bias);
#endif
}
}
......
......@@ -1633,11 +1633,11 @@ class FusionFcParam : public OpParam {
y_num_col_dims_ = GetAttr<int>("y_num_col_dims", attrs);
axis_ = GetAttr<int>("axis", attrs);
}
const GType *InputX() const { return input_x_; }
GType *InputX() const { return input_x_; }
const RType *InputY() const { return input_y_; }
RType *InputY() const { return input_y_; }
const RType *InputZ() const { return input_z_; }
RType *InputZ() const { return input_z_; }
GType *Out() const { return out_; }
......
......@@ -28,7 +28,7 @@ limitations under the License. */
int main() {
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
paddle_mobile.SetThreadNum(4);
paddle_mobile.SetThreadNum(8);
Tensor aa, bb, cc;
auto aaptr = aa.mutable_data<float>({m, k});
auto bbptr = bb.mutable_data<float>({k, n});
......@@ -44,12 +44,10 @@ int main() {
ccptr[i] = 2;
}
Tensor aa_int8, bb_int8, cc_int32, cc_int8;
Tensor aa_int8, bb_int8, cc_int8;
auto aaptr_int8 = aa_int8.mutable_data<int8_t>({m, k});
auto bbptr_int8 = bb_int8.mutable_data<int8_t>({k, n});
auto ccptr_int32 = cc_int32.mutable_data<int32_t>({m, n});
auto ccptr_int8 = cc_int8.mutable_data<int8_t>({m, n});
int32_t* bias_data = new int32_t[m];
auto ccptr_int8 = cc_int8.mutable_data<int32_t>({m, n});
for (int i = 0; i < m * k; ++i) {
aaptr_int8[i] = static_cast<int8_t>(2);
......@@ -58,11 +56,7 @@ int main() {
bbptr_int8[i] = static_cast<int8_t>(2);
}
for (int i = 0; i < m * n; ++i) {
ccptr_int32[i] = static_cast<int32_t>(2);
}
for (int i = 0; i < m; ++i) {
bias_data[i] = 2;
ccptr_int8[i] = static_cast<int32_t>(2);
}
// float
......@@ -82,41 +76,22 @@ int main() {
auto time2 = time();
std::cout << "float gemm cost :" << time_diff(time1, time2) / 10 << "ms\n";
// int8_t without bias
// int8_t
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul_int8(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<float>(0), false, nullptr);
paddle_mobile::operators::math::matmul<int8_t>(
aa_int8, false, bb_int8, false, static_cast<int8_t>(1), &cc_int8,
static_cast<int8_t>(0), false, nullptr);
}
auto time3 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul_int8(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<float>(0), false, nullptr);
paddle_mobile::operators::math::matmul<int8_t>(
aa_int8, false, bb_int8, false, static_cast<int8_t>(1), &cc_int8,
static_cast<int8_t>(0), false, nullptr);
}
auto time4 = time();
std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n";
// int8_t with bias&relu
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul_int8(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int8,
static_cast<float>(0), true, &bias_data[0]);
}
auto time5 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul_int8(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int8,
static_cast<float>(0), true, &bias_data[0]);
}
auto time6 = time();
std::cout << "int8_t gemm_with_bias_relu cost :"
<< time_diff(time5, time6) / 10 << "ms\n";
delete[] bias_data;
return 0;
}
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fstream>
#include <iomanip>
#include <iostream>
#include "../test_include.h"
#ifdef PADDLE_MOBILE_FPGA_V1
......@@ -87,26 +89,29 @@ int main() {
paddle_mobile::PaddleMobile<paddle_mobile::FPGA> paddle_mobile;
if (paddle_mobile.Load(std::string(g_resnet50), true)) {
Tensor input_tensor;
SetupTensor<float>(&input_tensor, {1, 3, 224, 224}, static_cast<float>(0),
static_cast<float>(1));
SetupTensor<float>(&input_tensor, {1, 3, 224, 224}, static_cast<float>(2),
static_cast<float>(2));
readStream(g_image_src_float,
input_tensor.mutable_data<float>({1, 3, 224, 224}));
paddle_mobile.FeedData(input_tensor);
paddle_mobile.Predict_To(-1);
/*for(int i = 0; i < 73; i++)
{
for (int i = 0; i < 73; i++) {
auto tensor_ptr = paddle_mobile.FetchResult(i);
std::string saveName = "resnet50_result_" + std::to_string (i);
std::string saveName = "resnet50_result_" + std::to_string(i);
paddle_mobile::fpga::fpga_invalidate((*tensor_ptr).data<float>(),
tensor_ptr->numel()); dump_stride(saveName, (*tensor_ptr), 20);
//dump(saveName, (*tensor_ptr));
}*/
tensor_ptr->numel() * sizeof(half));
dump_stride(saveName, (*tensor_ptr), 20);
// dump(saveName, (*tensor_ptr));
}
/*std::shared_ptr<Tensor> output_tensor = paddle_mobile.FetchResult(73);
(*output_tensor).dump<float>("resnet50_result_73");
std::shared_ptr<Tensor> output_tensor = paddle_mobile.FetchResult(73);
//(*output_tensor).dump<float>("resnet50_result_73");
output_tensor = paddle_mobile.FetchResult(74);
(*output_tensor).dump<float>("resnet50_result_74");*/
std::shared_ptr<Tensor> output_tensor = paddle_mobile.FetchResult(74);
//(*output_tensor).dump<float>("resnet50_result_74");
// std::shared_ptr<Tensor> output_tensor = paddle_mobile.FetchResult(74);
// output_tensor = paddle_mobile.FetchResult(74);
float max = 0;
auto data_ptr = output_tensor->data<float>();
int maximumIdx = 0;
......@@ -116,7 +121,7 @@ int main() {
max = data_ptr[i];
}
}
std::cout << "index : " << maximumIdx << ", value : " << max
std::cout << "index : " << std::dec << maximumIdx << ", value : " << max
<< std::endl;
std::cout << "Computation done" << std::endl;
return 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册