提交 efa4963e 编写于 作者: C Chon 提交者: GitHub

Merge pull request #942 from zhangyang0701/develop

Compress common code to function for FPGA track close #941
......@@ -249,5 +249,79 @@ void format_concat_output(framework::Tensor *out, int height, int width,
out->reset_data_ptr(data_ptr);
}
void fill_conv_arg(struct WrapperConvArgs *arg, framework::Tensor *input,
framework::Tensor *out, framework::Tensor *filter,
bool relu_enabled, int group_num, int stride_h, int stride_w,
int padding_h, int padding_w, float *bs_ptr) {
auto input_ptr = input->data<float>();
auto filter_ptr = filter->data<float>();
auto out_ptr = out->mutable_data<float>();
arg->group_num = (uint32_t)group_num;
arg->split_num = (uint32_t)fpga::get_plit_num(filter);
arg->filter_num = (uint32_t)filter->dims()[0];
arg->output.address = out_ptr;
arg->output.scale_address = out->scale;
arg->conv_args = (fpga::ConvArgs *)fpga::fpga_malloc(arg->split_num *
sizeof(fpga::ConvArgs));
arg->concat_arg.image_num = arg->split_num;
arg->concat_arg.image_out = out_ptr;
arg->concat_arg.scale_out = out->scale;
arg->concat_arg.height = (uint32_t)filter->dims()[2];
arg->concat_arg.width = (uint32_t)filter->dims()[3];
int n = arg->split_num;
arg->concat_arg.images_in = (half **)fpga::fpga_malloc(n * sizeof(int *));
arg->concat_arg.scales_in = (float **)fpga::fpga_malloc(n * sizeof(float *));
arg->concat_arg.channel_num =
(uint32_t *)fpga::fpga_malloc(n * sizeof(uint32_t));
arg->concat_arg.image_out = out_ptr;
const int channel = (int)out->dims()[1];
int element_num_per_div = fpga::get_element_num_per_div(filter, group_num);
int element_num = fpga::get_aligned_filter_element_num(
filter->dims()[1] * filter->dims()[2] * filter->dims()[3]);
for (int i = 0; i < n; i++) {
arg->conv_args[i].relu_enabled = relu_enabled;
arg->conv_args[i].group_num = (uint32_t)group_num;
arg->conv_args[i].kernel.stride_h = (uint32_t)stride_h;
arg->conv_args[i].kernel.stride_w = (uint32_t)stride_w;
arg->conv_args[i].kernel.height = (uint32_t)filter->dims()[2];
arg->conv_args[i].kernel.width = (uint32_t)filter->dims()[3];
arg->conv_args[i].image.address = input_ptr;
arg->conv_args[i].image.channels = (uint32_t)input->dims()[1];
arg->conv_args[i].image.height = (uint32_t)input->dims()[2];
arg->conv_args[i].image.width = (uint32_t)input->dims()[3];
arg->conv_args[i].image.scale_address = input->scale;
arg->conv_args[i].image.pad_height = (uint32_t)padding_h;
arg->conv_args[i].image.pad_width = (uint32_t)padding_w;
arg->conv_args[i].filter_address = &((int8_t *)filter_ptr)[i * element_num];
arg->conv_args[i].sb_address = &((int8_t *)bs_ptr)[i * element_num];
arg->conv_args[i].filter_num =
(uint32_t)(i == n - 1 ? fpga::get_aligned_filter_num(
channel - (n - 1) * element_num_per_div)
: element_num_per_div);
if (n > 1) {
arg->conv_args[i].output.scale_address =
(float *)fpga::fpga_malloc(2 * sizeof(float));
arg->conv_args[i].output.address =
fpga::fpga_malloc(input->dims()[2] * input->dims()[3] *
arg->conv_args[i].filter_num * sizeof(half));
}
else {
arg->conv_args[i].output.scale_address = out->scale;
arg->conv_args[i].output.address = out_ptr;
}
arg->concat_arg.images_in[i] = (half *)arg->conv_args[i].output.address;
arg->concat_arg.scales_in[i] = (float *)arg->conv_args[i].sb_address;
arg->concat_arg.channel_num[i] = arg->conv_args[i].filter_num;
}
}
} // namespace fpga
} // namespace paddle_mobile
......@@ -191,14 +191,15 @@ int ComputeFpgaEWAdd(const struct EWAddArgs& args);
int ComputeFPGAConcat(const struct ConcatArgs& args);
static inline int align_to_x(int num, int x) { return (num + x - 1) / x * x; }
void format_image(framework::Tensor* image_tensor);
void format_ofm(framework::Tensor* ofm_tensor); // only allocate memory
float filter_find_max(framework::Tensor* filter_tensor);
int get_element_num_per_div(framework::Tensor* filter_tensor, int group_num);
int get_plit_num(framework::Tensor* filter_tensor);
int get_aligned_filter_element_num(int chw);
int get_aligned_filter_num(int num);
void format_filter(framework::Tensor* filter_tensor, float max_value,
int group_num);
void format_bias_scale_array(float** bias_scale_array,
......@@ -206,5 +207,10 @@ void format_bias_scale_array(float** bias_scale_array,
void format_concat_output(framework::Tensor* out, int height, int width,
int image_num, uint32_t* channel_num);
void fill_conv_arg(struct WrapperConvArgs* arg, framework::Tensor* input,
framework::Tensor* out, framework::Tensor* filter,
bool relu_enabled, int group_num, int stride_h, int stride_w,
int padding_h, int padding_w, float* bs_ptr);
} // namespace fpga
} // namespace paddle_mobile
......@@ -64,7 +64,40 @@ void format_image(float **data_in, int channel, int height, int width) {
void concat_images(int16_t **images_in, float **scales_in, void *image_out,
float *scale_out, int image_num, uint32_t *channel_num,
int height, int width) {}
int height, int width) {
int i = 0;
int j = 0;
int k = 0;
int each_out_line_channel = 0;
int align_each_out_area_cw = 0;
int align_each_in_area_cw = 0;
int align_each_out_area_cw_differ = 0;
int tmp_channel = 0;
*scale_out = 0;
for (i = 0; i < image_num; i++) {
each_out_line_channel += channel_num[i];
*scale_out = std::max(*scale_out, scales_in[i][0]);
}
align_each_out_area_cw =
align_to_x(each_out_line_channel * width, IMAGE_ALIGNMENT);
align_each_out_area_cw_differ =
align_each_out_area_cw - each_out_line_channel * width;
for (k = 0; k < height; k++) {
for (j = 0; j < width; j++) {
for (i = 0; i < image_num; i++) {
align_each_in_area_cw =
align_to_x(channel_num[i] * width, IMAGE_ALIGNMENT);
memcpy((int16_t *)image_out + tmp_channel +
k * align_each_out_area_cw_differ,
images_in[i] + j * channel_num[i] + k * align_each_in_area_cw,
channel_num[i] * sizeof(int16_t));
tmp_channel += channel_num[i];
}
}
}
}
} // namespace image
} // namespace fpga
......
......@@ -60,84 +60,18 @@ bool ConvAddBNKernel<FPGA, float>::Init(FusionConvAddBNParam<FPGA> *param) {
float max_value = fpga::filter_find_max(filter);
fpga::format_filter(filter, max_value, param->Groups());
auto filter_ptr = filter->data<float>();
int element_num_per_div =
fpga::get_element_num_per_div(filter, param->Groups());
fpga::format_bias_scale_array(&bs_ptr, element_num_per_div, channel);
fpga::format_ofm(out);
auto out_ptr = out->mutable_data<float>();
fpga::WrapperConvArgs convArgs;
convArgs.group_num = (uint32_t)param->Groups();
convArgs.split_num = (uint32_t)fpga::get_plit_num(filter);
convArgs.filter_num = (uint32_t)filter->dims()[0];
convArgs.output.address = out_ptr;
convArgs.output.scale_address = out->scale;
convArgs.conv_args = (fpga::ConvArgs *)fpga::fpga_malloc(
convArgs.split_num * sizeof(fpga::ConvArgs));
convArgs.concat_arg.image_num = convArgs.split_num;
convArgs.concat_arg.image_out = out_ptr;
convArgs.concat_arg.scale_out = out->scale;
convArgs.concat_arg.height = (uint32_t)filter->dims()[2];
convArgs.concat_arg.width = (uint32_t)filter->dims()[3];
int n = convArgs.split_num;
convArgs.concat_arg.images_in = (half **)fpga::fpga_malloc(n * sizeof(int *));
convArgs.concat_arg.scales_in =
(float **)fpga::fpga_malloc(n * sizeof(float *));
convArgs.concat_arg.channel_num =
(uint32_t *)fpga::fpga_malloc(n * sizeof(uint32_t));
convArgs.concat_arg.image_out = out_ptr;
param->SetFpgaArgs(convArgs);
int element_num = fpga::get_aligned_filter_element_num(
filter->dims()[1] * filter->dims()[2] * filter->dims()[3]);
for (int i = 0; i < n; i++) {
convArgs.conv_args[i].relu_enabled = relu_enabled;
convArgs.conv_args[i].group_num = (uint32_t)param->Groups();
convArgs.conv_args[i].kernel.stride_h = (uint32_t)param->Strides()[0];
convArgs.conv_args[i].kernel.stride_w = (uint32_t)param->Strides()[1];
convArgs.conv_args[i].kernel.height = (uint32_t)filter->dims()[2];
convArgs.conv_args[i].kernel.width = (uint32_t)filter->dims()[3];
convArgs.conv_args[i].image.address = input_ptr;
convArgs.conv_args[i].image.channels = (uint32_t)input->dims()[1];
convArgs.conv_args[i].image.height = (uint32_t)input->dims()[2];
convArgs.conv_args[i].image.width = (uint32_t)input->dims()[3];
convArgs.conv_args[i].image.scale_address = input->scale;
convArgs.conv_args[i].image.pad_height = (uint32_t)param->Paddings()[0];
convArgs.conv_args[i].image.pad_width = (uint32_t)param->Paddings()[1];
convArgs.conv_args[i].filter_address =
&((int8_t *)filter_ptr)[i * element_num];
convArgs.conv_args[i].sb_address = &((int8_t *)bs_ptr)[i * element_num];
convArgs.conv_args[i].filter_num =
(uint32_t)(i == n - 1 ? fpga::get_aligned_filter_num(
channel - (n - 1) * element_num_per_div)
: element_num_per_div);
if (n > 1) {
convArgs.conv_args[i].output.scale_address =
(float *)fpga::fpga_malloc(2 * sizeof(float));
convArgs.conv_args[i].output.address =
fpga::fpga_malloc(input->dims()[2] * input->dims()[3] *
convArgs.conv_args[i].filter_num * sizeof(half));
}
else {
convArgs.conv_args[i].output.scale_address = out->scale;
convArgs.conv_args[i].output.address = out_ptr;
}
convArgs.concat_arg.images_in[i] =
(half *)convArgs.conv_args[i].output.address;
convArgs.concat_arg.scales_in[i] =
(float *)convArgs.conv_args[i].sb_address;
convArgs.concat_arg.channel_num[i] = convArgs.conv_args[i].filter_num;
}
fpga::WrapperConvArgs conv_arg;
fpga::fill_conv_arg(&conv_arg, input, out, filter, relu_enabled,
param->Groups(), param->Strides()[0], param->Strides()[1],
param->Paddings()[0], param->Paddings()[1], bs_ptr);
param->SetFpgaArgs(conv_arg);
return true;
}
......
......@@ -67,45 +67,11 @@ bool ConvAddBNReluKernel<FPGA, float>::Init(
fpga::format_ofm(out);
auto out_ptr = out->mutable_data<float>();
fpga::WrapperConvArgs convArgs;
convArgs.group_num = (uint32_t)param->Groups();
convArgs.split_num = (uint32_t)fpga::get_plit_num(filter);
convArgs.filter_num = (uint32_t)filter->dims()[0];
convArgs.output.address = out_ptr;
convArgs.output.scale_address = out->scale;
convArgs.conv_args = (fpga::ConvArgs *)fpga::fpga_malloc(
convArgs.split_num * sizeof(fpga::ConvArgs));
param->SetFpgaArgs(convArgs);
int element_num = fpga::get_aligned_filter_element_num(
filter->dims()[1] * filter->dims()[2] * filter->dims()[3]);
int n = convArgs.split_num;
for (int i = 0; i < n; i++) {
convArgs.conv_args[i].relu_enabled = relu_enabled;
convArgs.conv_args[i].group_num = (uint32_t)param->Groups();
convArgs.conv_args[i].kernel.stride_h = (uint32_t)param->Strides()[0];
convArgs.conv_args[i].kernel.stride_w = (uint32_t)param->Strides()[1];
convArgs.conv_args[i].kernel.height = (uint32_t)filter->dims()[2];
convArgs.conv_args[i].kernel.width = (uint32_t)filter->dims()[3];
convArgs.conv_args[i].image.address = input_ptr;
convArgs.conv_args[i].image.channels = (uint32_t)input->dims()[1];
convArgs.conv_args[i].image.height = (uint32_t)input->dims()[2];
convArgs.conv_args[i].image.width = (uint32_t)input->dims()[3];
convArgs.conv_args[i].image.pad_height = (uint32_t)param->Paddings()[0];
convArgs.conv_args[i].image.pad_width = (uint32_t)param->Paddings()[1];
convArgs.conv_args[i].filter_address =
&((int8_t *)filter_ptr)[i * element_num];
convArgs.conv_args[i].sb_address = &((int8_t *)bs_ptr)[i * element_num];
convArgs.conv_args[i].filter_num =
(uint32_t)(i == n - 1 ? fpga::get_aligned_filter_num(
channel - (n - 1) * element_num_per_div)
: element_num_per_div);
convArgs.conv_args[i].output.scale_address =
(float *)fpga::fpga_malloc(2 * sizeof(float));
convArgs.conv_args[i].image.scale_address = input->scale;
}
return true;
fpga::WrapperConvArgs conv_arg;
fpga::fill_conv_arg(&conv_arg, input, out, filter, relu_enabled,
param->Groups(), param->Strides()[0], param->Strides()[1],
param->Paddings()[0], param->Paddings()[1], bs_ptr);
param->SetFpgaArgs(conv_arg);
return true;
}
......
......@@ -49,44 +49,11 @@ bool ConvAddReluKernel<FPGA, float>::Init(FusionConvAddReluParam<FPGA> *param) {
fpga::format_ofm(out);
auto out_ptr = out->mutable_data<float>();
fpga::WrapperConvArgs convArgs;
convArgs.group_num = (uint32_t)param->Groups();
convArgs.split_num = (uint32_t)fpga::get_plit_num(filter);
convArgs.filter_num = (uint32_t)filter->dims()[0];
convArgs.output.address = out_ptr;
convArgs.output.scale_address = out->scale;
convArgs.conv_args = (fpga::ConvArgs *)fpga::fpga_malloc(
convArgs.split_num * sizeof(fpga::ConvArgs));
param->SetFpgaArgs(convArgs);
int element_num = fpga::get_aligned_filter_element_num(
filter->dims()[1] * filter->dims()[2] * filter->dims()[3]);
int n = convArgs.split_num;
for (int i = 0; i < n; i++) {
convArgs.conv_args[i].relu_enabled = relu_enabled;
convArgs.conv_args[i].group_num = (uint32_t)param->Groups();
convArgs.conv_args[i].kernel.stride_h = (uint32_t)param->Strides()[0];
convArgs.conv_args[i].kernel.stride_w = (uint32_t)param->Strides()[1];
convArgs.conv_args[i].kernel.height = (uint32_t)filter->dims()[2];
convArgs.conv_args[i].kernel.width = (uint32_t)filter->dims()[3];
convArgs.conv_args[i].image.address = input_ptr;
convArgs.conv_args[i].image.channels = (uint32_t)input->dims()[1];
convArgs.conv_args[i].image.height = (uint32_t)input->dims()[2];
convArgs.conv_args[i].image.width = (uint32_t)input->dims()[3];
convArgs.conv_args[i].image.pad_height = (uint32_t)param->Paddings()[0];
convArgs.conv_args[i].image.pad_width = (uint32_t)param->Paddings()[1];
convArgs.conv_args[i].filter_address =
&((int8_t *)filter_ptr)[i * element_num];
convArgs.conv_args[i].sb_address = &((int8_t *)bs_ptr)[i * element_num];
convArgs.conv_args[i].filter_num =
(uint32_t)(i == n - 1 ? fpga::get_aligned_filter_num(
channel - (n - 1) * element_num_per_div)
: element_num_per_div);
convArgs.conv_args[i].output.scale_address =
(float *)fpga::fpga_malloc(2 * sizeof(float));
convArgs.conv_args[i].image.scale_address = input->scale;
}
fpga::WrapperConvArgs conv_arg;
fpga::fill_conv_arg(&conv_arg, input, out, filter, relu_enabled,
param->Groups(), param->Strides()[0], param->Strides()[1],
param->Paddings()[0], param->Paddings()[1], bs_ptr);
param->SetFpgaArgs(conv_arg);
return true;
}
......
......@@ -64,44 +64,11 @@ bool ConvBNKernel<FPGA, float>::Init(FusionConvBNParam<FPGA> *param) {
fpga::format_ofm(out);
auto out_ptr = out->mutable_data<float>();
fpga::WrapperConvArgs convArgs;
convArgs.group_num = (uint32_t)param->Groups();
convArgs.split_num = (uint32_t)fpga::get_plit_num(filter);
convArgs.filter_num = (uint32_t)filter->dims()[0];
convArgs.output.address = out_ptr;
convArgs.output.scale_address = out->scale;
convArgs.conv_args = (fpga::ConvArgs *)fpga::fpga_malloc(
convArgs.split_num * sizeof(fpga::ConvArgs));
param->SetFpgaArgs(convArgs);
int element_num = fpga::get_aligned_filter_element_num(
filter->dims()[1] * filter->dims()[2] * filter->dims()[3]);
int n = convArgs.split_num;
for (int i = 0; i < n; i++) {
convArgs.conv_args[i].relu_enabled = relu_enabled;
convArgs.conv_args[i].group_num = (uint32_t)param->Groups();
convArgs.conv_args[i].kernel.stride_h = (uint32_t)param->Strides()[0];
convArgs.conv_args[i].kernel.stride_w = (uint32_t)param->Strides()[1];
convArgs.conv_args[i].kernel.height = (uint32_t)filter->dims()[2];
convArgs.conv_args[i].kernel.width = (uint32_t)filter->dims()[3];
convArgs.conv_args[i].image.address = input_ptr;
convArgs.conv_args[i].image.channels = (uint32_t)input->dims()[1];
convArgs.conv_args[i].image.height = (uint32_t)input->dims()[2];
convArgs.conv_args[i].image.width = (uint32_t)input->dims()[3];
convArgs.conv_args[i].image.pad_height = (uint32_t)param->Paddings()[0];
convArgs.conv_args[i].image.pad_width = (uint32_t)param->Paddings()[1];
convArgs.conv_args[i].filter_address =
&((int8_t *)filter_ptr)[i * element_num];
convArgs.conv_args[i].sb_address = &((int8_t *)bs_ptr)[i * element_num];
convArgs.conv_args[i].filter_num =
(uint32_t)(i == n - 1 ? fpga::get_aligned_filter_num(
channel - (n - 1) * element_num_per_div)
: element_num_per_div);
convArgs.conv_args[i].output.scale_address =
(float *)fpga::fpga_malloc(2 * sizeof(float));
convArgs.conv_args[i].image.scale_address = input->scale;
}
fpga::WrapperConvArgs conv_arg;
fpga::fill_conv_arg(&conv_arg, input, out, filter, relu_enabled,
param->Groups(), param->Strides()[0], param->Strides()[1],
param->Paddings()[0], param->Paddings()[1], bs_ptr);
param->SetFpgaArgs(conv_arg);
return true;
}
......
......@@ -74,31 +74,11 @@ bool ConvBNReluKernel<FPGA, float>::Init(FusionConvBNReluParam<FPGA> *param) {
int element_num = fpga::get_aligned_filter_element_num(
filter->dims()[1] * filter->dims()[2] * filter->dims()[3]);
int n = convArgs.split_num;
for (int i = 0; i < n; i++) {
convArgs.conv_args[i].relu_enabled = relu_enabled;
convArgs.conv_args[i].group_num = (uint32_t)param->Groups();
convArgs.conv_args[i].kernel.stride_h = (uint32_t)param->Strides()[0];
convArgs.conv_args[i].kernel.stride_w = (uint32_t)param->Strides()[1];
convArgs.conv_args[i].kernel.height = (uint32_t)filter->dims()[2];
convArgs.conv_args[i].kernel.width = (uint32_t)filter->dims()[3];
convArgs.conv_args[i].image.address = input_ptr;
convArgs.conv_args[i].image.channels = (uint32_t)input->dims()[1];
convArgs.conv_args[i].image.height = (uint32_t)input->dims()[2];
convArgs.conv_args[i].image.width = (uint32_t)input->dims()[3];
convArgs.conv_args[i].image.pad_height = (uint32_t)param->Paddings()[0];
convArgs.conv_args[i].image.pad_width = (uint32_t)param->Paddings()[1];
convArgs.conv_args[i].filter_address =
&((int8_t *)filter_ptr)[i * element_num];
convArgs.conv_args[i].sb_address = &((int8_t *)bs_ptr)[i * element_num];
convArgs.conv_args[i].filter_num =
(uint32_t)(i == n - 1 ? fpga::get_aligned_filter_num(
channel - (n - 1) * element_num_per_div)
: element_num_per_div);
convArgs.conv_args[i].output.scale_address =
(float *)fpga::fpga_malloc(2 * sizeof(float));
convArgs.conv_args[i].image.scale_address = input->scale;
}
fpga::WrapperConvArgs conv_arg;
fpga::fill_conv_arg(&conv_arg, input, out, filter, relu_enabled,
param->Groups(), param->Strides()[0], param->Strides()[1],
param->Paddings()[0], param->Paddings()[1], bs_ptr);
param->SetFpgaArgs(conv_arg);
return true;
}
......
......@@ -54,44 +54,10 @@ bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam<FPGA> *param) {
auto out_ptr = out->mutable_data<float>();
fpga::WrapperConvArgs convArgs;
convArgs.group_num = 1;
convArgs.split_num = (uint32_t)fpga::get_plit_num(filter);
convArgs.filter_num = (uint32_t)filter->dims()[0];
convArgs.output.address = out_ptr;
convArgs.output.scale_address = out->scale;
convArgs.conv_args = (fpga::ConvArgs *)fpga::fpga_malloc(
convArgs.split_num * sizeof(fpga::ConvArgs));
param->SetFpgaArgs(convArgs);
int element_num = fpga::get_aligned_filter_element_num(
filter->dims()[1] * filter->dims()[2] * filter->dims()[3]);
int n = convArgs.split_num;
for (int i = 0; i < n; i++) {
convArgs.conv_args[i].relu_enabled = relu_enabled;
convArgs.conv_args[i].group_num = 1;
convArgs.conv_args[i].kernel.stride_h = 1;
convArgs.conv_args[i].kernel.stride_w = 1;
convArgs.conv_args[i].kernel.height = (uint32_t)filter->dims()[2];
convArgs.conv_args[i].kernel.width = (uint32_t)filter->dims()[3];
convArgs.conv_args[i].image.address = input_x_ptr;
convArgs.conv_args[i].image.channels = (uint32_t)input_x->dims()[1];
convArgs.conv_args[i].image.height = (uint32_t)input_x->dims()[2];
convArgs.conv_args[i].image.width = (uint32_t)input_x->dims()[3];
convArgs.conv_args[i].image.pad_height = 0;
convArgs.conv_args[i].image.pad_width = 0;
convArgs.conv_args[i].filter_address =
&((int8_t *)filter_ptr)[i * element_num];
convArgs.conv_args[i].sb_address = &((int8_t *)bs_ptr)[i * element_num];
convArgs.conv_args[i].filter_num =
(uint32_t)(i == n - 1 ? fpga::get_aligned_filter_num(
channel - (n - 1) * element_num_per_div)
: element_num_per_div);
convArgs.conv_args[i].output.scale_address =
(float *)fpga::fpga_malloc(2 * sizeof(float));
convArgs.conv_args[i].image.scale_address = input_x->scale;
}
fpga::WrapperConvArgs conv_arg;
fpga::fill_conv_arg(&conv_arg, input_x, out, filter, relu_enabled, 1, 1, 1, 0,
0, bs_ptr);
param->SetFpgaArgs(conv_arg);
return true;
}
template <>
......
......@@ -55,44 +55,10 @@ bool FusionFcKernel<FPGA, float>::Init(FusionFcParam<FPGA> *param) {
auto out_ptr = out->mutable_data<float>();
fpga::WrapperConvArgs convArgs;
convArgs.group_num = 1;
convArgs.split_num = (uint32_t)fpga::get_plit_num(filter);
convArgs.filter_num = (uint32_t)filter->dims()[0];
convArgs.output.address = out_ptr;
convArgs.output.scale_address = out->scale;
convArgs.conv_args = (fpga::ConvArgs *)fpga::fpga_malloc(
convArgs.split_num * sizeof(fpga::ConvArgs));
param->SetFpgaArgs(convArgs);
int element_num = fpga::get_aligned_filter_element_num(
filter->dims()[1] * filter->dims()[2] * filter->dims()[3]);
int n = convArgs.split_num;
for (int i = 0; i < n; i++) {
convArgs.conv_args[i].relu_enabled = relu_enabled;
convArgs.conv_args[i].group_num = 1;
convArgs.conv_args[i].kernel.stride_h = 1;
convArgs.conv_args[i].kernel.stride_w = 1;
convArgs.conv_args[i].kernel.height = (uint32_t)filter->dims()[2];
convArgs.conv_args[i].kernel.width = (uint32_t)filter->dims()[3];
convArgs.conv_args[i].image.address = input_x_ptr;
convArgs.conv_args[i].image.channels = (uint32_t)input_x->dims()[1];
convArgs.conv_args[i].image.height = (uint32_t)input_x->dims()[2];
convArgs.conv_args[i].image.width = (uint32_t)input_x->dims()[3];
convArgs.conv_args[i].image.pad_height = 0;
convArgs.conv_args[i].image.pad_width = 0;
convArgs.conv_args[i].filter_address =
&((int8_t *)filter_ptr)[i * element_num];
convArgs.conv_args[i].sb_address = &((int8_t *)bs_ptr)[i * element_num];
convArgs.conv_args[i].filter_num =
(uint32_t)(i == n - 1 ? fpga::get_aligned_filter_num(
channel - (n - 1) * element_num_per_div)
: element_num_per_div);
convArgs.conv_args[i].output.scale_address =
(float *)fpga::fpga_malloc(2 * sizeof(float));
convArgs.conv_args[i].image.scale_address = input_x->scale;
}
fpga::WrapperConvArgs conv_arg;
fpga::fill_conv_arg(&conv_arg, input_x, out, filter, relu_enabled, 1, 1, 1, 0,
0, bs_ptr);
param->SetFpgaArgs(conv_arg);
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册