提交 8cdeea3b 编写于 作者: H hjchen2

Fix elementwise add bug for genet model

上级 dab2445d
......@@ -38,20 +38,31 @@ void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) {
Out->mutable_data<float>();
int axis = param.Axis();
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
const auto &x_dims = input_x->dims();
const auto &y_dims = input_y->dims();
/// axis = -1 represent the last dimensions.
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
size_t batch = 1;
size_t channels = 1;
size_t elementwise_num = 1;
for (int i = 0; i < axis; ++i) {
batch *= input_x->dims()[i];
batch *= x_dims[i];
}
for (int i = axis + 1; i < input_x->dims().size(); ++i) {
elementwise_num *= input_x->dims()[i];
for (int i = 0; i < y_dims.size(); ++i) {
channels *= y_dims[i];
}
for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) {
elementwise_num *= x_dims[i];
}
const float *bias_data = input_y->data<float>();
const float *input_data = input_x->data<float>();
float *output_data = Out->mutable_data<float>();
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < input_x->dims()[axis]; ++j) {
size_t offset = (i * input_x->dims()[axis] + j) * elementwise_num;
const float *input = input_x->data<float>() + offset;
const float *bias = input_y->data<float>() + j;
float *output = Out->mutable_data<float>() + offset;
for (int j = 0; j < channels; ++j) {
size_t offset = (i * channels + j) * elementwise_num;
const float *input = input_data + offset;
const float *bias = bias_data + j;
float *output = output_data + offset;
int loop = elementwise_num >> 0x4;
int remain = elementwise_num & 0xF;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册