#include "src/cuda/cross/opr_impl.h" #include "src/cuda/cross/cross.cuh" #include "src/cuda/utils.h" #include #include namespace megdnn { namespace cuda { void CrossImpl::exec( _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, _megdnn_workspace workspace) { check_exec(A.layout, B.layout, C.layout, workspace.size); size_t a1, b1, c1, a2, b2, c2, a3, b3, c3; get_ABC(A.layout, a1, b1, c1, param().axisa); get_ABC(B.layout, a2, b2, c2, param().axisb); get_ABC(C.layout, a3, b3, c3, param().axisc); #define cb(DType) \ if (C.layout.dtype.enumv() == DTypeTrait::enumv) { \ using ctype = typename DTypeTrait::ctype; \ cross::exec_internal( \ A.ptr(), b1 * c1, c1, B.ptr(), b2 * c2, c2, \ C.ptr(), b3 * c3, c3, a1 * c1, cuda_stream(handle())); \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) #undef cb } } // namespace cuda } // namespace megdnn // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}