From 2460a9463e608570d9f4281c0e1f6b381891b068 Mon Sep 17 00:00:00 2001 From: liuqi Date: Mon, 19 Mar 2018 15:06:09 +0800 Subject: [PATCH] Support arbitrary input size. --- mace/kernels/activation.h | 1 + mace/kernels/addn.h | 1 + mace/kernels/batch_norm.h | 1 + mace/kernels/bias_add.h | 1 + mace/kernels/channel_shuffle.h | 1 + mace/kernels/concat.h | 1 + mace/kernels/conv_2d.h | 1 + mace/kernels/depthwise_conv2d.h | 1 + mace/kernels/eltwise.h | 1 + mace/kernels/fully_connected.h | 1 + mace/kernels/opencl/activation_opencl.cc | 5 ++ mace/kernels/opencl/addn.cc | 24 +++++--- mace/kernels/opencl/batch_norm_opencl.cc | 5 +- mace/kernels/opencl/bias_add_opencl.cc | 3 + mace/kernels/opencl/channel_shuffle.cc | 12 ++-- mace/kernels/opencl/concat.cc | 6 +- mace/kernels/opencl/conv_2d_opencl.cc | 11 ++-- mace/kernels/opencl/conv_2d_opencl_1x1.cc | 5 ++ mace/kernels/opencl/conv_2d_opencl_3x3.cc | 6 +- mace/kernels/opencl/conv_2d_opencl_general.cc | 6 +- mace/kernels/opencl/depthwise_conv_opencl.cc | 31 +++++----- mace/kernels/opencl/eltwise_opencl.cc | 3 + mace/kernels/opencl/fully_connected_opencl.cc | 27 ++++++--- mace/kernels/opencl/helper.h | 7 +++ mace/kernels/opencl/matmul.cc | 19 +++--- mace/kernels/opencl/pooling_opencl.cc | 58 ++++++++++--------- mace/kernels/opencl/resize_bilinear_opencl.cc | 32 +++++----- mace/kernels/opencl/softmax_opencl.cc | 3 + mace/kernels/opencl/space_to_batch_opencl.cc | 4 ++ mace/kernels/opencl/winograd_transform.cc | 54 +++++++++-------- mace/kernels/pooling.h | 1 + mace/kernels/resize_bilinear.h | 1 + mace/kernels/softmax.h | 1 + mace/kernels/space_to_batch.h | 1 + mace/kernels/winograd_transform.h | 2 + tools/wino_conv.py | 8 +-- 36 files changed, 225 insertions(+), 120 deletions(-) diff --git a/mace/kernels/activation.h b/mace/kernels/activation.h index dd750a38..1e3601a4 100644 --- a/mace/kernels/activation.h +++ b/mace/kernels/activation.h @@ -152,6 +152,7 @@ class ActivationFunctor { T relux_max_limit_; cl::Kernel kernel_; std::string tuning_key_prefix_; + std::vector input_shape_; }; } // namespace kernels diff --git a/mace/kernels/addn.h b/mace/kernels/addn.h index 6e9ba2d4..3a5a45df 100644 --- a/mace/kernels/addn.h +++ b/mace/kernels/addn.h @@ -91,6 +91,7 @@ struct AddNFunctor { StatsFuture *future); cl::Kernel kernel_; + std::vector input_shape_; }; } // namespace kernels diff --git a/mace/kernels/batch_norm.h b/mace/kernels/batch_norm.h index 1e6a12bf..57f0f4d6 100644 --- a/mace/kernels/batch_norm.h +++ b/mace/kernels/batch_norm.h @@ -156,6 +156,7 @@ struct BatchNormFunctor : BatchNormFunctorBase { Tensor *output, StatsFuture *future); cl::Kernel kernel_; + std::vector input_shape_; }; } // namepsace kernels diff --git a/mace/kernels/bias_add.h b/mace/kernels/bias_add.h index 28adcf8d..d8e411ef 100644 --- a/mace/kernels/bias_add.h +++ b/mace/kernels/bias_add.h @@ -62,6 +62,7 @@ struct BiasAddFunctor { Tensor *output, StatsFuture *future); cl::Kernel kernel_; + std::vector input_shape_; }; } // namepsace kernels diff --git a/mace/kernels/channel_shuffle.h b/mace/kernels/channel_shuffle.h index e627121d..da2ce094 100644 --- a/mace/kernels/channel_shuffle.h +++ b/mace/kernels/channel_shuffle.h @@ -55,6 +55,7 @@ struct ChannelShuffleFunctor { cl::Kernel kernel_; const int groups_; + std::vector input_shape_; }; } // namespace kernels diff --git a/mace/kernels/concat.h b/mace/kernels/concat.h index 021b0f61..68705946 100644 --- a/mace/kernels/concat.h +++ b/mace/kernels/concat.h @@ -83,6 +83,7 @@ struct ConcatFunctor : ConcatFunctorBase { Tensor *output, StatsFuture *future); cl::Kernel kernel_; + std::vector input_shape_; }; } // namepsace kernels diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index a4a24eed..b107d332 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -401,6 +401,7 @@ struct Conv2dFunctor : Conv2dFunctorBase { StatsFuture *future); cl::Kernel kernel_; + std::vector input_shape_; }; } // namespace kernels diff --git a/mace/kernels/depthwise_conv2d.h b/mace/kernels/depthwise_conv2d.h index c0a1719f..dc6b7370 100644 --- a/mace/kernels/depthwise_conv2d.h +++ b/mace/kernels/depthwise_conv2d.h @@ -439,6 +439,7 @@ struct DepthwiseConv2dFunctor StatsFuture *future); cl::Kernel kernel_; + std::vector input_shape_; }; } // namespace kernels diff --git a/mace/kernels/eltwise.h b/mace/kernels/eltwise.h index 263dfb80..1aa883d5 100644 --- a/mace/kernels/eltwise.h +++ b/mace/kernels/eltwise.h @@ -94,6 +94,7 @@ struct EltwiseFunctor : EltwiseFunctorBase { StatsFuture *future); cl::Kernel kernel_; + std::vector input_shape_; }; } // namespace kernels diff --git a/mace/kernels/fully_connected.h b/mace/kernels/fully_connected.h index b8a74021..5c527d45 100644 --- a/mace/kernels/fully_connected.h +++ b/mace/kernels/fully_connected.h @@ -90,6 +90,7 @@ struct FullyConnectedFunctor : FullyConnectedBase { cl::Kernel kernel_; std::vector gws_; std::vector lws_; + std::vector input_shape_; }; } // namespace kernels diff --git a/mace/kernels/opencl/activation_opencl.cc b/mace/kernels/opencl/activation_opencl.cc index 180e38ca..9792cae5 100644 --- a/mace/kernels/opencl/activation_opencl.cc +++ b/mace/kernels/opencl/activation_opencl.cc @@ -58,6 +58,9 @@ void ActivationFunctor::operator()(const Tensor *input, LOG(FATAL) << "Unknown activation type: " << activation_; } kernel_ = runtime->BuildKernel("activation", kernel_name, built_options); + } + + if (!IsVecEqual(input_shape_, input->shape())) { int idx = 0; kernel_.setArg(idx++, *(input->opencl_image())); if (activation_ == PRELU) { @@ -66,6 +69,8 @@ void ActivationFunctor::operator()(const Tensor *input, } kernel_.setArg(idx++, static_cast(relux_max_limit_)); kernel_.setArg(idx++, *(output->opencl_image())); + + input_shape_ = input->shape(); } const uint32_t gws[3] = {static_cast(channel_blocks), diff --git a/mace/kernels/opencl/addn.cc b/mace/kernels/opencl/addn.cc index a6863a59..9f9571d0 100644 --- a/mace/kernels/opencl/addn.cc +++ b/mace/kernels/opencl/addn.cc @@ -32,15 +32,6 @@ void AddNFunctor::operator()( MACE_CHECK(channels == input_tensors[i]->dim(3)); } - std::vector output_shape = input_tensors[0]->shape(); - std::vector output_image_shape; - CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, output_image_shape); - output_tensor->ResizeImage(output_shape, output_image_shape); - - const index_t channel_blocks = RoundUpDiv4(channels); - const index_t width_pixels = channel_blocks * width; - const index_t batch_height_pixels = batch * height; - if (kernel_.get() == nullptr) { if (input_tensors.size() > 4) { MACE_NOT_IMPLEMENTED; @@ -55,11 +46,26 @@ void AddNFunctor::operator()( built_options.emplace(MakeString("-DINPUT_NUM=", input_tensors.size())); kernel_ = runtime->BuildKernel("addn", kernel_name, built_options); + } + + std::vector output_shape = input_tensors[0]->shape(); + + const index_t channel_blocks = RoundUpDiv4(channels); + const index_t width_pixels = channel_blocks * width; + const index_t batch_height_pixels = batch * height; + + if (!IsVecEqual(input_shape_, input_tensors[0]->shape())) { + std::vector output_image_shape; + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, output_image_shape); + output_tensor->ResizeImage(output_shape, output_image_shape); + uint32_t idx = 0; for (auto input : input_tensors) { kernel_.setArg(idx++, *(input->opencl_image())); } kernel_.setArg(idx++, *(output_tensor->opencl_image())); + + input_shape_ = input_tensors[0]->shape(); } const uint32_t gws[2] = {static_cast(width_pixels), diff --git a/mace/kernels/opencl/batch_norm_opencl.cc b/mace/kernels/opencl/batch_norm_opencl.cc index 8f14f34b..d9dfb825 100644 --- a/mace/kernels/opencl/batch_norm_opencl.cc +++ b/mace/kernels/opencl/batch_norm_opencl.cc @@ -61,7 +61,8 @@ void BatchNormFunctor::operator()(const Tensor *input, } kernel_ = runtime->BuildKernel("batch_norm", kernel_name, built_options); - + } + if (!IsVecEqual(input_shape_, input->shape())) { uint32_t idx = 0; kernel_.setArg(idx++, *(input->opencl_image())); kernel_.setArg(idx++, *(scale->opencl_image())); @@ -73,6 +74,8 @@ void BatchNormFunctor::operator()(const Tensor *input, } kernel_.setArg(idx++, *(output->opencl_image())); kernel_.setArg(idx++, relux_max_limit_); + + input_shape_ = input->shape(); } const uint32_t gws[3] = {static_cast(channel_blocks), diff --git a/mace/kernels/opencl/bias_add_opencl.cc b/mace/kernels/opencl/bias_add_opencl.cc index 613b633b..d2490000 100644 --- a/mace/kernels/opencl/bias_add_opencl.cc +++ b/mace/kernels/opencl/bias_add_opencl.cc @@ -33,10 +33,13 @@ void BiasAddFunctor::operator()(const Tensor *input, built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); kernel_ = runtime->BuildKernel("bias_add", kernel_name, built_options); + } + if (!IsVecEqual(input_shape_, input->shape())) { uint32_t idx = 0; kernel_.setArg(idx++, *(input->opencl_image())); kernel_.setArg(idx++, *(bias->opencl_image())); kernel_.setArg(idx++, *(output->opencl_image())); + input_shape_ = input->shape(); } const uint32_t gws[3] = {static_cast(channel_blocks), diff --git a/mace/kernels/opencl/channel_shuffle.cc b/mace/kernels/opencl/channel_shuffle.cc index 3325ff24..a88b3b05 100644 --- a/mace/kernels/opencl/channel_shuffle.cc +++ b/mace/kernels/opencl/channel_shuffle.cc @@ -13,9 +13,10 @@ namespace mace { namespace kernels { template -void ChannelShuffleFunctor::operator()(const Tensor *input, - Tensor *output, - StatsFuture *future) { +void ChannelShuffleFunctor::operator()( + const Tensor *input, + Tensor *output, + StatsFuture *future) { output->ResizeLike(input); const index_t batch = input->dim(0); @@ -39,12 +40,15 @@ void ChannelShuffleFunctor::operator()(const Tensor *inpu built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); kernel_ = runtime->BuildKernel("channel_shuffle", kernel_name, built_options); - + } + if (!IsVecEqual(input_shape_, input->shape())) { uint32_t idx = 0; kernel_.setArg(idx++, *(input->opencl_image())); kernel_.setArg(idx++, groups_); kernel_.setArg(idx++, static_cast(channels_per_group)); kernel_.setArg(idx++, *(output->opencl_image())); + + input_shape_ = input->shape(); } const uint32_t gws[3] = {static_cast(group_channel_blocks), static_cast(width), diff --git a/mace/kernels/opencl/concat.cc b/mace/kernels/opencl/concat.cc index 119ec7cd..e99ab060 100644 --- a/mace/kernels/opencl/concat.cc +++ b/mace/kernels/opencl/concat.cc @@ -15,6 +15,7 @@ static void Concat2(cl::Kernel *kernel, const Tensor *input0, const Tensor *input1, const DataType dt, + std::vector *prev_input_shape, Tensor *output, StatsFuture *future) { const index_t batch = output->dim(0); @@ -41,6 +42,8 @@ static void Concat2(cl::Kernel *kernel, } *kernel = runtime->BuildKernel("concat", kernel_name, built_options); + } + if (!IsVecEqual(*prev_input_shape, input0->shape())) { uint32_t idx = 0; kernel->setArg(idx++, *(static_cast(input0->opencl_image()))); @@ -49,6 +52,7 @@ static void Concat2(cl::Kernel *kernel, kernel->setArg(idx++, static_cast(input0->dim(3))); kernel->setArg(idx++, *(static_cast(output->opencl_image()))); + *prev_input_shape = input0->shape(); } const uint32_t gws[3] = { @@ -142,7 +146,7 @@ void ConcatFunctor::operator()( switch (inputs_count) { case 2: Concat2(&kernel_, input_list[0], input_list[1], DataTypeToEnum::value, - output, future); + &input_shape_, output, future); break; default: if (divisible_four) { diff --git a/mace/kernels/opencl/conv_2d_opencl.cc b/mace/kernels/opencl/conv_2d_opencl.cc index 3ed87e7c..46683fd1 100644 --- a/mace/kernels/opencl/conv_2d_opencl.cc +++ b/mace/kernels/opencl/conv_2d_opencl.cc @@ -18,6 +18,7 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, const ActivationType activation, const float relux_max_limit, const DataType dt, + std::vector *prev_input_shape, Tensor *output, StatsFuture *future); @@ -31,6 +32,7 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, const ActivationType activation, const float relux_max_limit, const DataType dt, + std::vector *prev_input_shape, Tensor *output, StatsFuture *future); @@ -44,6 +46,7 @@ extern void Conv2dOpencl(cl::Kernel *kernel, const ActivationType activation, const float relux_max_limit, const DataType dt, + std::vector *prev_input_shape, Tensor *output, StatsFuture *future); @@ -57,8 +60,8 @@ void Conv2dFunctor::operator()(const Tensor *input, cl::Kernel * kernel, const Tensor *input, const Tensor *filter, const Tensor *bias, const int stride, const int *padding, const int *dilations, const ActivationType activation, - const float relux_max_limit, const DataType dt, Tensor *output, - StatsFuture *future); + const float relux_max_limit, const DataType dt, + std::vector *input_shape, Tensor *output, StatsFuture *future); // Selection matrix: kernel_size x stride_size static const Conv2dOpenclFunction selector[5] = { Conv2dOpenclK1x1, nullptr, Conv2dOpenclK3x3, nullptr, nullptr}; @@ -97,11 +100,11 @@ void Conv2dFunctor::operator()(const Tensor *input, auto conv2d_func = selector[kernel_h - 1]; conv2d_func(&kernel_, input, filter, bias, strides_[0], paddings.data(), dilations_, activation_, relux_max_limit_, - DataTypeToEnum::value, output, future); + DataTypeToEnum::value, &input_shape_, output, future); } else { Conv2dOpencl(&kernel_, input, filter, bias, strides_[0], paddings.data(), dilations_, activation_, relux_max_limit_, - DataTypeToEnum::value, output, future); + DataTypeToEnum::value, &input_shape_, output, future); } } diff --git a/mace/kernels/opencl/conv_2d_opencl_1x1.cc b/mace/kernels/opencl/conv_2d_opencl_1x1.cc index 41eaad56..4109a979 100644 --- a/mace/kernels/opencl/conv_2d_opencl_1x1.cc +++ b/mace/kernels/opencl/conv_2d_opencl_1x1.cc @@ -20,6 +20,7 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, const ActivationType activation, const float relux_max_limit, const DataType dt, + std::vector *prev_input_shape, Tensor *output, StatsFuture *future) { const index_t batch = output->dim(0); @@ -68,6 +69,8 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, auto runtime = OpenCLRuntime::Global(); *kernel = runtime->BuildKernel("conv_2d_1x1", kernel_name, built_options); + } + if (!IsVecEqual(*prev_input_shape, input->shape())) { uint32_t idx = 0; kernel->setArg(idx++, *(input->opencl_image())); kernel->setArg(idx++, *(filter->opencl_image())); @@ -83,6 +86,8 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, kernel->setArg(idx++, static_cast(height)); kernel->setArg(idx++, static_cast(width)); kernel->setArg(idx++, stride); + + *prev_input_shape = input->shape(); } const uint32_t gws[3] = {static_cast(channel_blocks), diff --git a/mace/kernels/opencl/conv_2d_opencl_3x3.cc b/mace/kernels/opencl/conv_2d_opencl_3x3.cc index df2672c9..ba047cdf 100644 --- a/mace/kernels/opencl/conv_2d_opencl_3x3.cc +++ b/mace/kernels/opencl/conv_2d_opencl_3x3.cc @@ -22,6 +22,7 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, const ActivationType activation, const float relux_max_limit, const DataType dt, + std::vector *prev_input_shape, Tensor *output, StatsFuture *future) { const index_t batch = output->dim(0); @@ -62,7 +63,8 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, auto runtime = OpenCLRuntime::Global(); *kernel = runtime->BuildKernel("conv_2d_3x3", kernel_name, built_options); - + } + if (!IsVecEqual(*prev_input_shape, input->shape())) { uint32_t idx = 0; kernel->setArg(idx++, *(input->opencl_image())); kernel->setArg(idx++, *(filter->opencl_image())); @@ -81,6 +83,8 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, kernel->setArg(idx++, padding[1] / 2); kernel->setArg(idx++, dilations[0]); kernel->setArg(idx++, dilations[1]); + + *prev_input_shape = input->shape(); } const uint32_t gws[3] = {static_cast(channel_blocks), diff --git a/mace/kernels/opencl/conv_2d_opencl_general.cc b/mace/kernels/opencl/conv_2d_opencl_general.cc index c317aa8c..fd48605f 100644 --- a/mace/kernels/opencl/conv_2d_opencl_general.cc +++ b/mace/kernels/opencl/conv_2d_opencl_general.cc @@ -22,6 +22,7 @@ extern void Conv2dOpencl(cl::Kernel *kernel, const ActivationType activation, const float relux_max_limit, const DataType dt, + std::vector *prev_input_shape, Tensor *output, StatsFuture *future) { const index_t batch = output->dim(0); @@ -62,7 +63,8 @@ extern void Conv2dOpencl(cl::Kernel *kernel, auto runtime = OpenCLRuntime::Global(); *kernel = runtime->BuildKernel("conv_2d", kernel_name, built_options); - + } + if (!IsVecEqual(*prev_input_shape, input->shape())) { uint32_t idx = 0; kernel->setArg(idx++, *(input->opencl_image())); kernel->setArg(idx++, *(filter->opencl_image())); @@ -83,6 +85,8 @@ extern void Conv2dOpencl(cl::Kernel *kernel, kernel->setArg(idx++, padding[1] / 2); kernel->setArg(idx++, dilations[0]); kernel->setArg(idx++, dilations[1]); + + *prev_input_shape = input->shape(); } const uint32_t gws[3] = {static_cast(channel_blocks), diff --git a/mace/kernels/opencl/depthwise_conv_opencl.cc b/mace/kernels/opencl/depthwise_conv_opencl.cc index 1b99188b..37b587dc 100644 --- a/mace/kernels/opencl/depthwise_conv_opencl.cc +++ b/mace/kernels/opencl/depthwise_conv_opencl.cc @@ -21,6 +21,7 @@ void DepthwiseConv2d(cl::Kernel *kernel, const ActivationType activation, const float relux_max_limit, const DataType dt, + std::vector *prev_input_shape, Tensor *output, StatsFuture *future) { const index_t batch = output->dim(0); @@ -35,17 +36,6 @@ void DepthwiseConv2d(cl::Kernel *kernel, const index_t input_channel_blocks = RoundUpDiv4(input_channels); const index_t width_blocks = RoundUpDiv4(width); if (kernel->get() == nullptr) { - const index_t input_batch = input->dim(0); - const index_t input_height = input->dim(1); - const index_t input_width = input->dim(2); - - const index_t filter_height = filter->dim(0); - const index_t filter_width = filter->dim(1); - MACE_CHECK(multiplier == 1, "Multiplier > 1 not supported"); - MACE_CHECK(multiplier * input_channels == channels); - MACE_CHECK(filter->dim(2) == input_channels, filter->dim(2), "!=", - input_channels); - auto runtime = OpenCLRuntime::Global(); std::set built_options; std::string kernel_name = MACE_OBFUSCATE_SYMBOL("depthwise_conv2d"); @@ -80,6 +70,18 @@ void DepthwiseConv2d(cl::Kernel *kernel, *kernel = runtime->BuildKernel("depthwise_conv2d", kernel_name, built_options); + } + if (!IsVecEqual(*prev_input_shape, input->shape())) { + const index_t input_batch = input->dim(0); + const index_t input_height = input->dim(1); + const index_t input_width = input->dim(2); + + const index_t filter_height = filter->dim(0); + const index_t filter_width = filter->dim(1); + MACE_CHECK(multiplier == 1, "Multiplier > 1 not supported"); + MACE_CHECK(multiplier * input_channels == channels); + MACE_CHECK(filter->dim(2) == input_channels, filter->dim(2), "!=", + input_channels); uint32_t idx = 0; kernel->setArg(idx++, *(input->opencl_image())); @@ -102,6 +104,7 @@ void DepthwiseConv2d(cl::Kernel *kernel, kernel->setArg(idx++, static_cast(dilations[0])); kernel->setArg(idx++, static_cast(dilations[1])); } + *prev_input_shape = input->shape(); } const uint32_t gws[3] = {static_cast(channel_blocks), @@ -120,9 +123,7 @@ void DepthwiseConv2dFunctor::operator()( const Tensor *bias, Tensor *output, StatsFuture *future) { - typedef void (*Conv2dOpenclFunction)(const Tensor *input, - const Tensor *filter, const Tensor *bias, - Tensor *output, StatsFuture *future); + index_t kernel_h = filter->dim(2); index_t kernel_w = filter->dim(3); if (strides_[0] != strides_[1]) { @@ -163,7 +164,7 @@ void DepthwiseConv2dFunctor::operator()( DepthwiseConv2d(&kernel_, input, filter, bias, strides_[0], paddings.data(), dilations_, activation_, relux_max_limit_, - DataTypeToEnum::value, output, future); + DataTypeToEnum::value, &input_shape_, output, future); } template struct DepthwiseConv2dFunctor; diff --git a/mace/kernels/opencl/eltwise_opencl.cc b/mace/kernels/opencl/eltwise_opencl.cc index 82312c75..dde05b29 100644 --- a/mace/kernels/opencl/eltwise_opencl.cc +++ b/mace/kernels/opencl/eltwise_opencl.cc @@ -36,6 +36,8 @@ void EltwiseFunctor::operator()(const Tensor *input0, if (!coeff_.empty()) built_options.emplace("-DCOEFF_SUM"); kernel_ = runtime->BuildKernel("eltwise", kernel_name, built_options); + } + if (!IsVecEqual(input_shape_, input0->shape())) { uint32_t idx = 0; kernel_.setArg(idx++, *(input0->opencl_image())); kernel_.setArg(idx++, *(input1->opencl_image())); @@ -44,6 +46,7 @@ void EltwiseFunctor::operator()(const Tensor *input0, kernel_.setArg(idx++, coeff_[1]); } kernel_.setArg(idx++, *(output->opencl_image())); + input_shape_ = input0->shape(); } const uint32_t gws[2] = {static_cast(width_pixels), diff --git a/mace/kernels/opencl/fully_connected_opencl.cc b/mace/kernels/opencl/fully_connected_opencl.cc index abcbfe52..d5db5190 100644 --- a/mace/kernels/opencl/fully_connected_opencl.cc +++ b/mace/kernels/opencl/fully_connected_opencl.cc @@ -13,6 +13,7 @@ void FCWXKernel(cl::Kernel *kernel, const Tensor *input, const Tensor *weight, const Tensor *bias, + std::vector *prev_input_shape, Tensor *output, const ActivationType activation, std::vector &gws, @@ -67,6 +68,11 @@ void FCWXKernel(cl::Kernel *kernel, const uint32_t inter_local_blks = kwg_size / (gws[0] * gws[1]); lws = {gws[0], gws[1], inter_local_blks}; + } + if (!IsVecEqual(*prev_input_shape, input->shape())) { + const index_t batch = output->dim(0); + const index_t output_blocks = RoundUpDiv4(output->dim(3)); + uint32_t idx = 0; kernel->setArg(idx++, *(input->opencl_image())); kernel->setArg(idx++, *(weight->opencl_image())); @@ -80,6 +86,10 @@ void FCWXKernel(cl::Kernel *kernel, kernel->setArg(idx++, static_cast(RoundUpDiv4(input->dim(3)))); kernel->setArg(idx++, static_cast(output_blocks)); kernel->setArg(idx++, relux_max_limit); + + gws[2] = static_cast(batch * output_blocks); + + *prev_input_shape = input->shape(); } cl::Event event; cl_int error = runtime->command_queue().enqueueNDRangeKernel( @@ -103,6 +113,7 @@ void FCWTXKernel(cl::Kernel *kernel, const Tensor *input, const Tensor *weight, const Tensor *bias, + std::vector *prev_input_shape, Tensor *output, const ActivationType activation, std::vector &gws, @@ -141,6 +152,9 @@ void FCWTXKernel(cl::Kernel *kernel, *kernel = runtime->BuildKernel("fully_connected", kernel_name, built_options); + lws = {16, 64, 1}; + } + if (!IsVecEqual(*prev_input_shape, input->shape())) { uint32_t idx = 0; kernel->setArg(idx++, *(input->opencl_image())); kernel->setArg(idx++, *(weight->opencl_image())); @@ -155,14 +169,13 @@ void FCWTXKernel(cl::Kernel *kernel, kernel->setArg(idx++, relux_max_limit); const index_t batch = output->dim(0); - const index_t output_size = output->dim(3); - - const index_t output_blocks = RoundUpDiv4(output_size); + const index_t output_blocks = RoundUpDiv4(output->dim(3)); gws = { static_cast(batch), static_cast(output_blocks), }; - lws = {16, 64, 1}; + + *prev_input_shape = input->shape(); } std::stringstream ss; @@ -185,11 +198,11 @@ void FullyConnectedFunctor::operator()( output->ResizeImage(output_shape, output_image_shape); if (weight_type_ == BufferType::WEIGHT_HEIGHT) { - FCWTXKernel(&kernel_, input, weight, bias, output, + FCWTXKernel(&kernel_, input, weight, bias, &input_shape_, output, activation_, gws_, lws_, relux_max_limit_, future); } else { - FCWXKernel(&kernel_, input, weight, bias, output, - activation_, gws_, lws_, relux_max_limit_, future); + FCWXKernel(&kernel_, input, weight, bias, &input_shape_, output, + activation_, gws_, lws_, relux_max_limit_, future); } }; diff --git a/mace/kernels/opencl/helper.h b/mace/kernels/opencl/helper.h index 6513415a..56bf295e 100644 --- a/mace/kernels/opencl/helper.h +++ b/mace/kernels/opencl/helper.h @@ -71,6 +71,13 @@ inline bool LimitKernelTime() { return flag != nullptr && strlen(flag) == 1 && flag[0] == '1'; } +template +bool IsVecEqual(const std::vector &input0, + const std::vector &input1) { + return ((input0.size() == input1.size()) && + (std::equal(input0.begin(), input0.end(), input1.begin()))); +} + namespace { template void AppendToStream(std::stringstream *ss, const std::string &delimiter, T v) { diff --git a/mace/kernels/opencl/matmul.cc b/mace/kernels/opencl/matmul.cc index d453c293..4b61edb2 100644 --- a/mace/kernels/opencl/matmul.cc +++ b/mace/kernels/opencl/matmul.cc @@ -36,17 +36,16 @@ void MatMulFunctor::operator()(const Tensor *A, built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); kernel_ = runtime->BuildKernel("matmul", kernel_name, built_options); - - uint32_t idx = 0; - kernel_.setArg(idx++, *(A->opencl_image())); - kernel_.setArg(idx++, *(B->opencl_image())); - kernel_.setArg(idx++, *(C->opencl_image())); - kernel_.setArg(idx++, static_cast(height)); - kernel_.setArg(idx++, static_cast(width)); - kernel_.setArg(idx++, static_cast(A->dim(2))); - kernel_.setArg(idx++, static_cast(height_blocks)); - kernel_.setArg(idx++, static_cast(RoundUpDiv4(A->dim(2)))); } + uint32_t idx = 0; + kernel_.setArg(idx++, *(A->opencl_image())); + kernel_.setArg(idx++, *(B->opencl_image())); + kernel_.setArg(idx++, *(C->opencl_image())); + kernel_.setArg(idx++, static_cast(height)); + kernel_.setArg(idx++, static_cast(width)); + kernel_.setArg(idx++, static_cast(A->dim(2))); + kernel_.setArg(idx++, static_cast(height_blocks)); + kernel_.setArg(idx++, static_cast(RoundUpDiv4(A->dim(2)))); const uint32_t gws[2] = { static_cast(width_blocks), diff --git a/mace/kernels/opencl/pooling_opencl.cc b/mace/kernels/opencl/pooling_opencl.cc index d9256776..d8a6d675 100644 --- a/mace/kernels/opencl/pooling_opencl.cc +++ b/mace/kernels/opencl/pooling_opencl.cc @@ -17,31 +17,6 @@ void PoolingFunctor::operator()(const Tensor *input, StatsFuture *future) { MACE_CHECK(dilations_[0] == 1 && dilations_[1] == 1) << "Pooling opencl kernel not support dilation yet"; - std::vector output_shape(4); - std::vector filter_shape = {kernels_[0], kernels_[1], input->dim(3), - input->dim(3)}; - - std::vector paddings(2); - if (paddings_.empty()) { - kernels::CalcNHWCPaddingAndOutputSize( - input->shape().data(), filter_shape.data(), dilations_, strides_, - padding_type_, output_shape.data(), paddings.data()); - } else { - paddings = paddings_; - CalcOutputSize(input->shape().data(), filter_shape.data(), paddings_.data(), - dilations_, strides_, RoundType::CEIL, output_shape.data()); - } - - std::vector output_image_shape; - CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, output_image_shape); - output->ResizeImage(output_shape, output_image_shape); - - index_t batch = output->dim(0); - index_t out_height = output->dim(1); - index_t out_width = output->dim(2); - index_t channels = output->dim(3); - - index_t channel_blocks = (channels + 3) / 4; if (kernel_.get() == nullptr) { const DataType dt = DataTypeToEnum::value; @@ -62,18 +37,49 @@ void PoolingFunctor::operator()(const Tensor *input, } kernel_ = runtime->BuildKernel("pooling", kernel_name, built_options); + } + if (!IsVecEqual(input_shape_, input->shape())) { + std::vector output_shape(4); + std::vector filter_shape = {kernels_[0], kernels_[1], input->dim(3), + input->dim(3)}; + + std::vector paddings(2); + if (paddings_.empty()) { + kernels::CalcNHWCPaddingAndOutputSize( + input->shape().data(), filter_shape.data(), dilations_, strides_, + padding_type_, output_shape.data(), paddings.data()); + } else { + paddings = paddings_; + CalcOutputSize(input->shape().data(), filter_shape.data(), paddings_.data(), + dilations_, strides_, RoundType::CEIL, output_shape.data()); + } + + std::vector output_image_shape; + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, output_image_shape); + output->ResizeImage(output_shape, output_image_shape); + uint32_t idx = 0; kernel_.setArg(idx++, *(input->opencl_image())); kernel_.setArg(idx++, static_cast(input->dim(1))); kernel_.setArg(idx++, static_cast(input->dim(2))); - kernel_.setArg(idx++, static_cast(out_height)); + kernel_.setArg(idx++, static_cast(output->dim(1))); kernel_.setArg(idx++, paddings[0] / 2); kernel_.setArg(idx++, paddings[1] / 2); kernel_.setArg(idx++, strides_[0]); kernel_.setArg(idx++, kernels_[0]); kernel_.setArg(idx++, *(output->opencl_image())); + + input_shape_ = input->shape(); } + index_t batch = output->dim(0); + index_t out_height = output->dim(1); + index_t out_width = output->dim(2); + index_t channels = output->dim(3); + + index_t channel_blocks = (channels + 3) / 4; + + const uint32_t gws[3] = { static_cast(channel_blocks), static_cast(out_width), static_cast(batch * out_height), diff --git a/mace/kernels/opencl/resize_bilinear_opencl.cc b/mace/kernels/opencl/resize_bilinear_opencl.cc index 86893859..a3bb2ee1 100644 --- a/mace/kernels/opencl/resize_bilinear_opencl.cc +++ b/mace/kernels/opencl/resize_bilinear_opencl.cc @@ -24,21 +24,7 @@ void ResizeBilinearFunctor::operator()( const index_t out_height = out_height_; const index_t out_width = out_width_; - MACE_CHECK(out_height > 0 && out_width > 0); - std::vector output_shape{batch, out_height, out_width, channels}; - - std::vector output_image_shape; - CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, - output_image_shape); - output->ResizeImage(output_shape, output_image_shape); - if (kernel_.get() == nullptr) { - - float height_scale = - CalculateResizeScale(in_height, out_height, align_corners_); - float width_scale = - CalculateResizeScale(in_width, out_width, align_corners_); - auto runtime = OpenCLRuntime::Global(); std::set built_options; std::string kernel_name = MACE_OBFUSCATE_SYMBOL("resize_bilinear_nocache"); @@ -49,6 +35,21 @@ void ResizeBilinearFunctor::operator()( kernel_ = runtime->BuildKernel("resize_bilinear", kernel_name, built_options); + } + if (!IsVecEqual(input_shape_, input->shape())) { + MACE_CHECK(out_height > 0 && out_width > 0); + std::vector output_shape{batch, out_height, out_width, channels}; + + std::vector output_image_shape; + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, + output_image_shape); + output->ResizeImage(output_shape, output_image_shape); + + float height_scale = + CalculateResizeScale(in_height, out_height, align_corners_); + float width_scale = + CalculateResizeScale(in_width, out_width, align_corners_); + uint32_t idx = 0; kernel_.setArg(idx++, *(input->opencl_image())); kernel_.setArg(idx++, *(output->opencl_image())); @@ -57,6 +58,9 @@ void ResizeBilinearFunctor::operator()( kernel_.setArg(idx++, static_cast(in_height)); kernel_.setArg(idx++, static_cast(in_width)); kernel_.setArg(idx++, static_cast(out_height)); + + input_shape_ = input->shape(); + } const uint32_t gws[3] = {static_cast(channel_blocks), diff --git a/mace/kernels/opencl/softmax_opencl.cc b/mace/kernels/opencl/softmax_opencl.cc index 25e1c9e4..4aabe901 100644 --- a/mace/kernels/opencl/softmax_opencl.cc +++ b/mace/kernels/opencl/softmax_opencl.cc @@ -34,11 +34,14 @@ void SoftmaxFunctor::operator()(const Tensor *logits, built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); kernel_ = runtime->BuildKernel("softmax", kernel_name, built_options); + } + if (!IsVecEqual(input_shape_, logits->shape())) { uint32_t idx = 0; kernel_.setArg(idx++, *(logits->opencl_image())); kernel_.setArg(idx++, static_cast(channels)); kernel_.setArg(idx++, remain_channels); kernel_.setArg(idx++, *(output->opencl_image())); + input_shape_ = logits->shape(); } const uint32_t gws[3] = {static_cast(channel_blocks), static_cast(width), diff --git a/mace/kernels/opencl/space_to_batch_opencl.cc b/mace/kernels/opencl/space_to_batch_opencl.cc index 0cecb0a7..91f5564d 100644 --- a/mace/kernels/opencl/space_to_batch_opencl.cc +++ b/mace/kernels/opencl/space_to_batch_opencl.cc @@ -43,6 +43,8 @@ void SpaceToBatchFunctor::operator()( kernel_ = runtime->BuildKernel("space_to_batch", kernel_name, built_options); + } + if (!IsVecEqual(space_shape_, space_tensor->shape())) { uint32_t idx = 0; if (b2s_) { kernel_.setArg(idx++, *(batch_tensor->opencl_image())); @@ -59,6 +61,8 @@ void SpaceToBatchFunctor::operator()( kernel_.setArg(idx++, static_cast(space_tensor->dim(2))); kernel_.setArg(idx++, static_cast(batch_tensor->dim(1))); kernel_.setArg(idx++, static_cast(batch_tensor->dim(2))); + + space_shape_ = space_tensor->shape(); } const uint32_t chan_blk = RoundUpDiv4(batch_tensor->dim(3)); diff --git a/mace/kernels/opencl/winograd_transform.cc b/mace/kernels/opencl/winograd_transform.cc index aa67b20d..c07ccc99 100644 --- a/mace/kernels/opencl/winograd_transform.cc +++ b/mace/kernels/opencl/winograd_transform.cc @@ -14,6 +14,21 @@ namespace kernels { template void WinogradTransformFunctor::operator()( const Tensor *input_tensor, Tensor *output_tensor, StatsFuture *future) { + + if (kernel_.get() == nullptr) { + std::string obfuscated_kernel_name = + MACE_OBFUSCATE_SYMBOL("winograd_transform_2x2"); + std::set built_options; + built_options.emplace("-Dwinograd_transform_2x2=" + obfuscated_kernel_name); + built_options.emplace("-DDATA_TYPE=" + + DtToUpstreamCLDt(DataTypeToEnum::value)); + built_options.emplace("-DCMD_DATA_TYPE=" + + DtToUpstreamCLCMDDt(DataTypeToEnum::value)); + auto runtime = OpenCLRuntime::Global(); + kernel_ = runtime->BuildKernel("winograd_transform", obfuscated_kernel_name, + built_options); + + } std::vector output_shape(4); std::vector filter_shape = {3, 3, input_tensor->dim(3), 1}; std::vector paddings(2); @@ -27,28 +42,15 @@ void WinogradTransformFunctor::operator()( paddings_.data(), dilations_.data(), strides_.data(), RoundType::FLOOR, output_shape.data()); } - const index_t round_h = (output_shape[1] + 1) / 2; const index_t round_w = (output_shape[2] + 1) / 2; const index_t out_width = input_tensor->dim(0) * round_h * round_w; - output_shape = {16, input_tensor->dim(3), out_width, 1}; - std::vector image_shape; - CalImage2DShape(output_shape, BufferType::IN_OUT_HEIGHT, image_shape); - output_tensor->ResizeImage(output_shape, image_shape); - - if (kernel_.get() == nullptr) { - std::string obfuscated_kernel_name = - MACE_OBFUSCATE_SYMBOL("winograd_transform_2x2"); - std::set built_options; - built_options.emplace("-Dwinograd_transform_2x2=" + obfuscated_kernel_name); - built_options.emplace("-DDATA_TYPE=" + - DtToUpstreamCLDt(DataTypeToEnum::value)); - built_options.emplace("-DCMD_DATA_TYPE=" + - DtToUpstreamCLCMDDt(DataTypeToEnum::value)); - auto runtime = OpenCLRuntime::Global(); - kernel_ = runtime->BuildKernel("winograd_transform", obfuscated_kernel_name, - built_options); + if (!IsVecEqual(input_shape_, input_tensor->shape())) { + output_shape = {16, input_tensor->dim(3), out_width, 1}; + std::vector image_shape; + CalImage2DShape(output_shape, BufferType::IN_OUT_HEIGHT, image_shape); + output_tensor->ResizeImage(output_shape, image_shape); uint32_t idx = 0; kernel_.setArg(idx++, *(input_tensor->opencl_image())); @@ -60,6 +62,8 @@ void WinogradTransformFunctor::operator()( kernel_.setArg(idx++, static_cast(round_w)); kernel_.setArg(idx++, static_cast(paddings[0] / 2)); kernel_.setArg(idx++, static_cast(paddings[1] / 2)); + + input_shape_ = input_tensor->shape(); } const uint32_t gws[2] = { @@ -79,11 +83,6 @@ void WinogradInverseTransformFunctor::operator()( const Tensor *bias, Tensor *output_tensor, StatsFuture *future) { - std::vector output_shape = {batch_, height_, width_, - input_tensor->dim(1)}; - std::vector image_shape; - CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, image_shape); - output_tensor->ResizeImage(output_shape, image_shape); if (kernel_.get() == nullptr) { std::string obfuscated_kernel_name = @@ -121,6 +120,13 @@ void WinogradInverseTransformFunctor::operator()( auto runtime = OpenCLRuntime::Global(); kernel_ = runtime->BuildKernel("winograd_transform", obfuscated_kernel_name, built_options); + } + if (!IsVecEqual(input_shape_, input_tensor->shape())) { + std::vector output_shape = {batch_, height_, width_, + input_tensor->dim(1)}; + std::vector image_shape; + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, image_shape); + output_tensor->ResizeImage(output_shape, image_shape); const uint32_t round_h = (height_ + 1) / 2; const uint32_t round_w = (width_ + 1) / 2; @@ -139,6 +145,8 @@ void WinogradInverseTransformFunctor::operator()( kernel_.setArg(idx++, static_cast(round_h * round_w)); kernel_.setArg(idx++, static_cast(round_w)); kernel_.setArg(idx++, relux_max_limit_); + + input_shape_ = input_tensor->shape(); } const uint32_t gws[2] = { diff --git a/mace/kernels/pooling.h b/mace/kernels/pooling.h index 6bd5d94e..bc9892e5 100644 --- a/mace/kernels/pooling.h +++ b/mace/kernels/pooling.h @@ -182,6 +182,7 @@ struct PoolingFunctor : PoolingFunctorBase { StatsFuture *future); cl::Kernel kernel_; + std::vector input_shape_; }; } // namespace kernels diff --git a/mace/kernels/resize_bilinear.h b/mace/kernels/resize_bilinear.h index bdd94192..52c1da10 100644 --- a/mace/kernels/resize_bilinear.h +++ b/mace/kernels/resize_bilinear.h @@ -172,6 +172,7 @@ struct ResizeBilinearFunctor void operator()(const Tensor *input, Tensor *output, StatsFuture *future); cl::Kernel kernel_; + std::vector input_shape_; }; } // namespace kernels diff --git a/mace/kernels/softmax.h b/mace/kernels/softmax.h index 7ff375d3..d5bc5717 100644 --- a/mace/kernels/softmax.h +++ b/mace/kernels/softmax.h @@ -57,6 +57,7 @@ struct SoftmaxFunctor { void operator()(const Tensor *logits, Tensor *output, StatsFuture *future); cl::Kernel kernel_; + std::vector input_shape_; }; } // namepsace kernels diff --git a/mace/kernels/space_to_batch.h b/mace/kernels/space_to_batch.h index 402bf97c..ef7467b5 100644 --- a/mace/kernels/space_to_batch.h +++ b/mace/kernels/space_to_batch.h @@ -54,6 +54,7 @@ struct SpaceToBatchFunctor : SpaceToBatchFunctorBase { StatsFuture *future); cl::Kernel kernel_; + std::vector space_shape_; }; } // namespace kernels diff --git a/mace/kernels/winograd_transform.h b/mace/kernels/winograd_transform.h index 464a59ce..f3b7f7d6 100644 --- a/mace/kernels/winograd_transform.h +++ b/mace/kernels/winograd_transform.h @@ -49,6 +49,7 @@ struct WinogradTransformFunctor void operator()(const Tensor *input, Tensor *output, StatsFuture *future); cl::Kernel kernel_; + std::vector input_shape_; }; struct WinogradInverseTransformFunctorBase { @@ -105,6 +106,7 @@ struct WinogradInverseTransformFunctor StatsFuture *future); cl::Kernel kernel_; + std::vector input_shape_; }; } // namespace kernels diff --git a/tools/wino_conv.py b/tools/wino_conv.py index 383def86..0dc3f8d6 100644 --- a/tools/wino_conv.py +++ b/tools/wino_conv.py @@ -96,7 +96,7 @@ def output_shape(input_shape, filter_shape): return out_shape -def winog_conv(m, r, input, filter): +def winograd_conv(m, r, input, filter): alpha = m + r - 1 print 'Winograd(m = %d, r = %d, tile size=%d' % (m, r, alpha) alpha_square = alpha * alpha @@ -194,14 +194,14 @@ def main(): # filter.tofile("filter_in") for i in [2, 4, 6]: print "==========f(%d,3)==========" % i - winog_out = winog_conv(i, 3, input, filter) - res = np.allclose(tf_out, winog_out) + winograd_out = winograd_conv(i, 3, input, filter) + res = np.allclose(tf_out, winograd_out) if res: print "=========Pass=========" else: print "=========Failed=======" print "TF: ", tf_out - print "Winograd: ", winog_out + print "Winograd: ", winograd_out if __name__ == '__main__': -- GitLab