提交 8b183f2c 编写于 作者: M Megvii Engine Team

test(dnn/testcase): fix a testcase bug

GitOrigin-RevId: f6b9e5631888f71d4fc6e0fc4de523ae55ea0001
上级 5c224c71
...@@ -6,75 +6,110 @@ ...@@ -6,75 +6,110 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "megdnn/basic_types.h" #include "megdnn/basic_types.h"
#include "test/common/utils.h"
#include "test/common/timer.h" #include "test/common/timer.h"
#include "test/common/utils.h"
#include <random>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <random>
using namespace megdnn; using namespace megdnn;
namespace { namespace {
bool eq_shape0(const TensorShape &a, const TensorShape &b) { bool eq_shape0(const TensorShape& a, const TensorShape& b) {
if (a.ndim != b.ndim) if (a.ndim != b.ndim)
return false; return false;
return std::equal(a.shape, a.shape + a.ndim, b.shape); return std::equal(a.shape, a.shape + a.ndim, b.shape);
} }
bool eq_shape1(const TensorShape &a, const TensorShape &b) { bool eq_shape1(const TensorShape& a, const TensorShape& b) {
if (a.ndim == b.ndim) { if (a.ndim == b.ndim) {
size_t eq = 0; size_t eq = 0;
switch (a.ndim) { switch (a.ndim) {
case 6: eq += a.shape[5] == b.shape[5]; MEGDNN_FALLTHRU case 7:
case 5: eq += a.shape[4] == b.shape[4]; MEGDNN_FALLTHRU eq += a.shape[6] == b.shape[6];
case 4: eq += a.shape[3] == b.shape[3]; MEGDNN_FALLTHRU MEGDNN_FALLTHRU
case 3: eq += a.shape[2] == b.shape[2]; MEGDNN_FALLTHRU case 6:
case 2: eq += a.shape[1] == b.shape[1]; MEGDNN_FALLTHRU eq += a.shape[5] == b.shape[5];
case 1: eq += a.shape[0] == b.shape[0]; MEGDNN_FALLTHRU
case 5:
eq += a.shape[4] == b.shape[4];
MEGDNN_FALLTHRU
case 4:
eq += a.shape[3] == b.shape[3];
MEGDNN_FALLTHRU
case 3:
eq += a.shape[2] == b.shape[2];
MEGDNN_FALLTHRU
case 2:
eq += a.shape[1] == b.shape[1];
MEGDNN_FALLTHRU
case 1:
eq += a.shape[0] == b.shape[0];
} }
return eq == a.ndim; return eq == a.ndim;
} }
return false; return false;
} }
bool eq_layout0(const TensorLayout &a, const TensorLayout &b) { bool eq_layout0(const TensorLayout& a, const TensorLayout& b) {
if (!eq_shape0(a, b)) if (!eq_shape0(a, b))
return false; return false;
return std::equal(a.stride, a.stride + a.ndim, b.stride); return std::equal(a.stride, a.stride + a.ndim, b.stride);
} }
bool eq_layout1(const TensorLayout &a, const TensorLayout &b) { bool eq_layout1(const TensorLayout& a, const TensorLayout& b) {
auto ax = [](size_t shape0, size_t shape1, auto ax = [](size_t shape0, size_t shape1, ptrdiff_t stride0,
ptrdiff_t stride0, ptrdiff_t stride1) { ptrdiff_t stride1) {
return (shape0 == shape1) & ((shape0 == 1) | (stride0 == stride1)); return (shape0 == shape1) & ((shape0 == 1) | (stride0 == stride1));
}; };
if (a.ndim == b.ndim) { if (a.ndim == b.ndim) {
size_t eq = 0; size_t eq = 0;
switch (a.ndim) { switch (a.ndim) {
case 6: eq += ax(a.shape[5], b.shape[5], a.stride[5], b.stride[5]); MEGDNN_FALLTHRU case 7:
case 5: eq += ax(a.shape[4], b.shape[4], a.stride[4], b.stride[4]); MEGDNN_FALLTHRU eq += ax(a.shape[6], b.shape[6], a.stride[6], b.stride[6]);
case 4: eq += ax(a.shape[3], b.shape[3], a.stride[3], b.stride[3]); MEGDNN_FALLTHRU MEGDNN_FALLTHRU
case 3: eq += ax(a.shape[2], b.shape[2], a.stride[2], b.stride[2]); MEGDNN_FALLTHRU case 6:
case 2: eq += ax(a.shape[1], b.shape[1], a.stride[1], b.stride[1]); MEGDNN_FALLTHRU eq += ax(a.shape[5], b.shape[5], a.stride[5], b.stride[5]);
case 1: eq += ax(a.shape[0], b.shape[0], a.stride[0], b.stride[0]); MEGDNN_FALLTHRU
case 5:
eq += ax(a.shape[4], b.shape[4], a.stride[4], b.stride[4]);
MEGDNN_FALLTHRU
case 4:
eq += ax(a.shape[3], b.shape[3], a.stride[3], b.stride[3]);
MEGDNN_FALLTHRU
case 3:
eq += ax(a.shape[2], b.shape[2], a.stride[2], b.stride[2]);
MEGDNN_FALLTHRU
case 2:
eq += ax(a.shape[1], b.shape[1], a.stride[1], b.stride[1]);
MEGDNN_FALLTHRU
case 1:
eq += ax(a.shape[0], b.shape[0], a.stride[0], b.stride[0]);
} }
return eq == a.ndim; return eq == a.ndim;
} }
return false; return false;
} }
} // anonymous namespace } // anonymous namespace
// config NR_TEST at small memory device, eg, EV300 etc // config NR_TEST at small memory device, eg, EV300 etc
static constexpr size_t NR_TEST = 10000; static constexpr size_t NR_TEST = 10000;
TEST(BENCHMARK_BASIC_TYPES, EQ_SHAPE) { TEST(BENCHMARK_BASIC_TYPES, EQ_SHAPE) {
std::mt19937_64 rng; std::mt19937_64 rng;
static TensorShape s0, s1[NR_TEST]; static TensorShape s0, s1[NR_TEST];
auto init = [&rng](TensorShape& ts) {
for (size_t i = 0; i < ts.ndim; ++i)
ts.shape[i] = rng();
};
s0.ndim = rng() % TensorShape::MAX_NDIM + 1;
init(s0);
auto gen = [&](int type) { auto gen = [&](int type) {
if (type == 0) { if (type == 0) {
return s0; return s0;
...@@ -84,39 +119,45 @@ TEST(BENCHMARK_BASIC_TYPES, EQ_SHAPE) { ...@@ -84,39 +119,45 @@ TEST(BENCHMARK_BASIC_TYPES, EQ_SHAPE) {
ret.ndim = s0.ndim; ret.ndim = s0.ndim;
else else
ret.ndim = rng() % TensorShape::MAX_NDIM + 1; ret.ndim = rng() % TensorShape::MAX_NDIM + 1;
for (size_t i = 0; i < ret.ndim; ++ i) init(ret);
ret.shape[i] = rng();
return ret; return ret;
} }
}; };
s0 = gen(false); s0 = gen(false);
for (size_t i = 0; i < NR_TEST; ++ i) { for (size_t i = 0; i < NR_TEST; ++i) {
s1[i] = gen(rng() % 3); s1[i] = gen(rng() % 3);
} }
int nr_diff = 0; int nr_diff = 0;
test::Timer timer; test::Timer timer;
timer.start(); timer.start();
for (size_t i = 0; i < NR_TEST; ++ i) for (size_t i = 0; i < NR_TEST; ++i)
nr_diff += eq_shape0(s1[i], s0); nr_diff += eq_shape0(s1[i], s0);
timer.stop(); timer.stop();
auto time0 = timer.get_time_in_us(); auto time0 = timer.get_time_in_us();
timer.reset(); timer.reset();
timer.start(); timer.start();
for (size_t i = 0; i < NR_TEST; ++ i) for (size_t i = 0; i < NR_TEST; ++i)
nr_diff -= eq_shape1(s1[i], s0); nr_diff -= eq_shape1(s1[i], s0);
timer.stop(); timer.stop();
auto time1 = timer.get_time_in_us(); auto time1 = timer.get_time_in_us();
printf("time per eq_shape: %.3fus vs %.3fus; diff=%d\n", printf("time per eq_shape: %.3fus vs %.3fus; diff=%d\n",
time0 / double(NR_TEST), time1 / double(NR_TEST), time0 / double(NR_TEST), time1 / double(NR_TEST), nr_diff);
nr_diff);
} }
TEST(BENCHMARK_BASIC_TYPES, EQ_LAYOUT) { TEST(BENCHMARK_BASIC_TYPES, EQ_LAYOUT) {
std::mt19937_64 rng; std::mt19937_64 rng;
static TensorLayout s0, s1[NR_TEST]; static TensorLayout s0, s1[NR_TEST];
auto init = [&rng](TensorLayout& tl) {
for (size_t i = 0; i < tl.ndim; ++i) {
tl.shape[i] = rng();
tl.stride[i] = rng();
}
};
s0.ndim = rng() % TensorShape::MAX_NDIM + 1;
init(s0);
auto gen = [&](int type) { auto gen = [&](int type) {
if (type == 0) { if (type == 0) {
return s0; return s0;
...@@ -126,35 +167,31 @@ TEST(BENCHMARK_BASIC_TYPES, EQ_LAYOUT) { ...@@ -126,35 +167,31 @@ TEST(BENCHMARK_BASIC_TYPES, EQ_LAYOUT) {
ret.ndim = s0.ndim; ret.ndim = s0.ndim;
else else
ret.ndim = rng() % TensorShape::MAX_NDIM + 1; ret.ndim = rng() % TensorShape::MAX_NDIM + 1;
for (size_t i = 0; i < ret.ndim; ++ i) { init(ret);
ret.shape[i] = rng();
ret.stride[i] = rng();
}
return ret; return ret;
} }
}; };
s0 = gen(false); s0 = gen(false);
for (size_t i = 0; i < NR_TEST; ++ i) { for (size_t i = 0; i < NR_TEST; ++i) {
s1[i] = gen(rng() % 3); s1[i] = gen(rng() % 3);
} }
int nr_diff = 0; int nr_diff = 0;
test::Timer timer; test::Timer timer;
timer.start(); timer.start();
for (size_t i = 0; i < NR_TEST; ++ i) for (size_t i = 0; i < NR_TEST; ++i)
nr_diff += eq_layout0(s1[i], s0); nr_diff += eq_layout0(s1[i], s0);
timer.stop(); timer.stop();
auto time0 = timer.get_time_in_us(); auto time0 = timer.get_time_in_us();
timer.reset(); timer.reset();
timer.start(); timer.start();
for (size_t i = 0; i < NR_TEST; ++ i) for (size_t i = 0; i < NR_TEST; ++i)
nr_diff -= eq_layout1(s1[i], s0); nr_diff -= eq_layout1(s1[i], s0);
timer.stop(); timer.stop();
auto time1 = timer.get_time_in_us(); auto time1 = timer.get_time_in_us();
printf("time per eq_layout: %.3fus vs %.3fus; diff=%d\n", printf("time per eq_layout: %.3fus vs %.3fus; diff=%d\n",
time0 / double(NR_TEST), time1 / double(NR_TEST), time0 / double(NR_TEST), time1 / double(NR_TEST), nr_diff);
nr_diff);
} }
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册