#include "src/fallback/softmax/opr_impl.h" #include #include #include "src/fallback/elemwise/gi_impl/gi_mathfun.h" #include "src/naive/handle.h" namespace megdnn { namespace fallback { static void do_softmax( const float* sptr, float* dptr, size_t A, size_t B, size_t C, _megdnn_workspace workspace) { constexpr auto float_min = std::numeric_limits::min(); constexpr auto step = GI_SIMD_LEN_BYTE / sizeof(float); // TODO: When C=2,3,4..., src_ptr span is relatively large, the performance may // be poor if (C != 1) { WorkspaceBundle workspace_bundle{ workspace.raw_ptr, {A * C * sizeof(float), A * C * sizeof(float)}}; float* max = workspace_bundle.get_workspace(0).raw_ptr->as(); GI_FLOAT32_t v_max = GiBroadcastFloat32(float_min); size_t i = 0; for (; i + step <= A * C; i += step) GiStoreFloat32(max + i, v_max); for (; i < A * C; i++) max[i] = float_min; for (size_t a = 0; a < A; a++) { for (size_t b = 0; b < B; b++) { auto max_ptr = max + a * C; auto limit = max_ptr + C; auto src_ptr = sptr + a * B * C + b * C; for (; max_ptr + step <= limit; max_ptr += step, src_ptr += step) { GI_FLOAT32_t v_p = GiLoadFloat32(src_ptr); GI_FLOAT32_t v_max = GiLoadFloat32(max_ptr); v_max = GiMaximumFloat32(v_max, v_p); GiStoreFloat32(max_ptr, v_max); } for (; max_ptr < limit; ++max_ptr, ++src_ptr) { *max_ptr = std::max(*src_ptr, *max_ptr); } } } float* sum = workspace_bundle.get_workspace(1).raw_ptr->as(); memset(sum, 0, A * C * sizeof(float)); for (size_t a = 0; a < A; a++) { for (size_t b = 0; b < B; b++) { auto max_ptr = max + a * C; auto limit = max_ptr + C; auto sum_ptr = sum + a * C; auto src_ptr = sptr + a * B * C + C * b; auto dst_ptr = dptr + a * B * C + C * b; for (; max_ptr + step <= limit; max_ptr += step, sum_ptr += step, src_ptr += step, dst_ptr += step) { GI_FLOAT32_t v_p = GiLoadFloat32(src_ptr); GI_FLOAT32_t v_max = GiLoadFloat32(max_ptr); GI_FLOAT32_t v_sum = GiLoadFloat32(sum_ptr); v_p = GiExpPsFloat32(GiSubtractFloat32(v_p, v_max)); v_sum = GiAddFloat32(v_p, v_sum); GiStoreFloat32(dst_ptr, v_p); GiStoreFloat32(sum_ptr, v_sum); } for (; max_ptr < limit; ++max_ptr, ++sum_ptr, ++src_ptr, ++dst_ptr) { *dst_ptr = exp(*src_ptr - *max_ptr); *sum_ptr += *dst_ptr; } } } for (size_t a = 0; a < A; a++) { for (size_t b = 0; b < B; b++) { auto sum_ptr = sum + a * C; auto limit = sum_ptr + C; auto dst_ptr = dptr + a * B * C + C * b; for (; sum_ptr + step <= limit; sum_ptr += step, dst_ptr += step) { GI_FLOAT32_t v_p = GiLoadFloat32(dst_ptr); GI_FLOAT32_t v_sum = GiLoadFloat32(sum_ptr); v_p = GiDivideFloat32(v_p, v_sum); GiStoreFloat32(dst_ptr, v_p); } for (; sum_ptr < limit; ++sum_ptr, ++dst_ptr) *dst_ptr = *dst_ptr / *sum_ptr; } } } else { for (size_t a = 0; a < A; a++) { auto max = float_min; { auto src_ptr = sptr + a * B; auto limit = src_ptr + B; GI_FLOAT32_t v_max = GiBroadcastFloat32(max); for (; src_ptr + step <= limit; src_ptr += step) { GI_FLOAT32_t v_p = GiLoadFloat32(src_ptr); v_max = GiMaximumFloat32(v_max, v_p); } max = std::max(max, GiReduceMaxNanFloat32(v_max)); for (; src_ptr < limit; ++src_ptr) { max = std::max(*src_ptr, max); } } auto sum = 0.f; { auto src_ptr = sptr + a * B; auto limit = src_ptr + B; auto dst_ptr = dptr + a * B; GI_FLOAT32_t v_sum = GiZeroFloat32(); GI_FLOAT32_t v_max = GiBroadcastFloat32(max); for (; src_ptr + step <= limit; src_ptr += step, dst_ptr += step) { GI_FLOAT32_t v_p = GiLoadFloat32(src_ptr); v_p = GiExpPsFloat32(GiSubtractFloat32(v_p, v_max)); GiStoreFloat32(dst_ptr, v_p); v_sum = GiAddFloat32(v_sum, v_p); } sum += GiReduceAddFloat32(v_sum); for (; src_ptr < limit; ++src_ptr, ++dst_ptr) { *dst_ptr = exp(*src_ptr - max); sum += *dst_ptr; } } { auto dst_ptr = dptr + a * B; auto limit = dst_ptr + B; sum = 1 / sum; GI_FLOAT32_t v_sum = GiBroadcastFloat32(sum); for (; dst_ptr + step <= limit; dst_ptr += step) { GI_FLOAT32_t v_p = GiLoadFloat32(dst_ptr); v_p = GiMultiplyFloat32(v_p, v_sum); GiStoreFloat32(dst_ptr, v_p); } for (; dst_ptr < limit; ++dst_ptr) { *dst_ptr *= sum; } } } } } void SoftmaxForwardImpl::exec( _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { auto axis = param().axis; if (axis < 0) axis += src.layout.ndim; megdnn_assert(axis >= 0); check_exec(src.layout, dst.layout, workspace.size); if (!usable(src.layout)) { naive::SoftmaxForwardImpl::exec(src, dst, workspace); return; } typedef DTypeTrait::ctype Float32; auto sptr = src.ptr(); auto dptr = dst.ptr(); size_t A, B, C; reduce::get_ABC(src.layout, A, B, C, axis); MEGDNN_DISPATCH_CPU_KERN_OPR(do_softmax(sptr, dptr, A, B, C, workspace)); } } // namespace fallback } // namespace megdnn // vim: syntax=cpp.doxygen