diff --git a/mace/kernels/neon/relu_neon.cc b/mace/kernels/neon/relu_neon.cc index b03b896079fd1cc7901856e9f00aae45839fd331..426d8c222da5c27fc79a158ba8dd4594c8ca63cb 100644 --- a/mace/kernels/neon/relu_neon.cc +++ b/mace/kernels/neon/relu_neon.cc @@ -12,26 +12,53 @@ template <> void ReluFunctor::operator()(const float *input, float *output, index_t size) { -#pragma omp parallel for num_threads(1) // no significant performance improve - for (int64_t i = 0; i < size; i += kCostPerGroup) { - int64_t count = std::min(static_cast(kCostPerGroup), size - i); - int nn = count >> 2; - int remain = count - (nn << 2); - const float *inptr = input + i; - float *outptr = output + i; - float32x4_t _zero = vdupq_n_f32(0.f); - for (; nn > 0; --nn) { - float32x4_t _inptr = vld1q_f32(inptr); - float32x4_t _outptr = vmaxq_f32(_inptr, _zero); - vst1q_f32(outptr, _outptr); + if (max_limit_ < 0) { +#pragma omp parallel for num_threads(1) // no significant perf improve + for (int64_t i = 0; i < size; i += kCostPerGroup) { + int64_t count = std::min(static_cast(kCostPerGroup), size - i); + int block = count >> 2; + int remain = count - (block << 2); + const float *inptr = input + i; + float *outptr = output + i; + float32x4_t zero = vdupq_n_f32(0.f); + for (; block > 0; --block) { + float32x4_t in = vld1q_f32(inptr); + float32x4_t out = vmaxq_f32(in, zero); + vst1q_f32(outptr, out); - inptr += 4; - outptr += 4; + inptr += 4; + outptr += 4; + } + for (; remain > 0; --remain) { + *outptr = std::max(*inptr, 0.f); + ++inptr; + ++outptr; + } } - for (; remain > 0; --remain) { - *outptr = std::max(*inptr, 0.f); - ++inptr; - ++outptr; + } else { +#pragma omp parallel for num_threads(1) // no significant perf improve + for (int64_t i = 0; i < size; i += kCostPerGroup) { + int64_t count = std::min(static_cast(kCostPerGroup), size - i); + int block = count >> 2; + int remain = count - (block << 2); + const float *inptr = input + i; + float *outptr = output + i; + float32x4_t zero = vdupq_n_f32(0.f); + float32x4_t vmax = vdupq_n_f32(max_limit_); + for (; block > 0; --block) { + float32x4_t in = vld1q_f32(inptr); + float32x4_t out = vmaxq_f32(in, zero); + out = vminq_f32(out, vmax); + vst1q_f32(outptr, out); + + inptr += 4; + outptr += 4; + } + for (; remain > 0; --remain) { + *outptr = std::min(std::max(*inptr, 0.f), max_limit_); + ++inptr; + ++outptr; + } } } }; diff --git a/mace/kernels/relu.h b/mace/kernels/relu.h index 79788f0392972db8ff6682ffb4621b59180307c7..71cd07ab91b7327e41fd1ca6d12c5250989df36f 100644 --- a/mace/kernels/relu.h +++ b/mace/kernels/relu.h @@ -12,9 +12,17 @@ namespace kernels { template struct ReluFunctor { + T max_limit_; + void operator()(const T *input, T *output, index_t size) { - for (index_t i = 0; i < size; ++i) { - output[i] = std::max(input[i], static_cast(0)); + if (max_limit_ < 0) { + for (index_t i = 0; i < size; ++i) { + output[i] = std::max(input[i], static_cast(0)); + } + } else { + for (index_t i = 0; i < size; ++i) { + output[i] = std::min(std::max(input[i], static_cast(0)), max_limit_); + } } } }; diff --git a/mace/ops/relu.h b/mace/ops/relu.h index c195c78f2a3769082ba8cc4c0e0b74434576597f..5f68cca99c8e6430b4ade79dce2de7a76a606bec 100644 --- a/mace/ops/relu.h +++ b/mace/ops/relu.h @@ -14,7 +14,10 @@ template class ReluOp : public Operator { public: ReluOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) {} + : Operator(operator_def, ws) { + functor_.max_limit_ = + OperatorBase::GetSingleArgument("max_limit", static_cast(-1)); + } bool Run() override { const Tensor* input_tensor = this->inputs_[0]; Tensor* output_tensor = this->outputs_[0]; diff --git a/mace/ops/relu_test.cc b/mace/ops/relu_test.cc index 1277722c2183cafbc714b4caa5e8c6ac671bbcb4..bf4c810040076ee14ad0405138600a3af3579361 100644 --- a/mace/ops/relu_test.cc +++ b/mace/ops/relu_test.cc @@ -18,7 +18,7 @@ TEST_F(ReluOpTest, ReluOp) { .Finalize(net.operator_def()); // Add input data - net.AddRandomInput("Input", {1, 2, 3, 4}); + net.AddRandomInput("Input", {1, 2, 3, 5}); // Run net.RunOp(); @@ -32,4 +32,29 @@ TEST_F(ReluOpTest, ReluOp) { ExpectTensorNear(expected, *net.GetOutput("Output"), 0.01); } +TEST_F(ReluOpTest, ReluOpWithMax) { + // Construct graph + auto& net = test_net(); + OpDefBuilder("Relu", "ReluTestWithMax") + .Input("Input") + .Output("Output") + .Finalize(net.operator_def()); + + // Add input data + net.AddRandomInput("Input", {1, 2, 3, 5}); + net.AddFloatArg("max_limit", 0.5); + + // Run + net.RunOp(); + + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + // Check + net.RunOp(DeviceType::NEON); + + ExpectTensorNear(expected, *net.GetOutput("Output"), 0.01); +} + + } // namespace mace