non_zero.cpp 757 字节
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
#include "test/common/non_zero.h"
#include "megdnn/dtype.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
#include "test/cuda/fixture.h"

using namespace megdnn;
using namespace test;
TEST_F(CUDA, NONZERO) {
    std::vector<NonZeroTestcase> test_cases = NonZeroTestcase::make();
    auto opr_cuda = handle_cuda()->create_operator<NonZero>();
    auto opr_naive = handle_naive()->create_operator<NonZero>();
    for (NonZeroTestcase& test_case : test_cases) {
        NonZeroTestcase::CUDAResult data = test_case.run_cuda(opr_cuda.get());
        NonZeroTestcase::CUDAResult data_naive = test_case.run_cuda(opr_naive.get());

        std::vector<int> result = test_case.correct_answer;
        MEGDNN_ASSERT_TENSOR_EQ(*data, *data_naive);
    }
}