diff --git a/tensorflow/core/kernels/resize_bicubic_op.cc b/tensorflow/core/kernels/resize_bicubic_op.cc index c5c805c44f8784d838bdc168ce255224b61a8fbf..d7b063e0c184ac5a23a993c868be0d7b0928b7e3 100644 --- a/tensorflow/core/kernels/resize_bicubic_op.cc +++ b/tensorflow/core/kernels/resize_bicubic_op.cc @@ -60,24 +60,33 @@ inline int64 Bound(int64 val, int64 limit) { return std::min(limit - 1ll, std::max(0ll, val)); } +struct WeightsAndIndices { + float weight_0; + float weight_1; + float weight_2; + float weight_3; + int64 index_0; + int64 index_1; + int64 index_2; + int64 index_3; + + int advance; // advance value. +}; + inline void GetWeightsAndIndices(const float scale, const int64 out_loc, - const int64 limit, float* weight_0, - float* weight_1, float* weight_2, - float* weight_3, int64* index_0, - int64* index_1, int64* index_2, - int64* index_3) { + const int64 limit, WeightsAndIndices* out) { const int64 in_loc = scale * out_loc; const float delta = scale * out_loc - in_loc; const int64 offset = lrintf(delta * kTableSize); const float* coeffs_table = GetCoeffsTable(); - *weight_0 = coeffs_table[offset * 2 + 1]; - *weight_1 = coeffs_table[offset * 2]; - *weight_2 = coeffs_table[(kTableSize - offset) * 2]; - *weight_3 = coeffs_table[(kTableSize - offset) * 2 + 1]; - *index_0 = Bound(in_loc - 1, limit); - *index_1 = Bound(in_loc, limit); - *index_2 = Bound(in_loc + 1, limit); - *index_3 = Bound(in_loc + 2, limit); + out->weight_0 = coeffs_table[offset * 2 + 1]; + out->weight_1 = coeffs_table[offset * 2]; + out->weight_2 = coeffs_table[(kTableSize - offset) * 2]; + out->weight_3 = coeffs_table[(kTableSize - offset) * 2 + 1]; + out->index_0 = Bound(in_loc - 1, limit); + out->index_1 = Bound(in_loc, limit); + out->index_2 = Bound(in_loc + 1, limit); + out->index_3 = Bound(in_loc + 2, limit); } template @@ -91,43 +100,29 @@ inline float Interpolate1D(const float weight_0, const float weight_1, static_cast(value_3) * weight_3; } +// Compute the 1D interpolation for a given X index using the y_weights +static float Compute(float values_[4], const float xw_0, const float xw_1, + const float xw_2, const float xw_3) { + return Interpolate1D(xw_0, xw_1, xw_2, xw_3, values_[0], values_[1], + values_[2], values_[3]); +} + // In order to compute a single output value, we look at a 4x4 patch in the // source image. As we iterate increasing X across the image, the new 4x4 patch // often overlaps with the previous 4x4 patch we just looked at. // -// This class helps retain that intermediate computation work. -class CachedInterpolation { +// This class helps compute the number of values to copy from the previous +// point's values. +class CachedInterpolationCalculator { public: - CachedInterpolation() - : values_({{std::make_pair(-1, -1), std::make_pair(-1, -1), - std::make_pair(-1, -1), std::make_pair(-1, -1)}}) {} + CachedInterpolationCalculator() : indexes_{-1, -1, -1, -1} {} - // Advances the buffer. Returns the number of valid values. + // Advances iteration. Returns the number of values that should be copied from + // the current point to the next point. The copying should always be done by + // copying the last values from the old point to the first + // values of the new point. inline int Advance(const int64 x_0, const int64 x_1, const int64 x_2, const int64 x_3) { - // Either we have started a new line, or we don't have any values yet. - if (x_0 < values_[0].first || values_[0].first == -1) { - // Zero cached values were valid, we must recompute everything. - return 0; - } - if (values_[0].first == x_0 && values_[3].first == x_3) { - // Everything's the same. Yay! - return 4; - } - if (values_[1].first != 0 && values_[2].first != values_[3].first) { - // Fast (normal) path - if (values_[1].first == x_0) { - CopyPoint(1, 0); - CopyPoint(2, 1); - CopyPoint(3, 2); - return 3; - } - if (values_[2].first == x_0) { - CopyPoint(2, 0); - CopyPoint(3, 1); - return 2; - } - } // We use 2 hands and walk through, copying from one to another where // we already have values. // Invariant, new_indicies_hand <= cached_values_hand @@ -135,10 +130,9 @@ class CachedInterpolation { int cached_values_hand = 0; int new_indicies_hand = 0; while (cached_values_hand < 4) { - if (values_[cached_values_hand].first == - new_x_indices[new_indicies_hand]) { + if (indexes_[cached_values_hand] == new_x_indices[new_indicies_hand]) { if (new_indicies_hand < cached_values_hand) { - CopyPoint(cached_values_hand, new_indicies_hand); + indexes_[new_indicies_hand] = indexes_[cached_values_hand]; } cached_values_hand++; new_indicies_hand++; @@ -146,111 +140,225 @@ class CachedInterpolation { cached_values_hand++; } } + switch (new_indicies_hand) { + case 0: + indexes_[0] = x_0; + TF_FALLTHROUGH_INTENDED; + case 1: + indexes_[1] = x_1; + TF_FALLTHROUGH_INTENDED; + case 2: + indexes_[2] = x_2; + TF_FALLTHROUGH_INTENDED; + case 3: + indexes_[3] = x_3; + break; + } return new_indicies_hand; } - inline void SetPoint(const int index, const int64 x_index, - const float value) { - values_[index] = std::make_pair(x_index, value); - } + private: + int64 indexes_[4]; +}; - // Compute the 1D interpolation for a given X index using the y_weights - inline float Compute(const float xw_0, const float xw_1, const float xw_2, - const float xw_3) const { - return Interpolate1D(xw_0, xw_1, xw_2, xw_3, values_[0].second, - values_[1].second, values_[2].second, - values_[3].second); +static void ComputeXWeightsAndIndices(const ImageResizerState& resizer_state, + std::vector* x_wais) { + CachedInterpolationCalculator calc; + for (int64 x = 0; x < resizer_state.out_width; ++x) { + GetWeightsAndIndices(resizer_state.width_scale, x, resizer_state.in_width, + &(*x_wais)[x]); + auto& x_wai = (*x_wais)[x]; + x_wai.advance = calc.Advance(x_wai.index_0, x_wai.index_1, x_wai.index_2, + x_wai.index_3); } - - private: - inline void CopyPoint(const int source, const int dest) { - values_[dest] = values_[source]; + // Scale the values so they can be used as offsets into buffers. + for (int x = 0; x < resizer_state.out_width; ++x) { + (*x_wais)[x].index_0 *= resizer_state.channels; + (*x_wais)[x].index_1 *= resizer_state.channels; + (*x_wais)[x].index_2 *= resizer_state.channels; + (*x_wais)[x].index_3 *= resizer_state.channels; } +} - std::array, 4> values_; -}; +template +static EIGEN_ALWAYS_INLINE float ComputeYInterpolation( + int which, int channel_num, const WeightsAndIndices& y_wai, + const T* y_ptr_0, const T* y_ptr_1, const T* y_ptr_2, const T* y_ptr_3, + const WeightsAndIndices& x_wai) { + int x_index; + switch (which) { + case 0: + x_index = x_wai.index_0; + break; + case 1: + x_index = x_wai.index_1; + break; + case 2: + x_index = x_wai.index_2; + break; + default: + x_index = x_wai.index_3; + break; + } + const int64 pt_index = x_index + channel_num; + return Interpolate1D(y_wai.weight_0, y_wai.weight_1, y_wai.weight_2, + y_wai.weight_3, y_ptr_0[pt_index], y_ptr_1[pt_index], + y_ptr_2[pt_index], y_ptr_3[pt_index]); +} template inline void interpolate_with_caching( const typename TTypes::ConstTensor& input_data, const ImageResizerState& resizer_state, typename TTypes::Tensor output_data) { - std::vector cached_values(resizer_state.channels); - for (int64 b = 0; b < resizer_state.batch_size; ++b) { - for (int64 y = 0; y < resizer_state.out_height; ++y) { - float y_weight_0; - float y_weight_1; - float y_weight_2; - float y_weight_3; - int64 y_index_0; - int64 y_index_1; - int64 y_index_2; - int64 y_index_3; + std::vector x_wais(resizer_state.out_width); + ComputeXWeightsAndIndices(resizer_state, &x_wais); + + const auto num_channels = resizer_state.channels; + const int64 in_row_width = resizer_state.in_width * num_channels; + const int64 in_batch_width = resizer_state.in_height * in_row_width; + + const T* input_b_ptr = input_data.data(); + float* output_y_ptr = output_data.data(); + + for (int64 b = 0; b < resizer_state.batch_size; + ++b, input_b_ptr += in_batch_width) { + for (int64 y = 0; y < resizer_state.out_height; + ++y, output_y_ptr += resizer_state.out_width * num_channels) { + WeightsAndIndices y_wai; GetWeightsAndIndices(resizer_state.height_scale, y, - resizer_state.in_height, &y_weight_0, &y_weight_1, - &y_weight_2, &y_weight_3, &y_index_0, &y_index_1, - &y_index_2, &y_index_3); - for (int64 x = 0; x < resizer_state.out_width; ++x) { - float xw_0; - float xw_1; - float xw_2; - float xw_3; - int64 x_index_0; - int64 x_index_1; - int64 x_index_2; - int64 x_index_3; - GetWeightsAndIndices(resizer_state.width_scale, x, - resizer_state.in_width, &xw_0, &xw_1, &xw_2, &xw_3, - &x_index_0, &x_index_1, &x_index_2, &x_index_3); - for (int64 c = 0; c < resizer_state.channels; ++c) { - const int advance = cached_values[c].Advance(x_index_0, x_index_1, - x_index_2, x_index_3); - switch (advance) { + resizer_state.in_height, &y_wai); + // Make pointers represent offsets of data in input_b_ptr. + const T* y_ptr_0 = input_b_ptr + y_wai.index_0 * in_row_width; + const T* y_ptr_1 = input_b_ptr + y_wai.index_1 * in_row_width; + const T* y_ptr_2 = input_b_ptr + y_wai.index_2 * in_row_width; + const T* y_ptr_3 = input_b_ptr + y_wai.index_3 * in_row_width; + if (num_channels == 3) { + // Manually unroll case of 3 channels. + float cached_value_0[4]; + float cached_value_1[4]; + float cached_value_2[4]; + for (int64 x = 0; x < resizer_state.out_width; ++x) { + const WeightsAndIndices& x_wai = x_wais[x]; + // Shift values in cached_value_* to fill first 'advance' values. + switch (x_wai.advance) { + case 3: + cached_value_0[0] = cached_value_0[1]; + cached_value_0[1] = cached_value_0[2]; + cached_value_0[2] = cached_value_0[3]; + cached_value_1[0] = cached_value_1[1]; + cached_value_1[1] = cached_value_1[2]; + cached_value_1[2] = cached_value_1[3]; + cached_value_2[0] = cached_value_2[1]; + cached_value_2[1] = cached_value_2[2]; + cached_value_2[2] = cached_value_2[3]; + break; + case 2: + cached_value_0[0] = cached_value_0[2]; + cached_value_0[1] = cached_value_0[3]; + cached_value_1[0] = cached_value_1[2]; + cached_value_1[1] = cached_value_1[3]; + cached_value_2[0] = cached_value_2[2]; + cached_value_2[1] = cached_value_2[3]; + break; + case 1: { + cached_value_0[0] = cached_value_0[3]; + cached_value_1[0] = cached_value_1[3]; + cached_value_2[0] = cached_value_2[3]; + break; + } + } + + // Set the remaining '4-advance' values by computing. + switch (x_wai.advance) { case 0: - cached_values[c].SetPoint( - 0, x_index_0, - Interpolate1D(y_weight_0, y_weight_1, y_weight_2, - y_weight_3, - input_data(b, y_index_0, x_index_0, c), - input_data(b, y_index_1, x_index_0, c), - input_data(b, y_index_2, x_index_0, c), - input_data(b, y_index_3, x_index_0, c))); + cached_value_0[0] = ComputeYInterpolation( + 0, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai); + cached_value_1[0] = ComputeYInterpolation( + 0, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai); + cached_value_2[0] = ComputeYInterpolation( + 0, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai); TF_FALLTHROUGH_INTENDED; case 1: - cached_values[c].SetPoint( - 1, x_index_1, - Interpolate1D(y_weight_0, y_weight_1, y_weight_2, - y_weight_3, - input_data(b, y_index_0, x_index_1, c), - input_data(b, y_index_1, x_index_1, c), - input_data(b, y_index_2, x_index_1, c), - input_data(b, y_index_3, x_index_1, c))); + cached_value_0[1] = ComputeYInterpolation( + 1, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai); + cached_value_1[1] = ComputeYInterpolation( + 1, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai); + cached_value_2[1] = ComputeYInterpolation( + 1, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai); TF_FALLTHROUGH_INTENDED; case 2: - cached_values[c].SetPoint( - 2, x_index_2, - Interpolate1D(y_weight_0, y_weight_1, y_weight_2, - y_weight_3, - input_data(b, y_index_0, x_index_2, c), - input_data(b, y_index_1, x_index_2, c), - input_data(b, y_index_2, x_index_2, c), - input_data(b, y_index_3, x_index_2, c))); + cached_value_0[2] = ComputeYInterpolation( + 2, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai); + cached_value_1[2] = ComputeYInterpolation( + 2, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai); + cached_value_2[2] = ComputeYInterpolation( + 2, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai); TF_FALLTHROUGH_INTENDED; case 3: - cached_values[c].SetPoint( - 3, x_index_3, - Interpolate1D(y_weight_0, y_weight_1, y_weight_2, - y_weight_3, - input_data(b, y_index_0, x_index_3, c), - input_data(b, y_index_1, x_index_3, c), - input_data(b, y_index_2, x_index_3, c), - input_data(b, y_index_3, x_index_3, c))); - TF_FALLTHROUGH_INTENDED; - default: - output_data(b, y, x, c) = - cached_values[c].Compute(xw_0, xw_1, xw_2, xw_3); + cached_value_0[3] = ComputeYInterpolation( + 3, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai); + cached_value_1[3] = ComputeYInterpolation( + 3, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai); + cached_value_2[3] = ComputeYInterpolation( + 3, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai); break; } + output_y_ptr[x * num_channels + 0] = + Compute(cached_value_0, x_wai.weight_0, x_wai.weight_1, + x_wai.weight_2, x_wai.weight_3); + output_y_ptr[x * num_channels + 1] = + Compute(cached_value_1, x_wai.weight_0, x_wai.weight_1, + x_wai.weight_2, x_wai.weight_3); + output_y_ptr[x * num_channels + 2] = + Compute(cached_value_2, x_wai.weight_0, x_wai.weight_1, + x_wai.weight_2, x_wai.weight_3); + } + } else { + for (int64 c = 0; c < num_channels; ++c) { + float cached_value[4]; + for (int64 x = 0; x < resizer_state.out_width; ++x) { + const WeightsAndIndices& x_wai = x_wais[x]; + // Shift values in cached_value to fill first 'advance' values. + switch (x_wai.advance) { + case 3: + cached_value[0] = cached_value[1]; + cached_value[1] = cached_value[2]; + cached_value[2] = cached_value[3]; + break; + case 2: + cached_value[0] = cached_value[2]; + cached_value[1] = cached_value[3]; + break; + case 1: { + cached_value[0] = cached_value[3]; + break; + } + } + + // Set the remaining '4-advance' values by computing. + switch (x_wai.advance) { + case 0: + cached_value[0] = ComputeYInterpolation( + 0, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai); + TF_FALLTHROUGH_INTENDED; + case 1: + cached_value[1] = ComputeYInterpolation( + 1, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai); + TF_FALLTHROUGH_INTENDED; + case 2: + cached_value[2] = ComputeYInterpolation( + 2, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai); + TF_FALLTHROUGH_INTENDED; + case 3: + cached_value[3] = ComputeYInterpolation( + 3, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai); + break; + } + output_y_ptr[x * num_channels + c] = + Compute(cached_value, x_wai.weight_0, x_wai.weight_1, + x_wai.weight_2, x_wai.weight_3); + } } } } diff --git a/tensorflow/core/kernels/resize_bicubic_op_test.cc b/tensorflow/core/kernels/resize_bicubic_op_test.cc index 814102dda8e6091d65f8e81bbb95db84621fcdfc..ae14d2804e2bfbb61d71ce0ab4026a2b19293beb 100644 --- a/tensorflow/core/kernels/resize_bicubic_op_test.cc +++ b/tensorflow/core/kernels/resize_bicubic_op_test.cc @@ -34,9 +34,9 @@ class ResizeBicubicOpTest : public OpsTestBase { TF_EXPECT_OK(InitOp()); } - const Tensor* AddRandomImageInput(const TensorShape& shape) { - CHECK_GT(input_types_.size(), inputs_.size()) - << "Adding more inputs than types; perhaps you need to call MakeOp"; + const Tensor* SetRandomImageInput(const TensorShape& shape) { + inputs_.clear(); + CHECK_EQ(shape.dims(), 4) << "All images must have 4 dimensions."; bool is_ref = IsRefType(input_types_[inputs_.size()]); Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()), @@ -155,16 +155,22 @@ class ResizeBicubicOpTest : public OpsTestBase { } protected: - void RunRandomTest(const int64 in_height, const int64 in_width) { - const Tensor* input = - AddRandomImageInput(TensorShape({1, in_height, in_width, 1})); - AddInputFromArray(TensorShape({2}), {299, 299}); + void RunRandomTest(const int batch_size, const int64 in_height, + const int64 in_width, const int target_height, + const int target_width, int channels) { + LOG(INFO) << "Running random test " << in_height << "x" << in_width << "x" + << channels << " to " << target_height << "x" << target_width + << "x" << channels; + const Tensor* input = SetRandomImageInput( + TensorShape({batch_size, in_height, in_width, channels})); + AddInputFromArray(TensorShape({2}), {target_height, target_width}); TF_ASSERT_OK(RunOpKernel()); - std::unique_ptr expected( - new Tensor(device_->GetAllocator(AllocatorAttributes()), - DataTypeToEnum::v(), TensorShape({1, 299, 299, 1}))); + std::unique_ptr expected(new Tensor( + device_->GetAllocator(AllocatorAttributes()), + DataTypeToEnum::v(), + TensorShape({batch_size, target_height, target_width, channels}))); ResizeBicubicBaseline(input->tensor(), expected->tensor()); @@ -175,6 +181,21 @@ class ResizeBicubicOpTest : public OpsTestBase { // 0.00001 of the previous implementation. test::ExpectTensorNear(*expected, *GetOutput(0), 0.00001); } + + void RunManyRandomTests(int channels) { + for (int batch_size : {1, 2, 5}) { + for (int in_w : {2, 4, 7, 20, 165}) { + for (int in_h : {1, 3, 5, 8, 100, 233}) { + for (int target_height : {1, 2, 3, 50, 113}) { + for (int target_width : {target_height, target_height / 2 + 1}) { + RunRandomTest(batch_size, in_h, in_w, target_height, target_width, + channels); + } + } + } + } + } + } }; TEST_F(ResizeBicubicOpTest, TestBicubic2x2To1x1) { @@ -204,15 +225,30 @@ TEST_F(ResizeBicubicOpTest, TestBicubic2x2To0x0) { } TEST_F(ResizeBicubicOpTest, TestBicubicRandom141x186) { - RunRandomTest(141, 186); + RunRandomTest(2, 141, 186, 299, 299, 1 /* channels */); + RunRandomTest(2, 141, 186, 299, 299, 3 /* channels */); } TEST_F(ResizeBicubicOpTest, TestBicubicRandom183x229) { - RunRandomTest(183, 229); + RunRandomTest(2, 183, 229, 299, 299, 1 /* channels */); + RunRandomTest(2, 183, 229, 299, 299, 3 /* channels */); } TEST_F(ResizeBicubicOpTest, TestBicubicRandom749x603) { - RunRandomTest(749, 603); + RunRandomTest(2, 749, 603, 299, 299, 1 /* channels */); + RunRandomTest(2, 749, 603, 299, 299, 3 /* channels */); +} + +TEST_F(ResizeBicubicOpTest, TestAreaRandomDataSeveralInputsSizes1Channel) { + RunManyRandomTests(1); +} + +TEST_F(ResizeBicubicOpTest, TestAreaRandomDataSeveralInputsSizes3Channels) { + RunManyRandomTests(3); +} + +TEST_F(ResizeBicubicOpTest, TestAreaRandomDataSeveralInputsSizes4Channels) { + RunManyRandomTests(4); } static Graph* ResizeBicubic(int batch_size, int size, int channels) { diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 0c1e116c03b71d776eb15e22f8b4066e0713de85..6ae36f0d84eb7846e529511d689754d9926f2f3f 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -498,14 +498,16 @@ class ResizeBilinearBenchmark(test.Benchmark): class ResizeBicubicBenchmark(test.Benchmark): - def _benchmarkResize(self, image_size): - # 4D float tensor (10 images per batch, 3 channels per image) + def _benchmarkResize(self, image_size, num_channels): + batch_size = 1 + num_ops = 1000 img = variables.Variable( - random_ops.random_normal([10, image_size[0], image_size[1], 3]), + random_ops.random_normal( + [batch_size, image_size[0], image_size[1], num_channels]), name='img') deps = [] - for _ in xrange(100): + for _ in xrange(num_ops): with ops.control_dependencies(deps): resize_op = image_ops.resize_bicubic( img, [299, 299], align_corners=False) @@ -514,20 +516,41 @@ class ResizeBicubicBenchmark(test.Benchmark): with session.Session() as sess: sess.run(variables.global_variables_initializer()) - print('Variables initalized for resize_bicubic image size: %s.' % - (image_size,)) - benchmark_values = self.run_op_benchmark( - sess, benchmark_op, name=('bicubic_%s_%s' % image_size)) - print('Benchmark values:\n%s' % benchmark_values) + results = self.run_op_benchmark( + sess, + benchmark_op, + min_iters=20, + name=('resize_bicubic_%s_%s_%s' % (image_size[0], image_size[1], + num_channels))) + print('%s : %.2f ms/img' % (results['name'], 1000 * results['wall_time'] + / (batch_size * num_ops))) + + def benchmarkSimilar3Channel(self): + self._benchmarkResize((183, 229), 3) + + def benchmarkScaleUp3Channel(self): + self._benchmarkResize((141, 186), 3) + + def benchmarkScaleDown3Channel(self): + self._benchmarkResize((749, 603), 3) + + def benchmarkSimilar1Channel(self): + self._benchmarkResize((183, 229), 1) + + def benchmarkScaleUp1Channel(self): + self._benchmarkResize((141, 186), 1) + + def benchmarkScaleDown1Channel(self): + self._benchmarkResize((749, 603), 1) - def benchmarkSimilar(self): - self._benchmarkResize((183, 229)) + def benchmarkSimilar4Channel(self): + self._benchmarkResize((183, 229), 4) - def benchmarkScaleUp(self): - self._benchmarkResize((141, 186)) + def benchmarkScaleUp4Channel(self): + self._benchmarkResize((141, 186), 4) - def benchmarkScaleDown(self): - self._benchmarkResize((749, 603)) + def benchmarkScaleDown4Channel(self): + self._benchmarkResize((749, 603), 4) class ResizeAreaBenchmark(test.Benchmark):