diff --git a/paddle/fluid/lite/api/cxx_api_bin.cc b/paddle/fluid/lite/api/cxx_api_bin.cc index 4e111097e380ab79e33faf12374d69c56809964c..f53f6105d1bf8abdce928ad8fb8fc36ac79935c6 100644 --- a/paddle/fluid/lite/api/cxx_api_bin.cc +++ b/paddle/fluid/lite/api/cxx_api_bin.cc @@ -72,6 +72,7 @@ USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def); // USE_LITE_KERNEL(feed, kARM, kAny, kAny, def); // USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def); #endif // LITE_WITH_ARM diff --git a/paddle/fluid/lite/arm/math/CMakeLists.txt b/paddle/fluid/lite/arm/math/CMakeLists.txt index 278cb54a418d8e546a1ebb26c5664412ca692590..8f29f6542d40bd16b77cf127ec4f8151b0d10994 100644 --- a/paddle/fluid/lite/arm/math/CMakeLists.txt +++ b/paddle/fluid/lite/arm/math/CMakeLists.txt @@ -1,2 +1,2 @@ -cc_library(math_arm SRCS funcs.cc packed_sgemm.cc DEPS ${lite_kernel_deps} eigen3) +cc_library(math_arm SRCS funcs.cc packed_sgemm.cc softmax.cc DEPS ${lite_kernel_deps} eigen3) diff --git a/paddle/fluid/lite/arm/math/funcs.h b/paddle/fluid/lite/arm/math/funcs.h index dd3ba2db509971dafd482303129ae4e24479dbdb..2b34ac33c973367e5a5c4191d13134fa7bb9241a 100644 --- a/paddle/fluid/lite/arm/math/funcs.h +++ b/paddle/fluid/lite/arm/math/funcs.h @@ -18,12 +18,293 @@ #include #include "paddle/fluid/lite/arm/math/packed_sgemm.h" +#include "paddle/fluid/lite/arm/math/softmax.h" namespace paddle { namespace lite { namespace arm { namespace math { +#define c_inv_mant_mask ~0x7f800000u +#define c_cephes_SQRTHF 0.707106781186547524 +#define c_cephes_log_p0 7.0376836292E-2 +#define c_cephes_log_p1 -1.1514610310E-1 +#define c_cephes_log_p2 1.1676998740E-1 +#define c_cephes_log_p3 -1.2420140846E-1 +#define c_cephes_log_p4 +1.4249322787E-1 +#define c_cephes_log_p5 -1.6668057665E-1 +#define c_cephes_log_p6 +2.0000714765E-1 +#define c_cephes_log_p7 -2.4999993993E-1 +#define c_cephes_log_p8 +3.3333331174E-1 +#define c_cephes_log_q1 -2.12194440e-4 +#define c_cephes_log_q2 0.693359375 + +// natural logarithm computed for 4 simultaneous float +// return NaN for x <= 0 +inline float32x4_t log_ps(float32x4_t x) { + float32x4_t one = vdupq_n_f32(1); + + x = vmaxq_f32(x, vdupq_n_f32(0)); // force flush to zero on denormal values + uint32x4_t invalid_mask = vcleq_f32(x, vdupq_n_f32(0)); + + int32x4_t ux = vreinterpretq_s32_f32(x); + + int32x4_t emm0 = vshrq_n_s32(ux, 23); + + // keep only the fractional part + ux = vandq_s32(ux, vdupq_n_s32(c_inv_mant_mask)); + ux = vorrq_s32(ux, vreinterpretq_s32_f32(vdupq_n_f32(0.5f))); + x = vreinterpretq_f32_s32(ux); + + emm0 = vsubq_s32(emm0, vdupq_n_s32(0x7f)); + float32x4_t e = vcvtq_f32_s32(emm0); + + e = vaddq_f32(e, one); + + // part2: + // if( x < SQRTHF ) { + // e -= 1; + // x = x + x - 1.0; + // } else { + // x = x - 1.0; + // } + // + uint32x4_t mask = vcltq_f32(x, vdupq_n_f32(c_cephes_SQRTHF)); + float32x4_t tmp = + vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask)); + x = vsubq_f32(x, one); + e = vsubq_f32( + e, vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(one), mask))); + x = vaddq_f32(x, tmp); + + float32x4_t z = vmulq_f32(x, x); + + float32x4_t y = vdupq_n_f32(c_cephes_log_p0); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p1)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p2)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p3)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p4)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p5)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p6)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p7)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p8)); + y = vmulq_f32(y, x); + + y = vmulq_f32(y, z); + + tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q1)); + y = vaddq_f32(y, tmp); + + tmp = vmulq_f32(z, vdupq_n_f32(0.5f)); + y = vsubq_f32(y, tmp); + + tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q2)); + x = vaddq_f32(x, y); + x = vaddq_f32(x, tmp); + x = vreinterpretq_f32_u32(vorrq_u32( + vreinterpretq_u32_f32(x), invalid_mask)); // negative arg will be NAN + return x; +} + +#define c_exp_hi 88.3762626647949f +#define c_exp_lo -88.3762626647949f + +#define c_cephes_LOG2EF 1.44269504088896341 +#define c_cephes_exp_C1 0.693359375 +#define c_cephes_exp_C2 -2.12194440e-4 + +#define c_cephes_exp_p0 1.9875691500E-4 +#define c_cephes_exp_p1 1.3981999507E-3 +#define c_cephes_exp_p2 8.3334519073E-3 +#define c_cephes_exp_p3 4.1665795894E-2 +#define c_cephes_exp_p4 1.6666665459E-1 +#define c_cephes_exp_p5 5.0000001201E-1 + +// exp() computed for 4 float at once +inline float32x4_t exp_ps(float32x4_t x) { + float32x4_t tmp, fx; + + float32x4_t one = vdupq_n_f32(1); + x = vminq_f32(x, vdupq_n_f32(c_exp_hi)); + x = vmaxq_f32(x, vdupq_n_f32(c_exp_lo)); + + // express exp(x) as exp(g + n*log(2)) + fx = vmlaq_f32(vdupq_n_f32(0.5f), x, vdupq_n_f32(c_cephes_LOG2EF)); + + // perform a floorf + tmp = vcvtq_f32_s32(vcvtq_s32_f32(fx)); + + // if greater, substract 1 + uint32x4_t mask = vcgtq_f32(tmp, fx); + mask = vandq_u32(mask, vreinterpretq_u32_f32(one)); + + fx = vsubq_f32(tmp, vreinterpretq_f32_u32(mask)); + + tmp = vmulq_f32(fx, vdupq_n_f32(c_cephes_exp_C1)); + float32x4_t z = vmulq_f32(fx, vdupq_n_f32(c_cephes_exp_C2)); + x = vsubq_f32(x, tmp); + x = vsubq_f32(x, z); + + static const float cephes_exp_p[6] = {c_cephes_exp_p0, c_cephes_exp_p1, + c_cephes_exp_p2, c_cephes_exp_p3, + c_cephes_exp_p4, c_cephes_exp_p5}; + float32x4_t y = vld1q_dup_f32(cephes_exp_p + 0); + float32x4_t c1 = vld1q_dup_f32(cephes_exp_p + 1); + float32x4_t c2 = vld1q_dup_f32(cephes_exp_p + 2); + float32x4_t c3 = vld1q_dup_f32(cephes_exp_p + 3); + float32x4_t c4 = vld1q_dup_f32(cephes_exp_p + 4); + float32x4_t c5 = vld1q_dup_f32(cephes_exp_p + 5); + + y = vmulq_f32(y, x); + z = vmulq_f32(x, x); + + y = vaddq_f32(y, c1); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c2); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c3); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c4); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c5); + + y = vmulq_f32(y, z); + y = vaddq_f32(y, x); + y = vaddq_f32(y, one); + + // build 2^n + int32x4_t mm; + mm = vcvtq_s32_f32(fx); + mm = vaddq_s32(mm, vdupq_n_s32(0x7f)); + mm = vshlq_n_s32(mm, 23); + float32x4_t pow2n = vreinterpretq_f32_s32(mm); + + y = vmulq_f32(y, pow2n); + return y; +} + +#define c_minus_cephes_DP1 -0.78515625 +#define c_minus_cephes_DP2 -2.4187564849853515625e-4 +#define c_minus_cephes_DP3 -3.77489497744594108e-8 +#define c_sincof_p0 -1.9515295891E-4 +#define c_sincof_p1 8.3321608736E-3 +#define c_sincof_p2 -1.6666654611E-1 +#define c_coscof_p0 2.443315711809948E-005 +#define c_coscof_p1 -1.388731625493765E-003 +#define c_coscof_p2 4.166664568298827E-002 +#define c_cephes_FOPI 1.27323954473516 // 4 / M_PI + +// evaluation of 4 sines & cosines at once. +// +// The code is the exact rewriting of the cephes sinf function. +// Precision is excellent as long as x < 8192 (I did not bother to +// take into account the special handling they have for greater values +// -- it does not return garbage for arguments over 8192, though, but +// the extra precision is missing). +// +// Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the +// surprising but correct result. +// +// Note also that when you compute sin(x), cos(x) is available at +// almost no extra price so both sin_ps and cos_ps make use of +// sincos_ps.. +// +inline void sincos_ps(float32x4_t x, float32x4_t* ysin, float32x4_t* ycos) { + // any x + float32x4_t xmm1, xmm2, xmm3, y; + + uint32x4_t emm2; + + uint32x4_t sign_mask_sin, sign_mask_cos; + sign_mask_sin = vcltq_f32(x, vdupq_n_f32(0)); + x = vabsq_f32(x); + + // scale by 4/Pi + y = vmulq_f32(x, vdupq_n_f32(c_cephes_FOPI)); + + // store the integer part of y in mm0 + emm2 = vcvtq_u32_f32(y); + // j=(j+1) & (~1) (see the cephes sources) + emm2 = vaddq_u32(emm2, vdupq_n_u32(1)); + emm2 = vandq_u32(emm2, vdupq_n_u32(~1)); + y = vcvtq_f32_u32(emm2); + + // get the polynom selection mask + // there is one polynom for 0 <= x <= Pi/4 + // and another one for Pi/4 void fill_bias_fc(T* tensor, const T* bias, const int num, const int channel); diff --git a/paddle/fluid/lite/arm/math/softmax.cc b/paddle/fluid/lite/arm/math/softmax.cc new file mode 100644 index 0000000000000000000000000000000000000000..2a081eaf4899665eb5e8bf118e2634aed22d084e --- /dev/null +++ b/paddle/fluid/lite/arm/math/softmax.cc @@ -0,0 +1,601 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/arm/math/softmax.h" +#include +#include "paddle/fluid/lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void softmax_basic(const float* din, float* dout, const int axis_size, + const int inner_num, const int outer_num) { + int compute_size = inner_num * outer_num; +#pragma omp parallel for + for (int i = 0; i < compute_size; ++i) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + float max_data = din[real_index]; + // get max + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + max_data = din[real_index] > max_data ? din[real_index] : max_data; + } + + real_index = idx_outer * inner_num + idx_inner; + // sub, exp and sum + dout[real_index] = expf(din[real_index] - max_data); + float sum_data = dout[real_index]; + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + dout[real_index] = expf(din[real_index] - max_data); + sum_data += dout[real_index]; + } + + float sum_inv = 1.f / sum_data; + real_index = idx_outer * inner_num + idx_inner; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + dout[real_index] *= sum_inv; + real_index += inner_num; + } + } +} + +template <> +void softmax_inner8_axis4(const float* din, float* dout, + const int axis_size, const int inner_num, + const int outer_num) { + int compute_size = inner_num * outer_num; + int cmp_cnt = compute_size >> 3; + int remain = compute_size % 8; + float32x4_t vone = vdupq_n_f32(1.0f); + +#pragma omp parallel for + for (int c = 0; c < cmp_cnt; ++c) { + int i = c * 8; + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + // get max axis_size == 4 + const float* din_ptr = din + real_index; + const float* din_ptr1 = din_ptr + inner_num; + const float* din_ptr2 = din_ptr1 + inner_num; + const float* din_ptr3 = din_ptr2 + inner_num; + float32x4_t vdata0 = vld1q_f32(din_ptr); + float32x4_t vdata1 = vld1q_f32(din_ptr1); + float32x4_t vdata2 = vld1q_f32(din_ptr2); + float32x4_t vdata3 = vld1q_f32(din_ptr3); + + float32x4_t vdata01 = vld1q_f32(din_ptr + 4); + float32x4_t vdata11 = vld1q_f32(din_ptr1 + 4); + float32x4_t vdata21 = vld1q_f32(din_ptr2 + 4); + float32x4_t vdata31 = vld1q_f32(din_ptr3 + 4); + + float* dout_ptr0 = dout + real_index; + float* dout_ptr1 = dout_ptr0 + inner_num; + float32x4_t vmax1 = vmaxq_f32(vdata0, vdata1); + float32x4_t vmax2 = vmaxq_f32(vdata2, vdata3); + float32x4_t vmax11 = vmaxq_f32(vdata01, vdata11); + float32x4_t vmax21 = vmaxq_f32(vdata21, vdata31); + float* dout_ptr2 = dout_ptr1 + inner_num; + float* dout_ptr3 = dout_ptr2 + inner_num; + float32x4_t vmax = vmaxq_f32(vmax1, vmax2); + float32x4_t vmax_1 = vmaxq_f32(vmax11, vmax21); + + // sub, exp and sum + float32x4_t vsum0 = exp_ps(vsubq_f32(vdata0, vmax)); + float32x4_t vsum1 = exp_ps(vsubq_f32(vdata1, vmax)); + float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax)); + float32x4_t vsum3 = exp_ps(vsubq_f32(vdata3, vmax)); + + float32x4_t vsum01 = exp_ps(vsubq_f32(vdata01, vmax_1)); + float32x4_t vsum11 = exp_ps(vsubq_f32(vdata11, vmax_1)); + float32x4_t vsum21 = exp_ps(vsubq_f32(vdata21, vmax_1)); + float32x4_t vsum31 = exp_ps(vsubq_f32(vdata31, vmax_1)); + + float32x4_t vsum_1 = vaddq_f32(vsum0, vsum1); + float32x4_t vsum_2 = vaddq_f32(vsum2, vsum3); + float32x4_t vsum_11 = vaddq_f32(vsum01, vsum11); + float32x4_t vsum_21 = vaddq_f32(vsum21, vsum31); + + float32x4_t vsum = vaddq_f32(vsum_1, vsum_2); + float32x4_t vsum111 = vaddq_f32(vsum_11, vsum_21); + + float32x4_t vinf = div_ps(vone, vsum); + float32x4_t vinf1 = div_ps(vone, vsum111); + + vsum0 = vmulq_f32(vsum0, vinf); + vsum1 = vmulq_f32(vsum1, vinf); + vsum2 = vmulq_f32(vsum2, vinf); + vsum3 = vmulq_f32(vsum3, vinf); + + vsum01 = vmulq_f32(vsum01, vinf1); + vsum11 = vmulq_f32(vsum11, vinf1); + vsum21 = vmulq_f32(vsum21, vinf1); + vsum31 = vmulq_f32(vsum31, vinf1); + + vst1q_f32(dout_ptr0, vsum0); + vst1q_f32(dout_ptr1, vsum1); + vst1q_f32(dout_ptr2, vsum2); + vst1q_f32(dout_ptr3, vsum3); + + vst1q_f32(dout_ptr0 + 4, vsum01); + vst1q_f32(dout_ptr1 + 4, vsum11); + vst1q_f32(dout_ptr2 + 4, vsum21); + vst1q_f32(dout_ptr3 + 4, vsum31); + } + + int i = cmp_cnt * 8; + + if (remain > 4) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + // get max axis_size == 4 + const float* din_ptr = din + real_index; + const float* din_ptr1 = din_ptr + inner_num; + const float* din_ptr2 = din_ptr1 + inner_num; + const float* din_ptr3 = din_ptr2 + inner_num; + float32x4_t vdata0 = vld1q_f32(din_ptr); + float32x4_t vdata1 = vld1q_f32(din_ptr1); + float32x4_t vdata2 = vld1q_f32(din_ptr2); + float32x4_t vdata3 = vld1q_f32(din_ptr3); + + float* dout_ptr0 = dout + real_index; + float* dout_ptr1 = dout_ptr0 + inner_num; + float32x4_t vmax1 = vmaxq_f32(vdata0, vdata1); + float32x4_t vmax2 = vmaxq_f32(vdata2, vdata3); + float* dout_ptr2 = dout_ptr1 + inner_num; + float* dout_ptr3 = dout_ptr2 + inner_num; + float32x4_t vmax = vmaxq_f32(vmax1, vmax2); + + // sub, exp and sum + float32x4_t vsum0 = exp_ps(vsubq_f32(vdata0, vmax)); + float32x4_t vsum1 = exp_ps(vsubq_f32(vdata1, vmax)); + float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax)); + float32x4_t vsum3 = exp_ps(vsubq_f32(vdata3, vmax)); + + float32x4_t vsum_1 = vaddq_f32(vsum0, vsum1); + float32x4_t vsum_2 = vaddq_f32(vsum2, vsum3); + + float32x4_t vsum = vaddq_f32(vsum_1, vsum_2); + + float32x4_t vone = vdupq_n_f32(1.0f); + float32x4_t vinf = div_ps(vone, vsum); + + vsum0 = vmulq_f32(vsum0, vinf); + vsum1 = vmulq_f32(vsum1, vinf); + vsum2 = vmulq_f32(vsum2, vinf); + vsum3 = vmulq_f32(vsum3, vinf); + + vst1q_f32(dout_ptr0, vsum0); + vst1q_f32(dout_ptr1, vsum1); + vst1q_f32(dout_ptr2, vsum2); + vst1q_f32(dout_ptr3, vsum3); + + i += 4; + } + for (; i < compute_size; i++) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + float max_data = din[real_index]; + // get max + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + max_data = din[real_index] > max_data ? din[real_index] : max_data; + } + + real_index = idx_outer * inner_num + idx_inner; + // sub, exp and sum + dout[real_index] = expf(din[real_index] - max_data); + float sum_data = dout[real_index]; + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + dout[real_index] = expf(din[real_index] - max_data); + sum_data += dout[real_index]; + } + + float sum_inv = 1.f / sum_data; + real_index = idx_outer * inner_num + idx_inner; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + dout[real_index] *= sum_inv; + real_index += inner_num; + } + } +} + +template <> +void softmax_inner4_axis4(const float* din, float* dout, + const int axis_size, const int inner_num, + const int outer_num) { + int compute_size = inner_num * outer_num; + int cmp_cnt = compute_size >> 2; + int remain = compute_size % 4; + float32x4_t vone = vdupq_n_f32(1.0f); + +#pragma omp parallel for + for (int c = 0; c < cmp_cnt; ++c) { + int i = c * 4; + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + // get max axis_size == 4 + const float* din_ptr = din + real_index; + const float* din_ptr1 = din_ptr + inner_num; + const float* din_ptr2 = din_ptr1 + inner_num; + const float* din_ptr3 = din_ptr2 + inner_num; + float32x4_t vdata0 = vld1q_f32(din_ptr); + float32x4_t vdata1 = vld1q_f32(din_ptr1); + float32x4_t vdata2 = vld1q_f32(din_ptr2); + float32x4_t vdata3 = vld1q_f32(din_ptr3); + + float* dout_ptr0 = dout + real_index; + float* dout_ptr1 = dout_ptr0 + inner_num; + float32x4_t vmax1 = vmaxq_f32(vdata0, vdata1); + float32x4_t vmax2 = vmaxq_f32(vdata2, vdata3); + float* dout_ptr2 = dout_ptr1 + inner_num; + float* dout_ptr3 = dout_ptr2 + inner_num; + float32x4_t vmax = vmaxq_f32(vmax1, vmax2); + + // sub, exp and sum + float32x4_t vsum0 = exp_ps(vsubq_f32(vdata0, vmax)); + float32x4_t vsum1 = exp_ps(vsubq_f32(vdata1, vmax)); + float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax)); + float32x4_t vsum3 = exp_ps(vsubq_f32(vdata3, vmax)); + + float32x4_t vsum_1 = vaddq_f32(vsum0, vsum1); + float32x4_t vsum_2 = vaddq_f32(vsum2, vsum3); + + float32x4_t vsum = vaddq_f32(vsum_1, vsum_2); + + float32x4_t vinf = div_ps(vone, vsum); + + vsum0 = vmulq_f32(vsum0, vinf); + vsum1 = vmulq_f32(vsum1, vinf); + vsum2 = vmulq_f32(vsum2, vinf); + vsum3 = vmulq_f32(vsum3, vinf); + + vst1q_f32(dout_ptr0, vsum0); + vst1q_f32(dout_ptr1, vsum1); + vst1q_f32(dout_ptr2, vsum2); + vst1q_f32(dout_ptr3, vsum3); + } + + int i = cmp_cnt * 8; + for (; i < compute_size; i++) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + float max_data = din[real_index]; + // get max + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + max_data = din[real_index] > max_data ? din[real_index] : max_data; + } + + real_index = idx_outer * inner_num + idx_inner; + // sub, exp and sum + dout[real_index] = expf(din[real_index] - max_data); + float sum_data = dout[real_index]; + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + dout[real_index] = expf(din[real_index] - max_data); + sum_data += dout[real_index]; + } + + float sum_inv = 1.f / sum_data; + real_index = idx_outer * inner_num + idx_inner; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + dout[real_index] *= sum_inv; + real_index += inner_num; + } + } +} + +template <> +void softmax_inner8(const float* din, float* dout, const int axis_size, + const int inner_num, const int outer_num) { + int compute_size = inner_num * outer_num; + int cmp_cnt = compute_size >> 3; +#pragma omp parallel for + for (int c = 0; c < cmp_cnt; ++c) { + int i = c * 8; + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + const float* din_ptr = din + real_index; + float32x4_t vmax = vld1q_f32(din_ptr); + float32x4_t vmax2 = vld1q_f32(din_ptr + 4); + // get max + for (int j = 1; j < axis_size; ++j) { + din_ptr += inner_num; + float32x4_t vdata = vld1q_f32(din_ptr); + float32x4_t vdata2 = vld1q_f32(din_ptr + 4); + vmax = vmaxq_f32(vmax, vdata); + vmax2 = vmaxq_f32(vmax2, vdata2); + } + + // sub, exp and sum + din_ptr = din + real_index; + float* dout_ptr = dout + real_index; + float32x4_t vdata = vld1q_f32(din_ptr); + float32x4_t vdata2 = vld1q_f32(din_ptr + 4); + float32x4_t vsum = exp_ps(vsubq_f32(vdata, vmax)); + float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax2)); + din_ptr += inner_num; + vst1q_f32(dout_ptr, vsum); + vst1q_f32(dout_ptr + 4, vsum2); + dout_ptr += inner_num; + for (int j = 1; j < axis_size; ++j) { + float32x4_t vdata0 = vld1q_f32(din_ptr); + float32x4_t vdata1 = vld1q_f32(din_ptr + 4); + vdata0 = exp_ps(vsubq_f32(vdata0, vmax)); + vdata1 = exp_ps(vsubq_f32(vdata1, vmax2)); + din_ptr += inner_num; + vsum = vaddq_f32(vsum, vdata0); + vsum2 = vaddq_f32(vsum2, vdata1); + vst1q_f32(dout_ptr, vdata0); + vst1q_f32(dout_ptr + 4, vdata1); + dout_ptr += inner_num; + } + + float32x4_t vone = vdupq_n_f32(1.0f); + float32x4_t vinf = div_ps(vone, vsum); + float32x4_t vinf2 = div_ps(vone, vsum2); + dout_ptr = dout + real_index; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + float32x4_t vdata0 = vld1q_f32(dout_ptr); + float32x4_t vdata1 = vld1q_f32(dout_ptr + 4); + vdata0 = vmulq_f32(vdata0, vinf); + vdata1 = vmulq_f32(vdata1, vinf2); + vst1q_f32(dout_ptr, vdata0); + vst1q_f32(dout_ptr + 4, vdata1); + dout_ptr += inner_num; + } + } + + for (int i = cmp_cnt * 8; i < compute_size; i++) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + float max_data = din[real_index]; + // get max + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + max_data = din[real_index] > max_data ? din[real_index] : max_data; + } + + real_index = idx_outer * inner_num + idx_inner; + // sub, exp and sum + dout[real_index] = expf(din[real_index] - max_data); + float sum_data = dout[real_index]; + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + dout[real_index] = expf(din[real_index] - max_data); + sum_data += dout[real_index]; + } + + float sum_inv = 1.f / sum_data; + real_index = idx_outer * inner_num + idx_inner; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + dout[real_index] *= sum_inv; + real_index += inner_num; + } + } +} + +template <> +void softmax_inner4(const float* din, float* dout, const int axis_size, + const int inner_num, const int outer_num) { + int compute_size = inner_num * outer_num; + int cmp_cnt = compute_size >> 2; +#pragma omp parallel for + for (int c = 0; c < cmp_cnt; ++c) { + int i = c * 4; + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + // float max_data = din[real_index]; + const float* din_ptr = din + real_index; + float32x4_t vmax = vld1q_f32(din_ptr); + // get max + for (int j = 1; j < axis_size; ++j) { + din_ptr += inner_num; + float32x4_t vdata = vld1q_f32(din_ptr); + vmax = vmaxq_f32(vmax, vdata); + } + // sub, exp and sum + din_ptr = din + real_index; + float* dout_ptr = dout + real_index; + float32x4_t vdata = vld1q_f32(din_ptr); + float32x4_t vsum = exp_ps(vsubq_f32(vdata, vmax)); + din_ptr += inner_num; + vst1q_f32(dout_ptr, vsum); + dout_ptr += inner_num; + for (int j = 1; j < axis_size; ++j) { + // real_index += inner_num; + float32x4_t vdata0 = vld1q_f32(din_ptr); + vdata0 = exp_ps(vsubq_f32(vdata0, vmax)); + din_ptr += inner_num; + vsum = vaddq_f32(vsum, vdata0); + vst1q_f32(dout_ptr, vdata0); + dout_ptr += inner_num; + } + + float32x4_t vone = vdupq_n_f32(1.0f); + float32x4_t vinf = div_ps(vone, vsum); + dout_ptr = dout + real_index; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + float32x4_t vdata0 = vld1q_f32(dout_ptr); + vdata0 = vmulq_f32(vdata0, vinf); + vst1q_f32(dout_ptr, vdata0); + dout_ptr += inner_num; + } + } + + for (int i = cmp_cnt * 4; i < compute_size; i++) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + float max_data = din[real_index]; + // get max + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + max_data = din[real_index] > max_data ? din[real_index] : max_data; + } + + real_index = idx_outer * inner_num + idx_inner; + // sub, exp and sum + dout[real_index] = expf(din[real_index] - max_data); + float sum_data = dout[real_index]; + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + dout[real_index] = expf(din[real_index] - max_data); + sum_data += dout[real_index]; + } + + float sum_inv = 1.f / sum_data; + real_index = idx_outer * inner_num + idx_inner; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + dout[real_index] *= sum_inv; + real_index += inner_num; + } + } +} + +template <> +void softmax_inner1_large_axis(const float* din, float* dout, + const int outer_size, + const int axis_size) { +#pragma omp parallel for + for (int i = 0; i < outer_size; ++i) { + const float* din_ptr = din + i * axis_size; + float* dout_ptr = dout + i * axis_size; + + const float* din_max_ptr = din_ptr; + int nn = axis_size >> 2; + + // get max + float32x4_t vmax = vld1q_f32(din_max_ptr); + din_max_ptr += 4; + int j = 1; + for (; j < nn; ++j) { + vmax = vmaxq_f32(vmax, vld1q_f32(din_max_ptr)); + din_max_ptr += 4; + } + float32x2_t vhmax = vmax_f32(vget_high_f32(vmax), vget_low_f32(vmax)); + float max_data = std::max(vget_lane_f32(vhmax, 0), vget_lane_f32(vhmax, 1)); + for (j = 4 * j; j < axis_size; ++j) { + max_data = std::max(max_data, din_max_ptr[0]); + din_max_ptr++; + } + + // sub, exp and sum + const float* din_sum_ptr = din_ptr; + float* dout_sum_ptr = dout_ptr; + vmax = vdupq_n_f32(max_data); + float32x4_t vsub_exp = exp_ps(vsubq_f32(vld1q_f32(din_sum_ptr), vmax)); + float32x4_t vsum = vsub_exp; + vst1q_f32(dout_sum_ptr, vsub_exp); + din_sum_ptr += 4; + dout_sum_ptr += 4; + + j = 1; + for (; j < nn; ++j) { + vsub_exp = exp_ps(vsubq_f32(vld1q_f32(din_sum_ptr), vmax)); + vst1q_f32(dout_sum_ptr, vsub_exp); + vsum = vaddq_f32(vsum, vsub_exp); + din_sum_ptr += 4; + dout_sum_ptr += 4; + } + float32x2_t vhsum = vadd_f32(vget_high_f32(vsum), vget_low_f32(vsum)); + float sum_data = vget_lane_f32(vhsum, 0) + vget_lane_f32(vhsum, 1); + + for (j = 4 * j; j < axis_size; ++j) { + dout_sum_ptr[0] = expf(din_sum_ptr[0] - max_data); + sum_data += dout_sum_ptr[0]; + din_sum_ptr++; + dout_sum_ptr++; + } + + float sum_inv = 1.f / sum_data; + float* dout_res_ptr = dout_ptr; + float32x4_t vinv = vdupq_n_f32(sum_inv); + // get softmax result + j = 0; + for (; j < nn; ++j) { + float32x4_t vout = vld1q_f32(dout_res_ptr); + float32x4_t vres = vmulq_f32(vout, vinv); + vst1q_f32(dout_res_ptr, vres); + dout_res_ptr += 4; + } + for (j = nn * 4; j < axis_size; ++j) { + dout_ptr[j] *= sum_inv; + } + } +} + +template <> +void softmax_inner1_small_axis(const float* din, float* dout, + const int outer_size, + const int axis_size) { +#pragma omp parallel for + for (int i = 0; i < outer_size; ++i) { + const float* din_ptr = din + i * axis_size; + float* dout_ptr = dout + i * axis_size; + // get max + float max_data = din_ptr[0]; + for (int j = 1; j < axis_size; ++j) { + max_data = std::max(max_data, din_ptr[j]); + } + + // sub, exp and sum + float sum_data = 0.f; + for (int j = 0; j < axis_size; ++j) { + dout_ptr[j] = expf(din_ptr[j] - max_data); + sum_data += dout_ptr[j]; + } + + float sum_inv = 1.f / sum_data; + for (int j = 0; j < axis_size; ++j) { + dout_ptr[j] *= sum_inv; + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/arm/math/softmax.h b/paddle/fluid/lite/arm/math/softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..c0109ffd12f60addd820740a5662a80e4af85317 --- /dev/null +++ b/paddle/fluid/lite/arm/math/softmax.h @@ -0,0 +1,52 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void softmax_basic(const T* din, T* dout, const int axis_size, + const int inner_num, const int outer_num); + +template +void softmax_inner8_axis4(const T* din, T* dout, const int axis_size, + const int inner_num, const int outer_num); + +template +void softmax_inner4_axis4(const T* din, T* dout, const int axis_size, + const int inner_num, const int outer_num); +template +void softmax_inner8(const T* din, T* dout, const int axis_size, + const int inner_num, const int outer_num); + +template +void softmax_inner4(const T* din, T* dout, const int axis_size, + const int inner_num, const int outer_num); + +template +void softmax_inner1_large_axis(const T* din, T* dout, const int outer_size, + const int axis_size); + +template +void softmax_inner1_small_axis(const T* din, T* dout, const int outer_size, + const int axis_size); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/arm/CMakeLists.txt b/paddle/fluid/lite/kernels/arm/CMakeLists.txt index ebdd42443e0813ace83e7888ede5a45194270adc..82b1a07810ffc265510b2e6d7dac1d2cbadcbbb5 100644 --- a/paddle/fluid/lite/kernels/arm/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/arm/CMakeLists.txt @@ -8,13 +8,16 @@ cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps}) cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3) cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} eigen3) +cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm) +lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) set(arm_kernels fc_compute_arm relu_compute_arm mul_compute_arm - scale_compute_arm) + scale_compute_arm + softmax_compute_arm) set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels") diff --git a/paddle/fluid/lite/kernels/arm/softmax_compute.cc b/paddle/fluid/lite/kernels/arm/softmax_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..ceb061c901ff6c13c3e2160cd6a648ff36f4429e --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/softmax_compute.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/softmax_compute.h" +#include "paddle/fluid/lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void SoftmaxCompute::Run() { + auto& param = Param(); + const float* din = param.x->data(); + float* dout = param.output->mutable_data(); + auto dim_x = param.x->dims(); + auto rank_x = dim_x.size(); + int axis = param.axis; + if (axis < 0) { + axis += rank_x; + } + int outer_num = dim_x.Slice(0, axis).production(); + int inner_num = dim_x.Slice(axis + 1, rank_x).production(); + int axis_size = dim_x[axis]; + if (inner_num == 1) { + if (axis_size >= 4) { + lite::arm::math::softmax_inner1_large_axis(din, dout, outer_num, + axis_size); + } else { + lite::arm::math::softmax_inner1_small_axis(din, dout, outer_num, + axis_size); + } + } else { + int compute_size = outer_num * inner_num; + if (axis_size == 4 && inner_num % 8 == 0) { + lite::arm::math::softmax_inner8_axis4(din, dout, axis_size, inner_num, + outer_num); + } else if (axis_size == 4 && inner_num % 4 == 0) { + lite::arm::math::softmax_inner4_axis4(din, dout, axis_size, inner_num, + outer_num); + } else { + if (inner_num % 8 == 0) { + lite::arm::math::softmax_inner8(din, dout, axis_size, inner_num, + outer_num); + } else if (inner_num % 4 == 0) { + lite::arm::math::softmax_inner4(din, dout, axis_size, inner_num, + outer_num); + } else { + lite::arm::math::softmax_basic(din, dout, axis_size, inner_num, + outer_num); + } + } + } +} + +TargetType SoftmaxCompute::target() const { return TARGET(kARM); } + +PrecisionType SoftmaxCompute::precision() const { return PRECISION(kFloat); } + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, + paddle::lite::kernels::arm::SoftmaxCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/softmax_compute.h b/paddle/fluid/lite/kernels/arm/softmax_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..2daec0f9ee4167772fa7eeb5c0059a810f5db9ca --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/softmax_compute.h @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class SoftmaxCompute : public KernelLite { + public: + void Run() override; + + TargetType target() const override; + PrecisionType precision() const override; + + virtual ~SoftmaxCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/arm/softmax_compute_test.cc b/paddle/fluid/lite/kernels/arm/softmax_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d24868f2c5679e4f9b4bf0b5ad1bfbf62f3cbad5 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/softmax_compute_test.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/softmax_compute.h" +#include +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +template +void softmat_compute_ref(const operators::SoftmaxParam& param) { + const dtype* x_data = param.x->mutable_data(); + dtype* output_data = param.output->mutable_data(); + DDim dim = param.x->dims(); + ASSERT_EQ(dim.data(), param.output->dims().data()); + auto rank = dim.size(); + int axis = param.axis; + if (axis < 0) { + axis += rank; + } + int axis_size = dim[axis]; + int outer_num = dim.Slice(0, axis).production(); + int inner_num = dim.Slice(axis + 1, rank).production(); + int compute_size = outer_num * inner_num; + for (int i = 0; i < compute_size; i++) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int start = idx_outer * inner_num + idx_inner; + int offset; + + offset = start; + dtype max_data = std::numeric_limits::lowest(); + for (int j = 0; j < axis_size; j++) { + max_data = x_data[offset] > max_data ? x_data[offset] : max_data; + offset += inner_num; + } + + offset = start; + dtype sum_data = (dtype)0; + for (int j = 0; j < axis_size; j++) { + output_data[offset] = exp(x_data[offset] - max_data); + sum_data += output_data[offset]; + offset += inner_num; + } + + offset = start; + for (int j = 0; j < axis_size; j++) { + output_data[offset] /= sum_data; + offset += inner_num; + } + } +} + +TEST(softmax_arm, init) { + SoftmaxCompute softmax; + ASSERT_EQ(softmax.precision(), PRECISION(kFloat)); + ASSERT_EQ(softmax.target(), TARGET(kARM)); +} + +TEST(softmax_arm, compute) { + SoftmaxCompute softmax; + operators::SoftmaxParam param; + + lite::Tensor x; + lite::Tensor output; + lite::Tensor output_ref; + + for (auto n : {1, 3, 4, 11}) { + for (auto c : {1, 3, 11, 4}) { + for (auto h : {3, 1, 11, 4}) { + for (auto w : {1, 3, 4, 12}) { + for (auto axis : {-4, -3, -2, -1, 0, 1, 2, 3}) { + x.Resize(DDim(std::vector({n, c, h, w}))); + output.Resize(DDim(std::vector({n, c, h, w}))); + output_ref.Resize(DDim(std::vector({n, c, h, w}))); + auto* x_data = x.mutable_data(); + auto* output_data = output.mutable_data(); + auto* output_ref_data = output_ref.mutable_data(); + for (int i = 0; i < x.dims().production(); i++) { + x_data[i] = i; + } + param.x = &x; + param.axis = axis; + param.output = &output; + softmax.SetParam(param); + softmax.Run(); + param.output = &output_ref; + softmat_compute_ref(param); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); + } + } + } + } + } + } +} + +TEST(softmax, retrive_op) { + auto softmax = + KernelRegistry::Global().Create( + "softmax"); + ASSERT_FALSE(softmax.empty()); + ASSERT_TRUE(softmax.front()); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def); diff --git a/paddle/fluid/lite/kernels/arm/use_kernels.h b/paddle/fluid/lite/kernels/arm/use_kernels.h index af437bf8e4ad5a56d4f9575d609f4f47e16edb2a..d856950f3a177d08cdc950c259abf3d1a194ee25 100644 --- a/paddle/fluid/lite/kernels/arm/use_kernels.h +++ b/paddle/fluid/lite/kernels/arm/use_kernels.h @@ -18,5 +18,6 @@ USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(feed, kARM, kAny, kAny, def); USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def); diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index 40782010881127ba2cda238a0112dbabb080e189..4190e6037cf8b99a5e175fb3070e36eb89832fed 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -4,6 +4,7 @@ cc_library(fc_op_lite SRCS fc_op.cc DEPS ${op_DEPS}) cc_library(relu_op_lite SRCS relu_op.cc DEPS ${op_DEPS}) cc_library(mul_op_lite SRCS mul_op.cc DEPS ${op_DEPS}) cc_library(scale_op_lite SRCS scale_op.cc DEPS ${op_DEPS}) +cc_library(softmax_op_lite SRCS softmax_op.cc DEPS ${op_DEPS}) cc_library(feed_op_lite SRCS feed_op.cc DEPS ${op_DEPS}) cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS}) cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS}) @@ -19,6 +20,7 @@ set(ops_lite relu_op_lite mul_op_lite scale_op_lite + softmax_op_lite feed_op_lite fetch_op_lite io_copy_op_lite @@ -28,3 +30,4 @@ set(ops_lite PARENT_SCOPE) lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite memory_lite) +lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite) diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index c970ac2d87379bd4fdb9b99ab9b9cd8f13210d3a..77d78b481a734f999bf5c4d8dc91bf7b6840d4f7 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -93,6 +93,14 @@ struct ScaleParam { bool bias_after_scale{true}; }; +// For Softmax Op +struct SoftmaxParam { + lite::Tensor* x{}; + lite::Tensor* output{}; + + int axis{-1}; +}; + /// ----------------------- element wise operators ---------------------- struct ElementwiseParam { const lite::Tensor* X{}; diff --git a/paddle/fluid/lite/operators/softmax_op.cc b/paddle/fluid/lite/operators/softmax_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..3e41a5ffffe9ab1d045bdc5da3a819cb645408d5 --- /dev/null +++ b/paddle/fluid/lite/operators/softmax_op.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/softmax_op.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool SoftmaxOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + auto dim_x = param_.x->dims(); + auto rank_x = dim_x.size(); + CHECK_OR_FALSE(param_.axis >= -rank_x && param_.axis < rank_x); + return true; +} + +bool SoftmaxOp::InferShape() const { + param_.output->Resize(param_.x->dims()); + return true; +} + +bool SoftmaxOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) { + param_.x = const_cast( + &scope->FindVar(opdesc.Input("X").front())->Get()); + param_.output = + scope->FindVar(opdesc.Output("Out").front())->GetMutable(); + param_.axis = GetAttr(opdesc.GetAttr("axis")); + CHECK(param_.x); + CHECK(param_.output); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(softmax, paddle::lite::operators::SoftmaxOp); diff --git a/paddle/fluid/lite/operators/softmax_op.h b/paddle/fluid/lite/operators/softmax_op.h new file mode 100644 index 0000000000000000000000000000000000000000..062f707c6e0581d9fcc0ec9e083439a0ca9e656d --- /dev/null +++ b/paddle/fluid/lite/operators/softmax_op.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SoftmaxOp : public OpLite { + public: + SoftmaxOp() {} + explicit SoftmaxOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "softmax"; } + + private: + mutable SoftmaxParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/softmax_op_test.cc b/paddle/fluid/lite/operators/softmax_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f999564541a869fa9794a88bab9f299bb4df0f19 --- /dev/null +++ b/paddle/fluid/lite/operators/softmax_op_test.cc @@ -0,0 +1,54 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/softmax_op.h" +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +TEST(softmax_op_lite, test) { + // prepare variables + Scope scope; + auto* x = scope.Var("x")->GetMutable(); + auto* output = scope.Var("output")->GetMutable(); + x->Resize(DDim(std::vector({10, 20}))); + output->Resize(DDim(std::vector{10, 20})); + + // set data + for (int i = 0; i < 10 * 20; i++) { + x->mutable_data()[i] = i; + } + for (int i = 0; i < 10 * 20; i++) { + output->mutable_data()[i] = 0.; + } + + // prepare op desc + lite::OpDesc desc; + desc.SetType("softmax"); + desc.SetInput("X", {"x"}); + desc.SetOutput("Out", {"output"}); + desc.SetAttr("axis", static_cast(-1)); + + SoftmaxOp softmax("softmax"); + + softmax.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}}); + softmax.Attach(desc, &scope); +} + +} // namespace operators +} // namespace lite +} // namespace paddle