diff --git a/src/common/types.cpp b/src/common/types.cpp index 6cea95546d03bc61f3dd9e6d811b3591db0bf7f2..fcffcae1ab0322b839ace5447885e87fdb78fbf8 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -72,6 +72,8 @@ const char *G_OP_TYPE_SUM = "sum"; const char *G_OP_TYPE_QUANTIZE = "quantize"; const char *G_OP_TYPE_DEQUANTIZE = "dequantize"; +const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU = "fusion_dequant_add_bn_relu"; + const char *G_OP_TYPE_TANH = "tanh"; const char *G_OP_TYPE_FUSION_DECONV_RELU = "fusion_deconv_relu"; const char *G_OP_TYPE_FUSION_DECONV_ADD = "fusion_deconv_add"; @@ -136,6 +138,7 @@ std::unordered_map< {G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}}, {G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}}, {G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}}, + {G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU, {{"X", "Scale"}, {"Out"}}}, {G_OP_TYPE_TANH, {{"X"}, {"Out"}}}, {G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_DECONV_ADD, {{"Input"}, {"Out"}}}, diff --git a/src/common/types.h b/src/common/types.h index a1a9185733cb859fab76ab00a96d7bc2824e344b..b84f802cb81678c76da8ca29ce36f43e13618c23 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -139,6 +139,7 @@ extern const char *G_OP_TYPE_ELEMENTWISE_MUL; extern const char *G_OP_TYPE_QUANTIZE; extern const char *G_OP_TYPE_DEQUANTIZE; +extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU; extern const char *G_OP_TYPE_TANH; extern const char *G_OP_TYPE_FUSION_DECONV_RELU; diff --git a/src/fpga/V2/api.cpp b/src/fpga/V2/api.cpp index 5bfd34104600668ce63a9c7d684d4482d5d804fb..d58e780c279e03b90b4ebe3731c6693615107ec4 100644 --- a/src/fpga/V2/api.cpp +++ b/src/fpga/V2/api.cpp @@ -132,11 +132,11 @@ void format_concat_output(framework::Tensor *out, int height, int width, } int format_conv_data(framework::Tensor *filter_tensor, - framework::Tensor *ofm_tensor, float *bs_ptr, int group) { + framework::Tensor *ofm_tensor, float **bs_ptr, int group) { float max_value = fpga::filter_find_max(filter_tensor); fpga::format_filter(filter_tensor, max_value, group); int aligned_num = get_aligned_filter_num(filter_tensor); - fpga::format_bias_scale_array(&bs_ptr, + fpga::format_bias_scale_array(bs_ptr, (int)filter_tensor->dims()[0], // NOLINT aligned_num); int aligned_channel = fpga::get_conv_output_channel(filter_tensor); diff --git a/src/fpga/V2/api.h b/src/fpga/V2/api.h index 1386810164d72ef849162b76a8b83fcf32082907..59c1b006183e4355ebe9316766773215b6edf12f 100644 --- a/src/fpga/V2/api.h +++ b/src/fpga/V2/api.h @@ -39,7 +39,7 @@ void format_bias_scale_array(float** bias_scale_array, int filter_num, void format_concat_output(framework::Tensor* out, int height, int width, uint32_t out_channel); int format_conv_data(framework::Tensor* filter_tensor, - framework::Tensor* ofm_tensor, float* bs_ptr, int group); + framework::Tensor* ofm_tensor, float** bs_ptr, int group); int format_fc_data(framework::Tensor* filter_tensor, framework::Tensor* ofm_tensor, float* bs_ptr); void fill_split_arg(struct SplitConvArgs* arg, framework::Tensor* input, diff --git a/src/framework/cl/cl_image.h b/src/framework/cl/cl_image.h index f94eba187f2c5610d7a20098e95015244b420ce2..1a906ba4a4f43e1e1b57bbb3652fdc19fa052a78 100644 --- a/src/framework/cl/cl_image.h +++ b/src/framework/cl/cl_image.h @@ -68,6 +68,13 @@ class CLImage { InitCLImage(context, command_queue, folder_converter); } + void InitNormalCLImage(cl_context context, cl_command_queue command_queue) { + PADDLE_MOBILE_ENFORCE(tensor_data_ != nullptr, + " need call SetTensorData first"); + CLImageConverterNormal *normal_converter = new CLImageConverterNormal(); + InitCLImage(context, command_queue, normal_converter); + } + void InitCLImage(cl_context context, cl_command_queue command_queue, CLImageConverterBase *converter) { if (image_converter_ != nullptr) { diff --git a/src/framework/load_ops.h b/src/framework/load_ops.h index 982f1c0f3525afde8475866c0121343fafc9d5a0..135ef9083e42271fe63cdc29ee53e876f532c287 100644 --- a/src/framework/load_ops.h +++ b/src/framework/load_ops.h @@ -233,3 +233,7 @@ LOAD_OP1(quantize, CPU); #ifdef DEQUANT_OP LOAD_OP1(dequantize, CPU); #endif +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP +LOAD_OP1(fusion_dequant_add_bn_relu, CPU); +LOAD_FUSION_MATCHER(fusion_dequant_add_bn_relu); +#endif diff --git a/src/operators/feed_op.cpp b/src/operators/feed_op.cpp index ac707d22696dd0a62902137607fb64c141341d77..4e496fb51d16c47d801eabada7c36dbdefdd2140 100644 --- a/src/operators/feed_op.cpp +++ b/src/operators/feed_op.cpp @@ -22,7 +22,6 @@ void FeedOp::InferShape() const { auto out_dims = this->param_.Out()->dims(); out_dims[0] = this->param_.BatchSize(); auto input_dims = this->param_.InputX()->dims(); - DLOG << input_dims.size(); if (input_dims.size() == 4) { this->param_.Out()->Resize(input_dims); } else { diff --git a/src/operators/fusion_dequant_add_bn_relu_op.cpp b/src/operators/fusion_dequant_add_bn_relu_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..80d9040afb29b7a42c742b821e9d7522c1a12827 --- /dev/null +++ b/src/operators/fusion_dequant_add_bn_relu_op.cpp @@ -0,0 +1,40 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP + +#include "operators/fusion_dequant_add_bn_relu_op.h" + +namespace paddle_mobile { +namespace operators { + +template +void FusionDequantAddBNReluOp::InferShape() const { + const auto& input_dims = this->param_.input_->dims(); + this->param_.output_->Resize(input_dims); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +REGISTER_FUSION_MATCHER(fusion_dequant_add_bn_relu, + ops::FusionDequantAddBNReluMatcher); + +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(fusion_dequant_add_bn_relu, + ops::FusionDequantAddBNReluOp); +#endif + +#endif diff --git a/src/operators/fusion_dequant_add_bn_relu_op.h b/src/operators/fusion_dequant_add_bn_relu_op.h new file mode 100644 index 0000000000000000000000000000000000000000..dbd9ad0de2ece751ffd4da05cb09f0091a5755aa --- /dev/null +++ b/src/operators/fusion_dequant_add_bn_relu_op.h @@ -0,0 +1,76 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP + +#pragma once + +#include +#include +#include "framework/operator.h" +#include "framework/program/program-optimize/fusion_op_register.h" +#include "operators/kernel/dequant_add_bn_relu_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +class FusionDequantAddBNReluMatcher : public framework::FusionOpMatcher { + public: + FusionDequantAddBNReluMatcher() { + node_ = framework::Node(G_OP_TYPE_DEQUANTIZE); + node_ > std::make_shared(G_OP_TYPE_ELEMENTWISE_ADD) > + std::make_shared(G_OP_TYPE_BATCHNORM) > + std::make_shared(G_OP_TYPE_RELU); + } + + void FolderNodes( + framework::Node *node, + std::vector> *removed_nodes) { + node->Folder(node_.Depth(), Type(), + {{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}, + {G_OP_TYPE_BATCHNORM, + {{"Scale", "BNScale"}, + {"Mean", "BNMean"}, + {"Bias", "BNBias"}, + {"Variance", "BNVariance"}}}}, + removed_nodes); + } + + std::string Type() { return G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU; } +}; + +template +class FusionDequantAddBNReluOp + : public framework::OperatorWithKernel< + DeviceType, FusionDequantAddBNReluParam, + operators::FusionDequantAddBNReluKernel> { + public: + FusionDequantAddBNReluOp(const std::string &type, + const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel< + DeviceType, FusionDequantAddBNReluParam, + operators::FusionDequantAddBNReluKernel>( + type, inputs, outputs, attrs, scope) {} + // inference output shape + void InferShape() const override; +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/fusion_fc_op.cpp b/src/operators/fusion_fc_op.cpp index 928a4d8541db11886986ffbb695cdf54b5f12c51..f2e98b2b4ceae283ddbe04af06e8926f1b8bb47f 100644 --- a/src/operators/fusion_fc_op.cpp +++ b/src/operators/fusion_fc_op.cpp @@ -60,6 +60,9 @@ REGISTER_FUSION_MATCHER(fusion_fc, ops::FusionFcMatcher); #ifdef PADDLE_MOBILE_CPU REGISTER_OPERATOR_CPU(fusion_fc, ops::FusionFcOp); #endif +#ifdef PADDLE_MOBILE_CL +REGISTER_OPERATOR_CL(fusion_fc, ops::FusionFcOp); +#endif #ifdef PADDLE_MOBILE_MALI_GPU REGISTER_OPERATOR_MALI_GPU(fusion_fc, ops::FusionFcOp); #endif diff --git a/src/operators/kernel/arm/dequant_add_bn_relu_kernel.cpp b/src/operators/kernel/arm/dequant_add_bn_relu_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bfe1935c216f94d660997b1bfa42f18e63295992 --- /dev/null +++ b/src/operators/kernel/arm/dequant_add_bn_relu_kernel.cpp @@ -0,0 +1,116 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP + +#include "operators/kernel/dequant_add_bn_relu_kernel.h" +#include +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#include +#endif + +namespace paddle_mobile { +namespace operators { + +template <> +bool FusionDequantAddBNReluKernel::Init( + FusionDequantAddBNReluParam *param) { + // elementwise add params + const Tensor *bias = param->bias_; + // batch norm params + const Tensor *bn_mean = param->bn_mean_; + const Tensor *bn_variance = param->bn_variance_; + Tensor *bn_scale = param->bn_scale_; + Tensor *bn_bias = param->bn_bias_; + const float epsilon = param->epsilon_; + + const float *bias_ptr = bias->data(); + const float *mean_ptr = bn_mean->data(); + const float *var_ptr = bn_variance->data(); + float *bn_scale_ptr = bn_scale->mutable_data(); + float *bn_bias_ptr = bn_bias->mutable_data(); + for (int c = 0; c < bn_scale->numel(); ++c) { + float inv_scale = bn_scale_ptr[c] / (std::sqrt(var_ptr[c] + epsilon)); + bn_scale_ptr[c] = inv_scale; + bn_bias_ptr[c] = inv_scale * (bias_ptr[c] - mean_ptr[c]) + bn_bias_ptr[c]; + } + return true; +} + +template <> +void FusionDequantAddBNReluKernel::Compute( + const FusionDequantAddBNReluParam ¶m) { + const int32_t *input = param.input_->data(); + const float *bn_scale = param.bn_scale_->data(); + const float *bn_bias = param.bn_bias_->data(); + // dequantize params + const float activation_scale = param.activation_scale_->data()[0]; + const float weight_scale = param.weight_scale_; + const float dequant_scale = activation_scale / weight_scale; + + float *output = param.output_->mutable_data(); + int batch_size = param.input_->dims()[0]; + int channels = param.input_->dims()[1]; + size_t spatial_size = param.input_->dims()[2] * param.input_->dims()[3]; + + #pragma omp parallel for collapse(2) + for (int batch = 0; batch < batch_size; ++batch) { + for (int c = 0; c < channels; ++c) { + float scale = bn_scale[c] * dequant_scale; + float bias = bn_bias[c]; + size_t offset = (batch * channels + c) * spatial_size; + const int32_t *x = input + offset; + float *y = output + offset; + size_t remain = spatial_size; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + int loop = spatial_size >> 4; + remain = spatial_size & 0xF; + float32x4_t __scale = vdupq_n_f32(scale); + float32x4_t __bias = vdupq_n_f32(bias); + float32x4_t __zero = vdupq_n_f32(0.f); + + for (int k = 0; k < loop; ++k, x += 16, y += 16) { + int32x4_t r0 = vld1q_s32(x); + int32x4_t r1 = vld1q_s32(x + 4); + int32x4_t r2 = vld1q_s32(x + 8); + int32x4_t r3 = vld1q_s32(x + 12); + float32x4_t f0 = vcvtq_f32_s32(r0); + float32x4_t f1 = vcvtq_f32_s32(r1); + float32x4_t f2 = vcvtq_f32_s32(r2); + float32x4_t f3 = vcvtq_f32_s32(r3); + f0 = vmlaq_f32(__bias, __scale, f0); + f1 = vmlaq_f32(__bias, __scale, f1); + f2 = vmlaq_f32(__bias, __scale, f2); + f3 = vmlaq_f32(__bias, __scale, f3); + f0 = vmaxq_f32(__zero, f0); + f1 = vmaxq_f32(__zero, f1); + f2 = vmaxq_f32(__zero, f2); + f3 = vmaxq_f32(__zero, f3); + vst1q_f32(y, f0); + vst1q_f32(y + 4, f1); + vst1q_f32(y + 8, f2); + vst1q_f32(y + 12, f3); + } +#endif // __ARM_NEON__ + for (int k = 0; k < remain; ++k) { + y[k] = std::max(scale * x[k] + bias, 0.f); + } + } + } +} + +} // namespace operators +} // namespace paddle_mobile + +#endif // FUSION_DEQUANT_ADD_BN_RELU_OP diff --git a/src/operators/kernel/arm/quantize_kernel.cpp b/src/operators/kernel/arm/quantize_kernel.cpp index e0e6d44d5226237d931c07f571610ff945b40e16..1e7623436a1a73644aca61e4634a7cd405bd64ad 100644 --- a/src/operators/kernel/arm/quantize_kernel.cpp +++ b/src/operators/kernel/arm/quantize_kernel.cpp @@ -379,8 +379,8 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, const float *x3 = input3 + h * input_w; int loop = input_w >> 4; int remain = input_w & 0xF; - int pad_loop = paddings[1] >> 1; - int pad_remain = paddings[1] & 0x1; + int pad_loop = paddings[1] >> 1; // (paddings[1] << 1) >> 2 + int pad_remain = (paddings[1] << 1) & 0x3; int remain_steps = remain; asm volatile( "vdup.f32 q0, %[scale] \n" @@ -596,7 +596,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, "store_pad_2w_%=: \n" "cmp %[pad_remain], #2 \n" - "ble store_pad_1w_%= \n" + "blt store_pad_1w_%= \n" "vst1.16 {d0[0]}, [%[y0]]! \n" "vst1.16 {d0[0]}, [%[y1]]! \n" "vst1.16 {d0[0]}, [%[y2]]! \n" @@ -605,7 +605,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, "store_pad_1w_%=: \n" "cmp %[pad_remain], #1 \n" - "ble end_%= \n" + "blt end_%= \n" "vst1.8 {d0[0]}, [%[y0]]! \n" "vst1.8 {d0[0]}, [%[y1]]! \n" "vst1.8 {d0[0]}, [%[y2]]! \n" @@ -669,8 +669,8 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, const float *x0 = input0 + h * input_w; int loop = input_w >> 4; int remain = input_w & 0xF; - int pad_loop = paddings[1] >> 1; - int pad_remain = paddings[1] & 0x1; + int pad_loop = paddings[1] >> 1; // (paddings[1] << 1) >> 2 + int pad_remain = (paddings[1] << 1) & 0x3; asm volatile( "vdup.f32 q0, %[scale] \n" "cmp %[loop], #0 \n" @@ -754,14 +754,14 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, "pad_remain_%=: \n" "cmp %[pad_remain], #2 \n" - "ble store_pad_1w_%= \n" + "blt store_pad_1w_%= \n" "vst1.16 {d0[0]}, [%[y0]]! \n" "sub %[pad_remain], #2 \n" "store_pad_1w_%=: \n" "cmp %[pad_remain], #1 \n" - "ble end_%= \n" - "vst1.8 {d0[0]}, [%[y0]]! \n" + "blt end_%= \n" + "vst1.8 {d0[0]}, [%[y0]]! \n" "end_%=: \n" : [x0] "+r"(x0), [y0] "+r"(y0), [loop] "+r"(loop), [remain] "+r"(remain), [pad_loop] "+r"(pad_loop), @@ -795,10 +795,10 @@ void QuantizeKernel::Compute(const QuantizeParam ¶m) { // only support int8 currently float scale = 127 / max_abs; param.online_scale_->mutable_data()[0] = max_abs; - // const auto &paddings = param.paddings_; - std::vector paddings = {0, 0}; - // const auto padding_val = param.padding_val_; - int8_t padding_val = 127; + const auto &paddings = param.paddings_; + // std::vector paddings = {0, 0}; + // const auto padding_val = param.padding_val_; + int8_t padding_val = 0; switch (param.round_type_) { case ROUND_NEAREST_TO_EVEN: quantize_round_to_even(input, scale, paddings, padding_val, output); diff --git a/src/operators/kernel/cl/cl_kernel/concat_kernel.cl b/src/operators/kernel/cl/cl_kernel/concat_kernel.cl index b07ee4d819b25ef77729ed868c54b19a3d8699ae..20cf7b4c48db4191a2bc95b0d952fbaf0ea1dc18 100644 --- a/src/operators/kernel/cl/cl_kernel/concat_kernel.cl +++ b/src/operators/kernel/cl/cl_kernel/concat_kernel.cl @@ -13,7 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma OPENCL EXTENSION cl_khr_fp16 : enable -/* + +__kernel void concatByC0(__read_only image2d_t input_image, + __write_only image2d_t output_image, + __private const int out_W) { + + const int in_c = get_global_id(0); + const int in_w = get_global_id(1); + const int in_nh = get_global_id(2); + + int2 input_pos ; + input_pos.x = in_c * out_W + in_w; + input_pos.y = in_nh; + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + half4 input; + input = read_imageh(input_image, sampler,input_pos); + + write_imageh(output_image, input_pos, input); + +} __kernel void concatByC(__read_only image2d_t input_image1, __read_only image2d_t input_image2, @@ -24,13 +44,13 @@ __kernel void concatByC(__read_only image2d_t input_image1, __private const int out_C_Start, __private const int in_W, __private const int in_H, - __private const int int_C1, - __private const int int_C2) { + __private const int in_C1, + __private const int in_C2) { const int in_c = get_global_id(0); const int in_w = get_global_id(1); const int in_nh = get_global_id(2); - int out_c1 = (out_C_Start)/4 + in_c; + int out_c1 = (out_C_Start + 3)/4 -1 + in_c; int out_c2 = out_c1 + 1; @@ -45,7 +65,7 @@ __kernel void concatByC(__read_only image2d_t input_image1, int2 input_pos1; if(in_c==0){ - input_pos1.x = ((in_C1-1)/4) * in_W + in_w; + input_pos1.x = ((in_C1 + 3)/4-1) * in_W + in_w; }else{ input_pos1.x = (in_c - 1) * in_W + in_w; } @@ -103,26 +123,6 @@ __kernel void concatByC(__read_only image2d_t input_image1, write_imageh(output_image, output_pos2, output2); } -__kernel void concatByW0(__read_only image2d_t input_image, - __write_only image2d_t output_image, - __private const int out_W) { - - const int in_c = get_global_id(0); - const int in_w = get_global_id(1); - const int in_nh = get_global_id(2); - - int2 input_pos = in_c * out_W + in_w; - - const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | - CLK_ADDRESS_CLAMP | - CLK_FILTER_NEAREST; - half4 input; - input = read_imageh(input_image, sampler,input_pos); - - write_imageh(output_image, input_pos, input); - -} -*/ __kernel void concatByH(__read_only image2d_t input_image, __write_only image2d_t output_image, diff --git a/src/operators/kernel/cl/cl_kernel/conv_kernel.inc.cl b/src/operators/kernel/cl/cl_kernel/conv_kernel.inc.cl index 2247df59fb77a67a87a00bd26de014f94e86a378..1085e97c10d27aa99583a86a2e2d70ae11d2d68d 100644 --- a/src/operators/kernel/cl/cl_kernel/conv_kernel.inc.cl +++ b/src/operators/kernel/cl/cl_kernel/conv_kernel.inc.cl @@ -692,6 +692,238 @@ __kernel void conv_1x1_4(__private const int global_size_dim0, */ +__kernel void conv_7x7(__private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, + __read_only image2d_t input_image, + __read_only image2d_t filter_image, + +#ifdef BIASE + __read_only image2d_t bias, +#endif + +#ifdef BATCH_NORM + __read_only image2d_t new_scale, + __read_only image2d_t new_biase, +#endif + + __write_only image2d_t output_image, + __private const int stride, + __private const int offset, + __private const int input_c, + __private const int dilation, + __private const int input_width,/* of one block */ + __private const int input_height,/* of one block */ + __private const int output_width, + __private const int output_height) { + + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + + if (out_c >= global_size_dim0 || + out_w >= global_size_dim1 || + out_nh >= global_size_dim2) { + return; + } + const filter_n0 = 4 * out_c + 0; + const filter_n1 = 4 * out_c + 1; + const filter_n2 = 4 * out_c + 2; + const filter_n3 = 4 * out_c + 3; + + int2 stride_xy; + stride_xy.x = stride; + stride_xy.y = stride; + + int2 ouput_pos_in_one_block; + ouput_pos_in_one_block.x = out_w; + ouput_pos_in_one_block.y = out_nh; + + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + + int2 in_pos_in_one_block; + in_pos_in_one_block.x = ouput_pos_in_one_block.x * stride + offset; + in_pos_in_one_block.y = ouput_pos_in_one_block.y * stride + offset; + +#ifdef BIASE + half4 output = read_imageh(bias, sampler, (int2)(out_c, 0)); +#else + half4 output = 0.0f; +#endif + + half4 input; + half4 filter[4]; + int2 filter_pos0; + int2 filter_pos1; + int2 filter_pos2; + int2 filter_pos3; + for (int i = 0; i < input_c; ++i) { + int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y); + for(int j = 0; j < 7; j++){ + for(int k = 0; k < 7; k++){ + input = select(read_imageh(input_image, sampler, + (int2)(pos_in.x + (j - 3) * dilation, pos_in.y + (k - 3) * dilation)), + (half4)(0.0f), + (ushort4)((in_pos_in_one_block.x + (j - 3) * dilation < 0 || in_pos_in_one_block.y + (k - 3) * dilation < 0 || in_pos_in_one_block.x + (j - 3) * dilation >= input_width || in_pos_in_one_block.y + (k - 3) * dilation >= input_height) << 15)); + int filter_h = k; + int filter_w = j; + int filter_c = i; + + filter_pos0.x = filter_c * 7 + filter_w; + filter_pos0.y = filter_n0 * 7 + filter_h; + + filter_pos1.x = filter_c * 7 + filter_w; + filter_pos1.y = filter_n1 * 7 + filter_h; + + filter_pos2.x = filter_c * 7 + filter_w; + filter_pos2.y = filter_n2 * 7 + filter_h; + + filter_pos3.x = filter_c * 7 + filter_w; + filter_pos3.y = filter_n3 * 7 + filter_h; + + filter[0] = read_imageh(filter_image, sampler, filter_pos0); + filter[1] = read_imageh(filter_image, sampler, filter_pos1); + filter[2] = read_imageh(filter_image, sampler, filter_pos2); + filter[3] = read_imageh(filter_image, sampler, filter_pos3); + + output.x += dot(input, filter[0]); + output.y += dot(input, filter[1]); + output.z += dot(input, filter[2]); + output.w += dot(input, filter[3]); + } + } + } + +#ifdef BATCH_NORM + output = output * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + read_imageh(new_biase, sampler, (int2)(out_c, 0)); +#endif + +#ifdef RELU + output = activation(output); +#endif + + write_imageh(output_image, (int2)(out_c * global_size_dim1 + out_w, out_nh), output); +} + +__kernel void conv_5x5(__private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, + __read_only image2d_t input_image, + __read_only image2d_t filter_image, + +#ifdef BIASE + __read_only image2d_t bias, +#endif + +#ifdef BATCH_NORM + __read_only image2d_t new_scale, + __read_only image2d_t new_biase, +#endif + + __write_only image2d_t output_image, + __private const int stride, + __private const int offset, + __private const int input_c, + __private const int dilation, + __private const int input_width,/* of one block */ + __private const int input_height,/* of one block */ + __private const int output_width, + __private const int output_height) { + + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + + if (out_c >= global_size_dim0 || + out_w >= global_size_dim1 || + out_nh >= global_size_dim2) { + return; + } + const filter_n0 = 4 * out_c + 0; + const filter_n1 = 4 * out_c + 1; + const filter_n2 = 4 * out_c + 2; + const filter_n3 = 4 * out_c + 3; + + int2 stride_xy; + stride_xy.x = stride; + stride_xy.y = stride; + + int2 ouput_pos_in_one_block; + ouput_pos_in_one_block.x = out_w; + ouput_pos_in_one_block.y = out_nh; + + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + + int2 in_pos_in_one_block; + in_pos_in_one_block.x = ouput_pos_in_one_block.x * stride + offset; + in_pos_in_one_block.y = ouput_pos_in_one_block.y * stride + offset; + +#ifdef BIASE + half4 output = read_imageh(bias, sampler, (int2)(out_c, 0)); +#else + half4 output = 0.0f; +#endif + + half4 input; + half4 filter[4]; + int2 filter_pos0; + int2 filter_pos1; + int2 filter_pos2; + int2 filter_pos3; + for (int i = 0; i < input_c; ++i) { + int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y); + for(int j = 0; j < 5; j++){ + for(int k = 0; k < 5; k++){ + input = select(read_imageh(input_image, sampler, + (int2)(pos_in.x + (j - 2) * dilation, pos_in.y + (k - 2) * dilation)), + (half4)(0.0f), + (ushort4)((in_pos_in_one_block.x + (j - 2) * dilation < 0 || in_pos_in_one_block.y + (k - 2) * dilation < 0 || in_pos_in_one_block.x + (j - 2) * dilation >= input_width || in_pos_in_one_block.y + (k - 2) * dilation >= input_height) << 15)); + int filter_h = k; + int filter_w = j; + int filter_c = i; + + filter_pos0.x = filter_c * 5 + filter_w; + filter_pos0.y = filter_n0 * 5 + filter_h; + + filter_pos1.x = filter_c * 5 + filter_w; + filter_pos1.y = filter_n1 * 5 + filter_h; + + filter_pos2.x = filter_c * 5 + filter_w; + filter_pos2.y = filter_n2 * 5 + filter_h; + + filter_pos3.x = filter_c * 5 + filter_w; + filter_pos3.y = filter_n3 * 5 + filter_h; + + filter[0] = read_imageh(filter_image, sampler, filter_pos0); + filter[1] = read_imageh(filter_image, sampler, filter_pos1); + filter[2] = read_imageh(filter_image, sampler, filter_pos2); + filter[3] = read_imageh(filter_image, sampler, filter_pos3); + + output.x += dot(input, filter[0]); + output.y += dot(input, filter[1]); + output.z += dot(input, filter[2]); + output.w += dot(input, filter[3]); + } + } + } + +#ifdef BATCH_NORM + output = output * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + read_imageh(new_biase, sampler, (int2)(out_c, 0)); +#endif + +#ifdef RELU + output = activation(output); +#endif + + write_imageh(output_image, (int2)(out_c * global_size_dim1 + out_w, out_nh), output); +} + diff --git a/src/operators/kernel/cl/cl_kernel/lrn_kernel.cl b/src/operators/kernel/cl/cl_kernel/lrn_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..080928b23586b0aa3e639a0cc9b5577355863639 --- /dev/null +++ b/src/operators/kernel/cl/cl_kernel/lrn_kernel.cl @@ -0,0 +1,136 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void lrn(__read_only image2d_t input_image, + __write_only image2d_t output_image, + __private const int out_C, + __private const int out_W, + __private const int n, + __private const float k, + __private const float alpha, + __private const float beta){ + + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + + const int out_c0 = out_c * 4; + const int out_c1 = out_c * 4 + 1; + const int out_c2 = out_c * 4+ 2; + const int out_c3 = out_c * 4+ 3; + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + + const int start = -(n-1)/2; + const end = start + n; + float sqr_sum0 = 0.0f; + float sqr_sum1 = 0.0f; + float sqr_sum2 = 0.0f; + float sqr_sum3 = 0.0f; + int input_c0,input_c1,input_c2,input_c3; + int2 input_pos0,input_pos1,input_pos2,input_pos3; + float4 input0,input1,input2,input3; + for(int i = start; i < end ;i++){ + if(out_c0 + i>=0&&out_c0 + i=0&&out_c1 + i=0&&out_c2 + i=0&&out_c3 + i=2){ + output.y = input.y / (pow(k + alpha * (sqr_sum1),beta)); + } + if(out_C - 4 * out_c>=3){ + output.z = input.z / (pow(k + alpha * (sqr_sum2),beta)); + } + if(out_C - 4 * out_c>=4){ + output.w = input.w / (pow(k + alpha * (sqr_sum3),beta)); + } + half4 tmp = convert_half4(output); + write_imageh(output_image, output_pos, tmp); + +} \ No newline at end of file diff --git a/src/operators/kernel/cl/cl_kernel/pool_kernel.cl b/src/operators/kernel/cl/cl_kernel/pool_kernel.cl index fc660941f8863a0056c4618f0207ae69533d3242..a6a4da690fa921d281786fcddebf7362d3c52119 100644 --- a/src/operators/kernel/cl/cl_kernel/pool_kernel.cl +++ b/src/operators/kernel/cl/cl_kernel/pool_kernel.cl @@ -31,11 +31,13 @@ __kernel void pool_max( const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; - int start_h = max(out_h * stride_h - pad_top, 0); + int start_h = out_h * stride_h - pad_top; int end_h = min(start_h + ksize_h, in_height); + start_h = max(start_h,0); - int start_w = max(out_w * stride_w - pad_left, 0); + int start_w = out_w * stride_w - pad_left; int end_w = min(start_w + ksize_w, in_width); + start_w = max(start_w,0); const int pos_in_x = out_c * in_width; const int pos_in_y = out_n * in_height; diff --git a/src/operators/kernel/cl/concat_kernel.cpp b/src/operators/kernel/cl/concat_kernel.cpp index 3deb31e7aa0c408cc2b87c523d324001f75ade88..c8ff448b3be79c1acfac7e8cd4e32ea4e3c2b3f5 100644 --- a/src/operators/kernel/cl/concat_kernel.cpp +++ b/src/operators/kernel/cl/concat_kernel.cpp @@ -23,12 +23,17 @@ template <> bool ConcatKernel::Init(ConcatParam *param) { if (param->Out()->dims().size() < 4) { this->cl_helper_.AddKernel("concatByH", "concat_kernel.cl"); + } else if (param->Out()->dims().size() == 4) { + this->cl_helper_.AddKernel("concatByC0", "concat_kernel.cl"); + this->cl_helper_.AddKernel("concatByC", "concat_kernel.cl"); } return true; } template <> void ConcatKernel::Compute(const ConcatParam ¶m) { + DLOG << "yangfei50"; + DLOG << param.Out()->dims(); if (param.Out()->dims().size() < 4) { auto kernel = this->cl_helper_.KernelAt(0); auto inputs = param.Inputs(); @@ -62,6 +67,76 @@ void ConcatKernel::Compute(const ConcatParam ¶m) { out_H_Start += inputs[i]->dims()[0]; } } + } else { + auto kernel0 = this->cl_helper_.KernelAt(0); + auto kernel1 = this->cl_helper_.KernelAt(1); + auto inputs = param.Inputs(); + auto *output_image = param.Out()->GetCLImage(); + + int out_C_Start = 0; + auto input_image = inputs[0]->GetCLImage(); + auto default_work_size = this->cl_helper_.DefaultWorkSize(*inputs[0]); + int out_W = param.Out()->dims()[3]; + cl_int status; + status = clSetKernelArg(kernel0, 0, sizeof(cl_mem), &input_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel0, 1, sizeof(cl_mem), &output_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel0, 2, sizeof(int), &out_W); + CL_CHECK_ERRORS(status); + status = clEnqueueNDRangeKernel( + this->cl_helper_.CLCommandQueue(), kernel0, default_work_size.size(), + NULL, default_work_size.data(), NULL, 0, NULL, NULL); + CL_CHECK_ERRORS(status); + out_C_Start += inputs[0]->dims()[1]; + for (int i = 1; i < inputs.size(); i++) { + auto input_image1 = inputs[i - 1]->GetCLImage(); + auto input_image2 = inputs[i]->GetCLImage(); + default_work_size = this->cl_helper_.DefaultWorkSize(*inputs[i]); + int out_C = param.Out()->dims()[1]; + int out_H = param.Out()->dims()[2]; + int in_W = inputs[i]->dims()[3]; + int in_H = inputs[i]->dims()[2]; + int in_C1 = inputs[i - 1]->dims()[1]; + int in_C2 = inputs[i]->dims()[1]; + DLOG << "第" << i << "个"; + DLOG << "out_C=" << out_C; + DLOG << "out_H=" << out_H; + DLOG << "in_W=" << in_W; + DLOG << "in_H=" << in_H; + DLOG << "in_C1=" << in_C1; + DLOG << "in_C2=" << in_C2; + DLOG << "out_C_Start = " << out_C_Start; + status = clSetKernelArg(kernel1, 0, sizeof(cl_mem), &input_image1); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel1, 1, sizeof(cl_mem), &input_image2); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel1, 2, sizeof(cl_mem), &output_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel1, 3, sizeof(int), &out_C); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel1, 4, sizeof(int), &out_H); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel1, 5, sizeof(int), &out_W); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel1, 6, sizeof(int), &out_C_Start); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel1, 7, sizeof(int), &in_W); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel1, 8, sizeof(int), &in_H); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel1, 9, sizeof(int), &in_C1); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel1, 10, sizeof(int), &in_C2); + CL_CHECK_ERRORS(status); + + status = clEnqueueNDRangeKernel( + this->cl_helper_.CLCommandQueue(), kernel1, default_work_size.size(), + NULL, default_work_size.data(), NULL, 0, NULL, NULL); + CL_CHECK_ERRORS(status); + + out_C_Start += inputs[i]->dims()[1]; + } } } diff --git a/src/operators/kernel/cl/conv_add_kernel.cpp b/src/operators/kernel/cl/conv_add_kernel.cpp index 3292cc7ccd2febc4d1e5b8f5e4991f8348b25196..9485644dea3fbbfb983ca104e6dbc04832e2afe6 100644 --- a/src/operators/kernel/cl/conv_add_kernel.cpp +++ b/src/operators/kernel/cl/conv_add_kernel.cpp @@ -51,8 +51,16 @@ bool ConvAddKernel::Init(FusionConvAddParam *param) { this->cl_helper_.AddKernel("conv_3x3", "conv_add_kernel.cl"); - } else { - PADDLE_MOBILE_THROW_EXCEPTION(" not support "); + } else if (param->Filter()->dims()[2] == 7 && + param->Filter()->dims()[3] == 7) { + param->Filter()->InitCLImage(cl_helper_.CLContext(), + cl_helper_.CLCommandQueue()); + this->cl_helper_.AddKernel("conv_7x7", "conv_add_kernel.cl"); + } else if (param->Filter()->dims()[2] == 5 && + param->Filter()->dims()[3] == 5) { + param->Filter()->InitCLImage(cl_helper_.CLContext(), + cl_helper_.CLCommandQueue()); + this->cl_helper_.AddKernel("conv_5x5", "conv_add_kernel.cl"); } return true; diff --git a/src/operators/kernel/cl/conv_add_relu_kernel.cpp b/src/operators/kernel/cl/conv_add_relu_kernel.cpp index 814cff634cb0c4c2d5dd6e6706b558bb1cd64f22..88de4ae2e308f2b55020c314d18551ebe8ae1ea7 100644 --- a/src/operators/kernel/cl/conv_add_relu_kernel.cpp +++ b/src/operators/kernel/cl/conv_add_relu_kernel.cpp @@ -52,6 +52,16 @@ bool ConvAddReluKernel::Init( this->cl_helper_.AddKernel("conv_3x3", "conv_add_relu_kernel.cl"); + } else if (param->Filter()->dims()[2] == 7 && + param->Filter()->dims()[3] == 7) { + param->Filter()->InitCLImage(cl_helper_.CLContext(), + cl_helper_.CLCommandQueue()); + this->cl_helper_.AddKernel("conv_7x7", "conv_add_relu_kernel.cl"); + } else if (param->Filter()->dims()[2] == 5 && + param->Filter()->dims()[3] == 5) { + param->Filter()->InitCLImage(cl_helper_.CLContext(), + cl_helper_.CLCommandQueue()); + this->cl_helper_.AddKernel("conv_5x5", "conv_add_relu_kernel.cl"); } else { PADDLE_MOBILE_THROW_EXCEPTION(" not support "); } diff --git a/src/operators/kernel/cl/fusion_fc_kernel.cpp b/src/operators/kernel/cl/fusion_fc_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7d85becea601878de577b59a5c671b3ea04f9370 --- /dev/null +++ b/src/operators/kernel/cl/fusion_fc_kernel.cpp @@ -0,0 +1,130 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_FC_OP + +#include "operators/kernel/fusion_fc_kernel.h" +#include "operators/math/math_function.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool FusionFcKernel::Init(FusionFcParam *param) { + param->InputY()->InitNormalCLImage(cl_helper_.CLContext(), + this->cl_helper_.CLCommandQueue()); + param->InputZ()->InitNormalCLImage(cl_helper_.CLContext(), + this->cl_helper_.CLCommandQueue()); + this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl"); + this->cl_helper_.AddKernel("feed", "feed_kernel.cl"); + return true; +} + +template +void FusionFcCompute(const FusionFcParam ¶m, cl_context context, + cl_command_queue commandQueue, cl_kernel kernel0, + cl_kernel kernel1) { + auto *input_x_image = param.InputX(); + auto *input_y_image = param.InputY(); + auto *input_z_image = param.InputZ(); + + int axis = param.Axis(); + auto *out_image = param.Out(); + + Tensor *input_x = new Tensor(); + input_x->Resize(input_x_image->dims()); + input_x->mutable_data(); + framework::CLImageToTensor(input_x_image, input_x, context, commandQueue, + kernel0); + + Tensor *input_y = new Tensor(); + input_y->Resize(input_y_image->dims()); + input_y->mutable_data(); + framework::CLImageToTensor(input_y_image, input_y, context, commandQueue, + kernel0); + + Tensor *input_z = new Tensor(); + input_z->Resize(input_z_image->dims()); + input_z->mutable_data(); + framework::CLImageToTensor(input_z_image, input_z, context, commandQueue, + kernel0); + auto *input_z_data = input_z->data(); + + DLOG << *input_x; + DLOG << *input_y; + DLOG << *input_z; + + Tensor *out = new Tensor(); + out->Resize(out_image->dims()); + out->mutable_data(); + auto *out_data = out->mutable_data(); + + const Tensor x_matrix = + input_x->dims().size() > 2 + ? framework::ReshapeToMatrix(*input_x, param.XNumColDims()) + : *input_x; + const Tensor y_matrix = + input_y->dims().size() > 2 + ? framework::ReshapeToMatrix(*input_y, param.YNumColDims()) + : *input_y; + auto out_dim = out->dims(); + if (out_dim.size() != 2) { + out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); + } + PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2."); + PADDLE_MOBILE_ENFORCE(input_z->dims().size() == 1, "inpu_z size must be 1"); + PADDLE_MOBILE_ENFORCE(out_dim[1] == input_z->dims()[0], + " out_dim.size must be 2."); + axis = (axis == -1 ? out_dim.size() - input_z->dims().size() : axis); + PADDLE_MOBILE_ENFORCE(axis == 1, " to fit broadcast, axis = 1. "); + + int64_t classes = input_z->numel(); + for (int i = 0; i < out_dim[0]; i++) { + memory::Copy(out_data + i * classes, input_z_data, sizeof(float) * classes); + } + + // for (int i = 0; i < out->numel(); i++) { + // DLOG << out_data[i]; + // } + // bias_data的维度和out的维度一致 + math::matmul(x_matrix, false, y_matrix, false, static_cast(1), + out, static_cast(1), false); + + out_image->InitEmptyImage(context, commandQueue, out->dims()); + framework::TensorToCLImage(out, out_image, context, commandQueue, kernel1); + + DLOG << *out; + + delete (input_x); + delete (input_y); + delete (input_z); + delete (out); + PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2."); + // if (out_dim.size() != 2) { + // out->Resize(out_dim); + // } +} +template <> +void FusionFcKernel::Compute( + const FusionFcParam ¶m) { + auto kernel0 = this->cl_helper_.KernelAt(0); + auto kernel1 = this->cl_helper_.KernelAt(1); + FusionFcCompute(param, this->cl_helper_.CLContext(), + this->cl_helper_.CLCommandQueue(), kernel0, kernel1); +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/cl/lrn_kernel.cpp b/src/operators/kernel/cl/lrn_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e7e949e5ab5e8a8c8e17d76ee839767173251edc --- /dev/null +++ b/src/operators/kernel/cl/lrn_kernel.cpp @@ -0,0 +1,79 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef LRN_OP + +#include "operators/kernel/lrn_kernel.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool LrnKernel::Init(LrnParam *param) { + this->cl_helper_.AddKernel("lrn", "lrn_kernel.cl"); + return true; +} + +template <> +void LrnKernel::Compute(const LrnParam ¶m) { + auto kernel = this->cl_helper_.KernelAt(0); + auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Out()); + + auto input_image = param.InputX()->GetCLImage(); + auto x_dims = param.InputX()->dims(); + auto output_image = param.Out()->GetCLImage(); + + const int N = x_dims[0]; + const int C = x_dims[1]; + const int H = x_dims[2]; + const int W = x_dims[3]; + + const int n = param.N(); + const float alpha = param.Alpha(); + const float beta = param.Beta(); + const float k = param.K(); + DLOG << "n=" << n; + DLOG << "alpha=" << alpha; + DLOG << "beta=" << beta; + DLOG << "k=" << k; + DLOG << default_work_size; + DLOG << C; + DLOG << W; + cl_int status; + status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &input_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &output_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 2, sizeof(int), &C); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 3, sizeof(int), &W); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 4, sizeof(int), &n); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 5, sizeof(float), &k); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 6, sizeof(float), &alpha); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 7, sizeof(float), &beta); + + status = clEnqueueNDRangeKernel( + this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL, + default_work_size.data(), NULL, 0, NULL, NULL); + CL_CHECK_ERRORS(status); +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/dequant_add_bn_relu_kernel.h b/src/operators/kernel/dequant_add_bn_relu_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..7138e5c415caca6766913f9959bd41def0943d34 --- /dev/null +++ b/src/operators/kernel/dequant_add_bn_relu_kernel.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP + +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +class FusionDequantAddBNReluKernel + : public framework::OpKernelBase> { + public: + void Compute(const FusionDequantAddBNReluParam ¶m); + bool Init(FusionDequantAddBNReluParam *param); +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/fpga/V2/conv_add_bn_kernel.cpp b/src/operators/kernel/fpga/V2/conv_add_bn_kernel.cpp index 7c03daf7797dbc09ba85a4f4e32e983571d192df..82cb872055aed84d28c798e413b86478de6ca0a6 100644 --- a/src/operators/kernel/fpga/V2/conv_add_bn_kernel.cpp +++ b/src/operators/kernel/fpga/V2/conv_add_bn_kernel.cpp @@ -58,7 +58,7 @@ bool ConvAddBNKernel::Init(FusionConvAddBNParam *param) { param->SetNewScale(new_scale); param->SetNewBias(new_bias); - fpga::format_conv_data(filter, out, bs_ptr, param->Groups()); + fpga::format_conv_data(filter, out, &bs_ptr, param->Groups()); fpga::SplitConvArgs conv_arg = {0}; fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled, diff --git a/src/operators/kernel/fpga/V2/conv_add_bn_relu_kernel.cpp b/src/operators/kernel/fpga/V2/conv_add_bn_relu_kernel.cpp index 8737554e6f8c343491656ca7659e1850d84ea246..266ebe012e0db3ef3b2ac21f81f4436d143ece59 100644 --- a/src/operators/kernel/fpga/V2/conv_add_bn_relu_kernel.cpp +++ b/src/operators/kernel/fpga/V2/conv_add_bn_relu_kernel.cpp @@ -56,7 +56,7 @@ bool ConvAddBNReluKernel::Init( param->SetNewScale(new_scale); param->SetNewBias(new_bias); - fpga::format_conv_data(filter, out, bs_ptr, param->Groups()); + fpga::format_conv_data(filter, out, &bs_ptr, param->Groups()); fpga::SplitConvArgs conv_arg = {0}; fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled, diff --git a/src/operators/kernel/fpga/V2/conv_add_kernel.cpp b/src/operators/kernel/fpga/V2/conv_add_kernel.cpp index 22841e705c255433bebeab479a2e2b8d3a3b7187..e9c5032779b4e6b63f82355cd2a5634c1fae88de 100644 --- a/src/operators/kernel/fpga/V2/conv_add_kernel.cpp +++ b/src/operators/kernel/fpga/V2/conv_add_kernel.cpp @@ -38,7 +38,7 @@ bool ConvAddKernel::Init(FusionConvAddParam *param) { bs_ptr[i] = bias_ptr[i]; } - fpga::format_conv_data(filter, out, bs_ptr, param->Groups()); + fpga::format_conv_data(filter, out, &bs_ptr, param->Groups()); fpga::SplitConvArgs conv_arg = {0}; fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled, diff --git a/src/operators/kernel/fpga/V2/conv_add_relu_kernel.cpp b/src/operators/kernel/fpga/V2/conv_add_relu_kernel.cpp index a3c4443645e421ee0dce10f53914600fb7af75bf..1002a358434046b05fee41b60281cc594a093808 100644 --- a/src/operators/kernel/fpga/V2/conv_add_relu_kernel.cpp +++ b/src/operators/kernel/fpga/V2/conv_add_relu_kernel.cpp @@ -38,7 +38,7 @@ bool ConvAddReluKernel::Init(FusionConvAddReluParam *param) { bs_ptr[i] = bias_ptr[i]; } - fpga::format_conv_data(filter, out, bs_ptr, param->Groups()); + fpga::format_conv_data(filter, out, &bs_ptr, param->Groups()); fpga::SplitConvArgs conv_arg = {0}; fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled, diff --git a/src/operators/kernel/fpga/V2/conv_bn_kernel.cpp b/src/operators/kernel/fpga/V2/conv_bn_kernel.cpp index 070fce98b9e5f0c7055943447602dba8ae78c7c4..cb32c0fe040b9c55de660269fbfc3598ea9722bf 100644 --- a/src/operators/kernel/fpga/V2/conv_bn_kernel.cpp +++ b/src/operators/kernel/fpga/V2/conv_bn_kernel.cpp @@ -50,7 +50,7 @@ bool ConvBNKernel::Init(FusionConvBNParam *param) { param->SetNewScale(new_scale); param->SetNewBias(new_bias); - fpga::format_conv_data(filter, out, bs_ptr, param->Groups()); + fpga::format_conv_data(filter, out, &bs_ptr, param->Groups()); fpga::SplitConvArgs conv_arg = {0}; fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled, diff --git a/src/operators/kernel/fpga/V2/conv_bn_relu_kernel.cpp b/src/operators/kernel/fpga/V2/conv_bn_relu_kernel.cpp index 95ac74cbf87fe20ef419e748f8a8a04df20c98e3..918b65bd347811f9a2cc6b1182c54d9f39a9082e 100644 --- a/src/operators/kernel/fpga/V2/conv_bn_relu_kernel.cpp +++ b/src/operators/kernel/fpga/V2/conv_bn_relu_kernel.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #ifdef FUSION_CONVBNRELU_OP #include "operators/kernel/conv_bn_relu_kernel.h" +#include "fpga/V2/filter.h" namespace paddle_mobile { namespace operators { @@ -50,7 +51,7 @@ bool ConvBNReluKernel::Init(FusionConvBNReluParam *param) { param->SetNewScale(new_scale); param->SetNewBias(new_bias); - fpga::format_conv_data(filter, out, bs_ptr, param->Groups()); + fpga::format_conv_data(filter, out, &bs_ptr, param->Groups()); fpga::SplitConvArgs conv_arg = {0}; fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled, diff --git a/src/operators/lrn_op.cpp b/src/operators/lrn_op.cpp index faa9ccb6132e70e01e5c076554455d9424c68086..b63d2f2fbe594fc35cd580ea772562a263c97bd5 100644 --- a/src/operators/lrn_op.cpp +++ b/src/operators/lrn_op.cpp @@ -14,7 +14,7 @@ limitations under the License. */ #ifdef LRN_OP -#include "lrn_op.h" +#include "operators/lrn_op.h" namespace paddle_mobile { namespace operators { @@ -32,6 +32,9 @@ namespace ops = paddle_mobile::operators; #ifdef PADDLE_MOBILE_CPU REGISTER_OPERATOR_CPU(lrn, ops::LrnOp); #endif +#ifdef PADDLE_MOBILE_CL +REGISTER_OPERATOR_CL(lrn, ops::LrnOp); +#endif #ifdef PADDLE_MOBILE_MALI_GPU REGISTER_OPERATOR_MALI_GPU(lrn, ops::LrnOp); #endif diff --git a/src/operators/op_param.h b/src/operators/op_param.h index ea79a3af2dfe97b0385c3c1cdec671bedf2cd7ae..12c26aed3ac8685a4f8b662e3bb39ff711a7019a 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -1631,11 +1631,11 @@ class FusionFcParam : public OpParam { y_num_col_dims_ = GetAttr("y_num_col_dims", attrs); axis_ = GetAttr("axis", attrs); } - const GType *InputX() const { return input_x_; } + GType *InputX() const { return input_x_; } - const RType *InputY() const { return input_y_; } + RType *InputY() const { return input_y_; } - const RType *InputZ() const { return input_z_; } + RType *InputZ() const { return input_z_; } GType *Out() const { return out_; } @@ -2555,7 +2555,7 @@ class QuantizeParam : public OpParam { output_ = OutFrom(outputs, scope); // online // scale = max(abs(x)) - online_scale_ = GetVarValue("OutScale", outputs, scope); + online_scale_ = OpParam::GetVarValue("OutScale", outputs, scope); // offline if (HasAttr("static_scale", attrs)) { is_static_ = true; @@ -2565,6 +2565,11 @@ class QuantizeParam : public OpParam { if (HasAttr("round_type", attrs)) { round_type_ = GetAttr("round_type", attrs); } + // get paddings + paddings_ = std::vector({0, 0}); + if (HasAttr("paddings", attrs)) { + paddings_ = GetAttr>("paddings", attrs); + } } public: @@ -2598,7 +2603,7 @@ class DequantizeParam : public OpParam { const AttributeMap &attrs, const Scope &scope) { input_ = InputXFrom(inputs, scope); output_ = OutFrom(outputs, scope); - activation_scale_ = GetVarValue("Scale", inputs, scope); + activation_scale_ = OpParam::GetVarValue("Scale", inputs, scope); // dequantization is performed as x = x / static_scale / online_scale if (HasAttr("weight_scale", attrs)) { weight_scale_ = GetAttr("weight_scale", attrs); @@ -2617,5 +2622,44 @@ class DequantizeParam : public OpParam { }; #endif +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP +template +class FusionDequantAddBNReluParam : public DequantizeParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + FusionDequantAddBNReluParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) + : DequantizeParam(inputs, outputs, attrs, scope) { + // element wise add params + axis_ = OpParam::GetAttr("axis", attrs); + bias_ = OpParam::InputYFrom(inputs, scope); + // batch norm params + bn_mean_ = OpParam::GetVarValue("BNMean", inputs, scope); + bn_variance_ = OpParam::GetVarValue("BNVariance", inputs, scope); + bn_scale_ = OpParam::GetVarValue("BNScale", inputs, scope); + bn_bias_ = OpParam::GetVarValue("BNBias", inputs, scope); + epsilon_ = OpParam::GetAttr("epsilon", attrs); + // output + output_ = OpParam::OutFrom(outputs, scope); + } + + public: + // elementwise add + int axis_; + RType *bias_; + // batch norm + RType *bn_mean_; + RType *bn_variance_; + RType *bn_scale_; + RType *bn_bias_; + float epsilon_; + // output + RType *output_; +}; +#endif + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/quantize_op.cpp b/src/operators/quantize_op.cpp index bde99cfd5ab4b1c2f86c55cdb39dbf559c66d576..6dd9d75af463753008b273b93253cb986eb90e80 100644 --- a/src/operators/quantize_op.cpp +++ b/src/operators/quantize_op.cpp @@ -22,7 +22,10 @@ namespace operators { template void QuantizeOp::InferShape() const { - const auto &input_dims = this->param_.input_->dims(); + auto input_dims = this->param_.input_->dims(); + const std::vector &paddings = this->param_.paddings_; + input_dims[2] += 2 * paddings[0]; + input_dims[3] += 2 * paddings[1]; this->param_.output_->Resize(input_dims); auto scale_dims = framework::make_ddim(std::vector{1}); this->param_.online_scale_->Resize(scale_dims); diff --git a/test/operators/test_quantize_op.cpp b/test/operators/test_quantize_op.cpp index 5b1f276bebb0b956a7907a500645612c5aeaf8f9..9988661bcb898daa5e79b6d22d65d90cfa03c668 100644 --- a/test/operators/test_quantize_op.cpp +++ b/test/operators/test_quantize_op.cpp @@ -12,58 +12,131 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "../test_helper.h" #include "../test_include.h" #include "operators/quantize_op.h" namespace paddle_mobile { - -static float find_abs_max(const Tensor *input) { - float max_abs = 0.f; - const float *x = input->data(); - size_t size = input->numel(); - for (size_t i = 0; i < size; ++i) { - float value = std::abs(x[i]); - if (value > max_abs) { - max_abs = value; - } - } - return max_abs; +namespace round { +enum RoundType { + RoundToEven = 0, + RoundAwayZero = 1, + RoundTowardsZero = 2, +}; } -static void quantize_round_to_even(const Tensor *input, const float scale, - Tensor *output) { - const float *x = input->data(); - int8_t *y = output->mutable_data(); - size_t size = input->numel(); - for (size_t i = 0; i < size; ++i) { - float value = x[i] * scale; - float v = round(value); +template +struct Round { + int8_t operator()(float x); +}; + +template <> +struct Round { + int8_t operator()(float x) { return std::round(x); } +}; + +template <> +struct Round { + int8_t operator()(float x) { return int8_t(x); } +}; + +template <> +struct Round { + int8_t operator()(float x) { + int8_t ret = 0; + float v = std::round(x); int32_t q = (int32_t)v; - if (abs(abs(q - value) - 0.5) > 0) { - y[i] = q; + if (abs(abs(q - x) - 0.5) > 0) { + ret = q; } else { if (abs(q) % 2 == 0) { - y[i] = q; + ret = q; } else { - y[i] = q + ((q > 0) ? -1 : 1); + ret = q + ((q > 0) ? -1 : 1); + } + } + return ret; + } +}; + +template +static void quantize(const Tensor *input, const float scale, const int pad, + const int8_t pad_val, Tensor *output) { + int batch_size = input->dims()[0]; + int channels = input->dims()[1]; + int input_h = input->dims()[2]; + int input_w = input->dims()[3]; + int output_h = output->dims()[2]; + int output_w = output->dims()[3]; + size_t input_spatial = input_h * input_w; + size_t output_spatial = output_h * output_w; + const float *x = input->data(); + int8_t *y = output->mutable_data(); + + for (int nc = 0; nc < batch_size * channels; ++nc) { + const float *xh = x + nc * input_spatial; + int8_t *yh = y + nc * output_spatial; + // pad top + for (int h = 0; h < pad; ++h, yh += output_w) { + for (int w = 0; w < output_w; ++w) { + yh[w] = pad_val; + } + } + for (int h = 0; h < input_h; ++h, yh += output_w, xh += input_w) { + // pad left + for (int w = 0; w < pad; ++w) { + yh[w] = pad_val; + } + for (int w = 0; w < input_w; ++w) { + yh[w + pad] = Round()(xh[w] * scale); + } + // pad right + for (int w = 0; w < pad; ++w) { + yh[pad + input_w + w] = pad_val; + } + } + // pad bottom + for (int h = 0; h < pad; ++h, yh += output_w) { + for (int w = 0; w < output_w; ++w) { + yh[w] = pad_val; } } } } -static void quantize_round_to_nearest(const Tensor *input, const float scale, - Tensor *output) { +static float find_abs_max(const Tensor *input) { + float max_abs = 0.f; const float *x = input->data(); - int8_t *y = output->mutable_data(); size_t size = input->numel(); for (size_t i = 0; i < size; ++i) { - y[i] = round(x[i] * scale); + float value = std::abs(x[i]); + if (value > max_abs) { + max_abs = value; + } } + return max_abs; } -int TestQuqntizeOp() { - framework::DDim dim = framework::make_ddim({1, 3, 224, 224}); +int TestQuqntizeOp(int argc, char *argv[]) { + if (argc < 5) { + std::cout + << "Usage: ./test-quantize-op batch_size channel height width [pad]" + << std::endl; + return 1; + } + int pad = 0; + int batch_size = atoi(argv[1]); + int channel = atoi(argv[2]); + int height = atoi(argv[3]); + int width = atoi(argv[4]); + if (argc == 6) { + pad = atoi(argv[5]); + } + std::cout << "batch_size: " << batch_size << ", channel: " << channel + << ", height: " << height << ", width: " << width << std::endl; + framework::DDim dim = + framework::make_ddim({batch_size, channel, height, width}); VariableNameMap inputs; VariableNameMap outputs; @@ -80,6 +153,7 @@ int TestQuqntizeOp() { auto output_scale_var = scope.get()->Var("output_scale"); framework::AttributeMap attrs; + attrs["paddings"].Set>(std::vector({pad, pad})); auto *op = new operators::QuantizeOp("quantize", inputs, outputs, attrs, scope); op->InferShape(); @@ -96,10 +170,11 @@ int TestQuqntizeOp() { output_scale_cmp, output_scale_data[0]); framework::Tensor output_cmp; - output_cmp.Resize(dim); + output_cmp.Resize(output->dims()); float scale = 127 / output_scale_cmp; - // quantize_round_to_even(input, scale, &output_cmp); - quantize_round_to_nearest(input, scale, &output_cmp); + // quantize(input, scale, pad, 0, &output_cmp); + // quantize(input, scale, pad, 0, &output_cmp); + quantize(input, scale, pad, 0, &output_cmp); int8_t *output_cmp_data = output_cmp.data(); for (int i = 0; i < output->numel(); ++i) { PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i], @@ -113,4 +188,6 @@ int TestQuqntizeOp() { } // namespace paddle_mobile -int main() { return paddle_mobile::TestQuqntizeOp(); } +int main(int argc, char *argv[]) { + return paddle_mobile::TestQuqntizeOp(argc, argv); +} diff --git a/tools/op.cmake b/tools/op.cmake index 45dbcdcf058553840ad70805eda59306fa4ec36d..ce9a9079c682b5e9e3dff1754d1d37e98f3a5f3b 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -250,6 +250,7 @@ if(NOT FOUND_MATCH) set(SUM_OP ON) set(QUANT_OP ON) set(DEQUANT_OP ON) + set(FUSION_DEQUANT_ADD_BN_RELU ON) endif() # option(BATCHNORM_OP "" ON) @@ -454,6 +455,9 @@ endif() if (DEQUANT_OP) add_definitions(-DDEQUANT_OP) endif() +if (FUSION_DEQUANT_ADD_BN_RELU) + add_definitions(-DFUSION_DEQUANT_ADD_BN_RELU_OP) +endif() if (TANH_OP) add_definitions(-DTANH_OP) @@ -466,4 +470,4 @@ if (FUSION_DECONVADD_OP) endif() if (FUSION_DECONVADDRELU_OP) add_definitions(-DFUSION_DECONVADDRELU_OP) -endif() \ No newline at end of file +endif()