提交 f8e4ab86 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4113 Add fused_activation function for Sub, Add, Mul and Div op

Merge pull request !4113 from wangminggui/master
......@@ -384,19 +384,19 @@ table Eltwise {
}
table Add {
activationType : ActivationType;
activationType: ActivationType = 0;
}
table Sub {
activationType : ActivationType;
activationType: ActivationType = 0;
}
table Mul {
activationType : ActivationType;
activationType: ActivationType = 0;
}
table Div {
activationType : ActivationType;
activationType: ActivationType = 0;
}
table AddGrad {
......
......@@ -510,6 +510,23 @@ OpParameter *PopulateArithmetic(const lite::Primitive *primitive) {
arithmetic_param->op_parameter_.type_ = primitive->Type();
arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting();
arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims();
switch (primitive->Type()) {
case schema::PrimitiveType_Add:
arithmetic_param->activation_type_ = primitive->Value()->value_as_Add()->activationType();
break;
case schema::PrimitiveType_Sub:
arithmetic_param->activation_type_ = primitive->Value()->value_as_Sub()->activationType();
break;
case schema::PrimitiveType_Mul:
arithmetic_param->activation_type_ = primitive->Value()->value_as_Mul()->activationType();
break;
case schema::PrimitiveType_Div:
arithmetic_param->activation_type_ = primitive->Value()->value_as_Div()->activationType();
break;
default:
arithmetic_param->activation_type_ = 0;
break;
}
auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0();
(void)memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->InShape1();
......
......@@ -56,29 +56,26 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
auto input1_data1 = reinterpret_cast<float *>(inputs_[1]->Data());
auto output_data = reinterpret_cast<float *>(outputs_[0]->Data());
auto element_num = outputs_[0]->ElementsNum();
MS_ASSERT(thread_count_ != 0);
int stride = UP_DIV(element_num, thread_count_);
int count = MSMIN(stride, element_num - stride * task_id);
if (arithmetic_run_ == nullptr) {
MS_LOG(ERROR) << "arithmetic_run function is nullptr!";
return RET_ERROR;
}
int error_code = RET_OK;
if (arithmeticParameter_->broadcasting_) {
if (arithmetic_broadcast_run_ == nullptr) {
MS_LOG(ERROR) << "broadcasting_run function is nullptr!";
return RET_ERROR;
}
MS_ASSERT(thread_count_ != 0);
int stride = UP_DIV(element_num, thread_count_);
int count = MSMIN(stride, element_num - stride * task_id);
int error_code = arithmetic_run_(tile_data0_ + stride * task_id, tile_data1_ + stride * task_id,
output_data + stride * task_id, count);
if (error_code != RET_OK) {
return RET_ERROR;
}
} else if (arithmetic_run_ != nullptr) {
int error_code = arithmetic_run_(input0_data, input1_data1, output_data, element_num);
if (error_code != RET_OK) {
return RET_ERROR;
}
error_code = arithmetic_run_(tile_data0_ + stride * task_id, tile_data1_ + stride * task_id,
output_data + stride * task_id, count);
} else {
MS_LOG(ERROR) << "arithmetic_run function is nullptr!";
error_code = arithmetic_run_(input0_data + stride * task_id, input1_data1 + stride * task_id,
output_data + stride * task_id, count);
}
if (error_code != RET_OK) {
return RET_ERROR;
}
return RET_OK;
......
......@@ -50,22 +50,59 @@ class ArithmeticCPUKernel : public LiteKernel {
ArithmeticCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx)
: LiteKernel(parameter, inputs, outputs), thread_count_(ctx->thread_num_) {
arithmeticParameter_ = reinterpret_cast<ArithmeticParameter *>(parameter);
switch (parameter->type_) {
case PrimitiveType_Mul:
arithmetic_run_ = ElementMul;
arithmetic_broadcast_run_ = BroadcastMul;
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementMulRelu;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementMulRelu6;
break;
default:
arithmetic_run_ = ElementMul;
break;
}
break;
case PrimitiveType_Add:
arithmetic_run_ = ElementAdd;
arithmetic_broadcast_run_ = BroadcastAdd;
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementAddRelu;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementAddRelu6;
break;
default:
arithmetic_run_ = ElementAdd;
break;
}
break;
case PrimitiveType_Sub:
arithmetic_run_ = ElementSub;
arithmetic_broadcast_run_ = BroadcastSub;
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementSubRelu;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementSubRelu6;
break;
default:
arithmetic_run_ = ElementSub;
break;
}
break;
case PrimitiveType_Div:
arithmetic_run_ = ElementDiv;
arithmetic_broadcast_run_ = BroadcastDiv;
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementDivRelu;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementDivRelu6;
break;
default:
arithmetic_run_ = ElementDiv;
break;
}
break;
case PrimitiveType_LogicalAnd:
arithmetic_run_ = ElementLogicalAnd;
......@@ -125,7 +162,6 @@ class ArithmeticCPUKernel : public LiteKernel {
arithmetic_broadcast_run_ = nullptr;
break;
}
arithmeticParameter_ = reinterpret_cast<ArithmeticParameter *>(parameter);
}
~ArithmeticCPUKernel() override;
......
......@@ -27,6 +27,7 @@ struct ArithmeticParameter {
OpParameter op_parameter_;
bool broadcasting_;
size_t ndim_;
int activation_type_;
int in_shape0_[5];
int in_shape1_[5];
int out_shape_[5];
......@@ -49,4 +50,3 @@ void TileDimensionsInt8(int8_t *data0, int8_t *data1, int8_t *tile_data0, int8_t
ArithmeticParameter *param);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARITHMETIC_COMMON_H_
......@@ -47,7 +47,7 @@ inline int Relu6(const float *src, int length, float *dst) {
inline int LRelu(const float *src, int length, float *dst, float alpha) {
for (int i = 0; i < length; ++i) {
dst[i] = src[i] > (src[i] * alpha) ? src[i] : (src[i] * alpha);
dst[i] = src[i] > 0 ? src[i] : (src[i] * alpha);
}
return NNACL_OK;
}
......
......@@ -21,7 +21,7 @@ int ElementMul(float *input0, float *input1, float *output, int element_size) {
int block_c4 = element_size - block_mod;
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef USE_NEON
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vmulq_f32(vin0, vin1);
......@@ -43,6 +43,73 @@ int ElementMul(float *input0, float *input1, float *output, int element_size) {
return NNACL_OK;
}
int ElementMulRelu(float *input0, float *input1, float *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
#ifdef ENABLE_NEON
float32x4_t zeros = {0, 0, 0, 0};
#endif
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vmulq_f32(vin0, vin1);
vout = vbslq_f32(vcgtq_f32(vout, zeros), vout, zeros);
vst1q_f32(output, vout);
#else
float res = input0[0] * input1[0];
output[0] = res > 0 ? res : 0;
res = input0[1] * input1[1];
output[1] = res > 0 ? res : 0;
res = input0[2] * input1[2];
output[2] = res > 0 ? res : 0;
res = input0[3] * input1[3];
output[3] = res > 0 ? res : 0;
#endif
input0 += C4NUM;
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
float res = input0[index] * input1[index];
output[index] = res > 0 ? res : 0;
}
return NNACL_OK;
}
int ElementMulRelu6(float *input0, float *input1, float *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
#ifdef ENABLE_NEON
float32x4_t zeros = {0, 0, 0, 0};
float32x4_t bounds = {6, 6, 6, 6};
#endif
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vminq_f32(vmaxq_f32(vmulq_f32(vin0, vin1), zeros), bounds);
vst1q_f32(output, vout);
#else
output[0] = MSMIN(MSMAX(input0[0] * input1[0], 0), 6);
output[1] = MSMIN(MSMAX(input0[1] * input1[1], 0), 6);
output[2] = MSMIN(MSMAX(input0[2] * input1[2], 0), 6);
output[3] = MSMIN(MSMAX(input0[3] * input1[3], 0), 6);
#endif
input0 += C4NUM;
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(input0[index] * input1[index], 0), 6);
}
return NNACL_OK;
}
int BroadcastMul(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size,
ArithmeticParameter *param) {
TileDimensions(input0, input1, tile_input0, tile_input1, param);
......@@ -54,7 +121,7 @@ int ElementAdd(float *input0, float *input1, float *output, int element_size) {
int block_c4 = element_size - block_mod;
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef USE_NEON
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vaddq_f32(vin0, vin1);
......@@ -75,6 +142,72 @@ int ElementAdd(float *input0, float *input1, float *output, int element_size) {
return NNACL_OK;
}
int ElementAddRelu(float *input0, float *input1, float *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
#ifdef ENABLE_NEON
float32x4_t zeros = {0, 0, 0, 0};
#endif
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vaddq_f32(vin0, vin1);
vout = vbslq_f32(vcgtq_f32(vout, zeros), vout, zeros);
vst1q_f32(output, vout);
#else
float res = input0[0] + input1[0];
output[0] = res > 0 ? res : 0;
res = input0[1] + input1[1];
output[1] = res > 0 ? res : 0;
res = input0[2] + input1[2];
output[2] = res > 0 ? res : 0;
res = input0[3] + input1[3];
output[3] = res > 0 ? res : 0;
#endif
input0 += C4NUM;
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
float res = input0[index] + input1[index];
output[index] = res > 0 ? res : 0;
}
return NNACL_OK;
}
int ElementAddRelu6(float *input0, float *input1, float *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
#ifdef ENABLE_NEON
float32x4_t zeros = {0, 0, 0, 0};
float32x4_t bounds = {6, 6, 6, 6};
#endif
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vminq_f32(vmaxq_f32(vaddq_f32(vin0, vin1), zeros), bounds);
vst1q_f32(output, vout);
#else
output[0] = MSMIN(MSMAX(input0[0] + input1[0], 0), 6);
output[1] = MSMIN(MSMAX(input0[1] + input1[1], 0), 6);
output[2] = MSMIN(MSMAX(input0[2] + input1[2], 0), 6);
output[3] = MSMIN(MSMAX(input0[3] + input1[3], 0), 6);
#endif
input0 += C4NUM;
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(input0[index] + input1[index], 0), 6);
}
return NNACL_OK;
}
int ElementAddInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = input0[i] + input1[i];
......@@ -99,7 +232,7 @@ int ElementSub(float *input0, float *input1, float *output, int element_size) {
int block_c4 = element_size - block_mod;
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef USE_NEON
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vsubq_f32(vin0, vin1);
......@@ -120,6 +253,72 @@ int ElementSub(float *input0, float *input1, float *output, int element_size) {
return NNACL_OK;
}
int ElementSubRelu(float *input0, float *input1, float *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
#ifdef ENABLE_NEON
float32x4_t zeros = {0, 0, 0, 0};
#endif
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vsubq_f32(vin0, vin1);
vout = vbslq_f32(vcgtq_f32(vout, zeros), vout, zeros);
vst1q_f32(output, vout);
#else
float res = input0[0] - input1[0];
output[0] = res > 0 ? res : 0;
res = input0[1] - input1[1];
output[1] = res > 0 ? res : 0;
res = input0[2] - input1[2];
output[2] = res > 0 ? res : 0;
res = input0[3] - input1[3];
output[3] = res > 0 ? res : 0;
#endif
input0 += C4NUM;
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
float res = input0[index] - input1[index];
output[index] = res > 0 ? res : 0;
}
return NNACL_OK;
}
int ElementSubRelu6(float *input0, float *input1, float *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
#ifdef ENABLE_NEON
float32x4_t zeros = {0, 0, 0, 0};
float32x4_t bounds = {6, 6, 6, 6};
#endif
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vminq_f32(vmaxq_f32(vsubq_f32(vin0, vin1), zeros), bounds);
vst1q_f32(output, vout);
#else
output[0] = MSMIN(MSMAX(input0[0] - input1[0], 0), 6);
output[1] = MSMIN(MSMAX(input0[1] - input1[1], 0), 6);
output[2] = MSMIN(MSMAX(input0[2] - input1[2], 0), 6);
output[3] = MSMIN(MSMAX(input0[3] - input1[3], 0), 6);
#endif
input0 += C4NUM;
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(input0[index] - input1[index], 0), 6);
}
return NNACL_OK;
}
int BroadcastSub(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size,
ArithmeticParameter *param) {
TileDimensions(input0, input1, tile_input0, tile_input1, param);
......@@ -137,6 +336,27 @@ int ElementDiv(float *input0, float *input1, float *output, int element_size) {
return NNACL_OK;
}
int ElementDivRelu(float *input0, float *input1, float *output, int element_size) {
for (int i = 0; i < element_size; i++) {
if (input1[i] == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO;
}
float res = input0[i] / input1[i];
output[i] = res > 0 ? res : 0;
}
return NNACL_OK;
}
int ElementDivRelu6(float *input0, float *input1, float *output, int element_size) {
for (int i = 0; i < element_size; i++) {
if (input1[i] == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO;
}
output[i] = MSMIN(MSMAX(input0[i] / input1[i], 0), 6);
}
return NNACL_OK;
}
int BroadcastDiv(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size,
ArithmeticParameter *param) {
TileDimensions(input0, input1, tile_input0, tile_input1, param);
......@@ -179,11 +399,18 @@ int ElementLogicalAnd(float *input0, float *input1, float *output, int element_s
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
#ifdef ENABLE_NEON
float32x4_t vtrue = {1, 1, 1, 1};
float32x4_t vfalse = {0, 0, 0, 0};
uint32x4_t mask = vmovq_n_u32((uint32_t(1u << 31) - 1));
uint32x4_t zeros = {0, 0, 0, 0};
#endif
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef USE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vandq_f32(vin0, vin1);
#ifdef ENABLE_NEON
uint32x4_t vin0 = vandq_u32(vreinterpretq_s32_f32(vld1q_f32(input0)), mask);
uint32x4_t vin1 = vandq_u32(vreinterpretq_s32_f32(vld1q_f32(input1)), mask);
float32x4_t vout = vbslq_f32(vceqq_u32(vandq_u32(vin0, vin1), zeros), vfalse, vtrue);
vst1q_f32(output, vout);
#else
output[0] = (float)((bool)(input0[0]) & (bool)(input1[0]));
......@@ -222,11 +449,18 @@ int ElementLogicalOr(float *input0, float *input1, float *output, int element_si
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
#ifdef ENABLE_NEON
float32x4_t vtrue = {1, 1, 1, 1};
float32x4_t vfalse = {0, 0, 0, 0};
uint32x4_t mask = vmovq_n_u32((uint32_t(1u << 31) - 1));
uint32x4_t zeros = {0, 0, 0, 0};
#endif
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef USE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vorrq_f32(vin0, vin1);
#ifdef ENABLE_NEON
uint32x4_t vin0 = vandq_u32(vreinterpretq_s32_f32(vld1q_f32(input0)), mask);
uint32x4_t vin1 = vandq_u32(vreinterpretq_s32_f32(vld1q_f32(input1)), mask);
float32x4_t vout = vbslq_f32(vceqq_u32(vorrq_u32(vin0, vin1), zeros), vfalse, vtrue);
vst1q_f32(output, vout);
#else
output[0] = (float)((bool)(input0[0]) | (bool)(input1[0]));
......@@ -255,7 +489,7 @@ int ElementMaximum(float *input0, float *input1, float *output, int element_size
int block_c4 = element_size - block_mod;
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef USE_NEON
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vmaxq_f32(vin0, vin1);
......@@ -287,7 +521,7 @@ int ElementMinimum(float *input0, float *input1, float *output, int element_size
int block_c4 = element_size - block_mod;
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef USE_NEON
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vminq_f32(vin0, vin1);
......@@ -317,15 +551,15 @@ int BroadcastMinimum(float *input0, float *input1, float *tile_input0, float *ti
int ElementNotEqual(float *input0, float *input1, float *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
#ifdef USE_NEON
#ifdef ENABLE_NEON
float32x4_t vtrue = {1, 1, 1, 1};
float32x4_t vfalse = {0, 0, 0, 0};
#endif
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef USE_NEON
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vbslq_f32(vceqq_fp32(vin0, vin1), vfalse, vtrue);
float32x4_t vout = vbslq_f32(vceqq_f32(vin0, vin1), vfalse, vtrue);
vst1q_f32(output, vout);
#else
output[0] = (float)(input0[0] != input1[0]);
......@@ -352,15 +586,15 @@ int BroadcastNotEqual(float *input0, float *input1, float *tile_input0, float *t
int ElementEqual(float *input0, float *input1, float *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
#ifdef USE_NEON
#ifdef ENABLE_NEON
float32x4_t vtrue = {1, 1, 1, 1};
float32x4_t vfalse = {0, 0, 0, 0};
#endif
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef USE_NEON
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vbslq_f32(vceqq_fp32(vin0, vin1), vtrue, vfalse);
float32x4_t vout = vbslq_f32(vceqq_f32(vin0, vin1), vtrue, vfalse);
vst1q_f32(output, vout);
#else
output[0] = (float)(input0[0] == input1[0]);
......@@ -387,15 +621,15 @@ int BroadcastEqual(float *input0, float *input1, float *tile_input0, float *tile
int ElementLess(float *input0, float *input1, float *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
#ifdef USE_NEON
#ifdef ENABLE_NEON
float32x4_t vtrue = {1, 1, 1, 1};
float32x4_t vfalse = {0, 0, 0, 0};
#endif
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef USE_NEON
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vbslq_f32(vcltq_fp32(vin0, vin1), vtrue, vfalse);
float32x4_t vout = vbslq_f32(vcltq_f32(vin0, vin1), vtrue, vfalse);
vst1q_f32(output, vout);
#else
output[0] = (float)(input0[0] < input1[0]);
......@@ -422,15 +656,15 @@ int BroadcastLess(float *input0, float *input1, float *tile_input0, float *tile_
int ElementLessEqual(float *input0, float *input1, float *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
#ifdef USE_NEON
#ifdef ENABLE_NEON
float32x4_t vtrue = {1, 1, 1, 1};
float32x4_t vfalse = {0, 0, 0, 0};
#endif
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef USE_NEON
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vbslq_f32(vcleq_fp32(vin0, vin1), vtrue, vfalse);
float32x4_t vout = vbslq_f32(vcleq_f32(vin0, vin1), vtrue, vfalse);
vst1q_f32(output, vout);
#else
output[0] = (float)(input0[0] <= input1[0]);
......@@ -457,15 +691,15 @@ int BroadcastLessEqual(float *input0, float *input1, float *tile_input0, float *
int ElementGreater(float *input0, float *input1, float *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
#ifdef USE_NEON
#ifdef ENABLE_NEON
float32x4_t vtrue = {1, 1, 1, 1};
float32x4_t vfalse = {0, 0, 0, 0};
#endif
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef USE_NEON
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vbslq_f32(vcgtq_fp32(vin0, vin1), vtrue, vfalse);
float32x4_t vout = vbslq_f32(vcgtq_f32(vin0, vin1), vtrue, vfalse);
vst1q_f32(output, vout);
#else
output[0] = (float)(input0[0] > input1[0]);
......@@ -492,15 +726,15 @@ int BroadcastGreater(float *input0, float *input1, float *tile_input0, float *ti
int ElementGreaterEqual(float *input0, float *input1, float *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
#ifdef USE_NEON
#ifdef ENABLE_NEON
float32x4_t vtrue = {1, 1, 1, 1};
float32x4_t vfalse = {0, 0, 0, 0};
#endif
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef USE_NEON
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vbslq_f32(vcgeq_fp32(vin0, vin1), vtrue, vfalse);
float32x4_t vout = vbslq_f32(vcgeq_f32(vin0, vin1), vtrue, vfalse);
vst1q_f32(output, vout);
#else
output[0] = (float)(input0[0] >= input1[0]);
......@@ -523,4 +757,3 @@ int BroadcastGreaterEqual(float *input0, float *input1, float *tile_input0, floa
TileDimensions(input0, input1, tile_input0, tile_input1, param);
return ElementGreaterEqual(tile_input0, tile_input1, output, element_size);
}
......@@ -24,20 +24,28 @@
#include "src/runtime/kernel/arm/nnacl/errorcode.h"
int ElementMul(float *input0, float *input1, float *output, int element_size);
int ElementMulRelu(float *input0, float *input1, float *output, int element_size);
int ElementMulRelu6(float *input0, float *input1, float *output, int element_size);
int BroadcastMul(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size,
ArithmeticParameter *param);
int ElementAdd(float *input0, float *input1, float *output, int element_size);
int ElementAddRelu(float *input0, float *input1, float *output, int element_size);
int ElementAddRelu6(float *input0, float *input1, float *output, int element_size);
int BroadcastAdd(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size,
ArithmeticParameter *param);
int BroadcastAddInt8(int8_t *input0, int8_t *input1, int8_t *tile_input0, int8_t *tile_input1, int8_t *output,
int element_size, ArithmeticParameter *param);
int ElementSub(float *input0, float *input1, float *output, int element_size);
int ElementSubRelu(float *input0, float *input1, float *output, int element_size);
int ElementSubRelu6(float *input0, float *input1, float *output, int element_size);
int BroadcastSub(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size,
ArithmeticParameter *param);
int ElementDiv(float *input0, float *input1, float *output, int element_size);
int ElementDivRelu(float *input0, float *input1, float *output, int element_size);
int ElementDivRelu6(float *input0, float *input1, float *output, int element_size);
int BroadcastDiv(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size,
ArithmeticParameter *param);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册