diff --git a/mace/core/macros.h b/mace/core/macros.h index ced106e58aee1ff5f249007ff36255d456b1fc7b..0c00af5df0381e3e51bc6c0663588273aa3b8d38 100644 --- a/mace/core/macros.h +++ b/mace/core/macros.h @@ -17,4 +17,6 @@ #define MACE_PREDICT_TRUE(x) (x) #endif +#define MACE_UNUSED(var) (void)(var) + #endif // MACE_CORE_MACROS_H_ diff --git a/mace/kernels/opencl/cl/space_to_batch.cl b/mace/kernels/opencl/cl/space_to_batch.cl new file mode 100644 index 0000000000000000000000000000000000000000..fa432dd87e3e1229695192e137e2dce2fd487a7b --- /dev/null +++ b/mace/kernels/opencl/cl/space_to_batch.cl @@ -0,0 +1,39 @@ +void kernel space_to_batch(global float* space_data_ptr, + private const int space_batch, + private const int space_channel, + private const int space_height, + private const int space_width, + private const int block_height, + private const int block_width, + private const int b2s, + global float* batch_data_ptr) { + int batch_idx = get_global_id(0); + int batch_channel_idx = get_global_id(1); + int batch_pixel_idx = get_global_id(2); + + const int batch_height = space_height / block_height; + const int batch_width = space_width / block_width; + const int batch_pixel_height_idx = batch_pixel_idx / batch_width; + const int batch_pixel_width_idx = batch_pixel_idx % batch_width; + + const int block_size = block_height * block_width; + const int space_idx = batch_idx / block_size; + const int remaining_batch_idx = batch_idx % block_size; + const int space_pixel_height_idx = (remaining_batch_idx / block_width) + + batch_pixel_height_idx * block_height; + const int space_pixel_width_idx = (remaining_batch_idx % block_width) + + batch_pixel_width_idx * block_width; + const int batch_data_offset = batch_idx * (space_channel * batch_height * batch_width) + + (batch_channel_idx * batch_height * batch_width) + + batch_pixel_height_idx * batch_width + + batch_pixel_width_idx; + const int space_data_offset = space_idx * (space_channel * space_height * space_width) + + (batch_channel_idx * space_height * space_width) + + space_pixel_height_idx * space_width + + space_pixel_width_idx; + if (b2s) { + *(space_data_ptr + space_data_offset) = *(batch_data_ptr + batch_data_offset); + } else { + *(batch_data_ptr + batch_data_offset) = *(space_data_ptr + space_data_offset); + } +} diff --git a/mace/kernels/opencl/conv_2d_opencl.cc b/mace/kernels/opencl/conv_2d_opencl.cc index fcdb3de208fa6da2997dc391d94f514107bc60d7..11099d9d1280a388840bc8a7a7f3e174cc207205 100644 --- a/mace/kernels/opencl/conv_2d_opencl.cc +++ b/mace/kernels/opencl/conv_2d_opencl.cc @@ -8,20 +8,24 @@ namespace mace { namespace kernels { extern void Conv2dOpenclK1x1S1(const Tensor *input, const Tensor *filter, - const Tensor *bias, Tensor *output); + const Tensor *bias, const int dilation_height, + const int dilation_width, Tensor *output); extern void Conv2dOpenclK3x3S1(const Tensor *input, const Tensor *filter, - const Tensor *bias, Tensor *output); + const Tensor *bias, const int dilation_height, + const int dilation_width, Tensor *output); extern void Conv2dOpenclK3x3S2(const Tensor *input, const Tensor *filter, - const Tensor *bias, Tensor *output); + const Tensor *bias, const int dilation_height, + const int dilation_width, Tensor *output); template <> void Conv2dFunctor::operator()(const Tensor *input, const Tensor *filter, const Tensor *bias, Tensor *output) { typedef void (*Conv2dOpenclFunction)(const Tensor *input, const Tensor *filter, - const Tensor *bias, Tensor *output); + const Tensor *bias, const int dilation_height, + const int dilation_width, Tensor *output); // Selection matrix: kernel_size x stride_size static const Conv2dOpenclFunction selector[5][2] = { {Conv2dOpenclK1x1S1, nullptr}, @@ -33,8 +37,7 @@ void Conv2dFunctor::operator()(const Tensor *input, index_t kernel_h = filter->shape()[2]; index_t kernel_w = filter->shape()[3]; if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] || - strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 || - selector[kernel_h - 1][strides_[0] - 1] == nullptr) { + strides_[0] > 2 || selector[kernel_h - 1][strides_[0] - 1] == nullptr) { LOG(WARNING) << "OpenCL conv2d kernel with " << "filter" << kernel_h << "x" << kernel_w << "," << " stride " << strides_[0] << "x" << strides_[1] @@ -50,9 +53,9 @@ void Conv2dFunctor::operator()(const Tensor *input, Tensor::MappingGuard input_mapper(input); ConstructInputWithPadding(input->data(), input->shape().data(), paddings_.data(), &padded_input); - conv2d_func(&padded_input, filter, bias, output); + conv2d_func(&padded_input, filter, bias, dilations_[0], dilations_[1], output); }else { - conv2d_func(input, filter, bias, output); + conv2d_func(input, filter, bias, dilations_[0], dilations_[1], output); } } diff --git a/mace/kernels/opencl/conv_2d_opencl_1x1.cc b/mace/kernels/opencl/conv_2d_opencl_1x1.cc index 130ca4b7ab166d238c36484a5452565c119dde94..424f2a95523770c0b5ace8321ff9e14414cd9905 100644 --- a/mace/kernels/opencl/conv_2d_opencl_1x1.cc +++ b/mace/kernels/opencl/conv_2d_opencl_1x1.cc @@ -7,6 +7,7 @@ #include "mace/core/runtime/opencl/cl2_header.h" #include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/utils/utils.h" +#include "mace/core/macros.h" namespace mace { namespace kernels { @@ -173,7 +174,11 @@ void Conv1x1V3(const Tensor *input, extern void Conv2dOpenclK1x1S1(const Tensor *input, const Tensor *filter, const Tensor *bias, + const int dilation_height, + const int dilation_width, Tensor *output) { + MACE_UNUSED(dilation_height); + MACE_UNUSED(dilation_width); const index_t batch = output->shape()[0]; const index_t height = output->shape()[2]; const index_t width = output->shape()[3]; diff --git a/mace/kernels/opencl/conv_2d_opencl_3x3.cc b/mace/kernels/opencl/conv_2d_opencl_3x3.cc index 41dccf4c4ef9220ae7822df5f817705ed9ffcbd0..4738d165cdb5c455eda43053ecfd1924d14024e5 100644 --- a/mace/kernels/opencl/conv_2d_opencl_3x3.cc +++ b/mace/kernels/opencl/conv_2d_opencl_3x3.cc @@ -3,14 +3,19 @@ // #include "mace/core/common.h" +#include "mace/core/macros.h" #include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/kernels/conv_2d.h" +#include "mace/kernels/opencl/space_to_batch.h" namespace mace { namespace kernels { + static void InnerConv2dK3x3S12(const Tensor *input, const Tensor *filter, - const Tensor *bias, const uint32_t stride, Tensor *output) { + const Tensor *bias, const uint32_t stride, + Tensor *output, const std::vector *waiting_events, + cl::Event *ret_event) { const index_t channels = output->shape()[1]; const index_t height = output->shape()[2]; const index_t width = output->shape()[3]; @@ -46,18 +51,75 @@ static void InnerConv2dK3x3S12(const Tensor *input, const Tensor *filter, cl_int error = runtime->command_queue().enqueueNDRangeKernel( conv_kernel, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]), - cl::NDRange(lws[0], lws[1], lws[2])); + cl::NDRange(lws[0], lws[1], lws[2]), + waiting_events, + ret_event); MACE_CHECK(error == CL_SUCCESS); } +static void CalOutputShape(const std::vector &input_shape, + const std::vector &filter_shape, + const int dilation_height, + const int dilation_width, + std::vector &output_shape) { + index_t kernel_height = filter_shape[2]; + index_t kernel_width = filter_shape[3]; + index_t output_channels = filter_shape[0]; + + index_t k_extent_height = (kernel_height - 1) * dilation_height + 1; + index_t k_extent_width = (kernel_width - 1) * dilation_width + 1; + index_t output_height = input_shape[2] - k_extent_height + 1; + index_t output_width = input_shape[3] - k_extent_width + 1; + output_shape[0] = input_shape[0]; + output_shape[1] = output_channels; + output_shape[2] = output_height; + output_shape[3] = output_width; +} +static void ResizeBatchTensor(const std::vector &input_shape, + const int dilation_height, + const int dilation_width, + Tensor *batch_tensor) { + LOG(INFO) << input_shape[2] << "\t" << input_shape[3] << "\t" <Resize({input_shape[0] * dilation_height * dilation_width, + input_shape[1], + input_shape[2] / dilation_height, + input_shape[3] / dilation_width} + ); + LOG(INFO) << batch_tensor->dim(2) << "\t" << batch_tensor->dim(3) << "\t" < 1 && dilation_width > 1) { + cl::Event events[2]; + + Tensor reshaped_input_tensor(GetDeviceAllocator(DeviceType::OPENCL), input->dtype()); + ResizeBatchTensor(input->shape(), dilation_height, dilation_width, &reshaped_input_tensor); + SpaceToBatch(const_cast(input), dilation_height, dilation_width, + &reshaped_input_tensor, nullptr, &events[0]); + Tensor reshaped_output_tensor(GetDeviceAllocator(DeviceType::OPENCL), input->dtype()); + std::vector reshaped_output_shape(4, 0); + CalOutputShape(reshaped_input_tensor.shape(), filter->shape(), + dilation_height, dilation_width, reshaped_output_shape); + reshaped_output_tensor.Resize(reshaped_output_shape); + std::vector s2b_events(1, events[0]); + InnerConv2dK3x3S12(&reshaped_input_tensor, filter, bias, 1, &reshaped_output_tensor, + &s2b_events, &events[1]); + std::vector conv_events(1, events[1]); + SpaceToBatch(&reshaped_output_tensor, dilation_height, dilation_width, + output, &conv_events, nullptr); + } else { + InnerConv2dK3x3S12(input, filter, bias, 1, output, nullptr, nullptr); + } }; void Conv2dOpenclK3x3S2(const Tensor *input, const Tensor *filter, - const Tensor *bias, Tensor *output) { - InnerConv2dK3x3S12(input, filter, bias, 2, output); + const Tensor *bias, const int dilation_height, + const int dilation_width, Tensor *output) { + MACE_UNUSED(dilation_height); + MACE_UNUSED(dilation_width); + InnerConv2dK3x3S12(input, filter, bias, 2, output, nullptr, nullptr); }; } // namespace kernels diff --git a/mace/kernels/opencl/space_to_batch.h b/mace/kernels/opencl/space_to_batch.h new file mode 100644 index 0000000000000000000000000000000000000000..0dd393b882bcbcb6a23f81cd6817bfa20034563a --- /dev/null +++ b/mace/kernels/opencl/space_to_batch.h @@ -0,0 +1,54 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_KERNELS_OPENCL_SPACE_TO_BATCH_H_ +#define MACE_KERNELS_OPENCL_SPACE_TO_BATCH_H_ + +#include "mace/core/common.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/core/tensor.h" + +namespace mace { +namespace kernels { + +template +void SpaceToBatch(Tensor *space_tensor, + const int block_height, + const int block_width, + Tensor *batch_tensor, + const std::vector *waiting_events, + cl::Event *event) { + auto runtime = OpenCLRuntime::Get(); + auto program = runtime->program(); + auto s2b_kernel = cl::Kernel(program, "space_to_batch"); + + uint32_t idx = 0; + s2b_kernel.setArg(idx++, *(static_cast(space_tensor->buffer()))); + s2b_kernel.setArg(idx++, static_cast(space_tensor->dim(0))); + s2b_kernel.setArg(idx++, static_cast(space_tensor->dim(1))); + s2b_kernel.setArg(idx++, static_cast(space_tensor->dim(2))); + s2b_kernel.setArg(idx++, static_cast(space_tensor->dim(3))); + s2b_kernel.setArg(idx++, block_height); + s2b_kernel.setArg(idx++, block_width); + s2b_kernel.setArg(idx++, static_cast(B2S)); + s2b_kernel.setArg(idx++, *(static_cast(batch_tensor->buffer()))); + + const uint32_t gws[3] = {static_cast(batch_tensor->dim(0)), + static_cast(batch_tensor->dim(1)), + static_cast(batch_tensor->dim(2) * batch_tensor->dim(3))}; + const uint32_t lws[3] = {static_cast(1), + static_cast(8), + static_cast(128)}; + cl_int error = runtime->command_queue().enqueueNDRangeKernel( + s2b_kernel, cl::NullRange, + cl::NDRange(gws[0], gws[1], gws[2]), + cl::NDRange(lws[0], lws[1], lws[2]), + waiting_events, + event); + MACE_CHECK(error == CL_SUCCESS); +} + +} // namespace kernels +} // namespace mace +#endif // MACE_KERNELS_OPENCL_SPACE_TO_BATCH_H_ diff --git a/mace/ops/BUILD b/mace/ops/BUILD index e823136d965670cc198ed14c0f682cbf0b152d00..0f5bdbaae97c365268460b60774f7e5f56d28b6a 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -62,6 +62,35 @@ cc_test( ], ) +cc_test( + name = "space_to_batch_test", + testonly = 1, + srcs = glob(["space_to_batch_test.cc"]), + copts = ["-std=c++11"], + linkopts = if_android(["-pie"]), + linkstatic = 1, + deps = [ + "//mace/kernels", + "//mace/core", + "//mace/ops:test", + "@gtest//:gtest_main", + ], +) + +cc_test( + name = "conv_atrous_2d_test", + testonly = 1, + srcs = glob(["conv_atrous_2d_test.cc"]), + copts = ["-std=c++11"], + linkopts = ["-fopenmp"] + if_android(["-ldl"]), + linkstatic = 1, + deps = [ + ":ops", + ":test", + "@gtest//:gtest_main", + ], +) + cc_test( name = "ops_benchmark", testonly = 1, diff --git a/mace/ops/conv_atrous_2d_test.cc b/mace/ops/conv_atrous_2d_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6d4a23a3adef1c5458e88b4445aa252924b4399d --- /dev/null +++ b/mace/ops/conv_atrous_2d_test.cc @@ -0,0 +1,208 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/ops_test_util.h" +#include "mace/kernels/conv_pool_2d_util.h" + +using namespace mace; + +class AtrousConv2dOpTest : public OpsTestBase {}; + +static void UpSampleFilter(const std::vector &filter_shape, + const std::vector &filter_data, + const int dilation_rate, + std::vector &upsampled_filter_shape, + std::vector &upsampled_filter_data) { + upsampled_filter_shape[0] = filter_shape[0]; + upsampled_filter_shape[1] = filter_shape[1]; + upsampled_filter_shape[2] = filter_shape[2] + (filter_shape[2] - 1) * (dilation_rate - 1); + upsampled_filter_shape[3] = filter_shape[3] + (filter_shape[3] - 1) * (dilation_rate - 1); + const index_t upsampled_filter_size = std::accumulate(upsampled_filter_shape.begin(), + upsampled_filter_shape.end(), + 1, std::multiplies()); + upsampled_filter_data.resize(upsampled_filter_size, 0); + index_t filter_idx = 0; + index_t upsampled_filter_idx = 0; + for (index_t n = 0; n < filter_shape[0]; ++n) { + for (index_t c = 0; c < filter_shape[1]; ++c) { + for (index_t h = 0; h < filter_shape[2]; ++h) { + for (index_t w = 0; w < filter_shape[3]; ++w) { + upsampled_filter_data[upsampled_filter_idx] = filter_data[filter_idx]; + filter_idx += 1; + upsampled_filter_idx += dilation_rate; + } + upsampled_filter_idx += 1 - dilation_rate + (dilation_rate-1) * upsampled_filter_shape[3]; + } + upsampled_filter_idx -= (dilation_rate-1) * upsampled_filter_shape[3]; + } + } +} + +template +static void RunConv2D(const std::vector &input_shape, + const std::vector &input_data, + const std::vector &filter_shape, + const std::vector &filter_data, + const std::vector &bias_shape, + const std::vector &bias_data, + const int dilation_h, + const int dilation_w, + Padding padding, + Tensor *result) { + OpsTestNet net; + OpDefBuilder("Conv2D", "Conv2dTest") + .Input("Input") + .Input("Filter") + .Input("Bias") + .Output("Output") + .AddIntsArg("strides", {1, 1}) + .AddIntArg("padding", padding) + .AddIntsArg("dilations", {dilation_h, dilation_w}) + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddInputFromArray( + "Input", input_shape, input_data); + net.AddInputFromArray( + "Filter", filter_shape, filter_data); + net.AddInputFromArray("Bias", bias_shape, bias_data); + + // Run + net.RunOp(D); + + // Check + result->Copy(*net.GetOutput("Output")); +} + +template +static void GenerateAndRunConv2D(const index_t batch, + const index_t input_channels, + const index_t height, + const index_t width, + const index_t output_channels, + const index_t kernel_h, + const index_t kernel_w, + Padding padding, + const int dilation_rate) { + srand(time(NULL)); + // Add input data + std::vector input_shape = {batch, input_channels, height, width}; + std::vector input_data; + GenerateRandomRealTypeData(input_shape, input_data); + std::vector filter_shape = {output_channels, input_channels, kernel_h, kernel_w}; + std::vector filter_data; + GenerateRandomRealTypeData(filter_shape, filter_data); + std::vector bias_shape = {output_channels}; + std::vector bias_data; + GenerateRandomRealTypeData(bias_shape, bias_data); + + std::vector upsampled_filter_shape(4, 0); + std::vector upsampled_filter_data; + UpSampleFilter(filter_shape, filter_data, dilation_rate, + upsampled_filter_shape, upsampled_filter_data); + Tensor expected_result; + // Run on cpu + RunConv2D(input_shape, input_data, + upsampled_filter_shape, upsampled_filter_data, + bias_shape, bias_data, + 1, 1, + padding, &expected_result); + + Tensor device_result(GetDeviceAllocator(D), DataTypeToEnum::v()); + // run on device + RunConv2D(input_shape, input_data, + filter_shape, filter_data, + bias_shape, bias_data, + dilation_rate, dilation_rate, + padding, &device_result); + ExpectTensorNear(expected_result, device_result, 0.001); +} +template +static void TestSimple(const int kernel_h, + const int kernel_w, + Padding padding, + const int dilation_rate) { + GenerateAndRunConv2D(1, 3, 5, 5, 1, kernel_h, kernel_w, padding, dilation_rate); +} + +TEST_F(AtrousConv2dOpTest, CPUSimple) { + for (int i = 2 ; i < 4; ++i) { + TestSimple(3, 3, VALID, i); + TestSimple(3, 3, SAME, i); + } +} + +TEST_F(AtrousConv2dOpTest, OPENCLSimple) { + for (int i = 2 ; i < 3; ++i) { + TestSimple(3, 3, VALID, i); + } +} + +template +static void TestAligned(const int kernel_h, + const int kernel_w, + Padding padding, + const int dilation_rate) { + GenerateAndRunConv2D(3, 64, 32, 32, 128, kernel_h, kernel_w, padding, dilation_rate); +} + +template +static void TestUnAligned(const int kernel_h, + const int kernel_w, + Padding padding, + const int dilation_rate) { + srand(time(NULL)); + // generate random input + index_t batch = 3 + rand() % 10; + index_t input_channels = 3 + rand() % 10; + index_t height = 107; + index_t width = 113; + index_t output_channels = 3 + rand() % 10; + + GenerateAndRunConv2D(batch, input_channels, height, width, output_channels, + kernel_h, kernel_w, padding, dilation_rate); +} + +TEST_F(AtrousConv2dOpTest, UpSample) { + const int batch = 2; + const int channel = 2; + const int height = 3; + const int width = 3; + const int rate = 2; + std::vector filter_shape = {batch, channel, height, width}; + std::vector filter_data(batch*channel*height*width, 1); + std::vector upsampled_filter_shape(4, 0); + std::vector upsampled_filter_data; + UpSampleFilter(filter_shape, filter_data, rate, + upsampled_filter_shape, upsampled_filter_data); + int size = std::accumulate(upsampled_filter_shape.begin(), upsampled_filter_shape.end(), + 1, std::multiplies()); + const int expected_size = batch * channel * + (height + (height-1) * (rate - 1)) * + (width + (width-1) * (rate-1)); + EXPECT_EQ(expected_size, upsampled_filter_data.size()); +} + + +TEST_F(AtrousConv2dOpTest, CPUAligned) { + for (int i = 2 ; i < 4; ++i) { + TestAligned(3, 3, VALID, i); + TestAligned(3, 3, SAME, i); + } +} + +TEST_F(AtrousConv2dOpTest, OPENCLAligned) { + for (int i = 2 ; i < 4; ++i) { + TestAligned(3, 3, VALID, i); + TestAligned(3, 3, SAME, i); + } +} + +TEST_F(AtrousConv2dOpTest, CPUUnAligned) { + for (int i = 2 ; i < 4; ++i) { + TestUnAligned(3, 3, VALID, i); + TestUnAligned(3, 3, SAME, i); + } +} + diff --git a/mace/ops/space_to_batch_test.cc b/mace/ops/space_to_batch_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7454a9ba84ab821f172060261e64fb22ac565ec3 --- /dev/null +++ b/mace/ops/space_to_batch_test.cc @@ -0,0 +1,108 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/kernels/opencl/space_to_batch.h" +#include "gtest/gtest.h" +#include "mace/ops/ops_test_util.h" + +using namespace mace; + +template +void TestBidirectionTransform(const std::vector &space_shape, + const std::vector &space, + const int block_height, + const int block_width, + const std::vector &batch_shape, + const std::vector &batch) { + + auto space_tensor = unique_ptr(new Tensor(GetDeviceAllocator(DeviceType::OPENCL), + DataTypeToEnum::v())); + space_tensor->Resize(space_shape); + { + Tensor::MappingGuard space_mapper(space_tensor.get()); + T *space_data = space_tensor->mutable_data(); + MACE_CHECK(static_cast(space_tensor->size()) == space.size()) + << "Space tensor size:" << space_tensor->size() + << ", space data size:" << space.size(); + memcpy(space_data, space.data(), space.size() * sizeof(T)); + } + + auto batch_tensor = unique_ptr(new Tensor(GetDeviceAllocator(DeviceType::OPENCL), + DataTypeToEnum::v())); + batch_tensor->Resize(batch_shape); + { + Tensor::MappingGuard batch_mapper(batch_tensor.get()); + T *batch_data = batch_tensor->mutable_data(); + MACE_CHECK(static_cast(batch_tensor->size()) == batch.size()); + memcpy(batch_data, batch.data(), batch.size() * sizeof(T)); + } + + auto inner_batch_tensor = unique_ptr(new Tensor(GetDeviceAllocator(DeviceType::OPENCL), + DataTypeToEnum::v())); + inner_batch_tensor->Resize(batch_shape); + kernels::SpaceToBatch(space_tensor.get(), block_height, block_width, + inner_batch_tensor.get(), nullptr, nullptr); + ExpectTensorNear(*batch_tensor, *inner_batch_tensor, 1e-8); + auto inner_space_tensor = unique_ptr(new Tensor(GetDeviceAllocator(DeviceType::OPENCL), + DataTypeToEnum::v())); + inner_space_tensor->Resize(space_shape); + kernels::SpaceToBatch(inner_space_tensor.get(), block_height, block_width, + batch_tensor.get(), nullptr, nullptr); + ExpectTensorNear(*space_tensor, *inner_space_tensor, 1e-8); +} + +TEST(SpaceToBatchTest, NoTransform) { + TestBidirectionTransform({1, 1, 2, 2}, + {1,2,3,4}, + 1, 1, + {1,1,2,2}, + {1,2,3,4}); +} + +TEST(SpaceToBatchTest, SmallData) { + TestBidirectionTransform({1, 1, 2, 2}, + {1,2,3,4}, + 2, 2, + {4,1,1,1}, + {1,2,3,4}); +} + +TEST(SpaceToBatchTest, MultiChannelData) { + TestBidirectionTransform({1, 3, 2, 2}, + {1,2,3,4,5,6,7,8,9,10,11,12}, + 2, 2, + {4,3,1,1}, + {1,5,9,2,6,10,3,7,11,4,8,12} + ); +} + +TEST(SpaceToBatchTest, LargerMultiChannelData) { + TestBidirectionTransform({1, 1, 4, 4}, + {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}, + 2, 2, + {4,1,2,2}, + {1,3,9,11,2,4,10,12,5,7,13,15,6,8,14,16} + ); +} + +TEST(SpaceToBatchTest, MultiBatchData) { + TestBidirectionTransform({2, 1, 2, 4}, + {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}, + 2, 2, + {8,1,1,2}, + {1,3,2,4,5,7,6,8,9,11,10,12,13,15,14,16} + ); +} + +TEST(SpaceToBatchTest, MultiBatchAndChannelData) { + TestBidirectionTransform({2, 2, 2, 4}, + {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16, + 17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32}, + 2, 2, + {8,2,1,2}, + {1,3,9,11,2,4,10,12,5,7,13,15,6,8,14,16, + 17,19,25,27,18,20,26,28,21,23,29,31,22,24,30,32} + ); +} +