提交 01c0868e 编写于 作者: H hong19860320 提交者: GitHub

enable softmax op and add unit test (#17703)

* enable softmax op and add unit test

* move softmax sub-functions to softmax.cc, and move basic math functions to funcs.h
上级 244a9e06
......@@ -72,6 +72,7 @@ USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def);
// USE_LITE_KERNEL(feed, kARM, kAny, kAny, def);
// USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def);
#endif // LITE_WITH_ARM
......
cc_library(math_arm SRCS funcs.cc packed_sgemm.cc DEPS ${lite_kernel_deps} eigen3)
cc_library(math_arm SRCS funcs.cc packed_sgemm.cc softmax.cc DEPS ${lite_kernel_deps} eigen3)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/softmax.h"
#include <algorithm>
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <>
void softmax_basic<float>(const float* din, float* dout, const int axis_size,
const int inner_num, const int outer_num) {
int compute_size = inner_num * outer_num;
#pragma omp parallel for
for (int i = 0; i < compute_size; ++i) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
float max_data = din[real_index];
// get max
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
max_data = din[real_index] > max_data ? din[real_index] : max_data;
}
real_index = idx_outer * inner_num + idx_inner;
// sub, exp and sum
dout[real_index] = expf(din[real_index] - max_data);
float sum_data = dout[real_index];
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
dout[real_index] = expf(din[real_index] - max_data);
sum_data += dout[real_index];
}
float sum_inv = 1.f / sum_data;
real_index = idx_outer * inner_num + idx_inner;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
dout[real_index] *= sum_inv;
real_index += inner_num;
}
}
}
template <>
void softmax_inner8_axis4<float>(const float* din, float* dout,
const int axis_size, const int inner_num,
const int outer_num) {
int compute_size = inner_num * outer_num;
int cmp_cnt = compute_size >> 3;
int remain = compute_size % 8;
float32x4_t vone = vdupq_n_f32(1.0f);
#pragma omp parallel for
for (int c = 0; c < cmp_cnt; ++c) {
int i = c * 8;
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
// get max axis_size == 4
const float* din_ptr = din + real_index;
const float* din_ptr1 = din_ptr + inner_num;
const float* din_ptr2 = din_ptr1 + inner_num;
const float* din_ptr3 = din_ptr2 + inner_num;
float32x4_t vdata0 = vld1q_f32(din_ptr);
float32x4_t vdata1 = vld1q_f32(din_ptr1);
float32x4_t vdata2 = vld1q_f32(din_ptr2);
float32x4_t vdata3 = vld1q_f32(din_ptr3);
float32x4_t vdata01 = vld1q_f32(din_ptr + 4);
float32x4_t vdata11 = vld1q_f32(din_ptr1 + 4);
float32x4_t vdata21 = vld1q_f32(din_ptr2 + 4);
float32x4_t vdata31 = vld1q_f32(din_ptr3 + 4);
float* dout_ptr0 = dout + real_index;
float* dout_ptr1 = dout_ptr0 + inner_num;
float32x4_t vmax1 = vmaxq_f32(vdata0, vdata1);
float32x4_t vmax2 = vmaxq_f32(vdata2, vdata3);
float32x4_t vmax11 = vmaxq_f32(vdata01, vdata11);
float32x4_t vmax21 = vmaxq_f32(vdata21, vdata31);
float* dout_ptr2 = dout_ptr1 + inner_num;
float* dout_ptr3 = dout_ptr2 + inner_num;
float32x4_t vmax = vmaxq_f32(vmax1, vmax2);
float32x4_t vmax_1 = vmaxq_f32(vmax11, vmax21);
// sub, exp and sum
float32x4_t vsum0 = exp_ps(vsubq_f32(vdata0, vmax));
float32x4_t vsum1 = exp_ps(vsubq_f32(vdata1, vmax));
float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax));
float32x4_t vsum3 = exp_ps(vsubq_f32(vdata3, vmax));
float32x4_t vsum01 = exp_ps(vsubq_f32(vdata01, vmax_1));
float32x4_t vsum11 = exp_ps(vsubq_f32(vdata11, vmax_1));
float32x4_t vsum21 = exp_ps(vsubq_f32(vdata21, vmax_1));
float32x4_t vsum31 = exp_ps(vsubq_f32(vdata31, vmax_1));
float32x4_t vsum_1 = vaddq_f32(vsum0, vsum1);
float32x4_t vsum_2 = vaddq_f32(vsum2, vsum3);
float32x4_t vsum_11 = vaddq_f32(vsum01, vsum11);
float32x4_t vsum_21 = vaddq_f32(vsum21, vsum31);
float32x4_t vsum = vaddq_f32(vsum_1, vsum_2);
float32x4_t vsum111 = vaddq_f32(vsum_11, vsum_21);
float32x4_t vinf = div_ps(vone, vsum);
float32x4_t vinf1 = div_ps(vone, vsum111);
vsum0 = vmulq_f32(vsum0, vinf);
vsum1 = vmulq_f32(vsum1, vinf);
vsum2 = vmulq_f32(vsum2, vinf);
vsum3 = vmulq_f32(vsum3, vinf);
vsum01 = vmulq_f32(vsum01, vinf1);
vsum11 = vmulq_f32(vsum11, vinf1);
vsum21 = vmulq_f32(vsum21, vinf1);
vsum31 = vmulq_f32(vsum31, vinf1);
vst1q_f32(dout_ptr0, vsum0);
vst1q_f32(dout_ptr1, vsum1);
vst1q_f32(dout_ptr2, vsum2);
vst1q_f32(dout_ptr3, vsum3);
vst1q_f32(dout_ptr0 + 4, vsum01);
vst1q_f32(dout_ptr1 + 4, vsum11);
vst1q_f32(dout_ptr2 + 4, vsum21);
vst1q_f32(dout_ptr3 + 4, vsum31);
}
int i = cmp_cnt * 8;
if (remain > 4) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
// get max axis_size == 4
const float* din_ptr = din + real_index;
const float* din_ptr1 = din_ptr + inner_num;
const float* din_ptr2 = din_ptr1 + inner_num;
const float* din_ptr3 = din_ptr2 + inner_num;
float32x4_t vdata0 = vld1q_f32(din_ptr);
float32x4_t vdata1 = vld1q_f32(din_ptr1);
float32x4_t vdata2 = vld1q_f32(din_ptr2);
float32x4_t vdata3 = vld1q_f32(din_ptr3);
float* dout_ptr0 = dout + real_index;
float* dout_ptr1 = dout_ptr0 + inner_num;
float32x4_t vmax1 = vmaxq_f32(vdata0, vdata1);
float32x4_t vmax2 = vmaxq_f32(vdata2, vdata3);
float* dout_ptr2 = dout_ptr1 + inner_num;
float* dout_ptr3 = dout_ptr2 + inner_num;
float32x4_t vmax = vmaxq_f32(vmax1, vmax2);
// sub, exp and sum
float32x4_t vsum0 = exp_ps(vsubq_f32(vdata0, vmax));
float32x4_t vsum1 = exp_ps(vsubq_f32(vdata1, vmax));
float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax));
float32x4_t vsum3 = exp_ps(vsubq_f32(vdata3, vmax));
float32x4_t vsum_1 = vaddq_f32(vsum0, vsum1);
float32x4_t vsum_2 = vaddq_f32(vsum2, vsum3);
float32x4_t vsum = vaddq_f32(vsum_1, vsum_2);
float32x4_t vone = vdupq_n_f32(1.0f);
float32x4_t vinf = div_ps(vone, vsum);
vsum0 = vmulq_f32(vsum0, vinf);
vsum1 = vmulq_f32(vsum1, vinf);
vsum2 = vmulq_f32(vsum2, vinf);
vsum3 = vmulq_f32(vsum3, vinf);
vst1q_f32(dout_ptr0, vsum0);
vst1q_f32(dout_ptr1, vsum1);
vst1q_f32(dout_ptr2, vsum2);
vst1q_f32(dout_ptr3, vsum3);
i += 4;
}
for (; i < compute_size; i++) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
float max_data = din[real_index];
// get max
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
max_data = din[real_index] > max_data ? din[real_index] : max_data;
}
real_index = idx_outer * inner_num + idx_inner;
// sub, exp and sum
dout[real_index] = expf(din[real_index] - max_data);
float sum_data = dout[real_index];
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
dout[real_index] = expf(din[real_index] - max_data);
sum_data += dout[real_index];
}
float sum_inv = 1.f / sum_data;
real_index = idx_outer * inner_num + idx_inner;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
dout[real_index] *= sum_inv;
real_index += inner_num;
}
}
}
template <>
void softmax_inner4_axis4<float>(const float* din, float* dout,
const int axis_size, const int inner_num,
const int outer_num) {
int compute_size = inner_num * outer_num;
int cmp_cnt = compute_size >> 2;
int remain = compute_size % 4;
float32x4_t vone = vdupq_n_f32(1.0f);
#pragma omp parallel for
for (int c = 0; c < cmp_cnt; ++c) {
int i = c * 4;
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
// get max axis_size == 4
const float* din_ptr = din + real_index;
const float* din_ptr1 = din_ptr + inner_num;
const float* din_ptr2 = din_ptr1 + inner_num;
const float* din_ptr3 = din_ptr2 + inner_num;
float32x4_t vdata0 = vld1q_f32(din_ptr);
float32x4_t vdata1 = vld1q_f32(din_ptr1);
float32x4_t vdata2 = vld1q_f32(din_ptr2);
float32x4_t vdata3 = vld1q_f32(din_ptr3);
float* dout_ptr0 = dout + real_index;
float* dout_ptr1 = dout_ptr0 + inner_num;
float32x4_t vmax1 = vmaxq_f32(vdata0, vdata1);
float32x4_t vmax2 = vmaxq_f32(vdata2, vdata3);
float* dout_ptr2 = dout_ptr1 + inner_num;
float* dout_ptr3 = dout_ptr2 + inner_num;
float32x4_t vmax = vmaxq_f32(vmax1, vmax2);
// sub, exp and sum
float32x4_t vsum0 = exp_ps(vsubq_f32(vdata0, vmax));
float32x4_t vsum1 = exp_ps(vsubq_f32(vdata1, vmax));
float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax));
float32x4_t vsum3 = exp_ps(vsubq_f32(vdata3, vmax));
float32x4_t vsum_1 = vaddq_f32(vsum0, vsum1);
float32x4_t vsum_2 = vaddq_f32(vsum2, vsum3);
float32x4_t vsum = vaddq_f32(vsum_1, vsum_2);
float32x4_t vinf = div_ps(vone, vsum);
vsum0 = vmulq_f32(vsum0, vinf);
vsum1 = vmulq_f32(vsum1, vinf);
vsum2 = vmulq_f32(vsum2, vinf);
vsum3 = vmulq_f32(vsum3, vinf);
vst1q_f32(dout_ptr0, vsum0);
vst1q_f32(dout_ptr1, vsum1);
vst1q_f32(dout_ptr2, vsum2);
vst1q_f32(dout_ptr3, vsum3);
}
int i = cmp_cnt * 8;
for (; i < compute_size; i++) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
float max_data = din[real_index];
// get max
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
max_data = din[real_index] > max_data ? din[real_index] : max_data;
}
real_index = idx_outer * inner_num + idx_inner;
// sub, exp and sum
dout[real_index] = expf(din[real_index] - max_data);
float sum_data = dout[real_index];
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
dout[real_index] = expf(din[real_index] - max_data);
sum_data += dout[real_index];
}
float sum_inv = 1.f / sum_data;
real_index = idx_outer * inner_num + idx_inner;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
dout[real_index] *= sum_inv;
real_index += inner_num;
}
}
}
template <>
void softmax_inner8<float>(const float* din, float* dout, const int axis_size,
const int inner_num, const int outer_num) {
int compute_size = inner_num * outer_num;
int cmp_cnt = compute_size >> 3;
#pragma omp parallel for
for (int c = 0; c < cmp_cnt; ++c) {
int i = c * 8;
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
const float* din_ptr = din + real_index;
float32x4_t vmax = vld1q_f32(din_ptr);
float32x4_t vmax2 = vld1q_f32(din_ptr + 4);
// get max
for (int j = 1; j < axis_size; ++j) {
din_ptr += inner_num;
float32x4_t vdata = vld1q_f32(din_ptr);
float32x4_t vdata2 = vld1q_f32(din_ptr + 4);
vmax = vmaxq_f32(vmax, vdata);
vmax2 = vmaxq_f32(vmax2, vdata2);
}
// sub, exp and sum
din_ptr = din + real_index;
float* dout_ptr = dout + real_index;
float32x4_t vdata = vld1q_f32(din_ptr);
float32x4_t vdata2 = vld1q_f32(din_ptr + 4);
float32x4_t vsum = exp_ps(vsubq_f32(vdata, vmax));
float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax2));
din_ptr += inner_num;
vst1q_f32(dout_ptr, vsum);
vst1q_f32(dout_ptr + 4, vsum2);
dout_ptr += inner_num;
for (int j = 1; j < axis_size; ++j) {
float32x4_t vdata0 = vld1q_f32(din_ptr);
float32x4_t vdata1 = vld1q_f32(din_ptr + 4);
vdata0 = exp_ps(vsubq_f32(vdata0, vmax));
vdata1 = exp_ps(vsubq_f32(vdata1, vmax2));
din_ptr += inner_num;
vsum = vaddq_f32(vsum, vdata0);
vsum2 = vaddq_f32(vsum2, vdata1);
vst1q_f32(dout_ptr, vdata0);
vst1q_f32(dout_ptr + 4, vdata1);
dout_ptr += inner_num;
}
float32x4_t vone = vdupq_n_f32(1.0f);
float32x4_t vinf = div_ps(vone, vsum);
float32x4_t vinf2 = div_ps(vone, vsum2);
dout_ptr = dout + real_index;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
float32x4_t vdata0 = vld1q_f32(dout_ptr);
float32x4_t vdata1 = vld1q_f32(dout_ptr + 4);
vdata0 = vmulq_f32(vdata0, vinf);
vdata1 = vmulq_f32(vdata1, vinf2);
vst1q_f32(dout_ptr, vdata0);
vst1q_f32(dout_ptr + 4, vdata1);
dout_ptr += inner_num;
}
}
for (int i = cmp_cnt * 8; i < compute_size; i++) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
float max_data = din[real_index];
// get max
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
max_data = din[real_index] > max_data ? din[real_index] : max_data;
}
real_index = idx_outer * inner_num + idx_inner;
// sub, exp and sum
dout[real_index] = expf(din[real_index] - max_data);
float sum_data = dout[real_index];
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
dout[real_index] = expf(din[real_index] - max_data);
sum_data += dout[real_index];
}
float sum_inv = 1.f / sum_data;
real_index = idx_outer * inner_num + idx_inner;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
dout[real_index] *= sum_inv;
real_index += inner_num;
}
}
}
template <>
void softmax_inner4<float>(const float* din, float* dout, const int axis_size,
const int inner_num, const int outer_num) {
int compute_size = inner_num * outer_num;
int cmp_cnt = compute_size >> 2;
#pragma omp parallel for
for (int c = 0; c < cmp_cnt; ++c) {
int i = c * 4;
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
// float max_data = din[real_index];
const float* din_ptr = din + real_index;
float32x4_t vmax = vld1q_f32(din_ptr);
// get max
for (int j = 1; j < axis_size; ++j) {
din_ptr += inner_num;
float32x4_t vdata = vld1q_f32(din_ptr);
vmax = vmaxq_f32(vmax, vdata);
}
// sub, exp and sum
din_ptr = din + real_index;
float* dout_ptr = dout + real_index;
float32x4_t vdata = vld1q_f32(din_ptr);
float32x4_t vsum = exp_ps(vsubq_f32(vdata, vmax));
din_ptr += inner_num;
vst1q_f32(dout_ptr, vsum);
dout_ptr += inner_num;
for (int j = 1; j < axis_size; ++j) {
// real_index += inner_num;
float32x4_t vdata0 = vld1q_f32(din_ptr);
vdata0 = exp_ps(vsubq_f32(vdata0, vmax));
din_ptr += inner_num;
vsum = vaddq_f32(vsum, vdata0);
vst1q_f32(dout_ptr, vdata0);
dout_ptr += inner_num;
}
float32x4_t vone = vdupq_n_f32(1.0f);
float32x4_t vinf = div_ps(vone, vsum);
dout_ptr = dout + real_index;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
float32x4_t vdata0 = vld1q_f32(dout_ptr);
vdata0 = vmulq_f32(vdata0, vinf);
vst1q_f32(dout_ptr, vdata0);
dout_ptr += inner_num;
}
}
for (int i = cmp_cnt * 4; i < compute_size; i++) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
float max_data = din[real_index];
// get max
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
max_data = din[real_index] > max_data ? din[real_index] : max_data;
}
real_index = idx_outer * inner_num + idx_inner;
// sub, exp and sum
dout[real_index] = expf(din[real_index] - max_data);
float sum_data = dout[real_index];
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
dout[real_index] = expf(din[real_index] - max_data);
sum_data += dout[real_index];
}
float sum_inv = 1.f / sum_data;
real_index = idx_outer * inner_num + idx_inner;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
dout[real_index] *= sum_inv;
real_index += inner_num;
}
}
}
template <>
void softmax_inner1_large_axis<float>(const float* din, float* dout,
const int outer_size,
const int axis_size) {
#pragma omp parallel for
for (int i = 0; i < outer_size; ++i) {
const float* din_ptr = din + i * axis_size;
float* dout_ptr = dout + i * axis_size;
const float* din_max_ptr = din_ptr;
int nn = axis_size >> 2;
// get max
float32x4_t vmax = vld1q_f32(din_max_ptr);
din_max_ptr += 4;
int j = 1;
for (; j < nn; ++j) {
vmax = vmaxq_f32(vmax, vld1q_f32(din_max_ptr));
din_max_ptr += 4;
}
float32x2_t vhmax = vmax_f32(vget_high_f32(vmax), vget_low_f32(vmax));
float max_data = std::max(vget_lane_f32(vhmax, 0), vget_lane_f32(vhmax, 1));
for (j = 4 * j; j < axis_size; ++j) {
max_data = std::max(max_data, din_max_ptr[0]);
din_max_ptr++;
}
// sub, exp and sum
const float* din_sum_ptr = din_ptr;
float* dout_sum_ptr = dout_ptr;
vmax = vdupq_n_f32(max_data);
float32x4_t vsub_exp = exp_ps(vsubq_f32(vld1q_f32(din_sum_ptr), vmax));
float32x4_t vsum = vsub_exp;
vst1q_f32(dout_sum_ptr, vsub_exp);
din_sum_ptr += 4;
dout_sum_ptr += 4;
j = 1;
for (; j < nn; ++j) {
vsub_exp = exp_ps(vsubq_f32(vld1q_f32(din_sum_ptr), vmax));
vst1q_f32(dout_sum_ptr, vsub_exp);
vsum = vaddq_f32(vsum, vsub_exp);
din_sum_ptr += 4;
dout_sum_ptr += 4;
}
float32x2_t vhsum = vadd_f32(vget_high_f32(vsum), vget_low_f32(vsum));
float sum_data = vget_lane_f32(vhsum, 0) + vget_lane_f32(vhsum, 1);
for (j = 4 * j; j < axis_size; ++j) {
dout_sum_ptr[0] = expf(din_sum_ptr[0] - max_data);
sum_data += dout_sum_ptr[0];
din_sum_ptr++;
dout_sum_ptr++;
}
float sum_inv = 1.f / sum_data;
float* dout_res_ptr = dout_ptr;
float32x4_t vinv = vdupq_n_f32(sum_inv);
// get softmax result
j = 0;
for (; j < nn; ++j) {
float32x4_t vout = vld1q_f32(dout_res_ptr);
float32x4_t vres = vmulq_f32(vout, vinv);
vst1q_f32(dout_res_ptr, vres);
dout_res_ptr += 4;
}
for (j = nn * 4; j < axis_size; ++j) {
dout_ptr[j] *= sum_inv;
}
}
}
template <>
void softmax_inner1_small_axis<float>(const float* din, float* dout,
const int outer_size,
const int axis_size) {
#pragma omp parallel for
for (int i = 0; i < outer_size; ++i) {
const float* din_ptr = din + i * axis_size;
float* dout_ptr = dout + i * axis_size;
// get max
float max_data = din_ptr[0];
for (int j = 1; j < axis_size; ++j) {
max_data = std::max(max_data, din_ptr[j]);
}
// sub, exp and sum
float sum_data = 0.f;
for (int j = 0; j < axis_size; ++j) {
dout_ptr[j] = expf(din_ptr[j] - max_data);
sum_data += dout_ptr[j];
}
float sum_inv = 1.f / sum_data;
for (int j = 0; j < axis_size; ++j) {
dout_ptr[j] *= sum_inv;
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void softmax_basic(const T* din, T* dout, const int axis_size,
const int inner_num, const int outer_num);
template <typename T>
void softmax_inner8_axis4(const T* din, T* dout, const int axis_size,
const int inner_num, const int outer_num);
template <typename T>
void softmax_inner4_axis4(const T* din, T* dout, const int axis_size,
const int inner_num, const int outer_num);
template <typename T>
void softmax_inner8(const T* din, T* dout, const int axis_size,
const int inner_num, const int outer_num);
template <typename T>
void softmax_inner4(const T* din, T* dout, const int axis_size,
const int inner_num, const int outer_num);
template <typename T>
void softmax_inner1_large_axis(const T* din, T* dout, const int outer_size,
const int axis_size);
template <typename T>
void softmax_inner1_small_axis(const T* din, T* dout, const int outer_size,
const int axis_size);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
......@@ -8,13 +8,16 @@ cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps})
cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3)
cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} eigen3)
cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm)
lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm)
lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm)
set(arm_kernels
fc_compute_arm
relu_compute_arm
mul_compute_arm
scale_compute_arm)
scale_compute_arm
softmax_compute_arm)
set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/softmax_compute.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void SoftmaxCompute::Run() {
auto& param = Param<operators::SoftmaxParam>();
const float* din = param.x->data<float>();
float* dout = param.output->mutable_data<float>();
auto dim_x = param.x->dims();
auto rank_x = dim_x.size();
int axis = param.axis;
if (axis < 0) {
axis += rank_x;
}
int outer_num = dim_x.Slice(0, axis).production();
int inner_num = dim_x.Slice(axis + 1, rank_x).production();
int axis_size = dim_x[axis];
if (inner_num == 1) {
if (axis_size >= 4) {
lite::arm::math::softmax_inner1_large_axis(din, dout, outer_num,
axis_size);
} else {
lite::arm::math::softmax_inner1_small_axis(din, dout, outer_num,
axis_size);
}
} else {
int compute_size = outer_num * inner_num;
if (axis_size == 4 && inner_num % 8 == 0) {
lite::arm::math::softmax_inner8_axis4(din, dout, axis_size, inner_num,
outer_num);
} else if (axis_size == 4 && inner_num % 4 == 0) {
lite::arm::math::softmax_inner4_axis4(din, dout, axis_size, inner_num,
outer_num);
} else {
if (inner_num % 8 == 0) {
lite::arm::math::softmax_inner8(din, dout, axis_size, inner_num,
outer_num);
} else if (inner_num % 4 == 0) {
lite::arm::math::softmax_inner4(din, dout, axis_size, inner_num,
outer_num);
} else {
lite::arm::math::softmax_basic(din, dout, axis_size, inner_num,
outer_num);
}
}
}
}
TargetType SoftmaxCompute::target() const { return TARGET(kARM); }
PrecisionType SoftmaxCompute::precision() const { return PRECISION(kFloat); }
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(softmax, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::SoftmaxCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class SoftmaxCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
TargetType target() const override;
PrecisionType precision() const override;
virtual ~SoftmaxCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/softmax_compute.h"
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <typename dtype>
void softmat_compute_ref(const operators::SoftmaxParam& param) {
const dtype* x_data = param.x->mutable_data<const dtype>();
dtype* output_data = param.output->mutable_data<dtype>();
DDim dim = param.x->dims();
ASSERT_EQ(dim.data(), param.output->dims().data());
auto rank = dim.size();
int axis = param.axis;
if (axis < 0) {
axis += rank;
}
int axis_size = dim[axis];
int outer_num = dim.Slice(0, axis).production();
int inner_num = dim.Slice(axis + 1, rank).production();
int compute_size = outer_num * inner_num;
for (int i = 0; i < compute_size; i++) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int start = idx_outer * inner_num + idx_inner;
int offset;
offset = start;
dtype max_data = std::numeric_limits<dtype>::lowest();
for (int j = 0; j < axis_size; j++) {
max_data = x_data[offset] > max_data ? x_data[offset] : max_data;
offset += inner_num;
}
offset = start;
dtype sum_data = (dtype)0;
for (int j = 0; j < axis_size; j++) {
output_data[offset] = exp(x_data[offset] - max_data);
sum_data += output_data[offset];
offset += inner_num;
}
offset = start;
for (int j = 0; j < axis_size; j++) {
output_data[offset] /= sum_data;
offset += inner_num;
}
}
}
TEST(softmax_arm, init) {
SoftmaxCompute softmax;
ASSERT_EQ(softmax.precision(), PRECISION(kFloat));
ASSERT_EQ(softmax.target(), TARGET(kARM));
}
TEST(softmax_arm, compute) {
SoftmaxCompute softmax;
operators::SoftmaxParam param;
lite::Tensor x;
lite::Tensor output;
lite::Tensor output_ref;
for (auto n : {1, 3, 4, 11}) {
for (auto c : {1, 3, 11, 4}) {
for (auto h : {3, 1, 11, 4}) {
for (auto w : {1, 3, 4, 12}) {
for (auto axis : {-4, -3, -2, -1, 0, 1, 2, 3}) {
x.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output_ref.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
auto* x_data = x.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
auto* output_ref_data = output_ref.mutable_data<float>();
for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = i;
}
param.x = &x;
param.axis = axis;
param.output = &output;
softmax.SetParam(param);
softmax.Run();
param.output = &output_ref;
softmat_compute_ref<float>(param);
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5);
}
}
}
}
}
}
}
TEST(softmax, retrive_op) {
auto softmax =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"softmax");
ASSERT_FALSE(softmax.empty());
ASSERT_TRUE(softmax.front());
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def);
......@@ -18,5 +18,6 @@
USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(feed, kARM, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def);
......@@ -4,6 +4,7 @@ cc_library(fc_op_lite SRCS fc_op.cc DEPS ${op_DEPS})
cc_library(relu_op_lite SRCS relu_op.cc DEPS ${op_DEPS})
cc_library(mul_op_lite SRCS mul_op.cc DEPS ${op_DEPS})
cc_library(scale_op_lite SRCS scale_op.cc DEPS ${op_DEPS})
cc_library(softmax_op_lite SRCS softmax_op.cc DEPS ${op_DEPS})
cc_library(feed_op_lite SRCS feed_op.cc DEPS ${op_DEPS})
cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS})
cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS})
......@@ -19,6 +20,7 @@ set(ops_lite
relu_op_lite
mul_op_lite
scale_op_lite
softmax_op_lite
feed_op_lite
fetch_op_lite
io_copy_op_lite
......@@ -28,3 +30,4 @@ set(ops_lite
PARENT_SCOPE)
lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite memory_lite)
lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite)
......@@ -93,6 +93,14 @@ struct ScaleParam {
bool bias_after_scale{true};
};
// For Softmax Op
struct SoftmaxParam {
lite::Tensor* x{};
lite::Tensor* output{};
int axis{-1};
};
/// ----------------------- element wise operators ----------------------
struct ElementwiseParam {
const lite::Tensor* X{};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/softmax_op.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SoftmaxOp::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
auto dim_x = param_.x->dims();
auto rank_x = dim_x.size();
CHECK_OR_FALSE(param_.axis >= -rank_x && param_.axis < rank_x);
return true;
}
bool SoftmaxOp::InferShape() const {
param_.output->Resize(param_.x->dims());
return true;
}
bool SoftmaxOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) {
param_.x = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
param_.output =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.axis = GetAttr<int>(opdesc.GetAttr("axis"));
CHECK(param_.x);
CHECK(param_.output);
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(softmax, paddle::lite::operators::SoftmaxOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class SoftmaxOp : public OpLite {
public:
SoftmaxOp() {}
explicit SoftmaxOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "softmax"; }
private:
mutable SoftmaxParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/softmax_op.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
TEST(softmax_op_lite, test) {
// prepare variables
Scope scope;
auto* x = scope.Var("x")->GetMutable<Tensor>();
auto* output = scope.Var("output")->GetMutable<Tensor>();
x->Resize(DDim(std::vector<int64_t>({10, 20})));
output->Resize(DDim(std::vector<int64_t>{10, 20}));
// set data
for (int i = 0; i < 10 * 20; i++) {
x->mutable_data<float>()[i] = i;
}
for (int i = 0; i < 10 * 20; i++) {
output->mutable_data<float>()[i] = 0.;
}
// prepare op desc
lite::OpDesc desc;
desc.SetType("softmax");
desc.SetInput("X", {"x"});
desc.SetOutput("Out", {"output"});
desc.SetAttr("axis", static_cast<int>(-1));
SoftmaxOp softmax("softmax");
softmax.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}});
softmax.Attach(desc, &scope);
}
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册