提交 98087f3f 编写于 作者: Z Zhen Wang

Merge branch 'develop' of https://github.com/PaddlePaddle/paddle-mobile into...

Merge branch 'develop' of https://github.com/PaddlePaddle/paddle-mobile into fusion_conv_add_relu_int8_op
......@@ -72,6 +72,8 @@ const char *G_OP_TYPE_SUM = "sum";
const char *G_OP_TYPE_QUANTIZE = "quantize";
const char *G_OP_TYPE_DEQUANTIZE = "dequantize";
const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU = "fusion_dequant_add_bn_relu";
const char *G_OP_TYPE_TANH = "tanh";
const char *G_OP_TYPE_FUSION_DECONV_RELU = "fusion_deconv_relu";
const char *G_OP_TYPE_FUSION_DECONV_ADD = "fusion_deconv_add";
......@@ -136,6 +138,7 @@ std::unordered_map<
{G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}},
{G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_TANH, {{"X"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_ADD, {{"Input"}, {"Out"}}},
......
......@@ -139,6 +139,7 @@ extern const char *G_OP_TYPE_ELEMENTWISE_MUL;
extern const char *G_OP_TYPE_QUANTIZE;
extern const char *G_OP_TYPE_DEQUANTIZE;
extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU;
extern const char *G_OP_TYPE_TANH;
extern const char *G_OP_TYPE_FUSION_DECONV_RELU;
......
......@@ -132,11 +132,11 @@ void format_concat_output(framework::Tensor *out, int height, int width,
}
int format_conv_data(framework::Tensor *filter_tensor,
framework::Tensor *ofm_tensor, float *bs_ptr, int group) {
framework::Tensor *ofm_tensor, float **bs_ptr, int group) {
float max_value = fpga::filter_find_max(filter_tensor);
fpga::format_filter(filter_tensor, max_value, group);
int aligned_num = get_aligned_filter_num(filter_tensor);
fpga::format_bias_scale_array(&bs_ptr,
fpga::format_bias_scale_array(bs_ptr,
(int)filter_tensor->dims()[0], // NOLINT
aligned_num);
int aligned_channel = fpga::get_conv_output_channel(filter_tensor);
......
......@@ -39,7 +39,7 @@ void format_bias_scale_array(float** bias_scale_array, int filter_num,
void format_concat_output(framework::Tensor* out, int height, int width,
uint32_t out_channel);
int format_conv_data(framework::Tensor* filter_tensor,
framework::Tensor* ofm_tensor, float* bs_ptr, int group);
framework::Tensor* ofm_tensor, float** bs_ptr, int group);
int format_fc_data(framework::Tensor* filter_tensor,
framework::Tensor* ofm_tensor, float* bs_ptr);
void fill_split_arg(struct SplitConvArgs* arg, framework::Tensor* input,
......
......@@ -68,6 +68,13 @@ class CLImage {
InitCLImage(context, command_queue, folder_converter);
}
void InitNormalCLImage(cl_context context, cl_command_queue command_queue) {
PADDLE_MOBILE_ENFORCE(tensor_data_ != nullptr,
" need call SetTensorData first");
CLImageConverterNormal *normal_converter = new CLImageConverterNormal();
InitCLImage(context, command_queue, normal_converter);
}
void InitCLImage(cl_context context, cl_command_queue command_queue,
CLImageConverterBase *converter) {
if (image_converter_ != nullptr) {
......
......@@ -233,3 +233,7 @@ LOAD_OP1(quantize, CPU);
#ifdef DEQUANT_OP
LOAD_OP1(dequantize, CPU);
#endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
LOAD_OP1(fusion_dequant_add_bn_relu, CPU);
LOAD_FUSION_MATCHER(fusion_dequant_add_bn_relu);
#endif
......@@ -22,7 +22,6 @@ void FeedOp<DeviceType, T>::InferShape() const {
auto out_dims = this->param_.Out()->dims();
out_dims[0] = this->param_.BatchSize();
auto input_dims = this->param_.InputX()->dims();
DLOG << input_dims.size();
if (input_dims.size() == 4) {
this->param_.Out()->Resize(input_dims);
} else {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#include "operators/fusion_dequant_add_bn_relu_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void FusionDequantAddBNReluOp<Dtype, T>::InferShape() const {
const auto& input_dims = this->param_.input_->dims();
this->param_.output_->Resize(input_dims);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
REGISTER_FUSION_MATCHER(fusion_dequant_add_bn_relu,
ops::FusionDequantAddBNReluMatcher);
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_dequant_add_bn_relu,
ops::FusionDequantAddBNReluOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/dequant_add_bn_relu_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
class FusionDequantAddBNReluMatcher : public framework::FusionOpMatcher {
public:
FusionDequantAddBNReluMatcher() {
node_ = framework::Node(G_OP_TYPE_DEQUANTIZE);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) >
std::make_shared<framework::Node>(G_OP_TYPE_BATCHNORM) >
std::make_shared<framework::Node>(G_OP_TYPE_RELU);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}},
{G_OP_TYPE_BATCHNORM,
{{"Scale", "BNScale"},
{"Mean", "BNMean"},
{"Bias", "BNBias"},
{"Variance", "BNVariance"}}}},
removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU; }
};
template <typename DeviceType, typename T>
class FusionDequantAddBNReluOp
: public framework::OperatorWithKernel<
DeviceType, FusionDequantAddBNReluParam<DeviceType>,
operators::FusionDequantAddBNReluKernel<DeviceType, T>> {
public:
FusionDequantAddBNReluOp(const std::string &type,
const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionDequantAddBNReluParam<DeviceType>,
operators::FusionDequantAddBNReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
// inference output shape
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -60,6 +60,9 @@ REGISTER_FUSION_MATCHER(fusion_fc, ops::FusionFcMatcher);
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_fc, ops::FusionFcOp);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(fusion_fc, ops::FusionFcOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
REGISTER_OPERATOR_MALI_GPU(fusion_fc, ops::FusionFcOp);
#endif
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#include "operators/kernel/dequant_add_bn_relu_kernel.h"
#include <cmath>
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
#endif
namespace paddle_mobile {
namespace operators {
template <>
bool FusionDequantAddBNReluKernel<CPU, float>::Init(
FusionDequantAddBNReluParam<CPU> *param) {
// elementwise add params
const Tensor *bias = param->bias_;
// batch norm params
const Tensor *bn_mean = param->bn_mean_;
const Tensor *bn_variance = param->bn_variance_;
Tensor *bn_scale = param->bn_scale_;
Tensor *bn_bias = param->bn_bias_;
const float epsilon = param->epsilon_;
const float *bias_ptr = bias->data<float>();
const float *mean_ptr = bn_mean->data<float>();
const float *var_ptr = bn_variance->data<float>();
float *bn_scale_ptr = bn_scale->mutable_data<float>();
float *bn_bias_ptr = bn_bias->mutable_data<float>();
for (int c = 0; c < bn_scale->numel(); ++c) {
float inv_scale = bn_scale_ptr[c] / (std::sqrt(var_ptr[c] + epsilon));
bn_scale_ptr[c] = inv_scale;
bn_bias_ptr[c] = inv_scale * (bias_ptr[c] - mean_ptr[c]) + bn_bias_ptr[c];
}
return true;
}
template <>
void FusionDequantAddBNReluKernel<CPU, float>::Compute(
const FusionDequantAddBNReluParam<CPU> &param) {
const int32_t *input = param.input_->data<int32_t>();
const float *bn_scale = param.bn_scale_->data<float>();
const float *bn_bias = param.bn_bias_->data<float>();
// dequantize params
const float activation_scale = param.activation_scale_->data<float>()[0];
const float weight_scale = param.weight_scale_;
const float dequant_scale = activation_scale / weight_scale;
float *output = param.output_->mutable_data<float>();
int batch_size = param.input_->dims()[0];
int channels = param.input_->dims()[1];
size_t spatial_size = param.input_->dims()[2] * param.input_->dims()[3];
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < batch_size; ++batch) {
for (int c = 0; c < channels; ++c) {
float scale = bn_scale[c] * dequant_scale;
float bias = bn_bias[c];
size_t offset = (batch * channels + c) * spatial_size;
const int32_t *x = input + offset;
float *y = output + offset;
size_t remain = spatial_size;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
int loop = spatial_size >> 4;
remain = spatial_size & 0xF;
float32x4_t __scale = vdupq_n_f32(scale);
float32x4_t __bias = vdupq_n_f32(bias);
float32x4_t __zero = vdupq_n_f32(0.f);
for (int k = 0; k < loop; ++k, x += 16, y += 16) {
int32x4_t r0 = vld1q_s32(x);
int32x4_t r1 = vld1q_s32(x + 4);
int32x4_t r2 = vld1q_s32(x + 8);
int32x4_t r3 = vld1q_s32(x + 12);
float32x4_t f0 = vcvtq_f32_s32(r0);
float32x4_t f1 = vcvtq_f32_s32(r1);
float32x4_t f2 = vcvtq_f32_s32(r2);
float32x4_t f3 = vcvtq_f32_s32(r3);
f0 = vmlaq_f32(__bias, __scale, f0);
f1 = vmlaq_f32(__bias, __scale, f1);
f2 = vmlaq_f32(__bias, __scale, f2);
f3 = vmlaq_f32(__bias, __scale, f3);
f0 = vmaxq_f32(__zero, f0);
f1 = vmaxq_f32(__zero, f1);
f2 = vmaxq_f32(__zero, f2);
f3 = vmaxq_f32(__zero, f3);
vst1q_f32(y, f0);
vst1q_f32(y + 4, f1);
vst1q_f32(y + 8, f2);
vst1q_f32(y + 12, f3);
}
#endif // __ARM_NEON__
for (int k = 0; k < remain; ++k) {
y[k] = std::max(scale * x[k] + bias, 0.f);
}
}
}
}
} // namespace operators
} // namespace paddle_mobile
#endif // FUSION_DEQUANT_ADD_BN_RELU_OP
......@@ -379,8 +379,8 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
const float *x3 = input3 + h * input_w;
int loop = input_w >> 4;
int remain = input_w & 0xF;
int pad_loop = paddings[1] >> 1;
int pad_remain = paddings[1] & 0x1;
int pad_loop = paddings[1] >> 1; // (paddings[1] << 1) >> 2
int pad_remain = (paddings[1] << 1) & 0x3;
int remain_steps = remain;
asm volatile(
"vdup.f32 q0, %[scale] \n"
......@@ -596,7 +596,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
"store_pad_2w_%=: \n"
"cmp %[pad_remain], #2 \n"
"ble store_pad_1w_%= \n"
"blt store_pad_1w_%= \n"
"vst1.16 {d0[0]}, [%[y0]]! \n"
"vst1.16 {d0[0]}, [%[y1]]! \n"
"vst1.16 {d0[0]}, [%[y2]]! \n"
......@@ -605,7 +605,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
"store_pad_1w_%=: \n"
"cmp %[pad_remain], #1 \n"
"ble end_%= \n"
"blt end_%= \n"
"vst1.8 {d0[0]}, [%[y0]]! \n"
"vst1.8 {d0[0]}, [%[y1]]! \n"
"vst1.8 {d0[0]}, [%[y2]]! \n"
......@@ -669,8 +669,8 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
const float *x0 = input0 + h * input_w;
int loop = input_w >> 4;
int remain = input_w & 0xF;
int pad_loop = paddings[1] >> 1;
int pad_remain = paddings[1] & 0x1;
int pad_loop = paddings[1] >> 1; // (paddings[1] << 1) >> 2
int pad_remain = (paddings[1] << 1) & 0x3;
asm volatile(
"vdup.f32 q0, %[scale] \n"
"cmp %[loop], #0 \n"
......@@ -754,14 +754,14 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
"pad_remain_%=: \n"
"cmp %[pad_remain], #2 \n"
"ble store_pad_1w_%= \n"
"blt store_pad_1w_%= \n"
"vst1.16 {d0[0]}, [%[y0]]! \n"
"sub %[pad_remain], #2 \n"
"store_pad_1w_%=: \n"
"cmp %[pad_remain], #1 \n"
"ble end_%= \n"
"vst1.8 {d0[0]}, [%[y0]]! \n"
"blt end_%= \n"
"vst1.8 {d0[0]}, [%[y0]]! \n"
"end_%=: \n"
: [x0] "+r"(x0), [y0] "+r"(y0), [loop] "+r"(loop),
[remain] "+r"(remain), [pad_loop] "+r"(pad_loop),
......@@ -795,10 +795,10 @@ void QuantizeKernel<CPU, float>::Compute(const QuantizeParam<CPU> &param) {
// only support int8 currently
float scale = 127 / max_abs;
param.online_scale_->mutable_data<float>()[0] = max_abs;
// const auto &paddings = param.paddings_;
std::vector<int> paddings = {0, 0};
// const auto padding_val = param.padding_val_;
int8_t padding_val = 127;
const auto &paddings = param.paddings_;
// std::vector<int> paddings = {0, 0};
// const auto padding_val = param.padding_val_;
int8_t padding_val = 0;
switch (param.round_type_) {
case ROUND_NEAREST_TO_EVEN:
quantize_round_to_even(input, scale, paddings, padding_val, output);
......
......@@ -13,7 +13,27 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
/*
__kernel void concatByC0(__read_only image2d_t input_image,
__write_only image2d_t output_image,
__private const int out_W) {
const int in_c = get_global_id(0);
const int in_w = get_global_id(1);
const int in_nh = get_global_id(2);
int2 input_pos ;
input_pos.x = in_c * out_W + in_w;
input_pos.y = in_nh;
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
half4 input;
input = read_imageh(input_image, sampler,input_pos);
write_imageh(output_image, input_pos, input);
}
__kernel void concatByC(__read_only image2d_t input_image1,
__read_only image2d_t input_image2,
......@@ -24,13 +44,13 @@ __kernel void concatByC(__read_only image2d_t input_image1,
__private const int out_C_Start,
__private const int in_W,
__private const int in_H,
__private const int int_C1,
__private const int int_C2) {
__private const int in_C1,
__private const int in_C2) {
const int in_c = get_global_id(0);
const int in_w = get_global_id(1);
const int in_nh = get_global_id(2);
int out_c1 = (out_C_Start)/4 + in_c;
int out_c1 = (out_C_Start + 3)/4 -1 + in_c;
int out_c2 = out_c1 + 1;
......@@ -45,7 +65,7 @@ __kernel void concatByC(__read_only image2d_t input_image1,
int2 input_pos1;
if(in_c==0){
input_pos1.x = ((in_C1-1)/4) * in_W + in_w;
input_pos1.x = ((in_C1 + 3)/4-1) * in_W + in_w;
}else{
input_pos1.x = (in_c - 1) * in_W + in_w;
}
......@@ -103,26 +123,6 @@ __kernel void concatByC(__read_only image2d_t input_image1,
write_imageh(output_image, output_pos2, output2);
}
__kernel void concatByW0(__read_only image2d_t input_image,
__write_only image2d_t output_image,
__private const int out_W) {
const int in_c = get_global_id(0);
const int in_w = get_global_id(1);
const int in_nh = get_global_id(2);
int2 input_pos = in_c * out_W + in_w;
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
half4 input;
input = read_imageh(input_image, sampler,input_pos);
write_imageh(output_image, input_pos, input);
}
*/
__kernel void concatByH(__read_only image2d_t input_image,
__write_only image2d_t output_image,
......
......@@ -692,6 +692,238 @@ __kernel void conv_1x1_4(__private const int global_size_dim0,
*/
__kernel void conv_7x7(__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#ifdef BIASE
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int offset,
__private const int input_c,
__private const int dilation,
__private const int input_width,/* of one block */
__private const int input_height,/* of one block */
__private const int output_width,
__private const int output_height) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
if (out_c >= global_size_dim0 ||
out_w >= global_size_dim1 ||
out_nh >= global_size_dim2) {
return;
}
const filter_n0 = 4 * out_c + 0;
const filter_n1 = 4 * out_c + 1;
const filter_n2 = 4 * out_c + 2;
const filter_n3 = 4 * out_c + 3;
int2 stride_xy;
stride_xy.x = stride;
stride_xy.y = stride;
int2 ouput_pos_in_one_block;
ouput_pos_in_one_block.x = out_w;
ouput_pos_in_one_block.y = out_nh;
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
int2 in_pos_in_one_block;
in_pos_in_one_block.x = ouput_pos_in_one_block.x * stride + offset;
in_pos_in_one_block.y = ouput_pos_in_one_block.y * stride + offset;
#ifdef BIASE
half4 output = read_imageh(bias, sampler, (int2)(out_c, 0));
#else
half4 output = 0.0f;
#endif
half4 input;
half4 filter[4];
int2 filter_pos0;
int2 filter_pos1;
int2 filter_pos2;
int2 filter_pos3;
for (int i = 0; i < input_c; ++i) {
int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y);
for(int j = 0; j < 7; j++){
for(int k = 0; k < 7; k++){
input = select(read_imageh(input_image, sampler,
(int2)(pos_in.x + (j - 3) * dilation, pos_in.y + (k - 3) * dilation)),
(half4)(0.0f),
(ushort4)((in_pos_in_one_block.x + (j - 3) * dilation < 0 || in_pos_in_one_block.y + (k - 3) * dilation < 0 || in_pos_in_one_block.x + (j - 3) * dilation >= input_width || in_pos_in_one_block.y + (k - 3) * dilation >= input_height) << 15));
int filter_h = k;
int filter_w = j;
int filter_c = i;
filter_pos0.x = filter_c * 7 + filter_w;
filter_pos0.y = filter_n0 * 7 + filter_h;
filter_pos1.x = filter_c * 7 + filter_w;
filter_pos1.y = filter_n1 * 7 + filter_h;
filter_pos2.x = filter_c * 7 + filter_w;
filter_pos2.y = filter_n2 * 7 + filter_h;
filter_pos3.x = filter_c * 7 + filter_w;
filter_pos3.y = filter_n3 * 7 + filter_h;
filter[0] = read_imageh(filter_image, sampler, filter_pos0);
filter[1] = read_imageh(filter_image, sampler, filter_pos1);
filter[2] = read_imageh(filter_image, sampler, filter_pos2);
filter[3] = read_imageh(filter_image, sampler, filter_pos3);
output.x += dot(input, filter[0]);
output.y += dot(input, filter[1]);
output.z += dot(input, filter[2]);
output.w += dot(input, filter[3]);
}
}
}
#ifdef BATCH_NORM
output = output * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + read_imageh(new_biase, sampler, (int2)(out_c, 0));
#endif
#ifdef RELU
output = activation(output);
#endif
write_imageh(output_image, (int2)(out_c * global_size_dim1 + out_w, out_nh), output);
}
__kernel void conv_5x5(__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#ifdef BIASE
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int offset,
__private const int input_c,
__private const int dilation,
__private const int input_width,/* of one block */
__private const int input_height,/* of one block */
__private const int output_width,
__private const int output_height) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
if (out_c >= global_size_dim0 ||
out_w >= global_size_dim1 ||
out_nh >= global_size_dim2) {
return;
}
const filter_n0 = 4 * out_c + 0;
const filter_n1 = 4 * out_c + 1;
const filter_n2 = 4 * out_c + 2;
const filter_n3 = 4 * out_c + 3;
int2 stride_xy;
stride_xy.x = stride;
stride_xy.y = stride;
int2 ouput_pos_in_one_block;
ouput_pos_in_one_block.x = out_w;
ouput_pos_in_one_block.y = out_nh;
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
int2 in_pos_in_one_block;
in_pos_in_one_block.x = ouput_pos_in_one_block.x * stride + offset;
in_pos_in_one_block.y = ouput_pos_in_one_block.y * stride + offset;
#ifdef BIASE
half4 output = read_imageh(bias, sampler, (int2)(out_c, 0));
#else
half4 output = 0.0f;
#endif
half4 input;
half4 filter[4];
int2 filter_pos0;
int2 filter_pos1;
int2 filter_pos2;
int2 filter_pos3;
for (int i = 0; i < input_c; ++i) {
int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y);
for(int j = 0; j < 5; j++){
for(int k = 0; k < 5; k++){
input = select(read_imageh(input_image, sampler,
(int2)(pos_in.x + (j - 2) * dilation, pos_in.y + (k - 2) * dilation)),
(half4)(0.0f),
(ushort4)((in_pos_in_one_block.x + (j - 2) * dilation < 0 || in_pos_in_one_block.y + (k - 2) * dilation < 0 || in_pos_in_one_block.x + (j - 2) * dilation >= input_width || in_pos_in_one_block.y + (k - 2) * dilation >= input_height) << 15));
int filter_h = k;
int filter_w = j;
int filter_c = i;
filter_pos0.x = filter_c * 5 + filter_w;
filter_pos0.y = filter_n0 * 5 + filter_h;
filter_pos1.x = filter_c * 5 + filter_w;
filter_pos1.y = filter_n1 * 5 + filter_h;
filter_pos2.x = filter_c * 5 + filter_w;
filter_pos2.y = filter_n2 * 5 + filter_h;
filter_pos3.x = filter_c * 5 + filter_w;
filter_pos3.y = filter_n3 * 5 + filter_h;
filter[0] = read_imageh(filter_image, sampler, filter_pos0);
filter[1] = read_imageh(filter_image, sampler, filter_pos1);
filter[2] = read_imageh(filter_image, sampler, filter_pos2);
filter[3] = read_imageh(filter_image, sampler, filter_pos3);
output.x += dot(input, filter[0]);
output.y += dot(input, filter[1]);
output.z += dot(input, filter[2]);
output.w += dot(input, filter[3]);
}
}
}
#ifdef BATCH_NORM
output = output * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + read_imageh(new_biase, sampler, (int2)(out_c, 0));
#endif
#ifdef RELU
output = activation(output);
#endif
write_imageh(output_image, (int2)(out_c * global_size_dim1 + out_w, out_nh), output);
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void lrn(__read_only image2d_t input_image,
__write_only image2d_t output_image,
__private const int out_C,
__private const int out_W,
__private const int n,
__private const float k,
__private const float alpha,
__private const float beta){
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
const int out_c0 = out_c * 4;
const int out_c1 = out_c * 4 + 1;
const int out_c2 = out_c * 4+ 2;
const int out_c3 = out_c * 4+ 3;
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
const int start = -(n-1)/2;
const end = start + n;
float sqr_sum0 = 0.0f;
float sqr_sum1 = 0.0f;
float sqr_sum2 = 0.0f;
float sqr_sum3 = 0.0f;
int input_c0,input_c1,input_c2,input_c3;
int2 input_pos0,input_pos1,input_pos2,input_pos3;
float4 input0,input1,input2,input3;
for(int i = start; i < end ;i++){
if(out_c0 + i>=0&&out_c0 + i<out_C){
input_c0 = (out_c0 + i)/4;
input_pos0.x = input_c0 * out_W + out_w;
input_pos0.y = out_nh;
input0 = convert_float4(read_imageh(input_image, sampler,input_pos0));
if((out_c0 + i)%4 == 0){
sqr_sum0 += input0.x * input0.x;
}else if((out_c0 + i)%4 == 1){
sqr_sum0 += input0.y * input0.y;
}else if((out_c0 + i)%4 == 2){
sqr_sum0 += input0.z * input0.z;
}else{
sqr_sum0 += input0.w * input0.w;
}
}
if(out_c1 + i>=0&&out_c1 + i<out_C){
input_c1 = (out_c1 + i)/4;
input_pos1.x = input_c1 * out_W + out_w;
input_pos1.y = out_nh;
input1 = convert_float4(read_imageh(input_image, sampler,input_pos1));
if((out_c1 + i)%4 == 0){
sqr_sum1 += input1.x * input1.x;
}else if((out_c1 + i)%4 == 1){
sqr_sum1 += input1.y * input1.y;
}else if((out_c1 + i)%4 == 2){
sqr_sum1 += input1.z * input1.z;
}else{
sqr_sum1 += input1.w * input1.w;
}
}
if(out_c2 + i>=0&&out_c2 + i<out_C){
input_c2 = (out_c2 + i)/4;
input_pos2.x = input_c2 * out_W + out_w;
input_pos2.y = out_nh;
input2 = convert_float4(read_imageh(input_image, sampler,input_pos2));
if((out_c2 + i)%4 == 0){
sqr_sum2 += input2.x * input2.x;
}else if((out_c2 + i)%4 == 1){
sqr_sum2 += input2.y * input2.y;
}else if((out_c2 + i)%4 == 2){
sqr_sum2 += input2.z * input2.z;
}else{
sqr_sum2 += input2.w * input2.w;
}
}
if(out_c3 + i>=0&&out_c3 + i<out_C){
input_c3 = (out_c3 + i)/4;
input_pos3.x = input_c3 * out_W + out_w;
input_pos3.y = out_nh;
input3 = convert_float4(read_imageh(input_image, sampler,input_pos3));
if((out_c3 + i)%4 == 0){
sqr_sum3 += input3.x * input3.x;
}else if((out_c3 + i)%4 == 1){
sqr_sum3 += input3.y * input3.y;
}else if((out_c3 + i)%4 == 2){
sqr_sum3 += input3.z * input3.z;
}else{
sqr_sum3 += input3.w * input3.w;
}
}
}
float4 output = (float4)0.0f;
float4 input;
int2 output_pos;
output_pos.x = out_c * out_W + out_w;
output_pos.y = out_nh;
input = convert_float4(read_imageh(input_image, sampler,output_pos));
output.x = input.x / (pow(k + alpha * (sqr_sum0),beta));
if(out_C - 4 * out_c>=2){
output.y = input.y / (pow(k + alpha * (sqr_sum1),beta));
}
if(out_C - 4 * out_c>=3){
output.z = input.z / (pow(k + alpha * (sqr_sum2),beta));
}
if(out_C - 4 * out_c>=4){
output.w = input.w / (pow(k + alpha * (sqr_sum3),beta));
}
half4 tmp = convert_half4(output);
write_imageh(output_image, output_pos, tmp);
}
\ No newline at end of file
......@@ -31,11 +31,13 @@ __kernel void pool_max(
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int start_h = max(out_h * stride_h - pad_top, 0);
int start_h = out_h * stride_h - pad_top;
int end_h = min(start_h + ksize_h, in_height);
start_h = max(start_h,0);
int start_w = max(out_w * stride_w - pad_left, 0);
int start_w = out_w * stride_w - pad_left;
int end_w = min(start_w + ksize_w, in_width);
start_w = max(start_w,0);
const int pos_in_x = out_c * in_width;
const int pos_in_y = out_n * in_height;
......
......@@ -23,12 +23,17 @@ template <>
bool ConcatKernel<GPU_CL, float>::Init(ConcatParam<GPU_CL> *param) {
if (param->Out()->dims().size() < 4) {
this->cl_helper_.AddKernel("concatByH", "concat_kernel.cl");
} else if (param->Out()->dims().size() == 4) {
this->cl_helper_.AddKernel("concatByC0", "concat_kernel.cl");
this->cl_helper_.AddKernel("concatByC", "concat_kernel.cl");
}
return true;
}
template <>
void ConcatKernel<GPU_CL, float>::Compute(const ConcatParam<GPU_CL> &param) {
DLOG << "yangfei50";
DLOG << param.Out()->dims();
if (param.Out()->dims().size() < 4) {
auto kernel = this->cl_helper_.KernelAt(0);
auto inputs = param.Inputs();
......@@ -62,6 +67,76 @@ void ConcatKernel<GPU_CL, float>::Compute(const ConcatParam<GPU_CL> &param) {
out_H_Start += inputs[i]->dims()[0];
}
}
} else {
auto kernel0 = this->cl_helper_.KernelAt(0);
auto kernel1 = this->cl_helper_.KernelAt(1);
auto inputs = param.Inputs();
auto *output_image = param.Out()->GetCLImage();
int out_C_Start = 0;
auto input_image = inputs[0]->GetCLImage();
auto default_work_size = this->cl_helper_.DefaultWorkSize(*inputs[0]);
int out_W = param.Out()->dims()[3];
cl_int status;
status = clSetKernelArg(kernel0, 0, sizeof(cl_mem), &input_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel0, 1, sizeof(cl_mem), &output_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel0, 2, sizeof(int), &out_W);
CL_CHECK_ERRORS(status);
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel0, default_work_size.size(),
NULL, default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
out_C_Start += inputs[0]->dims()[1];
for (int i = 1; i < inputs.size(); i++) {
auto input_image1 = inputs[i - 1]->GetCLImage();
auto input_image2 = inputs[i]->GetCLImage();
default_work_size = this->cl_helper_.DefaultWorkSize(*inputs[i]);
int out_C = param.Out()->dims()[1];
int out_H = param.Out()->dims()[2];
int in_W = inputs[i]->dims()[3];
int in_H = inputs[i]->dims()[2];
int in_C1 = inputs[i - 1]->dims()[1];
int in_C2 = inputs[i]->dims()[1];
DLOG << "第" << i << "个";
DLOG << "out_C=" << out_C;
DLOG << "out_H=" << out_H;
DLOG << "in_W=" << in_W;
DLOG << "in_H=" << in_H;
DLOG << "in_C1=" << in_C1;
DLOG << "in_C2=" << in_C2;
DLOG << "out_C_Start = " << out_C_Start;
status = clSetKernelArg(kernel1, 0, sizeof(cl_mem), &input_image1);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel1, 1, sizeof(cl_mem), &input_image2);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel1, 2, sizeof(cl_mem), &output_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel1, 3, sizeof(int), &out_C);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel1, 4, sizeof(int), &out_H);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel1, 5, sizeof(int), &out_W);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel1, 6, sizeof(int), &out_C_Start);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel1, 7, sizeof(int), &in_W);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel1, 8, sizeof(int), &in_H);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel1, 9, sizeof(int), &in_C1);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel1, 10, sizeof(int), &in_C2);
CL_CHECK_ERRORS(status);
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel1, default_work_size.size(),
NULL, default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
out_C_Start += inputs[i]->dims()[1];
}
}
}
......
......@@ -51,8 +51,16 @@ bool ConvAddKernel<GPU_CL, float>::Init(FusionConvAddParam<GPU_CL> *param) {
this->cl_helper_.AddKernel("conv_3x3", "conv_add_kernel.cl");
} else {
PADDLE_MOBILE_THROW_EXCEPTION(" not support ");
} else if (param->Filter()->dims()[2] == 7 &&
param->Filter()->dims()[3] == 7) {
param->Filter()->InitCLImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("conv_7x7", "conv_add_kernel.cl");
} else if (param->Filter()->dims()[2] == 5 &&
param->Filter()->dims()[3] == 5) {
param->Filter()->InitCLImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("conv_5x5", "conv_add_kernel.cl");
}
return true;
......
......@@ -52,6 +52,16 @@ bool ConvAddReluKernel<GPU_CL, float>::Init(
this->cl_helper_.AddKernel("conv_3x3", "conv_add_relu_kernel.cl");
} else if (param->Filter()->dims()[2] == 7 &&
param->Filter()->dims()[3] == 7) {
param->Filter()->InitCLImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("conv_7x7", "conv_add_relu_kernel.cl");
} else if (param->Filter()->dims()[2] == 5 &&
param->Filter()->dims()[3] == 5) {
param->Filter()->InitCLImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("conv_5x5", "conv_add_relu_kernel.cl");
} else {
PADDLE_MOBILE_THROW_EXCEPTION(" not support ");
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_FC_OP
#include "operators/kernel/fusion_fc_kernel.h"
#include "operators/math/math_function.h"
namespace paddle_mobile {
namespace operators {
template <>
bool FusionFcKernel<GPU_CL, float>::Init(FusionFcParam<GPU_CL> *param) {
param->InputY()->InitNormalCLImage(cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
param->InputZ()->InitNormalCLImage(cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl");
this->cl_helper_.AddKernel("feed", "feed_kernel.cl");
return true;
}
template <typename P>
void FusionFcCompute(const FusionFcParam<GPU_CL> &param, cl_context context,
cl_command_queue commandQueue, cl_kernel kernel0,
cl_kernel kernel1) {
auto *input_x_image = param.InputX();
auto *input_y_image = param.InputY();
auto *input_z_image = param.InputZ();
int axis = param.Axis();
auto *out_image = param.Out();
Tensor *input_x = new Tensor();
input_x->Resize(input_x_image->dims());
input_x->mutable_data<float>();
framework::CLImageToTensor(input_x_image, input_x, context, commandQueue,
kernel0);
Tensor *input_y = new Tensor();
input_y->Resize(input_y_image->dims());
input_y->mutable_data<float>();
framework::CLImageToTensor(input_y_image, input_y, context, commandQueue,
kernel0);
Tensor *input_z = new Tensor();
input_z->Resize(input_z_image->dims());
input_z->mutable_data<float>();
framework::CLImageToTensor(input_z_image, input_z, context, commandQueue,
kernel0);
auto *input_z_data = input_z->data<float>();
DLOG << *input_x;
DLOG << *input_y;
DLOG << *input_z;
Tensor *out = new Tensor();
out->Resize(out_image->dims());
out->mutable_data<float>();
auto *out_data = out->mutable_data<float>();
const Tensor x_matrix =
input_x->dims().size() > 2
? framework::ReshapeToMatrix(*input_x, param.XNumColDims())
: *input_x;
const Tensor y_matrix =
input_y->dims().size() > 2
? framework::ReshapeToMatrix(*input_y, param.YNumColDims())
: *input_y;
auto out_dim = out->dims();
if (out_dim.size() != 2) {
out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
}
PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2.");
PADDLE_MOBILE_ENFORCE(input_z->dims().size() == 1, "inpu_z size must be 1");
PADDLE_MOBILE_ENFORCE(out_dim[1] == input_z->dims()[0],
" out_dim.size must be 2.");
axis = (axis == -1 ? out_dim.size() - input_z->dims().size() : axis);
PADDLE_MOBILE_ENFORCE(axis == 1, " to fit broadcast, axis = 1. ");
int64_t classes = input_z->numel();
for (int i = 0; i < out_dim[0]; i++) {
memory::Copy(out_data + i * classes, input_z_data, sizeof(float) * classes);
}
// for (int i = 0; i < out->numel(); i++) {
// DLOG << out_data[i];
// }
// bias_data的维度和out的维度一致
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(1), false);
out_image->InitEmptyImage(context, commandQueue, out->dims());
framework::TensorToCLImage(out, out_image, context, commandQueue, kernel1);
DLOG << *out;
delete (input_x);
delete (input_y);
delete (input_z);
delete (out);
PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2.");
// if (out_dim.size() != 2) {
// out->Resize(out_dim);
// }
}
template <>
void FusionFcKernel<GPU_CL, float>::Compute(
const FusionFcParam<GPU_CL> &param) {
auto kernel0 = this->cl_helper_.KernelAt(0);
auto kernel1 = this->cl_helper_.KernelAt(1);
FusionFcCompute<float>(param, this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue(), kernel0, kernel1);
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef LRN_OP
#include "operators/kernel/lrn_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool LrnKernel<GPU_CL, float>::Init(LrnParam<GPU_CL> *param) {
this->cl_helper_.AddKernel("lrn", "lrn_kernel.cl");
return true;
}
template <>
void LrnKernel<GPU_CL, float>::Compute(const LrnParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Out());
auto input_image = param.InputX()->GetCLImage();
auto x_dims = param.InputX()->dims();
auto output_image = param.Out()->GetCLImage();
const int N = x_dims[0];
const int C = x_dims[1];
const int H = x_dims[2];
const int W = x_dims[3];
const int n = param.N();
const float alpha = param.Alpha();
const float beta = param.Beta();
const float k = param.K();
DLOG << "n=" << n;
DLOG << "alpha=" << alpha;
DLOG << "beta=" << beta;
DLOG << "k=" << k;
DLOG << default_work_size;
DLOG << C;
DLOG << W;
cl_int status;
status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &input_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &output_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &C);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(int), &W);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(int), &n);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 5, sizeof(float), &k);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 6, sizeof(float), &alpha);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 7, sizeof(float), &beta);
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class FusionDequantAddBNReluKernel
: public framework::OpKernelBase<DeviceType,
FusionDequantAddBNReluParam<DeviceType>> {
public:
void Compute(const FusionDequantAddBNReluParam<DeviceType> &param);
bool Init(FusionDequantAddBNReluParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -58,7 +58,7 @@ bool ConvAddBNKernel<FPGA, float>::Init(FusionConvAddBNParam<FPGA> *param) {
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
fpga::format_conv_data(filter, out, bs_ptr, param->Groups());
fpga::format_conv_data(filter, out, &bs_ptr, param->Groups());
fpga::SplitConvArgs conv_arg = {0};
fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled,
......
......@@ -56,7 +56,7 @@ bool ConvAddBNReluKernel<FPGA, float>::Init(
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
fpga::format_conv_data(filter, out, bs_ptr, param->Groups());
fpga::format_conv_data(filter, out, &bs_ptr, param->Groups());
fpga::SplitConvArgs conv_arg = {0};
fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled,
......
......@@ -38,7 +38,7 @@ bool ConvAddKernel<FPGA, float>::Init(FusionConvAddParam<FPGA> *param) {
bs_ptr[i] = bias_ptr[i];
}
fpga::format_conv_data(filter, out, bs_ptr, param->Groups());
fpga::format_conv_data(filter, out, &bs_ptr, param->Groups());
fpga::SplitConvArgs conv_arg = {0};
fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled,
......
......@@ -38,7 +38,7 @@ bool ConvAddReluKernel<FPGA, float>::Init(FusionConvAddReluParam<FPGA> *param) {
bs_ptr[i] = bias_ptr[i];
}
fpga::format_conv_data(filter, out, bs_ptr, param->Groups());
fpga::format_conv_data(filter, out, &bs_ptr, param->Groups());
fpga::SplitConvArgs conv_arg = {0};
fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled,
......
......@@ -50,7 +50,7 @@ bool ConvBNKernel<FPGA, float>::Init(FusionConvBNParam<FPGA> *param) {
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
fpga::format_conv_data(filter, out, bs_ptr, param->Groups());
fpga::format_conv_data(filter, out, &bs_ptr, param->Groups());
fpga::SplitConvArgs conv_arg = {0};
fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled,
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef FUSION_CONVBNRELU_OP
#include "operators/kernel/conv_bn_relu_kernel.h"
#include "fpga/V2/filter.h"
namespace paddle_mobile {
namespace operators {
......@@ -50,7 +51,7 @@ bool ConvBNReluKernel<FPGA, float>::Init(FusionConvBNReluParam<FPGA> *param) {
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
fpga::format_conv_data(filter, out, bs_ptr, param->Groups());
fpga::format_conv_data(filter, out, &bs_ptr, param->Groups());
fpga::SplitConvArgs conv_arg = {0};
fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled,
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#ifdef LRN_OP
#include "lrn_op.h"
#include "operators/lrn_op.h"
namespace paddle_mobile {
namespace operators {
......@@ -32,6 +32,9 @@ namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(lrn, ops::LrnOp);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(lrn, ops::LrnOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
REGISTER_OPERATOR_MALI_GPU(lrn, ops::LrnOp);
#endif
......
......@@ -1631,11 +1631,11 @@ class FusionFcParam : public OpParam {
y_num_col_dims_ = GetAttr<int>("y_num_col_dims", attrs);
axis_ = GetAttr<int>("axis", attrs);
}
const GType *InputX() const { return input_x_; }
GType *InputX() const { return input_x_; }
const RType *InputY() const { return input_y_; }
RType *InputY() const { return input_y_; }
const RType *InputZ() const { return input_z_; }
RType *InputZ() const { return input_z_; }
GType *Out() const { return out_; }
......@@ -2555,7 +2555,7 @@ class QuantizeParam : public OpParam {
output_ = OutFrom<GType>(outputs, scope);
// online
// scale = max(abs(x))
online_scale_ = GetVarValue<GType>("OutScale", outputs, scope);
online_scale_ = OpParam::GetVarValue<GType>("OutScale", outputs, scope);
// offline
if (HasAttr("static_scale", attrs)) {
is_static_ = true;
......@@ -2565,6 +2565,11 @@ class QuantizeParam : public OpParam {
if (HasAttr("round_type", attrs)) {
round_type_ = GetAttr<RoundType>("round_type", attrs);
}
// get paddings
paddings_ = std::vector<int>({0, 0});
if (HasAttr("paddings", attrs)) {
paddings_ = GetAttr<vector<int>>("paddings", attrs);
}
}
public:
......@@ -2598,7 +2603,7 @@ class DequantizeParam : public OpParam {
const AttributeMap &attrs, const Scope &scope) {
input_ = InputXFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
activation_scale_ = GetVarValue<GType>("Scale", inputs, scope);
activation_scale_ = OpParam::GetVarValue<GType>("Scale", inputs, scope);
// dequantization is performed as x = x / static_scale / online_scale
if (HasAttr("weight_scale", attrs)) {
weight_scale_ = GetAttr<float>("weight_scale", attrs);
......@@ -2617,5 +2622,44 @@ class DequantizeParam : public OpParam {
};
#endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
template <typename Dtype>
class FusionDequantAddBNReluParam : public DequantizeParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionDequantAddBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: DequantizeParam<Dtype>(inputs, outputs, attrs, scope) {
// element wise add params
axis_ = OpParam::GetAttr<int>("axis", attrs);
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
// batch norm params
bn_mean_ = OpParam::GetVarValue<GType>("BNMean", inputs, scope);
bn_variance_ = OpParam::GetVarValue<GType>("BNVariance", inputs, scope);
bn_scale_ = OpParam::GetVarValue<GType>("BNScale", inputs, scope);
bn_bias_ = OpParam::GetVarValue<GType>("BNBias", inputs, scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
// output
output_ = OpParam::OutFrom<GType>(outputs, scope);
}
public:
// elementwise add
int axis_;
RType *bias_;
// batch norm
RType *bn_mean_;
RType *bn_variance_;
RType *bn_scale_;
RType *bn_bias_;
float epsilon_;
// output
RType *output_;
};
#endif
} // namespace operators
} // namespace paddle_mobile
......@@ -22,7 +22,10 @@ namespace operators {
template <typename DeviceType, typename T>
void QuantizeOp<DeviceType, T>::InferShape() const {
const auto &input_dims = this->param_.input_->dims();
auto input_dims = this->param_.input_->dims();
const std::vector<int> &paddings = this->param_.paddings_;
input_dims[2] += 2 * paddings[0];
input_dims[3] += 2 * paddings[1];
this->param_.output_->Resize(input_dims);
auto scale_dims = framework::make_ddim(std::vector<int>{1});
this->param_.online_scale_->Resize(scale_dims);
......
......@@ -12,58 +12,131 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "../test_helper.h"
#include "../test_include.h"
#include "operators/quantize_op.h"
namespace paddle_mobile {
static float find_abs_max(const Tensor *input) {
float max_abs = 0.f;
const float *x = input->data<const float>();
size_t size = input->numel();
for (size_t i = 0; i < size; ++i) {
float value = std::abs(x[i]);
if (value > max_abs) {
max_abs = value;
}
}
return max_abs;
namespace round {
enum RoundType {
RoundToEven = 0,
RoundAwayZero = 1,
RoundTowardsZero = 2,
};
}
static void quantize_round_to_even(const Tensor *input, const float scale,
Tensor *output) {
const float *x = input->data<const float>();
int8_t *y = output->mutable_data<int8_t>();
size_t size = input->numel();
for (size_t i = 0; i < size; ++i) {
float value = x[i] * scale;
float v = round(value);
template <round::RoundType T>
struct Round {
int8_t operator()(float x);
};
template <>
struct Round<round::RoundAwayZero> {
int8_t operator()(float x) { return std::round(x); }
};
template <>
struct Round<round::RoundTowardsZero> {
int8_t operator()(float x) { return int8_t(x); }
};
template <>
struct Round<round::RoundToEven> {
int8_t operator()(float x) {
int8_t ret = 0;
float v = std::round(x);
int32_t q = (int32_t)v;
if (abs(abs(q - value) - 0.5) > 0) {
y[i] = q;
if (abs(abs(q - x) - 0.5) > 0) {
ret = q;
} else {
if (abs(q) % 2 == 0) {
y[i] = q;
ret = q;
} else {
y[i] = q + ((q > 0) ? -1 : 1);
ret = q + ((q > 0) ? -1 : 1);
}
}
return ret;
}
};
template <round::RoundType T>
static void quantize(const Tensor *input, const float scale, const int pad,
const int8_t pad_val, Tensor *output) {
int batch_size = input->dims()[0];
int channels = input->dims()[1];
int input_h = input->dims()[2];
int input_w = input->dims()[3];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
size_t input_spatial = input_h * input_w;
size_t output_spatial = output_h * output_w;
const float *x = input->data<const float>();
int8_t *y = output->mutable_data<int8_t>();
for (int nc = 0; nc < batch_size * channels; ++nc) {
const float *xh = x + nc * input_spatial;
int8_t *yh = y + nc * output_spatial;
// pad top
for (int h = 0; h < pad; ++h, yh += output_w) {
for (int w = 0; w < output_w; ++w) {
yh[w] = pad_val;
}
}
for (int h = 0; h < input_h; ++h, yh += output_w, xh += input_w) {
// pad left
for (int w = 0; w < pad; ++w) {
yh[w] = pad_val;
}
for (int w = 0; w < input_w; ++w) {
yh[w + pad] = Round<T>()(xh[w] * scale);
}
// pad right
for (int w = 0; w < pad; ++w) {
yh[pad + input_w + w] = pad_val;
}
}
// pad bottom
for (int h = 0; h < pad; ++h, yh += output_w) {
for (int w = 0; w < output_w; ++w) {
yh[w] = pad_val;
}
}
}
}
static void quantize_round_to_nearest(const Tensor *input, const float scale,
Tensor *output) {
static float find_abs_max(const Tensor *input) {
float max_abs = 0.f;
const float *x = input->data<const float>();
int8_t *y = output->mutable_data<int8_t>();
size_t size = input->numel();
for (size_t i = 0; i < size; ++i) {
y[i] = round(x[i] * scale);
float value = std::abs(x[i]);
if (value > max_abs) {
max_abs = value;
}
}
return max_abs;
}
int TestQuqntizeOp() {
framework::DDim dim = framework::make_ddim({1, 3, 224, 224});
int TestQuqntizeOp(int argc, char *argv[]) {
if (argc < 5) {
std::cout
<< "Usage: ./test-quantize-op batch_size channel height width [pad]"
<< std::endl;
return 1;
}
int pad = 0;
int batch_size = atoi(argv[1]);
int channel = atoi(argv[2]);
int height = atoi(argv[3]);
int width = atoi(argv[4]);
if (argc == 6) {
pad = atoi(argv[5]);
}
std::cout << "batch_size: " << batch_size << ", channel: " << channel
<< ", height: " << height << ", width: " << width << std::endl;
framework::DDim dim =
framework::make_ddim({batch_size, channel, height, width});
VariableNameMap inputs;
VariableNameMap outputs;
......@@ -80,6 +153,7 @@ int TestQuqntizeOp() {
auto output_scale_var = scope.get()->Var("output_scale");
framework::AttributeMap attrs;
attrs["paddings"].Set<vector<int>>(std::vector<int>({pad, pad}));
auto *op = new operators::QuantizeOp<CPU, float>("quantize", inputs, outputs,
attrs, scope);
op->InferShape();
......@@ -96,10 +170,11 @@ int TestQuqntizeOp() {
output_scale_cmp, output_scale_data[0]);
framework::Tensor output_cmp;
output_cmp.Resize(dim);
output_cmp.Resize(output->dims());
float scale = 127 / output_scale_cmp;
// quantize_round_to_even(input, scale, &output_cmp);
quantize_round_to_nearest(input, scale, &output_cmp);
// quantize<round::RoundToEven>(input, scale, pad, 0, &output_cmp);
// quantize<round::RoundAwayZero>(input, scale, pad, 0, &output_cmp);
quantize<round::RoundTowardsZero>(input, scale, pad, 0, &output_cmp);
int8_t *output_cmp_data = output_cmp.data<int8_t>();
for (int i = 0; i < output->numel(); ++i) {
PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i],
......@@ -113,4 +188,6 @@ int TestQuqntizeOp() {
} // namespace paddle_mobile
int main() { return paddle_mobile::TestQuqntizeOp(); }
int main(int argc, char *argv[]) {
return paddle_mobile::TestQuqntizeOp(argc, argv);
}
......@@ -250,6 +250,7 @@ if(NOT FOUND_MATCH)
set(SUM_OP ON)
set(QUANT_OP ON)
set(DEQUANT_OP ON)
set(FUSION_DEQUANT_ADD_BN_RELU ON)
endif()
# option(BATCHNORM_OP "" ON)
......@@ -454,6 +455,9 @@ endif()
if (DEQUANT_OP)
add_definitions(-DDEQUANT_OP)
endif()
if (FUSION_DEQUANT_ADD_BN_RELU)
add_definitions(-DFUSION_DEQUANT_ADD_BN_RELU_OP)
endif()
if (TANH_OP)
add_definitions(-DTANH_OP)
......@@ -466,4 +470,4 @@ if (FUSION_DECONVADD_OP)
endif()
if (FUSION_DECONVADDRELU_OP)
add_definitions(-DFUSION_DECONVADDRELU_OP)
endif()
\ No newline at end of file
endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册