提交 898f7376 编写于 作者: E eclipsess

add fusion bnrelu

上级 b68d873f
......@@ -68,11 +68,11 @@ class FusionConvAddOp : public framework::OperatorWithKernel<
#ifdef PADDLE_MOBILE_CPU
//#ifndef CONV_ADD_REGISTER
//static framework::FusionOpRegistrar convadd_registrar(
// new FusionConvAddMatcher());
//#define CONV_ADD_REGISTER
//#endif
#ifndef CONV_ADD_REGISTER
static framework::FusionOpRegistrar convadd_registrar(
new FusionConvAddMatcher());
#define CONV_ADD_REGISTER
#endif
#endif
......
......@@ -80,7 +80,7 @@ class FusionConvAddBNReluOp
#ifdef PADDLE_MOBILE_CPU
#ifndef FUSION_CONV_ADD_BN_RELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar(
static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar(
new FusionConvAddBNReluMatcher());
#define FUSION_CONV_ADD_BN_RELU_REGISTER
#endif
......
......@@ -125,10 +125,10 @@ void ConvAddCompute(const FusionConvAddParam &param) {
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConv3x3(param.Input(), param.Strides(),
// param.Paddings(),
// param.Filter(), param.Bias(), param.Output(),
// false);
// math::DepthwiseConv3x3(param.Input(), param.Strides(),
// param.Paddings(),
// param.Filter(), param.Bias(),
// param.Output(), false);
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), param.Output(),
*param.Bias(), true);
......
......@@ -26,8 +26,6 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) {
Tensor bias = *param.Bias();
Tensor new_bias = *param.NewBias();
Tensor new_scale = *param.NewScale();
auto new_bias_ptr = new_bias.data<float>();
auto new_scale_ptr = new_scale.data<float>();
int axis = param.Axis();
Tensor *output = param.Output();
math::expand_bias(bias, axis, output->dims());
......@@ -106,25 +104,12 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
// math::matmul<float>(filter_slice, false, col_matrix, false,
// static_cast<float>(1), &out_slice,
// static_cast<float>(0));
math::matmulWithBn<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0),true, &new_scale,&new_bias);
math::matmulWithBn<float>(
filter_slice, false, col_matrix, false, static_cast<float>(1),
&out_slice, static_cast<float>(0), true, &new_scale, &new_bias);
}
}
/// todo : use neon in special case instead of 2for(300ms)
// auto output_ptr = output->data<float>();
// for (int c = 0; c < output_matrix_shape[0]; c++) {
// int start = c * output_matrix_shape[1];
// for (int j = 0; j < output_matrix_shape[1]; j++) {
// output_ptr[start + j] =
// output_ptr[start + j] * new_scale_ptr[c] + new_bias_ptr[c];
// output_ptr[start + j] =
// output_ptr[start + j] < 0 ? 0 : output_ptr[start + j];
// }
// }
}
template <typename P>
void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) {
......
......@@ -30,7 +30,7 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a,
template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu = true,
framework::Tensor *matrix_out, T beta, bool relu,
framework::Tensor *new_scale, framework::Tensor *new_bias);
} // namespace math
} // namespace operators
......
......@@ -52,8 +52,9 @@ int main() {
}
auto time1 = time();
// paddle_mobile::operators::math::Sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3, c,
// ldc);
// paddle_mobile::operators::math::Sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3,
// c,
// ldc);
auto time2 = time();
DLOG << "gemm cost :" << time_diff(time1, time2) << "ms\n";
for (int i = 0; i < m * n; ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册