diff --git a/mobile/src/operators/kernel/central-arm-func/softmax_arm_func.h b/mobile/src/operators/kernel/central-arm-func/softmax_arm_func.h index a94c8299c514bc9e2937daf57b1a845d7be56b16..29d63937ba59debf75da6ac5c5d31d50ab6abfa7 100644 --- a/mobile/src/operators/kernel/central-arm-func/softmax_arm_func.h +++ b/mobile/src/operators/kernel/central-arm-func/softmax_arm_func.h @@ -18,6 +18,44 @@ limitations under the License. */ #include "operators/op_param.h" namespace paddle_mobile { namespace operators { + +void softmax_basic_axis_float(const float *din, float *dout, + const int axis_size, const int inner_num, + const int outer_num) { + int compute_size = inner_num * outer_num; +#pragma omp parallel for + for (int i = 0; i < compute_size; ++i) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + float max_data = din[real_index]; + // get max + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + max_data = din[real_index] > max_data ? din[real_index] : max_data; + } + + real_index = idx_outer * inner_num + idx_inner; + // sub, exp and sum + dout[real_index] = expf(din[real_index] - max_data); + float sum_data = dout[real_index]; + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + dout[real_index] = expf(din[real_index] - max_data); + sum_data += dout[real_index]; + } + + float sum_inv = 1.f / sum_data; + real_index = idx_outer * inner_num + idx_inner; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + dout[real_index] *= sum_inv; + real_index += inner_num; + } + } +} + template void SoftmaxCompute(const SoftmaxParam ¶m) { const Tensor *in_x = param.InputX(); @@ -25,7 +63,29 @@ void SoftmaxCompute(const SoftmaxParam ¶m) { auto x_dims = in_x->dims(); out->Resize(x_dims); out->mutable_data(); - math::SoftmaxFuntor()(in_x, out); + if (param.has_axis_) { + int axis = param.axis_; + int axis_size = x_dims[axis]; + auto x_rank = x_dims.size(); + DLOG << "x_rank :" << x_rank; + + if (axis < 0) { + axis += x_rank; + } + + DLOG << "axis :" << axis; + + int outer_num = framework::product(framework::slice_ddim(x_dims, 0, axis)); + DLOG << "outer_num :" << outer_num; + int inner_num = + framework::product(framework::slice_ddim(x_dims, axis + 1, x_rank)); + DLOG << "inner_num :" << inner_num; + + softmax_basic_axis_float(in_x->data(), out->data(), axis_size, + inner_num, outer_num); + } else { + math::SoftmaxFuntor()(in_x, out); + } } } // namespace operators } // namespace paddle_mobile diff --git a/mobile/src/operators/op_param.h b/mobile/src/operators/op_param.h index b49f3fce320ce6fe0a350bed2f223c2a74e56047..4cce4f4914a95661485ef320ae013b13548f30e2 100644 --- a/mobile/src/operators/op_param.h +++ b/mobile/src/operators/op_param.h @@ -1180,10 +1180,17 @@ class SoftmaxParam : public OpParam { : OpParam(inputs, outputs, attrs, scope) { input_x_ = InputXFrom(inputs, *scope); out_ = OutFrom(outputs, *scope); + if (HasAttr("axis", attrs)) { + axis_ = GetAttr("axis", attrs); + has_axis_ = true; + } } const GType *InputX() const { return input_x_; } GType *Out() const { return out_; } + int axis_ = -1; + bool has_axis_ = false; + private: GType *input_x_; GType *out_;