提交 b0470e73 编写于 作者: M Megvii Engine Team

feat(dnn,megbrain,imperative): add cross opr

GitOrigin-RevId: 2323a94753acef03bab685e40d2f5e0219450e09
上级 6bb9772c
......@@ -208,6 +208,31 @@ protected:
using SVD = SVDForward;
//! return the cross product of two (arrays of) vectors.
class Cross : public OperatorBase {
DEF_OPR_IMPL(Cross, OperatorBase, 2, 1);
DEF_OPR_PARAM(Cross);
public:
/**
* \see https://numpy.org/doc/stable/reference/generated/numpy.cross.html
*/
virtual void exec(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace) = 0;
MGE_WIN_DECLSPEC_FUC void deduce_layout(
const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
virtual size_t get_workspace_in_bytes(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
void get_ABC(
const TensorShape& shape, size_t& A, size_t& B, size_t& C, int32_t axis);
protected:
void check_exec(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_in_bytes);
};
} // namespace megdnn
#include "megdnn/internal/opr_header_epilogue.h"
......
......@@ -803,6 +803,23 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
'negative value to a lower diagonal.'),
0))
(pdef('Cross').
add_fields(
'int32',
Doc('axisa', 'axis of a that defines the vector(s). By default, the last axis.'),
'-1').
add_fields(
'int32',
Doc('axisb', 'axis of b that defines the vector(s). By default, the last axis.'),
'-1').
add_fields(
'int32',
Doc('axisc', 'axis of c containing the cross product vector(s). Ignored if both '
'input vectors have dimension 2, as the return is scalar. By default, the '
'last axis.'),
'-1')
)
(pdef('UniformRNG', version=0, is_legacy=True).
add_fields('uint64', 'seed', 0))
......
#include "megdnn/oprs.h"
#include "src/common/utils.h"
#include <algorithm>
#include <numeric>
namespace megdnn {
void Cross::deduce_layout(
const TensorLayout& A, const TensorLayout& B, TensorLayout& C) {
auto calibrated_axis = [](int ndim, int axis) {
return axis < 0 ? (axis + ndim) : axis;
};
int axis_a = calibrated_axis(A.ndim, param().axisa);
int axis_b = calibrated_axis(B.ndim, param().axisb);
int axis_c = calibrated_axis(A.ndim, param().axisc);
megdnn_assert(
A[axis_a] == 3 && B[axis_b] == 3,
"incompatible dimensions for cross product (dimension must be 3)");
bool matched = true;
TensorShape shp;
if (A.ndim != B.ndim) {
matched = false;
} else {
for (int i = 0, j = 0, k = 0; i < static_cast<int>(A.ndim); i++) {
if (i == axis_a)
continue;
if (j == axis_b)
++j;
if (A[i] != B[j]) {
matched = false;
break;
}
if (k == axis_c)
++k;
shp[k++] = A[i];
++j;
}
}
megdnn_assert(
matched, "cross op shape mismatch: %s vs %s", A.to_string().c_str(),
B.to_string().c_str());
shp.ndim = A.ndim;
shp[axis_c] = A[axis_a];
C = TensorLayout{shp, A.dtype};
}
void Cross::check_exec(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_in_bytes) {
megdnn_assert_eq_dtype(A, B);
megdnn_assert_eq_dtype(B, C);
TensorLayout c_expected;
deduce_layout(A, B, c_expected);
megdnn_assert_eq_layout(c_expected, C);
megdnn_assert_contiguous(A);
megdnn_assert_contiguous(B);
megdnn_assert_contiguous(C);
auto required_workspace_in_bytes = get_workspace_in_bytes(A, B, C);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void Cross::get_ABC(
const TensorShape& shape, size_t& A, size_t& B, size_t& C, int32_t axis) {
auto shape_arr = shape.shape;
auto ndim = shape.ndim;
if (axis < 0)
axis += ndim;
A = std::accumulate(shape_arr, shape_arr + axis, 1_z, SafeMultiplies<size_t>());
B = shape_arr[axis];
C = std::accumulate(
shape_arr + (axis + 1), shape_arr + ndim, 1_z, SafeMultiplies<size_t>());
}
} // namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
......@@ -224,7 +224,8 @@ private:
cb(GroupNormBackward) \
cb(MaskedFill) \
cb(MultiHeadAttnForward)\
cb(MultiHeadAttnBackward)
cb(MultiHeadAttnBackward) \
cb(Cross)
// clang-format on
/*!
......
......@@ -80,6 +80,7 @@ DEF(IndexingRemapBackward, 3, true, false);
DEF(Linspace, 1, true, false);
DEF(Eye, 1, true, false);
DEF(Diag, 2, true, true);
DEF(Cross, 3, true, true);
DEF(Flip, 2, true, true);
DEF(ROICopy, 2, true, true);
DEF(Rotate, 2, true, true);
......
#include "megdnn/dtype.h"
#include "src/cuda/cross/cross.cuh"
#include "src/cuda/utils.cuh"
namespace {
template <typename T>
__global__ void cross_kernel(
T* A, size_t stride_a0, size_t stride_a1, T* B, size_t stride_b0,
size_t stride_b1, T* C, size_t stride_c0, size_t stride_c1, size_t N) {
size_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i < N) {
size_t ida = (i / stride_a1) * stride_a0 + i % stride_a1;
size_t idb = (i / stride_b1) * stride_b0 + i % stride_b1;
size_t idc = (i / stride_c1) * stride_c0 + i % stride_c1;
C[idc] = A[ida + stride_a1] * B[idb + 2 * stride_b1] -
A[ida + 2 * stride_a1] * B[idb + stride_b1];
C[idc + stride_c1] =
A[ida + 2 * stride_a1] * B[idb] - A[ida] * B[idb + 2 * stride_b1];
C[idc + 2 * stride_c1] =
A[ida] * B[idb + stride_b1] - A[ida + stride_a1] * B[idb];
}
}
} // anonymous namespace
namespace megdnn {
namespace cuda {
namespace cross {
template <typename T>
void exec_internal(
T* A, size_t stride_a0, size_t stride_a1, T* B, size_t stride_b0,
size_t stride_b1, T* C, size_t stride_c0, size_t stride_c1, size_t N,
cudaStream_t stream) {
cross_kernel<T><<<DIVUP(N, NR_THREADS), NR_THREADS, 0, stream>>>(
A, stride_a0, stride_a1, B, stride_b0, stride_b1, C, stride_c0, stride_c1,
N);
after_kernel_launch();
}
#define INST(T) \
template void exec_internal<T>( \
T*, size_t, size_t, T*, size_t, size_t, T*, size_t, size_t, size_t, \
cudaStream_t);
#define cb(DType) INST(typename DTypeTrait<DType>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef INST
#undef cb
} // namespace cross
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
#pragma once
#include <cuda_runtime_api.h>
#include <stdint.h>
namespace megdnn {
namespace cuda {
namespace cross {
template <typename T>
void exec_internal(
T* A, size_t stride_a0, size_t stride_a1, T* B, size_t stride_b0,
size_t stride_b1, T* C, size_t stride_c0, size_t stride_c1, size_t N,
cudaStream_t stream);
} // namespace cross
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
#include "src/cuda/cross/opr_impl.h"
#include "src/cuda/cross/cross.cuh"
#include "src/cuda/utils.h"
#include <algorithm>
#include <numeric>
namespace megdnn {
namespace cuda {
void CrossImpl::exec(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) {
check_exec(A.layout, B.layout, C.layout, workspace.size);
size_t a1, b1, c1, a2, b2, c2, a3, b3, c3;
get_ABC(A.layout, a1, b1, c1, param().axisa);
get_ABC(B.layout, a2, b2, c2, param().axisb);
get_ABC(C.layout, a3, b3, c3, param().axisc);
#define cb(DType) \
if (C.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
using ctype = typename DTypeTrait<DType>::ctype; \
cross::exec_internal<ctype>( \
A.ptr<ctype>(), b1 * c1, c1, B.ptr<ctype>(), b2 * c2, c2, \
C.ptr<ctype>(), b3 * c3, c3, a1 * c1, cuda_stream(handle())); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
}
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace cuda {
class CrossImpl final : public Cross {
public:
using Cross::Cross;
void exec(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& A, const TensorLayout& B,
const TensorLayout& C) override {
return 0;
}
};
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
......@@ -16,6 +16,7 @@
#include "src/cuda/convolution3d/opr_impl.h"
#include "src/cuda/convpooling/opr_impl.h"
#include "src/cuda/correlation/opr_impl.h"
#include "src/cuda/cross/opr_impl.h"
#include "src/cuda/cumsum/opr_impl.h"
#include "src/cuda/cvt_color/opr_impl.h"
#include "src/cuda/dct/opr_impl.h"
......@@ -234,6 +235,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(RegionRestrictedConvolutionBackwardData);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RegionRestrictedConvolutionBackwardFilter);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MultiHeadAttnForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MultiHeadAttnBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Cross);
template <typename Opr>
std::unique_ptr<Opr> HandleImpl::create_operator() {
......
#include "src/naive/cross/opr_impl.h"
#include <algorithm>
#include <numeric>
#include "src/common/utils.h"
#include "src/naive/handle.h"
namespace megdnn {
namespace naive {
template <typename ctype>
void CrossImpl::exec_internal(
ctype* A, size_t a1, size_t b1, size_t c1, ctype* B, size_t a2, size_t b2,
size_t c2, ctype* C, size_t a3, size_t b3, size_t c3) {
(void)a2;
(void)a3;
size_t N = a1 * c1;
for (size_t i = 0; i < N; ++i) {
size_t ida = (i / c1) * b1 * c1 + i % c1;
size_t idb = (i / c2) * b2 * c2 + i % c2;
size_t idc = (i / c3) * b3 * c3 + i % c3;
C[idc] = A[ida + c1] * B[idb + 2 * c2] - A[ida + 2 * c1] * B[idb + c2];
C[idc + c3] = A[ida + 2 * c1] * B[idb] - A[ida] * B[idb + 2 * c2];
C[idc + 2 * c3] = A[ida] * B[idb + c2] - A[ida + c1] * B[idb];
}
}
void CrossImpl::exec(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) {
check_exec(A.layout, B.layout, C.layout, workspace.size);
size_t a1, b1, c1, a2, b2, c2, a3, b3, c3;
get_ABC(A.layout, a1, b1, c1, param().axisa);
get_ABC(B.layout, a2, b2, c2, param().axisb);
get_ABC(C.layout, a3, b3, c3, param().axisc);
#define cb(DType) \
if (A.layout.dtype == DType()) { \
using ctype = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal<ctype>( \
A.ptr<ctype>(), a1, b1, c1, B.ptr<ctype>(), a2, b2, c2, \
C.ptr<ctype>(), a3, b3, c3)); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
}
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace naive {
class CrossImpl : public Cross {
public:
using Cross::Cross;
void exec(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&) override {
return 0;
}
private:
template <typename ctype>
void exec_internal(
ctype* A, size_t a1, size_t b1, size_t c1, ctype* B, size_t a2, size_t b2,
size_t c2, ctype* C, size_t a3, size_t b3, size_t c3);
};
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
......@@ -18,6 +18,7 @@
#include "src/naive/convolution3d/opr_impl.h"
#include "src/naive/convpooling/opr_impl.h"
#include "src/naive/correlation/opr_impl.h"
#include "src/naive/cross/opr_impl.h"
#include "src/naive/cumsum/opr_impl.h"
#include "src/naive/cvt_color/opr_impl.h"
#include "src/naive/dct/opr_impl.h"
......
#include "test/cuda/fixture.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
namespace megdnn {
namespace test {
TEST_F(CUDA, CROSS) {
Checker<Cross> checker(handle_cuda());
for (DType dtype :
std::vector<DType>{dtype::Float16(), dtype::Int32(), dtype::Float32()}) {
checker.set_param({-2, 1, -1})
.set_dtype(0, dtype)
.set_dtype(1, dtype)
.set_dtype(2, dtype);
checker.exec(TensorShapeArray{{2, 3, 4}, {2, 3, 4}, {2, 4, 3}});
checker.set_param({0, -1, 2})
.set_dtype(0, dtype)
.set_dtype(1, dtype)
.set_dtype(2, dtype);
checker.exec(TensorShapeArray{{3, 2, 3, 4}, {2, 3, 4, 3}, {2, 3, 3, 4}});
}
}
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
#include "test/naive/fixture.h"
#include "megdnn/oprs/linalg.h"
#include "test/common/checker.h"
using namespace megdnn;
using namespace test;
TEST_F(NAIVE, CROSS) {
Checker<Cross> checker(handle(), /* check_dispatch */ false);
Cross::Param param{2, 0, -3};
TensorND A = TensorValue(
{2, 2, 3, 3}, dtype::Float32(),
{0.14156187, 0.91729394, 0.74985146, 0.06110339, 0.67560272, 0.38569672,
0.63086542, 0.96403808, 0.44215527, 0.04618851, 0.17031024, 0.83291294,
0.15374447, 0.57640597, 0.90293485, 0.55324231, 0.75830697, 0.43003672,
0.18093289, 0.73350109, 0.38130576, 0.29138499, 0.77249111, 0.0878262,
0.11699259, 0.73904961, 0.52122415, 0.10095503, 0.15948783, 0.88822725,
0.12915866, 0.59800798, 0.17838123, 0.37576929, 0.9885241, 0.25812663});
TensorND B = TensorValue(
{3, 2, 2, 3}, dtype::Float32(),
{0.87391774, 0.65328165, 0.44478127, 0.94795241, 0.84836271, 0.8054875,
0.44598355, 0.50247015, 0.36997338, 0.77503215, 0.87235448, 0.21397323,
0.85153319, 0.16009368, 0.77536431, 0.04486964, 0.37626156, 0.81259061,
0.08839276, 0.50477703, 0.91626721, 0.44417563, 0.03038398, 0.21980224,
0.71472471, 0.85791626, 0.40345219, 0.16018583, 0.38511319, 0.23809232,
0.58711753, 0.92031592, 0.65426633, 0.14788665, 0.73273333, 0.3182309});
TensorND C = TensorValue(
{2, 3, 2, 3}, dtype::Float32(),
{-0.49353074, 0.42527416, -0.18722123, -0.0001961, -0.06334022,
-0.13446194, 0.45014671, -0.157173, -0.10586683, 0.51704864,
0.57773064, 0.14807902, 0.0671453, -0.29450591, 0.40985739,
-0.14366998, -0.42492013, -0.05048549, 0.16073594, 0.3378806,
-0.42011888, -0.14780672, 0.40814509, 0.00002961, -0.0540521,
-0.30370236, -0.05663646, 0.27630338, 0.74548138, -0.22742917,
-0.11395976, -0.01789922, 0.31688461, -0.05526035, -0.51682907,
0.15706553});
checker.set_param(param).exect(Testcase{A, B, {}}, Testcase{{}, {}, C});
}
......@@ -13,12 +13,13 @@ from ..core.tensor.utils import _normalize_axis
from ..tensor import Tensor
from ..utils.deprecation import deprecated_kwargs_default
from .elemwise import _elemwise_multi_type, clip
from .tensor import broadcast_to, expand_dims, squeeze
from .tensor import broadcast_to, concat, expand_dims, squeeze, zeros
__all__ = [
"argmax",
"argmin",
"argsort",
"cross",
"dot",
"isinf",
"isnan",
......@@ -720,6 +721,50 @@ def matinv(inp: Tensor) -> Tensor:
return result
def cross(
a: Tensor,
b: Tensor,
axisa: int = -1,
axisb: int = -1,
axisc: int = -1,
axis: int = None,
) -> Tensor:
r"""Return the cross product of two (arrays of) vectors.
The cross product of ``a`` and ``b`` in `R^3` is a vector perpendicular to both ``a`` and ``b``. If ``a`` and ``b`` are arrays of vectors, the vectors are defined by the last axis of ``a`` and ``b`` by default, and these axes can have dimensions 2 or 3.
Where the dimension of either ``a`` or ``b`` is 2, the third component of the input vector is assumed to be zero and the cross product calculated accordingly.
Args:
a: components of the first vector(s).
b: components of the first vector(s).
axisa: axis of a that defines the vector(s). By default, the last axis.
axisb: axis of b that defines the vector(s). By default, the last axis.
axisc: axis of c containing the cross product vector(s). By default, the last axis.
axis: if defined, the axis of a, b and c that defines the vector(s) and cross product(s). Overrides axisa, axisb and axisc.
Returns:
vector cross product(s).
Examples:
>>> a = Tensor([1.0, 2.0, 3.0])
>>> b = Tensor([4.0, 5.0, 6.0])
>>> out = F.cross(a, b)
>>> out.numpy()
array([-3., 6., -3.], dtype=float32)
"""
if axis is not None:
axisa = axisb = axisc = axis
if a.ndim == 1 and len(a) == 2:
a = Tensor(np.append(a, 0))
if b.ndim == 1 and len(b) == 2:
b = Tensor(np.append(b, 0))
op = builtin.Cross(axisa, axisb, axisc)
(result,) = apply(op, a, b)
return result
def matmul(
inp1: Tensor,
inp2: Tensor,
......
......@@ -1680,3 +1680,49 @@ def test_softmax():
actual = F.softmax(data)
np.testing.assert_allclose(actual.numpy(), desired, rtol=1e-5)
def test_cross():
shape1 = 3
shape2 = 3
shape3 = (4, 2, 3)
shape4 = (4, 2, 3)
data1 = np.random.random(shape1).astype("float32")
data2 = np.random.random(shape2).astype("float32")
data3 = np.random.random(shape3).astype("float32")
data4 = np.random.random(shape4).astype("float32")
cases = [
{"input": [data1, data2]},
{"input": [data3, data4]},
]
opr_test(
cases,
F.cross,
compare_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol=1e-4),
ref_fn=np.cross,
)
data5 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
data6 = np.array([[7, 8, 9], [4, 5, 6], [1, 2, 3]])
res = F.cross(mge.tensor(data5), mge.tensor(data6), axisa=0, axisb=0)
dst = np.cross(data5, data6, axisa=0, axisb=0)
np.testing.assert_allclose(res.numpy(), dst, rtol=1e-4)
data7 = np.array([1, 2, 3])
data8 = np.array([4, 5, 6])
res = F.cross(mge.tensor(data7), mge.tensor(data8), axisa=0, axisb=0)
dst = np.cross(data7, data8, axisa=0, axisb=0)
np.testing.assert_allclose(res.numpy(), dst, rtol=1e-4)
data9 = np.array([[6, 7, 4], [15, 2, 8], [9, 40, 39]])
data10 = np.array([[5, 8, 9], [14, 21, 17], [10, 3, 47]])
res = F.cross(mge.tensor(data9), mge.tensor(data10), axisa=1, axisb=-1)
dst = np.cross(data9, data10, axisa=1, axisb=-1)
np.testing.assert_allclose(res.numpy(), dst, rtol=1e-4)
data11 = np.array([1, 2])
data12 = np.array([4, 5, 6])
res = F.cross(mge.tensor(data11), mge.tensor(data12))
dst = np.cross(data11, data12)
np.testing.assert_allclose(res.numpy(), dst, rtol=1e-4)
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/blas.h"
#include "megdnn/oprs.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
namespace mgb {
namespace imperative {
namespace cross {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<Cross>();
mgb_assert(inputs.size() == 2);
cg::OperatorNodeConfig config{op.make_name()};
return opr::Cross::make(inputs[0], inputs[1], op.param(), config);
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
mgb_assert(inputs.size() == 2, "Cross expects two inputs");
auto&& op_def = def.cast_final_safe<Cross>();
auto comp_node = inputs[0].comp_node;
if (!inputs[0].layout.ndim) {
return {{{inputs[0].layout, comp_node}}, false};
}
if (!inputs[1].layout.ndim) {
return {{{inputs[1].layout, comp_node}}, false};
}
DnnOprHelper<megdnn::Cross> dnn_op(op_def.param());
auto oup_layout = dnn_op.deduce_layout(inputs[0].layout, inputs[1].layout);
return {{{oup_layout, comp_node}}, true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto comp_node = inputs[0]->comp_node();
auto&& op_def = def.cast_final_safe<Cross>();
DnnOprCaller<megdnn::Cross> dnn_op(comp_node, op_def.param());
auto dst = [&] {
if (validated) {
return output_descs[0].layout;
} else {
return dnn_op.deduce_layout(inputs[0]->layout(), inputs[1]->layout());
}
}();
auto out = Tensor::make(dst, comp_node);
if (!inputs[0]->layout().is_empty() && !inputs[1]->layout().is_empty()) {
dnn_op.exec_with_ws(inputs[0], inputs[1], out);
}
return {out};
}
OP_TRAIT_REG(Cross, Cross)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace cross
} // namespace imperative
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
29b2127eb4034bf24e473945d70ead4a ../../dnn/scripts/opr_param_defs.py
639ff50d64fcb78374de266c88942c2c ../../src/core/include/megbrain/ir/ops.td
16654743e01160eeee879107cc4cac41 generated/opdef.h.inl
97c541ed45b0be98f1ac2700f5b4d8a6 generated/opdef.cpp.inl
6f9c6a7a1d71cca195c1e30743a1f542 generated/opdef.py.inl
806c5ceb34f571fc5c9d98d2ca8cad63 generated/opdef.cpy.inl
e4035bfefce3a2cc0e8cc6ec7fcac227 ../../dnn/scripts/opr_param_defs.py
13ab898fce3749ebbcabf7c145876147 ../../src/core/include/megbrain/ir/ops.td
9dda6e2db75279373ec6809b297a2370 generated/opdef.h.inl
aabc2d8146742faacabf56e376177e7b generated/opdef.cpp.inl
8a5dffac1df3286178b3fd304c39b5da generated/opdef.py.inl
04322b642bba8f684034fcc5dc27efcf generated/opdef.cpy.inl
911001ef0dd771024919f7a1a3a009db generated/enum_macro.h
......@@ -2303,6 +2303,49 @@ OP_TRAIT_REG(Correlation, Correlation)
.props(Correlation_props_impl)
.make_name(Correlation_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Cross);
namespace {
size_t Cross_hash_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<Cross>();
static_cast<void>(op_);
size_t val = mgb::hash(op_.dyn_typeinfo());
val = mgb::hash_pair_combine(val, mgb::hash(op_.axisa));
val = mgb::hash_pair_combine(val, mgb::hash(op_.axisb));
val = mgb::hash_pair_combine(val, mgb::hash(op_.axisc));
return val;
}
bool Cross_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
auto &&a_ = lhs_.cast_final_safe<Cross>(),
&&b_ = rhs_.cast_final_safe<Cross>();
static_cast<void>(a_);
static_cast<void>(b_);
if (a_.axisa != b_.axisa) return false;
if (a_.axisb != b_.axisb) return false;
if (a_.axisc != b_.axisc) return false;
return true;
}
std::vector<std::pair<const char*, std::string>> Cross_props_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<Cross>();
static_cast<void>(op_);
std::vector<std::pair<const char*, std::string>> props_;
props_.emplace_back("axisa", std::to_string(op_.axisa));
props_.emplace_back("axisb", std::to_string(op_.axisb));
props_.emplace_back("axisc", std::to_string(op_.axisc));
return props_;
}
std::string Cross_make_name_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<Cross>();
static_cast<void>(op_);
return "Cross";
}
} // anonymous namespace
OP_TRAIT_REG(Cross, Cross)
.hash(Cross_hash_impl)
.is_same_st(Cross_is_same_st_impl)
.props(Cross_props_impl)
.make_name(Cross_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Cumsum);
namespace {
......
......@@ -7274,6 +7274,151 @@ void _init_py_Correlation(py::module m) {
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Correlation::typeinfo(), &py_type).second);
}
PyOpDefBegin(Cross) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
static PyObject* getstate(PyObject* self, PyObject*) {
auto& opdef = reinterpret_cast<PyOp(Cross)*>(self)->inst();
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {
{"axisa", serialization<decltype(opdef.axisa)>::dump(opdef.axisa)},
{"axisb", serialization<decltype(opdef.axisb)>::dump(opdef.axisb)},
{"axisc", serialization<decltype(opdef.axisc)>::dump(opdef.axisc)}
};
return py::cast(state).release().ptr();
}
static PyObject* setstate(PyObject* self, PyObject* args) {
PyObject* dict = PyTuple_GetItem(args, 0);
if (!dict) return NULL;
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
auto& opdef = reinterpret_cast<PyOp(Cross)*>(self)->inst();
static_cast<void>(opdef);
{
auto&& iter = state.find("axisa");
if (iter != state.end()) {
opdef.axisa = serialization<decltype(opdef.axisa)>::load(iter->second);
}
}
{
auto&& iter = state.find("axisb");
if (iter != state.end()) {
opdef.axisb = serialization<decltype(opdef.axisb)>::load(iter->second);
}
}
{
auto&& iter = state.find("axisc");
if (iter != state.end()) {
opdef.axisc = serialization<decltype(opdef.axisc)>::load(iter->second);
}
}
Py_RETURN_NONE;
}
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds);
static PyMethodDef py_init_methoddef;
// };
PyOpDefEnd(Cross)
int PyOp(Cross)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
static const char* kwlist[] = {"axisa", "axisb", "axisc", "scope", NULL};
PyObject *axisa = NULL, *axisb = NULL, *axisc = NULL, *scope = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOO", const_cast<char**>(kwlist), &axisa, &axisb, &axisc, &scope))
return -1;
if (axisa) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(Cross)*>(self)->inst().axisa =
py::cast<decltype(Cross::axisa)>(py::handle(axisa));
} CATCH_ALL(-1)
}
if (axisb) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(Cross)*>(self)->inst().axisb =
py::cast<decltype(Cross::axisb)>(py::handle(axisb));
} CATCH_ALL(-1)
}
if (axisc) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(Cross)*>(self)->inst().axisc =
py::cast<decltype(Cross::axisc)>(py::handle(axisc));
} CATCH_ALL(-1)
}
if (scope) {
try {
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(py::cast<std::string>(py::handle(scope)));
} CATCH_ALL(-1)
}
return 0;
}
PyGetSetDef PyOp(Cross)::py_getsetters[] = {
{const_cast<char*>("axisa"), py_get_generic(Cross, axisa), py_set_generic(Cross, axisa), const_cast<char*>("axisa"), NULL},
{const_cast<char*>("axisb"), py_get_generic(Cross, axisb), py_set_generic(Cross, axisb), const_cast<char*>("axisb"), NULL},
{const_cast<char*>("axisc"), py_get_generic(Cross, axisc), py_set_generic(Cross, axisc), const_cast<char*>("axisc"), NULL},
{NULL} /* Sentinel */
};
PyMethodDef PyOp(Cross)::tp_methods[] = {
{const_cast<char*>("__getstate__"), PyOp(Cross)::getstate, METH_NOARGS, "Cross getstate"},
{const_cast<char*>("__setstate__"), PyOp(Cross)::setstate, METH_VARARGS, "Cross setstate"},
{NULL} /* Sentinel */
};
PyObject *PyOp(Cross)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) {
if (PyOp(Cross)::py_init(self, args, kwds) < 0) {
return NULL;
}
Py_RETURN_NONE;
}
PyMethodDef PyOp(Cross)::py_init_methoddef = {
"__init__",
(PyCFunction)PyOp(Cross)::py_init_proxy,
METH_VARARGS | METH_KEYWORDS,
"__init__(self, axisa: int = ..., axisb: int = ..., axisc: int = ...) -> None\n"
};
void _init_py_Cross(py::module m) {
using py_op = PyOp(Cross);
auto& py_type = PyOpType(Cross);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.Cross";
py_type.tp_basicsize = sizeof(PyOp(Cross));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "Cross";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_methods = py_op::tp_methods;
py_type.tp_getset = py_op::py_getsetters;
py_type.tp_dict = PyDict_New();
PyObject* descr = PyDescr_NewMethod(&PyOpType(Cross), &PyOp(Cross)::py_init_methoddef);
PyDict_SetItemString(py_type.tp_dict, "__init__", descr);
mgb_assert(PyType_Ready(&py_type) >= 0);
PyType_Modified(&py_type);
m.add_object("Cross", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Cross::typeinfo(), &py_type).second);
}
PyOpDefBegin(Cumsum) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
......@@ -23420,6 +23565,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) {
_init_py_ConvolutionBackwardData(m); \
_init_py_Copy(m); \
_init_py_Correlation(m); \
_init_py_Cross(m); \
_init_py_Cumsum(m); \
_init_py_CvtColor(m); \
_init_py_DeformableConv(m); \
......
......@@ -567,6 +567,21 @@ public:
}
};
class Cross : public OpDefImplBase<Cross> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
int32_t axisa = -1;
int32_t axisb = -1;
int32_t axisc = -1;
Cross() = default;
Cross(int32_t axisa_, int32_t axisb_, int32_t axisc_, std::string scope_ = {}): axisa(axisa_), axisb(axisb_), axisc(axisc_) { set_scope(scope_); }
Cross(::megdnn::param::Cross packed_param_0): axisa(packed_param_0.axisa), axisb(packed_param_0.axisb), axisc(packed_param_0.axisc) {}
::megdnn::param::Cross param() const {
return {axisa, axisb, axisc};
}
};
class Cumsum : public OpDefImplBase<Cumsum> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
......
......@@ -678,6 +678,14 @@ CorrelationInst
.def_readwrite("pad_size", &Correlation::pad_size)
.def_readwrite("is_multiply", &Correlation::is_multiply);
py::class_<Cross, std::shared_ptr<Cross>, OpDef> CrossInst(m, "Cross");
CrossInst
.def(py::init<int32_t, int32_t, int32_t, std::string>(), py::arg("axisa") = -1, py::arg("axisb") = -1, py::arg("axisc") = -1, py::arg("scope") = {})
.def_readwrite("axisa", &Cross::axisa)
.def_readwrite("axisb", &Cross::axisb)
.def_readwrite("axisc", &Cross::axisc);
py::class_<Cumsum, std::shared_ptr<Cumsum>, OpDef> CumsumInst(m, "Cumsum");
CumsumInst
......
......@@ -548,7 +548,6 @@ def Dropout: MgbHashableOp<"Dropout", [DropoutParam]> {
);
}];
let cmpFunction = [{return $0.handle == $1.handle && $0.drop_prob == $1.drop_prob;}];
}
def MeshGrid: MgbHashableOp<"MeshGrid"> {
let extraArguments = (ins
......@@ -639,4 +638,6 @@ def MultiHeadAttn: MgbHashableOp<"MultiHeadAttn", [MultiHeadAttnParam]> {
}
def Cross: MgbHashableOp<"Cross", [CrossParam]>;
#endif // MGB_OPS
......@@ -736,4 +736,29 @@ SymbolVarArray SVD::make(
return ret;
}
/* ================= Cross ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Cross);
MEGDNN_OPR_INIT2(Cross, "cross")
void Cross::add_input_layout_constraint() {
input(0)->add_layout_constraint_contiguous();
input(1)->add_layout_constraint_contiguous();
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Cross) {
SymbolVar grad, i0{opr.input(0)}, i1{opr.input(1)}, og{out_grad[0]};
if (wrt_idx == 0) {
grad = Cross::make(
i1, og, {opr.param().axisb, opr.param().axisc, opr.param().axisa});
} else {
mgb_assert(wrt_idx == 1);
grad = Cross::make(
og, i0, {opr.param().axisc, opr.param().axisa, opr.param().axisb});
}
return grad.node();
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -48,4 +48,9 @@ decl_opr('SVD',
desc='Computes the singular value decompositions of matrices. '
'The input must has shape ``[..., M, N]``.')
decl_opr('Cross',
inputs=['A', 'B'],
params='Cross',
desc='computes the cross product of two (arrays of) vectors.')
# vim: ft=python
......@@ -92,7 +92,7 @@ MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(BatchedMatrixMulV2, 2, opr_shallow_copy_matmu
MGB_SEREG_OPR(Dot, 2);
MGB_SEREG_OPR(MatrixInverse, 1);
MGB_SEREG_OPR(SVD, 1);
MGB_SEREG_OPR(Cross, 2);
} // namespace opr
} // namespace mgb
......
......@@ -122,6 +122,19 @@ public:
const OperatorNodeConfig& config = {});
};
MGB_DEFINE_OPR_CLASS(Cross, intl::MegDNNOprWrapperFwd<megdnn::Cross>) // {
public:
MGE_WIN_DECLSPEC_FUC Cross(
VarNode* A, VarNode* B, const Param& param,
const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar A, SymbolVar B, const Param& param,
const OperatorNodeConfig& config = {});
private:
void add_input_layout_constraint() override;
};
} // namespace opr
} // namespace mgb
......
......@@ -634,6 +634,28 @@ TEST(TestOprBlas, Dot) {
.run({TensorShape{0}, TensorShape{0}});
}
TEST(TestOprBlas, Cross) {
using Checker = AutoOprChecker<2, 1>;
auto nopr = megdnn_naive_handle()->create_operator<megdnn::Cross>();
nopr->param() = {-1, -1, -1};
auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
return {opr::Cross::make(inputs[0], inputs[1], {-1, -1, -1})};
};
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
auto &&a = *inp[0], &&b = *inp[1];
dest[0].resize(a.shape());
nopr->exec(a.as_megdnn(), b.as_megdnn(), dest[0].as_megdnn(), {});
};
Checker(make_graph, fwd)
.run({TensorShape{3}, TensorShape{3}})
.run({TensorShape{6, 3}, TensorShape{6, 3}})
.run({TensorShape{2, 4, 3}, TensorShape{2, 4, 3}})
.run({TensorShape{2, 5, 2, 3}, TensorShape{2, 5, 2, 3}});
}
TEST(TestOprBlas, TransMatMul) {
run_trans_inp_test<float, float>();
}
......
......@@ -128,6 +128,7 @@ union OperatorParam {
param.GeneralNorm=94,
param.MultiHeadAttn=95,
param.Resize3D = 96,
param.Cross = 97,
}
table Operator {
......
......@@ -145,6 +145,7 @@ union OperatorParam {
param.GeneralNorm=94,
param.MultiHeadAttn=95,
param.Resize3D = 96,
param.Cross = 97,
}
table Operator {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册