diff --git a/mace/kernels/activation.h b/mace/kernels/activation.h index 83acf4fb70311c710bcde1d7b08c1e6c4630879f..b768eb28c93a52502af8dd51f0e6aa13a9a145a8 100644 --- a/mace/kernels/activation.h +++ b/mace/kernels/activation.h @@ -46,8 +46,7 @@ void DoActivation(const T *input_ptr, T *output_ptr, const index_t size, const ActivationType type, - const float relux_max_limit, - const float prelu_alpha) { + const float relux_max_limit) { MACE_CHECK(DataTypeToEnum::value != DataType::DT_HALF); switch (type) { @@ -66,17 +65,6 @@ void DoActivation(const T *input_ptr, static_cast(relux_max_limit)); } break; - case PRELU: -#pragma omp parallel for - for (index_t i = 0; i < size; ++i) { - T in = input_ptr[i]; - if (in < 0) { - output_ptr[i] = in * prelu_alpha; - } else { - output_ptr[i] = in; - } - } - break; case TANH: #pragma omp parallel for for (index_t i = 0; i < size; ++i) { @@ -95,45 +83,70 @@ void DoActivation(const T *input_ptr, } } +template +void PReLUActivation(const T *input_ptr, + const index_t size, + const index_t input_chan, + const T *alpha_ptr, + T *output_ptr) { +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + const index_t chan_idx = i % input_chan; + T in = input_ptr[i]; + if (in < 0) { + output_ptr[i] = in * alpha_ptr[chan_idx]; + } else { + output_ptr[i] = in; + } + } + +} + template class ActivationFunctor { public: - ActivationFunctor(ActivationType type, T relux_max_limit, T prelu_alpha) + ActivationFunctor(ActivationType type, T relux_max_limit) : activation_(type), - relux_max_limit_(relux_max_limit), - prelu_alpha_(prelu_alpha) {} + relux_max_limit_(relux_max_limit){} - void operator()(const Tensor *input, Tensor *output, StatsFuture *future) { + void operator()(const Tensor *input, + const Tensor *alpha, + Tensor *output, + StatsFuture *future) { const T *input_ptr = input->data(); T *output_ptr = output->mutable_data(); - DoActivation(input_ptr, output_ptr, output->size(), activation_, relux_max_limit_, - prelu_alpha_); + if (activation_ == PRELU) { + const T *alpha_ptr = alpha == nullptr ? nullptr : alpha->data(); + PReLUActivation(input_ptr, output->size(), input->dim(3), alpha_ptr, output_ptr); + } else { + DoActivation(input_ptr, output_ptr, output->size(), activation_, relux_max_limit_); + } } private: ActivationType activation_; T relux_max_limit_; - T prelu_alpha_; }; template <> void ActivationFunctor::operator()( - const Tensor *input, Tensor *output, StatsFuture *future); + const Tensor *input, const Tensor *alpha, Tensor *output, StatsFuture *future); template class ActivationFunctor { public: - ActivationFunctor(ActivationType type, T relux_max_limit, T prelu_alpha) + ActivationFunctor(ActivationType type, T relux_max_limit) : activation_(type), - relux_max_limit_(relux_max_limit), - prelu_alpha_(prelu_alpha) {} + relux_max_limit_(relux_max_limit){} - void operator()(const Tensor *input, Tensor *output, StatsFuture *future); + void operator()(const Tensor *input, + const Tensor *alpha, + Tensor *output, + StatsFuture *future); private: ActivationType activation_; T relux_max_limit_; - T prelu_alpha_; cl::Kernel kernel_; std::string tuning_key_prefix_; }; diff --git a/mace/kernels/batch_norm.h b/mace/kernels/batch_norm.h index bf5035ded2d67a0469347a0a001e6bc66f1d437f..0d489f40af6398240772cbabd8ff99f5b72be96c 100644 --- a/mace/kernels/batch_norm.h +++ b/mace/kernels/batch_norm.h @@ -21,27 +21,23 @@ namespace kernels { struct BatchNormFunctorBase { BatchNormFunctorBase(bool folded_constant, const ActivationType activation, - const float relux_max_limit, - const float prelu_alpha) + const float relux_max_limit) : folded_constant_(folded_constant), activation_(activation), - relux_max_limit_(relux_max_limit), - prelu_alpha_(prelu_alpha) {} + relux_max_limit_(relux_max_limit){} const bool folded_constant_; const ActivationType activation_; const float relux_max_limit_; - const float prelu_alpha_; }; template struct BatchNormFunctor : BatchNormFunctorBase { BatchNormFunctor(const bool folded_constant, const ActivationType activation, - const float relux_max_limit, - const float prelu_alpha) + const float relux_max_limit) : BatchNormFunctorBase( - folded_constant, activation, relux_max_limit, prelu_alpha) {} + folded_constant, activation, relux_max_limit) {} void operator()(const Tensor *input, const Tensor *scale, @@ -132,7 +128,7 @@ struct BatchNormFunctor : BatchNormFunctorBase { } } DoActivation(output_ptr, output_ptr, output->NumElements(), activation_, - relux_max_limit_, prelu_alpha_); + relux_max_limit_); } }; @@ -150,10 +146,9 @@ template struct BatchNormFunctor : BatchNormFunctorBase { BatchNormFunctor(const bool folded_constant, const ActivationType activation, - const float relux_max_limit, - const float prelu_alpha) + const float relux_max_limit) : BatchNormFunctorBase( - folded_constant, activation, relux_max_limit, prelu_alpha) {} + folded_constant, activation, relux_max_limit) {} void operator()(const Tensor *input, const Tensor *scale, const Tensor *offset, diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index b9bf8a983bd4bbafca9e16a52e576bcbf378924c..e6c22cd9f927364c80733f7b98f46535582eaaf7 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -182,15 +182,13 @@ struct Conv2dFunctorBase { const std::vector &paddings, const int *dilations, const ActivationType activation, - const float relux_max_limit, - const float prelu_alpha) + const float relux_max_limit) : strides_(strides), padding_type_(padding_type), paddings_(paddings), dilations_(dilations), activation_(activation), - relux_max_limit_(relux_max_limit), - prelu_alpha_(prelu_alpha) {} + relux_max_limit_(relux_max_limit){} const int *strides_; // [stride_h, stride_w] const Padding padding_type_; @@ -198,7 +196,6 @@ struct Conv2dFunctorBase { const int *dilations_; // [dilation_h, dilation_w] const ActivationType activation_; const float relux_max_limit_; - const float prelu_alpha_; }; template @@ -208,15 +205,13 @@ struct Conv2dFunctor : Conv2dFunctorBase { const std::vector &paddings, const int *dilations, const ActivationType activation, - const float relux_max_limit, - const float prelu_alpha) + const float relux_max_limit) : Conv2dFunctorBase(strides, padding_type, paddings, dilations, activation, - relux_max_limit, - prelu_alpha) {} + relux_max_limit) {} void operator()(const Tensor *input, // NHWC const Tensor *filter, // HWOI @@ -229,11 +224,14 @@ struct Conv2dFunctor : Conv2dFunctorBase { std::vector output_shape(4); std::vector paddings(2); - kernels::CalcNHWCPaddingAndOutputSize( - input->shape().data(), filter->shape().data(), dilations_, strides_, - padding_type_, output_shape.data(), paddings.data()); - if (!paddings_.empty()) { + if (paddings_.empty()) { + kernels::CalcNHWCPaddingAndOutputSize( + input->shape().data(), filter->shape().data(), dilations_, strides_, + padding_type_, output_shape.data(), paddings.data()); + } else { paddings = paddings_; + CalcOutputSize(input->shape().data(), filter->shape().data(), paddings_.data(), + dilations_, strides_, RoundType::FLOOR, output_shape.data()); } output->Resize(output_shape); @@ -619,7 +617,7 @@ struct Conv2dFunctor : Conv2dFunctorBase { } } DoActivation(output_data, output_data, output->NumElements(), activation_, - relux_max_limit_, prelu_alpha_); + relux_max_limit_); } }; @@ -637,15 +635,13 @@ struct Conv2dFunctor : Conv2dFunctorBase { const std::vector &paddings, const int *dilations, const ActivationType activation, - const float relux_max_limit, - const float prelu_alpha) + const float relux_max_limit) : Conv2dFunctorBase(strides, padding_type, paddings, dilations, activation, - relux_max_limit, - prelu_alpha) {} + relux_max_limit) {} void operator()(const Tensor *input, const Tensor *filter, diff --git a/mace/kernels/conv_pool_2d_util.cc b/mace/kernels/conv_pool_2d_util.cc index fb009d790b3fa079030ed064502131d2ea6eac87..9b7160a7363df6b0883de821e25cf9fbc29ec33c 100644 --- a/mace/kernels/conv_pool_2d_util.cc +++ b/mace/kernels/conv_pool_2d_util.cc @@ -135,6 +135,44 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC output_shape[3] = output_channels; } +void CalcOutputSize(const index_t *input_shape, // NHWC + const index_t *filter_shape, // HWOI + const int *padding_size, + const int *dilations, + const int *strides, + const RoundType round_type, + index_t *output_shape) { + MACE_CHECK(dilations[0] > 0 && dilations[1] > 0, + "Invalid dilations, must >= 1"); + MACE_CHECK((dilations[0] == 1 || strides[0] == 1) && + (dilations[1] == 1 || strides[1] == 1), + "If dilations > 1, strides should be 1"); + MACE_CHECK_NOTNULL(output_shape); + MACE_CHECK_NOTNULL(padding_size); + /* + * Convlution arithmetic: + * o = floor((i + 2 * p - k - (k - 1) * (d - 1)) / s) + 1 + * Pooling arithmetic: + * o = ceil((i + 2 * p - k - (k - 1) * (d - 1)) / s) + 1 + * For details, see https://arxiv.org/pdf/1603.07285.pdf or + * http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html + */ + output_shape[0] = input_shape[0]; + if (round_type == FLOOR) { + output_shape[1] = static_cast(std::floor(1.0 * (input_shape[1] + padding_size[0] + - filter_shape[0] - (filter_shape[0] - 1) * (dilations[0] - 1)) / strides[0]) + 1); + output_shape[2] = static_cast(std::floor(1.0 * (input_shape[2] + padding_size[1] + - filter_shape[1] - (filter_shape[1] - 1) * (dilations[1] - 1)) / strides[1]) + 1); + } else { + output_shape[1] = static_cast(std::ceil(1.0 * (input_shape[1] + padding_size[0] + - filter_shape[0] - (filter_shape[0] - 1) * (dilations[0] - 1)) / strides[0]) + 1); + output_shape[2] = static_cast(std::ceil(1.0 * (input_shape[2] + padding_size[1] + - filter_shape[1] - (filter_shape[1] - 1) * (dilations[1] - 1)) / strides[1]) + 1); + } + output_shape[3] = filter_shape[2]; + +} + void CalPaddingSize(const index_t *input_shape, // NCHW const index_t *filter_shape, // OIHW const int *dilations, diff --git a/mace/kernels/conv_pool_2d_util.h b/mace/kernels/conv_pool_2d_util.h index de5410c1c779acb0088a2604714ddb1e46f8d2d2..24097e814de51f4286b89e059046cd27b9a22122 100644 --- a/mace/kernels/conv_pool_2d_util.h +++ b/mace/kernels/conv_pool_2d_util.h @@ -15,6 +15,11 @@ enum Padding { FULL = 2, // Pads with one less than the filter size on both sides }; +enum RoundType{ + FLOOR = 0, + CEIL = 1, +}; + namespace kernels { void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW @@ -33,6 +38,14 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NCHW index_t *output_shape, int *padding_size); +void CalcOutputSize(const index_t *input_shape, // NHWC + const index_t *filter_shape, // HWOI + const int *padding_size, + const int *dilations, + const int *strides, + const RoundType round_type, + index_t *output_shape); + void CalPaddingSize(const index_t *input_shape, // NCHW const index_t *filter_shape, // OIHW const int *dilations, diff --git a/mace/kernels/depthwise_conv2d.h b/mace/kernels/depthwise_conv2d.h index c72a4a6d59ff68e5a94a539c6c85c782f4aa9d1f..141119a9ba8a02002d5e37db46d74654973add11 100644 --- a/mace/kernels/depthwise_conv2d.h +++ b/mace/kernels/depthwise_conv2d.h @@ -241,15 +241,13 @@ struct DepthwiseConv2dFunctorBase { const std::vector &paddings, const int *dilations, const ActivationType activation, - const float relux_max_limit, - const float prelu_alpha) + const float relux_max_limit) : strides_(strides), padding_type_(padding_type), paddings_(paddings), dilations_(dilations), activation_(activation), - relux_max_limit_(relux_max_limit), - prelu_alpha_(prelu_alpha) {} + relux_max_limit_(relux_max_limit){} const int *strides_; // [stride_h, stride_w] const Padding padding_type_; @@ -257,7 +255,6 @@ struct DepthwiseConv2dFunctorBase { const int *dilations_; // [dilation_h, dilation_w] const ActivationType activation_; const float relux_max_limit_; - const float prelu_alpha_; }; template @@ -267,15 +264,13 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase { const std::vector &paddings, const int *dilations, const ActivationType activation, - const float relux_max_limit, - const float prelu_alpha) + const float relux_max_limit) : DepthwiseConv2dFunctorBase(strides, padding_type, paddings, dilations, activation, - relux_max_limit, - prelu_alpha) {} + relux_max_limit) {} void operator()(const Tensor *input, // NHWC const Tensor *filter, // HWIM @@ -295,11 +290,14 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase { std::vector output_shape(4); std::vector paddings(2); - kernels::CalcNHWCPaddingAndOutputSize( - input->shape().data(), fake_filter_shape.data(), dilations_, strides_, - padding_type_, output_shape.data(), paddings.data()); - if (!paddings_.empty()) { + if (paddings_.empty()) { + kernels::CalcNHWCPaddingAndOutputSize( + input->shape().data(), fake_filter_shape.data(), dilations_, strides_, + padding_type_, output_shape.data(), paddings.data()); + } else { paddings = paddings_; + CalcOutputSize(input->shape().data(), fake_filter_shape.data(), paddings_.data(), + dilations_, strides_, RoundType::FLOOR, output_shape.data()); } auto input_shape = fake_filter_shape; output->Resize(output_shape); @@ -405,7 +403,7 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase { output_ptr = output->mutable_data(); DoActivation(output_ptr, output_ptr, output->NumElements(), activation_, - relux_max_limit_, prelu_alpha_); + relux_max_limit_); } }; @@ -425,15 +423,13 @@ struct DepthwiseConv2dFunctor const std::vector &paddings, const int *dilations, const ActivationType activation, - const float relux_max_limit, - const float prelu_alpha) + const float relux_max_limit) : DepthwiseConv2dFunctorBase(strides, padding_type, paddings, dilations, activation, - relux_max_limit, - prelu_alpha) {} + relux_max_limit) {} void operator()(const Tensor *input, const Tensor *filter, diff --git a/mace/kernels/fully_connected.h b/mace/kernels/fully_connected.h index f6d7c8e589d7352c78e6cdd891ffa2dc69275bb8..4919e63aaeca3f29d902910dcb18ce459c255dcc 100644 --- a/mace/kernels/fully_connected.h +++ b/mace/kernels/fully_connected.h @@ -15,23 +15,19 @@ namespace kernels { struct FullyConnectedBase { FullyConnectedBase(const ActivationType activation, - const float relux_max_limit, - const float prelu_alpha) + const float relux_max_limit) : activation_(activation), - relux_max_limit_(relux_max_limit), - prelu_alpha_(prelu_alpha) {} + relux_max_limit_(relux_max_limit){} const ActivationType activation_; const float relux_max_limit_; - const float prelu_alpha_; }; template struct FullyConnectedFunctor : FullyConnectedBase { FullyConnectedFunctor(const ActivationType activation, - const float relux_max_limit, - const float prelu_alpha) : - FullyConnectedBase(activation, relux_max_limit, prelu_alpha) {} + const float relux_max_limit) : + FullyConnectedBase(activation, relux_max_limit) {} void operator()(const Tensor *input, const Tensor *weight, @@ -70,16 +66,15 @@ struct FullyConnectedFunctor : FullyConnectedBase { } DoActivation(output_ptr, output_ptr, output->NumElements(), activation_, - relux_max_limit_, prelu_alpha_); + relux_max_limit_); } }; template struct FullyConnectedFunctor : FullyConnectedBase { FullyConnectedFunctor(const ActivationType activation, - const float relux_max_limit, - const float prelu_alpha) : - FullyConnectedBase(activation, relux_max_limit, prelu_alpha) {} + const float relux_max_limit) : + FullyConnectedBase(activation, relux_max_limit) {} void operator()(const Tensor *input, const Tensor *weight, diff --git a/mace/kernels/opencl/activation_opencl.cc b/mace/kernels/opencl/activation_opencl.cc index 934ce8da666b8e9910815b9f56de250a3511fd68..dee010875e853ba192cf05b90abfcdf5ee2cb48f 100644 --- a/mace/kernels/opencl/activation_opencl.cc +++ b/mace/kernels/opencl/activation_opencl.cc @@ -14,6 +14,7 @@ namespace kernels { template void ActivationFunctor::operator()(const Tensor *input, + const Tensor *alpha, Tensor *output, StatsFuture *future) { const index_t batch = input->dim(0); @@ -60,8 +61,10 @@ void ActivationFunctor::operator()(const Tensor *input, runtime->BuildKernel("activation", kernel_name, built_options); int idx = 0; kernel_.setArg(idx++, *(static_cast(input->buffer()))); + if (activation_ == PRELU) { + kernel_.setArg(idx++, *(static_cast(alpha->buffer()))); + } kernel_.setArg(idx++, static_cast(relux_max_limit_)); - kernel_.setArg(idx++, static_cast(prelu_alpha_)); kernel_.setArg(idx++, *(static_cast(output->buffer()))); } diff --git a/mace/kernels/opencl/batch_norm_opencl.cc b/mace/kernels/opencl/batch_norm_opencl.cc index d88fed51cee468073dcf5521c40a8187f892dd5c..7696e875e538b0f06aefb4e30c4032fcba56a538 100644 --- a/mace/kernels/opencl/batch_norm_opencl.cc +++ b/mace/kernels/opencl/batch_norm_opencl.cc @@ -50,9 +50,6 @@ void BatchNormFunctor::operator()(const Tensor *input, case RELUX: built_options.emplace("-DUSE_RELUX"); break; - case PRELU: - built_options.emplace("-DUSE_PRELU"); - break; case TANH: built_options.emplace("-DUSE_TANH"); break; @@ -79,7 +76,6 @@ void BatchNormFunctor::operator()(const Tensor *input, } kernel_.setArg(idx++, *(static_cast(output->buffer()))); kernel_.setArg(idx++, relux_max_limit_); - kernel_.setArg(idx++, prelu_alpha_); } const uint32_t gws[3] = {static_cast(channel_blocks), diff --git a/mace/kernels/opencl/cl/activation.cl b/mace/kernels/opencl/cl/activation.cl index fe8619e20b7f567bc36dfa4bc5d6c53bd5f792fb..bee0b0e35313b4129fe6741cd9575f88b60e1431 100644 --- a/mace/kernels/opencl/cl/activation.cl +++ b/mace/kernels/opencl/cl/activation.cl @@ -1,8 +1,10 @@ #include __kernel void activation(__read_only image2d_t input, +#ifdef USE_PRELU + __read_only image2d_t alpha, +#endif __private const float relux_max_limit, - __private const float prelu_alpha, __write_only image2d_t output) { const int ch_blk = get_global_id(0); const int w = get_global_id(1); @@ -11,7 +13,12 @@ __kernel void activation(__read_only image2d_t input, const int pos = mad24(ch_blk, width, w); DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb)); - DATA_TYPE4 out = do_activation(in, relux_max_limit, prelu_alpha); +#ifdef USE_PRELU + DATA_TYPE4 prelu_alpha = READ_IMAGET(alpha, SAMPLER, (int2)(ch_blk, 0)); + DATA_TYPE4 out = do_activation(in, prelu_alpha, relux_max_limit); +#else + DATA_TYPE4 out = do_activation(in, relux_max_limit); +#endif WRITE_IMAGET(output, (int2)(pos, hb), out); } diff --git a/mace/kernels/opencl/cl/batch_norm.cl b/mace/kernels/opencl/cl/batch_norm.cl index 99c00fabe4d7a988d8388c0d0fa61d45301baf09..773b59c44e0021ab68a4d621514056d5327b5427 100644 --- a/mace/kernels/opencl/cl/batch_norm.cl +++ b/mace/kernels/opencl/cl/batch_norm.cl @@ -9,8 +9,7 @@ __kernel void batch_norm(__read_only image2d_t input, __private const float epsilon, #endif __write_only image2d_t output, - __private const float relux_max_limit, - __private const float prelu_alpha) { + __private const float relux_max_limit) { const int ch_blk = get_global_id(0); const int w = get_global_id(1); const int hb = get_global_id(2); @@ -35,8 +34,8 @@ __kernel void batch_norm(__read_only image2d_t input, DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb)); DATA_TYPE4 out = mad(in, bn_scale, bn_offset); -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID) - out = do_activation(out, relux_max_limit, prelu_alpha); +#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + out = do_activation(out, relux_max_limit); #endif WRITE_IMAGET(output, (int2)(pos, hb), out); diff --git a/mace/kernels/opencl/cl/common.h b/mace/kernels/opencl/cl/common.h index 13b20e05c07e5ebddb78bc94c8decd7694bbbc25..28b9addd3dfe67f2caac81be7890ecd6624f0e90 100644 --- a/mace/kernels/opencl/cl/common.h +++ b/mace/kernels/opencl/cl/common.h @@ -22,8 +22,10 @@ __constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | inline DATA_TYPE4 do_activation(DATA_TYPE4 in, - __private const float relux_max_limit, - __private const float prelu_alpha) { +#ifdef USE_PRELU + DATA_TYPE4 prelu_alpha, +#endif + __private const float relux_max_limit) { DATA_TYPE4 out; #ifdef USE_RELU out = fmax(in, 0); diff --git a/mace/kernels/opencl/cl/conv_2d.cl b/mace/kernels/opencl/cl/conv_2d.cl index 35e17da8ff7b3337fb7cc30f73dcd92a793fc647..11253d69f02485c9c712c7abd7e56dbae2ad9414 100644 --- a/mace/kernels/opencl/cl/conv_2d.cl +++ b/mace/kernels/opencl/cl/conv_2d.cl @@ -7,7 +7,6 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ #endif __write_only image2d_t output, __private const float relux_max_limit, - __private const float prelu_alpha, __private const int in_height, __private const int in_width, __private const int in_ch_blks, @@ -112,11 +111,11 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ } } -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID) - out0 = do_activation(out0, relux_max_limit, prelu_alpha); - out1 = do_activation(out1, relux_max_limit, prelu_alpha); - out2 = do_activation(out2, relux_max_limit, prelu_alpha); - out3 = do_activation(out3, relux_max_limit, prelu_alpha); +#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + out0 = do_activation(out0, relux_max_limit); + out1 = do_activation(out1, relux_max_limit); + out2 = do_activation(out2, relux_max_limit); + out3 = do_activation(out3, relux_max_limit); #endif const int out_x_base = mul24(out_ch_blk, out_width); diff --git a/mace/kernels/opencl/cl/conv_2d_1x1.cl b/mace/kernels/opencl/cl/conv_2d_1x1.cl index 0eecdb19757fddc00b3f9cb4855fd76f835a79b6..b695165e1c3398ad333f2e52f307cd91e3eb4f59 100644 --- a/mace/kernels/opencl/cl/conv_2d_1x1.cl +++ b/mace/kernels/opencl/cl/conv_2d_1x1.cl @@ -7,12 +7,12 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] #endif __write_only image2d_t output, __private const float relux_max_limit, - __private const float prelu_alpha, __private const int in_height, __private const int in_width, __private const int in_ch_blks, __private const int height, - __private const int width) { + __private const int width, + __private const int stride) { const int out_ch_blk = get_global_id(0); const int out_w_blk = get_global_id(1); const int out_w_blks = get_global_size(1); @@ -31,19 +31,12 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] #endif int4 w; -#if STRIDE == 1 - w.x = out_w_blk; - w.y = w.x + out_w_blks; - w.z = w.y + out_w_blks; - w.w = w.z + out_w_blks; - int out_hb_idx = (out_hb % height); -#elif STRIDE == 2 - w.x = out_w_blk << 1; - w.y = (out_w_blk + out_w_blks) << 1; - w.z = (out_w_blk + (out_w_blks << 1)) << 1; - w.w = (out_w_blk + (out_w_blks << 1) + out_w_blks) << 1; - int out_hb_idx = (out_hb % height) << 1; -#endif + int in_width_stride = mul24(out_w_blks, stride); + w.x = mul24(out_w_blk, stride); + w.y = w.x + in_width_stride; + w.z = w.y + in_width_stride; + w.w = w.z + in_width_stride; + int out_hb_idx = mul24((out_hb % height), stride); w.x = select(w.x, INT_MIN, w.x >= in_width); w.y = select(w.y, INT_MIN, w.y >= in_width); @@ -92,11 +85,11 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] filter_x_base += 4; } -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID) - out0 = do_activation(out0, relux_max_limit, prelu_alpha); - out1 = do_activation(out1, relux_max_limit, prelu_alpha); - out2 = do_activation(out2, relux_max_limit, prelu_alpha); - out3 = do_activation(out3, relux_max_limit, prelu_alpha); +#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + out0 = do_activation(out0, relux_max_limit); + out1 = do_activation(out1, relux_max_limit); + out2 = do_activation(out2, relux_max_limit); + out3 = do_activation(out3, relux_max_limit); #endif const int out_x_base = mul24(out_ch_blk, width); diff --git a/mace/kernels/opencl/cl/conv_2d_3x3.cl b/mace/kernels/opencl/cl/conv_2d_3x3.cl index d37ec7f1e1fc8599c8fd00d9644d0f1251b3d16a..fad561aaca4aa8f6fe862f314177221214264053 100644 --- a/mace/kernels/opencl/cl/conv_2d_3x3.cl +++ b/mace/kernels/opencl/cl/conv_2d_3x3.cl @@ -7,12 +7,12 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] #endif __write_only image2d_t output, __private const float relux_max_limit, - __private const float prelu_alpha, __private const int in_height, __private const int in_width, __private const int in_ch_blks, __private const int out_height, __private const int out_width, + __private const int stride, __private const int padding_top, __private const int padding_left, __private const int dilation_h, @@ -38,21 +38,13 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] DATA_TYPE4 out4 = 0; #endif -#if STRIDE == 1 - int in_width0 = out_w_blk - padding_left; - int in_width1 = in_width0 + out_w_blks; - int in_width2 = in_width1 + out_w_blks; - int in_width3 = in_width2 + out_w_blks; - int in_width4 = in_width3 + out_w_blks; - const int height_idx = (out_hb % out_height) - padding_top; -#elif STRIDE == 2 - int in_width0 = (out_w_blk << 1) - padding_left; - int in_width1 = ((out_w_blk + out_w_blks) << 1) - padding_left; - int in_width2 = ((out_w_blk + (out_w_blks << 1)) << 1) - padding_left; - int in_width3 = ((out_w_blk + (out_w_blks << 1) + out_w_blks) << 1) - padding_left; - int in_width4 = ((out_w_blk + (out_w_blks << 2)) << 1) - padding_left; - const int height_idx = ((out_hb % out_height) << 1) - padding_top; -#endif + int in_width_stride = mul24(out_w_blks, stride); + int in_width0 = mad24(out_w_blk, stride, -padding_left); + int in_width1 = in_width0 + in_width_stride; + int in_width2 = in_width1 + in_width_stride; + int in_width3 = in_width2 + in_width_stride; + int in_width4 = in_width3 + in_width_stride; + const int height_idx = mad24((out_hb % out_height), stride, -padding_top); const int batch_idx = mul24((out_hb / out_height), in_height); const int rounded_in_ch_x_3 = (rounded_in_ch << 1) + rounded_in_ch; @@ -127,12 +119,12 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] } } -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID) - out0 = do_activation(out0, relux_max_limit, prelu_alpha); - out1 = do_activation(out1, relux_max_limit, prelu_alpha); - out2 = do_activation(out2, relux_max_limit, prelu_alpha); - out3 = do_activation(out3, relux_max_limit, prelu_alpha); - out4 = do_activation(out4, relux_max_limit, prelu_alpha); +#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + out0 = do_activation(out0, relux_max_limit); + out1 = do_activation(out1, relux_max_limit); + out2 = do_activation(out2, relux_max_limit); + out3 = do_activation(out3, relux_max_limit); + out4 = do_activation(out4, relux_max_limit); #endif const int out_x_base = mul24(out_ch_blk, out_width); diff --git a/mace/kernels/opencl/cl/depthwise_conv2d.cl b/mace/kernels/opencl/cl/depthwise_conv2d.cl index 5ba07d73dfe0c73617879a884a8a7310536cca32..792c0934a4f7af5774b3065ecd349300a1f18854 100644 --- a/mace/kernels/opencl/cl/depthwise_conv2d.cl +++ b/mace/kernels/opencl/cl/depthwise_conv2d.cl @@ -8,7 +8,6 @@ __kernel void depthwise_conv2d(__read_only image2d_t input, /* [c%4 * w * c/4, h #endif __write_only image2d_t output, __private const float relux_max_limit, - __private const float prelu_alpha, __private const short in_height, __private const short in_width, __private const short in_ch_blks, @@ -103,11 +102,11 @@ __kernel void depthwise_conv2d(__read_only image2d_t input, /* [c%4 * w * c/4, h in_hb_idx += dilation_h; } -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID) - out0 = do_activation(out0, relux_max_limit, prelu_alpha); - out1 = do_activation(out1, relux_max_limit, prelu_alpha); - out2 = do_activation(out2, relux_max_limit, prelu_alpha); - out3 = do_activation(out3, relux_max_limit, prelu_alpha); +#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + out0 = do_activation(out0, relux_max_limit); + out1 = do_activation(out1, relux_max_limit); + out2 = do_activation(out2, relux_max_limit); + out3 = do_activation(out3, relux_max_limit); #endif const short out_x_base = mul24(out_ch_blk, out_width); @@ -134,7 +133,6 @@ __kernel void depthwise_conv2d_s1(__read_only image2d_t input, /* [c%4 * w * c/4 #endif __write_only image2d_t output, __private const DATA_TYPE relux_max_limit, - __private const DATA_TYPE prelu_alpha, __private const short in_height, __private const short in_width, __private const short in_ch_blks, @@ -220,11 +218,11 @@ __kernel void depthwise_conv2d_s1(__read_only image2d_t input, /* [c%4 * w * c/4 in_hb_idx += 1; } -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID) - out0 = do_activation(out0, relux_max_limit, prelu_alpha); - out1 = do_activation(out1, relux_max_limit, prelu_alpha); - out2 = do_activation(out2, relux_max_limit, prelu_alpha); - out3 = do_activation(out3, relux_max_limit, prelu_alpha); +#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + out0 = do_activation(out0, relux_max_limit); + out1 = do_activation(out1, relux_max_limit); + out2 = do_activation(out2, relux_max_limit); + out3 = do_activation(out3, relux_max_limit); #endif const short out_x_base = mul24(out_ch_blk, out_width); diff --git a/mace/kernels/opencl/cl/fully_connected.cl b/mace/kernels/opencl/cl/fully_connected.cl index 021012ffe39eb5f9a3744267bdb088001123f1aa..89264d82b890ca0effd9de53ea935a8821506f19 100644 --- a/mace/kernels/opencl/cl/fully_connected.cl +++ b/mace/kernels/opencl/cl/fully_connected.cl @@ -10,8 +10,7 @@ __kernel void fully_connected(__read_only image2d_t input, __private const int input_height, __private const int input_width, __private const int input_channel, - __private const float relux_max_limit, - __private const float prelu_alpha) { + __private const float relux_max_limit) { const int batch_idx = get_global_id(0); const int out_blk_idx = get_global_id(1); const int input_chan_blk = (input_channel + 3) >> 2; @@ -51,8 +50,8 @@ __kernel void fully_connected(__read_only image2d_t input, input_coord.y++; } -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID) - result = do_activation(result, relux_max_limit, prelu_alpha); +#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + result = do_activation(result, relux_max_limit); #endif WRITE_IMAGET(output, (int2)(out_blk_idx, batch_idx), result); } diff --git a/mace/kernels/opencl/cl/winograd_transform.cl b/mace/kernels/opencl/cl/winograd_transform.cl index e4b315984ed040d5ea088e22237665b14abe4271..cbcd3b193a92e8e135a55014ad5e62b5545ed57e 100644 --- a/mace/kernels/opencl/cl/winograd_transform.cl +++ b/mace/kernels/opencl/cl/winograd_transform.cl @@ -115,8 +115,7 @@ __kernel void winograd_inverse_transform_2x2(__read_only image2d_t input, __private const int out_width, __private const int round_hw, __private const int round_w, - __private const float relux_max_limit, - __private const float prelu_alpha) { + __private const float relux_max_limit) { const int width_idx = get_global_id(0); const int height_idx = get_global_id(1); const int out_channel = get_global_size(1); @@ -183,11 +182,11 @@ __kernel void winograd_inverse_transform_2x2(__read_only image2d_t input, #endif -#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID) - in0[0] = do_activation(in0[0], relux_max_limit, prelu_alpha); - in0[1] = do_activation(in0[1], relux_max_limit, prelu_alpha); - in1[0] = do_activation(in1[0], relux_max_limit, prelu_alpha); - in1[1] = do_activation(in1[1], relux_max_limit, prelu_alpha); +#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + in0[0] = do_activation(in0[0], relux_max_limit); + in0[1] = do_activation(in0[1], relux_max_limit); + in1[0] = do_activation(in1[0], relux_max_limit); + in1[1] = do_activation(in1[1], relux_max_limit); #endif WRITE_IMAGET(output, (int2)(coord_x, coord_y), in0[0]); @@ -205,6 +204,4 @@ __kernel void winograd_inverse_transform_2x2(__read_only image2d_t input, WRITE_IMAGET(output, (int2)(coord_x + 1, coord_y + 1), in1[1]); } - - } diff --git a/mace/kernels/opencl/conv_2d_opencl.cc b/mace/kernels/opencl/conv_2d_opencl.cc index 8c0733b341697b965f7b804c625d035b51dec6f4..94aa783886c8eab5214f5a1edbf75ec6b42c6501 100644 --- a/mace/kernels/opencl/conv_2d_opencl.cc +++ b/mace/kernels/opencl/conv_2d_opencl.cc @@ -17,7 +17,6 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, const int *dilations, const ActivationType activation, const float relux_max_limit, - const float prelu_alpha, const DataType dt, Tensor *output, StatsFuture *future); @@ -31,7 +30,6 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, const int *dilations, const ActivationType activation, const float relux_max_limit, - const float prelu_alpha, const DataType dt, Tensor *output, StatsFuture *future); @@ -45,7 +43,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel, const int *dilations, const ActivationType activation, const float relux_max_limit, - const float prelu_alpha, const DataType dt, Tensor *output, StatsFuture *future); @@ -60,7 +57,7 @@ void Conv2dFunctor::operator()(const Tensor *input, cl::Kernel *kernel, const Tensor *input, const Tensor *filter, const Tensor *bias, const int stride, const int *padding, const int *dilations, const ActivationType activation, - const float relux_max_limit, const float prelu_alpha, const DataType dt, + const float relux_max_limit, const DataType dt, Tensor *output, StatsFuture *future); // Selection matrix: kernel_size x stride_size static const Conv2dOpenclFunction selector[5] = @@ -69,7 +66,6 @@ void Conv2dFunctor::operator()(const Tensor *input, index_t kernel_h = filter->dim(0); index_t kernel_w = filter->dim(1); if (!input->is_image() || strides_[0] != strides_[1] || - ((kernel_h == 1 || kernel_h == 3) && strides_[0] > 2) || (dilations_[0] > 1 && (strides_[0] > 1 || kernel_h == 1))) { LOG(WARNING) << "OpenCL conv2d kernel with " << "filter" << kernel_h << "x" << kernel_w << "," @@ -82,11 +78,14 @@ void Conv2dFunctor::operator()(const Tensor *input, std::vector output_shape(4); std::vector paddings(2); - kernels::CalcNHWCPaddingAndOutputSize( - input->shape().data(), filter->shape().data(), dilations_, strides_, - padding_type_, output_shape.data(), paddings.data()); - if (!paddings_.empty()) { + if (paddings_.empty()) { + kernels::CalcNHWCPaddingAndOutputSize( + input->shape().data(), filter->shape().data(), dilations_, strides_, + padding_type_, output_shape.data(), paddings.data()); + } else { paddings = paddings_; + CalcOutputSize(input->shape().data(), filter->shape().data(), paddings_.data(), + dilations_, strides_, RoundType::FLOOR, output_shape.data()); } std::vector output_image_shape; @@ -94,16 +93,13 @@ void Conv2dFunctor::operator()(const Tensor *input, output->ResizeImage(output_shape, output_image_shape); if (kernel_h == kernel_w && kernel_h <= 5 && - selector[kernel_h - 1] != nullptr && - 0 < strides_[0] && strides_[0] < 3 ) { + selector[kernel_h - 1] != nullptr) { auto conv2d_func = selector[kernel_h - 1]; conv2d_func(&kernel_, input, filter, bias, strides_[0], paddings.data(), dilations_, activation_, - relux_max_limit_, prelu_alpha_, DataTypeToEnum::value, - output, future); + relux_max_limit_, DataTypeToEnum::value, output, future); } else { Conv2dOpencl(&kernel_, input, filter, bias, strides_[0], paddings.data(), dilations_, - activation_, relux_max_limit_, prelu_alpha_, - DataTypeToEnum::value, output, future); + activation_, relux_max_limit_, DataTypeToEnum::value, output, future); } } diff --git a/mace/kernels/opencl/conv_2d_opencl_1x1.cc b/mace/kernels/opencl/conv_2d_opencl_1x1.cc index aa4bbc6bfa47407339dd67f432471c34867e5110..bee0e12a8826c2cd4d7bbe28ba3e70c9fe42f259 100644 --- a/mace/kernels/opencl/conv_2d_opencl_1x1.cc +++ b/mace/kernels/opencl/conv_2d_opencl_1x1.cc @@ -19,7 +19,6 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, const int *dilations, const ActivationType activation, const float relux_max_limit, - const float prelu_alpha, const DataType dt, Tensor *output, StatsFuture *future) { @@ -44,7 +43,6 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, built_options.emplace("-Dconv_2d_1x1=" + kernel_name); built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); - built_options.emplace(MakeString("-DSTRIDE=", stride)); if (bias != nullptr) { built_options.emplace("-DBIAS"); } @@ -57,9 +55,6 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, case RELUX: built_options.emplace("-DUSE_RELUX"); break; - case PRELU: - built_options.emplace("-DUSE_PRELU"); - break; case TANH: built_options.emplace("-DUSE_TANH"); break; @@ -87,12 +82,12 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, *(static_cast(output->buffer()))); // FIXME handle flexable data type: half not supported kernel->setArg(idx++, relux_max_limit); - kernel->setArg(idx++, prelu_alpha); kernel->setArg(idx++, static_cast(input_height)); kernel->setArg(idx++, static_cast(input_width)); kernel->setArg(idx++, static_cast(input_channel_blocks)); kernel->setArg(idx++, static_cast(height)); kernel->setArg(idx++, static_cast(width)); + kernel->setArg(idx++, stride); } const uint32_t gws[3] = {static_cast(channel_blocks), diff --git a/mace/kernels/opencl/conv_2d_opencl_3x3.cc b/mace/kernels/opencl/conv_2d_opencl_3x3.cc index 3a185faf998fed974576fed9aa30587a4a3d0d4b..bb67717791fd05f820ad92f734af545fe2b99e1e 100644 --- a/mace/kernels/opencl/conv_2d_opencl_3x3.cc +++ b/mace/kernels/opencl/conv_2d_opencl_3x3.cc @@ -21,7 +21,6 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, const int *dilations, const ActivationType activation, const float relux_max_limit, - const float prelu_alpha, const DataType dt, Tensor *output, StatsFuture *future) { @@ -42,7 +41,6 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); built_options.emplace(bias != nullptr ? "-DBIAS" : ""); - built_options.emplace(MakeString("-DSTRIDE=", stride)); switch (activation) { case NOOP: break; @@ -52,9 +50,6 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, case RELUX: built_options.emplace("-DUSE_RELUX"); break; - case PRELU: - built_options.emplace("-DUSE_PRELU"); - break; case TANH: built_options.emplace("-DUSE_TANH"); break; @@ -81,12 +76,12 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, kernel->setArg(idx++, *(static_cast(output->buffer()))); kernel->setArg(idx++, relux_max_limit); - kernel->setArg(idx++, prelu_alpha); kernel->setArg(idx++, static_cast(input->dim(1))); kernel->setArg(idx++, static_cast(input->dim(2))); kernel->setArg(idx++, static_cast(input_channel_blocks)); kernel->setArg(idx++, static_cast(height)); kernel->setArg(idx++, static_cast(width)); + kernel->setArg(idx++, stride); kernel->setArg(idx++, padding[0] / 2); kernel->setArg(idx++, padding[1] / 2); kernel->setArg(idx++, dilations[0]); diff --git a/mace/kernels/opencl/conv_2d_opencl_general.cc b/mace/kernels/opencl/conv_2d_opencl_general.cc index 30a1a75171bcd6805cfdad1d69233888d3922444..af344c284fe04836d1d2ac23b4014ffdf76ac22b 100644 --- a/mace/kernels/opencl/conv_2d_opencl_general.cc +++ b/mace/kernels/opencl/conv_2d_opencl_general.cc @@ -21,7 +21,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel, const int *dilations, const ActivationType activation, const float relux_max_limit, - const float prelu_alpha, const DataType dt, Tensor *output, StatsFuture *future) { @@ -42,7 +41,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel, built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); built_options.emplace(bias != nullptr ? "-DBIAS" : ""); - built_options.emplace(MakeString("-DSTRIDE=", stride)); switch (activation) { case NOOP: break; @@ -52,9 +50,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel, case RELUX: built_options.emplace("-DUSE_RELUX"); break; - case PRELU: - built_options.emplace("-DUSE_PRELU"); - break; case TANH: built_options.emplace("-DUSE_TANH"); break; @@ -81,7 +76,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel, kernel->setArg(idx++, *(static_cast(output->buffer()))); kernel->setArg(idx++, relux_max_limit); - kernel->setArg(idx++, prelu_alpha); kernel->setArg(idx++, static_cast(input->dim(1))); kernel->setArg(idx++, static_cast(input->dim(2))); kernel->setArg(idx++, static_cast(input_channel_blocks)); diff --git a/mace/kernels/opencl/depthwise_conv_opencl.cc b/mace/kernels/opencl/depthwise_conv_opencl.cc index 67304bd896bd5e5df14c273c2a839dccfea28390..2942c5d060b9a240c0e9c3aa47cf6e2a82a6fdfd 100644 --- a/mace/kernels/opencl/depthwise_conv_opencl.cc +++ b/mace/kernels/opencl/depthwise_conv_opencl.cc @@ -20,7 +20,6 @@ void DepthwiseConv2d(cl::Kernel *kernel, const int *dilations, const ActivationType activation, const float relux_max_limit, - const float prelu_alpha, const DataType dt, Tensor *output, StatsFuture *future) { @@ -69,9 +68,6 @@ void DepthwiseConv2d(cl::Kernel *kernel, case RELUX: built_options.emplace("-DUSE_RELUX"); break; - case PRELU: - built_options.emplace("-DUSE_PRELU"); - break; case TANH: built_options.emplace("-DUSE_TANH"); break; @@ -96,7 +92,6 @@ void DepthwiseConv2d(cl::Kernel *kernel, kernel->setArg( idx++, *(static_cast(output->buffer()))); kernel->setArg(idx++, relux_max_limit); - kernel->setArg(idx++, prelu_alpha); kernel->setArg(idx++, static_cast(input_height)); kernel->setArg(idx++, static_cast(input_width)); kernel->setArg(idx++, static_cast(input_channel_blocks)); @@ -140,8 +135,8 @@ void DepthwiseConv2dFunctor::operator()( << " is not implemented yet, using slow version"; // TODO(heliangliang) The CPU/NEON kernel should map the buffer DepthwiseConv2dFunctor( - strides_, padding_type_, paddings_, dilations_, activation_, relux_max_limit_, - prelu_alpha_)(input, filter, bias, output, future); + strides_, padding_type_, paddings_, dilations_, activation_, + relux_max_limit_)(input, filter, bias, output, future); return; } @@ -154,11 +149,14 @@ void DepthwiseConv2dFunctor::operator()( std::vector output_shape(4); std::vector paddings(2); - kernels::CalcNHWCPaddingAndOutputSize( - input->shape().data(), fake_filter_shape.data(), dilations_, strides_, - padding_type_, output_shape.data(), paddings.data()); - if (!paddings_.empty()) { + if (paddings_.empty()) { + kernels::CalcNHWCPaddingAndOutputSize( + input->shape().data(), fake_filter_shape.data(), dilations_, strides_, + padding_type_, output_shape.data(), paddings.data()); + } else { paddings = paddings_; + CalcOutputSize(input->shape().data(), fake_filter_shape.data(), paddings_.data(), + dilations_, strides_, RoundType::FLOOR, output_shape.data()); } std::vector output_image_shape; @@ -166,7 +164,7 @@ void DepthwiseConv2dFunctor::operator()( output->ResizeImage(output_shape, output_image_shape); DepthwiseConv2d(&kernel_, input, filter, bias, strides_[0], paddings.data(), dilations_, - activation_, relux_max_limit_, prelu_alpha_, + activation_, relux_max_limit_, DataTypeToEnum::value, output, future); } diff --git a/mace/kernels/opencl/fully_connected_opencl.cc b/mace/kernels/opencl/fully_connected_opencl.cc index 589a4e3368fe65a17c39ab5d5d32be26298435ee..33e26ecab3a668cacf77ae7a16bcd61f13d87aa9 100644 --- a/mace/kernels/opencl/fully_connected_opencl.cc +++ b/mace/kernels/opencl/fully_connected_opencl.cc @@ -48,9 +48,6 @@ void FullyConnectedFunctor::operator()( case RELUX: built_options.emplace("-DUSE_RELUX"); break; - case PRELU: - built_options.emplace("-DUSE_PRELU"); - break; case TANH: built_options.emplace("-DUSE_TANH"); break; @@ -78,7 +75,6 @@ void FullyConnectedFunctor::operator()( kernel_.setArg(idx++, static_cast(input->dim(3))); // FIXME handle flexable data type: half not supported kernel_.setArg(idx++, relux_max_limit_); - kernel_.setArg(idx++, prelu_alpha_); } const uint32_t gws[2] = { diff --git a/mace/kernels/opencl/pooling_opencl.cc b/mace/kernels/opencl/pooling_opencl.cc index 9b612e48a558599751b7bde26df063689ea54c6a..2ec0e0845982ac32cb041203454342411f846e9f 100644 --- a/mace/kernels/opencl/pooling_opencl.cc +++ b/mace/kernels/opencl/pooling_opencl.cc @@ -24,12 +24,14 @@ void PoolingFunctor::operator()(const Tensor *input, }; std::vector paddings(2); - kernels::CalcNHWCPaddingAndOutputSize( - input->shape().data(), filter_shape.data(), - dilations_, strides_, this->padding_type_, - output_shape.data(), paddings.data()); - if (!paddings_.empty()) { + if (paddings_.empty()) { + kernels::CalcNHWCPaddingAndOutputSize( + input->shape().data(), filter_shape.data(), dilations_, strides_, + padding_type_, output_shape.data(), paddings.data()); + } else { paddings = paddings_; + CalcOutputSize(input->shape().data(), filter_shape.data(), paddings_.data(), + dilations_, strides_, RoundType::CEIL, output_shape.data()); } std::vector output_image_shape; diff --git a/mace/kernels/opencl/winograd_transform.cc b/mace/kernels/opencl/winograd_transform.cc index 31ca09f482d75999eb79aa59fca27f5bd0e9929d..54511220fdc4ce1cec32f8e2a38f0fbf38b35519 100644 --- a/mace/kernels/opencl/winograd_transform.cc +++ b/mace/kernels/opencl/winograd_transform.cc @@ -18,11 +18,14 @@ void WinogradTransformFunctor::operator()(const Tensor *i std::vector output_shape(4); std::vector filter_shape = {3, 3, input_tensor->dim(3), 1}; std::vector paddings(2); - kernels::CalcNHWCPaddingAndOutputSize( - input_tensor->shape().data(), filter_shape.data(), dilations_.data(), - strides_.data(), padding_type_, output_shape.data(), paddings.data()); - if (!paddings_.empty()) { + if (paddings_.empty()) { + kernels::CalcNHWCPaddingAndOutputSize( + input_tensor->shape().data(), filter_shape.data(), dilations_.data(), strides_.data(), + padding_type_, output_shape.data(), paddings.data()); + } else { paddings = paddings_; + CalcOutputSize(input_tensor->shape().data(), filter_shape.data(), paddings_.data(), + dilations_.data(), strides_.data(), RoundType::FLOOR, output_shape.data()); } const index_t round_h = (output_shape[1] + 1) / 2; @@ -126,7 +129,6 @@ void WinogradInverseTransformFunctor::operator()(const Te kernel_.setArg(idx++, static_cast(round_h * round_w)); kernel_.setArg(idx++, static_cast(round_w)); kernel_.setArg(idx++, relux_max_limit_); - kernel_.setArg(idx++, prelu_alpha_); } const uint32_t gws[2] = {static_cast(input_tensor->dim(2)), diff --git a/mace/kernels/pooling.h b/mace/kernels/pooling.h index 7b13b2172a6acb390e5a48cb9b98df566f23b8ae..dbbfaefce9a60457c2fb973cbea88646bb9fb830 100644 --- a/mace/kernels/pooling.h +++ b/mace/kernels/pooling.h @@ -65,12 +65,14 @@ struct PoolingFunctor : PoolingFunctorBase { }; std::vector paddings(2); - kernels::CalcNHWCPaddingAndOutputSize( - input_tensor->shape().data(), filter_shape.data(), - dilations_, strides_, this->padding_type_, - output_shape.data(), paddings.data()); - if (!paddings_.empty()) { + if (paddings_.empty()) { + kernels::CalcNHWCPaddingAndOutputSize( + input_tensor->shape().data(), filter_shape.data(), dilations_, strides_, + padding_type_, output_shape.data(), paddings.data()); + } else { paddings = paddings_; + CalcOutputSize(input_tensor->shape().data(), filter_shape.data(), paddings_.data(), + dilations_, strides_, RoundType::CEIL, output_shape.data()); } output_tensor->Resize(output_shape); diff --git a/mace/kernels/winograd_transform.h b/mace/kernels/winograd_transform.h index 639138ee31485c540eb0b066fdf7ec9bdc8fc662..fdab5c7c8d42ee815ca69b37fab34775a60047e2 100644 --- a/mace/kernels/winograd_transform.h +++ b/mace/kernels/winograd_transform.h @@ -58,21 +58,18 @@ struct WinogradInverseTransformFunctorBase { const int height, const int width, const ActivationType activation, - const float relux_max_limit, - const float prelu_alpha) + const float relux_max_limit) : batch_(batch), height_(height), width_(width), activation_(activation), - relux_max_limit_(relux_max_limit), - prelu_alpha_(prelu_alpha) {} + relux_max_limit_(relux_max_limit) {} const int batch_; const int height_; const int width_; const ActivationType activation_; const float relux_max_limit_; - const float prelu_alpha_; }; template @@ -81,9 +78,8 @@ struct WinogradInverseTransformFunctor : WinogradInverseTransformFunctorBase { const int height, const int width, const ActivationType activation, - const float relux_max_limit, - const float prelu_alpha) - : WinogradInverseTransformFunctorBase(batch, height, width, activation, relux_max_limit, prelu_alpha) {} + const float relux_max_limit) + : WinogradInverseTransformFunctorBase(batch, height, width, activation, relux_max_limit) {} void operator()(const Tensor *input, const Tensor *bias, @@ -100,9 +96,8 @@ struct WinogradInverseTransformFunctor : WinogradInverseT const int height, const int width, const ActivationType activation, - const float relux_max_limit, - const float prelu_alpha) - : WinogradInverseTransformFunctorBase(batch, height, width, activation, relux_max_limit, prelu_alpha) {} + const float relux_max_limit) + : WinogradInverseTransformFunctorBase(batch, height, width, activation, relux_max_limit) {} void operator()(const Tensor *input, const Tensor *bias, diff --git a/mace/ops/activation.h b/mace/ops/activation.h index 04ca0249c18a19fb2c47529d97535f4cd8663073..a55dfe1a4dea9ae7bff92c43ee0133def76d7c64 100644 --- a/mace/ops/activation.h +++ b/mace/ops/activation.h @@ -18,15 +18,15 @@ class ActivationOp : public Operator { functor_(kernels::StringToActivationType( OperatorBase::GetSingleArgument("activation", "NOOP")), - OperatorBase::GetSingleArgument("max_limit", 0.0f), - OperatorBase::GetSingleArgument("alpha", 0.0f)) {} + OperatorBase::GetSingleArgument("max_limit", 0.0f)) {} bool Run(StatsFuture *future) override { - const Tensor *input_tensor = this->inputs_[0]; + const Tensor *input_tensor = this->Input(0); + const Tensor *alpha_tensor = this->InputSize() >= 2 ? this->Input(1) : nullptr; Tensor *output_tensor = this->outputs_[0]; output_tensor->ResizeLike(input_tensor); - functor_(input_tensor, output_tensor, future); + functor_(input_tensor, alpha_tensor, output_tensor, future); return true; } diff --git a/mace/ops/activation_test.cc b/mace/ops/activation_test.cc index 02e16108eaacc8f91608457717bfeb7e55260dac..ce5ddd4598d154b760849bee078d944973c60ac8 100644 --- a/mace/ops/activation_test.cc +++ b/mace/ops/activation_test.cc @@ -213,17 +213,22 @@ void TestSimplePrelu() { // Add input data net.AddInputFromArray( "Input", {2, 2, 2, 2}, - {-7, 7, -6, 6, -5, 5, -4, 4, -3, 3, -2, 2, -1, 1, 0, 0}); + {-7, 7, -6, 6, -5, -5, -4, -4, -3, 3, -2, 2, -1, -1, 0, 0}); + net.AddInputFromArray( + "Alpha", {2}, + {2.0, 3.0}); if (D == DeviceType::OPENCL) { BufferToImage(net, "Input", "InputImage", kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(net, "Alpha", "AlphaImage", + kernels::BufferType::ARGUMENT); OpDefBuilder("Activation", "PreluTest") .Input("InputImage") + .Input("AlphaImage") .Output("OutputImage") .AddStringArg("activation", "PRELU") - .AddFloatArg("alpha", 2.0) .Finalize(net.NewOperatorDef()); // Run @@ -235,9 +240,9 @@ void TestSimplePrelu() { } else { OpDefBuilder("Activation", "PreluTest") .Input("Input") + .Input("Alpha") .Output("Output") .AddStringArg("activation", "PRELU") - .AddFloatArg("alpha", 2.0) .Finalize(net.NewOperatorDef()); // Run @@ -245,7 +250,7 @@ void TestSimplePrelu() { } auto expected = CreateTensor( - {2, 2, 2, 2}, {-14, 7, -12, 6, -10, 5, -8, 4, -6, 3, -4, 2, -2, 1, 0, 0}); + {2, 2, 2, 2}, {-14, 7, -12, 6, -10, -15, -8, -12, -6, 3, -4, 2, -2, -3, 0, 0}); ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); } diff --git a/mace/ops/batch_norm.h b/mace/ops/batch_norm.h index 52235410432704c604bd0eeea7fae3d1eb905608..96b1af133b2532f1fbf9166219547d079d2637ed 100644 --- a/mace/ops/batch_norm.h +++ b/mace/ops/batch_norm.h @@ -16,7 +16,7 @@ class BatchNormOp : public Operator { public: BatchNormOp(const OperatorDef &operator_def, Workspace *ws) : Operator(operator_def, ws), - functor_(false, kernels::ActivationType::NOOP, 0.0f, 0.0f) { + functor_(false, kernels::ActivationType::NOOP, 0.0f) { epsilon_ = OperatorBase::GetSingleArgument("epsilon", static_cast(1e-4)); } diff --git a/mace/ops/conv_2d.h b/mace/ops/conv_2d.h index b7347b5eeb1f222c9fb04f127c1cf68669b058cb..c441b0b45b0b00619fe0a36554dcef307227b2d9 100644 --- a/mace/ops/conv_2d.h +++ b/mace/ops/conv_2d.h @@ -23,7 +23,6 @@ class Conv2dOp : public ConvPool2dOpBase { this->paddings_, this->dilations_.data(), kernels::ActivationType::NOOP, - 0.0f, 0.0f) {} bool Run(StatsFuture *future) override { diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index fb93504e5d946d832edc9bbed37a829a10985d8f..184772c47c7a7183f40db324b6a519fa4d820e34 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -342,7 +342,7 @@ TEST_F(Conv2dOpTest, CPUConv1x1) { TestConv1x1(); } TEST_F(Conv2dOpTest, OPENCLConv1x1) { TestConv1x1(); } template -static void TestComplexConvNxNS12(const std::vector &shape) { +static void TestComplexConvNxNS12(const std::vector &shape, const int stride) { testing::internal::LogToStderr(); auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w, Padding type) { @@ -405,20 +405,31 @@ static void TestComplexConvNxNS12(const std::vector &shape) { ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 0.001); }; - for (int kernel_size : {1, 3}) { - for (int stride : {1, 2}) { - func(kernel_size, kernel_size, stride, stride, VALID); - func(kernel_size, kernel_size, stride, stride, SAME); - } + for (int kernel_size : {1, 3, 7}) { + func(kernel_size, kernel_size, stride, stride, VALID); + func(kernel_size, kernel_size, stride, stride, SAME); } } TEST_F(Conv2dOpTest, OPENCLAlignedConvNxNS12) { - TestComplexConvNxNS12({32, 32, 32, 64}); + TestComplexConvNxNS12({32, 16, 16, 32}, + 1); + TestComplexConvNxNS12({32, 16, 16, 32}, + 2); } TEST_F(Conv2dOpTest, OPENCLUnalignedConvNxNS12) { - TestComplexConvNxNS12({107, 113, 5, 7}); + TestComplexConvNxNS12({17, 113, 5, 7}, + 1); + TestComplexConvNxNS12({17, 113, 5, 7}, + 2); +} + +TEST_F(Conv2dOpTest, OPENCLUnalignedConvNxNS34) { + TestComplexConvNxNS12({31, 113, 13, 17}, + 3); + TestComplexConvNxNS12({32, 32, 13, 17}, + 4); } template @@ -650,3 +661,81 @@ TEST_F(Conv2dOpTest, OPENCLUnalignedDilation4) { 4); } +template +static void TestArbitraryPadConvNxN(const std::vector &shape, const std::vector &paddings) { + testing::internal::LogToStderr(); + auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w) { + srand(time(NULL)); + + // generate random input + index_t batch = 1; + index_t height = shape[0]; + index_t width = shape[1]; + index_t input_channels = shape[2]; + index_t output_channels = shape[3]; + // Construct graph + OpsTestNet net; + OpDefBuilder("Conv2D", "Conv2dTest") + .Input("Input") + .Input("Filter") + .Input("Bias") + .Output("Output") + .AddIntsArg("strides", {stride_h, stride_w}) + .AddIntsArg("padding_values", paddings) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput("Input", {batch, height, width, input_channels}); + net.AddRandomInput( + "Filter", {kernel_h, kernel_w, output_channels, input_channels}); + net.AddRandomInput("Bias", {output_channels}); + + // run on cpu + net.RunOp(); + // Check + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + // run on gpu + BufferToImage(net, "Input", "InputImage", kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(net, "Filter", "FilterImage", kernels::BufferType::CONV2D_FILTER); + BufferToImage(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT); + + OpDefBuilder("Conv2D", "Conv2dTest") + .Input("InputImage") + .Input("FilterImage") + .Input("BiasImage") + .Output("OutputImage") + .AddIntsArg("strides", {stride_h, stride_w}) + .AddIntsArg("padding_values", paddings) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + // Run on device + net.RunOp(D); + + ImageToBuffer(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT_CHANNEL); + ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 0.001); + }; + + for (int kernel_size : {3, 5}) { + for (int stride : {2, 3}) { + func(kernel_size, kernel_size, stride, stride); + } + } +} + +TEST_F(Conv2dOpTest, OPENCLAlignedPad1) { + TestArbitraryPadConvNxN({32, 32, 32, 64}, + {1, 1}); +} + +TEST_F(Conv2dOpTest, OPENCLAlignedPad2) { + TestArbitraryPadConvNxN({128, 128, 16, 16}, + {2, 2}); +} + +TEST_F(Conv2dOpTest, OPENCLUnalignedPad4) { + TestArbitraryPadConvNxN({107, 113, 5, 7}, + {4, 4}); +} diff --git a/mace/ops/depthwise_conv2d.h b/mace/ops/depthwise_conv2d.h index 82d6e899649d1a04192725fe62a56c8e776ee8f3..0678ba0757d17a50fa1b2ef2cdcc5c4b2635a46d 100644 --- a/mace/ops/depthwise_conv2d.h +++ b/mace/ops/depthwise_conv2d.h @@ -26,8 +26,7 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase { kernels::StringToActivationType( OperatorBase::GetSingleArgument("activation", "NOOP")), - OperatorBase::GetSingleArgument("max_limit", 0.0f), - OperatorBase::GetSingleArgument("alpha", 0.0f)) {} + OperatorBase::GetSingleArgument("max_limit", 0.0f)) {} bool Run(StatsFuture *future) override { const Tensor *input = this->Input(INPUT); diff --git a/mace/ops/folded_batch_norm.h b/mace/ops/folded_batch_norm.h index b4f00776321a52dade57680989e46f62b7b9e844..28f7f99a8b20e26b553165258885222aca860483 100644 --- a/mace/ops/folded_batch_norm.h +++ b/mace/ops/folded_batch_norm.h @@ -19,8 +19,7 @@ class FoldedBatchNormOp : public Operator { kernels::StringToActivationType( OperatorBase::GetSingleArgument("activation", "NOOP")), - OperatorBase::GetSingleArgument("max_limit", 0.0f), - OperatorBase::GetSingleArgument("alpha", 0.0f)) {} + OperatorBase::GetSingleArgument("max_limit", 0.0f)) {} bool Run(StatsFuture *future) override { const Tensor *input = this->Input(INPUT); diff --git a/mace/ops/fully_connected.h b/mace/ops/fully_connected.h index 0ee90e2b3fd86d20d77d0f20d10b75c214cbb46e..c65947af7148d1ace34e9fe8b899a1044b3c3265 100644 --- a/mace/ops/fully_connected.h +++ b/mace/ops/fully_connected.h @@ -19,8 +19,7 @@ class FullyConnectedOp : public Operator { kernels::StringToActivationType( OperatorBase::GetSingleArgument("activation", "NOOP")), - OperatorBase::GetSingleArgument("max_limit", 0.0f), - OperatorBase::GetSingleArgument("alpha", 0.0f)) {} + OperatorBase::GetSingleArgument("max_limit", 0.0f)) {} bool Run(StatsFuture *future) override { const Tensor *input = this->Input(INPUT); diff --git a/mace/ops/fused_conv_2d.h b/mace/ops/fused_conv_2d.h index 5c5957f1c1975212f8e51d9cddd82cf36171cf30..0184f92f80700fd686e20a8c5918c3ce3c6ecb28 100644 --- a/mace/ops/fused_conv_2d.h +++ b/mace/ops/fused_conv_2d.h @@ -25,8 +25,7 @@ class FusedConv2dOp : public ConvPool2dOpBase { kernels::StringToActivationType( OperatorBase::GetSingleArgument("activation", "NOOP")), - OperatorBase::GetSingleArgument("max_limit", 0.0f), - OperatorBase::GetSingleArgument("alpha", 0.0f)) {} + OperatorBase::GetSingleArgument("max_limit", 0.0f)) {} bool Run(StatsFuture *future) override { const Tensor *input = this->Input(INPUT); diff --git a/mace/ops/winograd_inverse_transform.h b/mace/ops/winograd_inverse_transform.h index 4c20769f1fd461f393c1c57e58bc5f089197ed7c..aef374731408885aa19f4c88f5ffc15fe389bd2c 100644 --- a/mace/ops/winograd_inverse_transform.h +++ b/mace/ops/winograd_inverse_transform.h @@ -24,8 +24,7 @@ class WinogradInverseTransformOp : public Operator { kernels::StringToActivationType( OperatorBase::GetSingleArgument("activation", "NOOP")), - OperatorBase::GetSingleArgument("max_limit", 0.0f), - OperatorBase::GetSingleArgument("alpha", 0.0f)) {} + OperatorBase::GetSingleArgument("max_limit", 0.0f)) {} bool Run(StatsFuture *future) override { const Tensor *input_tensor = this->Input(INPUT); diff --git a/mace/python/tools/caffe_ops_stats.py b/mace/python/tools/caffe_ops_stats.py index 7c3bb7c45e44bb5973f910127a89b7b1963143f7..4eba5b664de816722d370c61757117ef0ffd25fe 100644 --- a/mace/python/tools/caffe_ops_stats.py +++ b/mace/python/tools/caffe_ops_stats.py @@ -5,10 +5,14 @@ import functools import argparse import sys import six +import os.path FLAGS = None def main(unused_args): + if not os.path.isfile(FLAGS.input): + print 'input model file not exist' + return -1 net = caffe_pb2.NetParameter() with open(FLAGS.input) as f: google.protobuf.text_format.Merge(str(f.read()), net)