提交 69761c04 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Improve performance of ResizeBicubic:

- for each x value, cache the indexes and the 'advance 'value.
- access input and output through direct pointer access instead
  of through eigen_tensor(b,y,x,c).
- special-case the 3 channel case.
- switch channel/width loops in the general case so that a single
  float[4] can be used for the cache.

After caching 'advance' value, the values used during iteration
could be converted to plain float[4] instead of using the CachedInterpolation
object.

Removed the special cases in CachedInterpolation::Advance; the special
cases for speed are not needed when it's only called once per image.

Added more test cases and benchmark cases.
Change: 150077397
上级 1c1d6f54
......@@ -60,24 +60,33 @@ inline int64 Bound(int64 val, int64 limit) {
return std::min(limit - 1ll, std::max(0ll, val));
}
struct WeightsAndIndices {
float weight_0;
float weight_1;
float weight_2;
float weight_3;
int64 index_0;
int64 index_1;
int64 index_2;
int64 index_3;
int advance; // advance value.
};
inline void GetWeightsAndIndices(const float scale, const int64 out_loc,
const int64 limit, float* weight_0,
float* weight_1, float* weight_2,
float* weight_3, int64* index_0,
int64* index_1, int64* index_2,
int64* index_3) {
const int64 limit, WeightsAndIndices* out) {
const int64 in_loc = scale * out_loc;
const float delta = scale * out_loc - in_loc;
const int64 offset = lrintf(delta * kTableSize);
const float* coeffs_table = GetCoeffsTable();
*weight_0 = coeffs_table[offset * 2 + 1];
*weight_1 = coeffs_table[offset * 2];
*weight_2 = coeffs_table[(kTableSize - offset) * 2];
*weight_3 = coeffs_table[(kTableSize - offset) * 2 + 1];
*index_0 = Bound(in_loc - 1, limit);
*index_1 = Bound(in_loc, limit);
*index_2 = Bound(in_loc + 1, limit);
*index_3 = Bound(in_loc + 2, limit);
out->weight_0 = coeffs_table[offset * 2 + 1];
out->weight_1 = coeffs_table[offset * 2];
out->weight_2 = coeffs_table[(kTableSize - offset) * 2];
out->weight_3 = coeffs_table[(kTableSize - offset) * 2 + 1];
out->index_0 = Bound(in_loc - 1, limit);
out->index_1 = Bound(in_loc, limit);
out->index_2 = Bound(in_loc + 1, limit);
out->index_3 = Bound(in_loc + 2, limit);
}
template <typename T>
......@@ -91,43 +100,29 @@ inline float Interpolate1D(const float weight_0, const float weight_1,
static_cast<float>(value_3) * weight_3;
}
// Compute the 1D interpolation for a given X index using the y_weights
static float Compute(float values_[4], const float xw_0, const float xw_1,
const float xw_2, const float xw_3) {
return Interpolate1D(xw_0, xw_1, xw_2, xw_3, values_[0], values_[1],
values_[2], values_[3]);
}
// In order to compute a single output value, we look at a 4x4 patch in the
// source image. As we iterate increasing X across the image, the new 4x4 patch
// often overlaps with the previous 4x4 patch we just looked at.
//
// This class helps retain that intermediate computation work.
class CachedInterpolation {
// This class helps compute the number of values to copy from the previous
// point's values.
class CachedInterpolationCalculator {
public:
CachedInterpolation()
: values_({{std::make_pair(-1, -1), std::make_pair(-1, -1),
std::make_pair(-1, -1), std::make_pair(-1, -1)}}) {}
CachedInterpolationCalculator() : indexes_{-1, -1, -1, -1} {}
// Advances the buffer. Returns the number of valid values.
// Advances iteration. Returns the number of values that should be copied from
// the current point to the next point. The copying should always be done by
// copying the last <retval> values from the old point to the first <retval>
// values of the new point.
inline int Advance(const int64 x_0, const int64 x_1, const int64 x_2,
const int64 x_3) {
// Either we have started a new line, or we don't have any values yet.
if (x_0 < values_[0].first || values_[0].first == -1) {
// Zero cached values were valid, we must recompute everything.
return 0;
}
if (values_[0].first == x_0 && values_[3].first == x_3) {
// Everything's the same. Yay!
return 4;
}
if (values_[1].first != 0 && values_[2].first != values_[3].first) {
// Fast (normal) path
if (values_[1].first == x_0) {
CopyPoint(1, 0);
CopyPoint(2, 1);
CopyPoint(3, 2);
return 3;
}
if (values_[2].first == x_0) {
CopyPoint(2, 0);
CopyPoint(3, 1);
return 2;
}
}
// We use 2 hands and walk through, copying from one to another where
// we already have values.
// Invariant, new_indicies_hand <= cached_values_hand
......@@ -135,10 +130,9 @@ class CachedInterpolation {
int cached_values_hand = 0;
int new_indicies_hand = 0;
while (cached_values_hand < 4) {
if (values_[cached_values_hand].first ==
new_x_indices[new_indicies_hand]) {
if (indexes_[cached_values_hand] == new_x_indices[new_indicies_hand]) {
if (new_indicies_hand < cached_values_hand) {
CopyPoint(cached_values_hand, new_indicies_hand);
indexes_[new_indicies_hand] = indexes_[cached_values_hand];
}
cached_values_hand++;
new_indicies_hand++;
......@@ -146,111 +140,225 @@ class CachedInterpolation {
cached_values_hand++;
}
}
switch (new_indicies_hand) {
case 0:
indexes_[0] = x_0;
TF_FALLTHROUGH_INTENDED;
case 1:
indexes_[1] = x_1;
TF_FALLTHROUGH_INTENDED;
case 2:
indexes_[2] = x_2;
TF_FALLTHROUGH_INTENDED;
case 3:
indexes_[3] = x_3;
break;
}
return new_indicies_hand;
}
inline void SetPoint(const int index, const int64 x_index,
const float value) {
values_[index] = std::make_pair(x_index, value);
}
private:
int64 indexes_[4];
};
// Compute the 1D interpolation for a given X index using the y_weights
inline float Compute(const float xw_0, const float xw_1, const float xw_2,
const float xw_3) const {
return Interpolate1D(xw_0, xw_1, xw_2, xw_3, values_[0].second,
values_[1].second, values_[2].second,
values_[3].second);
static void ComputeXWeightsAndIndices(const ImageResizerState& resizer_state,
std::vector<WeightsAndIndices>* x_wais) {
CachedInterpolationCalculator calc;
for (int64 x = 0; x < resizer_state.out_width; ++x) {
GetWeightsAndIndices(resizer_state.width_scale, x, resizer_state.in_width,
&(*x_wais)[x]);
auto& x_wai = (*x_wais)[x];
x_wai.advance = calc.Advance(x_wai.index_0, x_wai.index_1, x_wai.index_2,
x_wai.index_3);
}
private:
inline void CopyPoint(const int source, const int dest) {
values_[dest] = values_[source];
// Scale the values so they can be used as offsets into buffers.
for (int x = 0; x < resizer_state.out_width; ++x) {
(*x_wais)[x].index_0 *= resizer_state.channels;
(*x_wais)[x].index_1 *= resizer_state.channels;
(*x_wais)[x].index_2 *= resizer_state.channels;
(*x_wais)[x].index_3 *= resizer_state.channels;
}
}
std::array<std::pair<int64, float>, 4> values_;
};
template <typename T>
static EIGEN_ALWAYS_INLINE float ComputeYInterpolation(
int which, int channel_num, const WeightsAndIndices& y_wai,
const T* y_ptr_0, const T* y_ptr_1, const T* y_ptr_2, const T* y_ptr_3,
const WeightsAndIndices& x_wai) {
int x_index;
switch (which) {
case 0:
x_index = x_wai.index_0;
break;
case 1:
x_index = x_wai.index_1;
break;
case 2:
x_index = x_wai.index_2;
break;
default:
x_index = x_wai.index_3;
break;
}
const int64 pt_index = x_index + channel_num;
return Interpolate1D<T>(y_wai.weight_0, y_wai.weight_1, y_wai.weight_2,
y_wai.weight_3, y_ptr_0[pt_index], y_ptr_1[pt_index],
y_ptr_2[pt_index], y_ptr_3[pt_index]);
}
template <typename T>
inline void interpolate_with_caching(
const typename TTypes<T, 4>::ConstTensor& input_data,
const ImageResizerState& resizer_state,
typename TTypes<float, 4>::Tensor output_data) {
std::vector<CachedInterpolation> cached_values(resizer_state.channels);
for (int64 b = 0; b < resizer_state.batch_size; ++b) {
for (int64 y = 0; y < resizer_state.out_height; ++y) {
float y_weight_0;
float y_weight_1;
float y_weight_2;
float y_weight_3;
int64 y_index_0;
int64 y_index_1;
int64 y_index_2;
int64 y_index_3;
std::vector<WeightsAndIndices> x_wais(resizer_state.out_width);
ComputeXWeightsAndIndices(resizer_state, &x_wais);
const auto num_channels = resizer_state.channels;
const int64 in_row_width = resizer_state.in_width * num_channels;
const int64 in_batch_width = resizer_state.in_height * in_row_width;
const T* input_b_ptr = input_data.data();
float* output_y_ptr = output_data.data();
for (int64 b = 0; b < resizer_state.batch_size;
++b, input_b_ptr += in_batch_width) {
for (int64 y = 0; y < resizer_state.out_height;
++y, output_y_ptr += resizer_state.out_width * num_channels) {
WeightsAndIndices y_wai;
GetWeightsAndIndices(resizer_state.height_scale, y,
resizer_state.in_height, &y_weight_0, &y_weight_1,
&y_weight_2, &y_weight_3, &y_index_0, &y_index_1,
&y_index_2, &y_index_3);
for (int64 x = 0; x < resizer_state.out_width; ++x) {
float xw_0;
float xw_1;
float xw_2;
float xw_3;
int64 x_index_0;
int64 x_index_1;
int64 x_index_2;
int64 x_index_3;
GetWeightsAndIndices(resizer_state.width_scale, x,
resizer_state.in_width, &xw_0, &xw_1, &xw_2, &xw_3,
&x_index_0, &x_index_1, &x_index_2, &x_index_3);
for (int64 c = 0; c < resizer_state.channels; ++c) {
const int advance = cached_values[c].Advance(x_index_0, x_index_1,
x_index_2, x_index_3);
switch (advance) {
resizer_state.in_height, &y_wai);
// Make pointers represent offsets of data in input_b_ptr.
const T* y_ptr_0 = input_b_ptr + y_wai.index_0 * in_row_width;
const T* y_ptr_1 = input_b_ptr + y_wai.index_1 * in_row_width;
const T* y_ptr_2 = input_b_ptr + y_wai.index_2 * in_row_width;
const T* y_ptr_3 = input_b_ptr + y_wai.index_3 * in_row_width;
if (num_channels == 3) {
// Manually unroll case of 3 channels.
float cached_value_0[4];
float cached_value_1[4];
float cached_value_2[4];
for (int64 x = 0; x < resizer_state.out_width; ++x) {
const WeightsAndIndices& x_wai = x_wais[x];
// Shift values in cached_value_* to fill first 'advance' values.
switch (x_wai.advance) {
case 3:
cached_value_0[0] = cached_value_0[1];
cached_value_0[1] = cached_value_0[2];
cached_value_0[2] = cached_value_0[3];
cached_value_1[0] = cached_value_1[1];
cached_value_1[1] = cached_value_1[2];
cached_value_1[2] = cached_value_1[3];
cached_value_2[0] = cached_value_2[1];
cached_value_2[1] = cached_value_2[2];
cached_value_2[2] = cached_value_2[3];
break;
case 2:
cached_value_0[0] = cached_value_0[2];
cached_value_0[1] = cached_value_0[3];
cached_value_1[0] = cached_value_1[2];
cached_value_1[1] = cached_value_1[3];
cached_value_2[0] = cached_value_2[2];
cached_value_2[1] = cached_value_2[3];
break;
case 1: {
cached_value_0[0] = cached_value_0[3];
cached_value_1[0] = cached_value_1[3];
cached_value_2[0] = cached_value_2[3];
break;
}
}
// Set the remaining '4-advance' values by computing.
switch (x_wai.advance) {
case 0:
cached_values[c].SetPoint(
0, x_index_0,
Interpolate1D<T>(y_weight_0, y_weight_1, y_weight_2,
y_weight_3,
input_data(b, y_index_0, x_index_0, c),
input_data(b, y_index_1, x_index_0, c),
input_data(b, y_index_2, x_index_0, c),
input_data(b, y_index_3, x_index_0, c)));
cached_value_0[0] = ComputeYInterpolation(
0, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
cached_value_1[0] = ComputeYInterpolation(
0, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
cached_value_2[0] = ComputeYInterpolation(
0, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
TF_FALLTHROUGH_INTENDED;
case 1:
cached_values[c].SetPoint(
1, x_index_1,
Interpolate1D<T>(y_weight_0, y_weight_1, y_weight_2,
y_weight_3,
input_data(b, y_index_0, x_index_1, c),
input_data(b, y_index_1, x_index_1, c),
input_data(b, y_index_2, x_index_1, c),
input_data(b, y_index_3, x_index_1, c)));
cached_value_0[1] = ComputeYInterpolation(
1, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
cached_value_1[1] = ComputeYInterpolation(
1, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
cached_value_2[1] = ComputeYInterpolation(
1, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
TF_FALLTHROUGH_INTENDED;
case 2:
cached_values[c].SetPoint(
2, x_index_2,
Interpolate1D<T>(y_weight_0, y_weight_1, y_weight_2,
y_weight_3,
input_data(b, y_index_0, x_index_2, c),
input_data(b, y_index_1, x_index_2, c),
input_data(b, y_index_2, x_index_2, c),
input_data(b, y_index_3, x_index_2, c)));
cached_value_0[2] = ComputeYInterpolation(
2, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
cached_value_1[2] = ComputeYInterpolation(
2, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
cached_value_2[2] = ComputeYInterpolation(
2, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
TF_FALLTHROUGH_INTENDED;
case 3:
cached_values[c].SetPoint(
3, x_index_3,
Interpolate1D<T>(y_weight_0, y_weight_1, y_weight_2,
y_weight_3,
input_data(b, y_index_0, x_index_3, c),
input_data(b, y_index_1, x_index_3, c),
input_data(b, y_index_2, x_index_3, c),
input_data(b, y_index_3, x_index_3, c)));
TF_FALLTHROUGH_INTENDED;
default:
output_data(b, y, x, c) =
cached_values[c].Compute(xw_0, xw_1, xw_2, xw_3);
cached_value_0[3] = ComputeYInterpolation(
3, 0, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
cached_value_1[3] = ComputeYInterpolation(
3, 1, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
cached_value_2[3] = ComputeYInterpolation(
3, 2, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
break;
}
output_y_ptr[x * num_channels + 0] =
Compute(cached_value_0, x_wai.weight_0, x_wai.weight_1,
x_wai.weight_2, x_wai.weight_3);
output_y_ptr[x * num_channels + 1] =
Compute(cached_value_1, x_wai.weight_0, x_wai.weight_1,
x_wai.weight_2, x_wai.weight_3);
output_y_ptr[x * num_channels + 2] =
Compute(cached_value_2, x_wai.weight_0, x_wai.weight_1,
x_wai.weight_2, x_wai.weight_3);
}
} else {
for (int64 c = 0; c < num_channels; ++c) {
float cached_value[4];
for (int64 x = 0; x < resizer_state.out_width; ++x) {
const WeightsAndIndices& x_wai = x_wais[x];
// Shift values in cached_value to fill first 'advance' values.
switch (x_wai.advance) {
case 3:
cached_value[0] = cached_value[1];
cached_value[1] = cached_value[2];
cached_value[2] = cached_value[3];
break;
case 2:
cached_value[0] = cached_value[2];
cached_value[1] = cached_value[3];
break;
case 1: {
cached_value[0] = cached_value[3];
break;
}
}
// Set the remaining '4-advance' values by computing.
switch (x_wai.advance) {
case 0:
cached_value[0] = ComputeYInterpolation(
0, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
TF_FALLTHROUGH_INTENDED;
case 1:
cached_value[1] = ComputeYInterpolation(
1, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
TF_FALLTHROUGH_INTENDED;
case 2:
cached_value[2] = ComputeYInterpolation(
2, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
TF_FALLTHROUGH_INTENDED;
case 3:
cached_value[3] = ComputeYInterpolation(
3, c, y_wai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, x_wai);
break;
}
output_y_ptr[x * num_channels + c] =
Compute(cached_value, x_wai.weight_0, x_wai.weight_1,
x_wai.weight_2, x_wai.weight_3);
}
}
}
}
......
......@@ -34,9 +34,9 @@ class ResizeBicubicOpTest : public OpsTestBase {
TF_EXPECT_OK(InitOp());
}
const Tensor* AddRandomImageInput(const TensorShape& shape) {
CHECK_GT(input_types_.size(), inputs_.size())
<< "Adding more inputs than types; perhaps you need to call MakeOp";
const Tensor* SetRandomImageInput(const TensorShape& shape) {
inputs_.clear();
CHECK_EQ(shape.dims(), 4) << "All images must have 4 dimensions.";
bool is_ref = IsRefType(input_types_[inputs_.size()]);
Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()),
......@@ -155,16 +155,22 @@ class ResizeBicubicOpTest : public OpsTestBase {
}
protected:
void RunRandomTest(const int64 in_height, const int64 in_width) {
const Tensor* input =
AddRandomImageInput(TensorShape({1, in_height, in_width, 1}));
AddInputFromArray<int32>(TensorShape({2}), {299, 299});
void RunRandomTest(const int batch_size, const int64 in_height,
const int64 in_width, const int target_height,
const int target_width, int channels) {
LOG(INFO) << "Running random test " << in_height << "x" << in_width << "x"
<< channels << " to " << target_height << "x" << target_width
<< "x" << channels;
const Tensor* input = SetRandomImageInput(
TensorShape({batch_size, in_height, in_width, channels}));
AddInputFromArray<int32>(TensorShape({2}), {target_height, target_width});
TF_ASSERT_OK(RunOpKernel());
std::unique_ptr<Tensor> expected(
new Tensor(device_->GetAllocator(AllocatorAttributes()),
DataTypeToEnum<float>::v(), TensorShape({1, 299, 299, 1})));
std::unique_ptr<Tensor> expected(new Tensor(
device_->GetAllocator(AllocatorAttributes()),
DataTypeToEnum<float>::v(),
TensorShape({batch_size, target_height, target_width, channels})));
ResizeBicubicBaseline(input->tensor<float, 4>(),
expected->tensor<float, 4>());
......@@ -175,6 +181,21 @@ class ResizeBicubicOpTest : public OpsTestBase {
// 0.00001 of the previous implementation.
test::ExpectTensorNear<float>(*expected, *GetOutput(0), 0.00001);
}
void RunManyRandomTests(int channels) {
for (int batch_size : {1, 2, 5}) {
for (int in_w : {2, 4, 7, 20, 165}) {
for (int in_h : {1, 3, 5, 8, 100, 233}) {
for (int target_height : {1, 2, 3, 50, 113}) {
for (int target_width : {target_height, target_height / 2 + 1}) {
RunRandomTest(batch_size, in_h, in_w, target_height, target_width,
channels);
}
}
}
}
}
}
};
TEST_F(ResizeBicubicOpTest, TestBicubic2x2To1x1) {
......@@ -204,15 +225,30 @@ TEST_F(ResizeBicubicOpTest, TestBicubic2x2To0x0) {
}
TEST_F(ResizeBicubicOpTest, TestBicubicRandom141x186) {
RunRandomTest(141, 186);
RunRandomTest(2, 141, 186, 299, 299, 1 /* channels */);
RunRandomTest(2, 141, 186, 299, 299, 3 /* channels */);
}
TEST_F(ResizeBicubicOpTest, TestBicubicRandom183x229) {
RunRandomTest(183, 229);
RunRandomTest(2, 183, 229, 299, 299, 1 /* channels */);
RunRandomTest(2, 183, 229, 299, 299, 3 /* channels */);
}
TEST_F(ResizeBicubicOpTest, TestBicubicRandom749x603) {
RunRandomTest(749, 603);
RunRandomTest(2, 749, 603, 299, 299, 1 /* channels */);
RunRandomTest(2, 749, 603, 299, 299, 3 /* channels */);
}
TEST_F(ResizeBicubicOpTest, TestAreaRandomDataSeveralInputsSizes1Channel) {
RunManyRandomTests(1);
}
TEST_F(ResizeBicubicOpTest, TestAreaRandomDataSeveralInputsSizes3Channels) {
RunManyRandomTests(3);
}
TEST_F(ResizeBicubicOpTest, TestAreaRandomDataSeveralInputsSizes4Channels) {
RunManyRandomTests(4);
}
static Graph* ResizeBicubic(int batch_size, int size, int channels) {
......
......@@ -498,14 +498,16 @@ class ResizeBilinearBenchmark(test.Benchmark):
class ResizeBicubicBenchmark(test.Benchmark):
def _benchmarkResize(self, image_size):
# 4D float tensor (10 images per batch, 3 channels per image)
def _benchmarkResize(self, image_size, num_channels):
batch_size = 1
num_ops = 1000
img = variables.Variable(
random_ops.random_normal([10, image_size[0], image_size[1], 3]),
random_ops.random_normal(
[batch_size, image_size[0], image_size[1], num_channels]),
name='img')
deps = []
for _ in xrange(100):
for _ in xrange(num_ops):
with ops.control_dependencies(deps):
resize_op = image_ops.resize_bicubic(
img, [299, 299], align_corners=False)
......@@ -514,20 +516,41 @@ class ResizeBicubicBenchmark(test.Benchmark):
with session.Session() as sess:
sess.run(variables.global_variables_initializer())
print('Variables initalized for resize_bicubic image size: %s.' %
(image_size,))
benchmark_values = self.run_op_benchmark(
sess, benchmark_op, name=('bicubic_%s_%s' % image_size))
print('Benchmark values:\n%s' % benchmark_values)
results = self.run_op_benchmark(
sess,
benchmark_op,
min_iters=20,
name=('resize_bicubic_%s_%s_%s' % (image_size[0], image_size[1],
num_channels)))
print('%s : %.2f ms/img' % (results['name'], 1000 * results['wall_time']
/ (batch_size * num_ops)))
def benchmarkSimilar3Channel(self):
self._benchmarkResize((183, 229), 3)
def benchmarkScaleUp3Channel(self):
self._benchmarkResize((141, 186), 3)
def benchmarkScaleDown3Channel(self):
self._benchmarkResize((749, 603), 3)
def benchmarkSimilar1Channel(self):
self._benchmarkResize((183, 229), 1)
def benchmarkScaleUp1Channel(self):
self._benchmarkResize((141, 186), 1)
def benchmarkScaleDown1Channel(self):
self._benchmarkResize((749, 603), 1)
def benchmarkSimilar(self):
self._benchmarkResize((183, 229))
def benchmarkSimilar4Channel(self):
self._benchmarkResize((183, 229), 4)
def benchmarkScaleUp(self):
self._benchmarkResize((141, 186))
def benchmarkScaleUp4Channel(self):
self._benchmarkResize((141, 186), 4)
def benchmarkScaleDown(self):
self._benchmarkResize((749, 603))
def benchmarkScaleDown4Channel(self):
self._benchmarkResize((749, 603), 4)
class ResizeAreaBenchmark(test.Benchmark):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册