diff --git a/mace/core/allocator.h b/mace/core/allocator.h index ceb8c534251b359e1f2fc09842cb164c25c8aab2..0cde9c61faab5a494340f323ec438c0a5bf08c1e 100644 --- a/mace/core/allocator.h +++ b/mace/core/allocator.h @@ -31,7 +31,7 @@ class Allocator { template T* New(size_t num_elements) { if (num_elements > (std::numeric_limits::max() / sizeof(T))) { - return NULL; + return nullptr; } void* p = New(sizeof(T) * num_elements); T* typed_p = reinterpret_cast(p); diff --git a/mace/core/logging.h b/mace/core/logging.h index 0787af3383d91074ae60c214f096923b8fc891d9..be31a70afef51aa8c32db15f96a1211815482e1a 100644 --- a/mace/core/logging.h +++ b/mace/core/logging.h @@ -106,16 +106,27 @@ class LogMessageFatal : public LogMessage { if (VLOG_IS_ON(lvl)) \ ::mace::internal::LogMessage(__FILE__, __LINE__, mace::INFO) -// MACE_CHECK dies with a fatal error if condition is not true. It is *not* -// controlled by NDEBUG, so the check will be executed regardless of -// compilation mode. Therefore, it is safe to do things like: +// MACE_CHECK/MACE_ASSERT dies with a fatal error if condition is not true. +// MACE_ASSERT is controlled by NDEBUG ('-c opt' for bazel) while MACE_CHECK +// will be executed regardless of compilation mode. +// Therefore, it is safe to do things like: // MACE_CHECK(fp->Write(x) == 4) // MACE_CHECK(fp->Write(x) == 4, "Write failed") +// which are not correct for MACE_ASSERT. #define MACE_CHECK(condition, ...) \ if (!(condition)) \ LOG(FATAL) << "Check failed: " #condition " " \ << ::mace::internal::MakeString(__VA_ARGS__) +#ifndef NDEBUG +#define MACE_ASSERT(condition, ...) \ + if (!(condition)) \ + LOG(FATAL) << "Assert failed: " #condition " " \ + << ::mace::internal::MakeString(__VA_ARGS__) +#else +#define MACE_ASSERT(condition, ...) ((void)0) +#endif + template T&& CheckNotNull(const char* file, int line, const char* exprtext, T&& t) { if (t == nullptr) { diff --git a/mace/core/testing/test_benchmark.cc b/mace/core/testing/test_benchmark.cc index 885a9a63f70956428008291f29dc293245c7d37a..41608a9530d744e58efaa22a1faa2bb612bd45aa 100644 --- a/mace/core/testing/test_benchmark.cc +++ b/mace/core/testing/test_benchmark.cc @@ -6,7 +6,9 @@ #include #include +#include #include + #include "mace/core/logging.h" #include "mace/core/testing/env_time.h" #include "mace/core/testing/test_benchmark.h" @@ -52,12 +54,23 @@ Benchmark* Benchmark::ArgPair(int x, int y) { // Run all benchmarks void Benchmark::Run() { + Run("all"); +} + +void Benchmark::Run(const char* pattern) { if (!all_benchmarks) return; + if (std::string(pattern) == "all") { + pattern = ".*"; + } + std::regex regex(pattern); + // Compute name width. int width = 10; char name[100]; + std::smatch match; for (auto b : *all_benchmarks) { + if (!std::regex_match(b->name_, match, regex)) continue; for (auto arg : b->args_) { strcpy(name, b->name_.c_str()); if (arg.first >= 0) { @@ -74,7 +87,7 @@ void Benchmark::Run() { printf("%-*s %10s %10s\n", width, "Benchmark", "Time(ns)", "Iterations"); printf("%s\n", string(width + 22, '-').c_str()); for (auto b : *all_benchmarks) { - + if (!std::regex_match(b->name_, match, regex)) continue; for (auto arg : b->args_) { strcpy(name, b->name_.c_str()); if (arg.first >= 0) { diff --git a/mace/core/testing/test_benchmark.h b/mace/core/testing/test_benchmark.h index 5800f5edb0912899b09fc95ebebb8a741e2a48e1..6f96411b5bee94ddfb530c393207646bd7638394 100644 --- a/mace/core/testing/test_benchmark.h +++ b/mace/core/testing/test_benchmark.h @@ -28,6 +28,7 @@ class Benchmark { Benchmark* ArgPair(int x, int y); static void Run(); + static void Run(const char* pattern); private: string name_; diff --git a/mace/core/testing/test_benchmark_main.cc b/mace/core/testing/test_benchmark_main.cc index 7c184210453b958596faf9da0cb84ff67160b135..dfa8767229d9fc6fd21d0ff72b66f636c55eac70 100644 --- a/mace/core/testing/test_benchmark_main.cc +++ b/mace/core/testing/test_benchmark_main.cc @@ -9,7 +9,12 @@ int main(int argc, char** argv) { std::cout << "Running main() from test_main.cc\n"; - mace::testing::Benchmark::Run(); + // TODO Use gflags + if (argc == 2) { + mace::testing::Benchmark::Run(argv[1]); + } else { + mace::testing::Benchmark::Run("all"); + } return 0; } diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index a102ccef4075ec133158725b42a83ceb3b5a4411..34044b4e5afb603ba197a43421ccee5df0ad76d3 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -108,14 +108,14 @@ class Conv2dFunctor { const int* dilations_; // [dilation_h, dilation_w] }; -template<> -void Conv2dFunctor::operator()(const float* input, // NCHW - const index_t* input_shape, - const float* filter, // c_out, c_in, kernel_h, kernel_w - const index_t* filter_shape, - const float* bias, // c_out - float* output, // NCHW - const index_t* output_shape); +template <> +void Conv2dFunctor::operator()(const float* input, + const index_t* input_shape, + const float* filter, + const index_t* filter_shape, + const float* bias, + float* output, + const index_t* output_shape); } // namespace kernels } // namespace mace diff --git a/mace/kernels/neon/conv_2d_neon.cc b/mace/kernels/neon/conv_2d_neon.cc index 8d45861a1b17ab2e1c59b217723aa6d30d962d63..bcc1002e4d2116262c3298a163a6859c6b588608 100644 --- a/mace/kernels/neon/conv_2d_neon.cc +++ b/mace/kernels/neon/conv_2d_neon.cc @@ -9,39 +9,48 @@ namespace mace { namespace kernels { -static inline void ConstructInputWithPadding(const float* input, const index_t* input_shape, +static inline void ConstructInputWithPadding(const float* input, + const index_t* input_shape, const int* paddings, - Tensor& output_tensor, - std::vector& output_shape) { + Tensor* output_tensor) { index_t batch = input_shape[0]; index_t channels = input_shape[1]; index_t height = input_shape[2]; index_t width = input_shape[3]; - output_shape[0] = batch; - output_shape[1] = channels; - output_shape[2] = paddings[0] + height; - output_shape[3] = paddings[1] + width; - index_t output_width = output_shape[3]; - int padded_left = paddings[1] / 2; + std::vector output_shape({batch, + channels, + paddings[0] + height, + paddings[1] + width}); - output_tensor.Resize(output_shape); - float* output_ptr = output_tensor.mutable_data(); - memset(output_ptr, 0, output_tensor.size() * sizeof(float)); - output_ptr += paddings[0] / 2 * output_width; + const index_t output_width = output_shape[3]; + const int padded_top = paddings[0] / 2; + const int padded_left = paddings[1] / 2; + output_tensor->Resize(output_shape); + float* output_ptr = output_tensor->mutable_data(); + memset(output_ptr, 0, output_tensor->size() * sizeof(float)); + + // Skip the padded top rows + output_ptr += padded_top * output_width; for (; batch > 0; --batch) { for (; channels > 0; --channels) { for(; height > 0; --height) { - memcpy(output_ptr+padded_left, input, width*sizeof(float)); + memcpy(output_ptr + padded_left, input, width * sizeof(float)); input += width; output_ptr += output_width; } + // Skip the padded bottom in this channel and top in the next channel output_ptr += paddings[0] * output_width; } } } +extern void Conv2dNeonK1x1S1(const float* input, const index_t* input_shape, + const float* filter, const float* bias, + float* output, const index_t* output_shape); + + template<> void Conv2dFunctor::operator()(const float* input, // NCHW const index_t* input_shape, @@ -57,9 +66,10 @@ void Conv2dFunctor::operator()(const float* input, // N const float* bias, // c_out float* output, // NCHW const index_t* output_shape); + // Selection matrix: kernel_size x stride_size static const Conv2dNeonFunction selector[5][2] = { { - nullptr, + Conv2dNeonK1x1S1, nullptr }, { @@ -80,10 +90,13 @@ void Conv2dFunctor::operator()(const float* input, // N } }; // not implement yet - if (paddings_[0] != paddings_[1] || paddings_[0] > 5 || - strides_[0] != strides_[1] || strides_[0] > 4 || - dilations_[0] != 1 || dilations_[1] != 1 || - selector[paddings_[0]-1][strides_[0]-1] == nullptr) { + index_t kernel_h = filter_shape[2]; + index_t kernel_w = filter_shape[3]; + if (kernel_h != kernel_w || kernel_h > 5 || + strides_[0] != strides_[1] || strides_[0] > 2 || + dilations_[0] != 1 || dilations_[1] != 1 || + selector[kernel_h - 1][strides_[0] - 1] == nullptr) { + LOG(WARNING) << "NEON conv2d kernel not implementated, using slow vesion"; Conv2dFunctor(strides_, paddings_, dilations_)( input, input_shape, @@ -94,19 +107,22 @@ void Conv2dFunctor::operator()(const float* input, // N output_shape ); } + + // Keep this alive during kernel execution Tensor padded_input; - std::vector padded_input_shape(4); - ConstructInputWithPadding(input, input_shape, paddings_, padded_input, padded_input_shape); - auto conv2d_neon_func = selector[paddings_[0] - 1][strides_[0] - 1]; - conv2d_neon_func( - padded_input.data(), - padded_input_shape.data(), - filter, - bias, - output, - output_shape - ); + if (paddings_[0] > 0 || paddings_[1] > 0) { + ConstructInputWithPadding(input, input_shape, paddings_, &padded_input); + input = padded_input.data(); + input_shape = padded_input.shape().data(); + } + auto conv2d_neon_func = selector[kernel_h - 1][strides_[0] - 1]; + conv2d_neon_func(input, + input_shape, + filter, + bias, + output, + output_shape); } } // namespace kernels -} // namespace mace \ No newline at end of file +} // namespace mace diff --git a/mace/kernels/neon/conv_2d_neon_1x1.cc b/mace/kernels/neon/conv_2d_neon_1x1.cc new file mode 100644 index 0000000000000000000000000000000000000000..b3556c98a5c38ff17e27ce087c47f9b78375d375 --- /dev/null +++ b/mace/kernels/neon/conv_2d_neon_1x1.cc @@ -0,0 +1,187 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include +#include "mace/kernels/conv_2d.h" + +namespace mace { +namespace kernels { + +void Conv2dNeonK1x1S1(const float* input, // NCHW + const index_t* input_shape, + const float* filter, // c_out, c_in, kernel_h, kernel_w + const float* bias, // c_out + float* output, // NCHW + const index_t* output_shape) { + const index_t batch = output_shape[0]; + const index_t channels = output_shape[1]; + const index_t height = output_shape[2]; + const index_t width = output_shape[3]; + + const index_t input_batch = input_shape[0]; + const index_t input_channels = input_shape[1]; + const index_t input_height = input_shape[2]; + const index_t input_width = input_shape[3]; + + MACE_CHECK(input_batch == batch && + input_height == height && + input_width == width); + + const index_t total_pixels = height * width; + // Process 4 * 2 = 8 pixels for each innermost loop + // TODO Does 64 bit v.s. 32 bit index matters? need benchmark + const index_t total_loops = total_pixels >> 3; + const index_t loop_remaining = total_pixels & 7; + + // benchmark omp collapsed(2) + for (index_t n = 0; n < batch; ++n) { + const float* filter_ptr = filter; + #pragma omp parallel for + for (index_t c = 0; c < channels; ++c) { + // TODO Will GCC opt these out? + float* channel_output_start = + output + n * channels * height * width + c * height * width; + const float* input_ptr = input + n * input_channels * input_height * input_width; + + // Fill with bias + float* output_ptr = channel_output_start; + for (index_t ptr = 0; ptr < total_pixels; ++ptr) { + output_ptr[ptr] = bias[c]; // TODO can we avoid this? + } + + index_t inc = 0; + // Process 4 input channels in batch + for (; inc + 3 < input_channels; inc += 4) { + float* output_ptr = channel_output_start; + // The begining of each input feature map channel + MACE_ASSERT(input_ptr == input + n * input_channels * + input_height * input_width + + inc * input_height * input_width); + + const float* input_ptr1 = input_ptr + total_pixels; + const float* input_ptr2 = input_ptr1 + total_pixels; + const float* input_ptr3 = input_ptr2 + total_pixels; + + + // filter is in c_out, c_in, 1, 1 order + MACE_ASSERT(filter_ptr == filter + c * input_channels + inc); + const float k0 = filter_ptr[0]; + const float k1 = filter_ptr[1]; + const float k2 = filter_ptr[2]; + const float k3 = filter_ptr[3]; + filter_ptr += 4; + + const float32x4_t vk0 = vdupq_n_f32(k0); + const float32x4_t vk1 = vdupq_n_f32(k1); + const float32x4_t vk2 = vdupq_n_f32(k2); + const float32x4_t vk3 = vdupq_n_f32(k3); + + index_t loop_itr = total_loops; + for (; loop_itr > 0; --loop_itr) { + // Process 2 group of 4 floats + float32x4_t out0 = vld1q_f32(output_ptr); + float32x4_t out4 = vld1q_f32(output_ptr + 4); + + const float32x4_t in00 = vld1q_f32(input_ptr); + const float32x4_t in04 = vld1q_f32(input_ptr + 4); + + out0 = vfmaq_f32(out0, in00, vk0); + out4 = vfmaq_f32(out4, in04, vk0); + + const float32x4_t in10 = vld1q_f32(input_ptr1); + const float32x4_t in14 = vld1q_f32(input_ptr1 + 4); + + out0 = vfmaq_f32(out0, in10, vk1); + out4 = vfmaq_f32(out4, in14, vk1); + + const float32x4_t in20 = vld1q_f32(input_ptr2); + const float32x4_t in24 = vld1q_f32(input_ptr2 + 4); + + out0 = vfmaq_f32(out0, in20, vk2); + out4 = vfmaq_f32(out4, in24, vk2); + + const float32x4_t in30 = vld1q_f32(input_ptr3); + const float32x4_t in34 = vld1q_f32(input_ptr3 + 4); + + out0 = vfmaq_f32(out0, in30, vk3); + out4 = vfmaq_f32(out4, in34, vk3); + + float prev_output = output_ptr[0]; + // Save output + vst1q_f32(output_ptr, out0); + vst1q_f32(output_ptr + 4, out4); + + output_ptr += 8; + input_ptr += 8; + input_ptr1 += 8; + input_ptr2 += 8; + input_ptr3 += 8; + } + // Process the remaining pixels + index_t remaining_pixels = loop_remaining; + for (; remaining_pixels > 0; --remaining_pixels) { + const float mul = *input_ptr * k0; + const float mul1 = *input_ptr1 * k1; + const float mul2 = *input_ptr2 * k2; + const float mul3 = *input_ptr3 * k3; + + float prev_output = output_ptr[0]; + *output_ptr += mul + mul1 + mul2 + mul3; + + ++output_ptr; + ++input_ptr; + ++input_ptr1; + ++input_ptr2; + ++input_ptr3; + } + // Skip these 4 feature maps + input_ptr += 3 * total_pixels; + } + // Process the remaining channels + for (; inc < input_channels; ++inc) { + float* output_ptr = channel_output_start; + MACE_ASSERT(input_ptr == input + n * input_channels * + input_height * input_width + + inc * input_height * input_width); + MACE_ASSERT(filter_ptr == filter + c * input_channels + inc); + + const float k0 = filter_ptr[0]; + ++filter_ptr; + const float32x4_t vk0 = vdupq_n_f32(k0); + + index_t loop_itr = total_loops; + for (; loop_itr > 0; --loop_itr) { + float32x4_t out0 = vld1q_f32(output_ptr); + float32x4_t out4 = vld1q_f32(output_ptr + 4); + + const float32x4_t in0 = vld1q_f32(input_ptr); + const float32x4_t in4 = vld1q_f32(input_ptr + 4); + + out0 = vfmaq_f32(out0, in0, vk0); + out4 = vfmaq_f32(out4, in4, vk0); + + // Save output + vst1q_f32(output_ptr, out0); + vst1q_f32(output_ptr + 4, out4); + + output_ptr += 8; + input_ptr += 8; + } + // Process the remaining pixels + index_t remaining_pixels = loop_remaining; + for (; remaining_pixels > 0; --remaining_pixels) { + const float mul = *input_ptr * k0; + + *output_ptr += mul; + + ++output_ptr; + ++input_ptr; + } + } + } + } +}; + +} // namespace kernels +} // namespace mace diff --git a/mace/ops/BUILD b/mace/ops/BUILD index 03930e9240f6c89dd7b897db0fb5005df917001f..a35e9cc3da4b9ac2a10cea7d971213e76b3dc9c7 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -25,7 +25,7 @@ cc_library( name = "ops", srcs = glob( ["*.cc"], - exclude = ["*_test.cc"], + exclude = ["*_test.cc", "*_benchmark.cc"], ), hdrs = glob( ["*.h"], @@ -46,11 +46,6 @@ cc_test( ["*_test.cc"], ), copts = ["-std=c++11"], - linkopts = if_android([ - "-pie", - "-llog", - "-latomic", - ]), linkstatic = 1, deps = [ ":ops", @@ -58,3 +53,16 @@ cc_test( "@gtest//:gtest_main", ], ) + +cc_test( + name = "ops_benchmark", + srcs = glob(["*_benchmark.cc"]), + deps = [ + ":ops", + "//mace/core:core", + "//mace/core:test_benchmark_main", + ], + copts = ['-std=c++11'], + linkstatic = 1, + testonly = 1, +) diff --git a/mace/ops/conv_2d_benchmark.cc b/mace/ops/conv_2d_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c347c274a66f679cd8af051fb95ac19f3979329 --- /dev/null +++ b/mace/ops/conv_2d_benchmark.cc @@ -0,0 +1,37 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/conv_2d.h" + +namespace mace { + +template +static void Conv2d(int iters, int batch, int channels, int height, int width, + int kernel_h, int kernel_w, int stride, + Padding padding, int output_channels) { + mace::testing::StopTiming(); + + mace::testing::StartTiming(); + while(iters--) { + } +} + +#define BM_CONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, P, OC, TYPE, DEVICE) \ + static void BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_OC##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::ItemsProcessed(tot); \ + mace::testing::BytesProcessed(tot * (sizeof(TYPE))); \ + Conv2d(iters, N, C, H, W, KH, KW, STRIDE, mace::Padding::P, OC); \ + } \ + BENCHMARK(BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_OC##_##TYPE##_##DEVICE) + +#define BM_CONV_2D(N, C, H, W, KH, KW, S, P, OC, TYPE) \ + BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, CPU); \ + BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, NEON); + +BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, VALID, 128, float); + +} // namespace mace diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index c95f3dc6e4317c3596034bf2b65d9210258c1481..797075f2c4dcd9f9a3c4572e252694d04a70820c 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -136,4 +136,55 @@ TEST_F(Conv2dOpTest, Combined) { ExpectTensorNear(expected, *GetOutput("Output"), 0.001); } +TEST_F(Conv2dOpTest, Conv1x1) { + // Construct graph + OpDefBuilder("Conv2d", "Conv2dTest") + .Input("Input") + .Input("Filter") + .Input("Bias") + .Output("Output") + .Finalize(operator_def()); + + // Add args + AddIntsArg("strides", {1, 1}); + AddIntArg("padding", Padding::VALID); + AddIntsArg("dilations", {1, 1}); + + // Add input data + AddInputFromArray("Input", {1, 5, 3, 10}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + AddInputFromArray("Filter", {2, 5, 1, 1}, + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 2.0f, 2.0f, 2.0f, 2.0f, 2.0f}); + AddInputFromArray("Bias", {2}, {0.1f, 0.2f}); + + // Run + RunOp(DeviceType::NEON); + + // Check + Tensor expected = CreateTensor({1, 2, 3, 10}, + {5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, + 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, + 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, + 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, + 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, + 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f}); + + ExpectTensorNear(expected, *GetOutput("Output"), 0.001); +} + // TODO we need more tests diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index f61d4b19af319c154d2d99f5240138678372fce9..9dd44c2521219684ca3760bed4fe07ffe8586b89 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -9,8 +9,8 @@ #include "gtest/gtest.h" #include "mace/core/common.h" -#include "mace/core/tensor.h" #include "mace/core/net.h" +#include "mace/core/tensor.h" namespace mace { @@ -29,7 +29,7 @@ class OpDefBuilder { return *this; } void Finalize(OperatorDef* op_def) const { - MACE_CHECK(op_def != NULL, "input should not be null."); + MACE_CHECK(op_def != nullptr, "input should not be null."); *op_def = op_def_; } OperatorDef op_def_; @@ -49,6 +49,7 @@ class OpsTestBase : public ::testing::Test { Tensor* input = ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum::v()); input->Resize(shape); float* input_data = input->mutable_data(); + // TODO check the dims memcpy(input_data, data.data(), data.size() * sizeof(T)); } @@ -96,14 +97,18 @@ class OpsTestBase : public ::testing::Test { OperatorDef* operator_def() { return &op_def_; } - bool RunOp() { + bool RunOp(DeviceType device) { NetDef net_def; net_def.add_op()->CopyFrom(op_def_); VLOG(0) << net_def.DebugString(); - auto net = CreateNet(net_def, &ws_, DeviceType::CPU); + auto net = CreateNet(net_def, &ws_, device); return net->Run(); } + bool RunOp() { + return RunOp(DeviceType::CPU); + } + Tensor* GetOutput(const char* output_name) { return ws_.GetTensor(output_name); } @@ -209,6 +214,6 @@ void ExpectTensorNear(const Tensor& x, const Tensor& y, const double abs_err) { Expector::Near(x, y ,abs_err); } -} // namespace mace +} // namespace mace #endif // MACE_OPS_TEST_UTIL_H_