未验证 提交 ae6d5d98 编写于 作者: H HappyAngel 提交者: GitHub

[Cherry- pick] fix con_3x3s1_dw compute error and add gcn ops (#4373)

上级 9cb4cb88
......@@ -127,8 +127,10 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
anchor_generator.cc
split_merge_lod_tenosr.cc
reduce_prod.cc
reduce_sum.cc
lstm.cc
clip.cc
pixel_shuffle.cc
scatter.cc
DEPS ${lite_kernel_deps} context tensor)
endif()
......@@ -620,8 +620,10 @@ void conv_depthwise_3x3_fp32(const void* din,
int pad = pad_w;
bool flag_bias = param.bias != nullptr;
bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2));
bool ch_four = ch_in <= 4 * w_in;
if (stride == 1) {
if (pads_less && (pad_h == pad_w) && (pad < 2)) { // support pad = [0, 1]
if (ch_four && pads_less && (pad_h == pad_w) &&
(pad < 2)) { // support pad = [0, 1]
conv_depthwise_3x3s1_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
......@@ -638,7 +640,6 @@ void conv_depthwise_3x3_fp32(const void* din,
act_param,
ctx);
} else {
#ifdef __aarch64__
conv_3x3s1_depthwise_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
......@@ -653,30 +654,10 @@ void conv_depthwise_3x3_fp32(const void* din,
param,
act_param,
ctx);
#else
#ifdef LITE_WITH_ARM_CLANG
LOG(FATAL) << "fp32 depthwise conv3x3s1px doesnot support in v7-clang, "
"this can run in basic";
#else
conv_3x3s1_depthwise_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
ch_out,
h_out,
w_out,
ch_in,
h_in,
w_in,
reinterpret_cast<const float*>(weights),
bias,
param,
act_param,
ctx);
#endif
#endif
}
} else if (stride == 2) {
if (pads_less && pad_h == pad_w && (pad < 2)) { // support pad = [0, 1]
if (ch_four && pads_less && pad_h == pad_w &&
(pad < 2)) { // support pad = [0, 1]
conv_depthwise_3x3s2_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
......
......@@ -53,7 +53,9 @@
#include "lite/backends/arm/math/reduce_max.h"
#include "lite/backends/arm/math/reduce_mean.h"
#include "lite/backends/arm/math/reduce_prod.h"
#include "lite/backends/arm/math/reduce_sum.h"
#include "lite/backends/arm/math/scale.h"
#include "lite/backends/arm/math/scatter.h"
#include "lite/backends/arm/math/sequence_expand.h"
#include "lite/backends/arm/math/sequence_pool.h"
#include "lite/backends/arm/math/sequence_pool_grad.h"
......@@ -357,6 +359,15 @@ inline float32x4_t pow_ps(float32x4_t a, float32x4_t b) {
return exp_ps(vmulq_f32(b, log_ps(a)));
}
inline float32x4_t vpaddq_f32(float32x4_t a, float32x4_t b) {
float32x4_t vrst;
vrst[0] = a[0] + a[1];
vrst[1] = a[2] + a[3];
vrst[2] = b[0] + b[1];
vrst[3] = b[2] + b[3];
return vrst;
}
template <typename T>
void fill_bias_fc(
T* tensor, const T* bias, int num, int channel, bool flag_relu);
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "lite/backends/arm/math/reduce_sum.h"
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <>
void reduce_sum_n<float>(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
int chw_size = channel_in * height_in * width_in;
if (num_in == 1) {
memcpy(dst, src, sizeof(float) * chw_size);
} else {
int cnt_n = num_in >> 2;
int remain_n = num_in & 3;
int cnt_chw = chw_size >> 3;
int cnt_rem = chw_size & 7;
int stride = chw_size << 2;
int stride_c = 0;
for (int c = 0; c < cnt_chw; c++) {
float32x4_t vsum0 = vdupq_n_f32(0.f);
float32x4_t vsum1 = vdupq_n_f32(0.f);
const float* din_ptr0 = src + stride_c;
const float* din_ptr1 = din_ptr0 + chw_size;
const float* din_ptr2 = din_ptr1 + chw_size;
const float* din_ptr3 = din_ptr2 + chw_size;
for (int n = 0; n < cnt_n; n++) {
float32x4_t va0 = vld1q_f32(din_ptr0);
float32x4_t vb0 = vld1q_f32(din_ptr1);
float32x4_t va1 = vld1q_f32(din_ptr0 + 4);
float32x4_t vb1 = vld1q_f32(din_ptr1 + 4);
float32x4_t vc0 = vld1q_f32(din_ptr2);
float32x4_t vd0 = vld1q_f32(din_ptr3);
float32x4_t vs00 = vaddq_f32(va0, vb0);
float32x4_t vc1 = vld1q_f32(din_ptr2 + 4);
float32x4_t vs10 = vaddq_f32(va1, vb1);
float32x4_t vd1 = vld1q_f32(din_ptr3 + 4);
float32x4_t vs01 = vaddq_f32(vc0, vd0);
vsum0 = vaddq_f32(vsum0, vs00);
float32x4_t vs11 = vaddq_f32(vc1, vd1);
vsum1 = vaddq_f32(vsum1, vs10);
din_ptr0 += stride;
din_ptr1 += stride;
vsum0 = vaddq_f32(vsum0, vs01);
din_ptr2 += stride;
din_ptr3 += stride;
vsum1 = vaddq_f32(vsum1, vs11);
}
for (int n = 0; n < remain_n; n++) {
float32x4_t va0 = vld1q_f32(din_ptr0);
float32x4_t va1 = vld1q_f32(din_ptr0 + 4);
vsum0 = vaddq_f32(vsum0, va0);
din_ptr0 += chw_size;
vsum1 = vaddq_f32(vsum1, va1);
}
vst1q_f32(dst, vsum0);
dst += 4;
stride_c += 8;
vst1q_f32(dst, vsum1);
dst += 4;
}
if (cnt_rem > 3) {
float32x4_t vsum0 = vdupq_n_f32(0.f);
const float* din_ptr0 = src + stride_c;
const float* din_ptr1 = din_ptr0 + chw_size;
const float* din_ptr2 = din_ptr1 + chw_size;
const float* din_ptr3 = din_ptr2 + chw_size;
for (int n = 0; n < cnt_n; n++) {
float32x4_t va0 = vld1q_f32(din_ptr0);
float32x4_t vb0 = vld1q_f32(din_ptr1);
float32x4_t vc0 = vld1q_f32(din_ptr2);
float32x4_t vd0 = vld1q_f32(din_ptr3);
float32x4_t vs00 = vaddq_f32(va0, vb0);
float32x4_t vs01 = vaddq_f32(vc0, vd0);
vsum0 = vaddq_f32(vsum0, vs00);
din_ptr0 += stride;
din_ptr1 += stride;
vsum0 = vaddq_f32(vsum0, vs01);
din_ptr2 += stride;
din_ptr3 += stride;
}
for (int n = 0; n < remain_n; n++) {
float32x4_t va0 = vld1q_f32(din_ptr0);
din_ptr0 += chw_size;
vsum0 = vaddq_f32(vsum0, va0);
}
stride_c += 4;
vst1q_f32(dst, vsum0);
dst += 4;
cnt_rem -= 4;
}
for (int c = 0; c < cnt_rem; c++) {
const float* din_ptr0 = src + stride_c;
const float* din_ptr1 = din_ptr0 + chw_size;
const float* din_ptr2 = din_ptr1 + chw_size;
const float* din_ptr3 = din_ptr2 + chw_size;
float sum = 0.0;
for (int n = 0; n < cnt_n; n++) {
float tmp0 = din_ptr0[0] + din_ptr1[0];
float tmp1 = din_ptr2[0] + din_ptr3[0];
din_ptr0 += stride;
din_ptr1 += stride;
sum += tmp0;
din_ptr2 += stride;
din_ptr3 += stride;
sum += tmp1;
}
for (int n = 0; n < remain_n; n++) {
sum += din_ptr0[0];
din_ptr0 += chw_size;
}
stride_c++;
dst[0] = sum;
dst++;
}
}
}
template <>
void reduce_sum_c<float>(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
int hw_size = height_in * width_in;
int chw_size = hw_size * channel_in;
for (int n = 0; n < num_in; ++n) {
reduce_sum_n<float>(src, dst, channel_in, 1, height_in, width_in);
src += chw_size;
dst += hw_size;
}
}
template <>
void reduce_sum_h<float>(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
int nc_size = num_in * channel_in;
int hw_size = height_in * width_in;
for (int n = 0; n < nc_size; ++n) {
reduce_sum_n<float>(src, dst, height_in, 1, 1, width_in);
src += hw_size;
dst += width_in;
}
}
template <>
void reduce_sum_w<float>(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
int nch_size = num_in * channel_in * height_in;
int cnt_w = width_in >> 3;
int cnt_n = nch_size >> 2;
int rem_w = width_in & 7;
int rem_n = nch_size & 3;
int stride = 0;
int stride_n = width_in << 2;
for (int n = 0; n < cnt_n; n++) {
const float* din_ptr0 = src + stride;
const float* din_ptr1 = din_ptr0 + width_in;
const float* din_ptr2 = din_ptr1 + width_in;
const float* din_ptr3 = din_ptr2 + width_in;
float32x4_t vsum = vdupq_n_f32(0.f);
int tmp = rem_w;
for (int w = 0; w < cnt_w; w++) {
float32x4_t va0 = vld1q_f32(din_ptr0);
float32x4_t va1 = vld1q_f32(din_ptr0 + 4);
float32x4_t vb0 = vld1q_f32(din_ptr1);
float32x4_t vb1 = vld1q_f32(din_ptr1 + 4);
float32x4_t vc0 = vld1q_f32(din_ptr2);
float32x4_t vc1 = vld1q_f32(din_ptr2 + 4);
float32x4_t vs0 = vaddq_f32(va0, va1);
float32x4_t vd0 = vld1q_f32(din_ptr3);
float32x4_t vs1 = vaddq_f32(vb0, vb1);
float32x4_t vd1 = vld1q_f32(din_ptr3 + 4);
float32x4_t vs2 = vaddq_f32(vc0, vc1);
din_ptr0 += 8;
float32x4_t vs3 = vaddq_f32(vd0, vd1);
din_ptr1 += 8;
float32x4_t vs00 = vpaddq_f32(vs0, vs1);
din_ptr2 += 8;
float32x4_t vs01 = vpaddq_f32(vs2, vs3);
din_ptr3 += 8;
float32x4_t vs = vpaddq_f32(vs00, vs01);
vsum = vaddq_f32(vs, vsum);
}
if (tmp > 3) {
float32x4_t va0 = vld1q_f32(din_ptr0);
float32x4_t vb0 = vld1q_f32(din_ptr1);
float32x4_t vc0 = vld1q_f32(din_ptr2);
float32x4_t vd0 = vld1q_f32(din_ptr3);
din_ptr0 += 4;
din_ptr1 += 4;
float32x4_t vs00 = vpaddq_f32(va0, vb0);
float32x4_t vs01 = vpaddq_f32(vc0, vd0);
din_ptr2 += 4;
din_ptr3 += 4;
float32x4_t vs = vpaddq_f32(vs00, vs01);
vsum = vaddq_f32(vs, vsum);
tmp -= 4;
}
for (int w = 0; w < tmp; w++) {
vsum[0] += *din_ptr0++;
vsum[1] += *din_ptr1++;
vsum[2] += *din_ptr2++;
vsum[3] += *din_ptr3++;
}
stride += stride_n;
vst1q_f32(dst, vsum);
dst += 4;
}
if (rem_n > 1) {
const float* din_ptr0 = src + stride;
const float* din_ptr1 = din_ptr0 + width_in;
float32x4_t vsum = vdupq_n_f32(0.f);
for (int w = 0; w < cnt_w; w++) {
float32x4_t va0 = vld1q_f32(din_ptr0);
din_ptr0 += 4;
float32x4_t vb0 = vld1q_f32(din_ptr1);
din_ptr1 += 4;
float32x4_t va1 = vld1q_f32(din_ptr0);
float32x4_t vb1 = vld1q_f32(din_ptr1);
float32x4_t vs0 = vpaddq_f32(va0, vb0);
din_ptr0 += 4;
float32x4_t vs1 = vpaddq_f32(va1, vb1);
din_ptr1 += 4;
float32x4_t vs00 = vpaddq_f32(vs0, vs1);
vsum = vaddq_f32(vs00, vsum);
}
int tmp = rem_w;
if (tmp > 3) {
float32x4_t va0 = vld1q_f32(din_ptr0);
float32x4_t vb0 = vld1q_f32(din_ptr1);
din_ptr0 += 4;
din_ptr1 += 4;
float32x4_t vs00 = vpaddq_f32(va0, vb0);
tmp -= 4;
vsum[0] += vs00[0];
vsum[2] += vs00[1];
vsum[1] += vs00[2];
vsum[3] += vs00[3];
}
vsum[0] += vsum[2];
vsum[1] += vsum[3];
for (int w = 0; w < tmp; w++) {
vsum[0] += *din_ptr0++;
vsum[1] += *din_ptr1++;
}
stride += width_in;
*dst++ = vsum[0];
stride += width_in;
*dst++ = vsum[1];
rem_n -= 2;
}
for (int n = 0; n < rem_n; n++) {
const float* din_ptr0 = src + stride;
float32x4_t vsum = vdupq_n_f32(0.f);
for (int w = 0; w < cnt_w; w++) {
float32x4_t va0 = vld1q_f32(din_ptr0);
float32x4_t va1 = vld1q_f32(din_ptr0 + 4);
float32x4_t vs0 = vaddq_f32(va0, va1);
din_ptr0 += 8;
vsum = vaddq_f32(vs0, vsum);
}
if (rem_w > 3) {
float32x4_t va0 = vld1q_f32(din_ptr0);
din_ptr0 += 4;
vsum = vaddq_f32(vsum, va0);
rem_w -= 4;
}
vsum[1] += vsum[2];
for (int w = 0; w < rem_w; w++) {
vsum[0] += *din_ptr0++;
}
vsum[1] += vsum[3];
vsum[0] += vsum[1];
*dst++ = vsum[0];
}
}
template <>
void reduce_sum_all<float>(const float* src, float* dst, int all_size) {
int cnt_n = all_size >> 4;
int rem_n = all_size & 15;
int cnt_rem = rem_n >> 2;
int rem_rem = rem_n & 3;
float32x4_t vsum = vdupq_n_f32(0.f);
for (int n = 0; n < cnt_n; n++) {
float32x4_t va0 = vld1q_f32(src);
float32x4_t va1 = vld1q_f32(src + 4);
float32x4_t va2 = vld1q_f32(src + 8);
float32x4_t va3 = vld1q_f32(src + 12);
src += 16;
float32x4_t vs0 = vaddq_f32(va0, va1);
float32x4_t vs1 = vaddq_f32(va2, va3);
float32x4_t vs = vpaddq_f32(vs0, vs1);
vsum = vaddq_f32(vsum, vs);
}
for (int n = 0; n < cnt_rem; n++) {
float32x4_t va0 = vld1q_f32(src);
src += 4;
vsum = vaddq_f32(vsum, va0);
}
vsum[1] += vsum[2];
for (int n = 0; n < rem_rem; n++) {
vsum[0] += *src++;
}
vsum[1] += vsum[3];
vsum[0] += vsum[1];
dst[0] = vsum[0];
}
template <>
void reduce_sum_nc<float>(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
// reduce nc.
int num = num_in * channel_in;
int size = height_in * width_in;
reduce_sum_n(src, dst, num, size, 1, 1);
}
template <>
void reduce_sum_ch<float>(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
int ch_size = channel_in * height_in;
int chw_size = ch_size * width_in;
for (int n = 0; n < num_in; n++) {
reduce_sum_n<float>(src, dst, ch_size, 1, 1, width_in);
src += chw_size;
dst += width_in;
}
}
template <>
void reduce_sum_hw<float>(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
int hw_size = height_in * width_in;
int nc_size = num_in * channel_in;
reduce_sum_w(src, dst, nc_size, 1, 1, hw_size);
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void reduce_sum_n(const T* src,
T* dst,
int num_in,
int channel_in,
int height_in,
int width_in);
template <typename T>
void reduce_sum_c(const T* src,
T* dst,
int num_in,
int channel_in,
int height_in,
int width_in);
template <typename T>
void reduce_sum_h(const T* src,
T* dst,
int num_in,
int channel_in,
int height_in,
int width_in);
template <typename T>
void reduce_sum_w(const T* src,
T* dst,
int num_in,
int channel_in,
int height_in,
int width_in);
template <typename T>
void reduce_sum_nc(const T* src,
T* dst,
int num_in,
int channel_in,
int height_in,
int width_in);
template <typename T>
void reduce_sum_ch(const T* src,
T* dst,
int num_in,
int channel_in,
int height_in,
int width_in);
template <typename T>
void reduce_sum_hw(const T* src,
T* dst,
int num_in,
int channel_in,
int height_in,
int width_in);
template <typename T>
void reduce_sum_all(const T* src, T* dst, int all_size);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "lite/backends/arm/math/scatter.h"
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <>
void scatter<float>(const int64_t* indexs,
const float* src,
float* dst,
int index_size,
int num,
int size,
bool overwrite) {
for (int i = 0; i < num; i++) {
const float* din = src + indexs[i] * size;
memcpy(dst, din, sizeof(float) * size);
dst += size;
}
if (overwrite) {
for (int i = num; i < index_size; i++) {
const float* din = src + indexs[i] * size;
float* dout = dst + indexs[i] * size;
memcpy(dout, din, sizeof(float) * size);
}
} else {
int cnt = size >> 3;
int rem = size & 7;
for (int i = num; i < index_size; i++) {
const float* din = src + indexs[i] * size;
float* dout = dst + indexs[i] * size;
for (int j = 0; j < cnt; j++) {
float32x4_t va0 = vld1q_f32(din);
float32x4_t vb0 = vld1q_f32(dout);
float32x4_t va1 = vld1q_f32(din + 4);
float32x4_t vb1 = vld1q_f32(dout + 4);
vb0 = vaddq_f32(va0, vb0);
vb1 = vaddq_f32(va1, vb1);
din += 8;
vst1q_f32(dout, vb0);
vst1q_f32(dout + 4, vb0);
dout += 8;
}
for (int j = 0; j < rem; j++) {
dout[0] += *din++;
dout++;
}
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdint.h>
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void scatter(const int64_t* indexs,
const T* updates,
T* dst,
int index_size,
int num,
int size,
bool overwrite);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
......@@ -128,7 +128,7 @@ bool test_convert(bool cv_run,
for (int i = 0; i < test_iter; i++) {
clock_t begin = clock();
// resize default linear
image_preprocess.imageConvert(src, resize_lite);
image_preprocess.image_convert(src, resize_lite);
clock_t end = clock();
to_lite += (end - begin);
}
......@@ -226,7 +226,7 @@ bool test_flip(bool cv_run,
for (int i = 0; i < test_iter; i++) {
clock_t begin = clock();
// resize default linear
image_preprocess.imageFlip(src, resize_lite);
image_preprocess.image_flip(src, resize_lite);
clock_t end = clock();
to_lite += (end - begin);
}
......@@ -330,7 +330,7 @@ bool test_rotate(bool cv_run,
for (int i = 0; i < test_iter; i++) {
clock_t begin = clock();
// resize default linear
image_preprocess.imageRotate(src, resize_lite);
image_preprocess.image_rotate(src, resize_lite);
clock_t end = clock();
to_lite += (end - begin);
}
......@@ -426,7 +426,7 @@ bool test_resize(bool cv_run,
for (int i = 0; i < test_iter; i++) {
clock_t begin = clock();
// resize default linear
image_preprocess.imageResize(src, resize_lite);
image_preprocess.image_resize(src, resize_lite);
clock_t end = clock();
to_lite += (end - begin);
}
......@@ -526,7 +526,7 @@ bool test_crop(bool cv_run,
std::cout << "lite compute:" << std::endl;
for (int i = 0; i < test_iter; i++) {
clock_t begin = clock();
image_preprocess.imageCrop(
image_preprocess.image_crop(
src, resize_lite, dstFormat, srcw, srch, left_x, left_y, dstw, dsth);
clock_t end = clock();
to_lite += (end - begin);
......
......@@ -88,13 +88,13 @@ void pre_process(const cv::Mat& img, int width, int height, Tensor dstTensor) {
uint8_t* rgb_ptr = new uint8_t[img.cols * img.rows * 3];
uint8_t* resize_ptr = new uint8_t[width * height * 3];
// do convert bgr--rgb
img_process.imageConvert(img_ptr, rgb_ptr);
img_process.image_convert(img_ptr, rgb_ptr);
// do resize
img_process.imageResize(rgb_ptr, resize_ptr);
img_process.image_resize(rgb_ptr, resize_ptr);
// data--tensor and normalize
float means[3] = {103.94f, 116.78f, 123.68f};
float scales[3] = {0.017f, 0.017f, 0.017f};
img_process.image2Tensor(
img_process.image_to_tensor(
resize_ptr, &dstTensor, LayoutType::kNCHW, means, scales);
float* data = dstTensor.mutable_data<float>();
#else
......
......@@ -68,6 +68,7 @@ add_kernel(sequence_conv_compute_arm ARM extra SRCS sequence_conv_compute.cc DEP
add_kernel(layer_norm_compute_arm ARM extra SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(gather_compute_arm ARM extra SRCS gather_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(reduce_prod_compute_arm ARM extra SRCS reduce_prod_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(reduce_sum_compute_arm ARM extra SRCS reduce_sum_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(split_lod_tensor_compute_arm ARM extra SRCS split_lod_tensor_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(merge_lod_tensor_compute_arm ARM extra SRCS merge_lod_tensor_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(anchor_generator_compute_arm ARM extra SRCS anchor_generator_compute.cc DEPS ${lite_kernel_deps} math_arm)
......@@ -79,8 +80,10 @@ add_kernel(collect_fpn_proposals_compute_arm ARM extra SRCS collect_fpn_proposal
add_kernel(distribute_fpn_proposals_compute_arm ARM extra SRCS distribute_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(clip_compute_arm ARM extra SRCS clip_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(pixel_shuffle_compute_arm ARM extra SRCS pixel_shuffle_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(scatter_compute_arm ARM extra SRCS scatter_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sequence_expand_as_compute_arm ARM extra SRCS sequence_expand_as_compute.cc DEPS ${lite_kernel_deps} math_arm)
# for OCR specific
add_kernel(gru_unit_compute_arm ARM extra SRCS gru_unit_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(gru_compute_arm ARM extra SRCS gru_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
......@@ -59,12 +59,6 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
bool flag_dw_3x3 = (kw == 3) && (kh == 3) && (stride == 1 || stride == 2);
bool flag_dw_5x5 = (kw == 5) && (kh == 5) && (stride == 1 || stride == 2);
#ifdef __aarch64__
#else
bool flag =
(stride == 1 && (paddings[0] > 1 || paddings[2] > 1)) ? false : true;
flag_dw_3x3 = flag_dw_3x3 && flag;
#endif
bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
/// select conv impl
......
......@@ -28,11 +28,15 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
auto& ctx = this->ctx_->template As<ARMContext>();
auto w_dims = param.filter->dims();
auto kw = w_dims[3];
auto channel = w_dims[0];
auto hin = param.x->dims()[2];
auto win = param.x->dims()[3];
auto paddings = *param.paddings;
bool ch_four = channel <= 4 * win;
// select dw conv kernel
if (kw == 3) {
bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2));
if (pads_less && paddings[0] == paddings[2] &&
if (ch_four && pads_less && paddings[0] == paddings[2] &&
(paddings[0] == 0 || paddings[0] == 1)) {
flag_trans_weights_ = false;
} else {
......@@ -398,6 +402,14 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
w_scale_.data());
}
#ifdef LITE_WITH_PROFILE
template <>
void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::
SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
#endif
} // namespace arm
} // namespace kernels
} // namespace lite
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/reduce_sum_compute.h"
#include <string>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void ReduceSumCompute::Run() {
auto& param = this->template Param<operators::ReduceParam>();
auto* input = param.x->template data<float>();
auto x_dims = param.x->dims();
int x_rank = x_dims.size();
auto* output = param.output->template mutable_data<float>();
std::vector<int> dim = param.dim;
bool keep_dim = param.keep_dim;
bool reduce_all = param.reduce_all;
if (!dim.empty()) {
for (int i = 0; i < dim.size(); i++) {
if (dim[i] < 0) {
dim[i] += x_rank;
}
}
}
if (reduce_all) {
lite::arm::math::reduce_sum_all(input, output, x_dims.production());
} else {
int n_in = 1;
int c_in = 1;
int h_in = 1;
int w_in = 1;
switch (x_dims.size()) {
case 4:
w_in = x_dims[3];
case 3:
h_in = x_dims[2];
case 2:
c_in = x_dims[1];
case 1:
n_in = x_dims[0];
break;
default:
LOG(FATAL) << "x_dims.size is " << x_dims.size()
<< ", which should not be over than 4.";
}
if (dim.size() == 1) {
switch (dim[0]) {
case 0:
lite::arm::math::reduce_sum_n(input, output, n_in, c_in, h_in, w_in);
break;
case 1:
lite::arm::math::reduce_sum_c(input, output, n_in, c_in, h_in, w_in);
break;
case 2:
lite::arm::math::reduce_sum_h(input, output, n_in, c_in, h_in, w_in);
break;
case 3:
lite::arm::math::reduce_sum_w(input, output, n_in, c_in, h_in, w_in);
break;
default:
LOG(FATAL) << "dim[0] is " << dim[0]
<< ", which should be less than 4.";
}
} else if (dim.size() == 2) {
if (dim[0] == 0 && dim[1] == 1) {
lite::arm::math::reduce_sum_nc(input, output, n_in, c_in, h_in, w_in);
} else if (dim[0] == 1 && dim[1] == 2) {
lite::arm::math::reduce_sum_ch(input, output, n_in, c_in, h_in, w_in);
} else if (dim[0] == 2 && dim[1] == 3) {
lite::arm::math::reduce_sum_hw(input, output, n_in, c_in, h_in, w_in);
} else {
LOG(FATAL)
<< "Only support the values of the dim are 0,1 1,2 or 2,3 for now.";
}
} else {
LOG(FATAL) << "dim's size: " << dim.size()
<< " over than 2, which is not supported now!!";
}
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(reduce_sum,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::ReduceSumCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include "lite/backends/arm/math/type_trans.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class ReduceSumCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~ReduceSumCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/scatter_compute.h"
#include "lite/backends/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void ScatterCompute::Run() {
auto& param = this->template Param<operators::ScatterParam>();
const float* updates_data = param.updates->template data<float>();
const int64_t* indexs_data = param.indexs->template data<int64_t>();
float* output_data = param.output->template mutable_data<float>();
bool overwrite = param.overwrite;
int index_size = param.indexs->dims()[0];
auto in_dims = param.x->dims();
int num = 1;
for (int i = 1; i < in_dims.size(); i++) {
num *= in_dims[i];
}
lite::arm::math::scatter(indexs_data,
updates_data,
output_data,
index_size,
in_dims[0],
num,
overwrite);
if (!param.x->lod().empty()) {
param.output->set_lod(param.x->lod());
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(scatter,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::ScatterCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("Updates",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class ScatterCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~ScatterCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -121,6 +121,7 @@ add_operator(max_pool_with_index_op extra SRCS max_pool_with_index_op.cc DEPS ${
add_operator(pixel_shuffle_op extra SRCS pixel_shuffle_op.cc DEPS ${op_DEPS})
add_operator(clip_op extra SRCS clip_op.cc DEPS ${op_DEPS})
add_operator(print_op extra SRCS print_op.cc DEPS ${op_DEPS})
add_operator(scatter extra SRCS scatter_op.cc DEPS ${op_DEPS})
# for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
......
......@@ -294,6 +294,16 @@ struct ScaleParam : ParamBase {
}
};
// For Scatter OP
struct ScatterParam : ParamBase {
lite::Tensor* x{};
lite::Tensor* indexs{};
lite::Tensor* updates{};
lite::Tensor* output{};
bool overwrite{true};
};
// For Softmax op
struct SoftmaxParam : ParamBase {
lite::Tensor* x{};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/scatter_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool ScatterOp::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
return true;
}
bool ScatterOp::InferShapeImpl() const {
auto index_dims = param_.indexs->dims();
auto update_dims = param_.updates->dims();
auto input_dims = param_.x->dims();
for (int i = 1; i < update_dims.size(); i++) {
CHECK_EQ_OR_FALSE(update_dims[i], input_dims[i]);
}
CHECK_EQ_OR_FALSE(index_dims.size(), 1L);
param_.output->Resize(input_dims);
return true;
}
bool ScatterOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
AttachParam(&param_);
auto x = op_desc.Input("X").front();
auto indexs = op_desc.Input("Ids").front();
auto updates = op_desc.Input("Updates").front();
auto output = op_desc.Output("Out").front();
if (op_desc.HasAttr("overwrite")) {
param_.overwrite = op_desc.GetAttr<bool>("overwrite");
} else {
param_.overwrite = true;
}
param_.x = scope->FindVar(x)->GetMutable<Tensor>();
param_.indexs = scope->FindVar(indexs)->GetMutable<Tensor>();
param_.updates = scope->FindVar(updates)->GetMutable<Tensor>();
param_.output = scope->FindMutableTensor(output);
CHECK(param_.x);
CHECK(param_.indexs);
CHECK(param_.updates);
CHECK(param_.output);
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(scatter, paddle::lite::operators::ScatterOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class ScatterOp : public OpLite {
public:
ScatterOp() {}
explicit ScatterOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "Scatter"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = ch->DimToStr(param_.x->dims());
ch->output_shape = ch->DimToStr(param_.output->dims());
ch->macs = param_.x->numel() * 1.f;
}
#endif
private:
mutable ScatterParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -30,7 +30,7 @@ void bgra_to_tensor_hwc(const uint8_t* bgr,
float b_scales = scales[2];
int dim8 = width >> 3;
int remain = wwidth - (dim8 << 3);
int remain = width - (dim8 << 3);
float32x4_t vrmean = vdupq_n_f32(r_means);
float32x4_t vgmean = vdupq_n_f32(g_means);
......
......@@ -293,53 +293,53 @@ void test_img(const std::vector<int>& cluster_id,
// LOG(INFO) << "image convert saber compute";
t_convert.Start();
// 方法一: image_preprocess.imageCovert(src, lite_dst);
image_preprocess.imageConvert(
// method1: image_preprocess.image_convert(src, lite_dst);
image_preprocess.image_convert(
src, lite_dst, (ImageFormat)srcFormat, (ImageFormat)dstFormat);
t_convert.Stop();
// LOG(INFO) << "image resize saber compute";
t_resize.Start();
// 方法一:image_preprocess.imageResize(lite_dst, resize_tmp);
image_preprocess.imageResize(lite_dst,
resize_tmp,
(ImageFormat)dstFormat,
srcw,
srch,
dstw,
dsth);
// method1:image_preprocess.image_resize(lite_dst, resize_tmp);
image_preprocess.image_resize(lite_dst,
resize_tmp,
(ImageFormat)dstFormat,
srcw,
srch,
dstw,
dsth);
t_resize.Stop();
// LOG(INFO) << "image rotate saber compute";
t_rotate.Start();
// 方法一: image_preprocess.imageRotate(resize_tmp, tv_out_ratote);
image_preprocess.imageRotate(resize_tmp,
tv_out_ratote,
(ImageFormat)dstFormat,
dstw,
dsth,
rotate);
// method1: image_preprocess.image_rotate(resize_tmp, tv_out_ratote);
image_preprocess.image_rotate(resize_tmp,
tv_out_ratote,
(ImageFormat)dstFormat,
dstw,
dsth,
rotate);
t_rotate.Stop();
// LOG(INFO) << "image flip saber compute";
t_flip.Start();
// 方法一: image_preprocess.imageFlip(resize_tmp, tv_out_flip);
image_preprocess.imageFlip(
// method1: image_preprocess.image_flip(resize_tmp, tv_out_flip);
image_preprocess.image_flip(
resize_tmp, tv_out_flip, (ImageFormat)dstFormat, dstw, dsth, flip);
t_flip.Stop();
// LOG(INFO) << "image to tensor compute";
t_tensor.Start();
// 方法一: image_preprocess.image2Tensor(
// method1: image_preprocess.image_to_tensor(
// resize_tmp, &dst_tensor, layout, means, scales);
image_preprocess.image2Tensor(resize_tmp,
&dst_tensor,
(ImageFormat)dstFormat,
dstw,
dsth,
layout,
means,
scales);
image_preprocess.image_to_tensor(resize_tmp,
&dst_tensor,
(ImageFormat)dstFormat,
dstw,
dsth,
layout,
means,
scales);
t_tensor.Stop();
t1.Stop();
}
......@@ -680,7 +680,7 @@ void test_rotate(const std::vector<int>& cluster_id,
for (int i = 0; i < test_iter; ++i) {
t_rotate.Start();
image_preprocess.imageRotate(src, lite_dst);
image_preprocess.image_rotate(src, lite_dst);
t_rotate.Stop();
}
LOG(INFO) << "image rotate avg time : " << t_rotate.LapTimes().Avg()
......@@ -847,7 +847,7 @@ void test_flip(const std::vector<int>& cluster_id,
for (int i = 0; i < test_iter; ++i) {
t_rotate.Start();
image_preprocess.imageFlip(src, lite_dst);
image_preprocess.image_flip(src, lite_dst);
t_rotate.Stop();
}
LOG(INFO) << "image flip avg time : " << t_rotate.LapTimes().Avg()
......@@ -1016,7 +1016,7 @@ void test_resize(const std::vector<int>& cluster_id,
for (int i = 0; i < test_iter; ++i) {
t_rotate.Start();
image_preprocess.imageResize(src, lite_dst);
image_preprocess.image_resize(src, lite_dst);
t_rotate.Stop();
}
LOG(INFO) << "image Resize avg time : " << t_rotate.LapTimes().Avg()
......@@ -1191,7 +1191,7 @@ void test_convert(const std::vector<int>& cluster_id,
for (int i = 0; i < test_iter; ++i) {
t_rotate.Start();
image_preprocess.imageConvert(src, lite_dst);
image_preprocess.image_convert(src, lite_dst);
t_rotate.Stop();
}
LOG(INFO) << "image Convert avg time : " << t_rotate.LapTimes().Avg()
......
......@@ -163,7 +163,7 @@ void test_convert(const std::vector<int>& cluster_id,
for (int i = 0; i < test_iter; ++i) {
t_lite.Start();
image_preprocess.imageConvert(src, lite_dst);
image_preprocess.image_convert(src, lite_dst);
t_lite.Stop();
}
LOG(INFO) << "image Convert avg time : " << t_lite.LapTimes().Avg()
......@@ -284,7 +284,7 @@ void test_resize(const std::vector<int>& cluster_id,
for (int i = 0; i < test_iter; ++i) {
t_rotate.Start();
image_preprocess.imageResize(src, lite_dst);
image_preprocess.image_resize(src, lite_dst);
t_rotate.Stop();
}
LOG(INFO) << "image Resize avg time : " << t_rotate.LapTimes().Avg()
......@@ -405,7 +405,7 @@ void test_flip(const std::vector<int>& cluster_id,
for (int i = 0; i < test_iter; ++i) {
t_lite.Start();
image_preprocess.imageFlip(src, lite_dst);
image_preprocess.image_flip(src, lite_dst);
t_lite.Stop();
}
LOG(INFO) << "image flip avg time : " << t_lite.LapTimes().Avg()
......@@ -523,7 +523,7 @@ void test_rotate(const std::vector<int>& cluster_id,
for (int i = 0; i < test_iter; ++i) {
t_lite.Start();
image_preprocess.imageRotate(src, lite_dst);
image_preprocess.image_rotate(src, lite_dst);
t_lite.Stop();
}
LOG(INFO) << "image rotate avg time : " << t_lite.LapTimes().Avg()
......@@ -667,14 +667,14 @@ void test_to_tensor(const std::vector<int>& cluster_id,
for (int i = 0; i < test_iter; ++i) {
t_lite.Start();
image_preprocess.image2Tensor(src,
&dst_tensor,
(ImageFormat)dstFormat,
dstw,
dsth,
layout,
means,
scales);
image_preprocess.image_to_tensor(src,
&dst_tensor,
(ImageFormat)dstFormat,
dstw,
dsth,
layout,
means,
scales);
t_lite.Stop();
}
LOG(INFO) << "image tensor avg time : " << t_lite.LapTimes().Avg()
......
......@@ -66,6 +66,7 @@ if(LITE_BUILD_EXTRA)
lite_cc_test(test_kernel_ctc_align_compute SRCS ctc_align_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_clip_compute SRCS clip_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_pixel_shuffle_compute SRCS pixel_shuffle_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_scatter_compute SRCS scatter_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_expand_as_compute SRCS sequence_expand_as_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
# for training kernel
......
......@@ -420,7 +420,6 @@ void TestInterpAlignMode(Place place, float abs_error = 2e-5) {
if (place == TARGET(kARM) && align_mode == 1 && !align_corners) {
continue;
}
// align_mode = 0 && align_corners = false NOT supported in Huawei
// Ascend NPU DDK
if (place == TARGET(kHuaweiAscendNPU) && align_mode == 0 &&
!align_corners) {
......
......@@ -340,10 +340,10 @@ TEST(ReduceSum, precision) {
Place place(TARGET(kX86));
test_reduce_sum(place);
#endif
// #ifdef LITE_WITH_ARM
// Place place(TARGET(kARM));
// test_reduce_sum(place);
// #endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_reduce_sum(place);
#endif
}
} // namespace lite
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
void scatter_basic(const int64_t* indexs,
const float* src,
float* dst,
int index_size,
int num,
int size,
bool overwrite) {
for (int i = 0; i < num; i++) {
const float* din = src + indexs[i] * size;
memcpy(dst, din, sizeof(float) * size);
dst += size;
}
if (overwrite) {
for (int i = num; i < index_size; i++) {
const float* din = src + indexs[i] * size;
float* dout = dst + indexs[i] * size;
memcpy(dout, din, sizeof(float) * size);
}
} else {
for (int i = num; i < index_size; i++) {
const float* din = src + indexs[i] * size;
float* dout = dst + indexs[i] * size;
for (int j = 0; j < size; j++) {
dout[j] += din[j];
}
}
}
}
class ScatterComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string input_ = "x";
std::string indexs_ = "indexs";
std::string updates_ = "updates";
std::string output_ = "out";
DDim up_dims_{{1}};
DDim id_dims_{{1}};
DDim x_dims_{{1}};
int index_size_ = 0;
bool overwrite_ = false;
public:
ScatterComputeTester(const Place& place,
const std::string& alias,
DDim up_dims,
DDim id_dims,
DDim x_dims,
bool overwrite,
int index_size)
: TestCase(place, alias),
up_dims_(up_dims),
id_dims_(id_dims),
x_dims_(x_dims),
index_size_(index_size),
overwrite_(overwrite) {}
void RunBaseline(Scope* scope) override {
auto* indexs_t = scope->FindMutableTensor(indexs_);
auto* updates_t = scope->FindMutableTensor(updates_);
const auto* indexs_data = indexs_t->data<int64_t>();
const auto* updates_data = updates_t->data<float>();
auto* out = scope->NewTensor(output_);
out->Resize(x_dims_);
auto* out_data = out->mutable_data<float>();
int in_n = x_dims_[0];
int in_c = x_dims_[1];
int in_h = x_dims_[2];
int in_w = x_dims_[3];
int size = in_c * in_h * in_w;
scatter_basic(indexs_data,
updates_data,
out_data,
index_size_,
in_n,
size,
overwrite_);
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("scatter");
op_desc->SetInput("X", {input_});
op_desc->SetInput("Ids", {indexs_});
op_desc->SetInput("Updates", {updates_});
op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("overwrite", overwrite_);
}
void PrepareData() override {
std::vector<float> data(x_dims_.production());
for (int i = 0; i < x_dims_.production(); i++) {
data[i] = i * 1.0;
}
SetCommonTensor(input_, x_dims_, data.data());
std::vector<float> update(up_dims_.production());
for (int i = 0; i < up_dims_.production(); i++) {
update[i] = i * 1.0;
}
SetCommonTensor(updates_, up_dims_, update.data());
std::vector<int64_t> index(id_dims_.production());
for (int i = 0; i < id_dims_.production(); i++) {
index[i] = i;
}
SetCommonTensor(indexs_, id_dims_, index.data());
}
};
void test_scatter(Place place) {
for (auto n : {1, 3}) {
for (auto c : {1, 2}) {
for (auto h : {1, 3}) {
for (auto w : {1, 3}) {
for (bool overwrite : {false, true}) {
auto x_dims = DDim(std::vector<int64_t>({n, c, h, w}));
auto up_dims = DDim(std::vector<int64_t>({n, c, h, w}));
auto id_dims = DDim(std::vector<int64_t>({n}));
std::unique_ptr<arena::TestCase> tester(new ScatterComputeTester(
place, "def", up_dims, id_dims, x_dims, overwrite, n));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
}
}
}
TEST(Scatter, precision) {
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_scatter(place);
#endif
}
} // namespace lite
} // namespace paddle
......@@ -39,7 +39,13 @@ DEFINE_int32(power_mode,
DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
#ifdef LITE_WITH_ARM
DEFINE_bool(basic_test, true, "do all tests");
#else
DEFINE_bool(basic_test, false, "do all tests");
#endif
DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(M, 512, "gemm: M");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册