helper.h 2.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
/**
 * \file src/gopt/test/helper.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include "megbrain/test/helper.h"

#include "megbrain/opr/io.h"
#include "megbrain/gopt/framework.h"

namespace mgb {
    //! make an opr that reads \p x; only used for test
    SymbolVar opr_reader_for_test(SymbolVar x);

    template<class Pass>
    class TestGoptBasicArithPass: public ::testing::Test {
        protected:
            HostTensorGenerator<> gen;
            std::shared_ptr<ComputingGraph> graph = ComputingGraph::make();

29 30 31 32
            SymbolVar mkvar(const char* name, const TensorShape& shp = {1},
                            CompNode cn = CompNode::load("xpu0")) {
                return opr::Host2DeviceCopy::make(*graph, gen(shp), cn)
                        .rename(name);
33 34
            }

35 36
            SymbolVar mkcvar(const char* name, const TensorShape& shp = {1},
                             CompNode cn = CompNode::load("xpu0")) {
37
                return opr::SharedDeviceTensor::make(
38
                        *graph, *gen(shp), cn).rename(name);
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
            }

            template<typename ...Args>
            SymbolVarArray run_opt(
                    const SymbolVarArray &inp, Args&& ...args) {
                return gopt::GraphOptimizer{}.
                    add_pass<Pass>(std::forward<Args>(args)...).
                    apply({{inp}}).endpoint_vars();
            }

            template<bool check_ne=true, typename ...Args>
            void check(SymbolVar expect, SymbolVar inp, Args&& ...args) {
                if (check_ne) {
                    ASSERT_NE(expect.node(), inp.node());
                } else {
                    ASSERT_EQ(expect, inp);
                }
                SymbolVar get;
                unpack_vector(run_opt({inp}, std::forward<Args>(args)...),
                        get);
                ASSERT_EQ(expect, get);

                // test multiple readers
                unpack_vector(
                        gopt::GraphOptimizer{}.
                        add_pass<Pass>(std::forward<Args>(args)...).
                        apply({{inp + opr_reader_for_test(inp)}}).endpoint_vars(),
                        get);

                ASSERT_EQ(expect + opr_reader_for_test(expect), get);
            }
    };
}

#define TEST_PASS(pass, name) \
    using TestGopt##pass = TestGoptBasicArithPass<gopt::pass>; \
    TEST_F(TestGopt##pass, name)

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}