...
 
Commits (6)
    https://gitcode.net/megvii/megengine/-/commit/53a765267d831d6eb5ced676f80ffff41623d26f fix(opencl): fix MGE_WITH_OPENCL_ONLY crash on some 2023-07-18T10:04:11+08:00 Megvii Engine Team megengine@megvii.com model, eg, use type_cvt. at load stage, some(StaticInferOpr) naive op->exe will be called do not check elemwise_multi_type, as ocl do not run quantization model GitOrigin-RevId: 510df237428ca76b02f7c776cdec683e52c1a7bc https://gitcode.net/megvii/megengine/-/commit/fca4ae57faea24d2bffdc6f3fac90930cdb8428c fix(mge/dist): give more error messages when the device check fails 2023-07-20T14:04:09+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: 5601aca458157404244fa16c9623428722544308 https://gitcode.net/megvii/megengine/-/commit/b9e0319fc9591501a93c3480bcc301634ba64563 fix(imperative): fix format transformation handle nchw format tensor 2023-07-25T12:04:24+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: f5838c1f7fbc1a1f4ffd9fc8951ed0cbdd422dc2 https://gitcode.net/megvii/megengine/-/commit/ba9f67eb49e46d2a81b353b71ca3270e7fbdabc8 feat(imperative): add additive noise aug 2023-07-25T17:05:42+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: e1ef0c0e3b6d9638e236905ab034a88d4a686155 https://gitcode.net/megvii/megengine/-/commit/6bb9772cbcbbac43948b8d497c956189325122d8 feat(lite): support set the consistent stream id when multiconpnode 2023-07-26T13:04:04+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: 2eed8c2fbd6079c9cfc24e40e2746361c99f9796 https://gitcode.net/megvii/megengine/-/commit/b0470e738f7b5d451e8fa38cd6c4d38a3340212d feat(dnn,megbrain,imperative): add cross opr 2023-07-27T20:04:11+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: 2323a94753acef03bab685e40d2f5e0219450e09
......@@ -208,6 +208,31 @@ protected:
using SVD = SVDForward;
//! return the cross product of two (arrays of) vectors.
class Cross : public OperatorBase {
DEF_OPR_IMPL(Cross, OperatorBase, 2, 1);
DEF_OPR_PARAM(Cross);
public:
/**
* \see https://numpy.org/doc/stable/reference/generated/numpy.cross.html
*/
virtual void exec(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace) = 0;
MGE_WIN_DECLSPEC_FUC void deduce_layout(
const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
virtual size_t get_workspace_in_bytes(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
void get_ABC(
const TensorShape& shape, size_t& A, size_t& B, size_t& C, int32_t axis);
protected:
void check_exec(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_in_bytes);
};
} // namespace megdnn
#include "megdnn/internal/opr_header_epilogue.h"
......
......@@ -803,6 +803,23 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
'negative value to a lower diagonal.'),
0))
(pdef('Cross').
add_fields(
'int32',
Doc('axisa', 'axis of a that defines the vector(s). By default, the last axis.'),
'-1').
add_fields(
'int32',
Doc('axisb', 'axis of b that defines the vector(s). By default, the last axis.'),
'-1').
add_fields(
'int32',
Doc('axisc', 'axis of c containing the cross product vector(s). Ignored if both '
'input vectors have dimension 2, as the return is scalar. By default, the '
'last axis.'),
'-1')
)
(pdef('UniformRNG', version=0, is_legacy=True).
add_fields('uint64', 'seed', 0))
......
#include "megdnn/oprs.h"
#include "src/common/utils.h"
#include <algorithm>
#include <numeric>
namespace megdnn {
void Cross::deduce_layout(
const TensorLayout& A, const TensorLayout& B, TensorLayout& C) {
auto calibrated_axis = [](int ndim, int axis) {
return axis < 0 ? (axis + ndim) : axis;
};
int axis_a = calibrated_axis(A.ndim, param().axisa);
int axis_b = calibrated_axis(B.ndim, param().axisb);
int axis_c = calibrated_axis(A.ndim, param().axisc);
megdnn_assert(
A[axis_a] == 3 && B[axis_b] == 3,
"incompatible dimensions for cross product (dimension must be 3)");
bool matched = true;
TensorShape shp;
if (A.ndim != B.ndim) {
matched = false;
} else {
for (int i = 0, j = 0, k = 0; i < static_cast<int>(A.ndim); i++) {
if (i == axis_a)
continue;
if (j == axis_b)
++j;
if (A[i] != B[j]) {
matched = false;
break;
}
if (k == axis_c)
++k;
shp[k++] = A[i];
++j;
}
}
megdnn_assert(
matched, "cross op shape mismatch: %s vs %s", A.to_string().c_str(),
B.to_string().c_str());
shp.ndim = A.ndim;
shp[axis_c] = A[axis_a];
C = TensorLayout{shp, A.dtype};
}
void Cross::check_exec(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_in_bytes) {
megdnn_assert_eq_dtype(A, B);
megdnn_assert_eq_dtype(B, C);
TensorLayout c_expected;
deduce_layout(A, B, c_expected);
megdnn_assert_eq_layout(c_expected, C);
megdnn_assert_contiguous(A);
megdnn_assert_contiguous(B);
megdnn_assert_contiguous(C);
auto required_workspace_in_bytes = get_workspace_in_bytes(A, B, C);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void Cross::get_ABC(
const TensorShape& shape, size_t& A, size_t& B, size_t& C, int32_t axis) {
auto shape_arr = shape.shape;
auto ndim = shape.ndim;
if (axis < 0)
axis += ndim;
A = std::accumulate(shape_arr, shape_arr + axis, 1_z, SafeMultiplies<size_t>());
B = shape_arr[axis];
C = std::accumulate(
shape_arr + (axis + 1), shape_arr + ndim, 1_z, SafeMultiplies<size_t>());
}
} // namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
......@@ -224,7 +224,8 @@ private:
cb(GroupNormBackward) \
cb(MaskedFill) \
cb(MultiHeadAttnForward)\
cb(MultiHeadAttnBackward)
cb(MultiHeadAttnBackward) \
cb(Cross)
// clang-format on
/*!
......
......@@ -80,6 +80,7 @@ DEF(IndexingRemapBackward, 3, true, false);
DEF(Linspace, 1, true, false);
DEF(Eye, 1, true, false);
DEF(Diag, 2, true, true);
DEF(Cross, 3, true, true);
DEF(Flip, 2, true, true);
DEF(ROICopy, 2, true, true);
DEF(Rotate, 2, true, true);
......
#include "megdnn/dtype.h"
#include "src/cuda/cross/cross.cuh"
#include "src/cuda/utils.cuh"
namespace {
template <typename T>
__global__ void cross_kernel(
T* A, size_t stride_a0, size_t stride_a1, T* B, size_t stride_b0,
size_t stride_b1, T* C, size_t stride_c0, size_t stride_c1, size_t N) {
size_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i < N) {
size_t ida = (i / stride_a1) * stride_a0 + i % stride_a1;
size_t idb = (i / stride_b1) * stride_b0 + i % stride_b1;
size_t idc = (i / stride_c1) * stride_c0 + i % stride_c1;
C[idc] = A[ida + stride_a1] * B[idb + 2 * stride_b1] -
A[ida + 2 * stride_a1] * B[idb + stride_b1];
C[idc + stride_c1] =
A[ida + 2 * stride_a1] * B[idb] - A[ida] * B[idb + 2 * stride_b1];
C[idc + 2 * stride_c1] =
A[ida] * B[idb + stride_b1] - A[ida + stride_a1] * B[idb];
}
}
} // anonymous namespace
namespace megdnn {
namespace cuda {
namespace cross {
template <typename T>
void exec_internal(
T* A, size_t stride_a0, size_t stride_a1, T* B, size_t stride_b0,
size_t stride_b1, T* C, size_t stride_c0, size_t stride_c1, size_t N,
cudaStream_t stream) {
cross_kernel<T><<<DIVUP(N, NR_THREADS), NR_THREADS, 0, stream>>>(
A, stride_a0, stride_a1, B, stride_b0, stride_b1, C, stride_c0, stride_c1,
N);
after_kernel_launch();
}
#define INST(T) \
template void exec_internal<T>( \
T*, size_t, size_t, T*, size_t, size_t, T*, size_t, size_t, size_t, \
cudaStream_t);
#define cb(DType) INST(typename DTypeTrait<DType>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef INST
#undef cb
} // namespace cross
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
#pragma once
#include <cuda_runtime_api.h>
#include <stdint.h>
namespace megdnn {
namespace cuda {
namespace cross {
template <typename T>
void exec_internal(
T* A, size_t stride_a0, size_t stride_a1, T* B, size_t stride_b0,
size_t stride_b1, T* C, size_t stride_c0, size_t stride_c1, size_t N,
cudaStream_t stream);
} // namespace cross
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
#include "src/cuda/cross/opr_impl.h"
#include "src/cuda/cross/cross.cuh"
#include "src/cuda/utils.h"
#include <algorithm>
#include <numeric>
namespace megdnn {
namespace cuda {
void CrossImpl::exec(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) {
check_exec(A.layout, B.layout, C.layout, workspace.size);
size_t a1, b1, c1, a2, b2, c2, a3, b3, c3;
get_ABC(A.layout, a1, b1, c1, param().axisa);
get_ABC(B.layout, a2, b2, c2, param().axisb);
get_ABC(C.layout, a3, b3, c3, param().axisc);
#define cb(DType) \
if (C.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
using ctype = typename DTypeTrait<DType>::ctype; \
cross::exec_internal<ctype>( \
A.ptr<ctype>(), b1 * c1, c1, B.ptr<ctype>(), b2 * c2, c2, \
C.ptr<ctype>(), b3 * c3, c3, a1 * c1, cuda_stream(handle())); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
}
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace cuda {
class CrossImpl final : public Cross {
public:
using Cross::Cross;
void exec(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& A, const TensorLayout& B,
const TensorLayout& C) override {
return 0;
}
};
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
......@@ -16,6 +16,7 @@
#include "src/cuda/convolution3d/opr_impl.h"
#include "src/cuda/convpooling/opr_impl.h"
#include "src/cuda/correlation/opr_impl.h"
#include "src/cuda/cross/opr_impl.h"
#include "src/cuda/cumsum/opr_impl.h"
#include "src/cuda/cvt_color/opr_impl.h"
#include "src/cuda/dct/opr_impl.h"
......@@ -234,6 +235,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(RegionRestrictedConvolutionBackwardData);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RegionRestrictedConvolutionBackwardFilter);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MultiHeadAttnForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MultiHeadAttnBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Cross);
template <typename Opr>
std::unique_ptr<Opr> HandleImpl::create_operator() {
......
#include "src/naive/cross/opr_impl.h"
#include <algorithm>
#include <numeric>
#include "src/common/utils.h"
#include "src/naive/handle.h"
namespace megdnn {
namespace naive {
template <typename ctype>
void CrossImpl::exec_internal(
ctype* A, size_t a1, size_t b1, size_t c1, ctype* B, size_t a2, size_t b2,
size_t c2, ctype* C, size_t a3, size_t b3, size_t c3) {
(void)a2;
(void)a3;
size_t N = a1 * c1;
for (size_t i = 0; i < N; ++i) {
size_t ida = (i / c1) * b1 * c1 + i % c1;
size_t idb = (i / c2) * b2 * c2 + i % c2;
size_t idc = (i / c3) * b3 * c3 + i % c3;
C[idc] = A[ida + c1] * B[idb + 2 * c2] - A[ida + 2 * c1] * B[idb + c2];
C[idc + c3] = A[ida + 2 * c1] * B[idb] - A[ida] * B[idb + 2 * c2];
C[idc + 2 * c3] = A[ida] * B[idb + c2] - A[ida + c1] * B[idb];
}
}
void CrossImpl::exec(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) {
check_exec(A.layout, B.layout, C.layout, workspace.size);
size_t a1, b1, c1, a2, b2, c2, a3, b3, c3;
get_ABC(A.layout, a1, b1, c1, param().axisa);
get_ABC(B.layout, a2, b2, c2, param().axisb);
get_ABC(C.layout, a3, b3, c3, param().axisc);
#define cb(DType) \
if (A.layout.dtype == DType()) { \
using ctype = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal<ctype>( \
A.ptr<ctype>(), a1, b1, c1, B.ptr<ctype>(), a2, b2, c2, \
C.ptr<ctype>(), a3, b3, c3)); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
}
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace naive {
class CrossImpl : public Cross {
public:
using Cross::Cross;
void exec(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&) override {
return 0;
}
private:
template <typename ctype>
void exec_internal(
ctype* A, size_t a1, size_t b1, size_t c1, ctype* B, size_t a2, size_t b2,
size_t c2, ctype* C, size_t a3, size_t b3, size_t c3);
};
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
......@@ -18,6 +18,7 @@
#include "src/naive/convolution3d/opr_impl.h"
#include "src/naive/convpooling/opr_impl.h"
#include "src/naive/correlation/opr_impl.h"
#include "src/naive/cross/opr_impl.h"
#include "src/naive/cumsum/opr_impl.h"
#include "src/naive/cvt_color/opr_impl.h"
#include "src/naive/dct/opr_impl.h"
......
......@@ -112,7 +112,6 @@ void reduce_fwd(
}
}
#if !MGE_BUILD_WITHOUT_NAIVE_EXEC
template <>
void reduce_fwd<Mode::SUM>(
const dt_quint8* __restrict, dt_quint8* __restrict, size_t, size_t, size_t) {
......@@ -176,7 +175,6 @@ void reduce_fwd<Mode::PRODUCT>(
"Reduce (PRODUCT) with DEFAULT DataType is not supported "
"on QuantizedS8");
}
#endif
template <Mode mode>
void dispatch_dtype(
......@@ -240,7 +238,6 @@ size_t ReduceForwardImpl::get_workspace_in_bytes(
void ReduceForwardImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
#if !MGE_BUILD_WITHOUT_NAIVE_EXEC
using namespace reduce;
check_exec(src.layout, dst.layout, workspace.size);
size_t A, B, C;
......@@ -307,9 +304,6 @@ void ReduceForwardImpl::exec(
megdnn_assert_internal(false);
}
#undef CASE
#else
__builtin_trap();
#endif
}
} // namespace naive
......
......@@ -78,7 +78,6 @@ void on_dest_ctype(HandleImpl* handle, const TensorND& dest, const TensorND& src
} // anonymous namespace
void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
#if !MGE_BUILD_WITHOUT_NAIVE_EXEC
check_exec(src.layout, dst.layout);
// exec
......@@ -96,9 +95,6 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
#undef cb
default : megdnn_throw("bad dtype");
}
#else
__builtin_trap();
#endif
}
// vim: syntax=cpp.doxygen
#include "test/cuda/fixture.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
namespace megdnn {
namespace test {
TEST_F(CUDA, CROSS) {
Checker<Cross> checker(handle_cuda());
for (DType dtype :
std::vector<DType>{dtype::Float16(), dtype::Int32(), dtype::Float32()}) {
checker.set_param({-2, 1, -1})
.set_dtype(0, dtype)
.set_dtype(1, dtype)
.set_dtype(2, dtype);
checker.exec(TensorShapeArray{{2, 3, 4}, {2, 3, 4}, {2, 4, 3}});
checker.set_param({0, -1, 2})
.set_dtype(0, dtype)
.set_dtype(1, dtype)
.set_dtype(2, dtype);
checker.exec(TensorShapeArray{{3, 2, 3, 4}, {2, 3, 4, 3}, {2, 3, 3, 4}});
}
}
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
#include "test/naive/fixture.h"
#include "megdnn/oprs/linalg.h"
#include "test/common/checker.h"
using namespace megdnn;
using namespace test;
TEST_F(NAIVE, CROSS) {
Checker<Cross> checker(handle(), /* check_dispatch */ false);
Cross::Param param{2, 0, -3};
TensorND A = TensorValue(
{2, 2, 3, 3}, dtype::Float32(),
{0.14156187, 0.91729394, 0.74985146, 0.06110339, 0.67560272, 0.38569672,
0.63086542, 0.96403808, 0.44215527, 0.04618851, 0.17031024, 0.83291294,
0.15374447, 0.57640597, 0.90293485, 0.55324231, 0.75830697, 0.43003672,
0.18093289, 0.73350109, 0.38130576, 0.29138499, 0.77249111, 0.0878262,
0.11699259, 0.73904961, 0.52122415, 0.10095503, 0.15948783, 0.88822725,
0.12915866, 0.59800798, 0.17838123, 0.37576929, 0.9885241, 0.25812663});
TensorND B = TensorValue(
{3, 2, 2, 3}, dtype::Float32(),
{0.87391774, 0.65328165, 0.44478127, 0.94795241, 0.84836271, 0.8054875,
0.44598355, 0.50247015, 0.36997338, 0.77503215, 0.87235448, 0.21397323,
0.85153319, 0.16009368, 0.77536431, 0.04486964, 0.37626156, 0.81259061,
0.08839276, 0.50477703, 0.91626721, 0.44417563, 0.03038398, 0.21980224,
0.71472471, 0.85791626, 0.40345219, 0.16018583, 0.38511319, 0.23809232,
0.58711753, 0.92031592, 0.65426633, 0.14788665, 0.73273333, 0.3182309});
TensorND C = TensorValue(
{2, 3, 2, 3}, dtype::Float32(),
{-0.49353074, 0.42527416, -0.18722123, -0.0001961, -0.06334022,
-0.13446194, 0.45014671, -0.157173, -0.10586683, 0.51704864,
0.57773064, 0.14807902, 0.0671453, -0.29450591, 0.40985739,
-0.14366998, -0.42492013, -0.05048549, 0.16073594, 0.3378806,
-0.42011888, -0.14780672, 0.40814509, 0.00002961, -0.0540521,
-0.30370236, -0.05663646, 0.27630338, 0.74548138, -0.22742917,
-0.11395976, -0.01789922, 0.31688461, -0.05526035, -0.51682907,
0.15706553});
checker.set_param(param).exect(Testcase{A, B, {}}, Testcase{{}, {}, C});
}
......@@ -153,13 +153,15 @@ def synchronized(func: Callable):
def _check_device_initialized(device_type: str, rank: int):
try:
test = Tensor(1, device=(device_type + str(rank)))
inited = False
del test
except:
inited = True
errmsg = "The cuda env is set before the forked thread starts. Please do not use any cuda function or variable before forking."
if inited:
raise RuntimeError(errmsg)
except Exception as e:
errmsg = (
"Device initialization check failed, which may be caused "
"by using CUDA before forking the thread. Please review "
"the code to ensure that no CUDA functions or variables "
"are used before forking."
)
raise RuntimeError(errmsg) from e
def _check_interpreter_status():
......
......@@ -13,12 +13,13 @@ from ..core.tensor.utils import _normalize_axis
from ..tensor import Tensor
from ..utils.deprecation import deprecated_kwargs_default
from .elemwise import _elemwise_multi_type, clip
from .tensor import broadcast_to, expand_dims, squeeze
from .tensor import broadcast_to, concat, expand_dims, squeeze, zeros
__all__ = [
"argmax",
"argmin",
"argsort",
"cross",
"dot",
"isinf",
"isnan",
......@@ -720,6 +721,50 @@ def matinv(inp: Tensor) -> Tensor:
return result
def cross(
a: Tensor,
b: Tensor,
axisa: int = -1,
axisb: int = -1,
axisc: int = -1,
axis: int = None,
) -> Tensor:
r"""Return the cross product of two (arrays of) vectors.
The cross product of ``a`` and ``b`` in `R^3` is a vector perpendicular to both ``a`` and ``b``. If ``a`` and ``b`` are arrays of vectors, the vectors are defined by the last axis of ``a`` and ``b`` by default, and these axes can have dimensions 2 or 3.
Where the dimension of either ``a`` or ``b`` is 2, the third component of the input vector is assumed to be zero and the cross product calculated accordingly.
Args:
a: components of the first vector(s).
b: components of the first vector(s).
axisa: axis of a that defines the vector(s). By default, the last axis.
axisb: axis of b that defines the vector(s). By default, the last axis.
axisc: axis of c containing the cross product vector(s). By default, the last axis.
axis: if defined, the axis of a, b and c that defines the vector(s) and cross product(s). Overrides axisa, axisb and axisc.
Returns:
vector cross product(s).
Examples:
>>> a = Tensor([1.0, 2.0, 3.0])
>>> b = Tensor([4.0, 5.0, 6.0])
>>> out = F.cross(a, b)
>>> out.numpy()
array([-3., 6., -3.], dtype=float32)
"""
if axis is not None:
axisa = axisb = axisc = axis
if a.ndim == 1 and len(a) == 2:
a = Tensor(np.append(a, 0))
if b.ndim == 1 and len(b) == 2:
b = Tensor(np.append(b, 0))
op = builtin.Cross(axisa, axisb, axisc)
(result,) = apply(op, a, b)
return result
def matmul(
inp1: Tensor,
inp2: Tensor,
......
......@@ -37,3 +37,9 @@ from .quant_dequant import DequantStub, QuantStub
from .rnn import LSTM, RNN, LSTMCell, RNNCell
from .sequential import Sequential
from .sliding_window import SlidingWindow, SlidingWindowTranspose
from .vision import (
AdditiveElemwise,
AdditiveGaussianNoise,
AdditiveLaplaceNoise,
AdditivePoissonNoise,
)
import numpy as np
from ..functional.elemwise import abs, add, log
from ..functional.math import sign
from ..functional.tensor import broadcast_to
from ..random.rng import RNG
from ..tensor import Tensor
from .module import Module
class AdditiveElemwise(Module):
def __init__(self, per_channel=False, **kwargs):
self._per_channel = per_channel
super().__init__(**kwargs)
def forward(self, inp):
assert isinstance(
inp, Tensor
), "expected input is megengine.Tensor, but got {}".format(type(inp))
if self._per_channel is True:
noise = self.sample(inp.shape).to(inp.device)
elif self._per_channel is False:
if inp.format == "nchw":
N, C, H, W = inp.shape
c_noise = self.sample((N, 1, H, W))
# TODO: fix this code because the inp.shape always nchw output, even if format is "nhwc", cjs.
elif inp.format == "nhwc":
N, H, W, C = inp.shape
c_noise = self.sample((N, H, W, 1))
else:
raise RuntimeError(
"expect you create Tensor with format specified while per_channel is False, got format is {}".format(
inp.format
)
)
noise = broadcast_to(c_noise, inp.shape).to(inp.device)
else:
raise NotImplementedError("float point type per channel haven't impl")
return add(inp, noise)
def sample(self, size):
raise NotImplementedError()
@property
def per_channel(self):
return self._per_channel
@per_channel.setter
def per_channel(self, per_channel):
self._per_channel = per_channel
class AdditiveLaplaceNoise(AdditiveElemwise):
r"""Add random laplace noise to the input data.
Laplace noise is generated with given mean and std, sampled from Laplace distribution
ref to this page to learn more: https://en.wikipedia.org/wiki/Laplace_distribution
Args:
mean: laplace mean used to generate noise.
std: laplace standard deviation used to generate noise.
per_channel: Whether to use (imagewise) the same sample(s) for all channels (False) or to sample value(s) for each channel (True). Setting this to True will therefore lead to different transformations per image and channel, otherwise only per image.
seed: random number seed of generator
"""
def __init__(self, mean=0.0, std=1.0, per_channel=False, seed=None):
assert seed is None or isinstance(seed, int)
super().__init__(per_channel)
self.mean = Tensor(mean, dtype=np.float32)
self.std = Tensor(std, dtype=np.float32)
self.rng_func = RNG(seed).uniform
self.finfo = np.finfo(np.dtype(self.mean.dtype))
self._seed = seed
def sample(self, size):
u = self.rng_func((self.finfo.eps - 1).item(), 1, size)
value = self.mean - self.std * sign(u) * log(1 - abs(u))
return value
@property
def seed(self):
return self._seed
@seed.setter
def seed(self, seed):
assert isinstance(seed, int)
self._seed = seed
self.rng_func = RNG(seed).uniform
class AdditivePoissonNoise(AdditiveElemwise):
r"""Add random poisson noise to the input data.
poission noise is generated with given mean and std.
Args:
lam: lam parameter of poisson distribution used to generate noise.
per_channel: Whether to use (imagewise) the same sample(s) for all channels (False) or to sample value(s) for each channel (True). Setting this to True will therefore lead to different transformations per image and channel, otherwise only per image.
seed: random number seed of generator
"""
def __init__(self, lam=1.0, per_channel=False, seed=None):
assert seed is None or isinstance(seed, int)
super().__init__(per_channel)
self.lam = Tensor(lam, dtype=np.float32)
self.rng_func = RNG(seed).poisson
self._seed = seed
def sample(self, size):
value = self.rng_func(self.lam, size)
return value
@property
def seed(self):
return self._seed
@seed.setter
def seed(self, seed):
assert isinstance(seed, int)
self._seed = seed
self.rng_func = RNG(seed).poisson
class AdditiveGaussianNoise(AdditiveElemwise):
r"""Add random gaussian noise to the input data.
Gaussian noise is generated with given mean and std.
Args:
mean: Gaussian mean used to generate noise.
std: Gaussian standard deviation used to generate noise.
per_channel: Whether to use (imagewise) the same sample(s) for all channels (False) or to sample value(s) for each channel (True). Setting this to True will therefore lead to different transformations per image and channel, otherwise only per image.
seed: random number seed of generator
"""
def __init__(self, mean=0.0, std=1.0, per_channel=False, seed=None):
assert seed is None or isinstance(seed, int)
super().__init__(per_channel)
self.mean = Tensor(mean, dtype=np.float32)
self.std = Tensor(std, dtype=np.float32)
self.rng_func = RNG(seed).normal
self._seed = seed
def sample(self, size):
value = self.rng_func(self.mean, self.std, size)
return value
@property
def seed(self):
return self._seed
@seed.setter
def seed(self, seed):
assert isinstance(seed, int)
self._seed = seed
self.rng_func = RNG(seed).normal
......@@ -20,6 +20,10 @@ def test_basic():
b = tensor(a)
assert b.format == "nhwc"
b = tensor(data, format="nchw")
result = F.pad(b, ((0, 0), (0, 0), (1, 1), (1, 1)), mode="reflect")
assert result.format == "default"
# TODO: init from tensor with new format
# c = tensor(a, format="nchw")
# assert c.format == "nchw"
......
......@@ -1680,3 +1680,49 @@ def test_softmax():
actual = F.softmax(data)
np.testing.assert_allclose(actual.numpy(), desired, rtol=1e-5)
def test_cross():
shape1 = 3
shape2 = 3
shape3 = (4, 2, 3)
shape4 = (4, 2, 3)
data1 = np.random.random(shape1).astype("float32")
data2 = np.random.random(shape2).astype("float32")
data3 = np.random.random(shape3).astype("float32")
data4 = np.random.random(shape4).astype("float32")
cases = [
{"input": [data1, data2]},
{"input": [data3, data4]},
]
opr_test(
cases,
F.cross,
compare_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol=1e-4),
ref_fn=np.cross,
)
data5 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
data6 = np.array([[7, 8, 9], [4, 5, 6], [1, 2, 3]])
res = F.cross(mge.tensor(data5), mge.tensor(data6), axisa=0, axisb=0)
dst = np.cross(data5, data6, axisa=0, axisb=0)
np.testing.assert_allclose(res.numpy(), dst, rtol=1e-4)
data7 = np.array([1, 2, 3])
data8 = np.array([4, 5, 6])
res = F.cross(mge.tensor(data7), mge.tensor(data8), axisa=0, axisb=0)
dst = np.cross(data7, data8, axisa=0, axisb=0)
np.testing.assert_allclose(res.numpy(), dst, rtol=1e-4)
data9 = np.array([[6, 7, 4], [15, 2, 8], [9, 40, 39]])
data10 = np.array([[5, 8, 9], [14, 21, 17], [10, 3, 47]])
res = F.cross(mge.tensor(data9), mge.tensor(data10), axisa=1, axisb=-1)
dst = np.cross(data9, data10, axisa=1, axisb=-1)
np.testing.assert_allclose(res.numpy(), dst, rtol=1e-4)
data11 = np.array([1, 2])
data12 = np.array([4, 5, 6])
res = F.cross(mge.tensor(data11), mge.tensor(data12))
dst = np.cross(data11, data12)
np.testing.assert_allclose(res.numpy(), dst, rtol=1e-4)
import time
import numpy as np
import pytest
from megengine import Tensor
from megengine.module import (
AdditiveGaussianNoise,
AdditiveLaplaceNoise,
AdditivePoissonNoise,
)
@pytest.mark.parametrize(
"cls", [AdditiveGaussianNoise, AdditiveLaplaceNoise, AdditivePoissonNoise]
)
@pytest.mark.parametrize("per_channel", [False, True])
@pytest.mark.parametrize(
"shape, format",
[
((128, 3, 160, 160), "default"),
((128, 160, 160, 3), "nhwc"),
((128, 3, 160, 160), "nchw"),
],
)
@pytest.mark.parametrize("seed", [1024, None])
def test_AdditiveNoise(cls, per_channel, shape, format, seed):
if not per_channel and format == "default":
return
input_tensor = Tensor(
np.random.random(shape), np.float32, device="xpux", format=format
)
aug = cls(per_channel=per_channel, seed=seed)
aug_data = aug(input_tensor)
if seed is not None: # fix rng seed
aug_ref = cls(per_channel=per_channel, seed=seed)
aug_data_ref = aug_ref(input_tensor)
np.testing.assert_allclose(aug_data, aug_data_ref)
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/blas.h"
#include "megdnn/oprs.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
namespace mgb {
namespace imperative {
namespace cross {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<Cross>();
mgb_assert(inputs.size() == 2);
cg::OperatorNodeConfig config{op.make_name()};
return opr::Cross::make(inputs[0], inputs[1], op.param(), config);
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
mgb_assert(inputs.size() == 2, "Cross expects two inputs");
auto&& op_def = def.cast_final_safe<Cross>();
auto comp_node = inputs[0].comp_node;
if (!inputs[0].layout.ndim) {
return {{{inputs[0].layout, comp_node}}, false};
}
if (!inputs[1].layout.ndim) {
return {{{inputs[1].layout, comp_node}}, false};
}
DnnOprHelper<megdnn::Cross> dnn_op(op_def.param());
auto oup_layout = dnn_op.deduce_layout(inputs[0].layout, inputs[1].layout);
return {{{oup_layout, comp_node}}, true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto comp_node = inputs[0]->comp_node();
auto&& op_def = def.cast_final_safe<Cross>();
DnnOprCaller<megdnn::Cross> dnn_op(comp_node, op_def.param());
auto dst = [&] {
if (validated) {
return output_descs[0].layout;
} else {
return dnn_op.deduce_layout(inputs[0]->layout(), inputs[1]->layout());
}
}();
auto out = Tensor::make(dst, comp_node);
if (!inputs[0]->layout().is_empty() && !inputs[1]->layout().is_empty()) {
dnn_op.exec_with_ws(inputs[0], inputs[1], out);
}
return {out};
}
OP_TRAIT_REG(Cross, Cross)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace cross
} // namespace imperative
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -435,13 +435,22 @@ inline FT get_inputs_format(Span<ValueRef>& inputs, const FormatTransformation&
return format;
}
inline bool if_convert_format(const Format src_fmt, const FT& dst_fmt) {
if ((src_fmt == FT::NCHW && dst_fmt == FT::DEFAULT) ||
(src_fmt == FT::DEFAULT && dst_fmt == FT::NCHW)) {
return false;
} else {
return true;
}
}
inline ValueRefList unify_inputs_format(
const Span<ValueRef>& inputs, const FT& dst_fmt, const std::string& scope,
const FormatTransformation& t) {
ValueRefList unified_inputs(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
auto&& inp = inputs[i].cast(t.value_type());
if (inp.format() != dst_fmt) {
if (inp.format() != dst_fmt && if_convert_format(inp.format(), dst_fmt)) {
unified_inputs[i] = t.to(inp, dst_fmt, scope);
} else {
unified_inputs[i] = inputs[i];
......
29b2127eb4034bf24e473945d70ead4a ../../dnn/scripts/opr_param_defs.py
639ff50d64fcb78374de266c88942c2c ../../src/core/include/megbrain/ir/ops.td
16654743e01160eeee879107cc4cac41 generated/opdef.h.inl
97c541ed45b0be98f1ac2700f5b4d8a6 generated/opdef.cpp.inl
6f9c6a7a1d71cca195c1e30743a1f542 generated/opdef.py.inl
806c5ceb34f571fc5c9d98d2ca8cad63 generated/opdef.cpy.inl
e4035bfefce3a2cc0e8cc6ec7fcac227 ../../dnn/scripts/opr_param_defs.py
13ab898fce3749ebbcabf7c145876147 ../../src/core/include/megbrain/ir/ops.td
9dda6e2db75279373ec6809b297a2370 generated/opdef.h.inl
aabc2d8146742faacabf56e376177e7b generated/opdef.cpp.inl
8a5dffac1df3286178b3fd304c39b5da generated/opdef.py.inl
04322b642bba8f684034fcc5dc27efcf generated/opdef.cpy.inl
911001ef0dd771024919f7a1a3a009db generated/enum_macro.h
......@@ -2303,6 +2303,49 @@ OP_TRAIT_REG(Correlation, Correlation)
.props(Correlation_props_impl)
.make_name(Correlation_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Cross);
namespace {
size_t Cross_hash_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<Cross>();
static_cast<void>(op_);
size_t val = mgb::hash(op_.dyn_typeinfo());
val = mgb::hash_pair_combine(val, mgb::hash(op_.axisa));
val = mgb::hash_pair_combine(val, mgb::hash(op_.axisb));
val = mgb::hash_pair_combine(val, mgb::hash(op_.axisc));
return val;
}
bool Cross_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
auto &&a_ = lhs_.cast_final_safe<Cross>(),
&&b_ = rhs_.cast_final_safe<Cross>();
static_cast<void>(a_);
static_cast<void>(b_);
if (a_.axisa != b_.axisa) return false;
if (a_.axisb != b_.axisb) return false;
if (a_.axisc != b_.axisc) return false;
return true;
}
std::vector<std::pair<const char*, std::string>> Cross_props_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<Cross>();
static_cast<void>(op_);
std::vector<std::pair<const char*, std::string>> props_;
props_.emplace_back("axisa", std::to_string(op_.axisa));
props_.emplace_back("axisb", std::to_string(op_.axisb));
props_.emplace_back("axisc", std::to_string(op_.axisc));
return props_;
}
std::string Cross_make_name_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<Cross>();
static_cast<void>(op_);
return "Cross";
}
} // anonymous namespace
OP_TRAIT_REG(Cross, Cross)
.hash(Cross_hash_impl)
.is_same_st(Cross_is_same_st_impl)
.props(Cross_props_impl)
.make_name(Cross_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Cumsum);
namespace {
......
......@@ -7274,6 +7274,151 @@ void _init_py_Correlation(py::module m) {
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Correlation::typeinfo(), &py_type).second);
}
PyOpDefBegin(Cross) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
static PyObject* getstate(PyObject* self, PyObject*) {
auto& opdef = reinterpret_cast<PyOp(Cross)*>(self)->inst();
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {
{"axisa", serialization<decltype(opdef.axisa)>::dump(opdef.axisa)},
{"axisb", serialization<decltype(opdef.axisb)>::dump(opdef.axisb)},
{"axisc", serialization<decltype(opdef.axisc)>::dump(opdef.axisc)}
};
return py::cast(state).release().ptr();
}
static PyObject* setstate(PyObject* self, PyObject* args) {
PyObject* dict = PyTuple_GetItem(args, 0);
if (!dict) return NULL;
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
auto& opdef = reinterpret_cast<PyOp(Cross)*>(self)->inst();
static_cast<void>(opdef);
{
auto&& iter = state.find("axisa");
if (iter != state.end()) {
opdef.axisa = serialization<decltype(opdef.axisa)>::load(iter->second);
}
}
{
auto&& iter = state.find("axisb");
if (iter != state.end()) {
opdef.axisb = serialization<decltype(opdef.axisb)>::load(iter->second);
}
}
{
auto&& iter = state.find("axisc");
if (iter != state.end()) {
opdef.axisc = serialization<decltype(opdef.axisc)>::load(iter->second);
}
}
Py_RETURN_NONE;
}
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds);
static PyMethodDef py_init_methoddef;
// };
PyOpDefEnd(Cross)
int PyOp(Cross)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
static const char* kwlist[] = {"axisa", "axisb", "axisc", "scope", NULL};
PyObject *axisa = NULL, *axisb = NULL, *axisc = NULL, *scope = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOO", const_cast<char**>(kwlist), &axisa, &axisb, &axisc, &scope))
return -1;
if (axisa) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(Cross)*>(self)->inst().axisa =
py::cast<decltype(Cross::axisa)>(py::handle(axisa));
} CATCH_ALL(-1)
}
if (axisb) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(Cross)*>(self)->inst().axisb =
py::cast<decltype(Cross::axisb)>(py::handle(axisb));
} CATCH_ALL(-1)
}
if (axisc) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(Cross)*>(self)->inst().axisc =
py::cast<decltype(Cross::axisc)>(py::handle(axisc));
} CATCH_ALL(-1)
}
if (scope) {
try {
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(py::cast<std::string>(py::handle(scope)));
} CATCH_ALL(-1)
}
return 0;
}
PyGetSetDef PyOp(Cross)::py_getsetters[] = {
{const_cast<char*>("axisa"), py_get_generic(Cross, axisa), py_set_generic(Cross, axisa), const_cast<char*>("axisa"), NULL},
{const_cast<char*>("axisb"), py_get_generic(Cross, axisb), py_set_generic(Cross, axisb), const_cast<char*>("axisb"), NULL},
{const_cast<char*>("axisc"), py_get_generic(Cross, axisc), py_set_generic(Cross, axisc), const_cast<char*>("axisc"), NULL},
{NULL} /* Sentinel */
};
PyMethodDef PyOp(Cross)::tp_methods[] = {
{const_cast<char*>("__getstate__"), PyOp(Cross)::getstate, METH_NOARGS, "Cross getstate"},
{const_cast<char*>("__setstate__"), PyOp(Cross)::setstate, METH_VARARGS, "Cross setstate"},
{NULL} /* Sentinel */
};
PyObject *PyOp(Cross)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) {
if (PyOp(Cross)::py_init(self, args, kwds) < 0) {
return NULL;
}
Py_RETURN_NONE;
}
PyMethodDef PyOp(Cross)::py_init_methoddef = {
"__init__",
(PyCFunction)PyOp(Cross)::py_init_proxy,
METH_VARARGS | METH_KEYWORDS,
"__init__(self, axisa: int = ..., axisb: int = ..., axisc: int = ...) -> None\n"
};
void _init_py_Cross(py::module m) {
using py_op = PyOp(Cross);
auto& py_type = PyOpType(Cross);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.Cross";
py_type.tp_basicsize = sizeof(PyOp(Cross));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "Cross";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_methods = py_op::tp_methods;
py_type.tp_getset = py_op::py_getsetters;
py_type.tp_dict = PyDict_New();
PyObject* descr = PyDescr_NewMethod(&PyOpType(Cross), &PyOp(Cross)::py_init_methoddef);
PyDict_SetItemString(py_type.tp_dict, "__init__", descr);
mgb_assert(PyType_Ready(&py_type) >= 0);
PyType_Modified(&py_type);
m.add_object("Cross", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Cross::typeinfo(), &py_type).second);
}
PyOpDefBegin(Cumsum) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
......@@ -23420,6 +23565,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) {
_init_py_ConvolutionBackwardData(m); \
_init_py_Copy(m); \
_init_py_Correlation(m); \
_init_py_Cross(m); \
_init_py_Cumsum(m); \
_init_py_CvtColor(m); \
_init_py_DeformableConv(m); \
......
......@@ -567,6 +567,21 @@ public:
}
};
class Cross : public OpDefImplBase<Cross> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
int32_t axisa = -1;
int32_t axisb = -1;
int32_t axisc = -1;
Cross() = default;
Cross(int32_t axisa_, int32_t axisb_, int32_t axisc_, std::string scope_ = {}): axisa(axisa_), axisb(axisb_), axisc(axisc_) { set_scope(scope_); }
Cross(::megdnn::param::Cross packed_param_0): axisa(packed_param_0.axisa), axisb(packed_param_0.axisb), axisc(packed_param_0.axisc) {}
::megdnn::param::Cross param() const {
return {axisa, axisb, axisc};
}
};
class Cumsum : public OpDefImplBase<Cumsum> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
......
......@@ -678,6 +678,14 @@ CorrelationInst
.def_readwrite("pad_size", &Correlation::pad_size)
.def_readwrite("is_multiply", &Correlation::is_multiply);
py::class_<Cross, std::shared_ptr<Cross>, OpDef> CrossInst(m, "Cross");
CrossInst
.def(py::init<int32_t, int32_t, int32_t, std::string>(), py::arg("axisa") = -1, py::arg("axisb") = -1, py::arg("axisc") = -1, py::arg("scope") = {})
.def_readwrite("axisa", &Cross::axisa)
.def_readwrite("axisb", &Cross::axisb)
.def_readwrite("axisc", &Cross::axisc);
py::class_<Cumsum, std::shared_ptr<Cumsum>, OpDef> CumsumInst(m, "Cumsum");
CumsumInst
......
......@@ -112,10 +112,22 @@ void NetworkImplDft::application_config() {
loc.type = m_compnode_locator.type;
}
loc.device = m_compnode_locator.device;
//! the user configured stream
auto stream = m_compnode_locator.stream;
//! if user set the thread number and the compnode is multithread
if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD &&
m_nr_threads != 1) {
loc.stream = m_nr_threads;
if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD) {
if (m_nr_threads != 1) {
loc.nr_threads = m_nr_threads;
}
//! user set the stream to separate the different multithread
if (stream != 0) {
auto device = m_compnode_locator.device;
//! the device is also set by user, so combine them to one
//! int
if (device == -1) {
loc.device = stream;
}
}
} else {
loc.stream = m_compnode_locator.stream;
}
......
......@@ -148,7 +148,8 @@ public:
virtual void set_device_id(int device_id) = 0;
virtual int get_device_id() const = 0;
virtual LiteBackend get_backend_type() const = 0;
//! set stream id, default stream id = 0
//! set stream id, default stream id = 0, if there are multi compnode in a
//! model, set all the compnode stream to the stream_id
virtual void set_stream_id(int stream_id) = 0;
virtual int get_stream_id() const = 0;
......
......@@ -1665,6 +1665,30 @@ TEST(TestNetWork, AtlasLoadAtlasDeviceInput) {
TEST(TestNetWork, AtlasDeviceID) {
load_device_id(LiteDeviceType::LITE_ATLAS, 1, "./model_atlas.mgb");
}
TEST(TestNetWork, AtlasCrossCompnodeStreamID) {
auto thread = [](int stream_id) {
lite::Config config;
config.device_type = LiteDeviceType::LITE_ATLAS;
auto network = std::make_shared<lite::Network>(config);
network->set_stream_id(stream_id);
network->load_model("./model_atlas.mgb");
std::shared_ptr<Tensor> input_tensor = network->get_input_tensor(0);
std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0);
for (int i = 0; i < 10; i++) {
network->forward();
network->wait();
}
};
std::thread t0(thread, 1);
std::thread t1(thread, 2);
std::thread t2(thread, 3);
std::thread t3(thread, 4);
t0.join();
t1.join();
t2.join();
t3.join();
}
#endif
#if MGB_CAMBRICON
......
......@@ -548,7 +548,6 @@ def Dropout: MgbHashableOp<"Dropout", [DropoutParam]> {
);
}];
let cmpFunction = [{return $0.handle == $1.handle && $0.drop_prob == $1.drop_prob;}];
}
def MeshGrid: MgbHashableOp<"MeshGrid"> {
let extraArguments = (ins
......@@ -639,4 +638,6 @@ def MultiHeadAttn: MgbHashableOp<"MultiHeadAttn", [MultiHeadAttnParam]> {
}
def Cross: MgbHashableOp<"Cross", [CrossParam]>;
#endif // MGB_OPS
......@@ -736,4 +736,29 @@ SymbolVarArray SVD::make(
return ret;
}
/* ================= Cross ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Cross);
MEGDNN_OPR_INIT2(Cross, "cross")
void Cross::add_input_layout_constraint() {
input(0)->add_layout_constraint_contiguous();
input(1)->add_layout_constraint_contiguous();
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Cross) {
SymbolVar grad, i0{opr.input(0)}, i1{opr.input(1)}, og{out_grad[0]};
if (wrt_idx == 0) {
grad = Cross::make(
i1, og, {opr.param().axisb, opr.param().axisc, opr.param().axisa});
} else {
mgb_assert(wrt_idx == 1);
grad = Cross::make(
og, i0, {opr.param().axisc, opr.param().axisa, opr.param().axisb});
}
return grad.node();
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -48,4 +48,9 @@ decl_opr('SVD',
desc='Computes the singular value decompositions of matrices. '
'The input must has shape ``[..., M, N]``.')
decl_opr('Cross',
inputs=['A', 'B'],
params='Cross',
desc='computes the cross product of two (arrays of) vectors.')
# vim: ft=python
......@@ -92,7 +92,7 @@ MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(BatchedMatrixMulV2, 2, opr_shallow_copy_matmu
MGB_SEREG_OPR(Dot, 2);
MGB_SEREG_OPR(MatrixInverse, 1);
MGB_SEREG_OPR(SVD, 1);
MGB_SEREG_OPR(Cross, 2);
} // namespace opr
} // namespace mgb
......
......@@ -122,6 +122,19 @@ public:
const OperatorNodeConfig& config = {});
};
MGB_DEFINE_OPR_CLASS(Cross, intl::MegDNNOprWrapperFwd<megdnn::Cross>) // {
public:
MGE_WIN_DECLSPEC_FUC Cross(
VarNode* A, VarNode* B, const Param& param,
const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar A, SymbolVar B, const Param& param,
const OperatorNodeConfig& config = {});
private:
void add_input_layout_constraint() override;
};
} // namespace opr
} // namespace mgb
......
......@@ -634,6 +634,28 @@ TEST(TestOprBlas, Dot) {
.run({TensorShape{0}, TensorShape{0}});
}
TEST(TestOprBlas, Cross) {
using Checker = AutoOprChecker<2, 1>;
auto nopr = megdnn_naive_handle()->create_operator<megdnn::Cross>();
nopr->param() = {-1, -1, -1};
auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
return {opr::Cross::make(inputs[0], inputs[1], {-1, -1, -1})};
};
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
auto &&a = *inp[0], &&b = *inp[1];
dest[0].resize(a.shape());
nopr->exec(a.as_megdnn(), b.as_megdnn(), dest[0].as_megdnn(), {});
};
Checker(make_graph, fwd)
.run({TensorShape{3}, TensorShape{3}})
.run({TensorShape{6, 3}, TensorShape{6, 3}})
.run({TensorShape{2, 4, 3}, TensorShape{2, 4, 3}})
.run({TensorShape{2, 5, 2, 3}, TensorShape{2, 5, 2, 3}});
}
TEST(TestOprBlas, TransMatMul) {
run_trans_inp_test<float, float>();
}
......
......@@ -128,6 +128,7 @@ union OperatorParam {
param.GeneralNorm=94,
param.MultiHeadAttn=95,
param.Resize3D = 96,
param.Cross = 97,
}
table Operator {
......
......@@ -145,6 +145,7 @@ union OperatorParam {
param.GeneralNorm=94,
param.MultiHeadAttn=95,
param.Resize3D = 96,
param.Cross = 97,
}
table Operator {
......