未验证 提交 e242f781 编写于 作者: H hong19860320 提交者: GitHub

enable softmax op and add unit test (#17703)

* enable softmax op and add unit test

* move softmax sub-functions to softmax.cc, and move basic math functions to funcs.h
上级 ca45ed53
......@@ -72,6 +72,7 @@ USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def);
// USE_LITE_KERNEL(feed, kARM, kAny, kAny, def);
// USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def);
#endif // LITE_WITH_ARM
......
cc_library(math_arm SRCS funcs.cc packed_sgemm.cc DEPS ${lite_kernel_deps} eigen3)
cc_library(math_arm SRCS funcs.cc packed_sgemm.cc softmax.cc DEPS ${lite_kernel_deps} eigen3)
......@@ -18,12 +18,293 @@
#include <cmath>
#include "paddle/fluid/lite/arm/math/packed_sgemm.h"
#include "paddle/fluid/lite/arm/math/softmax.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
#define c_inv_mant_mask ~0x7f800000u
#define c_cephes_SQRTHF 0.707106781186547524
#define c_cephes_log_p0 7.0376836292E-2
#define c_cephes_log_p1 -1.1514610310E-1
#define c_cephes_log_p2 1.1676998740E-1
#define c_cephes_log_p3 -1.2420140846E-1
#define c_cephes_log_p4 +1.4249322787E-1
#define c_cephes_log_p5 -1.6668057665E-1
#define c_cephes_log_p6 +2.0000714765E-1
#define c_cephes_log_p7 -2.4999993993E-1
#define c_cephes_log_p8 +3.3333331174E-1
#define c_cephes_log_q1 -2.12194440e-4
#define c_cephes_log_q2 0.693359375
// natural logarithm computed for 4 simultaneous float
// return NaN for x <= 0
inline float32x4_t log_ps(float32x4_t x) {
float32x4_t one = vdupq_n_f32(1);
x = vmaxq_f32(x, vdupq_n_f32(0)); // force flush to zero on denormal values
uint32x4_t invalid_mask = vcleq_f32(x, vdupq_n_f32(0));
int32x4_t ux = vreinterpretq_s32_f32(x);
int32x4_t emm0 = vshrq_n_s32(ux, 23);
// keep only the fractional part
ux = vandq_s32(ux, vdupq_n_s32(c_inv_mant_mask));
ux = vorrq_s32(ux, vreinterpretq_s32_f32(vdupq_n_f32(0.5f)));
x = vreinterpretq_f32_s32(ux);
emm0 = vsubq_s32(emm0, vdupq_n_s32(0x7f));
float32x4_t e = vcvtq_f32_s32(emm0);
e = vaddq_f32(e, one);
// part2:
// if( x < SQRTHF ) {
// e -= 1;
// x = x + x - 1.0;
// } else {
// x = x - 1.0;
// }
//
uint32x4_t mask = vcltq_f32(x, vdupq_n_f32(c_cephes_SQRTHF));
float32x4_t tmp =
vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask));
x = vsubq_f32(x, one);
e = vsubq_f32(
e, vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(one), mask)));
x = vaddq_f32(x, tmp);
float32x4_t z = vmulq_f32(x, x);
float32x4_t y = vdupq_n_f32(c_cephes_log_p0);
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p1));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p2));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p3));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p4));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p5));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p6));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p7));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p8));
y = vmulq_f32(y, x);
y = vmulq_f32(y, z);
tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q1));
y = vaddq_f32(y, tmp);
tmp = vmulq_f32(z, vdupq_n_f32(0.5f));
y = vsubq_f32(y, tmp);
tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q2));
x = vaddq_f32(x, y);
x = vaddq_f32(x, tmp);
x = vreinterpretq_f32_u32(vorrq_u32(
vreinterpretq_u32_f32(x), invalid_mask)); // negative arg will be NAN
return x;
}
#define c_exp_hi 88.3762626647949f
#define c_exp_lo -88.3762626647949f
#define c_cephes_LOG2EF 1.44269504088896341
#define c_cephes_exp_C1 0.693359375
#define c_cephes_exp_C2 -2.12194440e-4
#define c_cephes_exp_p0 1.9875691500E-4
#define c_cephes_exp_p1 1.3981999507E-3
#define c_cephes_exp_p2 8.3334519073E-3
#define c_cephes_exp_p3 4.1665795894E-2
#define c_cephes_exp_p4 1.6666665459E-1
#define c_cephes_exp_p5 5.0000001201E-1
// exp() computed for 4 float at once
inline float32x4_t exp_ps(float32x4_t x) {
float32x4_t tmp, fx;
float32x4_t one = vdupq_n_f32(1);
x = vminq_f32(x, vdupq_n_f32(c_exp_hi));
x = vmaxq_f32(x, vdupq_n_f32(c_exp_lo));
// express exp(x) as exp(g + n*log(2))
fx = vmlaq_f32(vdupq_n_f32(0.5f), x, vdupq_n_f32(c_cephes_LOG2EF));
// perform a floorf
tmp = vcvtq_f32_s32(vcvtq_s32_f32(fx));
// if greater, substract 1
uint32x4_t mask = vcgtq_f32(tmp, fx);
mask = vandq_u32(mask, vreinterpretq_u32_f32(one));
fx = vsubq_f32(tmp, vreinterpretq_f32_u32(mask));
tmp = vmulq_f32(fx, vdupq_n_f32(c_cephes_exp_C1));
float32x4_t z = vmulq_f32(fx, vdupq_n_f32(c_cephes_exp_C2));
x = vsubq_f32(x, tmp);
x = vsubq_f32(x, z);
static const float cephes_exp_p[6] = {c_cephes_exp_p0, c_cephes_exp_p1,
c_cephes_exp_p2, c_cephes_exp_p3,
c_cephes_exp_p4, c_cephes_exp_p5};
float32x4_t y = vld1q_dup_f32(cephes_exp_p + 0);
float32x4_t c1 = vld1q_dup_f32(cephes_exp_p + 1);
float32x4_t c2 = vld1q_dup_f32(cephes_exp_p + 2);
float32x4_t c3 = vld1q_dup_f32(cephes_exp_p + 3);
float32x4_t c4 = vld1q_dup_f32(cephes_exp_p + 4);
float32x4_t c5 = vld1q_dup_f32(cephes_exp_p + 5);
y = vmulq_f32(y, x);
z = vmulq_f32(x, x);
y = vaddq_f32(y, c1);
y = vmulq_f32(y, x);
y = vaddq_f32(y, c2);
y = vmulq_f32(y, x);
y = vaddq_f32(y, c3);
y = vmulq_f32(y, x);
y = vaddq_f32(y, c4);
y = vmulq_f32(y, x);
y = vaddq_f32(y, c5);
y = vmulq_f32(y, z);
y = vaddq_f32(y, x);
y = vaddq_f32(y, one);
// build 2^n
int32x4_t mm;
mm = vcvtq_s32_f32(fx);
mm = vaddq_s32(mm, vdupq_n_s32(0x7f));
mm = vshlq_n_s32(mm, 23);
float32x4_t pow2n = vreinterpretq_f32_s32(mm);
y = vmulq_f32(y, pow2n);
return y;
}
#define c_minus_cephes_DP1 -0.78515625
#define c_minus_cephes_DP2 -2.4187564849853515625e-4
#define c_minus_cephes_DP3 -3.77489497744594108e-8
#define c_sincof_p0 -1.9515295891E-4
#define c_sincof_p1 8.3321608736E-3
#define c_sincof_p2 -1.6666654611E-1
#define c_coscof_p0 2.443315711809948E-005
#define c_coscof_p1 -1.388731625493765E-003
#define c_coscof_p2 4.166664568298827E-002
#define c_cephes_FOPI 1.27323954473516 // 4 / M_PI
// evaluation of 4 sines & cosines at once.
//
// The code is the exact rewriting of the cephes sinf function.
// Precision is excellent as long as x < 8192 (I did not bother to
// take into account the special handling they have for greater values
// -- it does not return garbage for arguments over 8192, though, but
// the extra precision is missing).
//
// Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
// surprising but correct result.
//
// Note also that when you compute sin(x), cos(x) is available at
// almost no extra price so both sin_ps and cos_ps make use of
// sincos_ps..
//
inline void sincos_ps(float32x4_t x, float32x4_t* ysin, float32x4_t* ycos) {
// any x
float32x4_t xmm1, xmm2, xmm3, y;
uint32x4_t emm2;
uint32x4_t sign_mask_sin, sign_mask_cos;
sign_mask_sin = vcltq_f32(x, vdupq_n_f32(0));
x = vabsq_f32(x);
// scale by 4/Pi
y = vmulq_f32(x, vdupq_n_f32(c_cephes_FOPI));
// store the integer part of y in mm0
emm2 = vcvtq_u32_f32(y);
// j=(j+1) & (~1) (see the cephes sources)
emm2 = vaddq_u32(emm2, vdupq_n_u32(1));
emm2 = vandq_u32(emm2, vdupq_n_u32(~1));
y = vcvtq_f32_u32(emm2);
// get the polynom selection mask
// there is one polynom for 0 <= x <= Pi/4
// and another one for Pi/4<x<=Pi/2
uint32x4_t poly_mask = vtstq_u32(emm2, vdupq_n_u32(2));
// the magic pass: "Extended precision modular arithmetic"
// x = ((x - y * DP1) - y * DP2) - y * DP3;
xmm1 = vmulq_n_f32(y, c_minus_cephes_DP1);
xmm2 = vmulq_n_f32(y, c_minus_cephes_DP2);
xmm3 = vmulq_n_f32(y, c_minus_cephes_DP3);
x = vaddq_f32(x, xmm1);
x = vaddq_f32(x, xmm2);
x = vaddq_f32(x, xmm3);
sign_mask_sin = veorq_u32(sign_mask_sin, vtstq_u32(emm2, vdupq_n_u32(4)));
sign_mask_cos = vtstq_u32(vsubq_u32(emm2, vdupq_n_u32(2)), vdupq_n_u32(4));
// evaluate the first polynom (0 <= x <= Pi/4) in y1,
// and the second polynom (Pi/4 <= x <= 0) in y2
float32x4_t z = vmulq_f32(x, x);
float32x4_t y1, y2;
y1 = vmulq_n_f32(z, c_coscof_p0);
y2 = vmulq_n_f32(z, c_sincof_p0);
y1 = vaddq_f32(y1, vdupq_n_f32(c_coscof_p1));
y2 = vaddq_f32(y2, vdupq_n_f32(c_sincof_p1));
y1 = vmulq_f32(y1, z);
y2 = vmulq_f32(y2, z);
y1 = vaddq_f32(y1, vdupq_n_f32(c_coscof_p2));
y2 = vaddq_f32(y2, vdupq_n_f32(c_sincof_p2));
y1 = vmulq_f32(y1, z);
y2 = vmulq_f32(y2, z);
y1 = vmulq_f32(y1, z);
y2 = vmulq_f32(y2, x);
y1 = vsubq_f32(y1, vmulq_f32(z, vdupq_n_f32(0.5f)));
y2 = vaddq_f32(y2, x);
y1 = vaddq_f32(y1, vdupq_n_f32(1));
// select the correct result from the two polynoms
float32x4_t ys = vbslq_f32(poly_mask, y1, y2);
float32x4_t yc = vbslq_f32(poly_mask, y2, y1);
*ysin = vbslq_f32(sign_mask_sin, vnegq_f32(ys), ys);
*ycos = vbslq_f32(sign_mask_cos, yc, vnegq_f32(yc));
}
inline float32x4_t sin_ps(float32x4_t x) {
float32x4_t ysin, ycos;
sincos_ps(x, &ysin, &ycos);
return ysin;
}
inline float32x4_t cos_ps(float32x4_t x) {
float32x4_t ysin, ycos;
sincos_ps(x, &ysin, &ycos);
return ycos;
}
inline float32x4_t div_ps(float32x4_t a, float32x4_t b) {
float32x4_t reciprocal = vrecpeq_f32(b);
reciprocal = vmulq_f32(vrecpsq_f32(b, reciprocal), reciprocal);
return vmulq_f32(a, reciprocal);
}
inline float32x4_t pow_ps(float32x4_t a, float32x4_t b) {
// pow(x, m) = exp(m * log(x))
return exp_ps(vmulq_f32(b, log_ps(a)));
}
template <typename T>
void fill_bias_fc(T* tensor, const T* bias, const int num, const int channel);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/softmax.h"
#include <algorithm>
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <>
void softmax_basic<float>(const float* din, float* dout, const int axis_size,
const int inner_num, const int outer_num) {
int compute_size = inner_num * outer_num;
#pragma omp parallel for
for (int i = 0; i < compute_size; ++i) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
float max_data = din[real_index];
// get max
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
max_data = din[real_index] > max_data ? din[real_index] : max_data;
}
real_index = idx_outer * inner_num + idx_inner;
// sub, exp and sum
dout[real_index] = expf(din[real_index] - max_data);
float sum_data = dout[real_index];
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
dout[real_index] = expf(din[real_index] - max_data);
sum_data += dout[real_index];
}
float sum_inv = 1.f / sum_data;
real_index = idx_outer * inner_num + idx_inner;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
dout[real_index] *= sum_inv;
real_index += inner_num;
}
}
}
template <>
void softmax_inner8_axis4<float>(const float* din, float* dout,
const int axis_size, const int inner_num,
const int outer_num) {
int compute_size = inner_num * outer_num;
int cmp_cnt = compute_size >> 3;
int remain = compute_size % 8;
float32x4_t vone = vdupq_n_f32(1.0f);
#pragma omp parallel for
for (int c = 0; c < cmp_cnt; ++c) {
int i = c * 8;
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
// get max axis_size == 4
const float* din_ptr = din + real_index;
const float* din_ptr1 = din_ptr + inner_num;
const float* din_ptr2 = din_ptr1 + inner_num;
const float* din_ptr3 = din_ptr2 + inner_num;
float32x4_t vdata0 = vld1q_f32(din_ptr);
float32x4_t vdata1 = vld1q_f32(din_ptr1);
float32x4_t vdata2 = vld1q_f32(din_ptr2);
float32x4_t vdata3 = vld1q_f32(din_ptr3);
float32x4_t vdata01 = vld1q_f32(din_ptr + 4);
float32x4_t vdata11 = vld1q_f32(din_ptr1 + 4);
float32x4_t vdata21 = vld1q_f32(din_ptr2 + 4);
float32x4_t vdata31 = vld1q_f32(din_ptr3 + 4);
float* dout_ptr0 = dout + real_index;
float* dout_ptr1 = dout_ptr0 + inner_num;
float32x4_t vmax1 = vmaxq_f32(vdata0, vdata1);
float32x4_t vmax2 = vmaxq_f32(vdata2, vdata3);
float32x4_t vmax11 = vmaxq_f32(vdata01, vdata11);
float32x4_t vmax21 = vmaxq_f32(vdata21, vdata31);
float* dout_ptr2 = dout_ptr1 + inner_num;
float* dout_ptr3 = dout_ptr2 + inner_num;
float32x4_t vmax = vmaxq_f32(vmax1, vmax2);
float32x4_t vmax_1 = vmaxq_f32(vmax11, vmax21);
// sub, exp and sum
float32x4_t vsum0 = exp_ps(vsubq_f32(vdata0, vmax));
float32x4_t vsum1 = exp_ps(vsubq_f32(vdata1, vmax));
float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax));
float32x4_t vsum3 = exp_ps(vsubq_f32(vdata3, vmax));
float32x4_t vsum01 = exp_ps(vsubq_f32(vdata01, vmax_1));
float32x4_t vsum11 = exp_ps(vsubq_f32(vdata11, vmax_1));
float32x4_t vsum21 = exp_ps(vsubq_f32(vdata21, vmax_1));
float32x4_t vsum31 = exp_ps(vsubq_f32(vdata31, vmax_1));
float32x4_t vsum_1 = vaddq_f32(vsum0, vsum1);
float32x4_t vsum_2 = vaddq_f32(vsum2, vsum3);
float32x4_t vsum_11 = vaddq_f32(vsum01, vsum11);
float32x4_t vsum_21 = vaddq_f32(vsum21, vsum31);
float32x4_t vsum = vaddq_f32(vsum_1, vsum_2);
float32x4_t vsum111 = vaddq_f32(vsum_11, vsum_21);
float32x4_t vinf = div_ps(vone, vsum);
float32x4_t vinf1 = div_ps(vone, vsum111);
vsum0 = vmulq_f32(vsum0, vinf);
vsum1 = vmulq_f32(vsum1, vinf);
vsum2 = vmulq_f32(vsum2, vinf);
vsum3 = vmulq_f32(vsum3, vinf);
vsum01 = vmulq_f32(vsum01, vinf1);
vsum11 = vmulq_f32(vsum11, vinf1);
vsum21 = vmulq_f32(vsum21, vinf1);
vsum31 = vmulq_f32(vsum31, vinf1);
vst1q_f32(dout_ptr0, vsum0);
vst1q_f32(dout_ptr1, vsum1);
vst1q_f32(dout_ptr2, vsum2);
vst1q_f32(dout_ptr3, vsum3);
vst1q_f32(dout_ptr0 + 4, vsum01);
vst1q_f32(dout_ptr1 + 4, vsum11);
vst1q_f32(dout_ptr2 + 4, vsum21);
vst1q_f32(dout_ptr3 + 4, vsum31);
}
int i = cmp_cnt * 8;
if (remain > 4) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
// get max axis_size == 4
const float* din_ptr = din + real_index;
const float* din_ptr1 = din_ptr + inner_num;
const float* din_ptr2 = din_ptr1 + inner_num;
const float* din_ptr3 = din_ptr2 + inner_num;
float32x4_t vdata0 = vld1q_f32(din_ptr);
float32x4_t vdata1 = vld1q_f32(din_ptr1);
float32x4_t vdata2 = vld1q_f32(din_ptr2);
float32x4_t vdata3 = vld1q_f32(din_ptr3);
float* dout_ptr0 = dout + real_index;
float* dout_ptr1 = dout_ptr0 + inner_num;
float32x4_t vmax1 = vmaxq_f32(vdata0, vdata1);
float32x4_t vmax2 = vmaxq_f32(vdata2, vdata3);
float* dout_ptr2 = dout_ptr1 + inner_num;
float* dout_ptr3 = dout_ptr2 + inner_num;
float32x4_t vmax = vmaxq_f32(vmax1, vmax2);
// sub, exp and sum
float32x4_t vsum0 = exp_ps(vsubq_f32(vdata0, vmax));
float32x4_t vsum1 = exp_ps(vsubq_f32(vdata1, vmax));
float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax));
float32x4_t vsum3 = exp_ps(vsubq_f32(vdata3, vmax));
float32x4_t vsum_1 = vaddq_f32(vsum0, vsum1);
float32x4_t vsum_2 = vaddq_f32(vsum2, vsum3);
float32x4_t vsum = vaddq_f32(vsum_1, vsum_2);
float32x4_t vone = vdupq_n_f32(1.0f);
float32x4_t vinf = div_ps(vone, vsum);
vsum0 = vmulq_f32(vsum0, vinf);
vsum1 = vmulq_f32(vsum1, vinf);
vsum2 = vmulq_f32(vsum2, vinf);
vsum3 = vmulq_f32(vsum3, vinf);
vst1q_f32(dout_ptr0, vsum0);
vst1q_f32(dout_ptr1, vsum1);
vst1q_f32(dout_ptr2, vsum2);
vst1q_f32(dout_ptr3, vsum3);
i += 4;
}
for (; i < compute_size; i++) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
float max_data = din[real_index];
// get max
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
max_data = din[real_index] > max_data ? din[real_index] : max_data;
}
real_index = idx_outer * inner_num + idx_inner;
// sub, exp and sum
dout[real_index] = expf(din[real_index] - max_data);
float sum_data = dout[real_index];
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
dout[real_index] = expf(din[real_index] - max_data);
sum_data += dout[real_index];
}
float sum_inv = 1.f / sum_data;
real_index = idx_outer * inner_num + idx_inner;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
dout[real_index] *= sum_inv;
real_index += inner_num;
}
}
}
template <>
void softmax_inner4_axis4<float>(const float* din, float* dout,
const int axis_size, const int inner_num,
const int outer_num) {
int compute_size = inner_num * outer_num;
int cmp_cnt = compute_size >> 2;
int remain = compute_size % 4;
float32x4_t vone = vdupq_n_f32(1.0f);
#pragma omp parallel for
for (int c = 0; c < cmp_cnt; ++c) {
int i = c * 4;
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
// get max axis_size == 4
const float* din_ptr = din + real_index;
const float* din_ptr1 = din_ptr + inner_num;
const float* din_ptr2 = din_ptr1 + inner_num;
const float* din_ptr3 = din_ptr2 + inner_num;
float32x4_t vdata0 = vld1q_f32(din_ptr);
float32x4_t vdata1 = vld1q_f32(din_ptr1);
float32x4_t vdata2 = vld1q_f32(din_ptr2);
float32x4_t vdata3 = vld1q_f32(din_ptr3);
float* dout_ptr0 = dout + real_index;
float* dout_ptr1 = dout_ptr0 + inner_num;
float32x4_t vmax1 = vmaxq_f32(vdata0, vdata1);
float32x4_t vmax2 = vmaxq_f32(vdata2, vdata3);
float* dout_ptr2 = dout_ptr1 + inner_num;
float* dout_ptr3 = dout_ptr2 + inner_num;
float32x4_t vmax = vmaxq_f32(vmax1, vmax2);
// sub, exp and sum
float32x4_t vsum0 = exp_ps(vsubq_f32(vdata0, vmax));
float32x4_t vsum1 = exp_ps(vsubq_f32(vdata1, vmax));
float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax));
float32x4_t vsum3 = exp_ps(vsubq_f32(vdata3, vmax));
float32x4_t vsum_1 = vaddq_f32(vsum0, vsum1);
float32x4_t vsum_2 = vaddq_f32(vsum2, vsum3);
float32x4_t vsum = vaddq_f32(vsum_1, vsum_2);
float32x4_t vinf = div_ps(vone, vsum);
vsum0 = vmulq_f32(vsum0, vinf);
vsum1 = vmulq_f32(vsum1, vinf);
vsum2 = vmulq_f32(vsum2, vinf);
vsum3 = vmulq_f32(vsum3, vinf);
vst1q_f32(dout_ptr0, vsum0);
vst1q_f32(dout_ptr1, vsum1);
vst1q_f32(dout_ptr2, vsum2);
vst1q_f32(dout_ptr3, vsum3);
}
int i = cmp_cnt * 8;
for (; i < compute_size; i++) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
float max_data = din[real_index];
// get max
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
max_data = din[real_index] > max_data ? din[real_index] : max_data;
}
real_index = idx_outer * inner_num + idx_inner;
// sub, exp and sum
dout[real_index] = expf(din[real_index] - max_data);
float sum_data = dout[real_index];
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
dout[real_index] = expf(din[real_index] - max_data);
sum_data += dout[real_index];
}
float sum_inv = 1.f / sum_data;
real_index = idx_outer * inner_num + idx_inner;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
dout[real_index] *= sum_inv;
real_index += inner_num;
}
}
}
template <>
void softmax_inner8<float>(const float* din, float* dout, const int axis_size,
const int inner_num, const int outer_num) {
int compute_size = inner_num * outer_num;
int cmp_cnt = compute_size >> 3;
#pragma omp parallel for
for (int c = 0; c < cmp_cnt; ++c) {
int i = c * 8;
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
const float* din_ptr = din + real_index;
float32x4_t vmax = vld1q_f32(din_ptr);
float32x4_t vmax2 = vld1q_f32(din_ptr + 4);
// get max
for (int j = 1; j < axis_size; ++j) {
din_ptr += inner_num;
float32x4_t vdata = vld1q_f32(din_ptr);
float32x4_t vdata2 = vld1q_f32(din_ptr + 4);
vmax = vmaxq_f32(vmax, vdata);
vmax2 = vmaxq_f32(vmax2, vdata2);
}
// sub, exp and sum
din_ptr = din + real_index;
float* dout_ptr = dout + real_index;
float32x4_t vdata = vld1q_f32(din_ptr);
float32x4_t vdata2 = vld1q_f32(din_ptr + 4);
float32x4_t vsum = exp_ps(vsubq_f32(vdata, vmax));
float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax2));
din_ptr += inner_num;
vst1q_f32(dout_ptr, vsum);
vst1q_f32(dout_ptr + 4, vsum2);
dout_ptr += inner_num;
for (int j = 1; j < axis_size; ++j) {
float32x4_t vdata0 = vld1q_f32(din_ptr);
float32x4_t vdata1 = vld1q_f32(din_ptr + 4);
vdata0 = exp_ps(vsubq_f32(vdata0, vmax));
vdata1 = exp_ps(vsubq_f32(vdata1, vmax2));
din_ptr += inner_num;
vsum = vaddq_f32(vsum, vdata0);
vsum2 = vaddq_f32(vsum2, vdata1);
vst1q_f32(dout_ptr, vdata0);
vst1q_f32(dout_ptr + 4, vdata1);
dout_ptr += inner_num;
}
float32x4_t vone = vdupq_n_f32(1.0f);
float32x4_t vinf = div_ps(vone, vsum);
float32x4_t vinf2 = div_ps(vone, vsum2);
dout_ptr = dout + real_index;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
float32x4_t vdata0 = vld1q_f32(dout_ptr);
float32x4_t vdata1 = vld1q_f32(dout_ptr + 4);
vdata0 = vmulq_f32(vdata0, vinf);
vdata1 = vmulq_f32(vdata1, vinf2);
vst1q_f32(dout_ptr, vdata0);
vst1q_f32(dout_ptr + 4, vdata1);
dout_ptr += inner_num;
}
}
for (int i = cmp_cnt * 8; i < compute_size; i++) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
float max_data = din[real_index];
// get max
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
max_data = din[real_index] > max_data ? din[real_index] : max_data;
}
real_index = idx_outer * inner_num + idx_inner;
// sub, exp and sum
dout[real_index] = expf(din[real_index] - max_data);
float sum_data = dout[real_index];
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
dout[real_index] = expf(din[real_index] - max_data);
sum_data += dout[real_index];
}
float sum_inv = 1.f / sum_data;
real_index = idx_outer * inner_num + idx_inner;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
dout[real_index] *= sum_inv;
real_index += inner_num;
}
}
}
template <>
void softmax_inner4<float>(const float* din, float* dout, const int axis_size,
const int inner_num, const int outer_num) {
int compute_size = inner_num * outer_num;
int cmp_cnt = compute_size >> 2;
#pragma omp parallel for
for (int c = 0; c < cmp_cnt; ++c) {
int i = c * 4;
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
// float max_data = din[real_index];
const float* din_ptr = din + real_index;
float32x4_t vmax = vld1q_f32(din_ptr);
// get max
for (int j = 1; j < axis_size; ++j) {
din_ptr += inner_num;
float32x4_t vdata = vld1q_f32(din_ptr);
vmax = vmaxq_f32(vmax, vdata);
}
// sub, exp and sum
din_ptr = din + real_index;
float* dout_ptr = dout + real_index;
float32x4_t vdata = vld1q_f32(din_ptr);
float32x4_t vsum = exp_ps(vsubq_f32(vdata, vmax));
din_ptr += inner_num;
vst1q_f32(dout_ptr, vsum);
dout_ptr += inner_num;
for (int j = 1; j < axis_size; ++j) {
// real_index += inner_num;
float32x4_t vdata0 = vld1q_f32(din_ptr);
vdata0 = exp_ps(vsubq_f32(vdata0, vmax));
din_ptr += inner_num;
vsum = vaddq_f32(vsum, vdata0);
vst1q_f32(dout_ptr, vdata0);
dout_ptr += inner_num;
}
float32x4_t vone = vdupq_n_f32(1.0f);
float32x4_t vinf = div_ps(vone, vsum);
dout_ptr = dout + real_index;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
float32x4_t vdata0 = vld1q_f32(dout_ptr);
vdata0 = vmulq_f32(vdata0, vinf);
vst1q_f32(dout_ptr, vdata0);
dout_ptr += inner_num;
}
}
for (int i = cmp_cnt * 4; i < compute_size; i++) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
float max_data = din[real_index];
// get max
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
max_data = din[real_index] > max_data ? din[real_index] : max_data;
}
real_index = idx_outer * inner_num + idx_inner;
// sub, exp and sum
dout[real_index] = expf(din[real_index] - max_data);
float sum_data = dout[real_index];
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
dout[real_index] = expf(din[real_index] - max_data);
sum_data += dout[real_index];
}
float sum_inv = 1.f / sum_data;
real_index = idx_outer * inner_num + idx_inner;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
dout[real_index] *= sum_inv;
real_index += inner_num;
}
}
}
template <>
void softmax_inner1_large_axis<float>(const float* din, float* dout,
const int outer_size,
const int axis_size) {
#pragma omp parallel for
for (int i = 0; i < outer_size; ++i) {
const float* din_ptr = din + i * axis_size;
float* dout_ptr = dout + i * axis_size;
const float* din_max_ptr = din_ptr;
int nn = axis_size >> 2;
// get max
float32x4_t vmax = vld1q_f32(din_max_ptr);
din_max_ptr += 4;
int j = 1;
for (; j < nn; ++j) {
vmax = vmaxq_f32(vmax, vld1q_f32(din_max_ptr));
din_max_ptr += 4;
}
float32x2_t vhmax = vmax_f32(vget_high_f32(vmax), vget_low_f32(vmax));
float max_data = std::max(vget_lane_f32(vhmax, 0), vget_lane_f32(vhmax, 1));
for (j = 4 * j; j < axis_size; ++j) {
max_data = std::max(max_data, din_max_ptr[0]);
din_max_ptr++;
}
// sub, exp and sum
const float* din_sum_ptr = din_ptr;
float* dout_sum_ptr = dout_ptr;
vmax = vdupq_n_f32(max_data);
float32x4_t vsub_exp = exp_ps(vsubq_f32(vld1q_f32(din_sum_ptr), vmax));
float32x4_t vsum = vsub_exp;
vst1q_f32(dout_sum_ptr, vsub_exp);
din_sum_ptr += 4;
dout_sum_ptr += 4;
j = 1;
for (; j < nn; ++j) {
vsub_exp = exp_ps(vsubq_f32(vld1q_f32(din_sum_ptr), vmax));
vst1q_f32(dout_sum_ptr, vsub_exp);
vsum = vaddq_f32(vsum, vsub_exp);
din_sum_ptr += 4;
dout_sum_ptr += 4;
}
float32x2_t vhsum = vadd_f32(vget_high_f32(vsum), vget_low_f32(vsum));
float sum_data = vget_lane_f32(vhsum, 0) + vget_lane_f32(vhsum, 1);
for (j = 4 * j; j < axis_size; ++j) {
dout_sum_ptr[0] = expf(din_sum_ptr[0] - max_data);
sum_data += dout_sum_ptr[0];
din_sum_ptr++;
dout_sum_ptr++;
}
float sum_inv = 1.f / sum_data;
float* dout_res_ptr = dout_ptr;
float32x4_t vinv = vdupq_n_f32(sum_inv);
// get softmax result
j = 0;
for (; j < nn; ++j) {
float32x4_t vout = vld1q_f32(dout_res_ptr);
float32x4_t vres = vmulq_f32(vout, vinv);
vst1q_f32(dout_res_ptr, vres);
dout_res_ptr += 4;
}
for (j = nn * 4; j < axis_size; ++j) {
dout_ptr[j] *= sum_inv;
}
}
}
template <>
void softmax_inner1_small_axis<float>(const float* din, float* dout,
const int outer_size,
const int axis_size) {
#pragma omp parallel for
for (int i = 0; i < outer_size; ++i) {
const float* din_ptr = din + i * axis_size;
float* dout_ptr = dout + i * axis_size;
// get max
float max_data = din_ptr[0];
for (int j = 1; j < axis_size; ++j) {
max_data = std::max(max_data, din_ptr[j]);
}
// sub, exp and sum
float sum_data = 0.f;
for (int j = 0; j < axis_size; ++j) {
dout_ptr[j] = expf(din_ptr[j] - max_data);
sum_data += dout_ptr[j];
}
float sum_inv = 1.f / sum_data;
for (int j = 0; j < axis_size; ++j) {
dout_ptr[j] *= sum_inv;
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void softmax_basic(const T* din, T* dout, const int axis_size,
const int inner_num, const int outer_num);
template <typename T>
void softmax_inner8_axis4(const T* din, T* dout, const int axis_size,
const int inner_num, const int outer_num);
template <typename T>
void softmax_inner4_axis4(const T* din, T* dout, const int axis_size,
const int inner_num, const int outer_num);
template <typename T>
void softmax_inner8(const T* din, T* dout, const int axis_size,
const int inner_num, const int outer_num);
template <typename T>
void softmax_inner4(const T* din, T* dout, const int axis_size,
const int inner_num, const int outer_num);
template <typename T>
void softmax_inner1_large_axis(const T* din, T* dout, const int outer_size,
const int axis_size);
template <typename T>
void softmax_inner1_small_axis(const T* din, T* dout, const int outer_size,
const int axis_size);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
......@@ -8,13 +8,16 @@ cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps})
cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3)
cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} eigen3)
cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm)
lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm)
lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm)
set(arm_kernels
fc_compute_arm
relu_compute_arm
mul_compute_arm
scale_compute_arm)
scale_compute_arm
softmax_compute_arm)
set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/softmax_compute.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void SoftmaxCompute::Run() {
auto& param = Param<operators::SoftmaxParam>();
const float* din = param.x->data<float>();
float* dout = param.output->mutable_data<float>();
auto dim_x = param.x->dims();
auto rank_x = dim_x.size();
int axis = param.axis;
if (axis < 0) {
axis += rank_x;
}
int outer_num = dim_x.Slice(0, axis).production();
int inner_num = dim_x.Slice(axis + 1, rank_x).production();
int axis_size = dim_x[axis];
if (inner_num == 1) {
if (axis_size >= 4) {
lite::arm::math::softmax_inner1_large_axis(din, dout, outer_num,
axis_size);
} else {
lite::arm::math::softmax_inner1_small_axis(din, dout, outer_num,
axis_size);
}
} else {
int compute_size = outer_num * inner_num;
if (axis_size == 4 && inner_num % 8 == 0) {
lite::arm::math::softmax_inner8_axis4(din, dout, axis_size, inner_num,
outer_num);
} else if (axis_size == 4 && inner_num % 4 == 0) {
lite::arm::math::softmax_inner4_axis4(din, dout, axis_size, inner_num,
outer_num);
} else {
if (inner_num % 8 == 0) {
lite::arm::math::softmax_inner8(din, dout, axis_size, inner_num,
outer_num);
} else if (inner_num % 4 == 0) {
lite::arm::math::softmax_inner4(din, dout, axis_size, inner_num,
outer_num);
} else {
lite::arm::math::softmax_basic(din, dout, axis_size, inner_num,
outer_num);
}
}
}
}
TargetType SoftmaxCompute::target() const { return TARGET(kARM); }
PrecisionType SoftmaxCompute::precision() const { return PRECISION(kFloat); }
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(softmax, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::SoftmaxCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class SoftmaxCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
TargetType target() const override;
PrecisionType precision() const override;
virtual ~SoftmaxCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/softmax_compute.h"
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <typename dtype>
void softmat_compute_ref(const operators::SoftmaxParam& param) {
const dtype* x_data = param.x->mutable_data<const dtype>();
dtype* output_data = param.output->mutable_data<dtype>();
DDim dim = param.x->dims();
ASSERT_EQ(dim.data(), param.output->dims().data());
auto rank = dim.size();
int axis = param.axis;
if (axis < 0) {
axis += rank;
}
int axis_size = dim[axis];
int outer_num = dim.Slice(0, axis).production();
int inner_num = dim.Slice(axis + 1, rank).production();
int compute_size = outer_num * inner_num;
for (int i = 0; i < compute_size; i++) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int start = idx_outer * inner_num + idx_inner;
int offset;
offset = start;
dtype max_data = std::numeric_limits<dtype>::lowest();
for (int j = 0; j < axis_size; j++) {
max_data = x_data[offset] > max_data ? x_data[offset] : max_data;
offset += inner_num;
}
offset = start;
dtype sum_data = (dtype)0;
for (int j = 0; j < axis_size; j++) {
output_data[offset] = exp(x_data[offset] - max_data);
sum_data += output_data[offset];
offset += inner_num;
}
offset = start;
for (int j = 0; j < axis_size; j++) {
output_data[offset] /= sum_data;
offset += inner_num;
}
}
}
TEST(softmax_arm, init) {
SoftmaxCompute softmax;
ASSERT_EQ(softmax.precision(), PRECISION(kFloat));
ASSERT_EQ(softmax.target(), TARGET(kARM));
}
TEST(softmax_arm, compute) {
SoftmaxCompute softmax;
operators::SoftmaxParam param;
lite::Tensor x;
lite::Tensor output;
lite::Tensor output_ref;
for (auto n : {1, 3, 4, 11}) {
for (auto c : {1, 3, 11, 4}) {
for (auto h : {3, 1, 11, 4}) {
for (auto w : {1, 3, 4, 12}) {
for (auto axis : {-4, -3, -2, -1, 0, 1, 2, 3}) {
x.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output_ref.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
auto* x_data = x.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
auto* output_ref_data = output_ref.mutable_data<float>();
for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = i;
}
param.x = &x;
param.axis = axis;
param.output = &output;
softmax.SetParam(param);
softmax.Run();
param.output = &output_ref;
softmat_compute_ref<float>(param);
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5);
}
}
}
}
}
}
}
TEST(softmax, retrive_op) {
auto softmax =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"softmax");
ASSERT_FALSE(softmax.empty());
ASSERT_TRUE(softmax.front());
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def);
......@@ -18,5 +18,6 @@
USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(feed, kARM, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def);
......@@ -4,6 +4,7 @@ cc_library(fc_op_lite SRCS fc_op.cc DEPS ${op_DEPS})
cc_library(relu_op_lite SRCS relu_op.cc DEPS ${op_DEPS})
cc_library(mul_op_lite SRCS mul_op.cc DEPS ${op_DEPS})
cc_library(scale_op_lite SRCS scale_op.cc DEPS ${op_DEPS})
cc_library(softmax_op_lite SRCS softmax_op.cc DEPS ${op_DEPS})
cc_library(feed_op_lite SRCS feed_op.cc DEPS ${op_DEPS})
cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS})
cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS})
......@@ -19,6 +20,7 @@ set(ops_lite
relu_op_lite
mul_op_lite
scale_op_lite
softmax_op_lite
feed_op_lite
fetch_op_lite
io_copy_op_lite
......@@ -28,3 +30,4 @@ set(ops_lite
PARENT_SCOPE)
lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite memory_lite)
lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite)
......@@ -93,6 +93,14 @@ struct ScaleParam {
bool bias_after_scale{true};
};
// For Softmax Op
struct SoftmaxParam {
lite::Tensor* x{};
lite::Tensor* output{};
int axis{-1};
};
/// ----------------------- element wise operators ----------------------
struct ElementwiseParam {
const lite::Tensor* X{};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/softmax_op.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SoftmaxOp::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
auto dim_x = param_.x->dims();
auto rank_x = dim_x.size();
CHECK_OR_FALSE(param_.axis >= -rank_x && param_.axis < rank_x);
return true;
}
bool SoftmaxOp::InferShape() const {
param_.output->Resize(param_.x->dims());
return true;
}
bool SoftmaxOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) {
param_.x = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
param_.output =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.axis = GetAttr<int>(opdesc.GetAttr("axis"));
CHECK(param_.x);
CHECK(param_.output);
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(softmax, paddle::lite::operators::SoftmaxOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class SoftmaxOp : public OpLite {
public:
SoftmaxOp() {}
explicit SoftmaxOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "softmax"; }
private:
mutable SoftmaxParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/softmax_op.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
TEST(softmax_op_lite, test) {
// prepare variables
Scope scope;
auto* x = scope.Var("x")->GetMutable<Tensor>();
auto* output = scope.Var("output")->GetMutable<Tensor>();
x->Resize(DDim(std::vector<int64_t>({10, 20})));
output->Resize(DDim(std::vector<int64_t>{10, 20}));
// set data
for (int i = 0; i < 10 * 20; i++) {
x->mutable_data<float>()[i] = i;
}
for (int i = 0; i < 10 * 20; i++) {
output->mutable_data<float>()[i] = 0.;
}
// prepare op desc
lite::OpDesc desc;
desc.SetType("softmax");
desc.SetInput("X", {"x"});
desc.SetOutput("Out", {"output"});
desc.SetAttr("axis", static_cast<int>(-1));
SoftmaxOp softmax("softmax");
softmax.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}});
softmax.Attach(desc, &scope);
}
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册