diff --git a/mace/core/macros.h b/mace/core/macros.h index ced106e58aee1ff5f249007ff36255d456b1fc7b..0c00af5df0381e3e51bc6c0663588273aa3b8d38 100644 --- a/mace/core/macros.h +++ b/mace/core/macros.h @@ -17,4 +17,6 @@ #define MACE_PREDICT_TRUE(x) (x) #endif +#define MACE_UNUSED(var) (void)(var) + #endif // MACE_CORE_MACROS_H_ diff --git a/mace/kernels/opencl/cl/space_to_batch.cl b/mace/kernels/opencl/cl/space_to_batch.cl new file mode 100644 index 0000000000000000000000000000000000000000..921a3bf82e402439198081884a9bfefb164b4161 --- /dev/null +++ b/mace/kernels/opencl/cl/space_to_batch.cl @@ -0,0 +1,55 @@ +void kernel space_to_batch(global float *space_data_ptr, + global const int *block_shape_ptr, + global const int *paddings_ptr, + private const int space_batch, + private const int space_channel, + private const int space_height, + private const int space_width, + private const int batch_height, + private const int batch_width, + private const int b2s, + global float* batch_data_ptr) { + int batch_idx = get_global_id(0); + int batch_channel_idx = get_global_id(1); + int batch_pixel_idx = get_global_id(2); + + const int block_height = block_shape_ptr[0]; + const int block_width = block_shape_ptr[1]; + const int padding_height_start = paddings_ptr[0]; + const int padding_width_start = paddings_ptr[2]; + + const int batch_pixel_height_idx = batch_pixel_idx / batch_width; + const int batch_pixel_width_idx = batch_pixel_idx % batch_width; + + const int block_size = block_height * block_width; + const int space_idx = batch_idx / block_size; + const int remaining_batch_idx = batch_idx % block_size; + int space_pixel_height_idx = (remaining_batch_idx / block_width) + + batch_pixel_height_idx * block_height; + int space_pixel_width_idx = (remaining_batch_idx % block_width) + + batch_pixel_width_idx * block_width; + + const int batch_data_offset = batch_idx * (space_channel * batch_height * batch_width) + + (batch_channel_idx * batch_height * batch_width) + + batch_pixel_height_idx * batch_width + + batch_pixel_width_idx; + + space_pixel_height_idx -= padding_height_start; + space_pixel_width_idx -= padding_width_start; + const int space_data_offset = space_idx * (space_channel * space_height * space_width) + + (batch_channel_idx * space_height * space_width) + + space_pixel_height_idx * space_width + + space_pixel_width_idx; + if (space_pixel_height_idx < 0 || space_pixel_height_idx >= space_height || + space_pixel_width_idx < 0 || space_pixel_width_idx >= space_width) { + if (!b2s) { + *(batch_data_ptr + batch_data_offset) = 0; + } + } else { + if (b2s) { + *(space_data_ptr + space_data_offset) = *(batch_data_ptr + batch_data_offset); + } else { + *(batch_data_ptr + batch_data_offset) = *(space_data_ptr + space_data_offset); + } + } +} diff --git a/mace/kernels/opencl/space_to_batch_opecl.cc b/mace/kernels/opencl/space_to_batch_opecl.cc new file mode 100644 index 0000000000000000000000000000000000000000..84601492256ff1364cd581ec6b969b82f6489408 --- /dev/null +++ b/mace/kernels/opencl/space_to_batch_opecl.cc @@ -0,0 +1,53 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_KERNELS_OPENCL_SPACE_TO_BATCH_H_ +#define MACE_KERNELS_OPENCL_SPACE_TO_BATCH_H_ + +#include "mace/core/common.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/kernels/space_to_batch.h" + +namespace mace { +namespace kernels { + +template <> +void SpaceToBatchFunctor::operator()(Tensor *space_tensor, + const Tensor *block_shape_tensor, + const Tensor *paddings_tensor, + Tensor *batch_tensor) { + auto runtime = OpenCLRuntime::Get(); + auto program = runtime->program(); + auto s2b_kernel = cl::Kernel(program, "space_to_batch"); + + + uint32_t idx = 0; + s2b_kernel.setArg(idx++, *(static_cast(space_tensor->buffer()))); + s2b_kernel.setArg(idx++, *(static_cast(block_shape_tensor->buffer()))); + s2b_kernel.setArg(idx++, *(static_cast(paddings_tensor->buffer()))); + s2b_kernel.setArg(idx++, static_cast(space_tensor->dim(0))); + s2b_kernel.setArg(idx++, static_cast(space_tensor->dim(1))); + s2b_kernel.setArg(idx++, static_cast(space_tensor->dim(2))); + s2b_kernel.setArg(idx++, static_cast(space_tensor->dim(3))); + s2b_kernel.setArg(idx++, static_cast(batch_tensor->dim(2))); + s2b_kernel.setArg(idx++, static_cast(batch_tensor->dim(3))); + s2b_kernel.setArg(idx++, static_cast(b2s_)); + s2b_kernel.setArg(idx++, *(static_cast(batch_tensor->buffer()))); + + const uint32_t gws[3] = {static_cast(batch_tensor->dim(0)), + static_cast(batch_tensor->dim(1)), + static_cast(batch_tensor->dim(2) * batch_tensor->dim(3))}; + const uint32_t lws[3] = {static_cast(1), + static_cast(8), + static_cast(128)}; + cl_int error = runtime->command_queue().enqueueNDRangeKernel( + s2b_kernel, cl::NullRange, + cl::NDRange(gws[0], gws[1], gws[2]), + cl::NDRange(lws[0], lws[1], lws[2])); + MACE_CHECK(error == CL_SUCCESS); +} + +} // namespace kernels +} // namespace mace +#endif // MACE_KERNELS_OPENCL_SPACE_TO_BATCH_H_ diff --git a/mace/kernels/space_to_batch.h b/mace/kernels/space_to_batch.h new file mode 100644 index 0000000000000000000000000000000000000000..ebf5994b6c84819ec2e08fb7fb45b2eecf7f072b --- /dev/null +++ b/mace/kernels/space_to_batch.h @@ -0,0 +1,37 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_KERNELS_CONV_2D_H_ +#define MACE_KERNELS_CONV_2D_H_ + +#include "mace/core/tensor.h" +#include "mace/proto/mace.pb.h" + +namespace mace { +namespace kernels { + +template +struct SpaceToBatchFunctor { + SpaceToBatchFunctor(const bool b2s = false): b2s_(b2s){} + + void operator()(Tensor *input_tensor, + const Tensor *block_shape_tensor, + const Tensor *paddings_tensor, + Tensor *output_tensor) { + MACE_NOT_IMPLEMENTED; + } + + bool b2s_; +}; + +template <> +void SpaceToBatchFunctor::operator()(Tensor *input_tensor, + const Tensor *block_shape_tensor, + const Tensor *paddings_tensor, + Tensor *output); + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_CONV_2D_H_ diff --git a/mace/ops/batch_to_space.cc b/mace/ops/batch_to_space.cc new file mode 100644 index 0000000000000000000000000000000000000000..fa5db7cd470683d97147ee5baf52fb98f3f4753c --- /dev/null +++ b/mace/ops/batch_to_space.cc @@ -0,0 +1,11 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/batch_to_space.h" + +namespace mace { + +REGISTER_OPENCL_OPERATOR(BatchToSpaceND, BatchToSpaceNDOp); + +} // namespace mace diff --git a/mace/ops/batch_to_space.h b/mace/ops/batch_to_space.h new file mode 100644 index 0000000000000000000000000000000000000000..14b6444553860935d4fe4add7e353727b0d74c96 --- /dev/null +++ b/mace/ops/batch_to_space.h @@ -0,0 +1,77 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_SPACE_TO_BATCH_H_ +#define MACE_OPS_SPACE_TO_BATCH_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/space_to_batch.h" + +namespace mace { + +static void BatchToSpaceHelper(const Tensor *input_tensor, + const Tensor *block_shape_tensor, + const Tensor *cropped_tensor, + Tensor *output) { + MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D"); + MACE_CHECK(block_shape_tensor->dim_size() == 1, "Block's shape should be 1D"); + MACE_CHECK(cropped_tensor->dim_size() == 2, "Paddings' shape should be 2D"); + + const index_t block_dims = block_shape_tensor->dim(0); + MACE_CHECK(block_dims == cropped_tensor->dim(0) && 2 == cropped_tensor->dim(1)); + // TODO change tensor to attribute if needed based on the benchmark + Tensor::MappingGuard block_shape_tensor_mapper(block_shape_tensor); + Tensor::MappingGuard cropped_tensor_mapper(cropped_tensor); + const int *block_shape_ptr = block_shape_tensor->data(); + const int *cropped_ptr = cropped_tensor->data(); + std::vector output_shape(4, 0); + index_t block_shape_product = 1; + for (uint32_t block_dim = 0; block_dim < block_dims; ++block_dim) { + MACE_CHECK(block_shape_ptr[block_dim] > 1, "block_shape's value should be great to 1"); + const index_t block_shape_value = block_shape_ptr[block_dim]; + const index_t cropped_input_size = input_tensor->dim(block_dim + 2) * block_shape_value + - *cropped_ptr + - *(cropped_ptr+1); + MACE_CHECK(cropped_input_size >= 0, + "cropped size must be non-negative"); + block_shape_product *= block_shape_value; + output_shape[block_dim+2] = cropped_input_size; + cropped_ptr += 2; + } + output_shape[0] = input_tensor->dim(0) / block_shape_product; + output_shape[1] = input_tensor->dim(1); + + output->Resize(output_shape); +} + +template +class BatchToSpaceNDOp: public Operator { + public: + BatchToSpaceNDOp(const OperatorDef &op_def, Workspace *ws) + : Operator(op_def, ws), functor_(true) {} + + bool Run() override { + const Tensor *input_tensor = this->Input(INPUT); + const Tensor *block_shape_tensor = this->Input(BLOCK_SHAPE); + const Tensor *cropped_tensor = this->Input(CROPS); + Tensor *output = this->Output(OUTPUT); + + BatchToSpaceHelper(input_tensor, block_shape_tensor, cropped_tensor, output); + functor_(output, block_shape_tensor, cropped_tensor, const_cast(input_tensor)); + return true; + } + + private: + kernels::SpaceToBatchFunctor functor_; + + protected: + OP_INPUT_TAGS(INPUT, BLOCK_SHAPE, CROPS); + OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace mace + +#endif // MACE_OPS_SPACE_TO_BATCH_H_ diff --git a/mace/ops/batch_to_space_benchmark.cc b/mace/ops/batch_to_space_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..89e100f1ebfe3ac61db1a34e4b5b4d446ec7d4d9 --- /dev/null +++ b/mace/ops/batch_to_space_benchmark.cc @@ -0,0 +1,56 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +template +static void BMBatchToSpace( + int iters, int batch, int channels, int height, int width) { + mace::testing::StopTiming(); + + OpsTestNet net; + OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest") + .Input("Input") + .Input("BlockShape") + .Input("Crops") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput("Input", {batch, channels, height, width}); + net.AddInputFromArray( + "BlockShape", {2}, {2, 2}); + net.AddInputFromArray("Crops", {2, 2}, {0,1,0,1}); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} + +#define BM_BATCH_TO_SPACE_MACRO(N, C, H, W, TYPE, DEVICE) \ + static void BM_BATCH_TO_SPACE_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::ItemsProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + BMBatchToSpace(iters, N, C, H, W); \ + } \ + BENCHMARK(BM_BATCH_TO_SPACE_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) + +#define BM_BATCH_TO_SPACE(N, C, H, W, TYPE) \ + BM_BATCH_TO_SPACE_MACRO(N, C, H, W, TYPE, OPENCL); + +BM_BATCH_TO_SPACE(128, 128, 8, 8, float); +} // namespace mace \ No newline at end of file diff --git a/mace/ops/space_to_batch.cc b/mace/ops/space_to_batch.cc new file mode 100644 index 0000000000000000000000000000000000000000..8a7af417768038f6cb66048a375bb6e5ff8fa402 --- /dev/null +++ b/mace/ops/space_to_batch.cc @@ -0,0 +1,11 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/space_to_batch.h" + +namespace mace { + +REGISTER_OPENCL_OPERATOR(SpaceToBatchND, SpaceToBatchNDOp); + +} // namespace mace diff --git a/mace/ops/space_to_batch.h b/mace/ops/space_to_batch.h new file mode 100644 index 0000000000000000000000000000000000000000..079697d495a163912ca4443823fdca15f4d1bda2 --- /dev/null +++ b/mace/ops/space_to_batch.h @@ -0,0 +1,76 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_SPACE_TO_BATCH_H_ +#define MACE_OPS_SPACE_TO_BATCH_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/space_to_batch.h" + +namespace mace { + +static void SpaceToBatchHelper(const Tensor *input_tensor, + const Tensor *block_shape_tensor, + const Tensor *paddings_tensor, + Tensor *output) { + MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D"); + MACE_CHECK(block_shape_tensor->dim_size() == 1, "Block's shape should be 1D"); + MACE_CHECK(paddings_tensor->dim_size() == 2, "Paddings' shape should be 2D"); + + const index_t block_dims = block_shape_tensor->dim(0); + MACE_CHECK(block_dims == paddings_tensor->dim(0) && 2 == paddings_tensor->dim(1)); + Tensor::MappingGuard block_shape_tensor_mapper(block_shape_tensor); + Tensor::MappingGuard padding_tensor_mapper(paddings_tensor); + const int *block_shape_ptr = block_shape_tensor->data(); + const int *paddings_ptr = paddings_tensor->data(); + std::vector output_shape(4, 0); + index_t block_shape_product = 1; + for (uint32_t block_dim = 0; block_dim < block_dims; ++block_dim) { + MACE_CHECK(block_shape_ptr[block_dim] > 1, "block_shape's value should be great to 1"); + const index_t block_shape_value = block_shape_ptr[block_dim]; + const index_t padded_input_size = input_tensor->dim(block_dim + 2) + + *paddings_ptr + + *(paddings_ptr+1); + MACE_CHECK(padded_input_size % block_shape_value == 0, + "padded input is not divisible by block_shape"); + block_shape_product *= block_shape_value; + output_shape[block_dim+2] = padded_input_size / block_shape_value; + paddings_ptr += 2; + } + output_shape[0] = input_tensor->dim(0) * block_shape_product; + output_shape[1] = input_tensor->dim(1); + + output->Resize(output_shape); +} + +template +class SpaceToBatchNDOp : public Operator { + public: + SpaceToBatchNDOp(const OperatorDef &op_def, Workspace *ws) + : Operator(op_def, ws) {} + + bool Run() override { + const Tensor *input_tensor = this->Input(INPUT); + const Tensor *block_shape_tensor = this->Input(BLOCK_SHAPE); + const Tensor *paddings_tensor = this->Input(PADDINGS); + Tensor *output = this->Output(OUTPUT); + + SpaceToBatchHelper(input_tensor, block_shape_tensor, paddings_tensor, output); + functor_(const_cast(input_tensor), block_shape_tensor, paddings_tensor, output); + return true; + } + + private: + kernels::SpaceToBatchFunctor functor_; + + protected: + OP_INPUT_TAGS(INPUT, BLOCK_SHAPE, PADDINGS); + OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace mace + +#endif // MACE_OPS_SPACE_TO_BATCH_H_ diff --git a/mace/ops/space_to_batch_benchmark.cc b/mace/ops/space_to_batch_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e119a041fd33603d38ad99c0f5084575c25d20d --- /dev/null +++ b/mace/ops/space_to_batch_benchmark.cc @@ -0,0 +1,56 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +template +static void BMSpaceToBatch( + int iters, int batch, int channels, int height, int width) { + mace::testing::StopTiming(); + + OpsTestNet net; + OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest") + .Input("Input") + .Input("BlockShape") + .Input("Padding") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput("Input", {batch, channels, height, width}); + net.AddInputFromArray( + "BlockShape", {2}, {2, 2}); + net.AddInputFromArray("Padding", {2, 2}, {2,3,2,3}); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} + +#define BM_SPACE_TO_BATCH_MACRO(N, C, H, W, TYPE, DEVICE) \ + static void BM_SPACE_TO_BATCH_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::ItemsProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + BMSpaceToBatch(iters, N, C, H, W); \ + } \ + BENCHMARK(BM_SPACE_TO_BATCH_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) + +#define BM_SPACE_TO_BATCH(N, C, H, W, TYPE) \ + BM_SPACE_TO_BATCH_MACRO(N, C, H, W, TYPE, OPENCL); + +BM_SPACE_TO_BATCH(128, 128, 15, 15, float); +} // namespace mace \ No newline at end of file diff --git a/mace/ops/space_to_batch_test.cc b/mace/ops/space_to_batch_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..78e933df1aaa63e6d338b7d2822d3250e26889db --- /dev/null +++ b/mace/ops/space_to_batch_test.cc @@ -0,0 +1,199 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "gtest/gtest.h" +#include "mace/ops/ops_test_util.h" + +using namespace mace; + +template +void RunSpaceToBatch(const std::vector &input_shape, + const std::vector &input_data, + const std::vector &block_shape_shape, + const std::vector &block_shape_data, + const std::vector &padding_shape, + const std::vector &padding_data, + const Tensor *expected) { + OpsTestNet net; + OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest") + .Input("Input") + .Input("BlockShape") + .Input("Padding") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddInputFromArray( + "Input", input_shape, input_data); + net.AddInputFromArray( + "BlockShape", block_shape_shape, block_shape_data); + net.AddInputFromArray("Padding", padding_shape, padding_data); + + // Run + net.RunOp(D); + + // Check + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-8); + +} + +template +void RunBatchToSpace(const std::vector &input_shape, + const std::vector &input_data, + const std::vector &block_shape_shape, + const std::vector &block_shape_data, + const std::vector &crops_shape, + const std::vector &crops_data, + const Tensor *expected) { + OpsTestNet net; + OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest") + .Input("Input") + .Input("BlockShape") + .Input("Crops") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddInputFromArray( + "Input", input_shape, input_data); + net.AddInputFromArray( + "BlockShape", block_shape_shape, block_shape_data); + net.AddInputFromArray("Crops", crops_shape, crops_data); + + // Run + net.RunOp(D); + + // Check + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-8); +} + +template +void TestBidirectionTransform(const std::vector &space_shape, + const std::vector &space_data, + const std::vector &block_shape, + const std::vector &block_data, + const std::vector &padding_shape, + const std::vector &padding_data, + const std::vector &batch_shape, + const std::vector &batch_data) { + + auto space_tensor = unique_ptr(new Tensor(GetDeviceAllocator(DeviceType::OPENCL), + DataTypeToEnum::v())); + space_tensor->Resize(space_shape); + { + Tensor::MappingGuard space_mapper(space_tensor.get()); + T *space_ptr = space_tensor->mutable_data(); + MACE_CHECK(static_cast(space_tensor->size()) == space_data.size()) + << "Space tensor size:" << space_tensor->size() + << ", space data size:" << space_data.size(); + memcpy(space_ptr, space_data.data(), space_data.size() * sizeof(T)); + } + + auto batch_tensor = unique_ptr(new Tensor(GetDeviceAllocator(DeviceType::OPENCL), + DataTypeToEnum::v())); + batch_tensor->Resize(batch_shape); + { + Tensor::MappingGuard batch_mapper(batch_tensor.get()); + T *batch_ptr = batch_tensor->mutable_data(); + MACE_CHECK(static_cast(batch_tensor->size()) == batch_data.size()); + memcpy(batch_ptr, batch_data.data(), batch_data.size() * sizeof(T)); + } + + RunSpaceToBatch(space_shape, space_data, + block_shape, block_data, + padding_shape, padding_data, + batch_tensor.get()); + + RunBatchToSpace(batch_shape, batch_data, + block_shape, block_data, + padding_shape, padding_data, + space_tensor.get()); +} + +TEST(SpaceToBatchTest, SmallData) { + TestBidirectionTransform({1, 1, 2, 2}, + {1,2,3,4}, + {2}, + {2, 2}, + {2, 2}, + {0, 0, 0, 0}, + {4,1,1,1}, + {1,2,3,4} + ); +} + +TEST(SpaceToBatchTest, SmallDataWithOnePadding) { + TestBidirectionTransform({1, 1, 2, 2}, + {1,2,3,4}, + {2}, + {3, 3}, + {2, 2}, + {1, 0, 1, 0}, + {9,1,1,1}, + {0,0,0,0,1,2,0,3,4} + ); +} + +TEST(SpaceToBatchTest, SmallDataWithTwoPadding) { + TestBidirectionTransform({1, 1, 2, 2}, + {1,2,3,4}, + {2}, + {2, 2}, + {2, 2}, + {1, 1, 1, 1}, + {4,1,2,2}, + {0,0,0,4,0,0,3,0,0,2,0,0,1,0,0,0} + ); +} + +TEST(SpaceToBatchTest, MultiChannelData) { + TestBidirectionTransform({1, 3, 2, 2}, + {1,2,3,4,5,6,7,8,9,10,11,12}, + {2}, + {2, 2}, + {2, 2}, + {0, 0, 0, 0}, + {4,3,1,1}, + {1,5,9,2,6,10,3,7,11,4,8,12} + ); +} + +TEST(SpaceToBatchTest, LargerMultiChannelData) { + TestBidirectionTransform({1, 1, 4, 4}, + {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}, + {2}, + {2, 2}, + {2, 2}, + {0, 0, 0, 0}, + {4,1,2,2}, + {1,3,9,11,2,4,10,12,5,7,13,15,6,8,14,16} + ); +} + +TEST(SpaceToBatchTest, MultiBatchData) { + TestBidirectionTransform({2, 1, 2, 4}, + {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}, + {2}, + {2, 2}, + {2, 2}, + {0, 0, 0, 0}, + {8,1,1,2}, + {1,3,2,4,5,7,6,8,9,11,10,12,13,15,14,16} + ); +} + +TEST(SpaceToBatchTest, MultiBatchAndChannelData) { + TestBidirectionTransform({2, 2, 2, 4}, + {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16, + 17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32}, + {2}, + {2, 2}, + {2, 2}, + {0, 0, 0, 0}, + {8,2,1,2}, + {1,3,9,11,2,4,10,12,5,7,13,15,6,8,14,16, + 17,19,25,27,18,20,26,28,21,23,29,31,22,24,30,32} + ); +} +