diff --git a/docs/user_guide/op_lists.rst b/docs/user_guide/op_lists.rst index bcb0d31751e4372be69a9f82ee4de55f02a4483e..a3f96e5181d4e36539c0bf59391a08263b54b669 100644 --- a/docs/user_guide/op_lists.rst +++ b/docs/user_guide/op_lists.rst @@ -21,6 +21,7 @@ Operator lists "DEPTH_TO_SPACE","Y","" "DEQUANTIZE","Y","Model quantization will be supported later." "ELEMENT_WISE","Y","ADD/MUL/DIV/MIN/MAX/NEG/ABS/SQR_DIFF/POW/RSQRT/SQRT/EQUAL/FLOOR_DIV" + "ELU","Y","" "EMBEDDING_LOOKUP","Y","" "EXPANDDIMS","Y","Only CPU and TensorFlow is supported." "FILL","Y","Only CPU and TensorFlow is supported." diff --git a/mace/ops/activation.cc b/mace/ops/activation.cc index c542d98e467e74edf17198d713d5cca126e44bde..17a3a905d62542c89656b1322c78a543f3505486 100644 --- a/mace/ops/activation.cc +++ b/mace/ops/activation.cc @@ -54,7 +54,7 @@ class ActivationOp : public Operation { const Tensor *input = this->Input(0); Tensor *output = this->Output(0); - if (activation_type_ == PRELU) { + if (activation_type_ == PRELU || activation_type_ == ELU) { MACE_RETURN_IF_ERROR(output->ResizeLike(input)); const T *input_ptr = input->data(); T *output_ptr = output->mutable_data(); @@ -63,8 +63,8 @@ class ActivationOp : public Operation { const T *alpha_ptr = alpha->data(); const index_t outer_size = output->dim(0); const index_t inner_size = output->dim(2) * output->dim(3); - PReLUActivation(context, input_ptr, outer_size, input->dim(1), inner_size, - alpha_ptr, output_ptr); + ActivationWithAlpha(context, input_ptr, outer_size, input->dim(1), + inner_size, alpha_ptr, activation_type_, output_ptr); } else { activation_delegator_->Compute(context, input, output); } @@ -96,7 +96,7 @@ class ActivationOp : public Operation { } else { MACE_NOT_IMPLEMENTED; } - if (type == ActivationType::PRELU) { + if (type == ActivationType::PRELU || type == ActivationType::ELU) { MACE_CHECK(TransformFilter( context, operator_def_.get(), 1, OpenCLBufferType::ARGUMENT, mem_type) == MaceStatus::MACE_SUCCESS); diff --git a/mace/ops/activation.h b/mace/ops/activation.h index 4003dd309331a59d64c2ff6ace5299e7cc9587a6..95f65777a9ba943993d7a4db7edc25e7f6d39106 100644 --- a/mace/ops/activation.h +++ b/mace/ops/activation.h @@ -42,6 +42,8 @@ inline ActivationType StringToActivationType(const std::string type) { return ActivationType::NOOP; } else if (type == "LEAKYRELU") { return ActivationType::LEAKYRELU; + } else if (type == "ELU") { + return ActivationType::ELU; } else { LOG(FATAL) << "Unknown activation type: " << type; } @@ -49,13 +51,14 @@ inline ActivationType StringToActivationType(const std::string type) { } template -void PReLUActivation(const OpContext *context, - const T *input_ptr, - const index_t outer_size, - const index_t input_chan, - const index_t inner_size, - const T *alpha_ptr, - T *output_ptr) { +void ActivationWithAlpha(const OpContext *context, + const T *input_ptr, + const index_t outer_size, + const index_t input_chan, + const index_t inner_size, + const T *alpha_ptr, + const index_t activation_type, + T *output_ptr) { utils::ThreadPool &thread_pool = context->device()->cpu_runtime()->thread_pool(); @@ -66,7 +69,12 @@ void PReLUActivation(const OpContext *context, for (index_t j = 0; j < inner_size; ++j) { index_t idx = i * input_chan * inner_size + chan_idx * inner_size + j; if (input_ptr[idx] < 0) { - output_ptr[idx] = input_ptr[idx] * alpha_ptr[chan_idx]; + if (activation_type == ActivationType::PRELU) { + output_ptr[idx] = input_ptr[idx] * alpha_ptr[chan_idx]; + } else if (activation_type == ActivationType::ELU) { + output_ptr[idx] = + (std::exp(input_ptr[idx]) - 1) * alpha_ptr[chan_idx]; + } } else { output_ptr[idx] = input_ptr[idx]; } @@ -75,7 +83,6 @@ void PReLUActivation(const OpContext *context, } }, 0, outer_size, 1, 0, input_chan, 1); } - } // namespace ops } // namespace mace diff --git a/mace/ops/common/activation_type.h b/mace/ops/common/activation_type.h index de8f6e8b7cef4697c61749edcb88039c0f788667..2d844dd4d59b85591604d91773877048ef354f43 100644 --- a/mace/ops/common/activation_type.h +++ b/mace/ops/common/activation_type.h @@ -26,6 +26,7 @@ enum ActivationType { TANH = 4, SIGMOID = 5, LEAKYRELU = 6, + ELU = 7, }; } // namespace ops diff --git a/mace/ops/opencl/cl/activation.cl b/mace/ops/opencl/cl/activation.cl index 5dbd9cd9da28852be3569d41e56a91fd22eb8834..8e825eceab33f0dfa542e5168b863871be6ef8d0 100644 --- a/mace/ops/opencl/cl/activation.cl +++ b/mace/ops/opencl/cl/activation.cl @@ -3,7 +3,7 @@ __kernel void activation(OUT_OF_RANGE_PARAMS GLOBAL_WORK_GROUP_SIZE_DIM3 __read_only image2d_t input, -#ifdef USE_PRELU +#if defined (USE_PRELU) || defined (USE_ELU) __read_only image2d_t alpha, #endif __private const float relux_max_limit, @@ -23,9 +23,9 @@ __kernel void activation(OUT_OF_RANGE_PARAMS const int pos = mad24(ch_blk, width, w); DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb)); -#ifdef USE_PRELU - DATA_TYPE4 prelu_alpha = READ_IMAGET(alpha, SAMPLER, (int2)(ch_blk, 0)); - DATA_TYPE4 out = do_activation(in, prelu_alpha, relux_max_limit, leakyrelu_coefficient); +#if defined (USE_PRELU) || defined (USE_ELU) + DATA_TYPE4 activation_alpha = READ_IMAGET(alpha, SAMPLER, (int2)(ch_blk, 0)); + DATA_TYPE4 out = do_activation(in, activation_alpha, relux_max_limit, leakyrelu_coefficient); #else DATA_TYPE4 out = do_activation(in, relux_max_limit, leakyrelu_coefficient); #endif diff --git a/mace/ops/opencl/cl/common.h b/mace/ops/opencl/cl/common.h index 630498ceff32460a9c6d8ecd74a1581a4e1e3c54..0bd97045b869ef40333eb47cf3f4a048ab289b18 100644 --- a/mace/ops/opencl/cl/common.h +++ b/mace/ops/opencl/cl/common.h @@ -83,8 +83,8 @@ inline float4 do_sigmoid(float4 in) { #ifdef DATA_TYPE inline DATA_TYPE4 do_activation(DATA_TYPE4 in, -#ifdef USE_PRELU - DATA_TYPE4 prelu_alpha, +#if defined (USE_PRELU) || defined (USE_ELU) + DATA_TYPE4 alpha, #endif __private const float relux_max_limit, __private const float leakyrelu_coefficient) { @@ -96,7 +96,10 @@ inline DATA_TYPE4 do_activation(DATA_TYPE4 in, out = clamp(in, (DATA_TYPE4)0, relux_max_limit); #endif #ifdef USE_PRELU - out = select(prelu_alpha * in, in, in >= (DATA_TYPE)0); + out = select(alpha * in, in, in >= (DATA_TYPE)0); +#endif +#ifdef USE_ELU + out = select(alpha * (native_exp(in) - 1.0f), in, in >= (DATA_TYPE)0); #endif #ifdef USE_TANH out = tanh(in); diff --git a/mace/ops/opencl/image/activation.cc b/mace/ops/opencl/image/activation.cc index 3c8ed331820cb23801fb346d645ed0f7a138936d..f013c99071bb3347b5303ea9302f3099bd3878b7 100644 --- a/mace/ops/opencl/image/activation.cc +++ b/mace/ops/opencl/image/activation.cc @@ -58,6 +58,11 @@ MaceStatus ActivationKernel::Compute( built_options.emplace("-DUSE_PRELU"); break; } + case ELU: { + tuning_key_prefix_ = "elu_opencl_kernel"; + built_options.emplace("-DUSE_ELU"); + break; + } case TANH: { tuning_key_prefix_ = "tanh_opencl_kernel"; built_options.emplace("-DUSE_TANH"); @@ -94,7 +99,7 @@ MaceStatus ActivationKernel::Compute( MACE_OUT_OF_RANGE_SET_ARGS(kernel_); MACE_SET_3D_GWS_ARGS(kernel_, gws); kernel_.setArg(idx++, *(input->opencl_image())); - if (activation_ == PRELU) { + if (activation_ == PRELU || activation_ == ELU) { MACE_CHECK_NOTNULL(alpha); kernel_.setArg(idx++, *(alpha->opencl_image())); } diff --git a/mace/ops/opencl/image/winograd_conv2d.cc b/mace/ops/opencl/image/winograd_conv2d.cc index 539b4cf4f8604261dbc79d8536e84bcc3f9596d0..e5c2be9756f19aa70e1cdcfca7bb4c79faf4617b 100644 --- a/mace/ops/opencl/image/winograd_conv2d.cc +++ b/mace/ops/opencl/image/winograd_conv2d.cc @@ -161,6 +161,10 @@ MaceStatus WinogradOutputTransform(OpContext *context, built_options.emplace("-DUSE_PRELU"); break; } + case ELU: { + built_options.emplace("-DUSE_ELU"); + break; + } case TANH: { built_options.emplace("-DUSE_TANH"); break; diff --git a/test/ccbenchmark/mace/ops/activation_benchmark.cc b/test/ccbenchmark/mace/ops/activation_benchmark.cc index c1e79288ecaa39ae06c11c6815b6558bee3e879a..ee92d352bc30da122589352078f2b0bce1e3235c 100644 --- a/test/ccbenchmark/mace/ops/activation_benchmark.cc +++ b/test/ccbenchmark/mace/ops/activation_benchmark.cc @@ -208,6 +208,70 @@ MACE_BM_PRELU(1, 3, 512, 512); MACE_BM_PRELU(1, 32, 112, 112); MACE_BM_PRELU(1, 64, 256, 256); +namespace { +template +void EluBenchmark(int iters, int batch, int channels, int height, int width) { + mace::testing::StopTiming(); + + OpsTestNet net; + + // Add input data + if (D == DeviceType::CPU) { + net.AddRandomInput("Input", {batch, channels, height, width}); + } else if (D == DeviceType::GPU) { + net.AddRandomInput("Input", {batch, height, width, channels}); + } else { + MACE_NOT_IMPLEMENTED; + } + net.AddRandomInput("Alpha", {channels}, true); + + OpDefBuilder("Activation", "EluBM") + .Input("Input") + .Input("Alpha") + .Output("Output") + .AddStringArg("activation", "ELU") + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} +} // namespace + +#define MACE_BM_ELU_MACRO(N, C, H, W, TYPE, DEVICE) \ + static void MACE_BM_ELU_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + EluBenchmark(iters, N, C, H, W); \ + } \ + MACE_BENCHMARK(MACE_BM_ELU_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) + +#ifdef MACE_ENABLE_OPENCL +#define MACE_BM_ELU(N, C, H, W) \ + MACE_BM_ELU_MACRO(N, C, H, W, float, CPU); \ + MACE_BM_ELU_MACRO(N, C, H, W, float, GPU); \ + MACE_BM_ELU_MACRO(N, C, H, W, half, GPU) +#else +#define MACE_BM_ELU(N, C, H, W) \ + MACE_BM_ELU_MACRO(N, C, H, W, float, CPU) +#endif + +MACE_BM_ELU(1, 1, 512, 512); +MACE_BM_ELU(1, 3, 128, 128); +MACE_BM_ELU(1, 3, 512, 512); +MACE_BM_ELU(1, 32, 112, 112); +MACE_BM_ELU(1, 64, 256, 256); + namespace { template void TanhBenchmark(int iters, int batch, int channels, int height, int width) { diff --git a/test/ccunit/mace/ops/activation_test.cc b/test/ccunit/mace/ops/activation_test.cc index 27cfefb7a861e717ade5b1808aa1e7b76fbe4891..dfa978e7625200fdde2cad922c185299e71c945e 100644 --- a/test/ccunit/mace/ops/activation_test.cc +++ b/test/ccunit/mace/ops/activation_test.cc @@ -235,6 +235,59 @@ TEST_F(ActivationOpTest, OPENCLSimplePrelu) { TestSimplePrelu(); } +namespace { +template +void TestSimpleElu() { + OpsTestNet net; + + // Add input data + net.AddInputFromArray( + "Input", {2, 2, 2, 2}, + {-7, 7, -6, 6, -5, -5, -4, -4, -3, 3, -2, 2, -1, -1, 0, 0}); + net.AddInputFromArray("Alpha", {2}, {2.0, 3.0}, true); + + if (D == DeviceType::GPU) { + OpDefBuilder("Activation", "EluTest") + .Input("Input") + .Input("Alpha") + .Output("Output") + .AddStringArg("activation", "ELU") + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(D); + } else { + net.TransformDataFormat( + "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW); + OpDefBuilder("Activation", "EluTest") + .Input("InputNCHW") + .Input("Alpha") + .Output("OutputNCHW") + .AddStringArg("activation", "ELU") + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(D); + net.TransformDataFormat( + "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC); + } + + auto expected = net.CreateTensor( + {2, 2, 2, 2}, + {-1.998176236068891, 7, -1.9950424956466672, 6, -1.986524106001829, + -2.9797861590027437, -1.9633687222225316, -2.9450530833337973, + -1.900425863264272, 3, -1.7293294335267746, 2, -1.2642411176571153, + -1.896361676485673, 0, 0}); + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} +} // namespace + +TEST_F(ActivationOpTest, CPUSimpleElu) { TestSimpleElu(); } + +TEST_F(ActivationOpTest, OPENCLSimpleElu) { + TestSimpleElu(); +} + namespace { template void TestSimpleTanh() { diff --git a/tools/python/transform/base_converter.py b/tools/python/transform/base_converter.py index 696a59551701d12274a0f26d0dc0d95c5820ec60..a4750d398c2a2f033beb5634c69d347210414626 100644 --- a/tools/python/transform/base_converter.py +++ b/tools/python/transform/base_converter.py @@ -49,6 +49,7 @@ class ActivationType(Enum): TANH = 4 SIGMOID = 5 LEAKYRELU = 6 + ELU = 7 class EltwiseType(Enum): diff --git a/tools/python/transform/onnx_converter.py b/tools/python/transform/onnx_converter.py index e4217a53339c162b66ad176569fff0c47f8549c7..ca384e59a18053968bd46344f92b4c3a815d110f 100644 --- a/tools/python/transform/onnx_converter.py +++ b/tools/python/transform/onnx_converter.py @@ -85,7 +85,7 @@ OnnxSupportedOps = [ 'Div', 'Dropout', 'DynamicLSTM', - # 'Elu', + 'Elu', 'Equal', # 'Exp', # 'Expand', @@ -323,6 +323,7 @@ class OnnxConverter(base_converter.ConverterInterface): OnnxOpType.Relu.name: ActivationType.RELU, OnnxOpType.LeakyRelu.name: ActivationType.LEAKYRELU, OnnxOpType.PRelu.name: ActivationType.PRELU, + OnnxOpType.Elu.name: ActivationType.ELU, OnnxOpType.Tanh.name: ActivationType.TANH, OnnxOpType.Sigmoid.name: ActivationType.SIGMOID, } @@ -348,6 +349,7 @@ class OnnxConverter(base_converter.ConverterInterface): OnnxOpType.Dropout.name: self.convert_dropout, OnnxOpType.DimRange.name: self.convert_dim_range, OnnxOpType.Div.name: self.convert_eltwise, + OnnxOpType.Elu.name: self.convert_activation, OnnxOpType.Equal.name: self.convert_eltwise, OnnxOpType.ExtractPooling.name: self.convert_extract_pooling, OnnxOpType.Flatten.name: self.convert_flatten, @@ -627,7 +629,11 @@ class OnnxConverter(base_converter.ConverterInterface): type_arg.s = six.b(self.activation_type[node.op_type].name) if "alpha" in node.attrs: - alpha_value = node.attrs["alpha"] + alpha_tensor_name = node.name + '_alpha' + alpha_value = np.array([node.attrs["alpha"]]) + self.add_tensor(alpha_tensor_name, alpha_value.reshape(-1).shape, + mace_pb2.DT_FLOAT, alpha_value) + op.input.extend([alpha_tensor_name]) else: if node.op_type == OnnxOpType.LeakyRelu.name: alpha_value = 0.01 diff --git a/tools/python/transform/transformer.py b/tools/python/transform/transformer.py index ab81aa65f5e54cd964d7e78ee029b7c500528be3..459299bc3d623a57527fefbd7e7df867e575a9c3 100644 --- a/tools/python/transform/transformer.py +++ b/tools/python/transform/transformer.py @@ -977,7 +977,10 @@ class Transformer(base_converter.ConverterInterface): [ActivationType.RELU.name, ActivationType.RELUX.name]) else: - fold_consumer = (act_type != ActivationType.PRELU.name) + fold_consumer = ( + act_type != ActivationType.PRELU.name + and act_type != ActivationType.ELU.name + ) # during quantization, only fold relu/relux if (self._option.quantize_stat or self._option.quantize) \ and act_type not in [ActivationType.RELU.name,