未验证 提交 807454fd 编写于 作者: X xiebaiyuan 提交者: GitHub

[MOBILE]Fix mobile softmax axis (#3456)

* [mobile] wrap_shape in mobile loader.

* [mobile] fix mobile bilinear compatibility ,test=mobile

* [mobile] fix softmax axis  ,test=develop
上级 b88e45a1
......@@ -18,6 +18,44 @@ limitations under the License. */
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
void softmax_basic_axis_float(const float *din, float *dout,
const int axis_size, const int inner_num,
const int outer_num) {
int compute_size = inner_num * outer_num;
#pragma omp parallel for
for (int i = 0; i < compute_size; ++i) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
float max_data = din[real_index];
// get max
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
max_data = din[real_index] > max_data ? din[real_index] : max_data;
}
real_index = idx_outer * inner_num + idx_inner;
// sub, exp and sum
dout[real_index] = expf(din[real_index] - max_data);
float sum_data = dout[real_index];
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
dout[real_index] = expf(din[real_index] - max_data);
sum_data += dout[real_index];
}
float sum_inv = 1.f / sum_data;
real_index = idx_outer * inner_num + idx_inner;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
dout[real_index] *= sum_inv;
real_index += inner_num;
}
}
}
template <typename P>
void SoftmaxCompute(const SoftmaxParam<CPU> &param) {
const Tensor *in_x = param.InputX();
......@@ -25,7 +63,29 @@ void SoftmaxCompute(const SoftmaxParam<CPU> &param) {
auto x_dims = in_x->dims();
out->Resize(x_dims);
out->mutable_data<float>();
math::SoftmaxFuntor<CPU, float>()(in_x, out);
if (param.has_axis_) {
int axis = param.axis_;
int axis_size = x_dims[axis];
auto x_rank = x_dims.size();
DLOG << "x_rank :" << x_rank;
if (axis < 0) {
axis += x_rank;
}
DLOG << "axis :" << axis;
int outer_num = framework::product(framework::slice_ddim(x_dims, 0, axis));
DLOG << "outer_num :" << outer_num;
int inner_num =
framework::product(framework::slice_ddim(x_dims, axis + 1, x_rank));
DLOG << "inner_num :" << inner_num;
softmax_basic_axis_float(in_x->data<float>(), out->data<float>(), axis_size,
inner_num, outer_num);
} else {
math::SoftmaxFuntor<CPU, float>()(in_x, out);
}
}
} // namespace operators
} // namespace paddle_mobile
......
......@@ -1180,10 +1180,17 @@ class SoftmaxParam : public OpParam {
: OpParam(inputs, outputs, attrs, scope) {
input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
if (HasAttr("axis", attrs)) {
axis_ = GetAttr<int>("axis", attrs);
has_axis_ = true;
}
}
const GType *InputX() const { return input_x_; }
GType *Out() const { return out_; }
int axis_ = -1;
bool has_axis_ = false;
private:
GType *input_x_;
GType *out_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册