From 2a7274f473c101831acf6e99c66b09f3f1b63d0d Mon Sep 17 00:00:00 2001 From: Liangliang He Date: Mon, 23 Oct 2017 09:59:19 +0800 Subject: [PATCH] Add conv_2d 1x1 basic opencl implementation --- mace/core/tensor.h | 10 ++-- mace/kernels/conv_2d.h | 5 ++ mace/kernels/opencl/cl/assign_f32.cl | 17 +++++++ mace/kernels/opencl/cl/conv_2d_1x1.cl | 61 +++++++++++++++++++++++ mace/kernels/opencl/conv_2d_opencl.cc | 2 +- mace/kernels/opencl/conv_2d_opencl_1x1.cc | 60 +++++++++++++++++++++- mace/ops/addn_benchmark.cc | 2 +- mace/ops/addn_test.cc | 6 +-- mace/ops/batch_norm_benchmark.cc | 12 ++--- mace/ops/batch_norm_test.cc | 24 ++++----- mace/ops/channel_shuffle_benchmark.cc | 2 +- mace/ops/channel_shuffle_test.cc | 2 +- mace/ops/concat_benchmark.cc | 6 +-- mace/ops/concat_test.cc | 16 +++--- mace/ops/conv_2d_benchmark.cc | 6 +-- mace/ops/conv_2d_test.cc | 52 ++++++++++--------- mace/ops/depthwise_conv2d_test.cc | 12 ++--- mace/ops/depthwise_conv_2d_benchmark.cc | 6 +-- mace/ops/global_avg_pooling_benchmark.cc | 2 +- mace/ops/global_avg_pooling_test.cc | 4 +- mace/ops/ops_test_util.h | 19 ++++--- mace/ops/pooling_benchmark.cc | 2 +- mace/ops/pooling_test.cc | 14 +++--- mace/ops/relu_benchmark.cc | 2 +- mace/ops/relu_test.cc | 4 +- mace/ops/resize_bilinear_test.cc | 8 +-- 26 files changed, 255 insertions(+), 101 deletions(-) create mode 100644 mace/kernels/opencl/cl/assign_f32.cl create mode 100644 mace/kernels/opencl/cl/conv_2d_1x1.cl diff --git a/mace/core/tensor.h b/mace/core/tensor.h index 5862b84f..5a7bd9d9 100644 --- a/mace/core/tensor.h +++ b/mace/core/tensor.h @@ -82,6 +82,11 @@ class Tensor { inline index_t size() const { return size_; } + inline int64_t NumElements() const { + return std::accumulate(shape_.begin(), shape_.end(), 1, + std::multiplies()); + } + inline const bool OnHost() const { return alloc_->OnHost(); } /* @@ -215,11 +220,6 @@ class Tensor { }; private: - inline int64_t NumElements() const { - return std::accumulate(shape_.begin(), shape_.end(), 1, - std::multiplies()); - } - inline void *MappedBuffer() const { if (OnHost()) { return buffer_; diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index 8a5cff2a..c960f285 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -117,6 +117,11 @@ void Conv2dFunctor::operator()(const Tensor *input, const Tensor *filter, const Tensor *bias, Tensor *output); +template <> +void Conv2dFunctor::operator()(const Tensor *input, + const Tensor *filter, + const Tensor *bias, + Tensor *output); } // namespace kernels } // namespace mace diff --git a/mace/kernels/opencl/cl/assign_f32.cl b/mace/kernels/opencl/cl/assign_f32.cl new file mode 100644 index 00000000..c73036ee --- /dev/null +++ b/mace/kernels/opencl/cl/assign_f32.cl @@ -0,0 +1,17 @@ +void kernel assign_f32(global float *vec, private const float value) { + int idx = get_global_id(0); + vec[idx] = value; +} + +void kernel assign_vec_f32(global float *vec, + global float *values, + private int pixels) { + int batch = get_global_id(0); + int channel = get_global_id(1); + int channels = get_global_size(1); + float value = values[channel]; + float *ptr = vec + (batch * channels + channel) * pixels; + for (int i = 0; i < pixels; ++i) { + ptr[i] = value; + } +} diff --git a/mace/kernels/opencl/cl/conv_2d_1x1.cl b/mace/kernels/opencl/cl/conv_2d_1x1.cl new file mode 100644 index 00000000..abc91aa9 --- /dev/null +++ b/mace/kernels/opencl/cl/conv_2d_1x1.cl @@ -0,0 +1,61 @@ +/* + * Split work item along output channels and pixels + */ +void kernel conv_2d_1x1_naive(global const float *input, /* n, c, h, w */ + global const float *filter, /* o, i, kh, kw */ + global float *output, /* n, c, h, w */ + private const int in_offset, + private const int out_offset, + private const int pixel_num, + private const int in_chan_num, + private const int out_chan_num) { + int out_chan_blk = get_global_id(0); + int out_pixel_blk = get_global_id(1); + + const int out_chan_begin = out_chan_blk << 2; + const int out_chan_end = min(out_chan_begin + 4, out_chan_num); + const int out_pixel_begin = out_pixel_blk << 3; + const int out_pixel_end = min(out_pixel_begin + 8, pixel_num); + + const float *input_base = input + in_offset + out_pixel_begin; + float *output_base = output + out_offset + out_pixel_begin; + int pixels = out_pixel_end - out_pixel_begin; + + for (int in_chan = 0; in_chan < in_chan_num; ++in_chan) { + const float *input_ptr = input_base + in_chan * pixel_num; + if (pixels == 8) { + /* TODO fix '#pragma unroll' build error */ + for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) { + float weights = filter[out_chan * in_chan_num + in_chan]; + float *output_ptr = output_base + out_chan * pixel_num; + /* TODO fix vload/vstore */ + /* + for (int p = 0; p < 2; ++p) { + float4 in = vload4(p * 4, input_ptr); + float4 out = vload4(p * 4, output_ptr); + out += in * weights; + vstore4(out, p * 4, output_ptr); + } + */ + for (int p = 0; p < 8; ++p) { + float in = input_ptr[p]; + float out = output_ptr[p]; + out += in * weights; + output_ptr[p] = out; + } + } + } else { + for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) { + float weights = filter[out_chan * in_chan_num + in_chan]; + float *output_ptr = output_base + out_chan * pixel_num; + + for (int p = 0; p < pixels; ++p) { + float in = input_ptr[p]; + float out = output_ptr[p]; + out += in * weights; + output_ptr[p] = out; + } + } + } + } +} diff --git a/mace/kernels/opencl/conv_2d_opencl.cc b/mace/kernels/opencl/conv_2d_opencl.cc index 3aca41d0..9eeec07a 100644 --- a/mace/kernels/opencl/conv_2d_opencl.cc +++ b/mace/kernels/opencl/conv_2d_opencl.cc @@ -41,7 +41,7 @@ void Conv2dFunctor::operator()(const Tensor *input, return; } - MACE_CHECK(paddings_[0] == 1 && paddings_[1] == 1, "Padding not supported"); + MACE_CHECK(paddings_[0] == 0 && paddings_[1] == 0, "Padding not supported"); auto conv2d_func = selector[kernel_h - 1][strides_[0] - 1]; conv2d_func(input, filter, bias, output); diff --git a/mace/kernels/opencl/conv_2d_opencl_1x1.cc b/mace/kernels/opencl/conv_2d_opencl_1x1.cc index af75a259..0b648076 100644 --- a/mace/kernels/opencl/conv_2d_opencl_1x1.cc +++ b/mace/kernels/opencl/conv_2d_opencl_1x1.cc @@ -3,6 +3,7 @@ // #include "mace/core/common.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/kernels/conv_2d.h" #include "mace/utils/utils.h" @@ -12,6 +13,43 @@ namespace kernels { static constexpr index_t kInputChannelBlockSize = 2; static constexpr index_t kOutputChannelBlockSize = 4; +// TODO(heliangliang) fix bad performance +void AssignBias(Tensor *output, const Tensor *bias) { + + auto runtime = OpenCLRuntime::Get(); + auto program = runtime->program(); + if (bias == nullptr) { + auto assign_bias = + cl::KernelFunctor(program, "assign_f32"); + int global_size = output->NumElements(); + cl_int error; + assign_bias(cl::EnqueueArgs(runtime->command_queue(), + cl::NDRange(global_size), + cl::NullRange), + *(static_cast(output->buffer())), + 0.0f, error); + MACE_CHECK(error == CL_SUCCESS); + } else { + auto output_shape = output->shape(); + index_t batch = output_shape[0]; + index_t channels = output_shape[1]; + index_t pixels = output_shape[2] * output_shape[3]; + MACE_CHECK(channels == bias->shape()[0], "Channels mismatch"); + + auto assign_bias = + cl::KernelFunctor(program, "assign_vec_f32"); + cl_int error; + assign_bias(cl::EnqueueArgs(runtime->command_queue(), + cl::NDRange(batch, channels), + cl::NullRange), + *(static_cast(output->buffer())), + *(static_cast(bias->buffer())), + static_cast(pixels), + error); + MACE_CHECK(error == CL_SUCCESS); + } +} + extern void Conv2dOpenclK1x1S1(const Tensor *input, const Tensor *filter, const Tensor *bias, Tensor *output) { const index_t batch = output->shape()[0]; @@ -27,9 +65,29 @@ extern void Conv2dOpenclK1x1S1(const Tensor *input, const Tensor *filter, MACE_CHECK(input_batch == batch && input_height == height && input_width == width); + AssignBias(output, bias); + + auto runtime = OpenCLRuntime::Get(); + auto program = runtime->program(); + auto conv_2d = cl::KernelFunctor(program, "conv_2d_1x1_naive"); const index_t total_pixels = height * width; - const index_t round_up_channels = RoundUp(channels, kOutputChannelBlockSize); + for (int b = 0; b < batch; ++b) { + int input_offset = b * input_channels * total_pixels; + int output_offset = b * channels * total_pixels; + int chan_blk_num = (channels + 3) >> 2; // each 4 output channels + int pixel_blk_num = (total_pixels + 7) >> 3; // each 8 pixels + cl_int error; + conv_2d(cl::EnqueueArgs(runtime->command_queue(), + cl::NDRange(chan_blk_num, pixel_blk_num), + cl::NullRange), + *(static_cast(input->buffer())), + *(static_cast(filter->buffer())), + *(static_cast(output->buffer())), + input_offset, output_offset, total_pixels, input_channels, channels, error); + MACE_CHECK(error == CL_SUCCESS); + } }; } // namespace kernels diff --git a/mace/ops/addn_benchmark.cc b/mace/ops/addn_benchmark.cc index 4893c850..801e9426 100644 --- a/mace/ops/addn_benchmark.cc +++ b/mace/ops/addn_benchmark.cc @@ -21,7 +21,7 @@ static void AddNBenchmark(int iters, int n, int size) { // Add input data for (int i = 0; i < n; ++i) { - net.AddRandomInput(internal::MakeString("Input", i).c_str(), {size}); + net.AddRandomInput(internal::MakeString("Input", i).c_str(), {size}); } // Warm-up diff --git a/mace/ops/addn_test.cc b/mace/ops/addn_test.cc index 8e6497f2..76a46355 100644 --- a/mace/ops/addn_test.cc +++ b/mace/ops/addn_test.cc @@ -20,9 +20,9 @@ TEST_F(AddnOpTest, AddnOp) { .Finalize(net.operator_def()); // Add input data - net.AddRandomInput("Input1", {1, 2, 3, 4}); - net.AddRandomInput("Input2", {1, 2, 3, 4}); - net.AddRandomInput("Input3", {1, 2, 3, 4}); + net.AddRandomInput("Input1", {1, 2, 3, 4}); + net.AddRandomInput("Input2", {1, 2, 3, 4}); + net.AddRandomInput("Input3", {1, 2, 3, 4}); // Run net.RunOp(); diff --git a/mace/ops/batch_norm_benchmark.cc b/mace/ops/batch_norm_benchmark.cc index 6607695a..8fc24797 100644 --- a/mace/ops/batch_norm_benchmark.cc +++ b/mace/ops/batch_norm_benchmark.cc @@ -24,12 +24,12 @@ static void BatchNorm( .Finalize(net.operator_def()); // Add input data - net.AddRandomInput("Input", {batch, channels, height, width}); - net.AddRandomInput("Scale", {channels}); - net.AddRandomInput("Offset", {channels}); - net.AddRandomInput("Mean", {channels}); - net.AddRandomInput("Var", {channels}, true); - net.AddInputFromArray("Epsilon", {}, {1e-3}); + net.AddRandomInput("Input", {batch, channels, height, width}); + net.AddRandomInput("Scale", {channels}); + net.AddRandomInput("Offset", {channels}); + net.AddRandomInput("Mean", {channels}); + net.AddRandomInput("Var", {channels}, true); + net.AddInputFromArray("Epsilon", {}, {1e-3}); // Warm-up for (int i = 0; i < 5; ++i) { diff --git a/mace/ops/batch_norm_test.cc b/mace/ops/batch_norm_test.cc index 2e931782..99338778 100644 --- a/mace/ops/batch_norm_test.cc +++ b/mace/ops/batch_norm_test.cc @@ -23,13 +23,13 @@ TEST_F(BatchNormOpTest, SimpleCPU) { .Finalize(net.operator_def()); // Add input data - net.AddInputFromArray("Input", {1, 1, 6, 2}, + net.AddInputFromArray("Input", {1, 1, 6, 2}, {5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}); - net.AddInputFromArray("Scale", {1}, {4.0f}); - net.AddInputFromArray("Offset", {1}, {2.0}); - net.AddInputFromArray("Mean", {1}, {10}); - net.AddInputFromArray("Var", {1}, {11.67f}); - net.AddInputFromArray("Epsilon", {}, {1e-3}); + net.AddInputFromArray("Scale", {1}, {4.0f}); + net.AddInputFromArray("Offset", {1}, {2.0}); + net.AddInputFromArray("Mean", {1}, {10}); + net.AddInputFromArray("Var", {1}, {11.67f}); + net.AddInputFromArray("Epsilon", {}, {1e-3}); // Run net.RunOp(); @@ -63,12 +63,12 @@ TEST_F(BatchNormOpTest, SimpleNeon) { .Finalize(net.operator_def()); // Add input data - net.AddRandomInput("Input", {batch, channels, height, width}); - net.AddRandomInput("Scale", {channels}); - net.AddRandomInput("Offset", {channels}); - net.AddRandomInput("Mean", {channels}); - net.AddRandomInput("Var", {channels}, true); - net.AddInputFromArray("Epsilon", {}, {1e-3}); + net.AddRandomInput("Input", {batch, channels, height, width}); + net.AddRandomInput("Scale", {channels}); + net.AddRandomInput("Offset", {channels}); + net.AddRandomInput("Mean", {channels}); + net.AddRandomInput("Var", {channels}, true); + net.AddInputFromArray("Epsilon", {}, {1e-3}); // run cpu net.RunOp(); diff --git a/mace/ops/channel_shuffle_benchmark.cc b/mace/ops/channel_shuffle_benchmark.cc index ecbc3610..112e5fef 100644 --- a/mace/ops/channel_shuffle_benchmark.cc +++ b/mace/ops/channel_shuffle_benchmark.cc @@ -23,7 +23,7 @@ static void ChannelShuffle( // Add input data net.AddIntArg("group", group); - net.AddRandomInput("Input", {batch, channels, height, width}); + net.AddRandomInput("Input", {batch, channels, height, width}); // Warm-up for (int i = 0; i < 5; ++i) { diff --git a/mace/ops/channel_shuffle_test.cc b/mace/ops/channel_shuffle_test.cc index dcf0a21e..c862e516 100644 --- a/mace/ops/channel_shuffle_test.cc +++ b/mace/ops/channel_shuffle_test.cc @@ -19,7 +19,7 @@ TEST_F(ChannelShuffleOpTest, C8G4) { net.AddIntArg("group", 4); // Add input data - net.AddInputFromArray( + net.AddInputFromArray( "Input", {1, 8, 1, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); diff --git a/mace/ops/concat_benchmark.cc b/mace/ops/concat_benchmark.cc index bd56c495..c871c20d 100644 --- a/mace/ops/concat_benchmark.cc +++ b/mace/ops/concat_benchmark.cc @@ -21,9 +21,9 @@ static void ConcatHelper(int iters, int concat_dim, int dim1) { // Add input data const int kDim0 = 100; - net.AddRandomInput("Input0", {kDim0, dim1}); - net.AddRandomInput("Input1", {kDim0, dim1}); - net.AddInputFromArray("Axis", {}, {concat_dim}); + net.AddRandomInput("Input0", {kDim0, dim1}); + net.AddRandomInput("Input1", {kDim0, dim1}); + net.AddInputFromArray("Axis", {}, {concat_dim}); // Warm-up for (int i = 0; i < 5; ++i) { diff --git a/mace/ops/concat_test.cc b/mace/ops/concat_test.cc index f537e385..7e910d21 100644 --- a/mace/ops/concat_test.cc +++ b/mace/ops/concat_test.cc @@ -26,9 +26,9 @@ TEST_F(ConcatOpTest, Simple_Horizon) { std::vector input1; GenerateRandomRealTypeData(input_shape, input1); // Add inputs - net.AddInputFromArray("Input0", input_shape, input0); - net.AddInputFromArray("Input1", input_shape, input1); - net.AddInputFromArray("Axis", {}, {0}); + net.AddInputFromArray("Input0", input_shape, input0); + net.AddInputFromArray("Input1", input_shape, input1); + net.AddInputFromArray("Axis", {}, {0}); // Run net.RunOp(); @@ -64,9 +64,9 @@ TEST_F(ConcatOpTest, Simple_Vertical) { std::vector input1; GenerateRandomRealTypeData(input_shape, input1); // Add inputs - net.AddInputFromArray("Input0", input_shape, input0); - net.AddInputFromArray("Input1", input_shape, input1); - net.AddInputFromArray("Axis", {}, {1}); + net.AddInputFromArray("Input0", input_shape, input0); + net.AddInputFromArray("Input1", input_shape, input1); + net.AddInputFromArray("Axis", {}, {1}); // Run net.RunOp(); @@ -112,10 +112,10 @@ TEST_F(ConcatOpTest, Random) { concat_axis_size += input_shapes[i][axis]; GenerateRandomRealTypeData(input_shapes[i], inputs[i]); input_ptrs[i] = inputs[i].data(); - net.AddInputFromArray(("Input" + ToString(i)).c_str(), + net.AddInputFromArray(("Input" + ToString(i)).c_str(), input_shapes[i], inputs[i]); } - net.AddInputFromArray("Axis", {}, {axis}); + net.AddInputFromArray("Axis", {}, {axis}); // Run net.RunOp(); diff --git a/mace/ops/conv_2d_benchmark.cc b/mace/ops/conv_2d_benchmark.cc index 8a78041c..5ee5f1ce 100644 --- a/mace/ops/conv_2d_benchmark.cc +++ b/mace/ops/conv_2d_benchmark.cc @@ -38,10 +38,10 @@ static void Conv2d(int iters, net.AddIntsArg("dilations", {1, 1}); // Add input data - net.AddRandomInput("Input", {batch, channels, height, width}); - net.AddRandomInput("Filter", + net.AddRandomInput("Input", {batch, channels, height, width}); + net.AddRandomInput("Filter", {output_channels, channels, kernel_h, kernel_w}); - net.AddRandomInput("Bias", {output_channels}); + net.AddRandomInput("Bias", {output_channels}); // Warm-up for (int i = 0; i < 5; ++i) { diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index 2202caf2..1575d2f2 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -26,14 +26,14 @@ TEST_F(Conv2dOpTest, Simple_VALID) { net.AddIntsArg("dilations", {1, 1}); // Add input data - net.AddInputFromArray( + net.AddInputFromArray( "Input", {1, 2, 3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - net.AddInputFromArray( + net.AddInputFromArray( "Filter", {1, 2, 3, 3}, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); - net.AddInputFromArray("Bias", {1}, {0.1f}); + net.AddInputFromArray("Bias", {1}, {0.1f}); // Run net.RunOp(); @@ -60,14 +60,14 @@ TEST_F(Conv2dOpTest, Simple_SAME) { net.AddIntsArg("dilations", {1, 1}); // Add input data - net.AddInputFromArray( + net.AddInputFromArray( "Input", {1, 2, 3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - net.AddInputFromArray( + net.AddInputFromArray( "Filter", {1, 2, 3, 3}, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); - net.AddInputFromArray("Bias", {1}, {0.1f}); + net.AddInputFromArray("Bias", {1}, {0.1f}); // Run net.RunOp(); @@ -96,16 +96,16 @@ TEST_F(Conv2dOpTest, Combined) { net.AddIntsArg("dilations", {1, 1}); // Add input data - net.AddInputFromArray( + net.AddInputFromArray( "Input", {1, 2, 5, 5}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - net.AddInputFromArray( + net.AddInputFromArray( "Filter", {2, 2, 3, 3}, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}); - net.AddInputFromArray("Bias", {2}, {0.1f, 0.2f}); + net.AddInputFromArray("Bias", {2}, {0.1f, 0.2f}); // Run net.RunOp(); @@ -118,9 +118,10 @@ TEST_F(Conv2dOpTest, Combined) { ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); } -TEST_F(Conv2dOpTest, Conv1x1) { +template +void TestConv1x1() { // Construct graph - auto &net = test_net(); + OpsTestNet net; OpDefBuilder("Conv2D", "Conv2DTest") .Input("Input") .Input("Filter") @@ -134,7 +135,7 @@ TEST_F(Conv2dOpTest, Conv1x1) { net.AddIntsArg("dilations", {1, 1}); // Add input data - net.AddInputFromArray( + net.AddInputFromArray( "Input", {1, 5, 3, 10}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, @@ -143,13 +144,13 @@ TEST_F(Conv2dOpTest, Conv1x1) { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - net.AddInputFromArray( + net.AddInputFromArray( "Filter", {2, 5, 1, 1}, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f}); - net.AddInputFromArray("Bias", {2}, {0.1f, 0.2f}); + net.AddInputFromArray("Bias", {2}, {0.1f, 0.2f}); // Run - net.RunOp(); + net.RunOp(D); // Check auto expected = CreateTensor( @@ -164,6 +165,11 @@ TEST_F(Conv2dOpTest, Conv1x1) { ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); } +TEST_F(Conv2dOpTest, Conv1x1) { + TestConv1x1(); + TestConv1x1(); +} + // TODO we need more tests TEST_F(Conv2dOpTest, IdleConvNxNS12) { testing::internal::LogToStderr(); @@ -192,10 +198,10 @@ TEST_F(Conv2dOpTest, IdleConvNxNS12) { net.AddIntsArg("dilations", {1, 1}); // Add input data - net.AddRandomInput("Input", {batch, input_channels, height, width}); - net.AddRandomInput( + net.AddRandomInput("Input", {batch, input_channels, height, width}); + net.AddRandomInput( "Filter", {output_channels, input_channels, kernel_h, kernel_w}); - net.AddRandomInput("Bias", {output_channels}); + net.AddRandomInput("Bias", {output_channels}); // run cpu net.RunOp(); @@ -208,8 +214,8 @@ TEST_F(Conv2dOpTest, IdleConvNxNS12) { ExpectTensorNear(expected, *net.GetOutput("Output"), 0.001); }; - for (int kernel_size : {1}) { - for (int stride : {1}) { + for (int kernel_size : {1, 3, 5}) { + for (int stride : {1, 2}) { func(kernel_size, kernel_size, stride, stride, VALID); func(kernel_size, kernel_size, stride, stride, SAME); } @@ -243,10 +249,10 @@ TEST_F(Conv2dOpTest, DisgustConvNxNS12) { net.AddIntsArg("dilations", {1, 1}); // Add input data - net.AddRandomInput("Input", {batch, input_channels, height, width}); - net.AddRandomInput( + net.AddRandomInput("Input", {batch, input_channels, height, width}); + net.AddRandomInput( "Filter", {output_channels, input_channels, kernel_h, kernel_w}); - net.AddRandomInput("Bias", {output_channels}); + net.AddRandomInput("Bias", {output_channels}); // run cpu net.RunOp(); diff --git a/mace/ops/depthwise_conv2d_test.cc b/mace/ops/depthwise_conv2d_test.cc index 6868e8c3..34ecb048 100644 --- a/mace/ops/depthwise_conv2d_test.cc +++ b/mace/ops/depthwise_conv2d_test.cc @@ -26,13 +26,13 @@ TEST_F(DepthwiseConv2dOpTest, Simple_VALID) { net.AddIntsArg("dilations", {1, 1}); // Add input data - net.AddInputFromArray("Input", {1, 2, 2, 3}, + net.AddInputFromArray("Input", {1, 2, 2, 3}, {1, 3, 5, 7, 9, 11, 2, 4, 6, 8, 10, 12}); - net.AddInputFromArray( + net.AddInputFromArray( "Filter", {2, 2, 2, 2}, {1.0f, 5.0f, 9.0f, 13.0f, 2.0f, 6.0f, 10.0f, 14.0f, 3.0f, 7.0f, 11.0f, 15.0f, 4.0f, 8.0f, 12.0f, 16.0f}); - net.AddInputFromArray("Bias", {4}, {.1f, .2f, .3f, .4f}); + net.AddInputFromArray("Bias", {4}, {.1f, .2f, .3f, .4f}); // Run net.RunOp(); @@ -71,10 +71,10 @@ TEST_F(DepthwiseConv2dOpTest, ConvNxNS12) { net.AddIntsArg("dilations", {1, 1}); // Add input data - net.AddRandomInput("Input", {batch, input_channels, height, width}); - net.AddRandomInput("Filter", + net.AddRandomInput("Input", {batch, input_channels, height, width}); + net.AddRandomInput("Filter", {multiplier, input_channels, kernel_h, kernel_w}); - net.AddRandomInput("Bias", {multiplier * input_channels}); + net.AddRandomInput("Bias", {multiplier * input_channels}); // run cpu net.RunOp(); diff --git a/mace/ops/depthwise_conv_2d_benchmark.cc b/mace/ops/depthwise_conv_2d_benchmark.cc index 9ba7001d..f801c075 100644 --- a/mace/ops/depthwise_conv_2d_benchmark.cc +++ b/mace/ops/depthwise_conv_2d_benchmark.cc @@ -38,10 +38,10 @@ static void DepthwiseConv2d(int iters, net.AddIntsArg("dilations", {1, 1}); // Add input data - net.AddRandomInput("Input", {batch, channels, height, width}); - net.AddRandomInput("Filter", + net.AddRandomInput("Input", {batch, channels, height, width}); + net.AddRandomInput("Filter", {output_channels, channels, kernel_h, kernel_w}); - net.AddRandomInput("Bias", {output_channels}); + net.AddRandomInput("Bias", {output_channels}); // Warm-up for (int i = 0; i < 5; ++i) { diff --git a/mace/ops/global_avg_pooling_benchmark.cc b/mace/ops/global_avg_pooling_benchmark.cc index d2521e7c..a0063496 100644 --- a/mace/ops/global_avg_pooling_benchmark.cc +++ b/mace/ops/global_avg_pooling_benchmark.cc @@ -22,7 +22,7 @@ static void GlobalAvgPooling( .Finalize(net.operator_def()); // Add input data - net.AddRandomInput("Input", {batch, channels, height, width}); + net.AddRandomInput("Input", {batch, channels, height, width}); // Warm-up for (int i = 0; i < 5; ++i) { diff --git a/mace/ops/global_avg_pooling_test.cc b/mace/ops/global_avg_pooling_test.cc index bf9e4269..540f874a 100644 --- a/mace/ops/global_avg_pooling_test.cc +++ b/mace/ops/global_avg_pooling_test.cc @@ -21,7 +21,7 @@ TEST_F(GlobalAvgPoolingOpTest, 3x7x7_CPU) { for (int i = 0; i < 147; ++i) { input[i] = i / 49 + 1; } - net.AddInputFromArray("Input", {1, 3, 7, 7}, input); + net.AddInputFromArray("Input", {1, 3, 7, 7}, input); // Run net.RunOp(); @@ -45,7 +45,7 @@ TEST_F(GlobalAvgPoolingOpTest, 3x7x7_NEON) { for (int i = 0; i < 147; ++i) { input[i] = i / 49 + 1; } - net.AddInputFromArray("Input", {1, 3, 7, 7}, input); + net.AddInputFromArray("Input", {1, 3, 7, 7}, input); // Run net.RunOp(DeviceType::NEON); diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index 252cb5d6..3b2ddfe0 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -43,36 +43,39 @@ class OpsTestNet { public: OpsTestNet() {} - template + template void AddInputFromArray(const char *name, const std::vector &shape, const std::vector &data) { Tensor *input = - ws_.CreateTensor(name, GetDeviceAllocator(DeviceType::CPU), DataTypeToEnum::v()); + ws_.CreateTensor(name, GetDeviceAllocator(D), DataTypeToEnum::v()); input->Resize(shape); + Tensor::MappingGuard input_mapper(input); T *input_data = input->mutable_data(); MACE_CHECK(static_cast(input->size()) == data.size()); memcpy(input_data, data.data(), data.size() * sizeof(T)); } - template + template void AddRepeatedInput(const char *name, const std::vector &shape, const T data) { Tensor *input = - ws_.CreateTensor(name, GetDeviceAllocator(DeviceType::CPU), DataTypeToEnum::v()); + ws_.CreateTensor(name, GetDeviceAllocator(D), DataTypeToEnum::v()); input->Resize(shape); + Tensor::MappingGuard input_mapper(input); T *input_data = input->mutable_data(); std::fill(input_data, input_data + input->size(), data); } - template + template void AddRandomInput(const char *name, const std::vector &shape, bool positive = false) { Tensor *input = - ws_.CreateTensor(name, GetDeviceAllocator(DeviceType::CPU), DataTypeToEnum::v()); + ws_.CreateTensor(name, GetDeviceAllocator(D), DataTypeToEnum::v()); input->Resize(shape); + Tensor::MappingGuard input_mapper(input); float *input_data = input->mutable_data(); std::random_device rd; @@ -274,6 +277,8 @@ struct Expector { static void Equal(const Tensor &x, const Tensor &y) { ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); AssertSameTypeDims(x, y); + Tensor::MappingGuard x_mapper(&x); + Tensor::MappingGuard y_mapper(&y); auto a = x.data(); auto b = y.data(); for (int i = 0; i < x.size(); ++i) { @@ -284,6 +289,8 @@ struct Expector { static void Near(const Tensor &x, const Tensor &y, const double abs_err) { ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); AssertSameTypeDims(x, y); + Tensor::MappingGuard x_mapper(&x); + Tensor::MappingGuard y_mapper(&y); auto a = x.data(); auto b = y.data(); for (int i = 0; i < x.size(); ++i) { diff --git a/mace/ops/pooling_benchmark.cc b/mace/ops/pooling_benchmark.cc index bae9bc2e..5282bff7 100644 --- a/mace/ops/pooling_benchmark.cc +++ b/mace/ops/pooling_benchmark.cc @@ -37,7 +37,7 @@ static void Pooling(int iters, net.AddIntsArg("dilations", {1, 1}); // Add input data - net.AddRandomInput("Input", {batch, channels, height, width}); + net.AddRandomInput("Input", {batch, channels, height, width}); // Warm-up for (int i = 0; i < 5; ++i) { diff --git a/mace/ops/pooling_test.cc b/mace/ops/pooling_test.cc index 6c977d59..75096f5d 100644 --- a/mace/ops/pooling_test.cc +++ b/mace/ops/pooling_test.cc @@ -29,7 +29,7 @@ TEST_F(PoolingOpTest, MAX_VALID) { net.AddIntArg("pooling_type", PoolingType::MAX); // Add input data - net.AddInputFromArray( + net.AddInputFromArray( "Input", {1, 2, 4, 4}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}); @@ -60,7 +60,7 @@ TEST_F(PoolingOpTest, AVG_VALID) { net.AddIntArg("pooling_type", PoolingType::AVG); // Add input data - net.AddInputFromArray( + net.AddInputFromArray( "Input", {1, 2, 4, 4}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}); @@ -91,7 +91,7 @@ TEST_F(PoolingOpTest, MAX_SAME) { net.AddIntArg("pooling_type", PoolingType::MAX); // Add input data - net.AddInputFromArray("Input", {1, 1, 3, 3}, + net.AddInputFromArray("Input", {1, 1, 3, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8}); // Run @@ -119,7 +119,7 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) { net.AddIntArg("pooling_type", PoolingType::MAX); // Add input data - net.AddInputFromArray( + net.AddInputFromArray( "Input", {1, 1, 4, 4}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); @@ -148,7 +148,7 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) { net.AddIntsArg("dilations", {1, 1}); // Add input data - net.AddInputFromArray( + net.AddInputFromArray( "Input", {1, 1, 2, 9}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}); // Run @@ -176,7 +176,7 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) { net.AddIntsArg("dilations", {1, 1}); // Add input data - net.AddInputFromArray( + net.AddInputFromArray( "Input", {1, 1, 3, 9}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26}); @@ -205,7 +205,7 @@ TEST_F(PoolingOpTest, AVG_k2x2s2x2) { net.AddIntsArg("dilations", {1, 1}); // Add input data - net.AddInputFromArray( + net.AddInputFromArray( "Input", {1, 1, 2, 8}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); // Run diff --git a/mace/ops/relu_benchmark.cc b/mace/ops/relu_benchmark.cc index e25b0b8f..1a2be2ca 100644 --- a/mace/ops/relu_benchmark.cc +++ b/mace/ops/relu_benchmark.cc @@ -19,7 +19,7 @@ static void ReluBenchmark(int iters, int size) { .Finalize(net.operator_def()); // Add input data - net.AddRandomInput("Input", {size}); + net.AddRandomInput("Input", {size}); // Warm-up for (int i = 0; i < 5; ++i) { diff --git a/mace/ops/relu_test.cc b/mace/ops/relu_test.cc index d930444e..91964b72 100644 --- a/mace/ops/relu_test.cc +++ b/mace/ops/relu_test.cc @@ -18,7 +18,7 @@ TEST_F(ReluOpTest, ReluOp) { .Finalize(net.operator_def()); // Add input data - net.AddRandomInput("Input", {1, 2, 3, 5}); + net.AddRandomInput("Input", {1, 2, 3, 5}); // Run net.RunOp(); @@ -41,7 +41,7 @@ TEST_F(ReluOpTest, ReluOpWithMax) { .Finalize(net.operator_def()); // Add input data - net.AddRandomInput("Input", {1, 2, 3, 5}); + net.AddRandomInput("Input", {1, 2, 3, 5}); net.AddFloatArg("max_limit", 0.5); // Run diff --git a/mace/ops/resize_bilinear_test.cc b/mace/ops/resize_bilinear_test.cc index 1690e8d0..dc05c5ef 100644 --- a/mace/ops/resize_bilinear_test.cc +++ b/mace/ops/resize_bilinear_test.cc @@ -23,8 +23,8 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWOAlignCorners) { // Add input data vector input(24); std::iota(begin(input), end(input), 0); - net.AddInputFromArray("Input", {1, 3, 2, 4}, input); - net.AddInputFromArray("OutSize", {2}, {1, 2}); + net.AddInputFromArray("Input", {1, 3, 2, 4}, input); + net.AddInputFromArray("OutSize", {2}, {1, 2}); // Run net.RunOp(); @@ -50,8 +50,8 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWAlignCorners) { // Add input data vector input(24); std::iota(begin(input), end(input), 0); - net.AddInputFromArray("Input", {1, 3, 2, 4}, input); - net.AddInputFromArray("OutSize", {2}, {1, 2}); + net.AddInputFromArray("Input", {1, 3, 2, 4}, input); + net.AddInputFromArray("OutSize", {2}, {1, 2}); // Run net.RunOp(); -- GitLab