提交 a58aea9e 编写于 作者: E eclipsess

fix basicConv+bnrelu in dwgroup

上级 f9b05d87
...@@ -104,7 +104,7 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) { ...@@ -104,7 +104,7 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) {
math::matmulWithBn<float>( math::matmulWithBn<float>(
filter_slice, false, col_matrix, false, static_cast<float>(1), filter_slice, false, col_matrix, false, static_cast<float>(1),
&out_slice, static_cast<float>(0), true, &new_scale, &new_bias); &out_slice, static_cast<float>(0), true, &new_scale, &new_bias, g);
} }
} }
} }
......
...@@ -101,23 +101,22 @@ void DWConvBNReluBasic(const FusionDWConvBNReluParam &param) { ...@@ -101,23 +101,22 @@ void DWConvBNReluBasic(const FusionDWConvBNReluParam &param) {
// gemm // gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
std::cout << "***************" << std::endl;
math::matmulWithBn<float>( math::matmulWithBn<float>(
filter_slice, false, col_matrix, false, static_cast<float>(1), filter_slice, false, col_matrix, false, static_cast<float>(1),
&out_slice, static_cast<float>(0), false, &new_scale, &new_bias); &out_slice, static_cast<float>(0), true, &new_scale, &new_bias, g);
} }
} }
} }
template <typename P> template <typename P>
void DWConvBNReluCompute(const FusionDWConvBNReluParam &param) { void DWConvBNReluCompute(const FusionDWConvBNReluParam &param) {
if (param.Groups() == param.Input()->dims()[1] && if (0&&param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(),
param.Output(), param.NewScale(), param.Output(), param.NewScale(),
param.NewBias(), true); param.NewBias(), true);
} else if (param.Groups() == param.Input()->dims()[1] && } else if (0&&param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
......
...@@ -50,7 +50,7 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a, ...@@ -50,7 +50,7 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, const framework::Tensor &matrix_b, bool trans_b,
float alpha, framework::Tensor *matrix_out, float beta, float alpha, framework::Tensor *matrix_out, float beta,
bool relu, framework::Tensor *new_scale, bool relu, framework::Tensor *new_scale,
framework::Tensor *new_bias) { framework::Tensor *new_bias, int group) {
auto dim_a = matrix_a.dims(); auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims(); auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims(); auto dim_out = matrix_out->dims();
...@@ -71,7 +71,7 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a, ...@@ -71,7 +71,7 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a,
SgemmWithBn(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), SgemmWithBn(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(),
N, beta, matrix_out->data<float>(), N, relu, N, beta, matrix_out->data<float>(), N, relu,
new_scale->data<float>(), new_bias->data<float>()); new_scale->data<float>()+group, new_bias->data<float>()+group);
} }
} // namespace math } // namespace math
......
...@@ -31,7 +31,7 @@ template <typename T> ...@@ -31,7 +31,7 @@ template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha, const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu, framework::Tensor *matrix_out, T beta, bool relu,
framework::Tensor *new_scale, framework::Tensor *new_bias); framework::Tensor *new_scale, framework::Tensor *new_bias, int group);
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册