提交 ce4d99cd 编写于 作者: E eclipsess

gemm bn relu v1

上级 00417d29
......@@ -68,11 +68,11 @@ class FusionConvAddOp : public framework::OperatorWithKernel<
#ifdef PADDLE_MOBILE_CPU
#ifndef CONV_ADD_REGISTER
static framework::FusionOpRegistrar convadd_registrar(
new FusionConvAddMatcher());
#define CONV_ADD_REGISTER
#endif
//#ifndef CONV_ADD_REGISTER
//static framework::FusionOpRegistrar convadd_registrar(
// new FusionConvAddMatcher());
//#define CONV_ADD_REGISTER
//#endif
#endif
......
......@@ -79,11 +79,11 @@ class FusionConvAddBNReluOp
#ifdef PADDLE_MOBILE_CPU
//#ifndef FUSION_CONV_ADD_BN_RELU_REGISTER
// static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar(
// new FusionConvAddBNReluMatcher());
//#define FUSION_CONV_ADD_BN_RELU_REGISTER
//#endif
#ifndef FUSION_CONV_ADD_BN_RELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar(
new FusionConvAddBNReluMatcher());
#define FUSION_CONV_ADD_BN_RELU_REGISTER
#endif
#endif
......
......@@ -125,10 +125,10 @@ void ConvAddCompute(const FusionConvAddParam &param) {
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConv3x3(param.Input(), param.Strides(),
// param.Paddings(),
// param.Filter(), param.Bias(), param.Output(),
// false);
// math::DepthwiseConv3x3(param.Input(), param.Strides(),
// param.Paddings(),
// param.Filter(), param.Bias(), param.Output(),
// false);
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), param.Output(),
*param.Bias(), true);
......
......@@ -106,22 +106,25 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float>(filter_slice, false, col_matrix, false,
// math::matmul<float>(filter_slice, false, col_matrix, false,
// static_cast<float>(1), &out_slice,
// static_cast<float>(0));
math::matmulWithBn<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0));
static_cast<float>(0),true, &new_scale,&new_bias);
}
}
/// todo : use neon in special case instead of 2for(300ms)
auto output_ptr = output->data<float>();
for (int c = 0; c < output_matrix_shape[0]; c++) {
int start = c * output_matrix_shape[1];
for (int j = 0; j < output_matrix_shape[1]; j++) {
output_ptr[start + j] =
output_ptr[start + j] * new_scale_ptr[c] + new_bias_ptr[c];
output_ptr[start + j] =
output_ptr[start + j] < 0 ? 0 : output_ptr[start + j];
}
}
// auto output_ptr = output->data<float>();
// for (int c = 0; c < output_matrix_shape[0]; c++) {
// int start = c * output_matrix_shape[1];
// for (int j = 0; j < output_matrix_shape[1]; j++) {
// output_ptr[start + j] =
// output_ptr[start + j] * new_scale_ptr[c] + new_bias_ptr[c];
// output_ptr[start + j] =
// output_ptr[start + j] < 0 ? 0 : output_ptr[start + j];
// }
// }
}
template <typename P>
void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) {
......
......@@ -1036,8 +1036,7 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter,
float32x4_t vbias = vdupq_n_f32(0.0);
float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1],
input_buff_top[w_times + 1];
float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1];
float32x4_t elewise_res0, elewise_res1, elewise_res2, res3;
int out2in_mid;
float32x4_t zero = vdupq_n_f32(0.0);
......
......@@ -30,7 +30,7 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a,
template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu,
framework::Tensor *matrix_out, T beta, bool relu = true,
framework::Tensor *new_scale, framework::Tensor *new_bias);
} // namespace math
} // namespace operators
......
......@@ -52,8 +52,8 @@ int main() {
}
auto time1 = time();
paddle_mobile::operators::math::sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3, c,
ldc);
// paddle_mobile::operators::math::Sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3, c,
// ldc);
auto time2 = time();
DLOG << "gemm cost :" << time_diff(time1, time2) << "ms\n";
for (int i = 0; i < m * n; ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册