diff --git a/dnn/include/megdnn/oprs/linalg.h b/dnn/include/megdnn/oprs/linalg.h index 65b7175fb7f2b01cb0bcd077082f860f79069f55..de877c52516d6499f129449e67ca68479a07beb3 100644 --- a/dnn/include/megdnn/oprs/linalg.h +++ b/dnn/include/megdnn/oprs/linalg.h @@ -208,6 +208,31 @@ protected: using SVD = SVDForward; +//! return the cross product of two (arrays of) vectors. +class Cross : public OperatorBase { + DEF_OPR_IMPL(Cross, OperatorBase, 2, 1); + DEF_OPR_PARAM(Cross); + +public: + /** + * \see https://numpy.org/doc/stable/reference/generated/numpy.cross.html + */ + virtual void exec( + _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, + _megdnn_workspace) = 0; + MGE_WIN_DECLSPEC_FUC void deduce_layout( + const TensorLayout& A, const TensorLayout& B, TensorLayout& C); + virtual size_t get_workspace_in_bytes( + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; + void get_ABC( + const TensorShape& shape, size_t& A, size_t& B, size_t& C, int32_t axis); + +protected: + void check_exec( + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, + size_t workspace_in_bytes); +}; + } // namespace megdnn #include "megdnn/internal/opr_header_epilogue.h" diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 382dea28b810293e8ad165d39856a384d3f97d21..3e2c4fec4ac615ae2b1a82d7def4e594af68d579 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -803,6 +803,23 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) 'negative value to a lower diagonal.'), 0)) +(pdef('Cross'). + add_fields( + 'int32', + Doc('axisa', 'axis of a that defines the vector(s). By default, the last axis.'), + '-1'). + add_fields( + 'int32', + Doc('axisb', 'axis of b that defines the vector(s). By default, the last axis.'), + '-1'). + add_fields( + 'int32', + Doc('axisc', 'axis of c containing the cross product vector(s). Ignored if both ' + 'input vectors have dimension 2, as the return is scalar. By default, the ' + 'last axis.'), + '-1') +) + (pdef('UniformRNG', version=0, is_legacy=True). add_fields('uint64', 'seed', 0)) diff --git a/dnn/src/common/cross.cpp b/dnn/src/common/cross.cpp new file mode 100644 index 0000000000000000000000000000000000000000..557ad4193186f64104a4cd7b73c7d168aec1fbf5 --- /dev/null +++ b/dnn/src/common/cross.cpp @@ -0,0 +1,83 @@ +#include "megdnn/oprs.h" +#include "src/common/utils.h" + +#include +#include + +namespace megdnn { + +void Cross::deduce_layout( + const TensorLayout& A, const TensorLayout& B, TensorLayout& C) { + auto calibrated_axis = [](int ndim, int axis) { + return axis < 0 ? (axis + ndim) : axis; + }; + + int axis_a = calibrated_axis(A.ndim, param().axisa); + int axis_b = calibrated_axis(B.ndim, param().axisb); + int axis_c = calibrated_axis(A.ndim, param().axisc); + + megdnn_assert( + A[axis_a] == 3 && B[axis_b] == 3, + "incompatible dimensions for cross product (dimension must be 3)"); + + bool matched = true; + TensorShape shp; + if (A.ndim != B.ndim) { + matched = false; + } else { + for (int i = 0, j = 0, k = 0; i < static_cast(A.ndim); i++) { + if (i == axis_a) + continue; + if (j == axis_b) + ++j; + if (A[i] != B[j]) { + matched = false; + break; + } + if (k == axis_c) + ++k; + shp[k++] = A[i]; + ++j; + } + } + + megdnn_assert( + matched, "cross op shape mismatch: %s vs %s", A.to_string().c_str(), + B.to_string().c_str()); + + shp.ndim = A.ndim; + shp[axis_c] = A[axis_a]; + C = TensorLayout{shp, A.dtype}; +} + +void Cross::check_exec( + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, + size_t workspace_in_bytes) { + megdnn_assert_eq_dtype(A, B); + megdnn_assert_eq_dtype(B, C); + TensorLayout c_expected; + deduce_layout(A, B, c_expected); + megdnn_assert_eq_layout(c_expected, C); + + megdnn_assert_contiguous(A); + megdnn_assert_contiguous(B); + megdnn_assert_contiguous(C); + auto required_workspace_in_bytes = get_workspace_in_bytes(A, B, C); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); +} + +void Cross::get_ABC( + const TensorShape& shape, size_t& A, size_t& B, size_t& C, int32_t axis) { + auto shape_arr = shape.shape; + auto ndim = shape.ndim; + if (axis < 0) + axis += ndim; + A = std::accumulate(shape_arr, shape_arr + axis, 1_z, SafeMultiplies()); + B = shape_arr[axis]; + C = std::accumulate( + shape_arr + (axis + 1), shape_arr + ndim, 1_z, SafeMultiplies()); +} + +} // namespace megdnn + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index b89dd8b35822fc8faabdc86307734811801083f7..88bbdbc56b191000d4b9b7fae1b5fb049702213b 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -224,7 +224,8 @@ private: cb(GroupNormBackward) \ cb(MaskedFill) \ cb(MultiHeadAttnForward)\ - cb(MultiHeadAttnBackward) + cb(MultiHeadAttnBackward) \ + cb(Cross) // clang-format on /*! diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h index 7e941a196dccb54f5b857188ca84f95cdc7c0125..b4fd26b65d08ffac455b09968245764fa8101a7e 100644 --- a/dnn/src/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -80,6 +80,7 @@ DEF(IndexingRemapBackward, 3, true, false); DEF(Linspace, 1, true, false); DEF(Eye, 1, true, false); DEF(Diag, 2, true, true); +DEF(Cross, 3, true, true); DEF(Flip, 2, true, true); DEF(ROICopy, 2, true, true); DEF(Rotate, 2, true, true); diff --git a/dnn/src/cuda/cross/cross.cu b/dnn/src/cuda/cross/cross.cu new file mode 100644 index 0000000000000000000000000000000000000000..d53296c50206e0e90184933c60f124a70669166a --- /dev/null +++ b/dnn/src/cuda/cross/cross.cu @@ -0,0 +1,54 @@ +#include "megdnn/dtype.h" +#include "src/cuda/cross/cross.cuh" +#include "src/cuda/utils.cuh" + +namespace { + +template +__global__ void cross_kernel( + T* A, size_t stride_a0, size_t stride_a1, T* B, size_t stride_b0, + size_t stride_b1, T* C, size_t stride_c0, size_t stride_c1, size_t N) { + size_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i < N) { + size_t ida = (i / stride_a1) * stride_a0 + i % stride_a1; + size_t idb = (i / stride_b1) * stride_b0 + i % stride_b1; + size_t idc = (i / stride_c1) * stride_c0 + i % stride_c1; + C[idc] = A[ida + stride_a1] * B[idb + 2 * stride_b1] - + A[ida + 2 * stride_a1] * B[idb + stride_b1]; + C[idc + stride_c1] = + A[ida + 2 * stride_a1] * B[idb] - A[ida] * B[idb + 2 * stride_b1]; + C[idc + 2 * stride_c1] = + A[ida] * B[idb + stride_b1] - A[ida + stride_a1] * B[idb]; + } +} + +} // anonymous namespace + +namespace megdnn { +namespace cuda { +namespace cross { + +template +void exec_internal( + T* A, size_t stride_a0, size_t stride_a1, T* B, size_t stride_b0, + size_t stride_b1, T* C, size_t stride_c0, size_t stride_c1, size_t N, + cudaStream_t stream) { + cross_kernel<<>>( + A, stride_a0, stride_a1, B, stride_b0, stride_b1, C, stride_c0, stride_c1, + N); + after_kernel_launch(); +} + +#define INST(T) \ + template void exec_internal( \ + T*, size_t, size_t, T*, size_t, size_t, T*, size_t, size_t, size_t, \ + cudaStream_t); +#define cb(DType) INST(typename DTypeTrait::ctype) +MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef INST +#undef cb + +} // namespace cross +} // namespace cuda +} // namespace megdnn + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} \ No newline at end of file diff --git a/dnn/src/cuda/cross/cross.cuh b/dnn/src/cuda/cross/cross.cuh new file mode 100644 index 0000000000000000000000000000000000000000..964c008ed5e752df9c95b7d77b18e7a5e0453837 --- /dev/null +++ b/dnn/src/cuda/cross/cross.cuh @@ -0,0 +1,18 @@ +#pragma once +#include +#include + +namespace megdnn { +namespace cuda { +namespace cross { + +template +void exec_internal( + T* A, size_t stride_a0, size_t stride_a1, T* B, size_t stride_b0, + size_t stride_b1, T* C, size_t stride_c0, size_t stride_c1, size_t N, + cudaStream_t stream); + +} // namespace cross +} // namespace cuda +} // namespace megdnn + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} \ No newline at end of file diff --git a/dnn/src/cuda/cross/opr_impl.cpp b/dnn/src/cuda/cross/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5e7c78de10ecbbfffdec7b9ec1ab4572845790d4 --- /dev/null +++ b/dnn/src/cuda/cross/opr_impl.cpp @@ -0,0 +1,34 @@ +#include "src/cuda/cross/opr_impl.h" + +#include "src/cuda/cross/cross.cuh" +#include "src/cuda/utils.h" + +#include +#include + +namespace megdnn { +namespace cuda { + +void CrossImpl::exec( + _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, + _megdnn_workspace workspace) { + check_exec(A.layout, B.layout, C.layout, workspace.size); + + size_t a1, b1, c1, a2, b2, c2, a3, b3, c3; + get_ABC(A.layout, a1, b1, c1, param().axisa); + get_ABC(B.layout, a2, b2, c2, param().axisb); + get_ABC(C.layout, a3, b3, c3, param().axisc); +#define cb(DType) \ + if (C.layout.dtype.enumv() == DTypeTrait::enumv) { \ + using ctype = typename DTypeTrait::ctype; \ + cross::exec_internal( \ + A.ptr(), b1 * c1, c1, B.ptr(), b2 * c2, c2, \ + C.ptr(), b3 * c3, c3, a1 * c1, cuda_stream(handle())); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb +} + +} // namespace cuda +} // namespace megdnn + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} \ No newline at end of file diff --git a/dnn/src/cuda/cross/opr_impl.h b/dnn/src/cuda/cross/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..baca7a5c88c8ae5a57b39109610a6d9df5367de1 --- /dev/null +++ b/dnn/src/cuda/cross/opr_impl.h @@ -0,0 +1,22 @@ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace cuda { + +class CrossImpl final : public Cross { +public: + using Cross::Cross; + void exec( + _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& A, const TensorLayout& B, + const TensorLayout& C) override { + return 0; + } +}; + +} // namespace cuda +} // namespace megdnn + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} \ No newline at end of file diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index 3c859c8fe84300abddf6625d0257c050ba6b8702..cf1c79b75a5d29842f39d48ff69f4173ef0091b2 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -16,6 +16,7 @@ #include "src/cuda/convolution3d/opr_impl.h" #include "src/cuda/convpooling/opr_impl.h" #include "src/cuda/correlation/opr_impl.h" +#include "src/cuda/cross/opr_impl.h" #include "src/cuda/cumsum/opr_impl.h" #include "src/cuda/cvt_color/opr_impl.h" #include "src/cuda/dct/opr_impl.h" @@ -234,6 +235,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(RegionRestrictedConvolutionBackwardData); MEGDNN_SPECIALIZE_CREATE_OPERATOR(RegionRestrictedConvolutionBackwardFilter); MEGDNN_SPECIALIZE_CREATE_OPERATOR(MultiHeadAttnForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(MultiHeadAttnBackward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(Cross); template std::unique_ptr HandleImpl::create_operator() { diff --git a/dnn/src/naive/cross/opr_impl.cpp b/dnn/src/naive/cross/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..93ba5a3d0d6e7121bd4bebfe909aeccc8a0e895c --- /dev/null +++ b/dnn/src/naive/cross/opr_impl.cpp @@ -0,0 +1,50 @@ +#include "src/naive/cross/opr_impl.h" +#include +#include +#include "src/common/utils.h" +#include "src/naive/handle.h" + +namespace megdnn { +namespace naive { + +template +void CrossImpl::exec_internal( + ctype* A, size_t a1, size_t b1, size_t c1, ctype* B, size_t a2, size_t b2, + size_t c2, ctype* C, size_t a3, size_t b3, size_t c3) { + (void)a2; + (void)a3; + + size_t N = a1 * c1; + for (size_t i = 0; i < N; ++i) { + size_t ida = (i / c1) * b1 * c1 + i % c1; + size_t idb = (i / c2) * b2 * c2 + i % c2; + size_t idc = (i / c3) * b3 * c3 + i % c3; + C[idc] = A[ida + c1] * B[idb + 2 * c2] - A[ida + 2 * c1] * B[idb + c2]; + C[idc + c3] = A[ida + 2 * c1] * B[idb] - A[ida] * B[idb + 2 * c2]; + C[idc + 2 * c3] = A[ida] * B[idb + c2] - A[ida + c1] * B[idb]; + } +} + +void CrossImpl::exec( + _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, + _megdnn_workspace workspace) { + check_exec(A.layout, B.layout, C.layout, workspace.size); + size_t a1, b1, c1, a2, b2, c2, a3, b3, c3; + get_ABC(A.layout, a1, b1, c1, param().axisa); + get_ABC(B.layout, a2, b2, c2, param().axisb); + get_ABC(C.layout, a3, b3, c3, param().axisc); +#define cb(DType) \ + if (A.layout.dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal( \ + A.ptr(), a1, b1, c1, B.ptr(), a2, b2, c2, \ + C.ptr(), a3, b3, c3)); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb +} + +} // namespace naive +} // namespace megdnn + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} \ No newline at end of file diff --git a/dnn/src/naive/cross/opr_impl.h b/dnn/src/naive/cross/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..1d8d602827ee07eaa901e88c7bc8718bbc7ce3af --- /dev/null +++ b/dnn/src/naive/cross/opr_impl.h @@ -0,0 +1,29 @@ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace naive { + +class CrossImpl : public Cross { +public: + using Cross::Cross; + + void exec( + _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&) override { + return 0; + } + +private: + template + void exec_internal( + ctype* A, size_t a1, size_t b1, size_t c1, ctype* B, size_t a2, size_t b2, + size_t c2, ctype* C, size_t a3, size_t b3, size_t c3); +}; + +} // namespace naive +} // namespace megdnn + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} \ No newline at end of file diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index dc741f937669be0a9b88465d2f6a1660dcae4fa4..7b4141895d95b0cfc3e0169fbd27d5d7f8353a02 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -18,6 +18,7 @@ #include "src/naive/convolution3d/opr_impl.h" #include "src/naive/convpooling/opr_impl.h" #include "src/naive/correlation/opr_impl.h" +#include "src/naive/cross/opr_impl.h" #include "src/naive/cumsum/opr_impl.h" #include "src/naive/cvt_color/opr_impl.h" #include "src/naive/dct/opr_impl.h" diff --git a/dnn/test/cuda/cross.cpp b/dnn/test/cuda/cross.cpp new file mode 100644 index 0000000000000000000000000000000000000000..19e7cf85536d3901e4189071016e0e8ffa1df925 --- /dev/null +++ b/dnn/test/cuda/cross.cpp @@ -0,0 +1,29 @@ +#include "test/cuda/fixture.h" + +#include "megdnn/oprs.h" +#include "test/common/checker.h" + +namespace megdnn { +namespace test { + +TEST_F(CUDA, CROSS) { + Checker checker(handle_cuda()); + for (DType dtype : + std::vector{dtype::Float16(), dtype::Int32(), dtype::Float32()}) { + checker.set_param({-2, 1, -1}) + .set_dtype(0, dtype) + .set_dtype(1, dtype) + .set_dtype(2, dtype); + checker.exec(TensorShapeArray{{2, 3, 4}, {2, 3, 4}, {2, 4, 3}}); + + checker.set_param({0, -1, 2}) + .set_dtype(0, dtype) + .set_dtype(1, dtype) + .set_dtype(2, dtype); + checker.exec(TensorShapeArray{{3, 2, 3, 4}, {2, 3, 4, 3}, {2, 3, 3, 4}}); + } +} + +} // namespace test +} // namespace megdnn + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} \ No newline at end of file diff --git a/dnn/test/naive/cross.cpp b/dnn/test/naive/cross.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d118b20c26465f2d8432020a2cc1c46370bcc747 --- /dev/null +++ b/dnn/test/naive/cross.cpp @@ -0,0 +1,44 @@ +#include "test/naive/fixture.h" + +#include "megdnn/oprs/linalg.h" +#include "test/common/checker.h" + +using namespace megdnn; +using namespace test; + +TEST_F(NAIVE, CROSS) { + Checker checker(handle(), /* check_dispatch */ false); + + Cross::Param param{2, 0, -3}; + + TensorND A = TensorValue( + {2, 2, 3, 3}, dtype::Float32(), + {0.14156187, 0.91729394, 0.74985146, 0.06110339, 0.67560272, 0.38569672, + 0.63086542, 0.96403808, 0.44215527, 0.04618851, 0.17031024, 0.83291294, + 0.15374447, 0.57640597, 0.90293485, 0.55324231, 0.75830697, 0.43003672, + 0.18093289, 0.73350109, 0.38130576, 0.29138499, 0.77249111, 0.0878262, + 0.11699259, 0.73904961, 0.52122415, 0.10095503, 0.15948783, 0.88822725, + 0.12915866, 0.59800798, 0.17838123, 0.37576929, 0.9885241, 0.25812663}); + + TensorND B = TensorValue( + {3, 2, 2, 3}, dtype::Float32(), + {0.87391774, 0.65328165, 0.44478127, 0.94795241, 0.84836271, 0.8054875, + 0.44598355, 0.50247015, 0.36997338, 0.77503215, 0.87235448, 0.21397323, + 0.85153319, 0.16009368, 0.77536431, 0.04486964, 0.37626156, 0.81259061, + 0.08839276, 0.50477703, 0.91626721, 0.44417563, 0.03038398, 0.21980224, + 0.71472471, 0.85791626, 0.40345219, 0.16018583, 0.38511319, 0.23809232, + 0.58711753, 0.92031592, 0.65426633, 0.14788665, 0.73273333, 0.3182309}); + + TensorND C = TensorValue( + {2, 3, 2, 3}, dtype::Float32(), + {-0.49353074, 0.42527416, -0.18722123, -0.0001961, -0.06334022, + -0.13446194, 0.45014671, -0.157173, -0.10586683, 0.51704864, + 0.57773064, 0.14807902, 0.0671453, -0.29450591, 0.40985739, + -0.14366998, -0.42492013, -0.05048549, 0.16073594, 0.3378806, + -0.42011888, -0.14780672, 0.40814509, 0.00002961, -0.0540521, + -0.30370236, -0.05663646, 0.27630338, 0.74548138, -0.22742917, + -0.11395976, -0.01789922, 0.31688461, -0.05526035, -0.51682907, + 0.15706553}); + + checker.set_param(param).exect(Testcase{A, B, {}}, Testcase{{}, {}, C}); +} diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index f90153d99d2d62eb0bc1879760697a9c99cf81c4..ee509024005a47268e66391eca1386fb4eb527a2 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -13,12 +13,13 @@ from ..core.tensor.utils import _normalize_axis from ..tensor import Tensor from ..utils.deprecation import deprecated_kwargs_default from .elemwise import _elemwise_multi_type, clip -from .tensor import broadcast_to, expand_dims, squeeze +from .tensor import broadcast_to, concat, expand_dims, squeeze, zeros __all__ = [ "argmax", "argmin", "argsort", + "cross", "dot", "isinf", "isnan", @@ -720,6 +721,50 @@ def matinv(inp: Tensor) -> Tensor: return result +def cross( + a: Tensor, + b: Tensor, + axisa: int = -1, + axisb: int = -1, + axisc: int = -1, + axis: int = None, +) -> Tensor: + r"""Return the cross product of two (arrays of) vectors. + + The cross product of ``a`` and ``b`` in `R^3` is a vector perpendicular to both ``a`` and ``b``. If ``a`` and ``b`` are arrays of vectors, the vectors are defined by the last axis of ``a`` and ``b`` by default, and these axes can have dimensions 2 or 3. + Where the dimension of either ``a`` or ``b`` is 2, the third component of the input vector is assumed to be zero and the cross product calculated accordingly. + + Args: + a: components of the first vector(s). + b: components of the first vector(s). + axisa: axis of a that defines the vector(s). By default, the last axis. + axisb: axis of b that defines the vector(s). By default, the last axis. + axisc: axis of c containing the cross product vector(s). By default, the last axis. + axis: if defined, the axis of a, b and c that defines the vector(s) and cross product(s). Overrides axisa, axisb and axisc. + + Returns: + vector cross product(s). + + Examples: + >>> a = Tensor([1.0, 2.0, 3.0]) + >>> b = Tensor([4.0, 5.0, 6.0]) + >>> out = F.cross(a, b) + >>> out.numpy() + array([-3., 6., -3.], dtype=float32) + """ + if axis is not None: + axisa = axisb = axisc = axis + + if a.ndim == 1 and len(a) == 2: + a = Tensor(np.append(a, 0)) + if b.ndim == 1 and len(b) == 2: + b = Tensor(np.append(b, 0)) + + op = builtin.Cross(axisa, axisb, axisc) + (result,) = apply(op, a, b) + return result + + def matmul( inp1: Tensor, inp2: Tensor, diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index cd99a896dee2c518c464f632caabbdfae6de893c..621366a60efd2e5eb3339445f3289f14025ea5ce 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -1680,3 +1680,49 @@ def test_softmax(): actual = F.softmax(data) np.testing.assert_allclose(actual.numpy(), desired, rtol=1e-5) + + +def test_cross(): + shape1 = 3 + shape2 = 3 + shape3 = (4, 2, 3) + shape4 = (4, 2, 3) + data1 = np.random.random(shape1).astype("float32") + data2 = np.random.random(shape2).astype("float32") + data3 = np.random.random(shape3).astype("float32") + data4 = np.random.random(shape4).astype("float32") + + cases = [ + {"input": [data1, data2]}, + {"input": [data3, data4]}, + ] + opr_test( + cases, + F.cross, + compare_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol=1e-4), + ref_fn=np.cross, + ) + + data5 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + data6 = np.array([[7, 8, 9], [4, 5, 6], [1, 2, 3]]) + res = F.cross(mge.tensor(data5), mge.tensor(data6), axisa=0, axisb=0) + dst = np.cross(data5, data6, axisa=0, axisb=0) + np.testing.assert_allclose(res.numpy(), dst, rtol=1e-4) + + data7 = np.array([1, 2, 3]) + data8 = np.array([4, 5, 6]) + res = F.cross(mge.tensor(data7), mge.tensor(data8), axisa=0, axisb=0) + dst = np.cross(data7, data8, axisa=0, axisb=0) + np.testing.assert_allclose(res.numpy(), dst, rtol=1e-4) + + data9 = np.array([[6, 7, 4], [15, 2, 8], [9, 40, 39]]) + data10 = np.array([[5, 8, 9], [14, 21, 17], [10, 3, 47]]) + res = F.cross(mge.tensor(data9), mge.tensor(data10), axisa=1, axisb=-1) + dst = np.cross(data9, data10, axisa=1, axisb=-1) + np.testing.assert_allclose(res.numpy(), dst, rtol=1e-4) + + data11 = np.array([1, 2]) + data12 = np.array([4, 5, 6]) + res = F.cross(mge.tensor(data11), mge.tensor(data12)) + dst = np.cross(data11, data12) + np.testing.assert_allclose(res.numpy(), dst, rtol=1e-4) diff --git a/imperative/src/impl/ops/cross.cpp b/imperative/src/impl/ops/cross.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d94e516c4a019650e4d9a4d13377415be6cfa7fa --- /dev/null +++ b/imperative/src/impl/ops/cross.cpp @@ -0,0 +1,67 @@ +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/opr/blas.h" +#include "megdnn/oprs.h" + +#include "../dnn_op_helper.h" +#include "../op_trait.h" + +namespace mgb { +namespace imperative { +namespace cross { + +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = def.cast_final_safe(); + mgb_assert(inputs.size() == 2); + cg::OperatorNodeConfig config{op.make_name()}; + return opr::Cross::make(inputs[0], inputs[1], op.param(), config); +} + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + mgb_assert(inputs.size() == 2, "Cross expects two inputs"); + auto&& op_def = def.cast_final_safe(); + auto comp_node = inputs[0].comp_node; + + if (!inputs[0].layout.ndim) { + return {{{inputs[0].layout, comp_node}}, false}; + } + if (!inputs[1].layout.ndim) { + return {{{inputs[1].layout, comp_node}}, false}; + } + + DnnOprHelper dnn_op(op_def.param()); + auto oup_layout = dnn_op.deduce_layout(inputs[0].layout, inputs[1].layout); + return {{{oup_layout, comp_node}}, true}; +} + +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + auto comp_node = inputs[0]->comp_node(); + auto&& op_def = def.cast_final_safe(); + DnnOprCaller dnn_op(comp_node, op_def.param()); + auto dst = [&] { + if (validated) { + return output_descs[0].layout; + } else { + return dnn_op.deduce_layout(inputs[0]->layout(), inputs[1]->layout()); + } + }(); + auto out = Tensor::make(dst, comp_node); + if (!inputs[0]->layout().is_empty() && !inputs[1]->layout().is_empty()) { + dnn_op.exec_with_ws(inputs[0], inputs[1], out); + } + return {out}; +} + +OP_TRAIT_REG(Cross, Cross) + .apply_on_var_node(apply_on_var_node) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .apply_on_physical_tensor(apply_on_physical_tensor) + .fallback(); + +} // namespace cross +} // namespace imperative +} // namespace mgb + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/imperative/tablegen/generated/hash.txt b/imperative/tablegen/generated/hash.txt index 2ced0439dcc2275b6f452af8defd9c4080dcbca3..f301093af061d48f23b2ee84a1db2876b1f89fb0 100644 --- a/imperative/tablegen/generated/hash.txt +++ b/imperative/tablegen/generated/hash.txt @@ -1,7 +1,7 @@ -29b2127eb4034bf24e473945d70ead4a ../../dnn/scripts/opr_param_defs.py -639ff50d64fcb78374de266c88942c2c ../../src/core/include/megbrain/ir/ops.td -16654743e01160eeee879107cc4cac41 generated/opdef.h.inl -97c541ed45b0be98f1ac2700f5b4d8a6 generated/opdef.cpp.inl -6f9c6a7a1d71cca195c1e30743a1f542 generated/opdef.py.inl -806c5ceb34f571fc5c9d98d2ca8cad63 generated/opdef.cpy.inl +e4035bfefce3a2cc0e8cc6ec7fcac227 ../../dnn/scripts/opr_param_defs.py +13ab898fce3749ebbcabf7c145876147 ../../src/core/include/megbrain/ir/ops.td +9dda6e2db75279373ec6809b297a2370 generated/opdef.h.inl +aabc2d8146742faacabf56e376177e7b generated/opdef.cpp.inl +8a5dffac1df3286178b3fd304c39b5da generated/opdef.py.inl +04322b642bba8f684034fcc5dc27efcf generated/opdef.cpy.inl 911001ef0dd771024919f7a1a3a009db generated/enum_macro.h diff --git a/imperative/tablegen/generated/opdef.cpp.inl b/imperative/tablegen/generated/opdef.cpp.inl index bb00684f216cd594ea9fbb8f75906049c6238f88..d20aaa57d80f5ab5cfcd98c82b97e187c9d7b55d 100644 --- a/imperative/tablegen/generated/opdef.cpp.inl +++ b/imperative/tablegen/generated/opdef.cpp.inl @@ -2303,6 +2303,49 @@ OP_TRAIT_REG(Correlation, Correlation) .props(Correlation_props_impl) .make_name(Correlation_make_name_impl); +MGB_DYN_TYPE_OBJ_FINAL_IMPL(Cross); + +namespace { +size_t Cross_hash_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + size_t val = mgb::hash(op_.dyn_typeinfo()); + val = mgb::hash_pair_combine(val, mgb::hash(op_.axisa)); + val = mgb::hash_pair_combine(val, mgb::hash(op_.axisb)); + val = mgb::hash_pair_combine(val, mgb::hash(op_.axisc)); + return val; +} +bool Cross_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { + auto &&a_ = lhs_.cast_final_safe(), + &&b_ = rhs_.cast_final_safe(); + static_cast(a_); + static_cast(b_); + if (a_.axisa != b_.axisa) return false; + if (a_.axisb != b_.axisb) return false; + if (a_.axisc != b_.axisc) return false; + return true; +} +std::vector> Cross_props_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + std::vector> props_; + props_.emplace_back("axisa", std::to_string(op_.axisa)); + props_.emplace_back("axisb", std::to_string(op_.axisb)); + props_.emplace_back("axisc", std::to_string(op_.axisc)); + return props_; +} +std::string Cross_make_name_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + return "Cross"; +} +} // anonymous namespace +OP_TRAIT_REG(Cross, Cross) + .hash(Cross_hash_impl) + .is_same_st(Cross_is_same_st_impl) + .props(Cross_props_impl) + .make_name(Cross_make_name_impl); + MGB_DYN_TYPE_OBJ_FINAL_IMPL(Cumsum); namespace { diff --git a/imperative/tablegen/generated/opdef.cpy.inl b/imperative/tablegen/generated/opdef.cpy.inl index dca84d298f40fff9088ecf18d7e18e33661bcbe6..3889dcb8b2bf21080c977edacab8ca30a57de47d 100644 --- a/imperative/tablegen/generated/opdef.cpy.inl +++ b/imperative/tablegen/generated/opdef.cpy.inl @@ -7274,6 +7274,151 @@ void _init_py_Correlation(py::module m) { mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Correlation::typeinfo(), &py_type).second); } +PyOpDefBegin(Cross) // { + static PyGetSetDef py_getsetters[]; + static PyMethodDef tp_methods[]; + + static PyObject* getstate(PyObject* self, PyObject*) { + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + std::unordered_map state { + + {"axisa", serialization::dump(opdef.axisa)}, + {"axisb", serialization::dump(opdef.axisb)}, + {"axisc", serialization::dump(opdef.axisc)} + }; + return py::cast(state).release().ptr(); + } + static PyObject* setstate(PyObject* self, PyObject* args) { + PyObject* dict = PyTuple_GetItem(args, 0); + if (!dict) return NULL; + auto state = py::cast>(dict); + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + + { + auto&& iter = state.find("axisa"); + if (iter != state.end()) { + opdef.axisa = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("axisb"); + if (iter != state.end()) { + opdef.axisb = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("axisc"); + if (iter != state.end()) { + opdef.axisc = serialization::load(iter->second); + } + } + Py_RETURN_NONE; + } + static int py_init(PyObject *self, PyObject *args, PyObject *kwds); + static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds); + static PyMethodDef py_init_methoddef; +// }; +PyOpDefEnd(Cross) + +int PyOp(Cross)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { + static const char* kwlist[] = {"axisa", "axisb", "axisc", "scope", NULL}; + PyObject *axisa = NULL, *axisb = NULL, *axisc = NULL, *scope = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOO", const_cast(kwlist), &axisa, &axisb, &axisc, &scope)) + return -1; + + if (axisa) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().axisa = + py::cast(py::handle(axisa)); + } CATCH_ALL(-1) + } + + if (axisb) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().axisb = + py::cast(py::handle(axisb)); + } CATCH_ALL(-1) + } + + if (axisc) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().axisc = + py::cast(py::handle(axisc)); + } CATCH_ALL(-1) + } + + if (scope) { + try { + reinterpret_cast(self)->op + ->set_scope(py::cast(py::handle(scope))); + } CATCH_ALL(-1) + } + + return 0; +} + +PyGetSetDef PyOp(Cross)::py_getsetters[] = { + {const_cast("axisa"), py_get_generic(Cross, axisa), py_set_generic(Cross, axisa), const_cast("axisa"), NULL}, + {const_cast("axisb"), py_get_generic(Cross, axisb), py_set_generic(Cross, axisb), const_cast("axisb"), NULL}, + {const_cast("axisc"), py_get_generic(Cross, axisc), py_set_generic(Cross, axisc), const_cast("axisc"), NULL}, + {NULL} /* Sentinel */ +}; + + PyMethodDef PyOp(Cross)::tp_methods[] = { + {const_cast("__getstate__"), PyOp(Cross)::getstate, METH_NOARGS, "Cross getstate"}, + {const_cast("__setstate__"), PyOp(Cross)::setstate, METH_VARARGS, "Cross setstate"}, + {NULL} /* Sentinel */ + }; + +PyObject *PyOp(Cross)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) { + if (PyOp(Cross)::py_init(self, args, kwds) < 0) { + return NULL; + } + Py_RETURN_NONE; +} + +PyMethodDef PyOp(Cross)::py_init_methoddef = { + "__init__", + (PyCFunction)PyOp(Cross)::py_init_proxy, + METH_VARARGS | METH_KEYWORDS, + "__init__(self, axisa: int = ..., axisb: int = ..., axisc: int = ...) -> None\n" +}; + +void _init_py_Cross(py::module m) { + using py_op = PyOp(Cross); + auto& py_type = PyOpType(Cross); + py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; + py_type.tp_name = "megengine.core._imperative_rt.ops.Cross"; + py_type.tp_basicsize = sizeof(PyOp(Cross)); + py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + py_type.tp_doc = "Cross"; + py_type.tp_base = &PyOpType(OpDef); + py_type.tp_dealloc = py_dealloc_generic; + py_type.tp_new = py_new_generic; + py_type.tp_init = py_op::py_init; + py_type.tp_methods = py_op::tp_methods; + py_type.tp_getset = py_op::py_getsetters; + + py_type.tp_dict = PyDict_New(); + PyObject* descr = PyDescr_NewMethod(&PyOpType(Cross), &PyOp(Cross)::py_init_methoddef); + PyDict_SetItemString(py_type.tp_dict, "__init__", descr); + mgb_assert(PyType_Ready(&py_type) >= 0); + + PyType_Modified(&py_type); + m.add_object("Cross", reinterpret_cast(&py_type)); + mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Cross::typeinfo(), &py_type).second); +} + PyOpDefBegin(Cumsum) // { static PyGetSetDef py_getsetters[]; static PyMethodDef tp_methods[]; @@ -23420,6 +23565,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) { _init_py_ConvolutionBackwardData(m); \ _init_py_Copy(m); \ _init_py_Correlation(m); \ + _init_py_Cross(m); \ _init_py_Cumsum(m); \ _init_py_CvtColor(m); \ _init_py_DeformableConv(m); \ diff --git a/imperative/tablegen/generated/opdef.h.inl b/imperative/tablegen/generated/opdef.h.inl index fd4103d4c8f593cd9f93643f9ea90bce235621e1..53885921f66d1b5ead6e5a88c371aee0b421419f 100644 --- a/imperative/tablegen/generated/opdef.h.inl +++ b/imperative/tablegen/generated/opdef.h.inl @@ -567,6 +567,21 @@ public: } }; +class Cross : public OpDefImplBase { + MGB_DYN_TYPE_OBJ_FINAL_DECL; + +public: + int32_t axisa = -1; + int32_t axisb = -1; + int32_t axisc = -1; + Cross() = default; + Cross(int32_t axisa_, int32_t axisb_, int32_t axisc_, std::string scope_ = {}): axisa(axisa_), axisb(axisb_), axisc(axisc_) { set_scope(scope_); } + Cross(::megdnn::param::Cross packed_param_0): axisa(packed_param_0.axisa), axisb(packed_param_0.axisb), axisc(packed_param_0.axisc) {} + ::megdnn::param::Cross param() const { + return {axisa, axisb, axisc}; + } +}; + class Cumsum : public OpDefImplBase { MGB_DYN_TYPE_OBJ_FINAL_DECL; diff --git a/imperative/tablegen/generated/opdef.py.inl b/imperative/tablegen/generated/opdef.py.inl index 6a142298ca4ac24541a5cdefdba422824e438a13..39fc0e8dd3a1ba23cd9d93fceef0fada2bd6c19e 100644 --- a/imperative/tablegen/generated/opdef.py.inl +++ b/imperative/tablegen/generated/opdef.py.inl @@ -678,6 +678,14 @@ CorrelationInst .def_readwrite("pad_size", &Correlation::pad_size) .def_readwrite("is_multiply", &Correlation::is_multiply); +py::class_, OpDef> CrossInst(m, "Cross"); + +CrossInst + .def(py::init(), py::arg("axisa") = -1, py::arg("axisb") = -1, py::arg("axisc") = -1, py::arg("scope") = {}) + .def_readwrite("axisa", &Cross::axisa) + .def_readwrite("axisb", &Cross::axisb) + .def_readwrite("axisc", &Cross::axisc); + py::class_, OpDef> CumsumInst(m, "Cumsum"); CumsumInst diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index aada14f6d637887811a443ef3a76b63af6576ee9..d32e3537f2ad45929f2893aac0d2c3aa3d03b70b 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -548,7 +548,6 @@ def Dropout: MgbHashableOp<"Dropout", [DropoutParam]> { ); }]; let cmpFunction = [{return $0.handle == $1.handle && $0.drop_prob == $1.drop_prob;}]; - } def MeshGrid: MgbHashableOp<"MeshGrid"> { let extraArguments = (ins @@ -639,4 +638,6 @@ def MultiHeadAttn: MgbHashableOp<"MultiHeadAttn", [MultiHeadAttnParam]> { } + +def Cross: MgbHashableOp<"Cross", [CrossParam]>; #endif // MGB_OPS diff --git a/src/opr/impl/blas.cpp b/src/opr/impl/blas.cpp index d1cca6f5334c65ea56b63ba0405727e3fe96d0ce..569de20b5969853dfc03e1f80954b1ed9625e34e 100644 --- a/src/opr/impl/blas.cpp +++ b/src/opr/impl/blas.cpp @@ -736,4 +736,29 @@ SymbolVarArray SVD::make( return ret; } +/* ================= Cross ================= */ + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(Cross); +MEGDNN_OPR_INIT2(Cross, "cross") + +void Cross::add_input_layout_constraint() { + input(0)->add_layout_constraint_contiguous(); + input(1)->add_layout_constraint_contiguous(); +} + +#if MGB_ENABLE_GRAD +MGB_IMPL_OPR_GRAD(Cross) { + SymbolVar grad, i0{opr.input(0)}, i1{opr.input(1)}, og{out_grad[0]}; + if (wrt_idx == 0) { + grad = Cross::make( + i1, og, {opr.param().axisb, opr.param().axisc, opr.param().axisa}); + } else { + mgb_assert(wrt_idx == 1); + grad = Cross::make( + og, i0, {opr.param().axisc, opr.param().axisa, opr.param().axisb}); + } + return grad.node(); +} +#endif + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/blas.oprdecl b/src/opr/impl/blas.oprdecl index bc3a95f20ed0afb980baf7e4ac8d09a418f76356..032df99bc258c2d8ed53a643c29088b92ce2f830 100644 --- a/src/opr/impl/blas.oprdecl +++ b/src/opr/impl/blas.oprdecl @@ -48,4 +48,9 @@ decl_opr('SVD', desc='Computes the singular value decompositions of matrices. ' 'The input must has shape ``[..., M, N]``.') +decl_opr('Cross', + inputs=['A', 'B'], + params='Cross', + desc='computes the cross product of two (arrays of) vectors.') + # vim: ft=python diff --git a/src/opr/impl/blas.sereg.h b/src/opr/impl/blas.sereg.h index 49a2d1ffa97c27ca7bf2dcb2803dfda6a660cdef..c4a39995f956f337d2ffeecfc741bc5c340dab6b 100644 --- a/src/opr/impl/blas.sereg.h +++ b/src/opr/impl/blas.sereg.h @@ -92,7 +92,7 @@ MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(BatchedMatrixMulV2, 2, opr_shallow_copy_matmu MGB_SEREG_OPR(Dot, 2); MGB_SEREG_OPR(MatrixInverse, 1); MGB_SEREG_OPR(SVD, 1); - +MGB_SEREG_OPR(Cross, 2); } // namespace opr } // namespace mgb diff --git a/src/opr/include/megbrain/opr/blas.h b/src/opr/include/megbrain/opr/blas.h index 36d4bba6666c42f7811291aa658191222a5aab73..3d6ef72795b90a9a3e3fd27a4a6c78303e88078a 100644 --- a/src/opr/include/megbrain/opr/blas.h +++ b/src/opr/include/megbrain/opr/blas.h @@ -122,6 +122,19 @@ public: const OperatorNodeConfig& config = {}); }; +MGB_DEFINE_OPR_CLASS(Cross, intl::MegDNNOprWrapperFwd) // { +public: + MGE_WIN_DECLSPEC_FUC Cross( + VarNode* A, VarNode* B, const Param& param, + const OperatorNodeConfig& config); + MGE_WIN_DECLSPEC_FUC static SymbolVar make( + SymbolVar A, SymbolVar B, const Param& param, + const OperatorNodeConfig& config = {}); + +private: + void add_input_layout_constraint() override; +}; + } // namespace opr } // namespace mgb diff --git a/src/opr/test/blas.cpp b/src/opr/test/blas.cpp index 9a61bea7efe0e9b7d9d51f6d6d39442c689baebe..7789db75f446b16310f9b1442b5b70f7d13ee6ca 100644 --- a/src/opr/test/blas.cpp +++ b/src/opr/test/blas.cpp @@ -634,6 +634,28 @@ TEST(TestOprBlas, Dot) { .run({TensorShape{0}, TensorShape{0}}); } +TEST(TestOprBlas, Cross) { + using Checker = AutoOprChecker<2, 1>; + auto nopr = megdnn_naive_handle()->create_operator(); + nopr->param() = {-1, -1, -1}; + + auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { + return {opr::Cross::make(inputs[0], inputs[1], {-1, -1, -1})}; + }; + + auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { + auto &&a = *inp[0], &&b = *inp[1]; + dest[0].resize(a.shape()); + nopr->exec(a.as_megdnn(), b.as_megdnn(), dest[0].as_megdnn(), {}); + }; + + Checker(make_graph, fwd) + .run({TensorShape{3}, TensorShape{3}}) + .run({TensorShape{6, 3}, TensorShape{6, 3}}) + .run({TensorShape{2, 4, 3}, TensorShape{2, 4, 3}}) + .run({TensorShape{2, 5, 2, 3}, TensorShape{2, 5, 2, 3}}); +} + TEST(TestOprBlas, TransMatMul) { run_trans_inp_test(); } diff --git a/src/serialization/impl/schema.fbs b/src/serialization/impl/schema.fbs index fbba7bfc011a1facda62b4914e9e2feecb1b59ca..8e9cd1eeb164cd7510bc397ff04c75ec8a92639a 100644 --- a/src/serialization/impl/schema.fbs +++ b/src/serialization/impl/schema.fbs @@ -128,6 +128,7 @@ union OperatorParam { param.GeneralNorm=94, param.MultiHeadAttn=95, param.Resize3D = 96, + param.Cross = 97, } table Operator { diff --git a/src/serialization/impl/schema_v2.fbs b/src/serialization/impl/schema_v2.fbs index fed5987778b231e3020c4b183a9fda39c022a70c..f40be3f470d223f4b0967a94e2cf6c44ab0eaff3 100644 --- a/src/serialization/impl/schema_v2.fbs +++ b/src/serialization/impl/schema_v2.fbs @@ -145,6 +145,7 @@ union OperatorParam { param.GeneralNorm=94, param.MultiHeadAttn=95, param.Resize3D = 96, + param.Cross = 97, } table Operator {