提交 898f7376 编写于 作者: E eclipsess

add fusion bnrelu

上级 b68d873f
...@@ -68,11 +68,11 @@ class FusionConvAddOp : public framework::OperatorWithKernel< ...@@ -68,11 +68,11 @@ class FusionConvAddOp : public framework::OperatorWithKernel<
#ifdef PADDLE_MOBILE_CPU #ifdef PADDLE_MOBILE_CPU
//#ifndef CONV_ADD_REGISTER #ifndef CONV_ADD_REGISTER
//static framework::FusionOpRegistrar convadd_registrar( static framework::FusionOpRegistrar convadd_registrar(
// new FusionConvAddMatcher()); new FusionConvAddMatcher());
//#define CONV_ADD_REGISTER #define CONV_ADD_REGISTER
//#endif #endif
#endif #endif
......
...@@ -80,7 +80,7 @@ class FusionConvAddBNReluOp ...@@ -80,7 +80,7 @@ class FusionConvAddBNReluOp
#ifdef PADDLE_MOBILE_CPU #ifdef PADDLE_MOBILE_CPU
#ifndef FUSION_CONV_ADD_BN_RELU_REGISTER #ifndef FUSION_CONV_ADD_BN_RELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar( static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar(
new FusionConvAddBNReluMatcher()); new FusionConvAddBNReluMatcher());
#define FUSION_CONV_ADD_BN_RELU_REGISTER #define FUSION_CONV_ADD_BN_RELU_REGISTER
#endif #endif
......
...@@ -125,10 +125,10 @@ void ConvAddCompute(const FusionConvAddParam &param) { ...@@ -125,10 +125,10 @@ void ConvAddCompute(const FusionConvAddParam &param) {
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConv3x3(param.Input(), param.Strides(), // math::DepthwiseConv3x3(param.Input(), param.Strides(),
// param.Paddings(), // param.Paddings(),
// param.Filter(), param.Bias(), param.Output(), // param.Filter(), param.Bias(),
// false); // param.Output(), false);
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), param.Output(),
*param.Bias(), true); *param.Bias(), true);
......
...@@ -26,8 +26,6 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) { ...@@ -26,8 +26,6 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) {
Tensor bias = *param.Bias(); Tensor bias = *param.Bias();
Tensor new_bias = *param.NewBias(); Tensor new_bias = *param.NewBias();
Tensor new_scale = *param.NewScale(); Tensor new_scale = *param.NewScale();
auto new_bias_ptr = new_bias.data<float>();
auto new_scale_ptr = new_scale.data<float>();
int axis = param.Axis(); int axis = param.Axis();
Tensor *output = param.Output(); Tensor *output = param.Output();
math::expand_bias(bias, axis, output->dims()); math::expand_bias(bias, axis, output->dims());
...@@ -106,25 +104,12 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) { ...@@ -106,25 +104,12 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) {
// gemm // gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
// math::matmul<float>(filter_slice, false, col_matrix, false,
// static_cast<float>(1), &out_slice, math::matmulWithBn<float>(
// static_cast<float>(0)); filter_slice, false, col_matrix, false, static_cast<float>(1),
math::matmulWithBn<float>(filter_slice, false, col_matrix, false, &out_slice, static_cast<float>(0), true, &new_scale, &new_bias);
static_cast<float>(1), &out_slice,
static_cast<float>(0),true, &new_scale,&new_bias);
} }
} }
/// todo : use neon in special case instead of 2for(300ms)
// auto output_ptr = output->data<float>();
// for (int c = 0; c < output_matrix_shape[0]; c++) {
// int start = c * output_matrix_shape[1];
// for (int j = 0; j < output_matrix_shape[1]; j++) {
// output_ptr[start + j] =
// output_ptr[start + j] * new_scale_ptr[c] + new_bias_ptr[c];
// output_ptr[start + j] =
// output_ptr[start + j] < 0 ? 0 : output_ptr[start + j];
// }
// }
} }
template <typename P> template <typename P>
void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) { void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) {
......
...@@ -30,7 +30,7 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, ...@@ -30,7 +30,7 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a,
template <typename T> template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha, const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu = true, framework::Tensor *matrix_out, T beta, bool relu,
framework::Tensor *new_scale, framework::Tensor *new_bias); framework::Tensor *new_scale, framework::Tensor *new_bias);
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -52,8 +52,9 @@ int main() { ...@@ -52,8 +52,9 @@ int main() {
} }
auto time1 = time(); auto time1 = time();
// paddle_mobile::operators::math::Sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3, c, // paddle_mobile::operators::math::Sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3,
// ldc); // c,
// ldc);
auto time2 = time(); auto time2 = time();
DLOG << "gemm cost :" << time_diff(time1, time2) << "ms\n"; DLOG << "gemm cost :" << time_diff(time1, time2) << "ms\n";
for (int i = 0; i < m * n; ++i) { for (int i = 0; i < m * n; ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册