From 10ba3061443736df1509bc07ef789826e65bc716 Mon Sep 17 00:00:00 2001 From: yejianwu Date: Wed, 19 Dec 2018 10:33:10 +0800 Subject: [PATCH] support leaky relu --- mace/ops/activation.cc | 14 +++-- mace/ops/activation.h | 10 ++-- mace/ops/activation_test.cc | 36 +++++++++++++ mace/ops/batch_norm.cc | 13 +++-- mace/ops/batch_norm_test.cc | 8 ++- mace/ops/conv_2d.cc | 15 ++++-- mace/ops/conv_2d_test.cc | 9 +++- mace/ops/deconv_2d.cc | 6 ++- mace/ops/deconv_2d.h | 5 +- mace/ops/deconv_2d_test.cc | 9 +++- mace/ops/depthwise_conv2d.cc | 9 ++-- mace/ops/depthwise_conv2d_test.cc | 12 ++--- mace/ops/depthwise_deconv2d.cc | 4 +- mace/ops/depthwise_deconv2d_test.cc | 9 +++- mace/ops/fully_connected.cc | 10 ++-- mace/ops/fully_connected_test.cc | 11 ++-- mace/ops/opencl/buffer/conv_2d.h | 8 ++- mace/ops/opencl/buffer/conv_2d_1x1.cc | 5 ++ mace/ops/opencl/buffer/conv_2d_general.cc | 5 ++ mace/ops/opencl/buffer/depthwise_conv2d.cc | 5 ++ mace/ops/opencl/buffer/depthwise_conv2d.h | 5 +- mace/ops/opencl/cl/activation.cl | 5 +- mace/ops/opencl/cl/batch_norm.cl | 7 +-- mace/ops/opencl/cl/common.h | 5 +- mace/ops/opencl/cl/conv_2d.cl | 11 ++-- mace/ops/opencl/cl/conv_2d_1x1.cl | 11 ++-- mace/ops/opencl/cl/conv_2d_1x1_buffer.cl | 7 +-- mace/ops/opencl/cl/conv_2d_3x3.cl | 13 ++--- mace/ops/opencl/cl/conv_2d_buffer.cl | 11 ++-- mace/ops/opencl/cl/deconv_2d.cl | 13 ++--- mace/ops/opencl/cl/depthwise_conv2d.cl | 22 ++++---- mace/ops/opencl/cl/depthwise_conv2d_buffer.cl | 11 ++-- mace/ops/opencl/cl/depthwise_deconv2d.cl | 15 +++--- mace/ops/opencl/cl/fully_connected.cl | 14 ++--- mace/ops/opencl/cl/winograd_transform.cl | 52 ++++++++++--------- mace/ops/opencl/conv_2d.h | 1 + mace/ops/opencl/deconv_2d.h | 1 + mace/ops/opencl/depthwise_conv2d.h | 1 + mace/ops/opencl/depthwise_deconv2d.h | 1 + mace/ops/opencl/fully_connected.h | 1 + mace/ops/opencl/image/activation.h | 8 ++- mace/ops/opencl/image/batch_norm.h | 14 +++-- mace/ops/opencl/image/conv_2d.h | 10 ++++ mace/ops/opencl/image/conv_2d_1x1.cc | 5 ++ mace/ops/opencl/image/conv_2d_3x3.cc | 5 ++ mace/ops/opencl/image/conv_2d_general.cc | 5 ++ mace/ops/opencl/image/deconv_2d.h | 6 +++ mace/ops/opencl/image/depthwise_conv2d.cc | 5 ++ mace/ops/opencl/image/depthwise_conv2d.h | 7 ++- mace/ops/opencl/image/depthwise_deconv2d.h | 6 +++ mace/ops/opencl/image/fully_connected.h | 6 +++ mace/ops/opencl/image/winograd_conv2d.cc | 9 +++- .../tools/converter_tool/base_converter.py | 1 + .../tools/converter_tool/caffe_converter.py | 8 +++ .../tools/converter_tool/onnx_converter.py | 2 +- .../tools/converter_tool/transformer.py | 4 +- 56 files changed, 366 insertions(+), 145 deletions(-) diff --git a/mace/ops/activation.cc b/mace/ops/activation.cc index fe8862bb..80bddab6 100644 --- a/mace/ops/activation.cc +++ b/mace/ops/activation.cc @@ -36,9 +36,11 @@ class ActivationOp : public Operation { : Operation(context), activation_(ops::StringToActivationType( Operation::GetOptionalArg("activation", - "NOOP"))), + "NOOP"))), relux_max_limit_(Operation::GetOptionalArg("max_limit", - 0.0f)) {} + 0.0f)), + leakyrelu_coefficient_(Operation::GetOptionalArg( + "leakyrelu_coefficient", 0.0f)) {} MaceStatus Run(OpContext *context) override { MACE_UNUSED(context); @@ -58,7 +60,7 @@ class ActivationOp : public Operation { alpha_ptr, output_ptr); } else { DoActivation(input_ptr, output_ptr, output->size(), activation_, - relux_max_limit_); + relux_max_limit_, leakyrelu_coefficient_); } return MaceStatus::MACE_SUCCESS; } @@ -66,6 +68,7 @@ class ActivationOp : public Operation { private: ActivationType activation_; float relux_max_limit_; + float leakyrelu_coefficient_; }; @@ -80,11 +83,14 @@ class ActivationOp : public Operation { "NOOP")); auto relux_max_limit = static_cast( Operation::GetOptionalArg("max_limit", 0.0f)); + auto leakyrelu_coefficient = static_cast( + Operation::GetOptionalArg("leakyrelu_coefficient", 0.0f)); MemoryType mem_type; if (context->device()->gpu_runtime()->UseImageMemory()) { mem_type = MemoryType::GPU_IMAGE; kernel_.reset( - new opencl::image::ActivationKernel(type, relux_max_limit)); + new opencl::image::ActivationKernel(type, relux_max_limit, + leakyrelu_coefficient)); } else { MACE_NOT_IMPLEMENTED; } diff --git a/mace/ops/activation.h b/mace/ops/activation.h index 24377d80..8ddcaea6 100644 --- a/mace/ops/activation.h +++ b/mace/ops/activation.h @@ -62,7 +62,8 @@ void DoActivation(const T *input_ptr, T *output_ptr, const index_t size, const ActivationType type, - const float relux_max_limit) { + const float relux_max_limit, + const float leakyrelu_coefficient) { MACE_CHECK(DataTypeToEnum::value != DataType::DT_HALF); switch (type) { @@ -97,7 +98,7 @@ void DoActivation(const T *input_ptr, #pragma omp parallel for schedule(runtime) for (index_t i = 0; i < size; ++i) { output_ptr[i] = std::max(input_ptr[i], static_cast(0)) - + std::min(input_ptr[i], static_cast(0)) * relux_max_limit; + + leakyrelu_coefficient * std::min(input_ptr[i], static_cast(0)); } break; default: @@ -110,7 +111,8 @@ inline void DoActivation(const float *input_ptr, float *output_ptr, const index_t size, const ActivationType type, - const float relux_max_limit) { + const float relux_max_limit, + const float leakyrelu_coefficient) { switch (type) { case NOOP: break; @@ -133,7 +135,7 @@ inline void DoActivation(const float *input_ptr, } break; case LEAKYRELU: - LeakyReluNeon(input_ptr, relux_max_limit, size, output_ptr); + LeakyReluNeon(input_ptr, leakyrelu_coefficient, size, output_ptr); break; default: LOG(FATAL) << "Unknown activation type: " << type; diff --git a/mace/ops/activation_test.cc b/mace/ops/activation_test.cc index 4cd63ab6..9932f167 100644 --- a/mace/ops/activation_test.cc +++ b/mace/ops/activation_test.cc @@ -52,6 +52,42 @@ TEST_F(ActivationOpTest, OPENCLSimpleRelu) { TestSimpleRelu(); } +namespace { +template +void TestSimpleLeakyRelu() { + OpsTestNet net; + + // Add input data + net.AddInputFromArray( + "Input", {2, 2, 2, 2}, + {-7, 7, -6, 6, -5, 5, -4, 4, -3, 3, -2, 2, -1, 1, 0, 0}); + + OpDefBuilder("Activation", "ReluTest") + .Input("Input") + .Output("Output") + .AddStringArg("activation", "LEAKYRELU") + .AddFloatArg("leakyrelu_coefficient", 0.1) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(D); + + auto expected = net.CreateTensor( + {2, 2, 2, 2}, + {-0.7, 7, -0.6, 6, -0.5, 5, -0.4, 4, -0.3, 3, -0.2, 2, -0.1, 1, 0, 0}); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} +} // namespace + +TEST_F(ActivationOpTest, CPUSimpleLeakyRelu) { + TestSimpleLeakyRelu(); +} + +TEST_F(ActivationOpTest, OPENCLSimpleLeakyRelu) { + TestSimpleLeakyRelu(); +} + namespace { template void TestUnalignedSimpleRelu() { diff --git a/mace/ops/batch_norm.cc b/mace/ops/batch_norm.cc index 3ca5592a..dae8f35e 100644 --- a/mace/ops/batch_norm.cc +++ b/mace/ops/batch_norm.cc @@ -35,10 +35,12 @@ class BatchNormOp : public Operation { explicit BatchNormOp(OpConstructContext *context) : Operation(context), epsilon_(Operation::GetOptionalArg("epsilon", - static_cast(1e-4))), + static_cast(1e-4))), activation_(ops::StringToActivationType( Operation::GetOptionalArg("activation", "NOOP"))), - relux_max_limit_(Operation::GetOptionalArg("max_limit", 0.0f)) {} + relux_max_limit_(Operation::GetOptionalArg("max_limit", 0.0f)), + leakyrelu_coefficient_(Operation::GetOptionalArg( + "leakyrelu_coefficient", 0.0f)) {} MaceStatus Run(OpContext *context) override { MACE_UNUSED(context); @@ -121,7 +123,7 @@ class BatchNormOp : public Operation { } } DoActivation(output_ptr, output_ptr, output->size(), activation_, - relux_max_limit_); + relux_max_limit_, leakyrelu_coefficient_); return MaceStatus::MACE_SUCCESS; } @@ -130,6 +132,7 @@ class BatchNormOp : public Operation { float epsilon_; const ActivationType activation_; const float relux_max_limit_; + const float leakyrelu_coefficient_; protected: MACE_OP_INPUT_TAGS(INPUT, SCALE, OFFSET, MEAN, VAR); @@ -148,11 +151,13 @@ class BatchNormOp : public Operation { ActivationType activation = ops::StringToActivationType( Operation::GetOptionalArg("activation", "NOOP")); float relux_max_limit = Operation::GetOptionalArg("max_limit", 0.0f); + float leakyrelu_coefficient = Operation::GetOptionalArg( + "leakyrelu_coefficient", 0.0f); MemoryType mem_type; if (context->device()->gpu_runtime()->UseImageMemory()) { mem_type = MemoryType::GPU_IMAGE; kernel_.reset(new opencl::image::BatchNormKernel( - epsilon, activation, relux_max_limit)); + epsilon, activation, relux_max_limit, leakyrelu_coefficient)); } else { MACE_NOT_IMPLEMENTED; } diff --git a/mace/ops/batch_norm_test.cc b/mace/ops/batch_norm_test.cc index d7c4903e..36398904 100644 --- a/mace/ops/batch_norm_test.cc +++ b/mace/ops/batch_norm_test.cc @@ -88,8 +88,8 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) { // Add input data net.AddRandomInput("Input", {batch, height, width, channels}); - net.AddRandomInput("Scale", {channels}, true); - net.AddRandomInput("Offset", {channels}, true); + net.AddRandomInput("Scale", {channels}, true, false); + net.AddRandomInput("Offset", {channels}, true, false); net.AddRandomInput("Mean", {channels}, true); net.AddRandomInput("Var", {channels}, true); @@ -105,6 +105,8 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) { .Input("Var") .AddFloatArg("epsilon", 1e-3) .Output("OutputNCHW") + .AddStringArg("activation", "LEAKYRELU") + .AddFloatArg("leakyrelu_coefficient", 0.1) .Finalize(net.NewOperatorDef()); // run cpu @@ -126,6 +128,8 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) { .Input("Var") .AddFloatArg("epsilon", 1e-3) .Output("Output") + .AddStringArg("activation", "LEAKYRELU") + .AddFloatArg("leakyrelu_coefficient", 0.1) .Finalize(net.NewOperatorDef()); // Tuning diff --git a/mace/ops/conv_2d.cc b/mace/ops/conv_2d.cc index 0a0d3bb5..2cc6f36e 100644 --- a/mace/ops/conv_2d.cc +++ b/mace/ops/conv_2d.cc @@ -58,6 +58,8 @@ class Conv2dOp : public ConvPool2dOpBase { Operation::GetOptionalArg("activation", "NOOP"))), relux_max_limit_(Operation::GetOptionalArg("max_limit", 0.0f)), + leakyrelu_coefficient_(Operation::GetOptionalArg( + "leakyrelu_coefficient", 0.0f)), is_filter_transformed_(false) {} MaceStatus Run(OpContext *context) override { @@ -520,7 +522,7 @@ class Conv2dOp : public ConvPool2dOpBase { } DoActivation(output_data, output_data, output->size(), activation_, - relux_max_limit_); + relux_max_limit_, leakyrelu_coefficient_); return MaceStatus::MACE_SUCCESS; } @@ -703,6 +705,7 @@ class Conv2dOp : public ConvPool2dOpBase { private: const ActivationType activation_; const float relux_max_limit_; + const float leakyrelu_coefficient_; bool is_filter_transformed_; SGemm sgemm_; @@ -721,7 +724,9 @@ class Conv2dOp : public ConvPool2dOpBase { activation_(ops::StringToActivationType( Operation::GetOptionalArg("activation", "NOOP"))), - relux_max_limit_(Operation::GetOptionalArg("max_limit", 0.0f)) {} + relux_max_limit_(Operation::GetOptionalArg("max_limit", 0.0f)), + leakyrelu_coefficient_(Operation::GetOptionalArg( + "leakyrelu_coefficient", 0.0f)) {} MaceStatus Run(OpContext *context) override { const Tensor *input = this->Input(INPUT); @@ -944,6 +949,7 @@ class Conv2dOp : public ConvPool2dOpBase { private: const ActivationType activation_; const float relux_max_limit_; + const float leakyrelu_coefficient_; private: MACE_OP_INPUT_TAGS(INPUT, FILTER, BIAS); @@ -961,6 +967,8 @@ class Conv2dOp : public ConvPool2dOpBase { Operation::GetOptionalArg("activation", "NOOP"))), relux_max_limit_(Operation::GetOptionalArg("max_limit", 0.0f)), + leakyrelu_coefficient_(Operation::GetOptionalArg( + "leakyrelu_coefficient", 0.0f)), wino_block_size_(Operation::GetOptionalArg("wino_block_size", 0)) { MemoryType mem_type; if (context->device()->gpu_runtime()->UseImageMemory()) { @@ -1007,12 +1015,13 @@ class Conv2dOp : public ConvPool2dOpBase { return kernel_->Compute(context, input, filter, bias, strides_.data(), padding_type_, paddings_, dilations_.data(), activation_, relux_max_limit_, - wino_block_size_, output); + leakyrelu_coefficient_, wino_block_size_, output); } private: const ActivationType activation_; const float relux_max_limit_; + const float leakyrelu_coefficient_; std::unique_ptr kernel_; int wino_block_size_; diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index d94b208d..335dd1a3 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -527,8 +527,9 @@ void TestComplexConvNxNS12(const std::vector &shape, // Add input data net.AddRandomInput("Input", {batch, height, width, input_channels}); net.AddRandomInput( - "Filter", {output_channels, input_channels, kernel_h, kernel_w}, true); - net.AddRandomInput("Bias", {output_channels}, true); + "Filter", {output_channels, input_channels, kernel_h, kernel_w}, true, + false); + net.AddRandomInput("Bias", {output_channels}, true, false); net.TransformDataFormat("Input", NHWC, "InputNCHW", NCHW); @@ -541,6 +542,8 @@ void TestComplexConvNxNS12(const std::vector &shape, .AddIntsArg("strides", {stride_h, stride_w}) .AddIntArg("padding", type) .AddIntsArg("dilations", {1, 1}) + .AddStringArg("activation", "LEAKYRELU") + .AddFloatArg("leakyrelu_coefficient", 0.1) .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); @@ -564,6 +567,8 @@ void TestComplexConvNxNS12(const std::vector &shape, .AddIntsArg("strides", {stride_h, stride_w}) .AddIntArg("padding", type) .AddIntsArg("dilations", {1, 1}) + .AddStringArg("activation", "LEAKYRELU") + .AddFloatArg("leakyrelu_coefficient", 0.1) .AddIntArg("T", static_cast(DataTypeToEnum::value)) .AddIntArg("wino_block_size", wino_blk_size) .Finalize(net.NewOperatorDef()); diff --git a/mace/ops/deconv_2d.cc b/mace/ops/deconv_2d.cc index 575e81ad..687aa445 100644 --- a/mace/ops/deconv_2d.cc +++ b/mace/ops/deconv_2d.cc @@ -292,7 +292,8 @@ class Deconv2dOp : public Deconv2dOpBase { output_data, output->size(), activation_, - relux_max_limit_); + relux_max_limit_, + leakyrelu_coefficient_); return MaceStatus::MACE_SUCCESS; } @@ -443,7 +444,8 @@ class Deconv2dOp : public Deconv2dOpBase { return kernel_->Compute(context, input, filter, bias, strides_.data(), in_paddings.data(), activation_, - relux_max_limit_, out_shape, output); + relux_max_limit_, leakyrelu_coefficient_, + out_shape, output); } private: diff --git a/mace/ops/deconv_2d.h b/mace/ops/deconv_2d.h index f6a4200c..c2e2c759 100644 --- a/mace/ops/deconv_2d.h +++ b/mace/ops/deconv_2d.h @@ -47,7 +47,9 @@ class Deconv2dOpBase : public Operation { Operation::GetOptionalArg("activation", "NOOP"))), relux_max_limit_( - Operation::GetOptionalArg("max_limit", 0.0f)) {} + Operation::GetOptionalArg("max_limit", 0.0f)), + leakyrelu_coefficient_( + Operation::GetOptionalArg("leakyrelu_coefficient", 0.0f)) {} static void CalcDeconvShape_Caffe( const index_t *input_shape, // NHWC @@ -191,6 +193,7 @@ class Deconv2dOpBase : public Operation { const FrameworkType model_type_; const ActivationType activation_; const float relux_max_limit_; + const float leakyrelu_coefficient_; }; template diff --git a/mace/ops/deconv_2d_test.cc b/mace/ops/deconv_2d_test.cc index 1847c943..19f8fd8e 100644 --- a/mace/ops/deconv_2d_test.cc +++ b/mace/ops/deconv_2d_test.cc @@ -377,8 +377,9 @@ void TestComplexDeconvNxN(const int batch, // Add input data net.AddRandomInput("Input", {batch, height, width, input_channels}); net.AddRandomInput( - "Filter", {output_channels, input_channels, kernel_h, kernel_w}, true); - net.AddRandomInput("Bias", {output_channels}, true); + "Filter", {output_channels, input_channels, kernel_h, kernel_w}, true, + false); + net.AddRandomInput("Bias", {output_channels}, true, false); net.TransformDataFormat("Input", NHWC, "InputNCHW", NCHW); int out_h = 0; @@ -418,6 +419,8 @@ void TestComplexDeconvNxN(const int batch, .AddIntsArg("strides", {stride_h, stride_w}) .AddIntsArg("padding_values", paddings) .AddIntArg("framework_type", model_type) + .AddStringArg("activation", "LEAKYRELU") + .AddFloatArg("leakyrelu_coefficient", 0.1) .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); } else { @@ -454,6 +457,8 @@ void TestComplexDeconvNxN(const int batch, .AddIntsArg("strides", {stride_h, stride_w}) .AddIntsArg("padding_values", paddings) .AddIntArg("framework_type", model_type) + .AddStringArg("activation", "LEAKYRELU") + .AddFloatArg("leakyrelu_coefficient", 0.1) .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); } else { diff --git a/mace/ops/depthwise_conv2d.cc b/mace/ops/depthwise_conv2d.cc index 2f849ef7..b74dcbb3 100644 --- a/mace/ops/depthwise_conv2d.cc +++ b/mace/ops/depthwise_conv2d.cc @@ -49,10 +49,13 @@ class DepthwiseConv2dOpBase : public ConvPool2dOpBase { activation_(ops::StringToActivationType( Operation::GetOptionalArg("activation", "NOOP"))), - relux_max_limit_(Operation::GetOptionalArg("max_limit", 0.0f)) {} + relux_max_limit_(Operation::GetOptionalArg("max_limit", 0.0f)), + leakyrelu_coefficient_(Operation::GetOptionalArg( + "leakyrelu_coefficient", 0.0f)) {} protected: const ActivationType activation_; const float relux_max_limit_; + const float leakyrelu_coefficient_; }; template @@ -218,7 +221,7 @@ class DepthwiseConv2dOp : public DepthwiseConv2dOpBase { } DoActivation(output_data, output_data, output->size(), activation_, - relux_max_limit_); + relux_max_limit_, leakyrelu_coefficient_); return MaceStatus::MACE_SUCCESS; } @@ -524,7 +527,7 @@ class DepthwiseConv2dOp : public DepthwiseConv2dOpBase { return kernel_->Compute(context, input, filter, bias, strides_.data(), padding_type_, paddings_, dilations_.data(), activation_, relux_max_limit_, - output); + leakyrelu_coefficient_, output); } private: diff --git a/mace/ops/depthwise_conv2d_test.cc b/mace/ops/depthwise_conv2d_test.cc index d9965658..770d36b1 100644 --- a/mace/ops/depthwise_conv2d_test.cc +++ b/mace/ops/depthwise_conv2d_test.cc @@ -244,10 +244,10 @@ void TestNxNS12(const index_t height, const index_t width) { net.AddRandomInput( "Input", {batch, height, width, channel}); net.AddRandomInput( - "Filter", {multiplier, channel, kernel_h, kernel_w}, true); + "Filter", {multiplier, channel, kernel_h, kernel_w}, true, false); net.AddRandomInput("Bias", {multiplier * channel}, - true); + true, false); net.TransformDataFormat("Input", NHWC, "InputNCHW", NCHW); @@ -260,8 +260,8 @@ void TestNxNS12(const index_t height, const index_t width) { .AddIntArg("padding", type) .AddIntsArg("dilations", {1, 1}) .AddIntArg("T", static_cast(DataTypeToEnum::value)) - .AddStringArg("activation", "RELUX") - .AddFloatArg("max_limit", 6.0) + .AddStringArg("activation", "LEAKYRELU") + .AddFloatArg("leakyrelu_coefficient", 0.1) .Finalize(net.NewOperatorDef()); // Run on cpu @@ -283,8 +283,8 @@ void TestNxNS12(const index_t height, const index_t width) { .AddIntArg("padding", type) .AddIntsArg("dilations", {1, 1}) .AddIntArg("T", static_cast(DataTypeToEnum::value)) - .AddStringArg("activation", "RELUX") - .AddFloatArg("max_limit", 6.0) + .AddStringArg("activation", "LEAKYRELU") + .AddFloatArg("leakyrelu_coefficient", 0.1) .Finalize(net.NewOperatorDef()); net.RunOp(DeviceType::GPU); diff --git a/mace/ops/depthwise_deconv2d.cc b/mace/ops/depthwise_deconv2d.cc index a4e7148e..1614ce76 100644 --- a/mace/ops/depthwise_deconv2d.cc +++ b/mace/ops/depthwise_deconv2d.cc @@ -281,7 +281,8 @@ class DepthwiseDeconv2dOp output_data, output->size(), activation_, - relux_max_limit_); + relux_max_limit_, + leakyrelu_coefficient_); return MaceStatus::MACE_SUCCESS; } @@ -458,6 +459,7 @@ class DepthwiseDeconv2dOp : public Deconv2dOpBase { group_, activation_, relux_max_limit_, + leakyrelu_coefficient_, out_shape, output); } diff --git a/mace/ops/depthwise_deconv2d_test.cc b/mace/ops/depthwise_deconv2d_test.cc index fe3b0b18..737fa502 100644 --- a/mace/ops/depthwise_deconv2d_test.cc +++ b/mace/ops/depthwise_deconv2d_test.cc @@ -185,12 +185,13 @@ void RandomTest(index_t batch, GenerateRandomRealTypeData({multiplier, channel, kernel, kernel}, &filter_data); net.AddInputFromArray( - "Filter", {multiplier, channel, kernel, kernel}, filter_data, true); + "Filter", {multiplier, channel, kernel, kernel}, filter_data, true, + false); std::vector bias_data(channel * multiplier); GenerateRandomRealTypeData({channel * multiplier}, &bias_data); net.AddInputFromArray("Bias", {channel * multiplier}, - bias_data, true); + bias_data, true, false); net.TransformDataFormat("Input", NHWC, "InputNCHW", NCHW); @@ -203,6 +204,8 @@ void RandomTest(index_t batch, .AddIntsArg("padding_values", {padding, padding}) .AddIntArg("group", channel) .AddIntsArg("dilations", {1, 1}) + .AddStringArg("activation", "LEAKYRELU") + .AddFloatArg("leakyrelu_coefficient", 0.1f) .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); // Run @@ -224,6 +227,8 @@ void RandomTest(index_t batch, .AddIntsArg("strides", {stride, stride}) .AddIntsArg("padding_values", {padding, padding}) .AddIntArg("group", channel) + .AddStringArg("activation", "LEAKYRELU") + .AddFloatArg("leakyrelu_coefficient", 0.1f) .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); diff --git a/mace/ops/fully_connected.cc b/mace/ops/fully_connected.cc index 31b1fb05..eb4e4b06 100644 --- a/mace/ops/fully_connected.cc +++ b/mace/ops/fully_connected.cc @@ -41,10 +41,13 @@ class FullyConnectedOpBase : public Operation { activation_(ops::StringToActivationType( Operation::GetOptionalArg("activation", "NOOP"))), - relux_max_limit_(Operation::GetOptionalArg("max_limit", 0.0f)) {} + relux_max_limit_(Operation::GetOptionalArg("max_limit", 0.0f)), + leakyrelu_coefficient_(Operation::GetOptionalArg( + "leakyrelu_coefficient", 0.0f)) {} protected: const ActivationType activation_; const float relux_max_limit_; + const float leakyrelu_coefficient_; MACE_OP_INPUT_TAGS(INPUT, WEIGHT, BIAS); MACE_OP_OUTPUT_TAGS(OUTPUT); @@ -104,7 +107,7 @@ class FullyConnectedOp : public FullyConnectedOpBase { } DoActivation(output_ptr, output_ptr, output->size(), activation_, - relux_max_limit_); + relux_max_limit_, leakyrelu_coefficient_); return MaceStatus::MACE_SUCCESS; } @@ -226,7 +229,8 @@ class FullyConnectedOp : public FullyConnectedOpBase { "The shape of Weight: ", MakeString(weight->shape()), " don't match."); return kernel_->Compute( - context, input, weight, bias, activation_, relux_max_limit_, output); + context, input, weight, bias, activation_, relux_max_limit_, + leakyrelu_coefficient_, output); } private: diff --git a/mace/ops/fully_connected_test.cc b/mace/ops/fully_connected_test.cc index 26134bb5..266b057c 100644 --- a/mace/ops/fully_connected_test.cc +++ b/mace/ops/fully_connected_test.cc @@ -123,10 +123,11 @@ void Random(const index_t batch, // Add input data net.AddRandomInput("Input", - {batch, height, width, channels}); + {batch, height, width, channels}, false, false); net.AddRandomInput( - "Weight", {out_channel, channels, height, width}, true); - net.AddRandomInput("Bias", {out_channel}, true); + "Weight", {out_channel, channels, height, width}, true, false); + net.AddRandomInput("Bias", {out_channel}, true, + false); net.TransformDataFormat("Input", NHWC, "InputNCHW", NCHW); @@ -135,6 +136,8 @@ void Random(const index_t batch, .Input("Weight") .Input("Bias") .Output("OutputNCHW") + .AddStringArg("activation", "LEAKYRELU") + .AddFloatArg("leakyrelu_coefficient", 0.1f) .Finalize(net.NewOperatorDef()); // run cpu @@ -152,6 +155,8 @@ void Random(const index_t batch, .Input("Weight") .Input("Bias") .Output("Output") + .AddStringArg("activation", "LEAKYRELU") + .AddFloatArg("leakyrelu_coefficient", 0.1f) .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); diff --git a/mace/ops/opencl/buffer/conv_2d.h b/mace/ops/opencl/buffer/conv_2d.h index dca57404..cd5c7464 100644 --- a/mace/ops/opencl/buffer/conv_2d.h +++ b/mace/ops/opencl/buffer/conv_2d.h @@ -38,6 +38,7 @@ extern MaceStatus Conv2d1x1(OpContext *context, const DataType dt, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const bool input_changed, Tensor *output, StatsFuture *future); @@ -52,6 +53,7 @@ extern MaceStatus Conv2dGeneral(OpContext *context, const DataType dt, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const bool input_changed, Tensor *output, StatsFuture *future); @@ -81,6 +83,7 @@ class Conv2dKernel : public OpenCLConv2dKernel { const int *dilations, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const int winograd_blk_size, Tensor *output) override; @@ -120,6 +123,7 @@ MaceStatus Conv2dKernel::Compute( const int *dilations, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const int winograd_blk_size, Tensor *output) { MACE_UNUSED(winograd_blk_size); @@ -221,14 +225,14 @@ MaceStatus Conv2dKernel::Compute( return conv2d::Conv2d1x1( context, &kernels_[1], pad_input, filter, bias, strides, DataTypeToEnum::v(), activation, relux_max_limit, - input_changed, output, &conv_future); + leakyrelu_coefficient, input_changed, output, &conv_future); }; } else { conv_func = [&](const Tensor *pad_input, Tensor *output) -> MaceStatus { return conv2d::Conv2dGeneral( context, &kernels_[1], pad_input, filter, bias, strides, dilations, DataTypeToEnum::v(), activation, relux_max_limit, - input_changed, output, &conv_future); + leakyrelu_coefficient, input_changed, output, &conv_future); }; } MACE_RETURN_IF_ERROR(conv_func(padded_input_ptr, output)); diff --git a/mace/ops/opencl/buffer/conv_2d_1x1.cc b/mace/ops/opencl/buffer/conv_2d_1x1.cc index abe7d93b..c32dc9bd 100644 --- a/mace/ops/opencl/buffer/conv_2d_1x1.cc +++ b/mace/ops/opencl/buffer/conv_2d_1x1.cc @@ -32,6 +32,7 @@ MaceStatus Conv2d1x1(OpContext *context, const DataType dt, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const bool input_changed, Tensor *output, StatsFuture *future) { @@ -71,6 +72,9 @@ MaceStatus Conv2d1x1(OpContext *context, case SIGMOID: built_options.emplace("-DUSE_SIGMOID"); break; + case LEAKYRELU: + built_options.emplace("-DUSE_LEAKYRELU"); + break; default: LOG(FATAL) << "Unknown activation type: " << activation; } @@ -106,6 +110,7 @@ MaceStatus Conv2d1x1(OpContext *context, kernel->setArg(idx++, strides[0]); kernel->setArg(idx++, strides[1]); kernel->setArg(idx++, relux_max_limit); + kernel->setArg(idx++, leakyrelu_coefficient); kernel->setArg(idx++, *(output->opencl_buffer())); } diff --git a/mace/ops/opencl/buffer/conv_2d_general.cc b/mace/ops/opencl/buffer/conv_2d_general.cc index e8ac509c..5d1dbb4b 100644 --- a/mace/ops/opencl/buffer/conv_2d_general.cc +++ b/mace/ops/opencl/buffer/conv_2d_general.cc @@ -33,6 +33,7 @@ MaceStatus Conv2dGeneral(OpContext *context, const DataType dt, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const bool input_changed, Tensor *output, StatsFuture *future) { @@ -76,6 +77,9 @@ MaceStatus Conv2dGeneral(OpContext *context, case SIGMOID: built_options.emplace("-DUSE_SIGMOID"); break; + case LEAKYRELU: + built_options.emplace("-DUSE_LEAKYRELU"); + break; default: LOG(FATAL) << "Unknown activation type: " << activation; } @@ -120,6 +124,7 @@ MaceStatus Conv2dGeneral(OpContext *context, kernel->setArg(idx++, static_cast( dilations[1] * in_channel)); kernel->setArg(idx++, relux_max_limit); + kernel->setArg(idx++, leakyrelu_coefficient); kernel->setArg(idx++, *(output->opencl_buffer())); } diff --git a/mace/ops/opencl/buffer/depthwise_conv2d.cc b/mace/ops/opencl/buffer/depthwise_conv2d.cc index d2c33599..08b74514 100644 --- a/mace/ops/opencl/buffer/depthwise_conv2d.cc +++ b/mace/ops/opencl/buffer/depthwise_conv2d.cc @@ -33,6 +33,7 @@ MaceStatus DepthwiseConv2d(OpContext *context, const DataType dt, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const bool input_changed, Tensor *output, StatsFuture *future) { @@ -76,6 +77,9 @@ MaceStatus DepthwiseConv2d(OpContext *context, case SIGMOID: built_options.emplace("-DUSE_SIGMOID"); break; + case LEAKYRELU: + built_options.emplace("-DUSE_LEAKYRELU"); + break; default: LOG(FATAL) << "Unknown activation type: " << activation; } @@ -116,6 +120,7 @@ MaceStatus DepthwiseConv2d(OpContext *context, kernel->setArg(idx++, static_cast( dilations[1] * in_channel)); kernel->setArg(idx++, relux_max_limit); + kernel->setArg(idx++, leakyrelu_coefficient); kernel->setArg(idx++, *(output->opencl_buffer())); } diff --git a/mace/ops/opencl/buffer/depthwise_conv2d.h b/mace/ops/opencl/buffer/depthwise_conv2d.h index 2d6ce0c8..6efc8f6e 100644 --- a/mace/ops/opencl/buffer/depthwise_conv2d.h +++ b/mace/ops/opencl/buffer/depthwise_conv2d.h @@ -39,6 +39,7 @@ MaceStatus DepthwiseConv2d(OpContext *context, const DataType dt, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const bool input_changed, Tensor *output, StatsFuture *future); @@ -60,6 +61,7 @@ class DepthwiseConv2dKernel : public OpenCLDepthwiseConv2dKernel { const int *dilations, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, Tensor *output) override; private: @@ -81,6 +83,7 @@ MaceStatus DepthwiseConv2dKernel::Compute( const int *dilations, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, Tensor *output) { StatsFuture pad_future, dw_conv_future; index_t filter_w = filter->dim(3); @@ -175,7 +178,7 @@ MaceStatus DepthwiseConv2dKernel::Compute( depthwise::DepthwiseConv2d( context, &kernels_[1], padded_input_ptr, filter, bias, strides, dilations, DataTypeToEnum::v(), activation, relux_max_limit, - input_changed, output, &dw_conv_future)); + leakyrelu_coefficient, input_changed, output, &dw_conv_future)); MergeMultipleFutureWaitFn({pad_future, dw_conv_future}, context->future()); return MaceStatus::MACE_SUCCESS; } diff --git a/mace/ops/opencl/cl/activation.cl b/mace/ops/opencl/cl/activation.cl index 62978d88..5dbd9cd9 100644 --- a/mace/ops/opencl/cl/activation.cl +++ b/mace/ops/opencl/cl/activation.cl @@ -7,6 +7,7 @@ __kernel void activation(OUT_OF_RANGE_PARAMS __read_only image2d_t alpha, #endif __private const float relux_max_limit, + __private const float leakyrelu_coefficient, __write_only image2d_t output) { const int ch_blk = get_global_id(0); const int w = get_global_id(1); @@ -24,9 +25,9 @@ __kernel void activation(OUT_OF_RANGE_PARAMS DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb)); #ifdef USE_PRELU DATA_TYPE4 prelu_alpha = READ_IMAGET(alpha, SAMPLER, (int2)(ch_blk, 0)); - DATA_TYPE4 out = do_activation(in, prelu_alpha, relux_max_limit); + DATA_TYPE4 out = do_activation(in, prelu_alpha, relux_max_limit, leakyrelu_coefficient); #else - DATA_TYPE4 out = do_activation(in, relux_max_limit); + DATA_TYPE4 out = do_activation(in, relux_max_limit, leakyrelu_coefficient); #endif WRITE_IMAGET(output, (int2)(pos, hb), out); diff --git a/mace/ops/opencl/cl/batch_norm.cl b/mace/ops/opencl/cl/batch_norm.cl index cf1f18c7..87da37d0 100644 --- a/mace/ops/opencl/cl/batch_norm.cl +++ b/mace/ops/opencl/cl/batch_norm.cl @@ -11,7 +11,8 @@ __kernel void batch_norm(OUT_OF_RANGE_PARAMS __private const float epsilon, #endif __write_only image2d_t output, - __private const float relux_max_limit) { + __private const float relux_max_limit, + __private const float leakyrelu_coefficient) { const int ch_blk = get_global_id(0); const int w = get_global_id(1); const int hb = get_global_id(2); @@ -43,8 +44,8 @@ __kernel void batch_norm(OUT_OF_RANGE_PARAMS DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb)); DATA_TYPE4 out = mad(in, bn_scale, bn_offset); -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) - out = do_activation(out, relux_max_limit); +#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + out = do_activation(out, relux_max_limit, leakyrelu_coefficient); #endif WRITE_IMAGET(output, (int2)(pos, hb), out); diff --git a/mace/ops/opencl/cl/common.h b/mace/ops/opencl/cl/common.h index 29054ad3..3c4b3a3d 100644 --- a/mace/ops/opencl/cl/common.h +++ b/mace/ops/opencl/cl/common.h @@ -86,7 +86,8 @@ inline DATA_TYPE4 do_activation(DATA_TYPE4 in, #ifdef USE_PRELU DATA_TYPE4 prelu_alpha, #endif - __private const float relux_max_limit) { + __private const float relux_max_limit, + __private const float leakyrelu_coefficient) { DATA_TYPE4 out; #ifdef USE_RELU out = fmax(in, (DATA_TYPE)0); @@ -104,7 +105,7 @@ inline DATA_TYPE4 do_activation(DATA_TYPE4 in, out = do_sigmoid(in); #endif #ifdef USE_LEAKYRELU - out = fmax(in, (DATA_TYPE)0) * relux_max_limit; + out = select(leakyrelu_coefficient * in, in, in >= (DATA_TYPE)0); #endif return out; } diff --git a/mace/ops/opencl/cl/conv_2d.cl b/mace/ops/opencl/cl/conv_2d.cl index b5ec1b15..4a3d9e0d 100644 --- a/mace/ops/opencl/cl/conv_2d.cl +++ b/mace/ops/opencl/cl/conv_2d.cl @@ -9,6 +9,7 @@ __kernel void conv_2d(OUT_OF_RANGE_PARAMS #endif __write_only image2d_t output, __private const float relux_max_limit, + __private const float leakyrelu_coefficient, __private const int in_height, __private const int in_width, __private const int in_ch_blks, @@ -123,11 +124,11 @@ __kernel void conv_2d(OUT_OF_RANGE_PARAMS } } -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) - out0 = do_activation(out0, relux_max_limit); - out1 = do_activation(out1, relux_max_limit); - out2 = do_activation(out2, relux_max_limit); - out3 = do_activation(out3, relux_max_limit); +#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + out0 = do_activation(out0, relux_max_limit, leakyrelu_coefficient); + out1 = do_activation(out1, relux_max_limit, leakyrelu_coefficient); + out2 = do_activation(out2, relux_max_limit, leakyrelu_coefficient); + out3 = do_activation(out3, relux_max_limit, leakyrelu_coefficient); #endif const int out_x_base = mul24(out_ch_blk, out_width); diff --git a/mace/ops/opencl/cl/conv_2d_1x1.cl b/mace/ops/opencl/cl/conv_2d_1x1.cl index db508ac9..d0dc2e15 100644 --- a/mace/ops/opencl/cl/conv_2d_1x1.cl +++ b/mace/ops/opencl/cl/conv_2d_1x1.cl @@ -9,6 +9,7 @@ __kernel void conv_2d_1x1(OUT_OF_RANGE_PARAMS #endif __write_only image2d_t output, __private const float relux_max_limit, + __private const float leakyrelu_coefficient, __private const int in_height, __private const int in_width, __private const int in_ch_blks, @@ -96,11 +97,11 @@ __kernel void conv_2d_1x1(OUT_OF_RANGE_PARAMS filter_x_base += 4; } -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) - out0 = do_activation(out0, relux_max_limit); - out1 = do_activation(out1, relux_max_limit); - out2 = do_activation(out2, relux_max_limit); - out3 = do_activation(out3, relux_max_limit); +#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + out0 = do_activation(out0, relux_max_limit, leakyrelu_coefficient); + out1 = do_activation(out1, relux_max_limit, leakyrelu_coefficient); + out2 = do_activation(out2, relux_max_limit, leakyrelu_coefficient); + out3 = do_activation(out3, relux_max_limit, leakyrelu_coefficient); #endif const int out_x_base = mul24(out_ch_blk, width); diff --git a/mace/ops/opencl/cl/conv_2d_1x1_buffer.cl b/mace/ops/opencl/cl/conv_2d_1x1_buffer.cl index 15cf5b59..be9f36e1 100644 --- a/mace/ops/opencl/cl/conv_2d_1x1_buffer.cl +++ b/mace/ops/opencl/cl/conv_2d_1x1_buffer.cl @@ -17,6 +17,7 @@ __kernel void conv2d(BUFFER_OUT_OF_RANGE_PARAMS __private const int stride_h, __private const int stride_w, __private const float relux_max_limit, + __private const float leakyrelu_coefficient, __global OUT_DATA_TYPE *output) { const int out_wc_blk_idx = get_global_id(0); const int out_hb_idx = get_global_id(1); @@ -79,9 +80,9 @@ __kernel void conv2d(BUFFER_OUT_OF_RANGE_PARAMS in_offset += 4; } -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) - out0 = do_activation(out0, relux_max_limit); - out1 = do_activation(out1, relux_max_limit); +#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + out0 = do_activation(out0, relux_max_limit, leakyrelu_coefficient); + out1 = do_activation(out1, relux_max_limit, leakyrelu_coefficient); #endif int out_offset = mad24(mad24(mad24(batch_idx, out_height, out_height_idx), diff --git a/mace/ops/opencl/cl/conv_2d_3x3.cl b/mace/ops/opencl/cl/conv_2d_3x3.cl index f4172e59..aeb85332 100644 --- a/mace/ops/opencl/cl/conv_2d_3x3.cl +++ b/mace/ops/opencl/cl/conv_2d_3x3.cl @@ -9,6 +9,7 @@ __kernel void conv_2d_3x3(OUT_OF_RANGE_PARAMS #endif __write_only image2d_t output, __private const float relux_max_limit, + __private const float leakyrelu_coefficient, __private const int in_height, __private const int in_width, __private const int in_ch_blks, @@ -128,12 +129,12 @@ __kernel void conv_2d_3x3(OUT_OF_RANGE_PARAMS } } -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) - out0 = do_activation(out0, relux_max_limit); - out1 = do_activation(out1, relux_max_limit); - out2 = do_activation(out2, relux_max_limit); - out3 = do_activation(out3, relux_max_limit); - out4 = do_activation(out4, relux_max_limit); +#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + out0 = do_activation(out0, relux_max_limit, leakyrelu_coefficient); + out1 = do_activation(out1, relux_max_limit, leakyrelu_coefficient); + out2 = do_activation(out2, relux_max_limit, leakyrelu_coefficient); + out3 = do_activation(out3, relux_max_limit, leakyrelu_coefficient); + out4 = do_activation(out4, relux_max_limit, leakyrelu_coefficient); #endif const int out_x_base = mul24(out_ch_blk, out_width); diff --git a/mace/ops/opencl/cl/conv_2d_buffer.cl b/mace/ops/opencl/cl/conv_2d_buffer.cl index 225e3a3b..41efc13a 100644 --- a/mace/ops/opencl/cl/conv_2d_buffer.cl +++ b/mace/ops/opencl/cl/conv_2d_buffer.cl @@ -22,6 +22,7 @@ __kernel void conv2d(BUFFER_OUT_OF_RANGE_PARAMS __private const int dilated_h_offset, __private const int dilated_w_offset, __private const float relux_max_limit, + __private const float leakyrelu_coefficient, __global OUT_DATA_TYPE *output) { const int out_wc_blk_idx = get_global_id(0); const int out_hb_idx = get_global_id(1); @@ -107,11 +108,11 @@ __kernel void conv2d(BUFFER_OUT_OF_RANGE_PARAMS } } -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) - out0 = do_activation(out0, relux_max_limit); - out1 = do_activation(out1, relux_max_limit); - out2 = do_activation(out2, relux_max_limit); - out3 = do_activation(out3, relux_max_limit); +#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + out0 = do_activation(out0, relux_max_limit, leakyrelu_coefficient); + out1 = do_activation(out1, relux_max_limit, leakyrelu_coefficient); + out2 = do_activation(out2, relux_max_limit, leakyrelu_coefficient); + out3 = do_activation(out3, relux_max_limit, leakyrelu_coefficient); #endif int out_offset = mad24(mad24(mad24(batch_idx, out_height, out_height_idx), diff --git a/mace/ops/opencl/cl/deconv_2d.cl b/mace/ops/opencl/cl/deconv_2d.cl index d39d3fe3..2837e5c7 100644 --- a/mace/ops/opencl/cl/deconv_2d.cl +++ b/mace/ops/opencl/cl/deconv_2d.cl @@ -9,6 +9,7 @@ __kernel void deconv_2d(OUT_OF_RANGE_PARAMS #endif __write_only image2d_t output, __private const float relux_max_limit, + __private const float leakyrelu_coefficient, __private const int in_height, __private const int in_width, __private const int in_channels, @@ -127,12 +128,12 @@ __kernel void deconv_2d(OUT_OF_RANGE_PARAMS } } -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) - out0 = do_activation(out0, relux_max_limit); - out1 = do_activation(out1, relux_max_limit); - out2 = do_activation(out2, relux_max_limit); - out3 = do_activation(out3, relux_max_limit); - out4 = do_activation(out4, relux_max_limit); +#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + out0 = do_activation(out0, relux_max_limit, leakyrelu_coefficient); + out1 = do_activation(out1, relux_max_limit, leakyrelu_coefficient); + out2 = do_activation(out2, relux_max_limit, leakyrelu_coefficient); + out3 = do_activation(out3, relux_max_limit, leakyrelu_coefficient); + out4 = do_activation(out4, relux_max_limit, leakyrelu_coefficient); #endif int2 out_pos; diff --git a/mace/ops/opencl/cl/depthwise_conv2d.cl b/mace/ops/opencl/cl/depthwise_conv2d.cl index 59761ee5..5a611968 100644 --- a/mace/ops/opencl/cl/depthwise_conv2d.cl +++ b/mace/ops/opencl/cl/depthwise_conv2d.cl @@ -10,6 +10,7 @@ __kernel void depthwise_conv2d(OUT_OF_RANGE_PARAMS #endif __write_only image2d_t output, __private const float relux_max_limit, + __private const float leakyrelu_coefficient, __private const short in_height, __private const short in_width, __private const short in_ch_blks, @@ -112,11 +113,11 @@ __kernel void depthwise_conv2d(OUT_OF_RANGE_PARAMS in_hb_idx += dilation_h; } -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) - out0 = do_activation(out0, relux_max_limit); - out1 = do_activation(out1, relux_max_limit); - out2 = do_activation(out2, relux_max_limit); - out3 = do_activation(out3, relux_max_limit); +#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + out0 = do_activation(out0, relux_max_limit, leakyrelu_coefficient); + out1 = do_activation(out1, relux_max_limit, leakyrelu_coefficient); + out2 = do_activation(out2, relux_max_limit, leakyrelu_coefficient); + out3 = do_activation(out3, relux_max_limit, leakyrelu_coefficient); #endif const short out_x_base = mul24(out_ch_blk, out_width); @@ -145,6 +146,7 @@ __kernel void depthwise_conv2d_s1(OUT_OF_RANGE_PARAMS #endif __write_only image2d_t output, __private const DATA_TYPE relux_max_limit, + __private const DATA_TYPE leakyrelu_coefficient, __private const short in_height, __private const short in_width, __private const short in_ch_blks, @@ -238,11 +240,11 @@ __kernel void depthwise_conv2d_s1(OUT_OF_RANGE_PARAMS in_hb_idx += 1; } -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) - out0 = do_activation(out0, relux_max_limit); - out1 = do_activation(out1, relux_max_limit); - out2 = do_activation(out2, relux_max_limit); - out3 = do_activation(out3, relux_max_limit); +#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + out0 = do_activation(out0, relux_max_limit, leakyrelu_coefficient); + out1 = do_activation(out1, relux_max_limit, leakyrelu_coefficient); + out2 = do_activation(out2, relux_max_limit, leakyrelu_coefficient); + out3 = do_activation(out3, relux_max_limit, leakyrelu_coefficient); #endif const short out_x_base = mul24(out_ch_blk, out_width); diff --git a/mace/ops/opencl/cl/depthwise_conv2d_buffer.cl b/mace/ops/opencl/cl/depthwise_conv2d_buffer.cl index efbd75c7..6c42c11f 100644 --- a/mace/ops/opencl/cl/depthwise_conv2d_buffer.cl +++ b/mace/ops/opencl/cl/depthwise_conv2d_buffer.cl @@ -22,6 +22,7 @@ __kernel void depthwise_conv2d(BUFFER_OUT_OF_RANGE_PARAMS __private const int dilated_h_offset, __private const int dilated_w_offset, __private const float relux_max_limit, + __private const float leakyrelu_coefficient, __global OUT_DATA_TYPE *output) { const int out_wc_blk_idx = get_global_id(0); const int out_hb_idx = get_global_id(1); @@ -85,11 +86,11 @@ __kernel void depthwise_conv2d(BUFFER_OUT_OF_RANGE_PARAMS } } -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) - out0 = do_activation(out0, relux_max_limit); - out1 = do_activation(out1, relux_max_limit); - out2 = do_activation(out2, relux_max_limit); - out3 = do_activation(out3, relux_max_limit); +#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + out0 = do_activation(out0, relux_max_limit, leakyrelu_coefficient); + out1 = do_activation(out1, relux_max_limit, leakyrelu_coefficient); + out2 = do_activation(out2, relux_max_limit, leakyrelu_coefficient); + out3 = do_activation(out3, relux_max_limit, leakyrelu_coefficient); #endif int out_offset = mad24(mad24(mad24(batch_idx, out_height, out_height_idx), diff --git a/mace/ops/opencl/cl/depthwise_deconv2d.cl b/mace/ops/opencl/cl/depthwise_deconv2d.cl index 9d648b65..b86bc96d 100644 --- a/mace/ops/opencl/cl/depthwise_deconv2d.cl +++ b/mace/ops/opencl/cl/depthwise_deconv2d.cl @@ -9,6 +9,7 @@ __kernel void depthwise_deconv2d(OUT_OF_RANGE_PARAMS #endif __write_only image2d_t output, __private const float relux_max_limit, + __private const float leakyrelu_coefficient, __private const int in_height, __private const int in_width, __private const int out_height, @@ -108,12 +109,12 @@ __kernel void depthwise_deconv2d(OUT_OF_RANGE_PARAMS } } -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) - out0 = do_activation(out0, relux_max_limit); - out1 = do_activation(out1, relux_max_limit); - out2 = do_activation(out2, relux_max_limit); - out3 = do_activation(out3, relux_max_limit); - out4 = do_activation(out4, relux_max_limit); +#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + out0 = do_activation(out0, relux_max_limit, leakyrelu_coefficient); + out1 = do_activation(out1, relux_max_limit, leakyrelu_coefficient); + out2 = do_activation(out2, relux_max_limit, leakyrelu_coefficient); + out3 = do_activation(out3, relux_max_limit, leakyrelu_coefficient); + out4 = do_activation(out4, relux_max_limit, leakyrelu_coefficient); #endif @@ -146,4 +147,4 @@ __kernel void depthwise_deconv2d(OUT_OF_RANGE_PARAMS out_pos.x += stride_w; WRITE_IMAGET(output, out_pos, out4); } -} \ No newline at end of file +} diff --git a/mace/ops/opencl/cl/fully_connected.cl b/mace/ops/opencl/cl/fully_connected.cl index 14e3ee64..f7f4bc48 100644 --- a/mace/ops/opencl/cl/fully_connected.cl +++ b/mace/ops/opencl/cl/fully_connected.cl @@ -12,7 +12,8 @@ __kernel void fully_connected(OUT_OF_RANGE_PARAMS __private const int input_height, __private const int input_width, __private const int input_channel, - __private const float relux_max_limit) { + __private const float relux_max_limit, + __private const float leakyrelu_coefficient) { const int batch_idx = get_global_id(0); const int out_blk_idx = get_global_id(1); const int input_chan_blk = (input_channel + 3) >> 2; @@ -56,8 +57,8 @@ __kernel void fully_connected(OUT_OF_RANGE_PARAMS input_coord.y++; } -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) - result = do_activation(result, relux_max_limit); +#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + result = do_activation(result, relux_max_limit, leakyrelu_coefficient); #endif WRITE_IMAGET(output, (int2)(out_blk_idx, batch_idx), result); @@ -77,7 +78,8 @@ __kernel void fully_connected_width(OUT_OF_RANGE_PARAMS __private const int input_width, __private const int in_chan_blks, __private const int out_blks, - __private const float relux_max_limit) { + __private const float relux_max_limit, + __private const float leakyrelu_coefficient) { const int inter_out_idx = get_global_id(0); const int width_blk_idx = get_global_id(1); const int width_blk_count = global_size_dim1; @@ -147,8 +149,8 @@ __kernel void fully_connected_width(OUT_OF_RANGE_PARAMS inter_idx += 4; } -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) - result = do_activation(result, relux_max_limit); +#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + result = do_activation(result, relux_max_limit, leakyrelu_coefficient); #endif WRITE_IMAGET(output, (int2)(out_blk_idx, batch_idx), result); diff --git a/mace/ops/opencl/cl/winograd_transform.cl b/mace/ops/opencl/cl/winograd_transform.cl index c6f9b149..d30427b5 100644 --- a/mace/ops/opencl/cl/winograd_transform.cl +++ b/mace/ops/opencl/cl/winograd_transform.cl @@ -127,7 +127,8 @@ __kernel void winograd_inverse_transform_2x2(OUT_OF_RANGE_PARAMS __private const int out_width, __private const int round_hw, __private const int round_w, - __private const float relux_max_limit) { + __private const float relux_max_limit, + __private const float leakyrelu_coefficient) { const int width_idx = get_global_id(0); const int height_idx = get_global_id(1); @@ -203,11 +204,11 @@ __kernel void winograd_inverse_transform_2x2(OUT_OF_RANGE_PARAMS #endif -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) - in0[0] = do_activation(in0[0], relux_max_limit); - in0[1] = do_activation(in0[1], relux_max_limit); - in1[0] = do_activation(in1[0], relux_max_limit); - in1[1] = do_activation(in1[1], relux_max_limit); +#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + in0[0] = do_activation(in0[0], relux_max_limit, leakyrelu_coefficient); + in0[1] = do_activation(in0[1], relux_max_limit, leakyrelu_coefficient); + in1[0] = do_activation(in1[0], relux_max_limit, leakyrelu_coefficient); + in1[1] = do_activation(in1[1], relux_max_limit, leakyrelu_coefficient); #endif WRITE_IMAGET(output, (int2)(coord_x, coord_y), in0[0]); @@ -395,7 +396,8 @@ __kernel void winograd_inverse_transform_4x4(OUT_OF_RANGE_PARAMS __private const int out_width, __private const int round_hw, __private const int round_w, - __private const float relux_max_limit) { + __private const float relux_max_limit, + __private const float leakyrelu_coefficient) { const int width_idx = get_global_id(0); const int height_idx = get_global_id(1); @@ -515,23 +517,23 @@ __kernel void winograd_inverse_transform_4x4(OUT_OF_RANGE_PARAMS out3[3] += bias_value; #endif -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) - out0[0] = do_activation(out0[0], relux_max_limit); - out0[1] = do_activation(out0[1], relux_max_limit); - out0[2] = do_activation(out0[2], relux_max_limit); - out0[3] = do_activation(out0[3], relux_max_limit); - out1[0] = do_activation(out1[0], relux_max_limit); - out1[1] = do_activation(out1[1], relux_max_limit); - out1[2] = do_activation(out1[2], relux_max_limit); - out1[3] = do_activation(out1[3], relux_max_limit); - out2[0] = do_activation(out2[0], relux_max_limit); - out2[1] = do_activation(out2[1], relux_max_limit); - out2[2] = do_activation(out2[2], relux_max_limit); - out2[3] = do_activation(out2[3], relux_max_limit); - out3[0] = do_activation(out3[0], relux_max_limit); - out3[1] = do_activation(out3[1], relux_max_limit); - out3[2] = do_activation(out3[2], relux_max_limit); - out3[3] = do_activation(out3[3], relux_max_limit); +#if defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + out0[0] = do_activation(out0[0], relux_max_limit, leakyrelu_coefficient); + out0[1] = do_activation(out0[1], relux_max_limit, leakyrelu_coefficient); + out0[2] = do_activation(out0[2], relux_max_limit, leakyrelu_coefficient); + out0[3] = do_activation(out0[3], relux_max_limit, leakyrelu_coefficient); + out1[0] = do_activation(out1[0], relux_max_limit, leakyrelu_coefficient); + out1[1] = do_activation(out1[1], relux_max_limit, leakyrelu_coefficient); + out1[2] = do_activation(out1[2], relux_max_limit, leakyrelu_coefficient); + out1[3] = do_activation(out1[3], relux_max_limit, leakyrelu_coefficient); + out2[0] = do_activation(out2[0], relux_max_limit, leakyrelu_coefficient); + out2[1] = do_activation(out2[1], relux_max_limit, leakyrelu_coefficient); + out2[2] = do_activation(out2[2], relux_max_limit, leakyrelu_coefficient); + out2[3] = do_activation(out2[3], relux_max_limit, leakyrelu_coefficient); + out3[0] = do_activation(out3[0], relux_max_limit, leakyrelu_coefficient); + out3[1] = do_activation(out3[1], relux_max_limit, leakyrelu_coefficient); + out3[2] = do_activation(out3[2], relux_max_limit, leakyrelu_coefficient); + out3[3] = do_activation(out3[3], relux_max_limit, leakyrelu_coefficient); #endif const int num = min(4, out_width - out_width_idx); @@ -556,4 +558,4 @@ __kernel void winograd_inverse_transform_4x4(OUT_OF_RANGE_PARAMS for (int i = 0; i < num; ++i) { WRITE_IMAGET(output, (int2)(coord_x + i, coord_y + 3), out3[i]); } -} \ No newline at end of file +} diff --git a/mace/ops/opencl/conv_2d.h b/mace/ops/opencl/conv_2d.h index 03f2cd49..8d24e103 100644 --- a/mace/ops/opencl/conv_2d.h +++ b/mace/ops/opencl/conv_2d.h @@ -45,6 +45,7 @@ class OpenCLConv2dKernel { const int *dilations, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const int winograd_blk_size, Tensor *output) = 0; MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLConv2dKernel); diff --git a/mace/ops/opencl/deconv_2d.h b/mace/ops/opencl/deconv_2d.h index 69bc6f97..1240f5d2 100644 --- a/mace/ops/opencl/deconv_2d.h +++ b/mace/ops/opencl/deconv_2d.h @@ -36,6 +36,7 @@ class OpenCLDeconv2dKernel { const int *padding_data, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const std::vector &output_shape, Tensor *output) = 0; MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLDeconv2dKernel); diff --git a/mace/ops/opencl/depthwise_conv2d.h b/mace/ops/opencl/depthwise_conv2d.h index b993e120..cc6246b7 100644 --- a/mace/ops/opencl/depthwise_conv2d.h +++ b/mace/ops/opencl/depthwise_conv2d.h @@ -38,6 +38,7 @@ class OpenCLDepthwiseConv2dKernel { const int *dilations, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, Tensor *output) = 0; MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLDepthwiseConv2dKernel); }; diff --git a/mace/ops/opencl/depthwise_deconv2d.h b/mace/ops/opencl/depthwise_deconv2d.h index 4238f0d2..1abde096 100644 --- a/mace/ops/opencl/depthwise_deconv2d.h +++ b/mace/ops/opencl/depthwise_deconv2d.h @@ -39,6 +39,7 @@ class OpenCLDepthwiseDeconv2dKernel { const int group, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const std::vector &output_shape, Tensor *output) = 0; MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLDepthwiseDeconv2dKernel); diff --git a/mace/ops/opencl/fully_connected.h b/mace/ops/opencl/fully_connected.h index 952c5b9c..5f5c14cb 100644 --- a/mace/ops/opencl/fully_connected.h +++ b/mace/ops/opencl/fully_connected.h @@ -35,6 +35,7 @@ class OpenCLFullyConnectedKernel { const Tensor *bias, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, Tensor *output) = 0; MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLFullyConnectedKernel); }; diff --git a/mace/ops/opencl/image/activation.h b/mace/ops/opencl/image/activation.h index e8448fe0..69e3bed8 100644 --- a/mace/ops/opencl/image/activation.h +++ b/mace/ops/opencl/image/activation.h @@ -35,8 +35,10 @@ template class ActivationKernel : public OpenCLActivationKernel { public: ActivationKernel(ActivationType type, - T relux_max_limit) - : activation_(type), relux_max_limit_(relux_max_limit) {} + T relux_max_limit, + T leakyrelu_coefficient) + : activation_(type), relux_max_limit_(relux_max_limit), + leakyrelu_coefficient_(leakyrelu_coefficient) {} MaceStatus Compute( OpContext *context, @@ -47,6 +49,7 @@ class ActivationKernel : public OpenCLActivationKernel { private: ActivationType activation_; T relux_max_limit_; + T leakyrelu_coefficient_; cl::Kernel kernel_; uint32_t kwg_size_; std::vector input_shape_; @@ -128,6 +131,7 @@ MaceStatus ActivationKernel::Compute( kernel_.setArg(idx++, *(alpha->opencl_image())); } kernel_.setArg(idx++, static_cast(relux_max_limit_)); + kernel_.setArg(idx++, static_cast(leakyrelu_coefficient_)); kernel_.setArg(idx++, *(output->opencl_image())); input_shape_ = input->shape(); diff --git a/mace/ops/opencl/image/batch_norm.h b/mace/ops/opencl/image/batch_norm.h index 68908830..7589ce23 100644 --- a/mace/ops/opencl/image/batch_norm.h +++ b/mace/ops/opencl/image/batch_norm.h @@ -37,7 +37,8 @@ class BatchNormKernel : public OpenCLBatchNormKernel { BatchNormKernel( const float epsilon, const ActivationType activation, - const float relux_max_limit); + const float relux_max_limit, + const float leakyrelu_coefficient); MaceStatus Compute(OpContext *context, const Tensor *input, const Tensor *scale, @@ -50,6 +51,7 @@ class BatchNormKernel : public OpenCLBatchNormKernel { const float epsilon_; const ActivationType activation_; const float relux_max_limit_; + const float leakyrelu_coefficient_; cl::Kernel kernel_; uint32_t kwg_size_; std::vector input_shape_; @@ -58,10 +60,12 @@ class BatchNormKernel : public OpenCLBatchNormKernel { template BatchNormKernel::BatchNormKernel(const float epsilon, const ActivationType activation, - const float relux_max_limit) + const float relux_max_limit, + const float leakyrelu_coefficient) : epsilon_(epsilon), activation_(activation), - relux_max_limit_(relux_max_limit) {} + relux_max_limit_(relux_max_limit), + leakyrelu_coefficient_(leakyrelu_coefficient) {} template MaceStatus BatchNormKernel::Compute( @@ -115,6 +119,9 @@ MaceStatus BatchNormKernel::Compute( case SIGMOID: built_options.emplace("-DUSE_SIGMOID"); break; + case LEAKYRELU: + built_options.emplace("-DUSE_LEAKYRELU"); + break; default: LOG(FATAL) << "Unknown activation type: " << activation_; } @@ -140,6 +147,7 @@ MaceStatus BatchNormKernel::Compute( } kernel_.setArg(idx++, *(output->opencl_image())); kernel_.setArg(idx++, relux_max_limit_); + kernel_.setArg(idx++, leakyrelu_coefficient_); input_shape_ = input->shape(); } diff --git a/mace/ops/opencl/image/conv_2d.h b/mace/ops/opencl/image/conv_2d.h index 51c9d1df..2175bd14 100644 --- a/mace/ops/opencl/image/conv_2d.h +++ b/mace/ops/opencl/image/conv_2d.h @@ -38,6 +38,7 @@ extern MaceStatus Conv2dK1x1(OpContext *context, const int *dilations, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const DataType dt, std::vector *prev_input_shape, Tensor *output, @@ -53,6 +54,7 @@ extern MaceStatus Conv2dK3x3(OpContext *context, const int *dilations, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const DataType dt, std::vector *prev_input_shape, Tensor *output, @@ -68,6 +70,7 @@ extern MaceStatus Conv2d(OpContext *context, const int *dilations, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const DataType dt, std::vector *prev_input_shape, Tensor *output, @@ -81,6 +84,7 @@ extern MaceStatus WinogradConv2dK3x3S1(OpContext *context, const int *padding, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const DataType dt, const int wino_blk_size, std::vector *prev_input_shape, @@ -109,6 +113,7 @@ class Conv2dKernel : public OpenCLConv2dKernel { const int *dilations, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const int wino_blk_size, Tensor *output) override; @@ -169,6 +174,7 @@ MaceStatus Conv2dKernel::Compute( const int *dilations, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const int wino_blk_size, Tensor *output) { index_t kernel_h = filter->dim(2); @@ -217,6 +223,7 @@ MaceStatus Conv2dKernel::Compute( paddings.data(), activation, relux_max_limit, + leakyrelu_coefficient, DataTypeToEnum::value, wino_blk_size, &input_shape_, @@ -235,6 +242,7 @@ MaceStatus Conv2dKernel::Compute( dilations, activation, relux_max_limit, + leakyrelu_coefficient, DataTypeToEnum::value, &input_shape_, output, @@ -252,6 +260,7 @@ MaceStatus Conv2dKernel::Compute( dilations, activation, relux_max_limit, + leakyrelu_coefficient, DataTypeToEnum::value, &input_shape_, output, @@ -269,6 +278,7 @@ MaceStatus Conv2dKernel::Compute( dilations, activation, relux_max_limit, + leakyrelu_coefficient, DataTypeToEnum::value, &input_shape_, output, diff --git a/mace/ops/opencl/image/conv_2d_1x1.cc b/mace/ops/opencl/image/conv_2d_1x1.cc index 57be0750..fe154461 100644 --- a/mace/ops/opencl/image/conv_2d_1x1.cc +++ b/mace/ops/opencl/image/conv_2d_1x1.cc @@ -76,6 +76,7 @@ extern MaceStatus Conv2dK1x1(OpContext *context, const int *dilations, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const DataType dt, std::vector *prev_input_shape, Tensor *output, @@ -125,6 +126,9 @@ extern MaceStatus Conv2dK1x1(OpContext *context, case SIGMOID: built_options.emplace("-DUSE_SIGMOID"); break; + case LEAKYRELU: + built_options.emplace("-DUSE_LEAKYRELU"); + break; default: LOG(FATAL) << "Unknown activation type: " << activation; } @@ -154,6 +158,7 @@ extern MaceStatus Conv2dK1x1(OpContext *context, kernel->setArg(idx++, *(output->opencl_image())); // FIXME handle flexable data type: half not supported kernel->setArg(idx++, relux_max_limit); + kernel->setArg(idx++, leakyrelu_coefficient); kernel->setArg(idx++, static_cast(input_height)); kernel->setArg(idx++, static_cast(input_width)); kernel->setArg(idx++, static_cast(input_channel_blocks)); diff --git a/mace/ops/opencl/image/conv_2d_3x3.cc b/mace/ops/opencl/image/conv_2d_3x3.cc index f7905a0c..8c1abf3c 100644 --- a/mace/ops/opencl/image/conv_2d_3x3.cc +++ b/mace/ops/opencl/image/conv_2d_3x3.cc @@ -69,6 +69,7 @@ extern MaceStatus Conv2dK3x3(OpContext *context, const int *dilations, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const DataType dt, std::vector *prev_input_shape, Tensor *output, @@ -110,6 +111,9 @@ extern MaceStatus Conv2dK3x3(OpContext *context, case SIGMOID: built_options.emplace("-DUSE_SIGMOID"); break; + case LEAKYRELU: + built_options.emplace("-DUSE_LEAKYRELU"); + break; default: LOG(FATAL) << "Unknown activation type: " << activation; } @@ -138,6 +142,7 @@ extern MaceStatus Conv2dK3x3(OpContext *context, } kernel->setArg(idx++, *(output->opencl_image())); kernel->setArg(idx++, relux_max_limit); + kernel->setArg(idx++, leakyrelu_coefficient); kernel->setArg(idx++, static_cast(input->dim(1))); kernel->setArg(idx++, static_cast(input->dim(2))); kernel->setArg(idx++, static_cast(input_channel_blocks)); diff --git a/mace/ops/opencl/image/conv_2d_general.cc b/mace/ops/opencl/image/conv_2d_general.cc index 28bdea6c..0d514132 100644 --- a/mace/ops/opencl/image/conv_2d_general.cc +++ b/mace/ops/opencl/image/conv_2d_general.cc @@ -77,6 +77,7 @@ extern MaceStatus Conv2d(OpContext *context, const int *dilations, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const DataType dt, std::vector *prev_input_shape, Tensor *output, @@ -118,6 +119,9 @@ extern MaceStatus Conv2d(OpContext *context, case SIGMOID: built_options.emplace("-DUSE_SIGMOID"); break; + case LEAKYRELU: + built_options.emplace("-DUSE_LEAKYRELU"); + break; default: LOG(FATAL) << "Unknown activation type: " << activation; } @@ -146,6 +150,7 @@ extern MaceStatus Conv2d(OpContext *context, } kernel->setArg(idx++, *(output->opencl_image())); kernel->setArg(idx++, relux_max_limit); + kernel->setArg(idx++, leakyrelu_coefficient); kernel->setArg(idx++, static_cast(input->dim(1))); kernel->setArg(idx++, static_cast(input->dim(2))); kernel->setArg(idx++, static_cast(input_channel_blocks)); diff --git a/mace/ops/opencl/image/deconv_2d.h b/mace/ops/opencl/image/deconv_2d.h index a8dd9c26..ad5d198b 100644 --- a/mace/ops/opencl/image/deconv_2d.h +++ b/mace/ops/opencl/image/deconv_2d.h @@ -42,6 +42,7 @@ class Deconv2dKernel : public OpenCLDeconv2dKernel { const int *padding_data, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const std::vector &output_shape, Tensor *output) override; @@ -61,6 +62,7 @@ MaceStatus Deconv2dKernel::Compute( const int *padding_data, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const std::vector &output_shape, Tensor *output) { std::vector output_image_shape; @@ -119,6 +121,9 @@ MaceStatus Deconv2dKernel::Compute( case SIGMOID: built_options.emplace("-DUSE_SIGMOID"); break; + case LEAKYRELU: + built_options.emplace("-DUSE_LEAKYRELU"); + break; default: LOG(FATAL) << "Unknown activation type: " << activation; } @@ -146,6 +151,7 @@ MaceStatus Deconv2dKernel::Compute( } kernel_.setArg(idx++, *(output->opencl_image())); kernel_.setArg(idx++, relux_max_limit); + kernel_.setArg(idx++, leakyrelu_coefficient); kernel_.setArg(idx++, static_cast(input->dim(1))); kernel_.setArg(idx++, static_cast(input->dim(2))); kernel_.setArg(idx++, static_cast(input->dim(3))); diff --git a/mace/ops/opencl/image/depthwise_conv2d.cc b/mace/ops/opencl/image/depthwise_conv2d.cc index 57a4415e..428b773b 100644 --- a/mace/ops/opencl/image/depthwise_conv2d.cc +++ b/mace/ops/opencl/image/depthwise_conv2d.cc @@ -73,6 +73,7 @@ MaceStatus DepthwiseConv2d(OpContext *context, const int *dilations, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const DataType dt, std::vector *prev_input_shape, Tensor *output, @@ -126,6 +127,9 @@ MaceStatus DepthwiseConv2d(OpContext *context, case SIGMOID: built_options.emplace("-DUSE_SIGMOID"); break; + case LEAKYRELU: + built_options.emplace("-DUSE_LEAKYRELU"); + break; default: LOG(FATAL) << "Unknown activation type: " << activation; } @@ -159,6 +163,7 @@ MaceStatus DepthwiseConv2d(OpContext *context, } kernel->setArg(idx++, *(output->opencl_image())); kernel->setArg(idx++, relux_max_limit); + kernel->setArg(idx++, leakyrelu_coefficient); kernel->setArg(idx++, static_cast(input_height)); kernel->setArg(idx++, static_cast(input_width)); kernel->setArg(idx++, static_cast(input_channel_blocks)); diff --git a/mace/ops/opencl/image/depthwise_conv2d.h b/mace/ops/opencl/image/depthwise_conv2d.h index c4ee3cb7..af99c483 100644 --- a/mace/ops/opencl/image/depthwise_conv2d.h +++ b/mace/ops/opencl/image/depthwise_conv2d.h @@ -39,6 +39,7 @@ MaceStatus DepthwiseConv2d(OpContext *context, const int *dilations, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const DataType dt, std::vector *prev_input_shape, Tensor *output, @@ -60,6 +61,7 @@ class DepthwiseConv2dKernel : public OpenCLDepthwiseConv2dKernel { const int *dilations, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, Tensor *output) override; private: @@ -80,6 +82,7 @@ MaceStatus DepthwiseConv2dKernel::Compute( const int *dilations, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, Tensor *output) { index_t kernel_h = filter->dim(2); index_t kernel_w = filter->dim(3); @@ -118,8 +121,8 @@ MaceStatus DepthwiseConv2dKernel::Compute( return depthwise::DepthwiseConv2d( context, &kernel_, input, filter, bias, strides[0], paddings.data(), - dilations, activation, relux_max_limit, DataTypeToEnum::value, - &input_shape_, output, &kwg_size_); + dilations, activation, relux_max_limit, leakyrelu_coefficient, + DataTypeToEnum::value, &input_shape_, output, &kwg_size_); } } // namespace image diff --git a/mace/ops/opencl/image/depthwise_deconv2d.h b/mace/ops/opencl/image/depthwise_deconv2d.h index d07a1649..65fe4129 100644 --- a/mace/ops/opencl/image/depthwise_deconv2d.h +++ b/mace/ops/opencl/image/depthwise_deconv2d.h @@ -43,6 +43,7 @@ class DepthwiseDeconv2dKernel : public OpenCLDepthwiseDeconv2dKernel { const int group, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const std::vector &output_shape, Tensor *output) override; @@ -63,6 +64,7 @@ MaceStatus DepthwiseDeconv2dKernel::Compute( const int group, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const std::vector &output_shape, Tensor *output) { const index_t batch = output_shape[0]; @@ -125,6 +127,9 @@ MaceStatus DepthwiseDeconv2dKernel::Compute( case SIGMOID: built_options.emplace("-DUSE_SIGMOID"); break; + case LEAKYRELU: + built_options.emplace("-DUSE_LEAKYRELU"); + break; default: LOG(FATAL) << "Unknown activation type: " << activation; } @@ -152,6 +157,7 @@ MaceStatus DepthwiseDeconv2dKernel::Compute( } kernel_.setArg(idx++, *(output->opencl_image())); kernel_.setArg(idx++, relux_max_limit); + kernel_.setArg(idx++, leakyrelu_coefficient); kernel_.setArg(idx++, static_cast(input->dim(1))); kernel_.setArg(idx++, static_cast(input->dim(2))); kernel_.setArg(idx++, static_cast(height)); diff --git a/mace/ops/opencl/image/fully_connected.h b/mace/ops/opencl/image/fully_connected.h index d52e927f..b4a915ae 100644 --- a/mace/ops/opencl/image/fully_connected.h +++ b/mace/ops/opencl/image/fully_connected.h @@ -40,6 +40,7 @@ class FullyConnectedKernel : public OpenCLFullyConnectedKernel { const Tensor *bias, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, Tensor *output) override; private: @@ -57,6 +58,7 @@ MaceStatus FullyConnectedKernel::Compute( const Tensor *bias, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, Tensor *output) { std::vector output_shape = {input->dim(0), 1, 1, weight->dim(0)}; std::vector output_image_shape; @@ -98,6 +100,9 @@ MaceStatus FullyConnectedKernel::Compute( case SIGMOID: built_options.emplace("-DUSE_SIGMOID"); break; + case LEAKYRELU: + built_options.emplace("-DUSE_LEAKYRELU"); + break; default: LOG(FATAL) << "Unknown activation type: " << activation; } @@ -148,6 +153,7 @@ MaceStatus FullyConnectedKernel::Compute( kernel_.setArg(idx++, static_cast(RoundUpDiv4(input->dim(3)))); kernel_.setArg(idx++, static_cast(output_blocks)); kernel_.setArg(idx++, relux_max_limit); + kernel_.setArg(idx++, leakyrelu_coefficient); input_shape_ = input->shape(); } diff --git a/mace/ops/opencl/image/winograd_conv2d.cc b/mace/ops/opencl/image/winograd_conv2d.cc index 8d684e59..6d90b6d5 100644 --- a/mace/ops/opencl/image/winograd_conv2d.cc +++ b/mace/ops/opencl/image/winograd_conv2d.cc @@ -115,6 +115,7 @@ MaceStatus WinogradOutputTransform(OpContext *context, const int wino_blk_size, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const bool input_changed, Tensor *output_tensor, uint32_t *kwg_size, @@ -164,6 +165,9 @@ MaceStatus WinogradOutputTransform(OpContext *context, case SIGMOID: built_options.emplace("-DUSE_SIGMOID"); break; + case LEAKYRELU: + built_options.emplace("-DUSE_LEAKYRELU"); + break; default: LOG(FATAL) << "Unknown activation type: " << activation; } @@ -199,6 +203,7 @@ MaceStatus WinogradOutputTransform(OpContext *context, kernel->setArg(idx++, static_cast(round_h * round_w)); kernel->setArg(idx++, static_cast(round_w)); kernel->setArg(idx++, relux_max_limit); + kernel->setArg(idx++, leakyrelu_coefficient); } const std::vector lws = {*kwg_size / 8, 8, 0}; std::string tuning_key = @@ -222,6 +227,7 @@ extern MaceStatus WinogradConv2dK3x3S1(OpContext *context, const int *paddings, const ActivationType activation, const float relux_max_limit, + const float leakyrelu_coefficient, const DataType dt, const int wino_blk_size, std::vector *prev_input_shape, @@ -338,7 +344,8 @@ extern MaceStatus WinogradConv2dK3x3S1(OpContext *context, MACE_RETURN_IF_ERROR(WinogradOutputTransform( context, kernels[2], mm_output.get(), bias, dt, round_h, round_w, wino_blk_size, activation, relux_max_limit, - input_changed, output, kwg_size[2], &t_output_future)) + leakyrelu_coefficient, input_changed, output, kwg_size[2], + &t_output_future)) MergeMultipleFutureWaitFn({t_input_future, mm_future, t_output_future}, context->future()); diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index aeb626a6..968c76f7 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -168,6 +168,7 @@ class MaceKeyword(object): mace_element_type_str = 'type' mace_activation_type_str = 'activation' mace_activation_max_limit_str = 'max_limit' + mace_activation_leakyrelu_coefficient_str = 'leakyrelu_coefficient' mace_resize_size_str = 'size' mace_batch_to_space_crops_str = 'crops' mace_paddings_str = 'paddings' diff --git a/mace/python/tools/converter_tool/caffe_converter.py b/mace/python/tools/converter_tool/caffe_converter.py index c1fea314..a9b19bac 100644 --- a/mace/python/tools/converter_tool/caffe_converter.py +++ b/mace/python/tools/converter_tool/caffe_converter.py @@ -493,6 +493,14 @@ class CaffeConverter(base_converter.ConverterInterface): mace_pb2.DT_FLOAT, alpha_data) op.input.extend([alpha_tensor_name]) + negative_slope = caffe_op.layer.relu_param.negative_slope + if caffe_op.type == 'ReLU' and negative_slope != 0: + param_arg = op.arg.add() + param_arg.name = MaceKeyword.mace_activation_leakyrelu_coefficient_str # noqa + param_arg.f = caffe_op.layer.relu_param.negative_slope + + type_arg.s = six.b(ActivationType.LEAKYRELU.name) + def convert_folded_batchnorm(self, caffe_op): op = self.convert_general_op(caffe_op) op.type = MaceOp.BatchNorm.name diff --git a/mace/python/tools/converter_tool/onnx_converter.py b/mace/python/tools/converter_tool/onnx_converter.py index e2e2d5de..731d54e8 100644 --- a/mace/python/tools/converter_tool/onnx_converter.py +++ b/mace/python/tools/converter_tool/onnx_converter.py @@ -286,10 +286,10 @@ class OnnxConverter(base_converter.ConverterInterface): activation_type = { OnnxOpType.Relu.name: ActivationType.RELU, + OnnxOpType.LeakyRelu.name: ActivationType.LEAKYRELU, OnnxOpType.PRelu.name: ActivationType.PRELU, OnnxOpType.Tanh.name: ActivationType.TANH, OnnxOpType.Sigmoid.name: ActivationType.SIGMOID, - OnnxOpType.LeakyRelu.name: ActivationType.LEAKYRELU, } def __init__(self, option, src_model_file): diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 5e564fa4..2e9e35ab 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -893,7 +893,9 @@ class Transformer(base_converter.ConverterInterface): op.output[0] = consumer_op.output[0] for arg in consumer_op.arg: if arg.name == MaceKeyword.mace_activation_type_str \ - or arg.name == MaceKeyword.mace_activation_max_limit_str: # noqa + or arg.name == \ + MaceKeyword.mace_activation_max_limit_str \ + or arg.name == MaceKeyword.mace_activation_leakyrelu_coefficient_str: # noqa op.arg.extend([arg]) self.replace_quantize_info(op, consumer_op) -- GitLab