提交 4f9e2106 编写于 作者: E eclipsess

fix basicConv+bnrelu in dwgroup

上级 cf32b41b
......@@ -109,14 +109,14 @@ void DWConvBNReluBasic(const FusionDWConvBNReluParam &param) {
}
template <typename P>
void DWConvBNReluCompute(const FusionDWConvBNReluParam &param) {
if (0&&param.Groups() == param.Input()->dims()[1] &&
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
} else if (0&&param.Groups() == param.Input()->dims()[1] &&
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
......
......@@ -71,7 +71,8 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a,
SgemmWithBn(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(),
N, beta, matrix_out->data<float>(), N, relu,
new_scale->data<float>()+group, new_bias->data<float>()+group);
new_scale->data<float>() + group,
new_bias->data<float>() + group);
}
} // namespace math
......
......@@ -31,7 +31,8 @@ template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu,
framework::Tensor *new_scale, framework::Tensor *new_bias, int group);
framework::Tensor *new_scale, framework::Tensor *new_bias,
int group);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册