提交 3e848f61 编写于 作者: E eclipsess

convaddbnrelu fix

上级 e4ed8847
...@@ -22,11 +22,11 @@ namespace operators { ...@@ -22,11 +22,11 @@ namespace operators {
template <> template <>
bool ConvAddBNReluKernel<CPU, float>::Init(FusionConvAddBNReluParam *param) { bool ConvAddBNReluKernel<CPU, float>::Init(FusionConvAddBNReluParam *param) {
const Tensor *mean = (*param).InputMean(); const Tensor *mean = param->InputMean();
const Tensor *variance = (*param).InputVariance(); const Tensor *variance = param->InputVariance();
const Tensor *scale = (*param).InputScale(); const Tensor *scale = param->InputScale();
const Tensor *bias = (*param).InputBias(); const Tensor *bias = param->InputBias();
const float epsilon = (*param).Epsilon(); const float epsilon = param->Epsilon();
auto mean_ptr = mean->data<float>(); auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>(); auto variance_ptr = variance->data<float>();
...@@ -47,8 +47,8 @@ bool ConvAddBNReluKernel<CPU, float>::Init(FusionConvAddBNReluParam *param) { ...@@ -47,8 +47,8 @@ bool ConvAddBNReluKernel<CPU, float>::Init(FusionConvAddBNReluParam *param) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
} }
(*param).SetNewScale(new_scale); param->SetNewScale(new_scale);
(*param).SetNewBias(new_bias); param->SetNewBias(new_bias);
return true; return true;
} }
......
...@@ -556,29 +556,35 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -556,29 +556,35 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
float w21 = filter_data_tmp[7]; float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8]; float w22 = filter_data_tmp[8];
output_data[0] = output_data[0] = w11 * input_data[0] + w12 * input_data[1] +
(w11 * input_data[0] + w12 * input_data[1] + w21 * input_data[l] + w21 * input_data[l] + w22 * input_data[l + 1];
w22 * input_data[l + 1] + bias_data[j]) *
newscale_data[j] + output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - 1] +
newbias_data[j];
output_data[l - 1] = (w10 * input_data[l - 2] + w11 * input_data[l - 1] +
w20 * input_data[2 * l - 2] + w20 * input_data[2 * l - 2] +
w21 * input_data[2 * l - 1] + bias_data[j]) * w21 * input_data[2 * l - 1];
newscale_data[j] +
newbias_data[j];
output_data[(l - 1) * l] = output_data[(l - 1) * l] =
(w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] + w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] +
w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1] + w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1];
bias_data[j]) * output_data[l * l - 1] = w00 * input_data[(l - 2) * (l + 1)] +
newscale_data[j] +
newbias_data[j];
output_data[l * l - 1] = (w00 * input_data[(l - 2) * (l + 1)] +
w01 * input_data[(l - 2) * (l + 1) + 1] + w01 * input_data[(l - 2) * (l + 1) + 1] +
w10 * input_data[l * l - 2] + w10 * input_data[l * l - 2] +
w11 * input_data[l * l - 1] + bias_data[j]) * w11 * input_data[l * l - 1];
newscale_data[j] + if (if_bias) {
newbias_data[j]; output_data[0] += bias_data[j];
output_data[l - 1] += bias_data[j];
output_data[(l - 1) * l] += bias_data[j];
output_data[l * l - 1] += bias_data[j];
}
if (if_bn) {
output_data[0] = output_data[0] * newscale_data[j] + newbias_data[j];
output_data[l - 1] =
output_data[l - 1] * newscale_data[j] + newbias_data[j];
output_data[(l - 1) * l] =
output_data[(l - 1) * l] * newscale_data[j] + newbias_data[j];
output_data[l * l - 1] =
output_data[l * l - 1] * newscale_data[j] + newbias_data[j];
}
if (if_relu) { if (if_relu) {
output_data[0] = output_data[0] < 0 ? 0 : output_data[0]; output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
output_data[l - 1] = output_data[l - 1] < 0 ? 0 : output_data[l - 1]; output_data[l - 1] = output_data[l - 1] < 0 ? 0 : output_data[l - 1];
...@@ -604,6 +610,16 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -604,6 +610,16 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
w21 * input_data[i * l + l - 1 + l] + bias_data[j]) * w21 * input_data[i * l + l - 1 + l] + bias_data[j]) *
newscale_data[j] + newscale_data[j] +
newbias_data[j]; newbias_data[j];
if (if_bias) {
output_data[i * l] += bias_data[j];
output_data[i * l + l - 1] += bias_data[j];
}
if (if_bn) {
output_data[i * l] =
output_data[i * l] * newscale_data[j] + newbias_data[j];
output_data[i * l + l - 1] =
output_data[i * l + l - 1] * newscale_data[j] + newbias_data[j];
}
if (if_relu) { if (if_relu) {
output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i * l]; output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i * l];
output_data[i * l + l - 1] = output_data[i * l + l - 1] =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册