提交 d64b527b 编写于 作者: E eclipsess

remove Init const

上级 35f359bd
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define FUSION_CONVADDBNRELU_OP
#ifdef FUSION_CONVADDBNRELU_OP #ifdef FUSION_CONVADDBNRELU_OP
#pragma once #pragma once
...@@ -79,11 +78,13 @@ class FusionConvAddBNReluOp ...@@ -79,11 +78,13 @@ class FusionConvAddBNReluOp
}; };
#ifdef PADDLE_MOBILE_CPU #ifdef PADDLE_MOBILE_CPU
//#ifndef FUSION_CONV_ADD_BN_RELU_REGISTER //#ifndef FUSION_CONV_ADD_BN_RELU_REGISTER
// static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar( // static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar(
// new FusionConvAddBNReluMatcher()); // new FusionConvAddBNReluMatcher());
//#define FUSION_CONV_ADD_BN_RELU_REGISTER //#define FUSION_CONV_ADD_BN_RELU_REGISTER
//#endif //#endif
#endif #endif
#ifdef PADDLE_MOBILE_MALI_GPU #ifdef PADDLE_MOBILE_MALI_GPU
......
...@@ -21,7 +21,7 @@ namespace paddle_mobile { ...@@ -21,7 +21,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool BatchNormKernel<CPU, float>::Init(BatchNormParam *param) const { bool BatchNormKernel<CPU, float>::Init(BatchNormParam *param) {
return true; return true;
} }
......
...@@ -111,7 +111,7 @@ void DecodeCenterSize(const framework::Tensor& target_box, ...@@ -111,7 +111,7 @@ void DecodeCenterSize(const framework::Tensor& target_box,
} }
template <> template <>
bool BoxCoderKernel<CPU, float>::Init(BoxCoderParam* param) const { bool BoxCoderKernel<CPU, float>::Init(BoxCoderParam* param) {
return true; return true;
} }
......
...@@ -53,7 +53,7 @@ class ConcatFunctor { ...@@ -53,7 +53,7 @@ class ConcatFunctor {
}; };
template <> template <>
bool ConcatKernel<CPU, float>::Init(ConcatParam *param) const { bool ConcatKernel<CPU, float>::Init(ConcatParam *param) {
return true; return true;
} }
......
...@@ -21,8 +21,7 @@ namespace paddle_mobile { ...@@ -21,8 +21,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ConvAddBNReluKernel<CPU, float>::Init( bool ConvAddBNReluKernel<CPU, float>::Init(FusionConvAddBNReluParam *param) {
FusionConvAddBNReluParam *param) const {
const Tensor *mean = (*param).InputMean(); const Tensor *mean = (*param).InputMean();
const Tensor *variance = (*param).InputVariance(); const Tensor *variance = (*param).InputVariance();
const Tensor *scale = (*param).InputScale(); const Tensor *scale = (*param).InputScale();
......
...@@ -19,7 +19,7 @@ namespace paddle_mobile { ...@@ -19,7 +19,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ConvAddKernel<CPU, float>::Init(FusionConvAddParam *param) const { bool ConvAddKernel<CPU, float>::Init(FusionConvAddParam *param) {
return true; return true;
} }
......
...@@ -21,7 +21,7 @@ namespace paddle_mobile { ...@@ -21,7 +21,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ConvAddReluKernel<CPU, float>::Init(FusionConvAddReluParam *param) const { bool ConvAddReluKernel<CPU, float>::Init(FusionConvAddReluParam *param) {
return true; return true;
} }
......
...@@ -21,7 +21,7 @@ namespace paddle_mobile { ...@@ -21,7 +21,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ConvKernel<CPU, float>::Init(ConvParam *param) const { bool ConvKernel<CPU, float>::Init(ConvParam *param) {
return true; return true;
} }
......
...@@ -21,7 +21,7 @@ namespace paddle_mobile { ...@@ -21,7 +21,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool DepthwiseConvKernel<CPU, float>::Init(ConvParam *param) const { bool DepthwiseConvKernel<CPU, float>::Init(ConvParam *param) {
return true; return true;
} }
......
...@@ -27,7 +27,7 @@ struct AddFunctor { ...@@ -27,7 +27,7 @@ struct AddFunctor {
}; };
template <> template <>
bool ElementwiseAddKernel<CPU, float>::Init(ElementwiseAddParam *param) const { bool ElementwiseAddKernel<CPU, float>::Init(ElementwiseAddParam *param) {
return true; return true;
} }
......
...@@ -22,7 +22,7 @@ namespace paddle_mobile { ...@@ -22,7 +22,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool FusionFcKernel<CPU, float>::Init(FusionFcParam *param) const { bool FusionFcKernel<CPU, float>::Init(FusionFcParam *param) {
return true; return true;
} }
......
...@@ -22,7 +22,7 @@ namespace paddle_mobile { ...@@ -22,7 +22,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool LrnKernel<CPU, float>::Init(LrnParam *param) const { bool LrnKernel<CPU, float>::Init(LrnParam *param) {
return true; return true;
} }
......
...@@ -22,7 +22,7 @@ namespace paddle_mobile { ...@@ -22,7 +22,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool MulKernel<CPU, float>::Init(MulParam *param) const { bool MulKernel<CPU, float>::Init(MulParam *param) {
return true; return true;
} }
......
...@@ -204,7 +204,7 @@ void MultiClassOutput(const Tensor& scores, const Tensor& bboxes, ...@@ -204,7 +204,7 @@ void MultiClassOutput(const Tensor& scores, const Tensor& bboxes,
} }
template <> template <>
bool MultiClassNMSKernel<CPU, float>::Init(MultiClassNMSParam* param) const { bool MultiClassNMSKernel<CPU, float>::Init(MultiClassNMSParam* param) {
return true; return true;
} }
......
...@@ -36,7 +36,7 @@ inline void PoolBasic(std::string pooling_type, std::vector<int> ksize, ...@@ -36,7 +36,7 @@ inline void PoolBasic(std::string pooling_type, std::vector<int> ksize,
} }
template <> template <>
bool PoolKernel<CPU, float>::Init(PoolParam *param) const { bool PoolKernel<CPU, float>::Init(PoolParam *param) {
return true; return true;
} }
......
...@@ -27,7 +27,7 @@ struct ClipFunctor { ...@@ -27,7 +27,7 @@ struct ClipFunctor {
}; };
template <> template <>
bool PriorBoxKernel<CPU, float>::Init(PriorBoxParam *param) const { bool PriorBoxKernel<CPU, float>::Init(PriorBoxParam *param) {
return true; return true;
} }
......
...@@ -26,7 +26,7 @@ struct ReluFunctor { ...@@ -26,7 +26,7 @@ struct ReluFunctor {
}; };
template <> template <>
bool ReluKernel<CPU, float>::Init(ReluParam *param) const { bool ReluKernel<CPU, float>::Init(ReluParam *param) {
return true; return true;
} }
......
...@@ -20,7 +20,7 @@ namespace paddle_mobile { ...@@ -20,7 +20,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ReshapeKernel<CPU, float>::Init(ReshapeParam *param) const { bool ReshapeKernel<CPU, float>::Init(ReshapeParam *param) {
return true; return true;
} }
......
...@@ -72,7 +72,7 @@ void sigmoid(const Tensor *X, Tensor *Y) { ...@@ -72,7 +72,7 @@ void sigmoid(const Tensor *X, Tensor *Y) {
} }
template <> template <>
bool SigmoidKernel<CPU, float>::Init(SigmoidParam *param) const { bool SigmoidKernel<CPU, float>::Init(SigmoidParam *param) {
return true; return true;
} }
......
...@@ -20,7 +20,7 @@ namespace paddle_mobile { ...@@ -20,7 +20,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool SoftmaxKernel<CPU, float>::Init(SoftmaxParam *param) const { bool SoftmaxKernel<CPU, float>::Init(SoftmaxParam *param) {
return true; return true;
} }
......
...@@ -35,7 +35,7 @@ namespace operators { ...@@ -35,7 +35,7 @@ namespace operators {
// } // }
template <> template <>
bool TransposeKernel<CPU, float>::Init(TransposeParam* param) const { bool TransposeKernel<CPU, float>::Init(TransposeParam* param) {
return true; return true;
} }
......
...@@ -29,7 +29,7 @@ class BatchNormKernel ...@@ -29,7 +29,7 @@ class BatchNormKernel
: public framework::OpKernelBase<DeviceType, BatchNormParam> { : public framework::OpKernelBase<DeviceType, BatchNormParam> {
public: public:
void Compute(const BatchNormParam &param) const; void Compute(const BatchNormParam &param) const;
bool Init(BatchNormParam *param) const; bool Init(BatchNormParam *param);
}; };
} // namespace operators } // namespace operators
......
...@@ -30,7 +30,7 @@ class BoxCoderKernel ...@@ -30,7 +30,7 @@ class BoxCoderKernel
: public framework::OpKernelBase<DeviceType, BoxCoderParam> { : public framework::OpKernelBase<DeviceType, BoxCoderParam> {
public: public:
void Compute(const BoxCoderParam& param) const; void Compute(const BoxCoderParam& param) const;
bool Init(BoxCoderParam* param) const; bool Init(BoxCoderParam* param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#pragma once #pragma once
#include "operators/kernel/conv_add_bn_relu_kernel.h" #include "operators/kernel/conv_add_bn_relu_kernel.h"
#include "operators/math/depthwiseconv3x3s1p1.h" #include "operators/math/depthwise_conv_3x3.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -24,23 +24,12 @@ namespace operators { ...@@ -24,23 +24,12 @@ namespace operators {
template <typename P> template <typename P>
void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) { void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
DLOG << "input: " << *input;
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
DLOG << "filter: " << filter;
Tensor bias = *param.Bias(); Tensor bias = *param.Bias();
DLOG << "bias: " << bias;
Tensor new_bias = *param.NewBias(); Tensor new_bias = *param.NewBias();
Tensor new_scale = *param.NewScale(); Tensor new_scale = *param.NewScale();
auto new_bias_ptr = new_bias.data<float>(); auto new_bias_ptr = new_bias.data<float>();
auto new_scale_ptr = new_scale.data<float>(); auto new_scale_ptr = new_scale.data<float>();
//
// for(int i = 0; i < new_scale.numel(); i++){
// std::cout << "new_scale " << new_scale_ptr[i] <<std::endl;
// }
// for(int i = 0; i < new_bias.numel(); i++){
// std::cout << "new_bias " << new_bias_ptr[i] <<std::endl;
// }
int axis = param.Axis(); int axis = param.Axis();
int groups = param.Groups(); int groups = param.Groups();
std::vector<int> strides = param.Strides(); std::vector<int> strides = param.Strides();
...@@ -50,8 +39,8 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) { ...@@ -50,8 +39,8 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) {
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims())); std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
if (filter_shape_vec[2] == 3 && strides[0] == 1 && groups > 1) { if (filter_shape_vec[2] == 3 && strides[0] == 1 && groups > 1) {
math::DepthwiseConv3x3s1p1(input, filter, output, &bias, 1, &new_scale, math::DepthwiseConvAddBNRelu3x3s1p1(input, filter, output, &bias, 1,
&new_bias, 1, 1); &new_scale, &new_bias, 1, 1);
} else { } else {
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
...@@ -131,11 +120,12 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) { ...@@ -131,11 +120,12 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) {
auto output_ptr = output->data<float>(); auto output_ptr = output->data<float>();
for (int c = 0; c < output_matrix_shape[0]; c++) { for (int c = 0; c < output_matrix_shape[0]; c++) {
// int start = c * output_matrix_shape[1]; int start = c * output_matrix_shape[1];
for (int j = 0; j < output_matrix_shape[1]; j++) { for (int j = 0; j < output_matrix_shape[1]; j++) {
// output_ptr[start + j] = output_ptr[start output_ptr[start + j] =
// +j]*new_scale_ptr[c]+new_bias_ptr[c]; output_ptr[start + j] = output_ptr[start + j] * new_scale_ptr[c] + new_bias_ptr[c];
// output_ptr[start+j]< 0 ? 0 : output_ptr[start +j]; output_ptr[start + j] =
output_ptr[start + j] < 0 ? 0 : output_ptr[start + j];
} }
} }
} }
......
...@@ -27,7 +27,7 @@ template <typename DeviceType, typename T> ...@@ -27,7 +27,7 @@ template <typename DeviceType, typename T>
class ConcatKernel : public framework::OpKernelBase<DeviceType, ConcatParam> { class ConcatKernel : public framework::OpKernelBase<DeviceType, ConcatParam> {
public: public:
void Compute(const ConcatParam &param) const; void Compute(const ConcatParam &param) const;
bool Init(ConcatParam *param) const; bool Init(ConcatParam *param);
}; };
} // namespace operators } // namespace operators
......
...@@ -36,7 +36,7 @@ class ConvAddBNReluKernel ...@@ -36,7 +36,7 @@ class ConvAddBNReluKernel
: public OpKernelBase<DeviceType, FusionConvAddBNReluParam> { : public OpKernelBase<DeviceType, FusionConvAddBNReluParam> {
public: public:
void Compute(const FusionConvAddBNReluParam &param) const; void Compute(const FusionConvAddBNReluParam &param) const;
bool Init(FusionConvAddBNReluParam *param) const; bool Init(FusionConvAddBNReluParam *param);
}; };
} // namespace operators } // namespace operators
......
...@@ -40,7 +40,7 @@ template <typename DeviceType, typename T> ...@@ -40,7 +40,7 @@ template <typename DeviceType, typename T>
class ConvAddKernel : public OpKernelBase<DeviceType, FusionConvAddParam> { class ConvAddKernel : public OpKernelBase<DeviceType, FusionConvAddParam> {
public: public:
void Compute(const FusionConvAddParam &param) const; void Compute(const FusionConvAddParam &param) const;
bool Init(FusionConvAddParam *param) const; bool Init(FusionConvAddParam *param);
}; };
} // namespace operators } // namespace operators
......
...@@ -36,7 +36,7 @@ class ConvAddReluKernel ...@@ -36,7 +36,7 @@ class ConvAddReluKernel
: public OpKernelBase<DeviceType, FusionConvAddReluParam> { : public OpKernelBase<DeviceType, FusionConvAddReluParam> {
public: public:
void Compute(const FusionConvAddReluParam &param) const; void Compute(const FusionConvAddReluParam &param) const;
bool Init(FusionConvAddReluParam *param) const; bool Init(FusionConvAddReluParam *param);
}; };
} // namespace operators } // namespace operators
......
...@@ -32,7 +32,7 @@ template <typename DeviceType, typename T> ...@@ -32,7 +32,7 @@ template <typename DeviceType, typename T>
class ConvKernel : public OpKernelBase<DeviceType, ConvParam> { class ConvKernel : public OpKernelBase<DeviceType, ConvParam> {
public: public:
void Compute(const ConvParam &param) const; void Compute(const ConvParam &param) const;
bool Init(ConvParam *param) const; bool Init(ConvParam *param);
}; };
} // namespace operators } // namespace operators
......
...@@ -31,7 +31,7 @@ template <typename DeviceType, typename T> ...@@ -31,7 +31,7 @@ template <typename DeviceType, typename T>
class DepthwiseConvKernel : public OpKernelBase<DeviceType, ConvParam> { class DepthwiseConvKernel : public OpKernelBase<DeviceType, ConvParam> {
public: public:
void Compute(const ConvParam &param) const; void Compute(const ConvParam &param) const;
bool Init(ConvParam *param) const; bool Init(ConvParam *param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -30,7 +30,7 @@ class ElementwiseAddKernel ...@@ -30,7 +30,7 @@ class ElementwiseAddKernel
: public framework::OpKernelBase<DeviceType, ElementwiseAddParam> { : public framework::OpKernelBase<DeviceType, ElementwiseAddParam> {
public: public:
void Compute(const ElementwiseAddParam &param) const; void Compute(const ElementwiseAddParam &param) const;
bool Init(ElementwiseAddParam *param) const; bool Init(ElementwiseAddParam *param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -20,7 +20,7 @@ namespace paddle_mobile { ...@@ -20,7 +20,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ConvKernel<FPGA, float>::Init(ConvParam *param) const { bool ConvKernel<FPGA, float>::Init(ConvParam *param) {
return true; return true;
} }
......
...@@ -28,7 +28,7 @@ class FusionFcKernel ...@@ -28,7 +28,7 @@ class FusionFcKernel
: public framework::OpKernelBase<DeviceType, FusionFcParam> { : public framework::OpKernelBase<DeviceType, FusionFcParam> {
public: public:
void Compute(const FusionFcParam& param) const; void Compute(const FusionFcParam& param) const;
bool Init(FusionFcParam* param) const; bool Init(FusionFcParam* param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -170,7 +170,7 @@ template <typename DeviceType, typename T> ...@@ -170,7 +170,7 @@ template <typename DeviceType, typename T>
class LrnKernel : public framework::OpKernelBase<DeviceType, LrnParam> { class LrnKernel : public framework::OpKernelBase<DeviceType, LrnParam> {
public: public:
void Compute(const LrnParam &param) const; void Compute(const LrnParam &param) const;
bool Init(LrnParam *param) const; bool Init(LrnParam *param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -128,7 +128,7 @@ class AclBatchNormOp : public acl::ACLOperator { ...@@ -128,7 +128,7 @@ class AclBatchNormOp : public acl::ACLOperator {
}; };
template <> template <>
bool BatchNormKernel<GPU_MALI, float>::Init(BatchNormParam* param) const { bool BatchNormKernel<GPU_MALI, float>::Init(BatchNormParam* param) {
AclBatchNormOp<GPU_MALI, float>* acl_op = AclBatchNormOp<GPU_MALI, float>* acl_op =
reinterpret_cast<AclBatchNormOp<GPU_MALI, float>*>(this->GetAclOp()); reinterpret_cast<AclBatchNormOp<GPU_MALI, float>*>(this->GetAclOp());
if (acl_op == nullptr) { if (acl_op == nullptr) {
......
...@@ -195,7 +195,7 @@ class AclConvOp : public acl::ACLOperator { ...@@ -195,7 +195,7 @@ class AclConvOp : public acl::ACLOperator {
}; };
template <> template <>
bool ConvKernel<GPU_MALI, float>::Init(ConvParam* param) const { bool ConvKernel<GPU_MALI, float>::Init(ConvParam* param) {
AclConvOp<GPU_MALI, float>* acl_op = AclConvOp<GPU_MALI, float>* acl_op =
reinterpret_cast<AclConvOp<GPU_MALI, float>*>(this->GetAclOp()); reinterpret_cast<AclConvOp<GPU_MALI, float>*>(this->GetAclOp());
if (acl_op == nullptr) { if (acl_op == nullptr) {
......
...@@ -27,8 +27,7 @@ struct AddFunctor { ...@@ -27,8 +27,7 @@ struct AddFunctor {
}; };
template <> template <>
bool ElementwiseAddKernel<GPU_MALI, float>::Init( bool ElementwiseAddKernel<GPU_MALI, float>::Init(ElementwiseAddParam *param) {
ElementwiseAddParam *param) const {
return true; return true;
} }
......
...@@ -22,7 +22,7 @@ namespace paddle_mobile { ...@@ -22,7 +22,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool FusionFcKernel<GPU_MALI, float>::Init(FusionFcParam *param) const { bool FusionFcKernel<GPU_MALI, float>::Init(FusionFcParam *param) {
return true; return true;
} }
......
...@@ -22,7 +22,7 @@ namespace paddle_mobile { ...@@ -22,7 +22,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool MulKernel<GPU_MALI, float>::Init(MulParam *param) const { bool MulKernel<GPU_MALI, float>::Init(MulParam *param) {
return true; return true;
} }
......
...@@ -22,7 +22,7 @@ namespace paddle_mobile { ...@@ -22,7 +22,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ReshapeKernel<GPU_MALI, float>::Init(ReshapeParam *param) const { bool ReshapeKernel<GPU_MALI, float>::Init(ReshapeParam *param) {
return true; return true;
} }
......
...@@ -29,7 +29,7 @@ template <typename DeviceType, typename T> ...@@ -29,7 +29,7 @@ template <typename DeviceType, typename T>
class MulKernel : public framework::OpKernelBase<DeviceType, MulParam> { class MulKernel : public framework::OpKernelBase<DeviceType, MulParam> {
public: public:
void Compute(const MulParam &param) const; void Compute(const MulParam &param) const;
bool Init(MulParam *param) const; bool Init(MulParam *param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -28,7 +28,7 @@ class MultiClassNMSKernel ...@@ -28,7 +28,7 @@ class MultiClassNMSKernel
: public framework::OpKernelBase<DeviceType, MultiClassNMSParam> { : public framework::OpKernelBase<DeviceType, MultiClassNMSParam> {
public: public:
void Compute(const MultiClassNMSParam& param) const; void Compute(const MultiClassNMSParam& param) const;
bool Init(MultiClassNMSParam* param) const; bool Init(MultiClassNMSParam* param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -28,7 +28,7 @@ template <typename DeviceType, typename T> ...@@ -28,7 +28,7 @@ template <typename DeviceType, typename T>
class PoolKernel : public OpKernelBase<DeviceType, PoolParam> { class PoolKernel : public OpKernelBase<DeviceType, PoolParam> {
public: public:
void Compute(const PoolParam &param) const override; void Compute(const PoolParam &param) const override;
bool Init(PoolParam *param) const; bool Init(PoolParam *param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -55,7 +55,7 @@ class PriorBoxKernel ...@@ -55,7 +55,7 @@ class PriorBoxKernel
: public framework::OpKernelBase<DeviceType, PriorBoxParam> { : public framework::OpKernelBase<DeviceType, PriorBoxParam> {
public: public:
void Compute(const PriorBoxParam& param) const; void Compute(const PriorBoxParam& param) const;
bool Init(PriorBoxParam* param) const; bool Init(PriorBoxParam* param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -27,7 +27,7 @@ template <typename DeviceType, typename T> ...@@ -27,7 +27,7 @@ template <typename DeviceType, typename T>
class ReluKernel : public framework::OpKernelBase<DeviceType, ReluParam> { class ReluKernel : public framework::OpKernelBase<DeviceType, ReluParam> {
public: public:
void Compute(const ReluParam& param) const; void Compute(const ReluParam& param) const;
bool Init(ReluParam* param) const; bool Init(ReluParam* param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -71,7 +71,7 @@ template <typename DeviceType, typename T> ...@@ -71,7 +71,7 @@ template <typename DeviceType, typename T>
class ReshapeKernel : public framework::OpKernelBase<DeviceType, ReshapeParam> { class ReshapeKernel : public framework::OpKernelBase<DeviceType, ReshapeParam> {
public: public:
void Compute(const ReshapeParam& param) const; void Compute(const ReshapeParam& param) const;
bool Init(ReshapeParam* param) const; bool Init(ReshapeParam* param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -26,7 +26,7 @@ template <typename DeviceType, typename T> ...@@ -26,7 +26,7 @@ template <typename DeviceType, typename T>
class SigmoidKernel : public OpKernelBase<DeviceType, SigmoidParam> { class SigmoidKernel : public OpKernelBase<DeviceType, SigmoidParam> {
public: public:
void Compute(const SigmoidParam& param) const override; void Compute(const SigmoidParam& param) const override;
bool Init(SigmoidParam* param) const; bool Init(SigmoidParam* param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -29,7 +29,7 @@ template <typename DeviceType, typename T> ...@@ -29,7 +29,7 @@ template <typename DeviceType, typename T>
class SoftmaxKernel : public OpKernelBase<DeviceType, SoftmaxParam> { class SoftmaxKernel : public OpKernelBase<DeviceType, SoftmaxParam> {
public: public:
void Compute(const SoftmaxParam &param) const override; void Compute(const SoftmaxParam &param) const override;
bool Init(SoftmaxParam *param) const; bool Init(SoftmaxParam *param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -29,7 +29,7 @@ class TransposeKernel ...@@ -29,7 +29,7 @@ class TransposeKernel
: public framework::OpKernelBase<DeviceType, TransposeParam> { : public framework::OpKernelBase<DeviceType, TransposeParam> {
public: public:
void Compute(const TransposeParam& param) const; void Compute(const TransposeParam& param) const;
bool Init(TransposeParam* param) const; bool Init(TransposeParam* param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -502,6 +502,322 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -502,6 +502,322 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
} }
} }
} }
void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, Tensor filter,
Tensor *output, Tensor *bias, bool if_bias,
Tensor *new_scale, Tensor *new_bias,
bool if_bn, bool if_relu) {
const float *input_data = input->data<float>();
const float *filter_data = filter.data<float>();
float *output_data = output->data<float>();
const float *bias_data = bias->data<float>();
const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>();
const int h = static_cast<int>(input->dims()[2]);
const int w = static_cast<int>(input->dims()[3]);
const int l = h;
const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]);
const int hxw = h * w;
float32x4_t vbias = vdupq_n_f32(0.0);
float32x4_t vnewbias = vdupq_n_f32(0.0);
float32x4_t vnewscale = vdupq_n_f32(1.0);
float32x4_t vzero = vdupq_n_f32(0);
for (int b = 0; b < batch_size; ++b) {
const float *filter_data_tmp = filter_data;
for (int j = 0; j < c; ++j) {
if (if_bias) {
vbias = vdupq_n_f32(bias_data[j]);
}
if (if_bn) {
vnewbias = vdupq_n_f32(newbias_data[j]);
vnewscale = vdupq_n_f32(newscale_data[j]);
}
int l_mid = l - 2; // l=1->l_mid=-1,l=2->l_mid=0
float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2];
float w10 = filter_data_tmp[3];
float w11 = filter_data_tmp[4];
float w12 = filter_data_tmp[5];
float w20 = filter_data_tmp[6];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
output_data[0] =
(w11 * input_data[0] + w12 * input_data[1] + w21 * input_data[l] +
w22 * input_data[l + 1] + bias_data[j]) *
newscale_data[j] +
newbias_data[j];
output_data[l - 1] = (w10 * input_data[l - 2] + w11 * input_data[l - 1] +
w20 * input_data[2 * l - 2] +
w21 * input_data[2 * l - 1] + bias_data[j]) *
newscale_data[j] +
newbias_data[j];
output_data[(l - 1) * l] =
(w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] +
w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1] +
bias_data[j]) *
newscale_data[j] +
newbias_data[j];
output_data[l * l - 1] = (w00 * input_data[(l - 2) * (l + 1)] +
w01 * input_data[(l - 2) * (l + 1) + 1] +
w10 * input_data[l * l - 2] +
w11 * input_data[l * l - 1] + bias_data[j]) *
newscale_data[j] +
newbias_data[j];
if (if_relu) {
output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
output_data[l - 1] = output_data[l - 1] < 0 ? 0 : output_data[l - 1];
output_data[(l - 1) * l] =
output_data[(l - 1) * l] < 0 ? 0 : output_data[(l - 1) * l];
output_data[l * l - 1] =
output_data[l * l - 1] < 0 ? 0 : output_data[l * l - 1];
}
for (int i = 1; i < l - 1; ++i) {
output_data[i * l] =
(w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] +
w11 * input_data[i * l] + w12 * input_data[i * l + 1] +
w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1] +
bias_data[j]) *
newscale_data[j] +
newbias_data[j];
output_data[i * l + l - 1] =
(w00 * input_data[i * l + l - 1 - l - 1] +
w01 * input_data[i * l + l - 1 - l] +
w10 * input_data[i * l + l - 1 - 1] +
w11 * input_data[i * l + l - 1] +
w20 * input_data[i * l + l - 1 + l - 1] +
w21 * input_data[i * l + l - 1 + l] + bias_data[j]) *
newscale_data[j] +
newbias_data[j];
if (if_relu) {
output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i * l];
output_data[i * l + l - 1] =
output_data[i * l + l - 1] < 0 ? 0 : output_data[i * l + l - 1];
}
}
// top 1 row and bottom 1 row
const float *input_tmp = input_data;
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
tmp3, tmp4, tmp5, out0;
in0 = vld1q_f32(input_tmp);
in2 = vld1q_f32(input_tmp + l);
const float *input_tmp_end = input_tmp + (l - 2) * l;
in4 = vld1q_f32(input_tmp_end);
in6 = vld1q_f32(input_tmp_end + l);
int c_mid = l_mid;
auto output_ptr = output_data + 1;
for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + l + 4);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vaddq_f32(out0, vbias);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
in5 = vld1q_f32(input_tmp_end + 4);
in7 = vld1q_f32(input_tmp_end + l + 4);
tmp0 = vextq_f32(in4, in5, 1);
tmp1 = vextq_f32(in4, in5, 2);
tmp2 = vextq_f32(in6, in7, 1);
tmp3 = vextq_f32(in6, in7, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vaddq_f32(out0, vbias);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr + (l - 1) * l, out0);
// can optimize to each 8 stride.
input_tmp += 4;
input_tmp_end += 4;
output_ptr += 4;
in0 = in1;
in2 = in3;
in4 = in5;
in6 = in7;
}
// top right pad
float32x4_t pad0 = vdupq_n_f32(input_data[l - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[2 * l - 1]);
tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2);
tmp2 = vextq_f32(in2, pad1, 1);
tmp3 = vextq_f32(in2, pad1, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vaddq_f32(out0, vbias);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
// bottom right pad
float32x4_t pad2 = vdupq_n_f32(input_data[l * l - 1 - l]);
float32x4_t pad3 = vdupq_n_f32(input_data[l * l - 1]);
tmp0 = vextq_f32(in4, pad2, 1);
tmp1 = vextq_f32(in4, pad2, 2);
tmp2 = vextq_f32(in6, pad3, 1);
tmp3 = vextq_f32(in6, pad3, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vaddq_f32(out0, vbias);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 2);
}
}
// mid
for (int i = 0; i < l - 2; ++i) {
auto output_ptr = output_data + (i + 1) * l + 1;
input_tmp = input_data + i * l;
auto in0_tmp = vld1q_f32(input_tmp);
auto in2_tmp = vld1q_f32(input_tmp + l);
auto in4_tmp = vld1q_f32(input_tmp + l + l);
c_mid = l_mid;
for (; c_mid > 3; c_mid -= 4) {
auto in1_tmp = vld1q_f32(input_tmp + 4);
auto in3_tmp = vld1q_f32(input_tmp + l + 4);
auto in5_tmp = vld1q_f32(input_tmp + l + l + 4);
tmp0 = vextq_f32(in0_tmp, in1_tmp, 1);
tmp1 = vextq_f32(in0_tmp, in1_tmp, 2);
tmp2 = vextq_f32(in2_tmp, in3_tmp, 1);
tmp3 = vextq_f32(in2_tmp, in3_tmp, 2);
tmp4 = vextq_f32(in4_tmp, in5_tmp, 1);
tmp5 = vextq_f32(in4_tmp, in5_tmp, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vaddq_f32(out0, vbias);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
output_ptr += 4;
input_tmp += 4;
in0_tmp = in1_tmp;
in2_tmp = in3_tmp;
in4_tmp = in5_tmp;
}
float32x4_t pad0 = vdupq_n_f32(input_data[i * l + l - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[i * l + l - 1 + l]);
float32x4_t pad2 = vdupq_n_f32(input_data[i * l + l - 1 + l + l]);
tmp0 = vextq_f32(in0_tmp, pad0, 1);
tmp1 = vextq_f32(in0_tmp, pad0, 2);
tmp2 = vextq_f32(in2_tmp, pad1, 1);
tmp3 = vextq_f32(in2_tmp, pad1, 2);
tmp4 = vextq_f32(in4_tmp, pad2, 1);
tmp5 = vextq_f32(in4_tmp, pad2, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vaddq_f32(out0, vbias);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
}
output_data += hxw;
input_data += hxw;
filter_data_tmp += 9;
}
}
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -32,6 +32,10 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides, ...@@ -32,6 +32,10 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
Tensor *output, bool if_bias); Tensor *output, bool if_bias);
void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor *bias, bool if_bias); Tensor *output, Tensor *bias, bool if_bias);
void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, Tensor filter,
Tensor *output, Tensor *bias, bool if_bias,
Tensor *new_scale, Tensor *new_bias,
bool if_bn, bool if_relu);
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册