opr_impl.cpp 1.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
#include "src/cuda/cross/opr_impl.h"

#include "src/cuda/cross/cross.cuh"
#include "src/cuda/utils.h"

#include <algorithm>
#include <numeric>

namespace megdnn {
namespace cuda {

void CrossImpl::exec(
        _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
        _megdnn_workspace workspace) {
    check_exec(A.layout, B.layout, C.layout, workspace.size);

    size_t a1, b1, c1, a2, b2, c2, a3, b3, c3;
    get_ABC(A.layout, a1, b1, c1, param().axisa);
    get_ABC(B.layout, a2, b2, c2, param().axisb);
    get_ABC(C.layout, a3, b3, c3, param().axisc);
#define cb(DType)                                                             \
    if (C.layout.dtype.enumv() == DTypeTrait<DType>::enumv) {                 \
        using ctype = typename DTypeTrait<DType>::ctype;                      \
        cross::exec_internal<ctype>(                                          \
                A.ptr<ctype>(), b1 * c1, c1, B.ptr<ctype>(), b2 * c2, c2,     \
                C.ptr<ctype>(), b3 * c3, c3, a1 * c1, cuda_stream(handle())); \
    }
    MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
}

}  // namespace cuda
}  // namespace megdnn
   // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}