提交 ce4d99cd 编写于 作者: E eclipsess

gemm bn relu v1

上级 00417d29
...@@ -68,11 +68,11 @@ class FusionConvAddOp : public framework::OperatorWithKernel< ...@@ -68,11 +68,11 @@ class FusionConvAddOp : public framework::OperatorWithKernel<
#ifdef PADDLE_MOBILE_CPU #ifdef PADDLE_MOBILE_CPU
#ifndef CONV_ADD_REGISTER //#ifndef CONV_ADD_REGISTER
static framework::FusionOpRegistrar convadd_registrar( //static framework::FusionOpRegistrar convadd_registrar(
new FusionConvAddMatcher()); // new FusionConvAddMatcher());
#define CONV_ADD_REGISTER //#define CONV_ADD_REGISTER
#endif //#endif
#endif #endif
......
...@@ -79,11 +79,11 @@ class FusionConvAddBNReluOp ...@@ -79,11 +79,11 @@ class FusionConvAddBNReluOp
#ifdef PADDLE_MOBILE_CPU #ifdef PADDLE_MOBILE_CPU
//#ifndef FUSION_CONV_ADD_BN_RELU_REGISTER #ifndef FUSION_CONV_ADD_BN_RELU_REGISTER
// static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar( static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar(
// new FusionConvAddBNReluMatcher()); new FusionConvAddBNReluMatcher());
//#define FUSION_CONV_ADD_BN_RELU_REGISTER #define FUSION_CONV_ADD_BN_RELU_REGISTER
//#endif #endif
#endif #endif
......
...@@ -125,10 +125,10 @@ void ConvAddCompute(const FusionConvAddParam &param) { ...@@ -125,10 +125,10 @@ void ConvAddCompute(const FusionConvAddParam &param) {
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConv3x3(param.Input(), param.Strides(), // math::DepthwiseConv3x3(param.Input(), param.Strides(),
// param.Paddings(), // param.Paddings(),
// param.Filter(), param.Bias(), param.Output(), // param.Filter(), param.Bias(), param.Output(),
// false); // false);
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), param.Output(),
*param.Bias(), true); *param.Bias(), true);
......
...@@ -106,22 +106,25 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) { ...@@ -106,22 +106,25 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) {
// gemm // gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float>(filter_slice, false, col_matrix, false, // math::matmul<float>(filter_slice, false, col_matrix, false,
// static_cast<float>(1), &out_slice,
// static_cast<float>(0));
math::matmulWithBn<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice, static_cast<float>(1), &out_slice,
static_cast<float>(0)); static_cast<float>(0),true, &new_scale,&new_bias);
} }
} }
/// todo : use neon in special case instead of 2for(300ms) /// todo : use neon in special case instead of 2for(300ms)
auto output_ptr = output->data<float>(); // auto output_ptr = output->data<float>();
for (int c = 0; c < output_matrix_shape[0]; c++) { // for (int c = 0; c < output_matrix_shape[0]; c++) {
int start = c * output_matrix_shape[1]; // int start = c * output_matrix_shape[1];
for (int j = 0; j < output_matrix_shape[1]; j++) { // for (int j = 0; j < output_matrix_shape[1]; j++) {
output_ptr[start + j] = // output_ptr[start + j] =
output_ptr[start + j] * new_scale_ptr[c] + new_bias_ptr[c]; // output_ptr[start + j] * new_scale_ptr[c] + new_bias_ptr[c];
output_ptr[start + j] = // output_ptr[start + j] =
output_ptr[start + j] < 0 ? 0 : output_ptr[start + j]; // output_ptr[start + j] < 0 ? 0 : output_ptr[start + j];
} // }
} // }
} }
template <typename P> template <typename P>
void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) { void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) {
......
...@@ -1036,8 +1036,7 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1036,8 +1036,7 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter,
float32x4_t vbias = vdupq_n_f32(0.0); float32x4_t vbias = vdupq_n_f32(0.0);
float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1], float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1];
input_buff_top[w_times + 1];
float32x4_t elewise_res0, elewise_res1, elewise_res2, res3; float32x4_t elewise_res0, elewise_res1, elewise_res2, res3;
int out2in_mid; int out2in_mid;
float32x4_t zero = vdupq_n_f32(0.0); float32x4_t zero = vdupq_n_f32(0.0);
......
...@@ -30,7 +30,7 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, ...@@ -30,7 +30,7 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a,
template <typename T> template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha, const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu, framework::Tensor *matrix_out, T beta, bool relu = true,
framework::Tensor *new_scale, framework::Tensor *new_bias); framework::Tensor *new_scale, framework::Tensor *new_bias);
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -52,8 +52,8 @@ int main() { ...@@ -52,8 +52,8 @@ int main() {
} }
auto time1 = time(); auto time1 = time();
paddle_mobile::operators::math::sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3, c, // paddle_mobile::operators::math::Sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3, c,
ldc); // ldc);
auto time2 = time(); auto time2 = time();
DLOG << "gemm cost :" << time_diff(time1, time2) << "ms\n"; DLOG << "gemm cost :" << time_diff(time1, time2) << "ms\n";
for (int i = 0; i < m * n; ++i) { for (int i = 0; i < m * n; ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册