提交 d64b527b 编写于 作者: E eclipsess

remove Init const

上级 35f359bd
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define FUSION_CONVADDBNRELU_OP
#ifdef FUSION_CONVADDBNRELU_OP
#pragma once
......@@ -79,11 +78,13 @@ class FusionConvAddBNReluOp
};
#ifdef PADDLE_MOBILE_CPU
//#ifndef FUSION_CONV_ADD_BN_RELU_REGISTER
// static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar(
// new FusionConvAddBNReluMatcher());
//#define FUSION_CONV_ADD_BN_RELU_REGISTER
//#endif
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
......
......@@ -21,7 +21,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool BatchNormKernel<CPU, float>::Init(BatchNormParam *param) const {
bool BatchNormKernel<CPU, float>::Init(BatchNormParam *param) {
return true;
}
......
......@@ -111,7 +111,7 @@ void DecodeCenterSize(const framework::Tensor& target_box,
}
template <>
bool BoxCoderKernel<CPU, float>::Init(BoxCoderParam* param) const {
bool BoxCoderKernel<CPU, float>::Init(BoxCoderParam* param) {
return true;
}
......
......@@ -53,7 +53,7 @@ class ConcatFunctor {
};
template <>
bool ConcatKernel<CPU, float>::Init(ConcatParam *param) const {
bool ConcatKernel<CPU, float>::Init(ConcatParam *param) {
return true;
}
......
......@@ -21,8 +21,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ConvAddBNReluKernel<CPU, float>::Init(
FusionConvAddBNReluParam *param) const {
bool ConvAddBNReluKernel<CPU, float>::Init(FusionConvAddBNReluParam *param) {
const Tensor *mean = (*param).InputMean();
const Tensor *variance = (*param).InputVariance();
const Tensor *scale = (*param).InputScale();
......
......@@ -19,7 +19,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ConvAddKernel<CPU, float>::Init(FusionConvAddParam *param) const {
bool ConvAddKernel<CPU, float>::Init(FusionConvAddParam *param) {
return true;
}
......
......@@ -21,7 +21,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ConvAddReluKernel<CPU, float>::Init(FusionConvAddReluParam *param) const {
bool ConvAddReluKernel<CPU, float>::Init(FusionConvAddReluParam *param) {
return true;
}
......
......@@ -21,7 +21,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ConvKernel<CPU, float>::Init(ConvParam *param) const {
bool ConvKernel<CPU, float>::Init(ConvParam *param) {
return true;
}
......
......@@ -21,7 +21,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool DepthwiseConvKernel<CPU, float>::Init(ConvParam *param) const {
bool DepthwiseConvKernel<CPU, float>::Init(ConvParam *param) {
return true;
}
......
......@@ -27,7 +27,7 @@ struct AddFunctor {
};
template <>
bool ElementwiseAddKernel<CPU, float>::Init(ElementwiseAddParam *param) const {
bool ElementwiseAddKernel<CPU, float>::Init(ElementwiseAddParam *param) {
return true;
}
......
......@@ -22,7 +22,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool FusionFcKernel<CPU, float>::Init(FusionFcParam *param) const {
bool FusionFcKernel<CPU, float>::Init(FusionFcParam *param) {
return true;
}
......
......@@ -22,7 +22,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool LrnKernel<CPU, float>::Init(LrnParam *param) const {
bool LrnKernel<CPU, float>::Init(LrnParam *param) {
return true;
}
......
......@@ -22,7 +22,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool MulKernel<CPU, float>::Init(MulParam *param) const {
bool MulKernel<CPU, float>::Init(MulParam *param) {
return true;
}
......
......@@ -204,7 +204,7 @@ void MultiClassOutput(const Tensor& scores, const Tensor& bboxes,
}
template <>
bool MultiClassNMSKernel<CPU, float>::Init(MultiClassNMSParam* param) const {
bool MultiClassNMSKernel<CPU, float>::Init(MultiClassNMSParam* param) {
return true;
}
......
......@@ -36,7 +36,7 @@ inline void PoolBasic(std::string pooling_type, std::vector<int> ksize,
}
template <>
bool PoolKernel<CPU, float>::Init(PoolParam *param) const {
bool PoolKernel<CPU, float>::Init(PoolParam *param) {
return true;
}
......
......@@ -27,7 +27,7 @@ struct ClipFunctor {
};
template <>
bool PriorBoxKernel<CPU, float>::Init(PriorBoxParam *param) const {
bool PriorBoxKernel<CPU, float>::Init(PriorBoxParam *param) {
return true;
}
......
......@@ -26,7 +26,7 @@ struct ReluFunctor {
};
template <>
bool ReluKernel<CPU, float>::Init(ReluParam *param) const {
bool ReluKernel<CPU, float>::Init(ReluParam *param) {
return true;
}
......
......@@ -20,7 +20,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ReshapeKernel<CPU, float>::Init(ReshapeParam *param) const {
bool ReshapeKernel<CPU, float>::Init(ReshapeParam *param) {
return true;
}
......
......@@ -72,7 +72,7 @@ void sigmoid(const Tensor *X, Tensor *Y) {
}
template <>
bool SigmoidKernel<CPU, float>::Init(SigmoidParam *param) const {
bool SigmoidKernel<CPU, float>::Init(SigmoidParam *param) {
return true;
}
......
......@@ -20,7 +20,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool SoftmaxKernel<CPU, float>::Init(SoftmaxParam *param) const {
bool SoftmaxKernel<CPU, float>::Init(SoftmaxParam *param) {
return true;
}
......
......@@ -35,7 +35,7 @@ namespace operators {
// }
template <>
bool TransposeKernel<CPU, float>::Init(TransposeParam* param) const {
bool TransposeKernel<CPU, float>::Init(TransposeParam* param) {
return true;
}
......
......@@ -29,7 +29,7 @@ class BatchNormKernel
: public framework::OpKernelBase<DeviceType, BatchNormParam> {
public:
void Compute(const BatchNormParam &param) const;
bool Init(BatchNormParam *param) const;
bool Init(BatchNormParam *param);
};
} // namespace operators
......
......@@ -30,7 +30,7 @@ class BoxCoderKernel
: public framework::OpKernelBase<DeviceType, BoxCoderParam> {
public:
void Compute(const BoxCoderParam& param) const;
bool Init(BoxCoderParam* param) const;
bool Init(BoxCoderParam* param);
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#pragma once
#include "operators/kernel/conv_add_bn_relu_kernel.h"
#include "operators/math/depthwiseconv3x3s1p1.h"
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
......@@ -24,23 +24,12 @@ namespace operators {
template <typename P>
void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) {
const Tensor *input = param.Input();
DLOG << "input: " << *input;
Tensor filter = *param.Filter();
DLOG << "filter: " << filter;
Tensor bias = *param.Bias();
DLOG << "bias: " << bias;
Tensor new_bias = *param.NewBias();
Tensor new_scale = *param.NewScale();
auto new_bias_ptr = new_bias.data<float>();
auto new_scale_ptr = new_scale.data<float>();
//
// for(int i = 0; i < new_scale.numel(); i++){
// std::cout << "new_scale " << new_scale_ptr[i] <<std::endl;
// }
// for(int i = 0; i < new_bias.numel(); i++){
// std::cout << "new_bias " << new_bias_ptr[i] <<std::endl;
// }
int axis = param.Axis();
int groups = param.Groups();
std::vector<int> strides = param.Strides();
......@@ -50,8 +39,8 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) {
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
if (filter_shape_vec[2] == 3 && strides[0] == 1 && groups > 1) {
math::DepthwiseConv3x3s1p1(input, filter, output, &bias, 1, &new_scale,
&new_bias, 1, 1);
math::DepthwiseConvAddBNRelu3x3s1p1(input, filter, output, &bias, 1,
&new_scale, &new_bias, 1, 1);
} else {
const int batch_size = static_cast<int>(input->dims()[0]);
......@@ -131,11 +120,12 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) {
auto output_ptr = output->data<float>();
for (int c = 0; c < output_matrix_shape[0]; c++) {
// int start = c * output_matrix_shape[1];
int start = c * output_matrix_shape[1];
for (int j = 0; j < output_matrix_shape[1]; j++) {
// output_ptr[start + j] = output_ptr[start
// +j]*new_scale_ptr[c]+new_bias_ptr[c]; output_ptr[start + j] =
// output_ptr[start+j]< 0 ? 0 : output_ptr[start +j];
output_ptr[start + j] =
output_ptr[start + j] * new_scale_ptr[c] + new_bias_ptr[c];
output_ptr[start + j] =
output_ptr[start + j] < 0 ? 0 : output_ptr[start + j];
}
}
}
......
......@@ -27,7 +27,7 @@ template <typename DeviceType, typename T>
class ConcatKernel : public framework::OpKernelBase<DeviceType, ConcatParam> {
public:
void Compute(const ConcatParam &param) const;
bool Init(ConcatParam *param) const;
bool Init(ConcatParam *param);
};
} // namespace operators
......
......@@ -36,7 +36,7 @@ class ConvAddBNReluKernel
: public OpKernelBase<DeviceType, FusionConvAddBNReluParam> {
public:
void Compute(const FusionConvAddBNReluParam &param) const;
bool Init(FusionConvAddBNReluParam *param) const;
bool Init(FusionConvAddBNReluParam *param);
};
} // namespace operators
......
......@@ -40,7 +40,7 @@ template <typename DeviceType, typename T>
class ConvAddKernel : public OpKernelBase<DeviceType, FusionConvAddParam> {
public:
void Compute(const FusionConvAddParam &param) const;
bool Init(FusionConvAddParam *param) const;
bool Init(FusionConvAddParam *param);
};
} // namespace operators
......
......@@ -36,7 +36,7 @@ class ConvAddReluKernel
: public OpKernelBase<DeviceType, FusionConvAddReluParam> {
public:
void Compute(const FusionConvAddReluParam &param) const;
bool Init(FusionConvAddReluParam *param) const;
bool Init(FusionConvAddReluParam *param);
};
} // namespace operators
......
......@@ -32,7 +32,7 @@ template <typename DeviceType, typename T>
class ConvKernel : public OpKernelBase<DeviceType, ConvParam> {
public:
void Compute(const ConvParam &param) const;
bool Init(ConvParam *param) const;
bool Init(ConvParam *param);
};
} // namespace operators
......
......@@ -31,7 +31,7 @@ template <typename DeviceType, typename T>
class DepthwiseConvKernel : public OpKernelBase<DeviceType, ConvParam> {
public:
void Compute(const ConvParam &param) const;
bool Init(ConvParam *param) const;
bool Init(ConvParam *param);
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -30,7 +30,7 @@ class ElementwiseAddKernel
: public framework::OpKernelBase<DeviceType, ElementwiseAddParam> {
public:
void Compute(const ElementwiseAddParam &param) const;
bool Init(ElementwiseAddParam *param) const;
bool Init(ElementwiseAddParam *param);
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -20,7 +20,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ConvKernel<FPGA, float>::Init(ConvParam *param) const {
bool ConvKernel<FPGA, float>::Init(ConvParam *param) {
return true;
}
......
......@@ -28,7 +28,7 @@ class FusionFcKernel
: public framework::OpKernelBase<DeviceType, FusionFcParam> {
public:
void Compute(const FusionFcParam& param) const;
bool Init(FusionFcParam* param) const;
bool Init(FusionFcParam* param);
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -170,7 +170,7 @@ template <typename DeviceType, typename T>
class LrnKernel : public framework::OpKernelBase<DeviceType, LrnParam> {
public:
void Compute(const LrnParam &param) const;
bool Init(LrnParam *param) const;
bool Init(LrnParam *param);
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -128,7 +128,7 @@ class AclBatchNormOp : public acl::ACLOperator {
};
template <>
bool BatchNormKernel<GPU_MALI, float>::Init(BatchNormParam* param) const {
bool BatchNormKernel<GPU_MALI, float>::Init(BatchNormParam* param) {
AclBatchNormOp<GPU_MALI, float>* acl_op =
reinterpret_cast<AclBatchNormOp<GPU_MALI, float>*>(this->GetAclOp());
if (acl_op == nullptr) {
......
......@@ -195,7 +195,7 @@ class AclConvOp : public acl::ACLOperator {
};
template <>
bool ConvKernel<GPU_MALI, float>::Init(ConvParam* param) const {
bool ConvKernel<GPU_MALI, float>::Init(ConvParam* param) {
AclConvOp<GPU_MALI, float>* acl_op =
reinterpret_cast<AclConvOp<GPU_MALI, float>*>(this->GetAclOp());
if (acl_op == nullptr) {
......
......@@ -27,8 +27,7 @@ struct AddFunctor {
};
template <>
bool ElementwiseAddKernel<GPU_MALI, float>::Init(
ElementwiseAddParam *param) const {
bool ElementwiseAddKernel<GPU_MALI, float>::Init(ElementwiseAddParam *param) {
return true;
}
......
......@@ -22,7 +22,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool FusionFcKernel<GPU_MALI, float>::Init(FusionFcParam *param) const {
bool FusionFcKernel<GPU_MALI, float>::Init(FusionFcParam *param) {
return true;
}
......
......@@ -22,7 +22,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool MulKernel<GPU_MALI, float>::Init(MulParam *param) const {
bool MulKernel<GPU_MALI, float>::Init(MulParam *param) {
return true;
}
......
......@@ -22,7 +22,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ReshapeKernel<GPU_MALI, float>::Init(ReshapeParam *param) const {
bool ReshapeKernel<GPU_MALI, float>::Init(ReshapeParam *param) {
return true;
}
......
......@@ -29,7 +29,7 @@ template <typename DeviceType, typename T>
class MulKernel : public framework::OpKernelBase<DeviceType, MulParam> {
public:
void Compute(const MulParam &param) const;
bool Init(MulParam *param) const;
bool Init(MulParam *param);
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -28,7 +28,7 @@ class MultiClassNMSKernel
: public framework::OpKernelBase<DeviceType, MultiClassNMSParam> {
public:
void Compute(const MultiClassNMSParam& param) const;
bool Init(MultiClassNMSParam* param) const;
bool Init(MultiClassNMSParam* param);
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -28,7 +28,7 @@ template <typename DeviceType, typename T>
class PoolKernel : public OpKernelBase<DeviceType, PoolParam> {
public:
void Compute(const PoolParam &param) const override;
bool Init(PoolParam *param) const;
bool Init(PoolParam *param);
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -55,7 +55,7 @@ class PriorBoxKernel
: public framework::OpKernelBase<DeviceType, PriorBoxParam> {
public:
void Compute(const PriorBoxParam& param) const;
bool Init(PriorBoxParam* param) const;
bool Init(PriorBoxParam* param);
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -27,7 +27,7 @@ template <typename DeviceType, typename T>
class ReluKernel : public framework::OpKernelBase<DeviceType, ReluParam> {
public:
void Compute(const ReluParam& param) const;
bool Init(ReluParam* param) const;
bool Init(ReluParam* param);
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -71,7 +71,7 @@ template <typename DeviceType, typename T>
class ReshapeKernel : public framework::OpKernelBase<DeviceType, ReshapeParam> {
public:
void Compute(const ReshapeParam& param) const;
bool Init(ReshapeParam* param) const;
bool Init(ReshapeParam* param);
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -26,7 +26,7 @@ template <typename DeviceType, typename T>
class SigmoidKernel : public OpKernelBase<DeviceType, SigmoidParam> {
public:
void Compute(const SigmoidParam& param) const override;
bool Init(SigmoidParam* param) const;
bool Init(SigmoidParam* param);
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -29,7 +29,7 @@ template <typename DeviceType, typename T>
class SoftmaxKernel : public OpKernelBase<DeviceType, SoftmaxParam> {
public:
void Compute(const SoftmaxParam &param) const override;
bool Init(SoftmaxParam *param) const;
bool Init(SoftmaxParam *param);
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -29,7 +29,7 @@ class TransposeKernel
: public framework::OpKernelBase<DeviceType, TransposeParam> {
public:
void Compute(const TransposeParam& param) const;
bool Init(TransposeParam* param) const;
bool Init(TransposeParam* param);
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -502,6 +502,322 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
}
}
}
void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, Tensor filter,
Tensor *output, Tensor *bias, bool if_bias,
Tensor *new_scale, Tensor *new_bias,
bool if_bn, bool if_relu) {
const float *input_data = input->data<float>();
const float *filter_data = filter.data<float>();
float *output_data = output->data<float>();
const float *bias_data = bias->data<float>();
const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>();
const int h = static_cast<int>(input->dims()[2]);
const int w = static_cast<int>(input->dims()[3]);
const int l = h;
const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]);
const int hxw = h * w;
float32x4_t vbias = vdupq_n_f32(0.0);
float32x4_t vnewbias = vdupq_n_f32(0.0);
float32x4_t vnewscale = vdupq_n_f32(1.0);
float32x4_t vzero = vdupq_n_f32(0);
for (int b = 0; b < batch_size; ++b) {
const float *filter_data_tmp = filter_data;
for (int j = 0; j < c; ++j) {
if (if_bias) {
vbias = vdupq_n_f32(bias_data[j]);
}
if (if_bn) {
vnewbias = vdupq_n_f32(newbias_data[j]);
vnewscale = vdupq_n_f32(newscale_data[j]);
}
int l_mid = l - 2; // l=1->l_mid=-1,l=2->l_mid=0
float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2];
float w10 = filter_data_tmp[3];
float w11 = filter_data_tmp[4];
float w12 = filter_data_tmp[5];
float w20 = filter_data_tmp[6];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
output_data[0] =
(w11 * input_data[0] + w12 * input_data[1] + w21 * input_data[l] +
w22 * input_data[l + 1] + bias_data[j]) *
newscale_data[j] +
newbias_data[j];
output_data[l - 1] = (w10 * input_data[l - 2] + w11 * input_data[l - 1] +
w20 * input_data[2 * l - 2] +
w21 * input_data[2 * l - 1] + bias_data[j]) *
newscale_data[j] +
newbias_data[j];
output_data[(l - 1) * l] =
(w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] +
w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1] +
bias_data[j]) *
newscale_data[j] +
newbias_data[j];
output_data[l * l - 1] = (w00 * input_data[(l - 2) * (l + 1)] +
w01 * input_data[(l - 2) * (l + 1) + 1] +
w10 * input_data[l * l - 2] +
w11 * input_data[l * l - 1] + bias_data[j]) *
newscale_data[j] +
newbias_data[j];
if (if_relu) {
output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
output_data[l - 1] = output_data[l - 1] < 0 ? 0 : output_data[l - 1];
output_data[(l - 1) * l] =
output_data[(l - 1) * l] < 0 ? 0 : output_data[(l - 1) * l];
output_data[l * l - 1] =
output_data[l * l - 1] < 0 ? 0 : output_data[l * l - 1];
}
for (int i = 1; i < l - 1; ++i) {
output_data[i * l] =
(w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] +
w11 * input_data[i * l] + w12 * input_data[i * l + 1] +
w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1] +
bias_data[j]) *
newscale_data[j] +
newbias_data[j];
output_data[i * l + l - 1] =
(w00 * input_data[i * l + l - 1 - l - 1] +
w01 * input_data[i * l + l - 1 - l] +
w10 * input_data[i * l + l - 1 - 1] +
w11 * input_data[i * l + l - 1] +
w20 * input_data[i * l + l - 1 + l - 1] +
w21 * input_data[i * l + l - 1 + l] + bias_data[j]) *
newscale_data[j] +
newbias_data[j];
if (if_relu) {
output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i * l];
output_data[i * l + l - 1] =
output_data[i * l + l - 1] < 0 ? 0 : output_data[i * l + l - 1];
}
}
// top 1 row and bottom 1 row
const float *input_tmp = input_data;
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
tmp3, tmp4, tmp5, out0;
in0 = vld1q_f32(input_tmp);
in2 = vld1q_f32(input_tmp + l);
const float *input_tmp_end = input_tmp + (l - 2) * l;
in4 = vld1q_f32(input_tmp_end);
in6 = vld1q_f32(input_tmp_end + l);
int c_mid = l_mid;
auto output_ptr = output_data + 1;
for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + l + 4);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vaddq_f32(out0, vbias);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
in5 = vld1q_f32(input_tmp_end + 4);
in7 = vld1q_f32(input_tmp_end + l + 4);
tmp0 = vextq_f32(in4, in5, 1);
tmp1 = vextq_f32(in4, in5, 2);
tmp2 = vextq_f32(in6, in7, 1);
tmp3 = vextq_f32(in6, in7, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vaddq_f32(out0, vbias);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr + (l - 1) * l, out0);
// can optimize to each 8 stride.
input_tmp += 4;
input_tmp_end += 4;
output_ptr += 4;
in0 = in1;
in2 = in3;
in4 = in5;
in6 = in7;
}
// top right pad
float32x4_t pad0 = vdupq_n_f32(input_data[l - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[2 * l - 1]);
tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2);
tmp2 = vextq_f32(in2, pad1, 1);
tmp3 = vextq_f32(in2, pad1, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vaddq_f32(out0, vbias);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
// bottom right pad
float32x4_t pad2 = vdupq_n_f32(input_data[l * l - 1 - l]);
float32x4_t pad3 = vdupq_n_f32(input_data[l * l - 1]);
tmp0 = vextq_f32(in4, pad2, 1);
tmp1 = vextq_f32(in4, pad2, 2);
tmp2 = vextq_f32(in6, pad3, 1);
tmp3 = vextq_f32(in6, pad3, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vaddq_f32(out0, vbias);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 2);
}
}
// mid
for (int i = 0; i < l - 2; ++i) {
auto output_ptr = output_data + (i + 1) * l + 1;
input_tmp = input_data + i * l;
auto in0_tmp = vld1q_f32(input_tmp);
auto in2_tmp = vld1q_f32(input_tmp + l);
auto in4_tmp = vld1q_f32(input_tmp + l + l);
c_mid = l_mid;
for (; c_mid > 3; c_mid -= 4) {
auto in1_tmp = vld1q_f32(input_tmp + 4);
auto in3_tmp = vld1q_f32(input_tmp + l + 4);
auto in5_tmp = vld1q_f32(input_tmp + l + l + 4);
tmp0 = vextq_f32(in0_tmp, in1_tmp, 1);
tmp1 = vextq_f32(in0_tmp, in1_tmp, 2);
tmp2 = vextq_f32(in2_tmp, in3_tmp, 1);
tmp3 = vextq_f32(in2_tmp, in3_tmp, 2);
tmp4 = vextq_f32(in4_tmp, in5_tmp, 1);
tmp5 = vextq_f32(in4_tmp, in5_tmp, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vaddq_f32(out0, vbias);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
output_ptr += 4;
input_tmp += 4;
in0_tmp = in1_tmp;
in2_tmp = in3_tmp;
in4_tmp = in5_tmp;
}
float32x4_t pad0 = vdupq_n_f32(input_data[i * l + l - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[i * l + l - 1 + l]);
float32x4_t pad2 = vdupq_n_f32(input_data[i * l + l - 1 + l + l]);
tmp0 = vextq_f32(in0_tmp, pad0, 1);
tmp1 = vextq_f32(in0_tmp, pad0, 2);
tmp2 = vextq_f32(in2_tmp, pad1, 1);
tmp3 = vextq_f32(in2_tmp, pad1, 2);
tmp4 = vextq_f32(in4_tmp, pad2, 1);
tmp5 = vextq_f32(in4_tmp, pad2, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vaddq_f32(out0, vbias);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
}
output_data += hxw;
input_data += hxw;
filter_data_tmp += 9;
}
}
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -32,6 +32,10 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
Tensor *output, bool if_bias);
void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor *bias, bool if_bias);
void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, Tensor filter,
Tensor *output, Tensor *bias, bool if_bias,
Tensor *new_scale, Tensor *new_bias,
bool if_bn, bool if_relu);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册