提交 61155043 编写于 作者: 李寅

Make bias optional

上级 f4a96655
......@@ -73,10 +73,12 @@ class Conv2dFunctor {
#pragma omp parallel for collapse(2)
for (int n = 0; n < batch; ++n) {
for (int c = 0; c < channels; ++c) {
T bias_channel = bias ? bias[c] : 0;
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
index_t offset = n * channels * height * width +
c * height * width + h * width + w;
output[offset] = bias_channel;
T sum = 0;
const T *filter_ptr = filter + c * kernel_size;
for (int inc = 0; inc < input_channels; ++inc) {
......@@ -102,8 +104,8 @@ class Conv2dFunctor {
++filter_ptr;
}
}
output[offset] = sum + bias[c];
}
output[offset] += sum;
}
}
}
......
......@@ -75,10 +75,12 @@ class DepthwiseConv2dFunctor {
#pragma omp parallel for collapse(2)
for (int n = 0; n < batch; ++n) {
for (int c = 0; c < channels; ++c) {
T bias_channel = bias ? bias[c] : 0;
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
index_t offset = n * channels * height * width +
c * height * width + h * width + w;
output[offset] = bias_channel;
T sum = 0;
const T *filter_ptr = filter + c * kernel_size;
for (int kh = 0; kh < kernel_h; ++kh) {
......@@ -103,7 +105,7 @@ class DepthwiseConv2dFunctor {
++filter_ptr;
}
}
output[offset] = sum + bias[c];
output[offset] += sum;
}
}
}
......
......@@ -47,9 +47,7 @@ void Conv2dNeonK1x1S1(const float *input, // NCHW
// Fill with bias
float *output_ptr = channel_output_start;
for (index_t ptr = 0; ptr < total_pixels; ++ptr) {
output_ptr[ptr] = bias[c]; // TODO can we avoid this?
}
std::fill(output_ptr, output_ptr + total_pixels, bias ? bias[c] : 0);
index_t inc = 0;
// Process 4 input channels in batch
......
......@@ -28,7 +28,7 @@ namespace kernels {
input_ptr += (oc / multiplier) * input_height * input_width; \
} \
float *output_ptr = output_ptr_base + oc * output_height * output_width; \
std::fill(output_ptr, output_ptr + output_height * output_width, bias[oc]); \
std::fill(output_ptr, output_ptr + output_height * output_width, bias ? bias[oc] : 0); \
for (int ic = 0; ic < filter_in_channels; ++ic) { \
float32x4_t n_filter_v[3] = {vld1q_f32(filter_ptr), vld1q_f32(filter_ptr+3), vld1q_f32(filter_ptr+6)};
......
......@@ -45,9 +45,8 @@ void Conv2dNeonK5x5S1(const float *input, // NCHW
const float *input_ptr = input + n * input_total_pixels_per_batch;
// Fill with bias
for (index_t i = 0; i < output_total_pixels_per_channel; ++i) {
output_ptr[i] = bias[c];
}
std::fill(output_ptr, output_ptr + output_total_pixels_per_channel,
bias ? bias[c] : 0);
for (index_t inc = 0; inc < input_channels; ++inc) {
float *outptr = output_ptr;
......
......@@ -27,7 +27,12 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
bool Run() override {
const Tensor *input = this->Input(INPUT);
const Tensor *filter = this->Input(FILTER);
const Tensor *bias = this->Input(BIAS);
const T *bias_data = nullptr;
if (this->InputSize() >= 3) {
const Tensor *bias = this->Input(BIAS);
bias_data = bias->data<T>();
}
Tensor *output = this->Output(OUTPUT);
std::vector<index_t> output_shape(4);
......@@ -36,7 +41,7 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
output->Resize(output_shape);
functor_(input->data<T>(), input->shape().data(), filter->data<T>(),
filter->shape().data(), bias->data<T>(), output->mutable_data<T>(),
filter->shape().data(), bias_data, output->mutable_data<T>(),
output->shape().data());
return true;
......
......@@ -26,7 +26,11 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
bool Run() override {
const Tensor *input = this->Input(INPUT);
const Tensor *filter = this->Input(FILTER);
const Tensor *bias = this->Input(BIAS);
const T *bias_data = nullptr;
if (this->InputSize() >= 3) {
const Tensor *bias = this->Input(BIAS);
bias_data = bias->data<T>();
}
Tensor *output = this->Output(OUTPUT);
// resize filter shape.
......@@ -38,7 +42,7 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
output->Resize(output_shape);
functor_(input->data<T>(), input->shape().data(), filter->data<T>(),
filter_shape.data(), bias->data<T>(), output->mutable_data<T>(),
filter_shape.data(), bias_data, output->mutable_data<T>(),
output->shape().data());
return true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册