未验证 提交 8e11ee09 编写于 作者: R Ray Liu 提交者: GitHub

Merge branch 'develop' into fill_constant_op-dev

......@@ -12,14 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cstdlib>
#include <cstring>
#include <string>
#include "common/enforce.h"
#include "common/log.h"
#pragma once
namespace paddle_mobile {
template <int ID, typename Type>
struct IDToType {
typedef Type type_t;
......
......@@ -156,7 +156,7 @@ class AttrReader {
template <typename T>
inline T Get(const string &name) const {
PADDLE_MOBILE_ENFORCE(attrs_.count(name) != 0,
"%s should be in AttributeMap", name);
"%s should be in AttributeMap", name.c_str());
return ((Attribute)attrs_.at(name)).Get<T>();
}
......
......@@ -18,9 +18,9 @@ limitations under the License. */
#include <vector>
#include "framework/lod_tensor.h"
#include "framework/mixed_vector.h"
#include "framework/tensor.h"
#include "memory/t_malloc.h"
#include "mixed_vector.h"
namespace paddle_mobile {
namespace framework {
......
......@@ -343,7 +343,9 @@ inline Print &operator<<(Print &printer, const Tensor &tensor) {
} else if (tensor.type() == typeid(int64_t)) {
printer << tensor.data<int64_t>()[i] << " ";
} else if (tensor.type() == typeid(int8_t)) {
printer << static_cast<int32_t>(tensor.data<int8_t>()[i]) << " ";
printer << static_cast<int>(tensor.data<int8_t>()[i]) << " ";
} else if (tensor.type() == typeid(int32_t)) {
printer << tensor.data<int32_t>()[i] << " ";
}
}
#endif
......
......@@ -80,12 +80,13 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size,
}
template <typename Dtype>
void LoadMemInternal(void **data, framework::LoDTensor *tensor) {
static void LoadMemInternal(void **data, framework::LoDTensor *tensor,
bool quant_uint8 = false) {
char **data_buf = reinterpret_cast<char **>(data);
int64_t size = tensor->numel();
Dtype *tensor_data = tensor->mutable_data<Dtype>();
if (0) {
// TODO(hjchen2) should be moved into operator init function
if (quant_uint8) {
// should be moved into operator init function
float min_value;
float max_value;
memcpy(&min_value, data_buf, sizeof(float));
......@@ -141,7 +142,8 @@ void Executor<Dtype, P>::LoadMemory(
// parse tensor from stream
switch (tensor_desc.DataType()) {
case framework::VARTYPE_TYPE_FP32:
LoadMemInternal<float>(reinterpret_cast<void **>(data_buf), tensor);
LoadMemInternal<float>(reinterpret_cast<void **>(data_buf), tensor,
program_.quantification);
break;
case framework::VARTYPE_TYPE_INT8:
LoadMemInternal<int8_t>(reinterpret_cast<void **>(data_buf), tensor);
......
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef DEQUANT_OP
#include "operators/dequantize_op.h"
namespace paddle_mobile {
......@@ -30,3 +32,5 @@ namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(dequantize, ops::DequantizeOp);
#endif
#endif
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef DEQUANT_OP
#pragma once
#include <string>
......@@ -41,3 +43,5 @@ class DequantizeOp
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -14,7 +14,7 @@ limitations under the License. */
#ifdef ELEMENTWISEMUL_OP
#include "elementwise_mul_op.h"
#include "operators/elementwise_mul_op.h"
namespace paddle_mobile {
namespace operators {
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_MOBILE_CPU
#ifdef DEQUANT_OP
#include "operators/kernel/dequantize_kernel.h"
......@@ -38,7 +38,8 @@ void DequantizeKernel<CPU, float>::Compute(
const int32_t *x = input->data<const int32_t>();
float *y = output->mutable_data<float>();
size_t size = output->numel();
float scale = 1.f / (activation_scale * weight_scale);
// float scale = 1.f / (activation_scale * weight_scale);
float scale = activation_scale / weight_scale;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t loop = size >> 4;
size_t remain = size & 0xF;
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_MOBILE_CPU
#ifdef QUANT_OP
#include "operators/kernel/quantize_kernel.h"
#include <cmath>
......@@ -225,7 +225,7 @@ static void quantize_round_to_nearest(const Tensor *input, const float scale,
const float *x = input->data<const float>();
int8_t *y = output->mutable_data<int8_t>();
size_t size = input->numel();
#ifdef defined(__ARM_NEON__) || defined(__ARM_NEON)
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t loop = size >> 4;
size_t remain = size & 0xF;
for (size_t i = 0; i < loop; ++i) {
......@@ -280,17 +280,18 @@ void QuantizeKernel<CPU, float>::Compute(
}
max_abs = std::max(max_abs, 1e-6f);
// only support int8 currently
float online_scale = 127 / max_abs;
param.online_scale_->mutable_data<float>()[0] = online_scale;
float scale = 127 / max_abs;
param.online_scale_->mutable_data<float>()[0] = max_abs;
switch (param.round_type_) {
case ROUND_NEAREST_TO_EVEN:
quantize_round_to_even(input, online_scale, output);
quantize_round_to_even(input, scale, output);
break;
case ROUND_NEAREST_TOWARDS_ZERO:
quantize_round_to_zero(input, online_scale, output);
quantize_round_to_zero(input, scale, output);
break;
case ROUND_NEAREST_AWAY_ZERO:
quantize_round_to_nearest(input, online_scale, output);
quantize_round_to_nearest(input, scale, output);
break;
default:
LOG(kLOG_ERROR) << "round type is not supported.";
break;
......
......@@ -16,24 +16,27 @@ limitations under the License. */
#pragma once
#include <vector>
#include "operators/math/conv_arm_int8.h"
#include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/pad.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype>
inline void ConvBasic(const ConvParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor *output = param.Output();
output->mutable_data<float>();
int groups = param.Groups();
std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings();
std::vector<int> dilations = param.Dilations();
const std::vector<int> strides = param.Strides();
const std::vector<int> paddings = param.Paddings();
const std::vector<int> dilations = param.Dilations();
const int batch_size = static_cast<int>(input->dims()[0]);
......@@ -57,7 +60,7 @@ inline void ConvBasic(const ConvParam<CPU> &param) {
Tensor col;
Tensor col_matrix;
if (is_expand) {
col.mutable_data<float>(col_shape);
col.mutable_data<Dtype>(col_shape);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
......@@ -76,8 +79,8 @@ inline void ConvBasic(const ConvParam<CPU> &param) {
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups;
math::Vol2ColFunctor<CPU, float> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, float> im2col;
math::Vol2ColFunctor<CPU, Dtype> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, Dtype> im2col;
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
......@@ -96,6 +99,7 @@ inline void ConvBasic(const ConvParam<CPU> &param) {
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(in_slice, dilations, strides, paddings, &col);
......@@ -104,29 +108,85 @@ inline void ConvBasic(const ConvParam<CPU> &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float>(filter_slice, false, col_matrix, false,
math::matmul<Dtype>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0));
}
}
}
inline void ConvCompute_int8(const ConvParam<CPU> &param) {
typedef void (*ConvFunc)(const Tensor &input, const Tensor &kernel,
Tensor *output);
static ConvFunc conv_funcs_table[7][5] = {
{0, 0, 0, 0, 0}, // k = 1
{0, 0, 0, 0, 0}, {conv3x3s1_int8, 0, 0, 0, 0}, // k = 3
{0, 0, 0, 0, 0}, {conv5x5s1_int8, 0, 0, 0, 0}, // k = 5
{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, // k = 7
};
const Tensor *input = param.Input();
Tensor *filter = param.Filter();
Tensor *output = param.Output();
int groups = param.Groups();
const std::vector<int> &strides = param.Strides();
const std::vector<int> &paddings = param.Paddings();
const std::vector<int> &dilations = param.Dilations();
int kernel_h = filter->dims()[2];
int kernel_w = filter->dims()[3];
output->mutable_data<int32_t>();
ConvFunc conv_func = 0;
if (strides[1] == strides[0] && strides[1] < 6 && kernel_h == kernel_w &&
kernel_h < 8 && groups == 1 && dilations[0] == dilations[1] &&
dilations[1] == 1) {
conv_func = conv_funcs_table[kernel_h - 1][strides[0] - 1];
}
if (conv_func) {
int batch_size = input->dims()[0];
math::PadFunctor<CPU, int8_t> pad;
Tensor input_pad;
for (int i = 0; i < batch_size; ++i) {
Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1);
if (paddings[0] == 0 && paddings[1] == 0) {
input_pad = in_batch;
} else {
framework::DDim pad_shape = in_batch.dims();
pad_shape[2] += 2 * paddings[0];
pad_shape[3] += 2 * paddings[1];
input_pad.mutable_data<int8_t>(pad_shape);
pad(in_batch, paddings[0], paddings[1], &input_pad);
}
conv_func(input_pad, *filter, &out_batch);
}
} else {
ConvBasic<int8_t>(param);
}
}
template <typename P>
void ConvCompute(const ConvParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3) {
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
if (param.Input()->type() == typeid(int8_t)) {
ConvCompute_int8(param);
} else {
ConvBasic(param);
param.Output()->mutable_data<float>();
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3) {
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
} else {
ConvBasic<float>(param);
}
}
}
......
......@@ -44,7 +44,7 @@ void DepthwiseConvCompute(const ConvParam<CPU> &param) {
Bias, false);
} else {
ConvBasic(param);
ConvBasic<float>(param);
}
}
......
......@@ -15,8 +15,12 @@ limitations under the License. */
#ifdef ELEMENTWISEADD_OP
#pragma once
#include "operators/math/elementwise_op_function.h"
#include "operators/op_param.h"
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
#endif
namespace paddle_mobile {
namespace operators {
......@@ -33,8 +37,61 @@ void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) {
Tensor *Out = param.Out();
Out->mutable_data<float>();
int axis = param.Axis();
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
const auto &x_dims = input_x->dims();
const auto &y_dims = input_y->dims();
/// axis = -1 represent the last dimensions.
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
size_t batch = 1;
size_t channels = 1;
size_t elementwise_num = 1;
for (int i = 0; i < axis; ++i) {
batch *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
channels *= y_dims[i];
}
for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) {
elementwise_num *= x_dims[i];
}
const float *bias_data = input_y->data<float>();
const float *input_data = input_x->data<float>();
float *output_data = Out->mutable_data<float>();
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
size_t offset = (i * channels + j) * elementwise_num;
const float *input = input_data + offset;
const float *bias = bias_data + j;
float *output = output_data + offset;
int loop = elementwise_num >> 0x4;
int remain = elementwise_num & 0xF;
for (int k = 0; k < loop; ++k) {
float32x4_t rb = vdupq_n_f32(*bias);
float32x4_t r0 = vld1q_f32(input);
float32x4_t r1 = vld1q_f32(input + 4);
float32x4_t r2 = vld1q_f32(input + 8);
float32x4_t r3 = vld1q_f32(input + 12);
r0 = vaddq_f32(r0, rb);
r1 = vaddq_f32(r1, rb);
r2 = vaddq_f32(r2, rb);
r3 = vaddq_f32(r3, rb);
vst1q_f32(output, r0);
vst1q_f32(output + 4, r1);
vst1q_f32(output + 8, r2);
vst1q_f32(output + 12, r3);
input += 16;
output += 16;
}
for (int k = 0; k < remain; ++k) {
output[k] = input[k] + *bias;
}
}
}
#else
ElementwiseComputeEx<AddFunctor<float>, float>(input_x, input_y, axis,
AddFunctor<float>(), Out);
#endif
}
template class ElementwiseAddKernel<CPU, float>;
......
......@@ -17,6 +17,9 @@ limitations under the License. */
#include <operators/math/transform.h>
#include "operators/op_param.h"
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
#endif
namespace paddle_mobile {
namespace operators {
......@@ -37,71 +40,100 @@ void ReluCompute(const ReluParam<CPU> &param) {
auto *out_ptr = out->mutable_data<float>();
int numel = input_x->numel();
// if (numel > 64) {
// asm volatile(
// "pld [%[input_x_ptr], #0] \n\t"
// "vmov.f32 q8, #0.0 \n\t"
// "subs %[num], %[num], #32 \n\t"
// "blt end_num_%= \n\t"
// "loop_num_%=: \n\t"
// "pld [%[input_x_ptr], #1024] \n\t"
//
// "vld1.32 {q0, q1}, [%[input_x_ptr]]! \n\t"
// "vld1.32 {q2, q3}, [%[input_x_ptr]]! \n\t"
// "vld1.32 {q4, q5}, [%[input_x_ptr]]! \n\t"
// "vld1.32 {q6, q7}, [%[input_x_ptr]]! \n\t"
//
// "vmax.f32 q0, q0, q8 \n\t"
// "vmax.f32 q1, q1, q8 \n\t"
// "vmax.f32 q2, q2, q8 \n\t"
// "vmax.f32 q3, q3, q8 \n\t"
// "vmax.f32 q4, q4, q8 \n\t"
// "vmax.f32 q5, q5, q8 \n\t"
// "vmax.f32 q6, q6, q8 \n\t"
// "vmax.f32 q7, q7, q8 \n\t"
//
// "vst1.32 {q0, q1}, [%[out_ptr]]! \n\t"
// "vst1.32 {q2, q3}, [%[out_ptr]]! \n\t"
// "vst1.32 {q4, q5}, [%[out_ptr]]! \n\t"
// "vst1.32 {q6, q7}, [%[out_ptr]]! \n\t"
//
// "subs %[num], %[num], #32 \n\t"
// "bge loop_num_%= \n\t"
// "end_num_%=: \n\t"
// "cmp %[num], #0 \n\t"
// "bge end_%= \n\t"
// "mov r6, #4 \n\t"
// "mul r5, %[num], r6 \n\t"
// "add %[input_x_ptr], %[input_x_ptr], r5 \n\t"
// "vld1.32 {q0, q1}, [%[input_x_ptr]]! \n\t"
// "vld1.32 {q2, q3}, [%[input_x_ptr]]! \n\t"
// "vld1.32 {q4, q5}, [%[input_x_ptr]]! \n\t"
// "vld1.32 {q6, q7}, [%[input_x_ptr]]! \n\t"
// "vmax.f32 q0, q0, q8 \n\t"
// "vmax.f32 q1, q1, q8 \n\t"
// "vmax.f32 q2, q2, q8 \n\t"
// "vmax.f32 q3, q3, q8 \n\t"
// "vmax.f32 q4, q4, q8 \n\t"
// "vmax.f32 q5, q5, q8 \n\t"
// "vmax.f32 q6, q6, q8 \n\t"
// "vmax.f32 q7, q7, q8 \n\t"
// "add %[out_ptr], %[out_ptr], r5 \n\t"
// "vst1.32 {q0, q1}, [%[out_ptr]]! \n\t"
// "vst1.32 {q2, q3}, [%[out_ptr]]! \n\t"
// "vst1.32 {q4, q5}, [%[out_ptr]]! \n\t"
// "vst1.32 {q6, q7}, [%[out_ptr]]! \n\t"
// "end_%=: \n\t"
// :
// :
// [out_ptr] "r"(out_ptr), [input_x_ptr] "r"(input_x_ptr), [num]
// "r"(numel) : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6",
// "q7", "q8", "r5",
// "r6");
// } else {
ReluFunctor<float> func_;
math::Transform trans;
trans(input_x_ptr, input_x_ptr + numel, out_ptr, func_);
// }
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#if __aarch64__
if (numel > 0) {
int loop = numel >> 0x4;
int remain = numel & 0xF;
float32x4_t zero = vdupq_n_f32(0.f);
for (int i = 0; i < loop; ++i) {
float32x4_t r0 = vld1q_f32(input_x_ptr);
float32x4_t r1 = vld1q_f32(input_x_ptr + 4);
float32x4_t r2 = vld1q_f32(input_x_ptr + 8);
float32x4_t r3 = vld1q_f32(input_x_ptr + 12);
r0 = vmaxq_f32(r0, zero);
r1 = vmaxq_f32(r1, zero);
r2 = vmaxq_f32(r2, zero);
r3 = vmaxq_f32(r3, zero);
vst1q_f32(out_ptr, r0);
vst1q_f32(out_ptr + 4, r1);
vst1q_f32(out_ptr + 8, r2);
vst1q_f32(out_ptr + 12, r3);
input_x_ptr += 16;
out_ptr += 16;
}
for (int i = 0; i < remain; ++i) {
out_ptr[i] = (input_x_ptr[i] > 0) * input_x_ptr[i];
}
#else
if (numel > 64) {
asm volatile(
"pld [%[input_x_ptr], #0] \n\t"
"vmov.f32 q8, #0.0 \n\t"
"subs %[num], %[num], #32 \n\t"
"blt end_num_%= \n\t"
"loop_num_%=: \n\t"
"pld [%[input_x_ptr], #1024] \n\t"
"vld1.32 {q0, q1}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q4, q5}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q6, q7}, [%[input_x_ptr]]! \n\t"
"vmax.f32 q0, q0, q8 \n\t"
"vmax.f32 q1, q1, q8 \n\t"
"vmax.f32 q2, q2, q8 \n\t"
"vmax.f32 q3, q3, q8 \n\t"
"vmax.f32 q4, q4, q8 \n\t"
"vmax.f32 q5, q5, q8 \n\t"
"vmax.f32 q6, q6, q8 \n\t"
"vmax.f32 q7, q7, q8 \n\t"
"vst1.32 {q0, q1}, [%[out_ptr]]! \n\t"
"vst1.32 {q2, q3}, [%[out_ptr]]! \n\t"
"vst1.32 {q4, q5}, [%[out_ptr]]! \n\t"
"vst1.32 {q6, q7}, [%[out_ptr]]! \n\t"
"subs %[num], %[num], #32 \n\t"
"bge loop_num_%= \n\t"
"end_num_%=: \n\t"
"cmp %[num], #0 \n\t"
"bge end_%= \n\t"
"mov r6, #4 \n\t"
"mul r5, %[num], r6 \n\t"
"add %[input_x_ptr], %[input_x_ptr], r5 \n\t"
"vld1.32 {q0, q1}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q4, q5}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q6, q7}, [%[input_x_ptr]]! \n\t"
"vmax.f32 q0, q0, q8 \n\t"
"vmax.f32 q1, q1, q8 \n\t"
"vmax.f32 q2, q2, q8 \n\t"
"vmax.f32 q3, q3, q8 \n\t"
"vmax.f32 q4, q4, q8 \n\t"
"vmax.f32 q5, q5, q8 \n\t"
"vmax.f32 q6, q6, q8 \n\t"
"vmax.f32 q7, q7, q8 \n\t"
"add %[out_ptr], %[out_ptr], r5 \n\t"
"vst1.32 {q0, q1}, [%[out_ptr]]! \n\t"
"vst1.32 {q2, q3}, [%[out_ptr]]! \n\t"
"vst1.32 {q4, q5}, [%[out_ptr]]! \n\t"
"vst1.32 {q6, q7}, [%[out_ptr]]! \n\t"
"end_%=: \n\t"
:
:
[out_ptr] "r"(out_ptr), [input_x_ptr] "r"(input_x_ptr), [num] "r"(numel)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "r5",
"r6");
#endif
} else {
#endif
ReluFunctor<float> func_;
math::Transform trans;
trans(input_x_ptr, input_x_ptr + numel, out_ptr, func_);
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
}
#endif
}
} // namespace operators
} // namespace paddle_mobile
......
......@@ -15,11 +15,14 @@ limitations under the License. */
#ifdef SUM_OP
#pragma once
#include <vector>
#include "operators/math/selected_rows_functor.h"
namespace paddle_mobile {
namespace operators {
using LoDTensorArray = std::vector<LoDTensor>;
template <typename P>
void SumCompute(const SumParam<CPU> &param) {
auto inputsvars = param.InputsVars();
......@@ -63,31 +66,21 @@ void SumCompute(const SumParam<CPU> &param) {
std::unique_ptr<framework::SelectedRows> in0;
if (in_place) {
// If is in_place, we store the input[0] to in0
auto *in_sel0 = inputsvars[0]->Get<SelectedRows>();
auto *in_sel0 = inputsvars[0]->Get<framework::SelectedRows>();
auto &rows = in_sel0->rows();
//#ifdef PADDLE_WITH_CUDA
// std::vector<int64_t> rows_in_cpu;
// rows_in_cpu.reserve(rows.size());
// for (auto item : rows) {
// rows_in_cpu.push_back(item);
// }
// in0.reset(new framework::SelectedRows(rows_in_cpu,
// in_sel0.height()));
//#else
in0.reset(new framework::SelectedRows(rows, in_sel0->height()));
//#endif
in0->mutable_value()->ShareDataWith(in_sel0->value());
}
auto get_selected_row = [&](size_t i) -> const SelectedRows & {
auto get_selected_row = [&](size_t i) -> const framework::SelectedRows & {
if (i == 0 && in0) {
return *in0.get();
} else {
return *(inputsvars[i]->Get<SelectedRows>());
return *(inputsvars[i]->Get<framework::SelectedRows>());
}
};
auto *out = outvar->GetMutable<SelectedRows>();
auto *out = outvar->GetMutable<framework::SelectedRows>();
out->mutable_rows()->clear();
auto *out_value = out->mutable_value();
......@@ -150,8 +143,6 @@ void SumCompute(const SumParam<CPU> &param) {
}
}
} else {
if (outvar->IsType<framework::Tensor>()) {
}
PADDLE_MOBILE_THROW_EXCEPTION(
"Unexpected branch, output variable type is %s", outvar->Type().name());
}
......
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef DEQUANT_OP
#pragma once
#include "framework/operator.h"
......@@ -30,3 +32,5 @@ class DequantizeKernel
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -23,8 +23,6 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
using namespace framework;
template <typename DeviceType, typename T>
class ElementwiseMulKernel
: public framework::OpKernelBase<DeviceType,
......
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef QUANT_OP
#pragma once
#include "framework/operator.h"
......@@ -30,3 +32,5 @@ class QuantizeKernel
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -21,8 +21,6 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
using namespace framework;
template <typename DeviceType, typename T>
class SumKernel
: public framework::OpKernelBase<DeviceType, SumParam<DeviceType>> {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CONV_OP
#include "operators/math/conv_arm_int8.h"
namespace paddle_mobile {
namespace operators {
void conv3x3s1_int8(const framework::Tensor& input,
const framework::Tensor& weight,
framework::Tensor* output) {
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
const int8_t* in_data = input.data<int8_t>();
const int8_t* w_data = weight.data<int8_t>();
int32_t* out_data = output->mutable_data<int32_t>();
// make sure that batch size is 1
int input_c = input.dims()[1];
int input_h = input.dims()[2];
int input_w = input.dims()[3];
int output_c = output->dims()[1];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
int image_size = input_h * input_w;
int out_image_size = output_h * output_w;
memset(out_data, 0, output_c * out_image_size * sizeof(int32_t));
#if __aarch64__
// TODO(hjchen2)
#else
int oc = 0;
#pragma omp parallel for
for (; oc < output_c - 1; oc += 2) {
for (int ic = 0; ic < input_c; ++ic) {
const int8_t* kernel0 = w_data + (oc * input_c + ic) * 9;
const int8_t* kernel1 = w_data + ((oc + 1) * input_c + ic) * 9;
int32_t* output0 = out_data + oc * out_image_size;
int32_t* output0n = output0 + output_w;
int32_t* output1 = out_data + (oc + 1) * out_image_size;
int32_t* output1n = output1 + output_w;
int oh = 0;
for (; oh < output_h - 1; oh += 2) {
const int8_t* r0 = in_data + ic * image_size + oh * input_w;
const int8_t* r1 = r0 + input_w;
const int8_t* r2 = r1 + input_w;
const int8_t* r3 = r2 + input_w;
int ow = output_w >> 3;
int remain = output_w & 0x7;
if (ow > 0) {
asm volatile(
"vld1.8 {d0}, [%[kernel0]] \n"
"ldr r5, [%[kernel0], #8] \n"
"vld1.8 {d1}, [%[kernel1]] \n"
"ldr r6, [%[kernel1], #8] \n"
"0: \n"
"vld1.8 {d2-d3}, [%[r0]] \n" // r0
"add %[r0], #8 \n"
"vext.8 d4, d2, d3, #1 \n"
"vext.8 d5, d2, d3, #2 \n"
"vdup.s8 d6, d0[0] \n"
"vdup.s8 d7, d0[1] \n"
"vdup.s8 d8, d0[2] \n"
"vdup.s8 d9, d1[0] \n"
"vdup.s8 d10, d1[1] \n"
"vdup.s8 d11, d1[2] \n"
"vmull.s8 q6, d2, d6 \n"
"vmull.s8 q7, d4, d7 \n"
"vmlal.s8 q6, d5, d8 \n"
"vaddl.s16 q12, d12, d14 \n"
"vaddl.s16 q13, d13, d15 \n"
"vmull.s8 q6, d2, d9 \n"
"vmull.s8 q7, d4, d10 \n"
"vmlal.s8 q6, d5, d11 \n"
"vaddl.s16 q14, d12, d14 \n"
"vaddl.s16 q15, d13, d15 \n"
"vld1.8 {d2-d3}, [%[r1]] \n" // r1
"add %[r1], #8 \n"
"vext.8 d4, d2, d3, #1 \n"
"vext.8 d5, d2, d3, #2 \n"
"vmull.s8 q6, d2, d6 \n" // next row
"vmull.s8 q7, d4, d7 \n"
"vmlal.s8 q6, d5, d8 \n"
"vaddl.s16 q8, d12, d14 \n"
"vaddl.s16 q9, d13, d15 \n"
"vmull.s8 q6, d2, d9 \n"
"vmull.s8 q7, d4, d10 \n"
"vmlal.s8 q6, d5, d11 \n"
"vaddl.s16 q10, d12, d14 \n"
"vaddl.s16 q11, d13, d15 \n"
"vdup.s8 d6, d0[3] \n"
"vdup.s8 d7, d0[4] \n"
"vdup.s8 d8, d0[5] \n"
"vdup.s8 d9, d1[3] \n"
"vdup.s8 d10, d1[4] \n"
"vdup.s8 d11, d1[5] \n"
"vmull.s8 q6, d2, d6 \n"
"vmull.s8 q7, d4, d7 \n"
"vmlal.s8 q6, d5, d8 \n"
"vaddw.s16 q12, q12, d12 \n"
"vaddw.s16 q13, q13, d13 \n"
"vaddw.s16 q12, q12, d14 \n"
"vaddw.s16 q13, q13, d15 \n"
"vmull.s8 q6, d2, d9 \n"
"vmull.s8 q7, d4, d10 \n"
"vmlal.s8 q6, d5, d11 \n"
"vaddw.s16 q14, q14, d12 \n"
"vaddw.s16 q15, q15, d13 \n"
"vaddw.s16 q14, q14, d14 \n"
"vaddw.s16 q15, q15, d15 \n"
"vld1.8 {d2-d3}, [%[r2]] \n" // r2
"add %[r2], #8 \n"
"vext.8 d4, d2, d3, #1 \n"
"vext.8 d5, d2, d3, #2 \n"
"vmull.s8 q6, d2, d6 \n" // next row
"vmull.s8 q7, d4, d7 \n"
"vmlal.s8 q6, d5, d8 \n"
"vaddw.s16 q8, q8, d12 \n"
"vaddw.s16 q8, q8, d14 \n"
"vaddw.s16 q9, q9, d13 \n"
"vaddw.s16 q9, q9, d15 \n"
"vmull.s8 q6, d2, d9 \n"
"vmull.s8 q7, d4, d10 \n"
"vmlal.s8 q6, d5, d11 \n"
"vaddw.s16 q10, q10, d12 \n"
"vaddw.s16 q11, q11, d13 \n"
"vaddw.s16 q10, q10, d14 \n"
"vaddw.s16 q11, q11, d15 \n"
"vdup.s8 d6, d0[6] \n"
"vdup.s8 d7, d0[7] \n"
"vdup.s8 d8, r5 \n"
"vdup.s8 d9, d1[6] \n"
"vdup.s8 d10, d1[7] \n"
"vdup.s8 d11, r6 \n"
"vmull.s8 q6, d2, d6 \n"
"vmull.s8 q7, d4, d7 \n"
"vmlal.s8 q6, d5, d8 \n"
"vaddw.s16 q12, q12, d12 \n"
"vaddw.s16 q13, q13, d13 \n"
"vaddw.s16 q12, q12, d14 \n"
"vaddw.s16 q13, q13, d15 \n"
"vld1.32 {d12-d15}, [%[output0]] \n"
"vadd.s32 q6, q6, q12 \n"
"vadd.s32 q7, q7, q13 \n"
"vst1.32 {d12-d15}, [%[output0]]! \n"
"vmull.s8 q6, d2, d9 \n"
"vmull.s8 q7, d4, d10 \n"
"vmlal.s8 q6, d5, d11 \n"
"vaddw.s16 q14, q14, d12 \n"
"vaddw.s16 q15, q15, d13 \n"
"vaddw.s16 q14, q14, d14 \n"
"vaddw.s16 q15, q15, d15 \n"
"vld1.32 {d12-d15}, [%[output1]] \n"
"vadd.s32 q6, q6, q14 \n"
"vadd.s32 q7, q7, q15 \n"
"vst1.32 {d12-d15}, [%[output1]]! \n"
"vld1.8 {d2-d3}, [%[r3]] \n" // r3
"add %[r3], #8 \n"
"vext.8 d4, d2, d3, #1 \n"
"vext.8 d5, d2, d3, #2 \n"
"vmull.s8 q6, d2, d6 \n" // next row
"vmull.s8 q7, d4, d7 \n"
"vmlal.s8 q6, d5, d8 \n"
"vaddw.s16 q8, q8, d12 \n"
"vaddw.s16 q9, q9, d15 \n"
"vaddw.s16 q8, q8, d14 \n"
"vaddw.s16 q9, q9, d13 \n"
"vld1.32 {d12-d15}, [%[output0n]] \n"
"vadd.s32 q6, q6, q8 \n"
"vadd.s32 q7, q7, q9 \n"
"vst1.32 {d12-d15}, [%[output0n]]! \n"
"vmull.s8 q6, d2, d9 \n"
"vmull.s8 q7, d4, d10 \n"
"vmlal.s8 q6, d5, d11 \n"
"vaddw.s16 q10, q10, d12 \n"
"vaddw.s16 q11, q11, d15 \n"
"vaddw.s16 q10, q10, d14 \n"
"vaddw.s16 q11, q11, d13 \n"
"vld1.32 {d12-d15}, [%[output1n]] \n"
"vadd.s32 q6, q6, q10 \n"
"vadd.s32 q7, q7, q11 \n"
"vst1.32 {d12-d15}, [%[output1n]]! \n"
"subs %[ow], #1 \n"
"bne 0b \n"
: [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), [r3] "+r"(r3),
[ow] "+r"(ow), [output0] "+r"(output0), [output1] "+r"(output1),
[output0n] "+r"(output0n), [output1n] "+r"(output1n)
: [kernel0] "r"(kernel0), [kernel1] "r"(kernel1)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r5",
"r6");
}
if (remain > 0) {
asm volatile(
"vld1.8 {d0}, [%[kernel0]] \n"
"ldr r5, [%[kernel0], #8] \n"
"vld1.8 {d1}, [%[kernel1]] \n"
"ldr r6, [%[kernel1], #8] \n"
"0: \n"
"vld1.8 d4, [%[r0]] \n"
"vld1.8 d5, [%[r1]] \n"
"vld1.8 d6, [%[r2]] \n"
"vld1.8 d7, [%[r3]] \n"
"add %[r0], #1 \n"
"add %[r1], #1 \n"
"add %[r2], #1 \n"
"add %[r3], #1 \n"
"vdup.s8 d2, r5 \n"
"vdup.s8 d3, r6 \n"
"vext.8 d8, d0, d2, #3 \n"
"vext.8 d9, d0, d2, #6 \n"
"vext.8 d10, d1, d3, #3 \n"
"vext.8 d11, d1, d3, #6 \n"
"vmull.s8 q6, d4, d0 \n"
"vmull.s8 q7, d5, d8 \n"
"vmlal.s8 q6, d6, d9 \n"
"vaddl.s16 q12, d12, d14 \n"
"vdup.s32 d2, d24[1] \n"
"vadd.s32 d24, d24, d2 \n"
"vadd.s32 d24, d24, d25 \n"
"vmull.s8 q6, d4, d1 \n"
"vmull.s8 q7, d5, d10 \n"
"vmlal.s8 q6, d6, d11 \n"
"vaddl.s16 q13, d12, d14 \n"
"vdup.s32 d2, d26[1] \n"
"vadd.s32 d26, d26, d2 \n"
"vadd.s32 d26, d26, d27 \n"
"ldr r7, [%[output0]] \n"
"vdup.s32 d14, r7 \n"
"vadd.s32 d14, d14, d24 \n"
"vst1.32 d14[0], [%[output0]]! \n"
"ldr r7, [%[output1]] \n"
"vdup.s32 d14, r7 \n"
"vadd.s32 d14, d14, d26 \n"
"vst1.32 d14[0], [%[output1]]! \n"
"vmull.s8 q6, d5, d0 \n"
"vmull.s8 q7, d6, d8 \n"
"vmlal.s8 q6, d7, d9 \n"
"vaddl.s16 q12, d12, d14 \n"
"vdup.s32 d2, d24[1] \n"
"vadd.s32 d24, d24, d2 \n"
"vadd.s32 d24, d24, d25 \n"
"vmull.s8 q6, d5, d1 \n"
"vmull.s8 q7, d6, d10 \n"
"vmlal.s8 q6, d7, d11 \n"
"vaddl.s16 q13, d12, d14 \n"
"vdup.s32 d2, d26[1] \n"
"vadd.s32 d26, d26, d2 \n"
"vadd.s32 d26, d26, d27 \n"
"ldr r7, [%[output0n]] \n"
"vdup.s32 d14, r7 \n"
"vadd.s32 d14, d14, d24 \n"
"vst1.32 d14[0], [%[output0n]]! \n"
"ldr r7, [%[output1n]] \n"
"vdup.s32 d14, r7 \n"
"vadd.s32 d14, d14, d26 \n"
"vst1.32 d14[0], [%[output1n]]! \n"
"subs %[remain], #1 \n"
"bne 0b \n"
: [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), [r3] "+r"(r3),
[remain] "+r"(remain), [output0] "+r"(output0),
[output1] "+r"(output1), [output0n] "+r"(output0n),
[output1n] "+r"(output1n)
: [kernel0] "r"(kernel0), [kernel1] "r"(kernel1)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "r5", "r6", "r7");
}
output0 += output_w;
output1 += output_w;
output0n += output_w;
output1n += output_w;
}
// remain output height
for (; oh < output_h; ++oh) {
const int8_t* r0 = in_data + ic * image_size + oh * input_w;
const int8_t* r1 = r0 + input_w;
const int8_t* r2 = r1 + input_w;
const int8_t* r3 = r2 + input_w;
const int8_t* r4 = r3 + input_w;
int ow = output_w >> 3;
int remain = output_w & 0x7;
if (ow > 0) {
asm volatile(
"vld1.8 {d0}, [%[kernel0]] \n"
"ldr r5, [%[kernel0], #8] \n"
"vld1.8 {d1}, [%[kernel1]] \n"
"ldr r6, [%[kernel1], #8] \n"
"0: \n"
"vld1.8 {d2-d3}, [%[r0]] \n" // r0
"add %[r0], #8 \n"
"vext.8 d4, d2, d3, #1 \n"
"vext.8 d5, d2, d3, #2 \n"
"vdup.s8 d6, d0[0] \n"
"vdup.s8 d7, d0[1] \n"
"vdup.s8 d8, d0[2] \n"
"vdup.s8 d9, d1[0] \n"
"vdup.s8 d10, d1[1] \n"
"vdup.s8 d11, d1[2] \n"
"vmull.s8 q6, d2, d6 \n"
"vmull.s8 q7, d4, d7 \n"
"vmlal.s8 q6, d5, d8 \n"
"vaddl.s16 q12, d12, d14 \n"
"vaddl.s16 q13, d13, d15 \n"
"vmull.s8 q6, d2, d9 \n"
"vmull.s8 q7, d4, d10 \n"
"vmlal.s8 q6, d5, d11 \n"
"vaddl.s16 q14, d12, d14 \n"
"vaddl.s16 q15, d13, d15 \n"
"vld1.8 {d2-d3}, [%[r1]] \n" // r1
"add %[r1], #8 \n"
"vext.8 d4, d2, d3, #1 \n"
"vext.8 d5, d2, d3, #2 \n"
"vdup.s8 d6, d0[3] \n"
"vdup.s8 d7, d0[4] \n"
"vdup.s8 d8, d0[5] \n"
"vdup.s8 d9, d1[3] \n"
"vdup.s8 d10, d1[4] \n"
"vdup.s8 d11, d1[5] \n"
"vmull.s8 q6, d2, d6 \n"
"vmull.s8 q7, d4, d7 \n"
"vmlal.s8 q6, d5, d8 \n"
"vaddw.s16 q12, q12, d12 \n"
"vaddw.s16 q12, q12, d14 \n"
"vaddw.s16 q13, q13, d13 \n"
"vaddw.s16 q13, q13, d15 \n"
"vmull.s8 q6, d2, d9 \n"
"vmull.s8 q7, d4, d10 \n"
"vmlal.s8 q6, d5, d11 \n"
"vaddw.s16 q14, q14, d12 \n"
"vaddw.s16 q14, q14, d14 \n"
"vaddw.s16 q15, q15, d13 \n"
"vaddw.s16 q15, q15, d15 \n"
"vld1.8 {d2-d3}, [%[r2]] \n" // r2
"add %[r2], #8 \n"
"vext.8 d4, d2, d3, #1 \n"
"vext.8 d5, d2, d3, #2 \n"
"vdup.s8 d6, d0[6] \n"
"vdup.s8 d7, d0[7] \n"
"vdup.s8 d8, r5 \n"
"vdup.s8 d9, d1[6] \n"
"vdup.s8 d10, d1[7] \n"
"vdup.s8 d11, r6 \n"
"vmull.s8 q6, d2, d6 \n"
"vmull.s8 q7, d4, d7 \n"
"vmlal.s8 q6, d5, d8 \n"
"vaddw.s16 q12, q12, d12 \n"
"vaddw.s16 q12, q12, d14 \n"
"vaddw.s16 q13, q13, d13 \n"
"vaddw.s16 q13, q13, d15 \n"
"vmull.s8 q6, d2, d9 \n"
"vmull.s8 q7, d4, d10 \n"
"vmlal.s8 q6, d5, d11 \n"
"vaddw.s16 q14, q14, d12 \n"
"vaddw.s16 q14, q14, d14 \n"
"vaddw.s16 q15, q15, d13 \n"
"vaddw.s16 q15, q15, d15 \n"
"vld1.32 {d12-d15}, [%[output0]] \n"
"vadd.s32 q6, q6, q12 \n"
"vadd.s32 q7, q7, q13 \n"
"vst1.32 {d12-d15}, [%[output0]]! \n"
"vld1.32 {d12-d15}, [%[output1]] \n"
"vadd.s32 q6, q6, q14 \n"
"vadd.s32 q7, q7, q15 \n"
"vst1.32 {d12-d15}, [%[output1]]! \n"
"subs %[ow], #1 \n"
"bne 0b \n"
: [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), [ow] "+r"(ow),
[output0] "+r"(output0), [output1] "+r"(output1)
: [kernel0] "r"(kernel0), [kernel1] "r"(kernel1)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r5",
"r6");
}
if (remain > 0) {
asm volatile(
"vld1.8 {d0}, [%[kernel0]] \n"
"ldr r5, [%[kernel0], #8] \n"
"vld1.8 {d1}, [%[kernel1]] \n"
"ldr r6, [%[kernel1], #8] \n"
"0: \n"
"vld1.8 d4, [%[r0]] \n"
"vld1.8 d5, [%[r1]] \n"
"vld1.8 d6, [%[r2]] \n"
"add %[r0], #1 \n"
"add %[r1], #1 \n"
"add %[r2], #1 \n"
"vdup.s8 d2, r5 \n"
"vdup.s8 d3, r6 \n"
"vext.8 d8, d0, d2, #3 \n"
"vext.8 d9, d0, d2, #6 \n"
"vext.8 d10, d1, d3, #3 \n"
"vext.8 d11, d1, d3, #6 \n"
"vmull.s8 q6, d4, d0 \n"
"vmull.s8 q7, d5, d8 \n"
"vmlal.s8 q6, d6, d9 \n"
"vaddl.s16 q12, d12, d14 \n"
"vdup.s32 d2, d24[1] \n"
"vadd.s32 d24, d24, d2 \n"
"vadd.s32 d24, d24, d25 \n"
"vmull.s8 q6, d4, d1 \n"
"vmull.s8 q7, d5, d10 \n"
"vmlal.s8 q6, d6, d11 \n"
"vaddl.s16 q13, d12, d14 \n"
"vdup.s32 d2, d26[1] \n"
"vadd.s32 d26, d26, d2 \n"
"vadd.s32 d26, d26, d27 \n"
"ldr r7, [%[output0]] \n"
"vdup.s32 d14, r7 \n"
"vadd.s32 d14, d14, d24 \n"
"vst1.32 d14[0], [%[output0]]! \n"
"ldr r7, [%[output1]] \n"
"vdup.s32 d14, r7 \n"
"vadd.s32 d14, d14, d26 \n"
"vst1.32 d14[0], [%[output1]]! \n"
"subs %[remain], #1 \n"
"bne 0b \n"
: [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2),
[remain] "+r"(remain), [output0] "+r"(output0),
[output1] "+r"(output1)
: [kernel0] "r"(kernel0), [kernel1] "r"(kernel1)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "r5", "r6", "r7");
}
}
}
}
for (; oc < output_c; ++oc) {
for (int ic = 0; ic < input_c; ++ic) {
const int8_t* kernel0 = w_data + (oc * input_c + ic) * 9;
int32_t* output0 = out_data + oc * out_image_size;
int32_t* output0n = output0 + output_w;
int oh = 0;
for (; oh < output_h - 1; oh += 2) {
const int8_t* r0 = in_data + ic * image_size + oh * input_w;
const int8_t* r1 = r0 + input_w;
const int8_t* r2 = r1 + input_w;
const int8_t* r3 = r2 + input_w;
int ow = output_w >> 3;
int remain = output_w & 0x7;
if (ow > 0) {
asm volatile(
"vld1.8 {d0}, [%[kernel0]] \n"
"ldr r5, [%[kernel0], #8] \n"
"0: \n"
"vld1.8 {d2-d3}, [%[r0]] \n" // r0
"add %[r0], #8 \n"
"vext.8 d4, d2, d3, #1 \n"
"vext.8 d5, d2, d3, #2 \n"
"vdup.s8 d6, d0[0] \n"
"vdup.s8 d7, d0[1] \n"
"vdup.s8 d8, d0[2] \n"
"vmull.s8 q6, d2, d6 \n"
"vmull.s8 q7, d4, d7 \n"
"vmlal.s8 q6, d5, d8 \n"
"vaddl.s16 q12, d12, d14 \n"
"vaddl.s16 q13, d13, d15 \n"
"vld1.8 {d2-d3}, [%[r1]] \n" // r1
"add %[r1], #8 \n"
"vext.8 d4, d2, d3, #1 \n"
"vext.8 d5, d2, d3, #2 \n"
"vmull.s8 q6, d2, d6 \n" // next row
"vmull.s8 q7, d4, d7 \n"
"vmlal.s8 q6, d5, d8 \n"
"vaddl.s16 q8, d12, d14 \n"
"vaddl.s16 q9, d13, d15 \n"
"vdup.s8 d6, d0[3] \n"
"vdup.s8 d7, d0[4] \n"
"vdup.s8 d8, d0[5] \n"
"vmull.s8 q6, d2, d6 \n"
"vmull.s8 q7, d4, d7 \n"
"vmlal.s8 q6, d5, d8 \n"
"vaddw.s16 q12, q12, d12 \n"
"vaddw.s16 q12, q12, d14 \n"
"vaddw.s16 q13, q13, d13 \n"
"vaddw.s16 q13, q13, d15 \n"
"vld1.8 {d2-d3}, [%[r2]] \n" // r2
"add %[r2], #8 \n"
"vext.8 d4, d2, d3, #1 \n"
"vext.8 d5, d2, d3, #2 \n"
"vmull.s8 q6, d2, d6 \n" // next row
"vmull.s8 q7, d4, d7 \n"
"vmlal.s8 q6, d5, d8 \n"
"vaddw.s16 q8, q8, d12 \n"
"vaddw.s16 q8, q8, d14 \n"
"vaddw.s16 q9, q9, d13 \n"
"vaddw.s16 q9, q9, d15 \n"
"vdup.s8 d6, d0[6] \n"
"vdup.s8 d7, d0[7] \n"
"vdup.s8 d8, r5 \n"
"vmull.s8 q6, d2, d6 \n"
"vmull.s8 q7, d4, d7 \n"
"vmlal.s8 q6, d5, d8 \n"
"vaddw.s16 q12, q12, d12 \n"
"vaddw.s16 q12, q12, d14 \n"
"vaddw.s16 q13, q13, d13 \n"
"vaddw.s16 q13, q13, d15 \n"
"vld1.32 {d12-d15}, [%[output0]] \n"
"vadd.s32 q6, q6, q12 \n"
"vadd.s32 q7, q7, q13 \n"
"vst1.32 {d12-d15}, [%[output0]]! \n"
"vld1.8 {d2-d3}, [%[r3]] \n" // r3
"add %[r3], #8 \n"
"vext.8 d4, d2, d3, #1 \n"
"vext.8 d5, d2, d3, #2 \n"
"vmull.s8 q6, d2, d6 \n" // next row
"vmull.s8 q7, d4, d7 \n"
"vmlal.s8 q6, d5, d8 \n"
"vaddw.s16 q8, q8, d12 \n"
"vaddw.s16 q8, q8, d14 \n"
"vaddw.s16 q9, q9, d13 \n"
"vaddw.s16 q9, q9, d15 \n"
"vld1.32 {d12-d15}, [%[output0n]] \n"
"vadd.s32 q6, q6, q8 \n"
"vadd.s32 q7, q7, q9 \n"
"vst1.32 {d12-d15}, [%[output0n]]! \n"
"subs %[ow], #1 \n"
"bne 0b \n"
: [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), [r3] "+r"(r3),
[ow] "+r"(ow), [output0] "+r"(output0),
[output0n] "+r"(output0n)
: [kernel0] "r"(kernel0)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r5");
}
if (remain > 0) {
asm volatile(
"vld1.8 {d0}, [%[kernel0]] \n"
"ldr r5, [%[kernel0], #8] \n"
"0: \n"
"vld1.8 d4, [%[r0]] \n"
"vld1.8 d5, [%[r1]] \n"
"vld1.8 d6, [%[r2]] \n"
"vld1.8 d7, [%[r3]] \n"
"add %[r0], #1 \n"
"add %[r1], #1 \n"
"add %[r2], #1 \n"
"add %[r3], #1 \n"
"vdup.s8 d2, r5 \n"
"vext.8 d8, d0, d2, #3 \n"
"vext.8 d9, d0, d2, #6 \n"
"vmull.s8 q6, d4, d0 \n"
"vmull.s8 q7, d5, d8 \n"
"vmlal.s8 q6, d6, d9 \n"
"vaddl.s16 q12, d12, d14 \n"
"vdup.s32 d2, d24[1] \n"
"vadd.s32 d24, d24, d2 \n"
"vadd.s32 d24, d24, d25 \n"
"ldr r7, [%[output0]] \n"
"vdup.s32 d14, r7 \n"
"vadd.s32 d14, d14, d24 \n"
"vst1.32 d14[0], [%[output0]]! \n"
"vmull.s8 q6, d5, d0 \n"
"vmull.s8 q7, d6, d8 \n"
"vmlal.s8 q6, d7, d9 \n"
"vaddl.s16 q12, d12, d14 \n"
"vdup.s32 d2, d24[1] \n"
"vadd.s32 d24, d24, d2 \n"
"vadd.s32 d24, d24, d25 \n"
"ldr r7, [%[output0n]] \n"
"vdup.s32 d14, r7 \n"
"vadd.s32 d14, d14, d24 \n"
"vst1.32 d14[0], [%[output0n]]! \n"
"subs %[remain], #1 \n"
"bne 0b \n"
: [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), [r3] "+r"(r3),
[remain] "+r"(remain), [output0] "+r"(output0),
[output0n] "+r"(output0n)
: [kernel0] "r"(kernel0)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "r5", "r7");
}
output0 += output_w;
output0n += output_w;
}
// remain output height
for (; oh < output_h; ++oh) {
const int8_t* r0 = in_data + ic * image_size + oh * input_w;
const int8_t* r1 = r0 + input_w;
const int8_t* r2 = r1 + input_w;
int ow = output_w >> 3;
int remain = output_w & 0x7;
if (ow > 0) {
asm volatile(
"vld1.8 {d0}, [%[kernel0]] \n"
"ldr r5, [%[kernel0], #8] \n"
"0: \n"
"vld1.8 {d2-d3}, [%[r0]] \n" // r0
"add %[r0], #8 \n"
"vext.8 d4, d2, d3, #1 \n"
"vext.8 d5, d2, d3, #2 \n"
"vdup.s8 d6, d0[0] \n"
"vdup.s8 d7, d0[1] \n"
"vdup.s8 d8, d0[2] \n"
"vmull.s8 q6, d2, d6 \n"
"vmull.s8 q7, d4, d7 \n"
"vmlal.s8 q6, d5, d8 \n"
"vaddl.s16 q12, d12, d14 \n"
"vaddl.s16 q13, d13, d15 \n"
"vld1.8 {d2-d3}, [%[r1]] \n" // r1
"add %[r1], #8 \n"
"vext.8 d4, d2, d3, #1 \n"
"vext.8 d5, d2, d3, #2 \n"
"vdup.s8 d6, d0[3] \n"
"vdup.s8 d7, d0[4] \n"
"vdup.s8 d8, d0[5] \n"
"vmull.s8 q6, d2, d6 \n"
"vmull.s8 q7, d4, d7 \n"
"vmlal.s8 q6, d5, d8 \n"
"vaddw.s16 q12, q12, d12 \n"
"vaddw.s16 q12, q12, d14 \n"
"vaddw.s16 q13, q13, d13 \n"
"vaddw.s16 q13, q13, d15 \n"
"vld1.8 {d2-d3}, [%[r2]] \n" // r2
"add %[r2], #8 \n"
"vext.8 d4, d2, d3, #1 \n"
"vext.8 d5, d2, d3, #2 \n"
"vdup.s8 d6, d0[6] \n"
"vdup.s8 d7, d0[7] \n"
"vdup.s8 d8, r5 \n"
"vmull.s8 q6, d2, d6 \n"
"vmull.s8 q7, d4, d7 \n"
"vmlal.s8 q6, d5, d8 \n"
"vaddw.s16 q12, q12, d12 \n"
"vaddw.s16 q12, q12, d14 \n"
"vaddw.s16 q13, q13, d13 \n"
"vaddw.s16 q13, q13, d15 \n"
"vld1.32 {d12-d15}, [%[output0]] \n"
"vadd.s32 q6, q6, q12 \n"
"vadd.s32 q7, q7, q13 \n"
"vst1.32 {d12-d15}, [%[output0]]! \n"
"subs %[ow], #1 \n"
"bne 0b \n"
: [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), [ow] "+r"(ow),
[output0] "+r"(output0)
: [kernel0] "r"(kernel0)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r5");
}
if (remain > 0) {
asm volatile(
"vld1.8 {d0}, [%[kernel0]] \n"
"ldr r5, [%[kernel0], #8] \n"
"0: \n"
"vld1.8 d4, [%[r0]] \n"
"vld1.8 d5, [%[r1]] \n"
"vld1.8 d6, [%[r2]] \n"
"add %[r0], #1 \n"
"add %[r1], #1 \n"
"add %[r2], #1 \n"
"vdup.s8 d2, r5 \n"
"vext.8 d8, d0, d2, #3 \n"
"vext.8 d9, d0, d2, #6 \n"
"vmull.s8 q6, d4, d0 \n"
"vmull.s8 q7, d5, d8 \n"
"vmlal.s8 q6, d6, d9 \n"
"vaddl.s16 q12, d12, d14 \n"
"vdup.s32 d2, d24[1] \n"
"vadd.s32 d24, d24, d2 \n"
"vadd.s32 d24, d24, d25 \n"
"ldr r7, [%[output0]] \n"
"vdup.s32 d14, r7 \n"
"vadd.s32 d14, d14, d24 \n"
"vst1.32 d14[0], [%[output0]]! \n"
"subs %[remain], #1 \n"
"bne 0b \n"
: [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2),
[remain] "+r"(remain), [output0] "+r"(output0)
: [kernel0] "r"(kernel0)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "r5", "r7");
}
}
}
}
#endif
#else
// TODO(hjchen2)
#endif
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CONV_OP
#include "operators/math/conv_arm_int8.h"
namespace paddle_mobile {
namespace operators {
void conv5x5s1_int8(const framework::Tensor& input,
const framework::Tensor& weight,
framework::Tensor* output) {
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
const int8_t* in_data = input.data<int8_t>();
const int8_t* w_data = weight.data<int8_t>();
int32_t* out_data = output->mutable_data<int32_t>();
// make sure that batch size is 1
int input_c = input.dims()[1];
int input_h = input.dims()[2];
int input_w = input.dims()[3];
int output_c = output->dims()[1];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
int image_size = input_h * input_w;
int out_image_size = output_h * output_w;
memset(out_data, 0, output_c * out_image_size * sizeof(int32_t));
#if __aarch64__
// TODO(hjchen2)
#else
#pragma omp parallel for
for (int oc = 0; oc < output_c; ++oc) {
for (int ic = 0; ic < input_c; ++ic) {
const int8_t* kernel = w_data + (oc * input_c + ic) * 25;
int32_t* output0 = out_data + oc * out_image_size;
int32_t* output1 = output0 + output_w;
int oh = 0;
for (; oh < output_h - 1; oh += 2) {
const int8_t* r0 = in_data + ic * image_size + oh * input_w;
const int8_t* r1 = r0 + input_w;
const int8_t* r2 = r1 + input_w;
const int8_t* r3 = r2 + input_w;
const int8_t* r4 = r3 + input_w;
const int8_t* r5 = r4 + input_w;
int ow = output_w >> 3;
int remain = output_w & 0x7;
if (ow > 0) {
asm volatile("vld1.8 {d0-d3}, [%[kernel]] \n"
: [kernel] "+r"(kernel)
:
: "cc", "memory", "q0", "q1");
asm volatile(
"0: \n"
"vld1.8 {d4-d5}, [%[r0]] \n" // r0
"add %[r0], #8 \n"
"vext.8 d6, d4, d5, #1 \n"
"vext.8 d7, d4, d5, #2 \n"
"vext.8 d8, d4, d5, #3 \n"
"vext.8 d9, d4, d5, #4 \n"
"vdup.s8 d10, d0[0] \n"
"vdup.s8 d11, d0[1] \n"
"vdup.s8 d12, d0[2] \n"
"vdup.s8 d13, d0[3] \n"
"vdup.s8 d14, d0[4] \n"
"vmull.s8 q8, d4, d10 \n"
"vmull.s8 q9, d6, d11 \n"
"vmlal.s8 q8, d7, d12 \n"
"vmlal.s8 q9, d8, d13 \n"
"vaddl.s16 q14, d16, d18 \n"
"vaddl.s16 q15, d17, d19 \n"
"vmull.s8 q8, d9, d14 \n"
"vaddw.s16 q14, q14, d16 \n"
"vaddw.s16 q15, q15, d17 \n"
"vld1.8 {d4-d5}, [%[r1]] \n" // r1
"add %[r1], #8 \n"
"vext.8 d6, d4, d5, #1 \n"
"vext.8 d7, d4, d5, #2 \n"
"vext.8 d8, d4, d5, #3 \n"
"vext.8 d9, d4, d5, #4 \n"
"vmull.s8 q8, d4, d10 \n" // next row
"vmull.s8 q9, d6, d11 \n"
"vmlal.s8 q8, d7, d12 \n"
"vmlal.s8 q9, d8, d13 \n"
"vaddl.s16 q10, d16, d18 \n"
"vaddl.s16 q11, d17, d19 \n"
"vmull.s8 q8, d9, d14 \n"
"vaddw.s16 q10, q10, d16 \n"
"vaddw.s16 q11, q11, d17 \n"
"vdup.s8 d10, d0[5] \n"
"vdup.s8 d11, d0[6] \n"
"vdup.s8 d12, d0[7] \n"
"vdup.s8 d13, d1[0] \n"
"vdup.s8 d14, d1[1] \n"
"vmull.s8 q8, d4, d10 \n"
"vmull.s8 q9, d6, d11 \n"
"vmlal.s8 q8, d7, d12 \n"
"vmlal.s8 q9, d8, d13 \n"
"vaddl.s16 q12, d16, d18 \n"
"vaddl.s16 q13, d17, d19 \n"
"vmull.s8 q8, d9, d14 \n"
"vaddw.s16 q12, q12, d16 \n"
"vaddw.s16 q13, q13, d17 \n"
"vadd.s32 q14, q14, q12 \n"
"vadd.s32 q15, q15, q13 \n"
"vld1.8 {d4-d5}, [%[r2]] \n" // r2
"add %[r2], #8 \n"
"vext.8 d6, d4, d5, #1 \n"
"vext.8 d7, d4, d5, #2 \n"
"vext.8 d8, d4, d5, #3 \n"
"vext.8 d9, d4, d5, #4 \n"
"vmull.s8 q8, d4, d10 \n" // next row
"vmull.s8 q9, d6, d11 \n"
"vmlal.s8 q8, d7, d12 \n"
"vmlal.s8 q9, d8, d13 \n"
"vaddl.s16 q12, d16, d18 \n"
"vaddl.s16 q13, d17, d19 \n"
"vmull.s8 q8, d9, d14 \n"
"vaddw.s16 q12, q12, d16 \n"
"vaddw.s16 q13, q13, d17 \n"
"vadd.s32 q10, q10, q12 \n"
"vadd.s32 q11, q11, q13 \n"
"vdup.s8 d10, d1[2] \n"
"vdup.s8 d11, d1[3] \n"
"vdup.s8 d12, d1[4] \n"
"vdup.s8 d13, d1[5] \n"
"vdup.s8 d14, d1[6] \n"
"vmull.s8 q8, d4, d10 \n"
"vmull.s8 q9, d6, d11 \n"
"vmlal.s8 q8, d7, d12 \n"
"vmlal.s8 q9, d8, d13 \n"
"vaddl.s16 q12, d16, d18 \n"
"vaddl.s16 q13, d17, d19 \n"
"vmull.s8 q8, d9, d14 \n"
"vaddw.s16 q12, q12, d16 \n"
"vaddw.s16 q13, q13, d17 \n"
"vadd.s32 q14, q14, q12 \n"
"vadd.s32 q15, q15, q13 \n"
"vld1.8 {d4-d5}, [%[r3]] \n" // r3
"add %[r3], #8 \n"
"vext.8 d6, d4, d5, #1 \n"
"vext.8 d7, d4, d5, #2 \n"
"vext.8 d8, d4, d5, #3 \n"
"vext.8 d9, d4, d5, #4 \n"
"vmull.s8 q8, d4, d10 \n" // next row
"vmull.s8 q9, d6, d11 \n"
"vmlal.s8 q8, d7, d12 \n"
"vmlal.s8 q9, d8, d13 \n"
"vaddl.s16 q12, d16, d18 \n"
"vaddl.s16 q13, d17, d19 \n"
"vmull.s8 q8, d9, d14 \n"
"vaddw.s16 q12, q12, d16 \n"
"vaddw.s16 q13, q13, d17 \n"
"vadd.s32 q10, q10, q12 \n"
"vadd.s32 q11, q11, q13 \n"
"vdup.s8 d10, d1[7] \n"
"vdup.s8 d11, d2[0] \n"
"vdup.s8 d12, d2[1] \n"
"vdup.s8 d13, d2[2] \n"
"vdup.s8 d14, d2[3] \n"
"vmull.s8 q8, d4, d10 \n"
"vmull.s8 q9, d6, d11 \n"
"vmlal.s8 q8, d7, d12 \n"
"vmlal.s8 q9, d8, d13 \n"
"vaddl.s16 q12, d16, d18 \n"
"vaddl.s16 q13, d17, d19 \n"
"vmull.s8 q8, d9, d14 \n"
"vaddw.s16 q12, q12, d16 \n"
"vaddw.s16 q13, q13, d17 \n"
"vadd.s32 q14, q14, q12 \n"
"vadd.s32 q15, q15, q13 \n"
"vld1.8 {d4-d5}, [%[r4]] \n" // r4
"add %[r4], #8 \n"
"vext.8 d6, d4, d5, #1 \n"
"vext.8 d7, d4, d5, #2 \n"
"vext.8 d8, d4, d5, #3 \n"
"vext.8 d9, d4, d5, #4 \n"
"vmull.s8 q8, d4, d10 \n" // next row
"vmull.s8 q9, d6, d11 \n"
"vmlal.s8 q8, d7, d12 \n"
"vmlal.s8 q9, d8, d13 \n"
"vaddl.s16 q12, d16, d18 \n"
"vaddl.s16 q13, d17, d19 \n"
"vmull.s8 q8, d9, d14 \n"
"vaddw.s16 q12, q12, d16 \n"
"vaddw.s16 q13, q13, d17 \n"
"vadd.s32 q10, q10, q12 \n"
"vadd.s32 q11, q11, q13 \n"
"vdup.s8 d10, d2[4] \n"
"vdup.s8 d11, d2[5] \n"
"vdup.s8 d12, d2[6] \n"
"vdup.s8 d13, d2[7] \n"
"vdup.s8 d14, d3[0] \n"
"vmull.s8 q8, d4, d10 \n"
"vmull.s8 q9, d6, d11 \n"
"vmlal.s8 q8, d7, d12 \n"
"vmlal.s8 q9, d8, d13 \n"
"vaddl.s16 q12, d16, d18 \n"
"vaddl.s16 q13, d17, d19 \n"
"vmull.s8 q8, d9, d14 \n"
"vaddw.s16 q12, q12, d16 \n"
"vaddw.s16 q13, q13, d17 \n"
"vadd.s32 q14, q14, q12 \n"
"vadd.s32 q15, q15, q13 \n"
"vld1.32 {d24-d27}, [%[output0]] \n"
"vadd.s32 q12, q12, q14 \n"
"vadd.s32 q13, q13, q15 \n"
"vst1.32 {d24-d27}, [%[output0]]! \n"
"vld1.8 {d4-d5}, [%[r5]] \n" // row 5
"add %[r5], #8 \n"
"vext.8 d6, d4, d5, #1 \n"
"vext.8 d7, d4, d5, #2 \n"
"vext.8 d8, d4, d5, #3 \n"
"vext.8 d9, d4, d5, #4 \n"
"vmull.s8 q8, d4, d10 \n"
"vmull.s8 q9, d6, d11 \n"
"vmlal.s8 q8, d7, d12 \n"
"vmlal.s8 q9, d8, d13 \n"
"vaddl.s16 q12, d16, d18 \n"
"vaddl.s16 q13, d17, d19 \n"
"vmull.s8 q8, d9, d14 \n"
"vaddw.s16 q12, q12, d16 \n"
"vaddw.s16 q13, q13, d17 \n"
"vadd.s32 q10, q10, q12 \n"
"vadd.s32 q11, q11, q13 \n"
"vld1.32 {d24-d27}, [%[output1]] \n"
"vadd.s32 q12, q12, q10 \n"
"vadd.s32 q13, q13, q11 \n"
"vst1.32 {d24-d27}, [%[output1]]! \n"
"subs %[ow], #1 \n"
"bne 0b \n"
: [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), [r3] "+r"(r3),
[r4] "+r"(r4), [r5] "+r"(r5), [ow] "+r"(ow),
[output0] "+r"(output0), [output1] "+r"(output1)
:
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
}
if (remain > 0) {
asm volatile("vld1.8 {d0-d3}, [%[kernel]] \n"
: [kernel] "+r"(kernel)
:
: "cc", "memory", "q0", "q1");
asm volatile(
"0: \n"
"vld1.8 d4, [%[r0]] \n"
"vld1.8 d5, [%[r1]] \n"
"vld1.8 d6, [%[r2]] \n"
"vld1.8 d7, [%[r3]] \n"
"vld1.8 d8, [%[r4]] \n"
"vld1.8 d9, [%[r5]] \n"
"add %[r0], #1 \n"
"add %[r1], #1 \n"
"add %[r2], #1 \n"
"add %[r3], #1 \n"
"add %[r4], #1 \n"
"add %[r5], #1 \n"
"vext.8 d10, d0, d1, #5 \n"
"vext.8 d11, d1, d2, #2 \n"
"vext.8 d12, d1, d2, #7 \n"
"vext.8 d13, d2, d3, #4 \n"
"vmull.s8 q7, d4, d0 \n"
"vmull.s8 q8, d5, d10 \n"
"vmull.s8 q9, d6, d11 \n"
"vmlal.s8 q8, d7, d12 \n"
"vmlal.s8 q9, d8, d13 \n"
"vaddl.s16 q10, d14, d16 \n"
"vaddw.s16 q10, q10, d18 \n"
"vadd.s32 d4, d20, d21 \n"
"vaddl.s16 q10, d15, d17 \n"
"vaddw.s16 q10, q10, d19 \n"
"vdup.s32 d14, d4[0] \n"
"vdup.s32 d15, d4[1] \n"
"vadd.s32 d15, d15, d14 \n"
"vdup.s32 d14, d20[0] \n"
"vadd.s32 d15, d15, d14 \n"
"ldr r6, [%[output0]] \n"
"vdup.s32 d14, r6 \n"
"vadd.s32 d15, d15, d14 \n"
"vst1.32 d15[0], [%[output0]]! \n"
"vmull.s8 q7, d5, d0 \n"
"vmull.s8 q8, d6, d10 \n"
"vmull.s8 q9, d7, d11 \n"
"vmlal.s8 q8, d8, d12 \n"
"vmlal.s8 q9, d9, d13 \n"
"vaddl.s16 q10, d14, d16 \n"
"vaddw.s16 q10, q10, d18 \n"
"vadd.s32 d4, d20, d21 \n"
"vaddl.s16 q10, d15, d17 \n"
"vaddw.s16 q10, q10, d19 \n"
"vdup.s32 d14, d4[0] \n"
"vdup.s32 d15, d4[1] \n"
"vadd.s32 d15, d15, d14 \n"
"vdup.s32 d14, d20[0] \n"
"vadd.s32 d15, d15, d14 \n"
"ldr r6, [%[output1]] \n"
"vdup.s32 d14, r6 \n"
"vadd.s32 d15, d15, d14 \n"
"vst1.32 d15[0], [%[output1]]! \n"
"subs %[remain], #1 \n"
"bne 0b \n"
: [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), [r3] "+r"(r3),
[r4] "+r"(r4), [r5] "+r"(r5), [remain] "+r"(remain),
[output0] "+r"(output0), [output1] "+r"(output1)
:
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "r6");
}
output0 += output_w;
output1 += output_w;
}
// remain output height
for (; oh < output_h; ++oh) {
const int8_t* r0 = in_data + ic * image_size + oh * input_w;
const int8_t* r1 = r0 + input_w;
const int8_t* r2 = r1 + input_w;
const int8_t* r3 = r2 + input_w;
const int8_t* r4 = r3 + input_w;
int ow = output_w >> 3;
int remain = output_w & 0x7;
if (ow > 0) {
asm volatile("vld1.8 {d0-d3}, [%[kernel]] \n"
: [kernel] "+r"(kernel)
:
: "cc", "memory", "q0", "q1");
asm volatile(
"0: \n"
"vld1.8 {d4-d5}, [%[r0]] \n" // r0
"add %[r0], #8 \n"
"vext.8 d6, d4, d5, #1 \n"
"vext.8 d7, d4, d5, #2 \n"
"vext.8 d8, d4, d5, #3 \n"
"vext.8 d9, d4, d5, #4 \n"
"vdup.s8 d10, d0[0] \n"
"vdup.s8 d11, d0[1] \n"
"vdup.s8 d12, d0[2] \n"
"vdup.s8 d13, d0[3] \n"
"vdup.s8 d14, d0[4] \n"
"vmull.s8 q8, d4, d10 \n"
"vmull.s8 q9, d6, d11 \n"
"vmlal.s8 q8, d7, d12 \n"
"vmlal.s8 q9, d8, d13 \n"
"vaddl.s16 q14, d16, d18 \n"
"vaddl.s16 q15, d17, d19 \n"
"vmull.s8 q8, d9, d14 \n"
"vaddw.s16 q14, q14, d16 \n"
"vaddw.s16 q15, q15, d17 \n"
"vld1.8 {d4-d5}, [%[r1]] \n" // r1
"add %[r1], #8 \n"
"vext.8 d6, d4, d5, #1 \n"
"vext.8 d7, d4, d5, #2 \n"
"vext.8 d8, d4, d5, #3 \n"
"vext.8 d9, d4, d5, #4 \n"
"vdup.s8 d10, d0[5] \n"
"vdup.s8 d11, d0[6] \n"
"vdup.s8 d12, d0[7] \n"
"vdup.s8 d13, d1[0] \n"
"vdup.s8 d14, d1[1] \n"
"vmull.s8 q8, d4, d10 \n"
"vmull.s8 q9, d6, d11 \n"
"vmlal.s8 q8, d7, d12 \n"
"vmlal.s8 q9, d8, d13 \n"
"vaddl.s16 q12, d16, d18 \n"
"vaddl.s16 q13, d17, d19 \n"
"vmull.s8 q8, d9, d14 \n"
"vaddw.s16 q12, q12, d16 \n"
"vaddw.s16 q13, q13, d17 \n"
"vadd.s32 q14, q14, q12 \n"
"vadd.s32 q15, q15, q13 \n"
"vld1.8 {d4-d5}, [%[r2]] \n" // r2
"add %[r2], #8 \n"
"vext.8 d6, d4, d5, #1 \n"
"vext.8 d7, d4, d5, #2 \n"
"vext.8 d8, d4, d5, #3 \n"
"vext.8 d9, d4, d5, #4 \n"
"vdup.s8 d10, d1[2] \n"
"vdup.s8 d11, d1[3] \n"
"vdup.s8 d12, d1[4] \n"
"vdup.s8 d13, d1[5] \n"
"vdup.s8 d14, d1[6] \n"
"vmull.s8 q8, d4, d10 \n"
"vmull.s8 q9, d6, d11 \n"
"vmlal.s8 q8, d7, d12 \n"
"vmlal.s8 q9, d8, d13 \n"
"vaddl.s16 q12, d16, d18 \n"
"vaddl.s16 q13, d17, d19 \n"
"vmull.s8 q8, d9, d14 \n"
"vaddw.s16 q12, q12, d16 \n"
"vaddw.s16 q13, q13, d17 \n"
"vadd.s32 q14, q14, q12 \n"
"vadd.s32 q15, q15, q13 \n"
"vld1.8 {d4-d5}, [%[r3]] \n" // r3
"add %[r3], #8 \n"
"vext.8 d6, d4, d5, #1 \n"
"vext.8 d7, d4, d5, #2 \n"
"vext.8 d8, d4, d5, #3 \n"
"vext.8 d9, d4, d5, #4 \n"
"vdup.s8 d10, d1[7] \n"
"vdup.s8 d11, d2[0] \n"
"vdup.s8 d12, d2[1] \n"
"vdup.s8 d13, d2[2] \n"
"vdup.s8 d14, d2[3] \n"
"vmull.s8 q8, d4, d10 \n"
"vmull.s8 q9, d6, d11 \n"
"vmlal.s8 q8, d7, d12 \n"
"vmlal.s8 q9, d8, d13 \n"
"vaddl.s16 q12, d16, d18 \n"
"vaddl.s16 q13, d17, d19 \n"
"vmull.s8 q8, d9, d14 \n"
"vaddw.s16 q12, q12, d16 \n"
"vaddw.s16 q13, q13, d17 \n"
"vadd.s32 q14, q14, q12 \n"
"vadd.s32 q15, q15, q13 \n"
"vld1.8 {d4-d5}, [%[r4]] \n" // r4
"add %[r4], #8 \n"
"vext.8 d6, d4, d5, #1 \n"
"vext.8 d7, d4, d5, #2 \n"
"vext.8 d8, d4, d5, #3 \n"
"vext.8 d9, d4, d5, #4 \n"
"vdup.s8 d10, d2[4] \n"
"vdup.s8 d11, d2[5] \n"
"vdup.s8 d12, d2[6] \n"
"vdup.s8 d13, d2[7] \n"
"vdup.s8 d14, d3[0] \n"
"vmull.s8 q8, d4, d10 \n"
"vmull.s8 q9, d6, d11 \n"
"vmlal.s8 q8, d7, d12 \n"
"vmlal.s8 q9, d8, d13 \n"
"vaddl.s16 q12, d16, d18 \n"
"vaddl.s16 q13, d17, d19 \n"
"vmull.s8 q8, d9, d14 \n"
"vaddw.s16 q12, q12, d16 \n"
"vaddw.s16 q13, q13, d17 \n"
"vadd.s32 q14, q14, q12 \n"
"vadd.s32 q15, q15, q13 \n"
"vld1.32 {d24-d27}, [%[output0]] \n"
"vadd.s32 q12, q12, q14 \n"
"vadd.s32 q13, q13, q15 \n"
"vst1.32 {d24-d27}, [%[output0]]! \n"
"subs %[ow], #1 \n"
"bne 0b \n"
: [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), [r3] "+r"(r3),
[r4] "+r"(r4), [ow] "+r"(ow), [output0] "+r"(output0)
:
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
}
if (remain > 0) {
asm volatile("vld1.8 {d0-d3}, [%[kernel]] \n"
: [kernel] "+r"(kernel)
:
: "cc", "memory", "q0", "q1");
asm volatile(
"0: \n"
"vld1.8 d4, [%[r0]] \n"
"vld1.8 d5, [%[r1]] \n"
"vld1.8 d6, [%[r2]] \n"
"vld1.8 d7, [%[r3]] \n"
"vld1.8 d8, [%[r4]] \n"
"add %[r0], #1 \n"
"add %[r1], #1 \n"
"add %[r2], #1 \n"
"add %[r3], #1 \n"
"add %[r4], #1 \n"
"vext.8 d10, d0, d1, #5 \n"
"vext.8 d11, d1, d2, #2 \n"
"vext.8 d12, d1, d2, #7 \n"
"vext.8 d13, d2, d3, #4 \n"
"vmull.s8 q7, d4, d0 \n"
"vmull.s8 q8, d5, d10 \n"
"vmull.s8 q9, d6, d11 \n"
"vmlal.s8 q8, d7, d12 \n"
"vmlal.s8 q9, d8, d13 \n"
"vaddl.s16 q10, d14, d16 \n"
"vaddw.s16 q10, q10, d18 \n"
"vadd.s32 d4, d20, d21 \n"
"vaddl.s16 q10, d15, d17 \n"
"vaddw.s16 q10, q10, d19 \n"
"vdup.s32 d14, d4[0] \n"
"vdup.s32 d15, d4[1] \n"
"vadd.s32 d15, d15, d14 \n"
"vdup.s32 d14, d20[0] \n"
"vadd.s32 d15, d15, d14 \n"
"ldr r6, [%[output0]] \n"
"vdup.s32 d14, r6 \n"
"vadd.s32 d15, d15, d14 \n"
"vst1.32 d15[0], [%[output0]]! \n"
"subs %[remain], #1 \n"
"bne 0b \n"
: [r0] "+r"(r0), [r1] "+r"(r1), [r2] "+r"(r2), [r3] "+r"(r3),
[r4] "+r"(r4), [remain] "+r"(remain), [output0] "+r"(output0)
:
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "r6");
}
}
}
}
#endif
#else
// TODO(hjchen2)
#endif
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CONV_OP
#pragma once
#include "framework/tensor.h"
namespace paddle_mobile {
namespace operators {
void conv3x3s1_int8(const framework::Tensor& input,
const framework::Tensor& weight, framework::Tensor* output);
void conv3x3s1_int8_4c(const framework::Tensor& input,
const framework::Tensor& weight,
framework::Tensor* output);
void conv5x5s1_int8(const framework::Tensor& input,
const framework::Tensor& weight, framework::Tensor* output);
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -28,91 +28,240 @@ namespace math {
* [input_channels, filter_height, filter_width, output_height,
* output_width]
*/
template <class T>
class Im2ColFunctor<ColFormat::kCFO, CPU, T> {
public:
void operator()(const framework::Tensor &im, const std::vector<int> &dilation,
const std::vector<int> &stride,
const std::vector<int> &padding, framework::Tensor *col) {
// PADDLE_ENFORCE(im.dims().size() == 3);
// PADDLE_ENFORCE(col->dims().size() == 5);
template <>
void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
const framework::Tensor &im, const std::vector<int> &dilation,
const std::vector<int> &stride, const std::vector<int> &padding,
framework::Tensor *col) {
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
int filter_height = col->dims()[1];
int filter_width = col->dims()[2];
int col_height = col->dims()[3];
int col_width = col->dims()[4];
int channels_col = im_channels * filter_height * filter_width;
const float *im_data = im.data<float>();
float *col_data = col->data<float>();
#if __ARM_NEON
const int osize = col_height;
const int isize = im_height;
bool pad1 = padding[0] > 0;
bool pad2 =
(pad1 && padding[1] &&
(((isize - 2 * padding[0] + filter_height) % stride[0] == 0) ? 1 : 0));
int fill = isize % 2;
if (stride[0] == 1 && filter_height == 3 && pad1 && pad2 &&
dilation[0] == 1 && im_height > 2) {
for (int c = 0; c < im_channels; ++c) {
int oosize = osize * osize;
int nk4 = osize / 4;
int mk4 = osize % 4;
float *col0 = col_data + 0 * oosize + 2 * osize + 2;
float *col1 = col_data + 1 * oosize + 2 * osize + 1;
float *col2 = col_data + 2 * oosize + 2 * osize;
float *col3 = col_data + 3 * oosize + osize + 2;
float *col4 = col_data + 4 * oosize + osize + 1;
float *col5 = col_data + 5 * oosize + osize;
float *col6 = col_data + 6 * oosize + 2;
float *col7 = col_data + 7 * oosize + 1;
float *col8 = col_data + 8 * oosize;
float32x4_t im1;
const float *im_tmp_data = im_data + osize + 1;
int rrsize = oosize - osize - 1;
int nr4 = rrsize / 4;
int mr4 = rrsize % 4;
for (int i = 0; i < nr4; ++i) {
im1 = vld1q_f32(im_tmp_data);
vst1q_f32(col0, im1);
vst1q_f32(col1, im1);
vst1q_f32(col2, im1);
vst1q_f32(col3, im1);
vst1q_f32(col4, im1);
vst1q_f32(col5, im1);
vst1q_f32(col6, im1);
vst1q_f32(col7, im1);
vst1q_f32(col8, im1);
col0 += 4;
col1 += 4;
col2 += 4;
col3 += 4;
col4 += 4;
col5 += 4;
col6 += 4;
col7 += 4;
col8 += 4;
im_tmp_data += 4;
}
for (int i = 0; i < mr4; ++i) {
*col0 = *im_tmp_data;
*col1 = *im_tmp_data;
*col2 = *im_tmp_data;
*col3 = *im_tmp_data;
*col4 = *im_tmp_data;
*col5 = *im_tmp_data;
*col6 = *im_tmp_data;
*col7 = *im_tmp_data;
*col8 = *im_tmp_data;
col0++;
col1++;
col2++;
col3++;
col4++;
col5++;
col6++;
col7++;
col8++;
im_tmp_data++;
}
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
int filter_height = col->dims()[1];
int filter_width = col->dims()[2];
int col_height = col->dims()[3];
int col_width = col->dims()[4];
// PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2]
// -
// ((dilation[0] * (filter_height - 1)
// + 1))) /
// stride[0] +
// 1,
// col_height,
// "Output_height and
// padding(padding_up, padding_down)
// are " "inconsistent.");
// PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3]
// -
// ((dilation[1] * (filter_width - 1)
// + 1))) /
// stride[1] +
// 1,
// col_width,
// "Output_height and
// padding(padding_up, padding_down)
// are " "inconsistent.");
im_tmp_data = im_data + 1;
col0 = col_data + 0 * oosize + osize + 2;
col1 = col_data + 1 * oosize + osize + 1;
col2 = col_data + 2 * oosize + osize;
col3 = col_data + 3 * oosize + 2;
col4 = col_data + 4 * oosize + 1;
col5 = col_data + 5 * oosize;
for (int i = 0; i < nk4; i++) {
im1 = vld1q_f32(im_tmp_data);
vst1q_f32(col0, im1);
vst1q_f32(col1, im1);
vst1q_f32(col2, im1);
vst1q_f32(col3, im1);
vst1q_f32(col4, im1);
vst1q_f32(col5, im1);
col0 += 4;
col1 += 4;
col2 += 4;
col3 += 4;
col4 += 4;
col5 += 4;
im_tmp_data += 4;
}
int channels_col = im_channels * filter_height * filter_width;
const T *im_data = im.data<T>();
T *col_data = col->data<T>();
#if __ARM_NEON
const int osize = col_height;
const int isize = im_height;
bool pad1 = padding[0] > 0;
bool pad2 =
(pad1 && padding[1] &&
(((isize - 2 * padding[0] + filter_height) % stride[0] == 0) ? 1 : 0));
int fill = isize % 2;
if (stride[0] == 1 && filter_height == 3 && pad1 && pad2 &&
dilation[0] == 1 && im_height > 2) {
for (int c = 0; c < im_channels; ++c) {
int oosize = osize * osize;
int nk4 = osize / 4;
int mk4 = osize % 4;
float *col0 = col_data + 0 * oosize + 2 * osize + 2;
float *col1 = col_data + 1 * oosize + 2 * osize + 1;
float *col2 = col_data + 2 * oosize + 2 * osize;
float *col3 = col_data + 3 * oosize + osize + 2;
float *col4 = col_data + 4 * oosize + osize + 1;
float *col5 = col_data + 5 * oosize + osize;
float *col6 = col_data + 6 * oosize + 2;
float *col7 = col_data + 7 * oosize + 1;
float *col8 = col_data + 8 * oosize;
float32x4_t im1;
const float *im_tmp_data = im_data + osize + 1;
int rrsize = oosize - osize - 1;
int nr4 = rrsize / 4;
int mr4 = rrsize % 4;
for (int i = 0; i < nr4; ++i) {
im1 = vld1q_f32(im_tmp_data);
vst1q_f32(col0, im1);
vst1q_f32(col1, im1);
vst1q_f32(col2, im1);
vst1q_f32(col3, im1);
vst1q_f32(col4, im1);
vst1q_f32(col5, im1);
vst1q_f32(col6, im1);
vst1q_f32(col7, im1);
vst1q_f32(col8, im1);
for (int i = 0; i < mk4; i++) {
*col0 = *im_tmp_data;
*col1 = *im_tmp_data;
*col2 = *im_tmp_data;
*col3 = *im_tmp_data;
*col4 = *im_tmp_data;
*col5 = *im_tmp_data;
col0++;
col1++;
col2++;
col3++;
col4++;
col5++;
im_tmp_data++;
}
// fill 0 1 11;
for (int i = 0; i < osize; ++i) {
col_data[0 * oosize + i * osize] = 0.0;
col_data[3 * oosize + i * osize] = 0.0;
col_data[6 * oosize + i * osize] = 0.0;
col_data[2 * oosize + osize - 1 + i * osize] = 0.0;
col_data[5 * oosize + osize - 1 + i * osize] = 0.0;
col_data[8 * oosize + osize - 1 + i * osize] = 0.0;
}
col_data[0 * oosize + osize + 1] = im_data[0];
col_data[3 * oosize + 1] = im_data[0];
col_data[6 * oosize + 1] = im_data[osize];
col_data[1 * oosize + osize] = im_data[0];
col_data[4 * oosize] = im_data[0];
col_data[7 * oosize] = im_data[osize];
float32x4_t zero4;
zero4 = vdupq_n_f32(0.0);
auto col_z0 = col_data;
auto col_z1 = col_data + oosize;
auto col_z2 = col_data + 2 * oosize;
auto col_z6 = col_data + 6 * oosize + osize * (osize - 1);
auto col_z7 = col_data + 7 * oosize + osize * (osize - 1);
auto col_z8 = col_data + 8 * oosize + osize * (osize - 1);
for (int i = 0; i < nk4; ++i) {
vst1q_f32(col_z0, zero4);
vst1q_f32(col_z1, zero4);
vst1q_f32(col_z2, zero4);
vst1q_f32(col_z6, zero4);
vst1q_f32(col_z7, zero4);
vst1q_f32(col_z8, zero4);
col_z0 += 4;
col_z1 += 4;
col_z2 += 4;
col_z6 += 4;
col_z7 += 4;
col_z8 += 4;
}
for (int i = 0; i < mk4; ++i) {
col_z0[i] = 0.0;
col_z1[i] = 0.0;
col_z2[i] = 0.0;
col_z6[i] = 0.0;
col_z7[i] = 0.0;
col_z8[i] = 0.0;
}
col_data += 9 * oosize;
im_data += isize * isize;
}
} else if (stride[0] == 2 && filter_height == 3 && pad1 && dilation[0] == 1 &&
im_height > 2) {
for (int c = 0; c < im_channels; ++c) {
int oosize = osize * osize;
int nk4 = osize / 4;
int mk4 = osize % 4;
// 3 2 3 1 0 1 3 2 3
float *col0 = col_data + 0 * oosize + osize + 1;
float *col1 = col_data + 1 * oosize + osize;
float *col2 = col_data + 2 * oosize + osize;
float *col3 = col_data + 3 * oosize + 1;
float *col4 = col_data + 4 * oosize;
float *col5 = col_data + 5 * oosize;
float *col6 = col_data + 6 * oosize + 1;
float *col7 = col_data + 7 * oosize;
float *col8 = col_data + 8 * oosize;
float32x4x2_t im01;
float32x4x2_t im23;
const float *im_tmp_data0 = im_data;
const float *im_tmp_data2 = im_data + isize;
for (int j = 0; j < osize; ++j) {
for (int i = 0; i < nk4; ++i) {
im01 = vld2q_f32(im_tmp_data0);
im23 = vld2q_f32(im_tmp_data2);
vst1q_f32(col0, im23.val[1]);
vst1q_f32(col1, im23.val[0]);
vst1q_f32(col2, im23.val[1]);
vst1q_f32(col3, im01.val[1]);
vst1q_f32(col4, im01.val[0]);
vst1q_f32(col5, im01.val[1]);
vst1q_f32(col6, im23.val[1]);
vst1q_f32(col7, im23.val[0]);
vst1q_f32(col8, im23.val[1]);
col0 += 4;
col1 += 4;
......@@ -124,18 +273,21 @@ class Im2ColFunctor<ColFormat::kCFO, CPU, T> {
col7 += 4;
col8 += 4;
im_tmp_data += 4;
im_tmp_data0 += 8;
im_tmp_data2 += 8;
}
for (int i = 0; i < mr4; ++i) {
*col0 = *im_tmp_data;
*col1 = *im_tmp_data;
*col2 = *im_tmp_data;
*col3 = *im_tmp_data;
*col4 = *im_tmp_data;
*col5 = *im_tmp_data;
*col6 = *im_tmp_data;
*col7 = *im_tmp_data;
*col8 = *im_tmp_data;
const float *im_tmp_data1 = im_tmp_data0 + 1;
const float *im_tmp_data3 = im_tmp_data2 + 1;
for (int i = 0; i < mk4; ++i) {
*col0 = *im_tmp_data3;
*col1 = *im_tmp_data2;
*col2 = *im_tmp_data3;
*col3 = *im_tmp_data1;
*col4 = *im_tmp_data0;
*col5 = *im_tmp_data1;
*col6 = *im_tmp_data3;
*col7 = *im_tmp_data2;
*col8 = *im_tmp_data3;
col0++;
col1++;
......@@ -146,271 +298,215 @@ class Im2ColFunctor<ColFormat::kCFO, CPU, T> {
col6++;
col7++;
col8++;
im_tmp_data++;
im_tmp_data0 += 2;
im_tmp_data1 += 2;
im_tmp_data2 += 2;
im_tmp_data3 += 2;
}
im_tmp_data = im_data + 1;
col0 = col_data + 0 * oosize + osize + 2;
col1 = col_data + 1 * oosize + osize + 1;
col2 = col_data + 2 * oosize + osize;
col3 = col_data + 3 * oosize + 2;
col4 = col_data + 4 * oosize + 1;
col5 = col_data + 5 * oosize;
for (int i = 0; i < nk4; i++) {
im1 = vld1q_f32(im_tmp_data);
vst1q_f32(col0, im1);
vst1q_f32(col1, im1);
vst1q_f32(col2, im1);
vst1q_f32(col3, im1);
vst1q_f32(col4, im1);
vst1q_f32(col5, im1);
col0 += 4;
col1 += 4;
col2 += 4;
col3 += 4;
col4 += 4;
col5 += 4;
im_tmp_data += 4;
}
for (int i = 0; i < mk4; i++) {
*col0 = *im_tmp_data;
*col1 = *im_tmp_data;
*col2 = *im_tmp_data;
*col3 = *im_tmp_data;
*col4 = *im_tmp_data;
*col5 = *im_tmp_data;
col0++;
col1++;
col2++;
col3++;
col4++;
col5++;
im_tmp_data++;
}
// fill 0 1 11;
for (int i = 0; i < osize; ++i) {
col_data[0 * oosize + i * osize] = 0.0;
col_data[3 * oosize + i * osize] = 0.0;
col_data[6 * oosize + i * osize] = 0.0;
im_tmp_data0 += (isize - fill);
im_tmp_data2 += (isize - fill);
}
for (int i = 0; i < osize; ++i) {
col_data[0 * oosize + i * osize] = 0.0;
col_data[3 * oosize + i * osize] = 0.0;
col_data[6 * oosize + i * osize] = 0.0;
if (pad2) {
col_data[2 * oosize + osize - 1 + i * osize] = 0.0;
col_data[5 * oosize + osize - 1 + i * osize] = 0.0;
col_data[8 * oosize + osize - 1 + i * osize] = 0.0;
}
col_data[0 * oosize + osize + 1] = im_data[0];
col_data[3 * oosize + 1] = im_data[0];
col_data[6 * oosize + 1] = im_data[osize];
col_data[1 * oosize + osize] = im_data[0];
col_data[4 * oosize] = im_data[0];
col_data[7 * oosize] = im_data[osize];
float32x4_t zero4;
zero4 = vdupq_n_f32(0.0);
auto col_z0 = col_data;
auto col_z1 = col_data + oosize;
auto col_z2 = col_data + 2 * oosize;
auto col_z6 = col_data + 6 * oosize + osize * (osize - 1);
auto col_z7 = col_data + 7 * oosize + osize * (osize - 1);
auto col_z8 = col_data + 8 * oosize + osize * (osize - 1);
for (int i = 0; i < nk4; ++i) {
vst1q_f32(col_z0, zero4);
vst1q_f32(col_z1, zero4);
vst1q_f32(col_z2, zero4);
}
float32x4_t zero4;
zero4 = vdupq_n_f32(0.0);
auto col_z0 = col_data;
auto col_z1 = col_data + oosize;
auto col_z2 = col_data + 2 * oosize;
auto col_z6 = col_data + 6 * oosize + osize * (osize - 1);
auto col_z7 = col_data + 7 * oosize + osize * (osize - 1);
auto col_z8 = col_data + 8 * oosize + osize * (osize - 1);
for (int i = 0; i < nk4; ++i) {
vst1q_f32(col_z0, zero4);
vst1q_f32(col_z1, zero4);
vst1q_f32(col_z2, zero4);
if (pad2) {
vst1q_f32(col_z6, zero4);
vst1q_f32(col_z7, zero4);
vst1q_f32(col_z8, zero4);
col_z0 += 4;
col_z1 += 4;
col_z2 += 4;
col_z6 += 4;
col_z7 += 4;
col_z8 += 4;
}
col_z0 += 4;
col_z1 += 4;
col_z2 += 4;
col_z6 += 4;
col_z7 += 4;
col_z8 += 4;
}
for (int i = 0; i < mk4; ++i) {
col_z0[i] = 0.0;
col_z1[i] = 0.0;
col_z2[i] = 0.0;
for (int i = 0; i < mk4; ++i) {
col_z0[i] = 0.0;
col_z1[i] = 0.0;
col_z2[i] = 0.0;
if (pad2) {
col_z6[i] = 0.0;
col_z7[i] = 0.0;
col_z8[i] = 0.0;
}
col_data += 9 * oosize;
im_data += isize * isize;
}
} else if (stride[0] == 2 && filter_height == 3 && pad1 &&
dilation[0] == 1 && im_height > 2) {
for (int c = 0; c < im_channels; ++c) {
int oosize = osize * osize;
int nk4 = osize / 4;
int mk4 = osize % 4;
// 3 2 3 1 0 1 3 2 3
float *col0 = col_data + 0 * oosize + osize + 1;
float *col1 = col_data + 1 * oosize + osize;
float *col2 = col_data + 2 * oosize + osize;
float *col3 = col_data + 3 * oosize + 1;
float *col4 = col_data + 4 * oosize;
float *col5 = col_data + 5 * oosize;
float *col6 = col_data + 6 * oosize + 1;
float *col7 = col_data + 7 * oosize;
float *col8 = col_data + 8 * oosize;
float32x4x2_t im01;
float32x4x2_t im23;
const float *im_tmp_data0 = im_data;
const float *im_tmp_data2 = im_data + isize;
for (int j = 0; j < osize; ++j) {
for (int i = 0; i < nk4; ++i) {
im01 = vld2q_f32(im_tmp_data0);
im23 = vld2q_f32(im_tmp_data2);
vst1q_f32(col0, im23.val[1]);
vst1q_f32(col1, im23.val[0]);
vst1q_f32(col2, im23.val[1]);
vst1q_f32(col3, im01.val[1]);
vst1q_f32(col4, im01.val[0]);
vst1q_f32(col5, im01.val[1]);
vst1q_f32(col6, im23.val[1]);
vst1q_f32(col7, im23.val[0]);
vst1q_f32(col8, im23.val[1]);
col0 += 4;
col1 += 4;
col2 += 4;
col3 += 4;
col4 += 4;
col5 += 4;
col6 += 4;
col7 += 4;
col8 += 4;
im_tmp_data0 += 8;
im_tmp_data2 += 8;
}
const float *im_tmp_data1 = im_tmp_data0 + 1;
const float *im_tmp_data3 = im_tmp_data2 + 1;
for (int i = 0; i < mk4; ++i) {
*col0 = *im_tmp_data3;
*col1 = *im_tmp_data2;
*col2 = *im_tmp_data3;
*col3 = *im_tmp_data1;
*col4 = *im_tmp_data0;
*col5 = *im_tmp_data1;
*col6 = *im_tmp_data3;
*col7 = *im_tmp_data2;
*col8 = *im_tmp_data3;
col0++;
col1++;
col2++;
col3++;
col4++;
col5++;
col6++;
col7++;
col8++;
im_tmp_data0 += 2;
im_tmp_data1 += 2;
im_tmp_data2 += 2;
im_tmp_data3 += 2;
}
im_tmp_data0 += (isize - fill);
im_tmp_data2 += (isize - fill);
}
for (int i = 0; i < osize; ++i) {
col_data[0 * oosize + i * osize] = 0.0;
col_data[3 * oosize + i * osize] = 0.0;
col_data[6 * oosize + i * osize] = 0.0;
if (pad2) {
col_data[2 * oosize + osize - 1 + i * osize] = 0.0;
col_data[5 * oosize + osize - 1 + i * osize] = 0.0;
col_data[8 * oosize + osize - 1 + i * osize] = 0.0;
}
}
float32x4_t zero4;
zero4 = vdupq_n_f32(0.0);
auto col_z0 = col_data;
auto col_z1 = col_data + oosize;
auto col_z2 = col_data + 2 * oosize;
auto col_z6 = col_data + 6 * oosize + osize * (osize - 1);
auto col_z7 = col_data + 7 * oosize + osize * (osize - 1);
auto col_z8 = col_data + 8 * oosize + osize * (osize - 1);
for (int i = 0; i < nk4; ++i) {
vst1q_f32(col_z0, zero4);
vst1q_f32(col_z1, zero4);
vst1q_f32(col_z2, zero4);
if (pad2) {
vst1q_f32(col_z6, zero4);
vst1q_f32(col_z7, zero4);
vst1q_f32(col_z8, zero4);
}
col_z0 += 4;
col_z1 += 4;
col_z2 += 4;
col_z6 += 4;
col_z7 += 4;
col_z8 += 4;
}
col_data[1 * oosize + osize] = im_data[isize];
for (int i = 1; i < osize; ++i) {
col_data[3 * oosize + i] = im_data[(i - 1) * stride[0] + 1];
}
col_data[4 * oosize] = im_data[0];
col_data[7 * oosize] = im_data[isize];
for (int i = 0; i < mk4; ++i) {
col_z0[i] = 0.0;
col_z1[i] = 0.0;
col_z2[i] = 0.0;
if (pad2) {
col_z6[i] = 0.0;
col_z7[i] = 0.0;
col_z8[i] = 0.0;
}
}
col_data += 9 * oosize;
im_data += isize * isize;
}
} else {
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int col_idx = (c * col_height + h) * col_width + w;
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
col_data[1 * oosize + osize] = im_data[isize];
for (int i = 1; i < osize; ++i) {
col_data[3 * oosize + i] = im_data[(i - 1) * stride[0] + 1];
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
im_col_idx < 0 || im_col_idx >= im_width)
? static_cast<float>(0)
: im_data[im_idx];
}
col_data[4 * oosize] = im_data[0];
col_data[7 * oosize] = im_data[isize];
col_data += 9 * oosize;
im_data += isize * isize;
}
}
}
#else
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int col_idx = (c * col_height + h) * col_width + w;
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
im_col_idx < 0 || im_col_idx >= im_width)
? static_cast<float>(0)
: im_data[im_idx];
}
}
}
#endif
}
void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height,
const int im_width, const int col_height, const int col_width,
const int padding_h, const int padding_w, const int stride_h,
const int stride_w, const int kh, const int kw) {
int h = padding_h - kh;
int w = padding_w - kw;
int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0;
int col_start_width = w > 0 ? (w + stride_w - 1) / stride_w : 0;
int start_height = kh + col_start_height * stride_h - padding_h;
int start_width = kw + col_start_width * stride_w - padding_w;
int end_height = (col_height - col_start_height) * stride_h + start_height;
end_height = end_height > im_height ? im_height : end_height;
int end_width = (col_width - col_start_width) * stride_w + start_width;
end_width = end_width > im_width ? im_width : end_width;
int extract = (end_width - start_width + stride_w - 1) / stride_w;
im_data += start_height * im_width + start_width;
col_data += col_start_height * col_width + col_start_width;
for (int i = start_height; i < end_height; i += stride_h) {
if (stride_w == 1) {
memcpy(col_data, im_data, extract * sizeof(int8_t));
} else if (stride_w == 2) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 15; s += 16) {
int8x16x2_t img = vld2q_s8(im_data + s * 2);
vst1q_s8(col_data + s, img.val[0]);
}
#endif
for (; s < extract; ++s) {
col_data[s] = im_data[s * 2];
}
} else if (stride_w == 3) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 15; s += 16) {
int8x16x3_t img = vld3q_s8(im_data + s * 3);
vst1q_s8(col_data + s, img.val[0]);
}
#endif
for (; s < extract; ++s) {
col_data[s] = im_data[s * 3];
}
} else if (stride_w == 4) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 15; s += 16) {
int8x16x4_t img = vld4q_s8(im_data + s * 4);
vst1q_s8(col_data + s, img.val[0]);
}
#endif
for (; s < extract; ++s) {
col_data[s] = im_data[s * 4];
}
} else {
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) {
int im_col_idx =
w * stride[1] - padding[1] + w_offset * dilation[1];
int col_idx = (c * col_height + h) * col_width + w;
int im_idx =
(im_row_idx + c_im * im_height) * im_width + im_col_idx;
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
im_col_idx < 0 || im_col_idx >= im_width)
? static_cast<T>(0)
: im_data[im_idx];
}
PADDLE_MOBILE_THROW_EXCEPTION("stride_w must be one of 1, 2, 3 and 4.");
}
im_data += im_width * stride_h;
col_data += col_width;
}
}
/*
* im = [input_channels, input_height, input_width]
* col =
* [input_channels, filter_height, filter_width, output_height,
* output_width]
*/
template <>
void Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>::operator()(
const framework::Tensor &im, const std::vector<int> &dilation,
const std::vector<int> &stride, const std::vector<int> &padding,
framework::Tensor *col) {
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
int filter_height = col->dims()[1];
int filter_width = col->dims()[2];
int col_height = col->dims()[3];
int col_width = col->dims()[4];
int channels_col = im_channels * filter_height * filter_width;
const int8_t *im_data = im.data<int8_t>();
int8_t *col_data = col->data<int8_t>();
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
// pad 0
memset(col_data, 0, col->numel() * sizeof(int8_t));
for (int ic = 0; ic < im_channels; ++ic) {
for (int kh = 0; kh < filter_height; ++kh) {
for (int kw = 0; kw < filter_width; ++kw) {
ExtractToImg(im_data, col_data, im_height, im_width, col_height,
col_width, padding[0], padding[1], stride[0], stride[1],
kh, kw);
col_data += col_height * col_width;
}
}
im_data += im_height * im_width;
}
#else
} else {
#endif
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
......@@ -424,14 +520,15 @@ class Im2ColFunctor<ColFormat::kCFO, CPU, T> {
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
im_col_idx < 0 || im_col_idx >= im_width)
? static_cast<T>(0)
? static_cast<int8_t>(0)
: im_data[im_idx];
}
}
}
#endif
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
}
};
#endif
}
/*
* im = [input_channels, input_height, input_width]
......@@ -456,27 +553,6 @@ class Col2ImFunctor<ColFormat::kCFO, CPU, T> {
int col_height = col.dims()[3];
int col_width = col.dims()[4];
// PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2]
// -
// ((dilation[0] * (filter_height - 1)
// + 1))) /
// stride[0] +
// 1,
// col_height,
// "Output_height and
// padding(padding_up, padding_down)
// are " "inconsistent.");
// PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3]
// -
// ((dilation[1] * (filter_width - 1)
// + 1))) /
// stride[1] +
// 1,
// col_width,
// "Output_height and
// padding(padding_up, padding_down)
// are " "inconsistent.");
int channels_col = im_channels * filter_height * filter_width;
T *im_data = im->data<T>();
......@@ -503,9 +579,9 @@ class Col2ImFunctor<ColFormat::kCFO, CPU, T> {
};
template class Im2ColFunctor<ColFormat::kCFO, CPU, float>;
// template class Im2ColFunctor<ColFormat::kCFO, CPU, double>;
template class Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>;
template class Col2ImFunctor<ColFormat::kCFO, CPU, float>;
template class Col2ImFunctor<ColFormat::kCFO, CPU, double>;
template class Col2ImFunctor<ColFormat::kCFO, CPU, int8_t>;
/*
* im = [input_channels, input_height, input_width]
......@@ -519,8 +595,6 @@ class Im2ColFunctor<ColFormat::kOCF, CPU, T> {
void operator()(const framework::Tensor &im, const std::vector<int> &dilation,
const std::vector<int> &stride,
const std::vector<int> &padding, framework::Tensor *col) {
// PADDLE_ENFORCE(im.dims().size() == 3);
// PADDLE_ENFORCE(col->dims().size() == 5);
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
......@@ -529,19 +603,6 @@ class Im2ColFunctor<ColFormat::kOCF, CPU, T> {
int col_height = col->dims()[0];
int col_width = col->dims()[1];
// PADDLE_ENFORCE_EQ(
// (im_height + padding[0] + padding[2] -
// filter_height) / stride[0]
// + 1, col_height, "Output_height and
// padding(padding_up,
// padding_down) are " "inconsistent.");
// PADDLE_ENFORCE_EQ(
// (im_width + padding[1] + padding[3] -
// filter_width) / stride[1] +
// 1, col_width, "col_width and padding(padding_left,
// padding_right)
// are " "inconsistent.");
const T *im_data = im.data<T>();
T *col_data = col->data<T>();
......@@ -593,8 +654,6 @@ class Col2ImFunctor<ColFormat::kOCF, CPU, T> {
const std::vector<int> &dilation,
const std::vector<int> &stride,
const std::vector<int> &padding, framework::Tensor *im) {
// PADDLE_ENFORCE(im->dims().size() == 3);
// PADDLE_ENFORCE(col.dims().size() == 5);
int im_channels = im->dims()[0];
int im_height = im->dims()[1];
int im_width = im->dims()[2];
......@@ -603,19 +662,6 @@ class Col2ImFunctor<ColFormat::kOCF, CPU, T> {
int col_height = col.dims()[0];
int col_width = col.dims()[1];
// PADDLE_ENFORCE_EQ(
// (im_height + padding[0] + padding[2] -
// filter_height) / stride[0]
// + 1, col_height, "Output_height and
// padding(padding_up,
// padding_down) are " "inconsistent.");
// PADDLE_ENFORCE_EQ(
// (im_width + padding[1] + padding[3] -
// filter_width) / stride[1] +
// 1, col_width, "col_width and padding(padding_left,
// padding_right)
// are " "inconsistent.");
T *im_data = im->data<T>();
const T *col_data = col.data<T>();
......@@ -655,9 +701,7 @@ class Col2ImFunctor<ColFormat::kOCF, CPU, T> {
};
template class Im2ColFunctor<ColFormat::kOCF, CPU, float>;
template class Im2ColFunctor<ColFormat::kOCF, CPU, double>;
template class Col2ImFunctor<ColFormat::kOCF, CPU, float>;
template class Col2ImFunctor<ColFormat::kOCF, CPU, double>;
} // namespace math
} // namespace operators
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <cmath>
#include <string>
#include "framework/tensor.h"
namespace paddle_mobile {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/math/pad.h"
namespace paddle_mobile {
namespace operators {
namespace math {
template <typename T>
class PadFunctor<CPU, T> {
public:
void operator()(const framework::Tensor &input, const int pad_h,
const int pad_w, framework::Tensor *output) {
const T *in_data = input.data<T>();
T *out_data = output->mutable_data<T>();
const framework::DDim &input_shape = input.dims();
const framework::DDim &output_shape = output->dims();
// fill output with 0
memset(out_data, 0, sizeof(T) * output->numel());
// should make sure the shape of output is match with input
for (int i = 0; i < input_shape[0]; ++i) {
for (int c = 0; c < input_shape[1]; ++c) {
out_data += pad_h * output_shape[3];
for (int h = 0; h < input_shape[2]; ++h) {
memcpy(out_data + pad_w, in_data, sizeof(T) * input_shape[3]);
out_data += output_shape[3];
in_data += input_shape[3];
}
out_data += pad_h * output_shape[3];
}
}
}
};
template class PadFunctor<CPU, float>;
template class PadFunctor<CPU, int8_t>;
} // namespace math
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "framework/tensor.h"
namespace paddle_mobile {
namespace operators {
namespace math {
template <typename DeviceType, typename T>
class PadFunctor {
public:
void operator()(const framework::Tensor &input, const int pad_h,
const int pad_w, framework::Tensor *output);
};
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -32,9 +32,6 @@ class Vol2ColFunctor<CPU, T> {
void operator()(const Tensor &vol, const std::vector<int> &dilations,
const std::vector<int> &strides,
const std::vector<int> &paddings, Tensor *col) const {
// PADDLE_ENFORCE(vol.dims().size() == 4);
// PADDLE_ENFORCE(col->dims().size() == 7);
int input_channels = vol.dims()[0];
int input_depth = vol.dims()[1];
int input_height = vol.dims()[2];
......@@ -48,32 +45,6 @@ class Vol2ColFunctor<CPU, T> {
int channels_col =
input_channels * filter_depth * filter_height * filter_width;
// PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] -
// ((dilations[0] * (filter_depth - 1)
// + 1))) /
// strides[0] +
// 1,
// output_depth,
// "input_depth and output_depth are "
// "mismatching.");
// PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] -
// ((dilations[1] * (filter_height -
// 1) + 1))) /
// strides[1] +
// 1,
// output_height,
// "input_height and output_height are
// "
// "mismatching.");
// PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] -
// ((dilations[2] * (filter_width - 1)
// + 1))) /
// strides[2] +
// 1,
// output_width,
// "input_width and output_width are "
// "mismatching.");
const T *vol_data = vol.data<T>();
T *col_data = col->data<T>();
......@@ -119,9 +90,6 @@ class Col2VolFunctor<CPU, T> {
void operator()(const Tensor &col, const std::vector<int> &dilations,
const std::vector<int> &strides,
const std::vector<int> &paddings, Tensor *vol) const {
// PADDLE_ENFORCE(vol->dims().size() == 4);
// PADDLE_ENFORCE(col.dims().size() == 7);
int input_channels = vol->dims()[0];
int input_depth = vol->dims()[1];
int input_height = vol->dims()[2];
......@@ -135,31 +103,6 @@ class Col2VolFunctor<CPU, T> {
int channels_col =
input_channels * filter_depth * filter_height * filter_width;
// PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] -
// ((dilations[0] * (filter_depth - 1)
// + 1))) /
// strides[0] +
// 1,
// output_depth,
// "input_depth and output_depth are "
// "mismatching.");
// PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] -
// ((dilations[1] * (filter_height -
// 1) + 1))) /
// strides[1] +
// 1,
// output_height,
// "input_height and output_height are
// "
// "mismatching.");
// PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] -
// ((dilations[2] * (filter_width - 1)
// + 1))) /
// strides[2] +
// 1,
// output_width,
// "input_width and output_width are "
// "mismatching.");
T *vol_data = vol->data<T>();
const T *col_data = col.data<T>();
......@@ -195,9 +138,9 @@ class Col2VolFunctor<CPU, T> {
};
template class Vol2ColFunctor<CPU, float>;
template class Vol2ColFunctor<CPU, double>;
template class Vol2ColFunctor<CPU, int8_t>;
template class Col2VolFunctor<CPU, float>;
template class Col2VolFunctor<CPU, double>;
template class Col2VolFunctor<CPU, int8_t>;
} // namespace math
} // namespace operators
......
......@@ -2330,6 +2330,7 @@ class ShapeParam : public OpParam {
};
#endif
#ifdef QUANT_OP
template <typename Dtype>
class QuantizeParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
......@@ -2340,14 +2341,12 @@ class QuantizeParam : public OpParam {
const AttributeMap &attrs, const Scope &scope) {
input_ = InputXFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope);
if (HasAttr("is_static", attrs)) {
is_static_ = GetAttr<bool>("is_static", attrs);
}
// online
// scale = max(abs(x))
online_scale_ = GetVarValue<GType>("OutScale", outputs, scope);
// offline
if (HasAttr("static_scale", attrs)) {
is_static_ = true;
static_scale_ = GetAttr<float>("static_scale", attrs);
}
// x = round(scale * x)
......@@ -2369,9 +2368,11 @@ class QuantizeParam : public OpParam {
float static_scale_ = 1.0f;
// round method type
// nearest_zero and nearest_even is valid currently
RoundType round_type_ = ROUND_NEAREST_TO_EVEN;
RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO;
};
#endif
#ifdef DEQUANT_OP
template <typename Dtype>
class DequantizeParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
......@@ -2399,6 +2400,7 @@ class DequantizeParam : public OpParam {
RType *activation_scale_;
float weight_scale_;
};
#endif
} // namespace operators
} // namespace paddle_mobile
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef QUANT_OP
#include "operators/quantize_op.h"
#include <vector>
......@@ -33,3 +35,5 @@ namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(quantize, ops::QuantizeOp);
#endif
#endif
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef QUANT_OP
#pragma once
#include <string>
......@@ -40,3 +42,5 @@ class QuantizeOp : public framework::OperatorWithKernel<
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -26,7 +26,7 @@ void SumOp<Dtype, T>::InferShape() const {
auto inputs = this->param_.Inputs();
const size_t n = inputs.size();
std::vector<DDim> inputs_dims;
std::vector<framework::DDim> inputs_dims;
inputs_dims.reserve(n);
for (int i = 0; i < n; i++) {
inputs_dims.push_back(inputs[i]->dims());
......
......@@ -213,6 +213,10 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-dequantize-op operators/test_dequantize_op.cpp test_helper.h test_include.h)
target_link_libraries(test-dequantize-op paddle-mobile)
# test int8 conv op
ADD_EXECUTABLE(test-int8-conv-op operators/test_int8_conv_op.cpp test_helper.h test_include.h)
target_link_libraries(test-int8-conv-op paddle-mobile)
# gen test log
ADD_EXECUTABLE(test-log common/test_log.cpp)
target_link_libraries(test-log paddle-mobile)
......
......@@ -25,27 +25,31 @@ int main() {
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
#endif
paddle_mobile.SetThreadNum(4);
bool optimize = true;
paddle_mobile.SetThreadNum(1);
bool optimize = false;
auto time1 = time();
if (paddle_mobile.Load(g_googlenet, optimize)) {
auto time2 = time();
std::cout << "load cost :" << time_diff(time1, time2) << "ms" << std::endl;
std::vector<float> input;
std::vector<float> output;
std::vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224, &input, dims);
// 预热十次
for (int i = 0; i < 10; ++i) {
auto vec_result = paddle_mobile.Predict(input, dims);
}
// // 预热十次
// for (int i = 0; i < 10; ++i) {
// output = paddle_mobile.Predict(input, dims);
// }
auto time3 = time();
for (int i = 0; i < 10; ++i) {
auto vec_result = paddle_mobile.Predict(input, dims);
output = paddle_mobile.Predict(input, dims);
}
auto time4 = time();
std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms"
<< std::endl;
for (int i = 0; i < output.size(); ++i) {
DLOG << "result[" << i << "] = " << output[i];
}
}
return 0;
}
......@@ -59,7 +59,7 @@ int TestDequqntizeOp() {
framework::Tensor output_cmp;
output_cmp.Resize(dim);
float dequant_scale = 1.f / (1.27 * 1.74);
float dequant_scale = 1.27 / 1.74;
dequantize(input, dequant_scale, &output_cmp);
const float* output_cmp_data = output_cmp.data<float>();
for (int i = 0; i < output->numel(); ++i) {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../test_helper.h"
#include "../test_include.h"
#include "operators/conv_op.h"
namespace paddle_mobile {
// Reference convolution for checking results:
// accumulate through explicit loops over input, output, and filters.
template <typename Itype, typename Otype>
void conv2d(const framework::Tensor *input, const framework::Tensor *filter,
const framework::AttributeMap &attrs, framework::Tensor *output) {
framework::AttrReader attr_reader(attrs);
std::vector<int> paddings = attr_reader.Get<std::vector<int>>("paddings");
std::vector<int> strides = attr_reader.Get<std::vector<int>>("strides");
std::vector<int> dilations = attr_reader.Get<std::vector<int>>("dilations");
int groups = attr_reader.Get<int>("groups");
int kernel_h = filter->dims()[2];
int kernel_w = filter->dims()[3];
int pad_h = paddings[0];
int pad_w = paddings[1];
int stride_h = strides[0];
int stride_w = strides[1];
int dilation_h = dilations[0];
int dilation_w = dilations[1];
auto in_shape = input->dims();
auto out_shape = output->dims();
const bool has_depth = 0;
int kernel_d, pad_d, stride_d, dilation_d;
if (has_depth) {
kernel_d = kernel_h;
stride_d = stride_h;
pad_d = pad_h;
dilation_d = dilation_h;
} else {
kernel_d = stride_d = dilation_d = 1;
pad_d = 0;
}
// Groups
int o_g = out_shape[1] / groups;
int k_g = in_shape[1] / groups;
int o_head, k_head;
// Convolution
vector<int> weight_offset(4 + has_depth);
vector<int> in_offset(4 + has_depth);
vector<int> out_offset(4 + has_depth);
auto offset = [](const framework::Tensor *input, const vector<int> &indics) {
framework::DDim shape = input->dims();
size_t count = 0;
for (int i = 0; i < indics.size(); ++i) {
count *= shape[i];
count += indics[i];
}
return count;
};
const Itype *in_data = input->data<Itype>();
const Itype *w_data = filter->data<Itype>();
Otype *out_data = output->mutable_data<Otype>();
memset(out_data, 0, output->numel() * sizeof(Otype));
for (int n = 0; n < out_shape[0]; n++) {
for (int g = 0; g < groups; g++) {
o_head = o_g * g;
k_head = k_g * g;
for (int o = 0; o < o_g; o++) {
for (int k = 0; k < k_g; k++) {
for (int z = 0; z < (has_depth ? out_shape[2] : 1); z++) {
for (int y = 0; y < out_shape[2 + has_depth]; y++) {
for (int x = 0; x < out_shape[3 + has_depth]; x++) {
for (int r = 0; r < kernel_d; r++) {
for (int p = 0; p < kernel_h; p++) {
for (int q = 0; q < kernel_w; q++) {
int in_z = z * stride_d - pad_d + r * dilation_d;
int in_y = y * stride_h - pad_h + p * dilation_h;
int in_x = x * stride_w - pad_w + q * dilation_w;
if (in_z >= 0 && in_z < (has_depth ? in_shape[2] : 1) &&
in_y >= 0 && in_y < in_shape[2 + has_depth] &&
in_x >= 0 && in_x < in_shape[3 + has_depth]) {
weight_offset[0] = o + o_head;
weight_offset[1] = k;
if (has_depth) {
weight_offset[2] = r;
}
weight_offset[2 + has_depth] = p;
weight_offset[3 + has_depth] = q;
in_offset[0] = n;
in_offset[1] = k + k_head;
if (has_depth) {
in_offset[2] = in_z;
}
in_offset[2 + has_depth] = in_y;
in_offset[3 + has_depth] = in_x;
out_offset[0] = n;
out_offset[1] = o + o_head;
if (has_depth) {
out_offset[2] = z;
}
out_offset[2 + has_depth] = y;
out_offset[3 + has_depth] = x;
out_data[offset(output, out_offset)] +=
in_data[offset(input, in_offset)] *
w_data[offset(filter, weight_offset)];
}
}
}
}
}
}
}
}
}
}
}
}
template <typename Itype, typename Otype, int Kernel, int Pad, int Stride>
int TestConvOp() {
int kernel_h = Kernel;
int kernel_w = Kernel;
int pad_h = Pad;
int pad_w = Pad;
int stride_h = Stride;
int stride_w = Stride;
int dilation_h = 1;
int dilation_w = 1;
int batch_size = 1;
int input_c = 3;
int input_h = 100;
int input_w = 100;
int output_c = 10;
framework::DDim input_shape =
framework::make_ddim({batch_size, input_c, input_h, input_w});
framework::DDim filter_shape =
framework::make_ddim({output_c, input_c, kernel_h, kernel_w});
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["Input"] = std::vector<std::string>({"input"});
inputs["Filter"] = std::vector<std::string>({"filter"});
outputs["Output"] = std::vector<std::string>({"output"});
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<Itype>(input, input_shape, -20, 20);
auto filter_var = scope.get()->Var("filter");
auto filter = filter_var->template GetMutable<framework::LoDTensor>();
SetupTensor<Itype>(filter, filter_shape, -20, 20);
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
attrs["strides"].Set<vector<int>>(std::vector<int>({stride_h, stride_w}));
attrs["paddings"].Set<vector<int>>(std::vector<int>({pad_h, pad_w}));
attrs["dilations"].Set<vector<int>>(
std::vector<int>({dilation_h, dilation_w}));
attrs["groups"].Set<int>(1);
auto *op = new operators::ConvOp<CPU, float>("conv2d", inputs, outputs, attrs,
scope);
// struct timespec ts_begin, ts_end;
op->InferShape();
// warmup
// op->Run();
// clock_gettime(CLOCK_MONOTONIC, &ts_begin);
// for (int i = 0; i < 10; ++i) {
op->Run();
// }
// clock_gettime(CLOCK_MONOTONIC, &ts_end);
// uint64_t elapsed = (ts_end.tv_sec - ts_begin.tv_sec) * 1e3 +
// (ts_end.tv_nsec - ts_begin.tv_nsec) / 1e6;
// LOG(kLOG_INFO) << "elapsed: " << elapsed / 10.0 << " ms";
int kernel_extent_h = dilation_h * (kernel_h - 1) + 1;
int kernel_extent_w = dilation_w * (kernel_w - 1) + 1;
int output_h = (input_h + 2 * pad_h - kernel_extent_h) / stride_h + 1;
int output_w = (input_w + 2 * pad_w - kernel_extent_w) / stride_w + 1;
auto output_shape = framework::make_ddim(
std::vector<int>({batch_size, output_c, output_h, output_w}));
framework::Tensor output_cmp;
output_cmp.mutable_data<Otype>(output_shape);
conv2d<Itype, Otype>(input, filter, attrs, &output_cmp);
// compare results
auto output = output_var->template Get<framework::LoDTensor>();
const Otype *output_data = output->data<Otype>();
Otype *output_cmp_data = output_cmp.data<Otype>();
for (int i = 0; i < output->numel(); ++i) {
PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i],
"output[%d] = %d, output_cmp[%d] = %d", i,
output_data[i], i, output_cmp_data[i]);
}
delete op;
return 0;
}
} // namespace paddle_mobile
int main() {
// kernel = 7, pad = 0, stride = 2
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=2";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 0, 2>();
// kernel = 7, pad = 1, stride = 2
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=2";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 1, 2>();
// kernel = 7, pad = 3, stride = 2
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=2";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 2>();
// kernel = 7, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 0, 1>();
// kernel = 7, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 1, 1>();
// kernel = 7, pad = 3, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 1>();
// kernel = 7, pad = 5, stride = 3
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=5, stride=3";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 5, 3>();
// kernel = 7, pad = 3, stride = 4
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=4";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 4>();
LOG(paddle_mobile::kLOG_INFO) << "\n";
// kernel = 3, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=0, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 3, 0, 1>();
// kernel = 3, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=0, stride=1";
paddle_mobile::TestConvOp<float, float, 3, 0, 1>();
LOG(paddle_mobile::kLOG_INFO) << "\n";
// kernel = 3, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=1, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 3, 1, 1>();
// kernel = 3, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=1, stride=1";
paddle_mobile::TestConvOp<float, float, 3, 1, 1>();
LOG(paddle_mobile::kLOG_INFO) << "\n";
// kernel = 5, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=0, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 5, 0, 1>();
// kernel = 5, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=0, stride=1";
paddle_mobile::TestConvOp<float, float, 5, 0, 1>();
LOG(paddle_mobile::kLOG_INFO) << "\n";
// kernel = 5, pad = 2, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=2, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 5, 2, 1>();
// kernel = 5, pad = 2, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=2, stride=1";
paddle_mobile::TestConvOp<float, float, 5, 2, 1>();
}
......@@ -18,14 +18,6 @@ limitations under the License. */
namespace paddle_mobile {
// static float g_test_data[50] = {
// -5.55, -5.5, -5.45, -5.0, -4.55, -4.5, -4.45, -4.0, -3.55, -3.5,
// -3.45, -3.01, -2.75, -2.5, -2.501, -2.49, -2.01, -1.75, -1.5, -1.25,
// -1.0, -0.75, -0.5, -0.25, 0.0, 0.25, 0.5, 0.75, 1.0, 1.25,
// 1.5, 1.75, 2.01, 2.49, 2.501, 2.5, 2.75, 3.01, 3.45, 3.5,
// 3.55, 4.0, 4.45, 4.5, 4.55, 5.0, 5.45, 5.5, 5.55, 6.0,
// };
static float find_abs_max(const Tensor *input) {
float max_abs = 0.f;
const float *x = input->data<const float>();
......@@ -60,6 +52,16 @@ static void quantize_round_to_even(const Tensor *input, const float scale,
}
}
static void quantize_round_to_nearest(const Tensor *input, const float scale,
Tensor *output) {
const float *x = input->data<const float>();
int8_t *y = output->mutable_data<int8_t>();
size_t size = input->numel();
for (size_t i = 0; i < size; ++i) {
y[i] = round(x[i] * scale);
}
}
int TestQuqntizeOp() {
framework::DDim dim = framework::make_ddim({1, 3, 224, 224});
......@@ -88,15 +90,16 @@ int TestQuqntizeOp() {
auto output_scale = output_scale_var->template Get<framework::LoDTensor>();
const float *output_scale_data = output_scale->data<float>();
float max_abs = find_abs_max(input);
float output_scale_cmp = 127 / max_abs;
float output_scale_cmp = find_abs_max(input);
PADDLE_MOBILE_ENFORCE(output_scale_cmp == output_scale_data[0],
"output_scale = %.6f, output_scale_cmp = %.6f",
output_scale_cmp, output_scale_data[0]);
framework::Tensor output_cmp;
output_cmp.Resize(dim);
quantize_round_to_even(input, output_scale_cmp, &output_cmp);
float scale = 127 / output_scale_cmp;
// quantize_round_to_even(input, scale, &output_cmp);
quantize_round_to_nearest(input, scale, &output_cmp);
int8_t *output_cmp_data = output_cmp.data<int8_t>();
for (int i = 0; i < output->numel(); ++i) {
PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i],
......
......@@ -224,6 +224,8 @@ if(NOT FOUND_MATCH)
set(SHAPE_OP ON)
set(ELEMENTWISEMUL_OP ON)
set(SUM_OP ON)
set(QUANT_OP ON)
set(DEQUANT_OP ON)
endif()
# option(BATCHNORM_OP "" ON)
......@@ -411,3 +413,10 @@ if (SUM_OP)
add_definitions(-DSUM_OP)
endif()
if (QUANT_OP)
add_definitions(-DQUANT_OP)
endif()
if (DEQUANT_OP)
add_definitions(-DDEQUANT_OP)
endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册