op.cpp 7.9 KB
Newer Older
1 2 3 4
#include "megbrain_build_config.h"

#if MGB_CUSTOM_OP

M
Megvii Engine Team 已提交
5
#include "gtest/gtest.h"
6 7
#include "megbrain/comp_node.h"
#include "megbrain/custom/data_adaptor.h"
M
Megvii Engine Team 已提交
8 9
#include "megbrain/custom/op.h"
#include "megbrain/tensor.h"
10 11 12 13 14 15 16 17 18 19 20
#include "megbrain_build_config.h"

#define OP_TEST_LOG 0

using namespace mgb;

namespace custom {

TEST(TestCustomOp, TestCustomOpInfoSetter) {
    CustomOp test("TestOp", CUSTOM_OP_VERSION);
    test.set_description("Test Op")
M
Megvii Engine Team 已提交
21 22 23 24 25 26
            .add_input("lhs", "lhs of test op", {"float32", "int32"}, 2)
            .add_inputs(2)
            .add_input("rhs", "rhs of test op", {"float32", "int32"}, 2)
            .add_outputs(1)
            .add_output("out", "out of test op", {"float32", "int32"}, 2)
            .add_outputs(3);
27 28 29 30 31 32 33

    ASSERT_TRUE(test.op_type() == "TestOp");
    ASSERT_TRUE(test.op_desc() == "Test Op");
    ASSERT_TRUE(test.input_num() == 4);
    ASSERT_TRUE(test.output_num() == 5);

#if OP_TEST_LOG
M
Megvii Engine Team 已提交
34
    for (auto input : test.inputs_info()) {
35 36
        std::cout << input.str() << std::endl;
    }
M
Megvii Engine Team 已提交
37
    for (auto output : test.outputs_info()) {
38 39 40 41 42
        std::cout << output.str() << std::endl;
    }
#endif

    test.add_param("param1", "param1 - float", 1.23f)
M
Megvii Engine Team 已提交
43 44 45 46
            .add_param("param2", "param2 - float list", {2.34f, 3.45f})
            .add_param("param3", "param3 - string", "test-string")
            .add_param("param4", {"test", "string", "list"})
            .add_param("param5", 1);
47 48 49

#if OP_TEST_LOG
    ParamInfo pinfo = test.param_info();
M
Megvii Engine Team 已提交
50
    for (auto kv : pinfo.meta()) {
51 52 53 54 55
        std::cout << kv.str() << std::endl;
    }
#endif
}

M
Megvii Engine Team 已提交
56 57 58
void device_infer(
        const std::vector<Device>& inputs, const Param& params,
        std::vector<Device>& outputs) {
59 60 61 62 63 64 65
    (void)inputs;
    (void)params;
    (void)outputs;
    outputs[0] = inputs[1];
    outputs[1] = inputs[0];
}

M
Megvii Engine Team 已提交
66 67 68
void shape_infer(
        const std::vector<Shape>& inputs, const Param& params,
        std::vector<Shape>& outputs) {
69 70 71 72 73 74 75
    (void)inputs;
    (void)params;
    (void)outputs;
    outputs[0] = inputs[1];
    outputs[1] = inputs[0];
}

M
Megvii Engine Team 已提交
76 77 78
void dtype_infer(
        const std::vector<DType>& inputs, const Param& params,
        std::vector<DType>& outputs) {
79 80 81 82 83 84 85
    (void)inputs;
    (void)params;
    (void)outputs;
    outputs[0] = inputs[1];
    outputs[1] = inputs[0];
}

M
Megvii Engine Team 已提交
86 87 88
void format_infer(
        const std::vector<Format>& inputs, const Param& params,
        std::vector<Format>& outputs) {
89 90 91 92 93 94 95
    (void)inputs;
    (void)params;
    (void)outputs;
    outputs[0] = inputs[1];
    outputs[1] = inputs[0];
}

M
Megvii Engine Team 已提交
96 97 98
void cpu_kernel(
        const std::vector<Tensor>& inputs, const Param& params,
        std::vector<Tensor>& outputs) {
99 100 101 102
    (void)inputs;
    (void)params;
    (void)outputs;
#if OP_TEST_LOG
M
Megvii Engine Team 已提交
103 104
    std::cout << "Checking CPU Forward - " << params["device"].as<std::string>()
              << std::endl;
105 106 107 108
#endif
    ASSERT_TRUE(params["device"] == "x86");
}

M
Megvii Engine Team 已提交
109 110 111
void gpu_kernel(
        const std::vector<Tensor>& inputs, const Param& params,
        std::vector<Tensor>& outputs) {
112 113 114 115
    (void)inputs;
    (void)params;
    (void)outputs;
#if OP_TEST_LOG
M
Megvii Engine Team 已提交
116 117
    std::cout << "Checking GPU Forward - " << params["device"].as<std::string>()
              << std::endl;
118 119 120 121
#endif
    ASSERT_TRUE(params["device"] == "cuda");
}

122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
void cpu_kernel_with_runtime_args(
        const std::vector<Tensor>& inputs, const Param& params,
        std::vector<Tensor>& outputs, const RuntimeArgs& args) {
    (void)inputs;
    (void)params;
    (void)outputs;
    (void)args;
#if OP_TEST_LOG
    std::cout << "Checking CPU Forward - " << params["device"].as<std::string>()
              << std::endl;
#endif
    ASSERT_TRUE(params["device"] == "x86");
}

void gpu_kernel_with_runtime_args(
        const std::vector<Tensor>& inputs, const Param& params,
        std::vector<Tensor>& outputs, const RuntimeArgs& args) {
    (void)inputs;
    (void)params;
    (void)outputs;
    (void)args;
#if OP_TEST_LOG
    std::cout << "Checking GPU Forward - " << params["device"].as<std::string>()
              << std::endl;
#endif
    ASSERT_TRUE(params["device"] == "cuda");
}

150 151 152 153
TEST(TestCustomOp, TestCustomOpFuncSetter) {
#if MGB_CUDA
    CustomOp test("TestOp", CUSTOM_OP_VERSION);
    test.set_description("Test Op Forward Backward Union")
M
Megvii Engine Team 已提交
154 155 156 157 158 159
            .add_input("lhs", "lhs of Test op", {"float32", "int32"}, 2)
            .add_input("rhs", "rhs of Test op", {"float32", "int32"}, 2)
            .add_output("outl", "outl of Test op", {"float32", "int32"}, 2)
            .add_output("outr", "outr of Test op", {"float32", "int32"}, 2)
            .add_param("smooth", "smooth", 0.f)
            .add_param("device", "using for judge device", "x86");
160 161

    std::vector<Device> idevices = {"x86", "cuda"};
M
Megvii Engine Team 已提交
162 163
    std::vector<Shape> ishapes = {{2, 3}, {3, 4}};
    std::vector<DType> idtypes = {"int32", "float32"};
164 165 166 167
    std::vector<Format> iformats = {"default", "default"};
    Param param(test.param_info());

    std::vector<Device> odevices = test.infer_output_device(idevices, param);
M
Megvii Engine Team 已提交
168 169
    std::vector<Shape> oshapes = test.infer_output_shape(ishapes, param);
    std::vector<DType> odtypes = test.infer_output_dtype(idtypes, param);
170 171 172
    std::vector<Format> oformats = test.infer_output_format(iformats, param);

    ASSERT_TRUE(odevices.size() == 2);
M
Megvii Engine Team 已提交
173 174
    ASSERT_TRUE(oshapes.size() == 2);
    ASSERT_TRUE(odtypes.size() == 2);
175 176 177 178
    ASSERT_TRUE(oformats.size() == 2);

    ASSERT_TRUE(odevices[0] == "x86");
    ASSERT_TRUE(odevices[1] == "x86");
M
Megvii Engine Team 已提交
179 180 181 182
    ASSERT_TRUE(oshapes[0] == Shape({2, 3}));
    ASSERT_TRUE(oshapes[1] == Shape({2, 3}));
    ASSERT_TRUE(odtypes[0] == "int32");
    ASSERT_TRUE(odtypes[1] == "int32");
183 184 185 186
    ASSERT_TRUE(iformats[0].is_default());
    ASSERT_TRUE(iformats[1].is_default());

    test.set_device_infer(device_infer)
M
Megvii Engine Team 已提交
187 188 189
            .set_shape_infer(shape_infer)
            .set_dtype_infer(dtype_infer)
            .set_format_infer(format_infer);
190 191

    odevices = test.infer_output_device(idevices, param);
M
Megvii Engine Team 已提交
192 193
    oshapes = test.infer_output_shape(ishapes, param);
    odtypes = test.infer_output_dtype(idtypes, param);
194 195 196
    oformats = test.infer_output_format(iformats, param);

    ASSERT_TRUE(odevices.size() == 2);
M
Megvii Engine Team 已提交
197 198
    ASSERT_TRUE(oshapes.size() == 2);
    ASSERT_TRUE(odtypes.size() == 2);
199 200 201 202
    ASSERT_TRUE(oformats.size() == 2);

    ASSERT_TRUE(odevices[0] == "cuda");
    ASSERT_TRUE(odevices[1] == "x86");
M
Megvii Engine Team 已提交
203 204 205 206
    ASSERT_TRUE(oshapes[0] == Shape({3, 4}));
    ASSERT_TRUE(oshapes[1] == Shape({2, 3}));
    ASSERT_TRUE(odtypes[0] == "float32");
    ASSERT_TRUE(odtypes[1] == "int32");
207 208 209
    ASSERT_TRUE(iformats[0].is_default());
    ASSERT_TRUE(iformats[1].is_default());

210
    test.set_compute(cpu_kernel_with_runtime_args);
211 212 213 214 215 216
    test.set_compute(cpu_kernel);
    DeviceTensorND cdev_itensor0(CompNode::load("cpux"), {3, 2}, dtype::Int32{});
    DeviceTensorND cdev_itensor1(CompNode::load("cpux"), {3, 2}, dtype::Float32{});
    DeviceTensorND cdev_otensor0(CompNode::load("cpux"), {3, 2}, dtype::Float32{});
    DeviceTensorND cdev_otensor1(CompNode::load("cpux"), {3, 2}, dtype::Int32{});

M
Megvii Engine Team 已提交
217 218 219 220
    std::vector<Tensor> cinputs = {
            to_custom_tensor(cdev_itensor0), to_custom_tensor(cdev_itensor1)};
    std::vector<Tensor> coutputs = {
            to_custom_tensor(cdev_otensor0), to_custom_tensor(cdev_otensor1)};
221 222 223
    param["device"] = "x86";
    test.compute(cinputs, param, coutputs);

224
    test.set_compute("cuda", gpu_kernel_with_runtime_args);
225 226 227 228 229 230
    test.set_compute("cuda", gpu_kernel);
    DeviceTensorND gdev_itensor0(CompNode::load("gpux"), {3, 2}, dtype::Int32{});
    DeviceTensorND gdev_itensor1(CompNode::load("gpux"), {3, 2}, dtype::Float32{});
    DeviceTensorND gdev_otensor0(CompNode::load("gpux"), {3, 2}, dtype::Float32{});
    DeviceTensorND gdev_otensor1(CompNode::load("gpux"), {3, 2}, dtype::Int32{});

M
Megvii Engine Team 已提交
231 232 233 234
    std::vector<Tensor> ginputs = {
            to_custom_tensor(gdev_itensor0), to_custom_tensor(gdev_itensor1)};
    std::vector<Tensor> goutputs = {
            to_custom_tensor(gdev_otensor0), to_custom_tensor(gdev_otensor1)};
235 236 237 238 239
    param["device"] = "cuda";
    test.compute(ginputs, param, goutputs);
#endif
}

M
Megvii Engine Team 已提交
240
}  // namespace custom
241 242

#endif