diff --git a/dnn/include/megdnn/oprs/general.h b/dnn/include/megdnn/oprs/general.h index 30df4dd3a03425769d2d53c7a32170ab842c4a49..cfdcc3da2dffa0bd660b8b5f0c7274270a4e63d4 100644 --- a/dnn/include/megdnn/oprs/general.h +++ b/dnn/include/megdnn/oprs/general.h @@ -998,6 +998,28 @@ protected: void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); }; +class Diag : public OperatorBase { + DEF_OPR_IMPL(Diag, OperatorBase, 1, 1); + DEF_OPR_PARAM(Diag); + +public: + /** + * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.diag.html + */ + + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + void deduce_layout(const TensorLayout& src, TensorLayout& dst); + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; + +protected: + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes); +}; + class IndexingOneHotBase : public OperatorBase { DEF_OPR_IMPL_CTOR(IndexingOneHotBase, OperatorBase); DEF_OPR_PARAM(Axis); diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index b4feb266962626e969333021799081270df0801a..481c63b595b0952c45a9c128901339606f722cc5 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -759,6 +759,14 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) 'dtype', Doc('dtype', 'data type of output value'), 'DTypeEnum::Float32')) +(pdef('Diag'). + add_fields( + 'int32', + Doc('k', 'Index of the diagonal: 0 (the default) refers to the main ' + 'diagonal, a positive value refers to an upper diagonal, and a ' + 'negative value to a lower diagonal.'), + 0)) + (pdef('UniformRNG', version=0, is_legacy=True). add_fields('uint64', 'seed', 0)) diff --git a/dnn/src/common/diag.cpp b/dnn/src/common/diag.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7fed7636ae6adfa730fea3b065a241594d523a30 --- /dev/null +++ b/dnn/src/common/diag.cpp @@ -0,0 +1,47 @@ +/** + * \file dnn/src/common/diag.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "megdnn/oprs.h" + +#include "src/common/utils.h" + +namespace megdnn { + +void Diag::deduce_layout(const TensorLayout& src, TensorLayout& dst) { + megdnn_assert( + src.ndim == 1 || src.ndim == 2, "Only support vector or matrix as input."); + int k = param().k; + if (src.ndim == 1) { + size_t o = src.total_nr_elems() + std::abs(k); + dst = TensorLayout(TensorShape({o, o}), src.dtype); + } else { // src.ndim == 2 + size_t m = src.shape[0]; + size_t n = src.shape[1]; + size_t o = (k >= 0 ? std::min(n - k, m) : std::min(m + k, n)); + megdnn_assert(o > 0, "The moved diagonal is out of the input matrix."); + dst = TensorLayout(TensorShape({o}), src.dtype); + } +} + +void Diag::check_exec( + const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { + TensorLayout dst_expected; + megdnn_assert_eq_dtype(src, dst); + deduce_layout(src, dst_expected); + megdnn_assert_eq_layout(dst_expected, dst); + + megdnn_assert_contiguous(dst); + auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); +} + +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index 1e717915281e14c0714d0fd22f69c505377d6625..bfb35bb9b7781d3db86112ad48a207df9d200c51 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -146,6 +146,7 @@ private: cb(BatchedSetMeshIndexing) \ cb(Linspace) \ cb(Eye) \ + cb(Diag) \ cb(SleepForward) \ cb(UniformRNG) \ cb(GaussianRNG) \ diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h index 6a2e4d33f50be84594d6287e17e8a37b6ee9c840..4cb6fa1f873fd91ad8e08713b3b01077eb637584 100644 --- a/dnn/src/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -88,6 +88,7 @@ DEF(IndexingRemapForward, 3, true, true); DEF(IndexingRemapBackward, 3, true, false); DEF(Linspace, 1, true, false); DEF(Eye, 1, true, false); +DEF(Diag, 2, true, true); DEF(Flip, 2, true, true); DEF(ROICopy, 2, true, true); DEF(Rotate, 2, true, true); diff --git a/dnn/src/cuda/diag/diag.cu b/dnn/src/cuda/diag/diag.cu new file mode 100644 index 0000000000000000000000000000000000000000..6581e255f664aa7efbb64abae5a39170a0640a0d --- /dev/null +++ b/dnn/src/cuda/diag/diag.cu @@ -0,0 +1,87 @@ +/** + * \file dnn/src/cuda/diag/diag.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "megdnn/dtype.h" +#include "src/cuda/diag/diag.cuh" +#include "src/cuda/utils.cuh" + +namespace { + +template +__global__ void kernel_to_vector( + T* src, T* dst, ptrdiff_t start, ptrdiff_t size, ptrdiff_t stride_sum, + ptrdiff_t dst_stride) { + ptrdiff_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i < size) { + dst[dst_stride * i] = src[start + stride_sum * i]; + } +} + +template +__global__ void kernel_to_matrix( + T* src, T* dst, ptrdiff_t offset, ptrdiff_t n, ptrdiff_t k, + ptrdiff_t dst_stride0, ptrdiff_t dst_stride1, ptrdiff_t src_stride) { + ptrdiff_t i = threadIdx.x + blockIdx.x * blockDim.x; + ptrdiff_t x = i % n; + ptrdiff_t y = i / n; + ptrdiff_t p = dst_stride0 * y + dst_stride1 * x; + if (i < n * n) { + if (y + k == x) + dst[p] = src[src_stride * (y - offset)]; + else + dst[p] = 0; + } +} + +} // anonymous namespace + +namespace megdnn { +namespace cuda { +namespace diag { + +template +void exec_internal_to_vector( + T* src, T* dst, ptrdiff_t start, ptrdiff_t size, ptrdiff_t stride_sum, + ptrdiff_t dst_stride, cudaStream_t stream) { + kernel_to_vector<<>>( + src, dst, start, size, stride_sum, dst_stride); + after_kernel_launch(); +} + +template +void exec_internal_to_matrix( + T* src, T* dst, ptrdiff_t offset, ptrdiff_t n, ptrdiff_t k, + ptrdiff_t dst_stride0, ptrdiff_t dst_stride1, ptrdiff_t src_stride, + cudaStream_t stream) { + kernel_to_matrix<<>>( + src, dst, offset, n, k, dst_stride0, dst_stride1, src_stride); + after_kernel_launch(); +} + +#define INST(T) \ + template void exec_internal_to_vector( \ + T*, T*, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, cudaStream_t); +#define cb(DType) INST(typename DTypeTrait::ctype) +MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +cb(::megdnn::dtype::Bool) +#undef INST +#undef cb + +#define INST(T) \ + template void exec_internal_to_matrix( \ + T*, T*, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, \ + cudaStream_t); +#define cb(DType) INST(typename DTypeTrait::ctype) + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) cb(::megdnn::dtype::Bool) + +} // namespace diag +} // namespace cuda +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/cuda/diag/diag.cuh b/dnn/src/cuda/diag/diag.cuh new file mode 100644 index 0000000000000000000000000000000000000000..d88ac11f532cb9204920ad1d1a65f8a9da8ae429 --- /dev/null +++ b/dnn/src/cuda/diag/diag.cuh @@ -0,0 +1,33 @@ +/** + * \file dnn/src/cuda/diag/diag.cuh + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include +#include + +namespace megdnn { +namespace cuda { +namespace diag { + +template +void exec_internal_to_vector( + T* src, T* dst, ptrdiff_t start, ptrdiff_t size, ptrdiff_t stride_sum, + ptrdiff_t dst_stride, cudaStream_t stream); + +template +void exec_internal_to_matrix( + T* src, T* dst, ptrdiff_t start, ptrdiff_t n, ptrdiff_t k, + ptrdiff_t dst_stride0, ptrdiff_t dst_stride1, ptrdiff_t src_stride, + cudaStream_t stream); + +} // namespace diag +} // namespace cuda +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/cuda/diag/opr_impl.cpp b/dnn/src/cuda/diag/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7f5e89543d90889d3410a1a5b4574bc0db5cffbc --- /dev/null +++ b/dnn/src/cuda/diag/opr_impl.cpp @@ -0,0 +1,61 @@ +/** + * \file dnn/src/cuda/diag/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/cuda/diag/opr_impl.h" + +#include "src/cuda/diag/diag.cuh" +#include "src/cuda/utils.h" + +namespace megdnn { +namespace cuda { + +void DiagImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { + check_exec(src.layout, dst.layout, workspace.size); + if (src.layout.ndim == 2) { + auto src_stride0 = src.layout.stride[0]; + auto src_stride1 = src.layout.stride[1]; + auto dst_stride = dst.layout.stride[0]; + auto start = + (param().k >= 0) ? param().k * src_stride1 : -param().k * src_stride0; + +#define cb(DType) \ + if (dst.layout.dtype.enumv() == DTypeTrait::enumv) { \ + using ctype = typename DTypeTrait::ctype; \ + diag::exec_internal_to_vector( \ + src.ptr(), dst.ptr(), start, dst.layout.shape[0], \ + src_stride0 + src_stride1, dst_stride, cuda_stream(handle())); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) +#undef cb + } else { + auto n = dst.layout.shape[0]; + auto src_stride = src.layout.stride[0]; + auto dst_stride0 = dst.layout.stride[0]; + auto dst_stride1 = dst.layout.stride[1]; + auto offset = (param().k >= 0) ? 0 : -param().k; + +#define cb(DType) \ + if (dst.layout.dtype.enumv() == DTypeTrait::enumv) { \ + using ctype = typename DTypeTrait::ctype; \ + diag::exec_internal_to_matrix( \ + src.ptr(), dst.ptr(), offset, n, param().k, dst_stride0, \ + dst_stride1, src_stride, cuda_stream(handle())); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) +#undef cb + } +} + +} // namespace cuda +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/cuda/diag/opr_impl.h b/dnn/src/cuda/diag/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..9c65f82f70f5ac4b077802b49355e2fb1e2030e7 --- /dev/null +++ b/dnn/src/cuda/diag/opr_impl.h @@ -0,0 +1,31 @@ +/** + * \file dnn/src/cuda/diag/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace cuda { + +class DiagImpl final : public Diag { +public: + using Diag::Diag; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) override { + return 0; + } +}; + +} // namespace cuda +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index ac003713f298c0c86e811429ef8c1cf182bb66de..09f83606ab23b3bd4a3a11b381d67ce5e4e659ee 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -33,6 +33,7 @@ #include "src/cuda/dct/opr_impl.h" #include "src/cuda/deformable_conv/opr_impl.h" #include "src/cuda/deformable_ps_roi_pooling/opr_impl.h" +#include "src/cuda/diag/opr_impl.h" #include "src/cuda/dot/opr_impl.h" #include "src/cuda/dropout/opr_impl.h" #include "src/cuda/elemwise/opr_impl.h" @@ -154,6 +155,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedIncrMeshIndexing); MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedSetMeshIndexing); MEGDNN_SPECIALIZE_CREATE_OPERATOR(Linspace); MEGDNN_SPECIALIZE_CREATE_OPERATOR(Eye); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(Diag); MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(UniformRNG); MEGDNN_SPECIALIZE_CREATE_OPERATOR(GaussianRNG); diff --git a/dnn/src/naive/diag/opr_impl.cpp b/dnn/src/naive/diag/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1811f297d60199454d9a8be3191fdb1ef636cb51 --- /dev/null +++ b/dnn/src/naive/diag/opr_impl.cpp @@ -0,0 +1,60 @@ +/** + * \file dnn/src/naive/diag/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/naive/diag/opr_impl.h" +#include "src/common/utils.h" +#include "src/naive/handle.h" + +namespace megdnn { +namespace naive { + +template +void DiagImpl::exec_internal( + ctype* src, const TensorLayout& src_layout, ctype* dst, + const TensorLayout& dst_layout, size_t input_ndim, int k) { + if (input_ndim == 1) { + size_t l = src_layout.shape[0]; + size_t s0 = dst_layout.stride[0]; + size_t s1 = dst_layout.stride[1]; + size_t start = (k >= 0) ? (k * s1) : (-k * s0); + for (size_t i = 0; i < dst_layout.shape[0]; ++i) + for (size_t j = 0; j < dst_layout.shape[1]; ++j) + dst[i * s0 + j * s1] = 0; + for (size_t i = 0; i < l; ++i) + dst[start + i * (s0 + s1)] = src[i]; + } else { + size_t l = dst_layout.shape[0]; + size_t s0 = src_layout.stride[0]; + size_t s1 = src_layout.stride[1]; + size_t start = (k >= 0) ? (k * s1) : (-k * s0); + for (size_t i = 0; i < l; ++i) + dst[i] = src[start + i * (s0 + s1)]; + } +} + +void DiagImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { + check_exec(src.layout, dst.layout, workspace.size); +#define cb(DType) \ + if (src.layout.dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal( \ + src.ptr(), src.layout, dst.ptr(), dst.layout, \ + src.layout.ndim, param().k)); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) +#undef cb +} + +} // namespace naive +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/diag/opr_impl.h b/dnn/src/naive/diag/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..cdff2c35f55c79c9cb17cc60dfa6d905e63507c8 --- /dev/null +++ b/dnn/src/naive/diag/opr_impl.h @@ -0,0 +1,37 @@ +/** + * \file dnn/src/naive/diag/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace naive { + +class DiagImpl : public Diag { +public: + using Diag::Diag; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { + return 0; + } + +private: + template + void exec_internal( + ctype* src, const TensorLayout& src_layout, ctype* dst, + const TensorLayout& dst_layout, size_t input_ndim, int k); +}; + +} // namespace naive +} // namespace megdnn + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index be79091bbedd2df30bf4652c4f511425b31f8a45..e4dcc7af8ae49b2e62d7a2d5e350f4aef86647ab 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -34,6 +34,7 @@ #include "src/naive/dct/opr_impl.h" #include "src/naive/deformable_conv/opr_impl.h" #include "src/naive/deformable_ps_roi_pooling/opr_impl.h" +#include "src/naive/diag/opr_impl.h" #include "src/naive/dot/opr_impl.h" #include "src/naive/dropout/opr_impl.h" #include "src/naive/elemwise/opr_impl.h" diff --git a/dnn/test/cuda/diag.cpp b/dnn/test/cuda/diag.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dacf8dd29729aec724a536f622ed9b9570df93d8 --- /dev/null +++ b/dnn/test/cuda/diag.cpp @@ -0,0 +1,42 @@ +/** + * \file dnn/test/cuda/diag.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/cuda/fixture.h" + +#include "megdnn/oprs.h" +#include "test/common/checker.h" + +namespace megdnn { +namespace test { + +TEST_F(CUDA, DIAG) { + Checker checker(handle_cuda()); + for (DType dtype : + std::vector{dtype::Float16(), dtype::Int32(), dtype::Float32()}) + for (int k = -5; k < 5; ++k) { + checker.set_param({k}); + checker.set_dtype(0, dtype); + checker.set_dtype(1, dtype); + size_t absk = static_cast(std::abs(k)); + checker.exec(TensorShapeArray{{8}, {8 + absk, 8 + absk}}); + + auto oshape = [&](int n, int m) -> TensorShape { + size_t o = (k >= 0 ? std::min(n - k, m) : std::min(m + k, n)); + return {o, o}; + }; + checker.exec(TensorShapeArray{{8, 6}, oshape(8, 6)}); + checker.exec(TensorShapeArray{{6, 8}, oshape(6, 8)}); + checker.exec(TensorShapeArray{{8, 8}, oshape(8, 8)}); + } +} + +} // namespace test +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/test/naive/diag.cpp b/dnn/test/naive/diag.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4a560617b3ce56088870b92e780c7bcc9cd23a41 --- /dev/null +++ b/dnn/test/naive/diag.cpp @@ -0,0 +1,111 @@ +/** + * \file dnn/test/naive/diag.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megdnn/dtype.h" +#include "megdnn/oprs.h" +#include "test/common/checker.h" +#include "test/naive/fixture.h" + +namespace megdnn { +namespace test { + +TEST_F(NAIVE, DiagVector2Matrix) { + Checker checker(handle(), false); + Diag::Param param; + param.k = 0; + checker.set_param(param).exect( + Testcase{TensorValue({3}, dtype::Float32(), {1, 2, 3}), {}}, + Testcase{ + {}, + // clang-format off + TensorValue({3, 3}, dtype::Float32(), {1, 0, 0, + 0, 2, 0, + 0, 0, 3})}); + // clang-format on +} + +TEST_F(NAIVE, DiagVector2Matrix_PositiveK) { + Checker checker(handle(), false); + Diag::Param param; + param.k = 1; + checker.set_param(param).exect( + Testcase{TensorValue({3}, dtype::Float32(), {1, 2, 3}), {}}, + Testcase{ + {}, + // clang-format off + TensorValue({4, 4}, dtype::Float32(), {0, 1, 0, 0, + 0, 0, 2, 0, + 0, 0, 0, 3, + 0, 0, 0, 0,})}); + // clang-format on +} + +TEST_F(NAIVE, DiagVector2Matrix_NegativeK) { + Checker checker(handle(), false); + Diag::Param param; + param.k = -1; + checker.set_param(param).exect( + Testcase{TensorValue({3}, dtype::Float32(), {1, 2, 3}), {}}, + Testcase{ + {}, + // clang-format off + TensorValue({4, 4}, dtype::Float32(), {0, 0, 0, 0, + 1, 0, 0, 0, + 0, 2, 0, 0, + 0, 0, 3, 0,})}); + // clang-format on +} + +TEST_F(NAIVE, DiagMatrix2Vector) { + Checker checker(handle(), false); + Diag::Param param; + param.k = 0; + checker.set_param(param).exect( + // clang-format off + Testcase{TensorValue({3, 3}, dtype::Float32(), {1, 2, 3, + 4, 5, 6, + 7, 8, 9}), + // clang-format on + {}}, + Testcase{{}, TensorValue({3}, dtype::Float32(), {1, 5, 9})}); +} + +TEST_F(NAIVE, DiagMatrix2Vector_PositiveK) { + Checker checker(handle(), false); + Diag::Param param; + param.k = 1; + checker.set_param(param).exect( + // clang-format off + Testcase{TensorValue({3, 3}, dtype::Float32(), {1, 2, 3, + 4, 5, 6, + 7, 8, 9}), + // clang-format on + {}}, + Testcase{{}, TensorValue({2}, dtype::Float32(), {2, 6})}); +} + +TEST_F(NAIVE, DiagMatrix2Vector_NegativeK) { + Checker checker(handle(), false); + Diag::Param param; + param.k = -1; + checker.set_param(param).exect( + // clang-format off + Testcase{TensorValue({3, 3}, dtype::Float32(), {1, 2, 3, + 4, 5, 6, + 7, 8, 9}), + // clang-format on + {}}, + Testcase{{}, TensorValue({2}, dtype::Float32(), {4, 8})}); +} + +} // namespace test +} // namespace megdnn diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 77c7df3ab7576bbfd13980677921b39e95d5f5db..adc781b031994c7d6aa00f773257a7b2c41354a2 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -28,6 +28,7 @@ __all__ = [ "concat", "cond_take", "cumsum", + "diag", "expand_dims", "eye", "flatten", @@ -53,6 +54,32 @@ __all__ = [ ] +def diag(inp, k=0) -> Tensor: + r"""If ``inp`` is a 1D tensor, then returns a 2D tensor with the elements of ``inp`` as the diagonal. + If ``inp`` is a 2D tensor, then returns a 1D tensor with the diagonal elements of ``inp``. + + Args: + inp: input tensor. + k: diagonal in consider. Use :math:`k=0` for the main diagonal, :math:`k>0` for diagonals above the + main diagonal, and :math:`k<0` for diagonals below the main diagonal. Default: 0. + + Returns: + the extracted diagonal or constructed diagonal array. + + Examples: + >>> inp = F.arange(6, dtype='int32').reshape(2,3) + >>> out = F.diag(inp, k=1) + >>> out + Tensor([1 5], dtype=int32, device=xpux:0) + >>> F.diag(out) + Tensor([[1 0] + [0 5]], dtype=int32, device=xpux:0) + """ + op = builtin.Diag(k=k) + (result,) = apply(op, inp) + return result + + def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor: r"""Returns a 2D tensor with ones on the diagonal and zeros elsewhere. diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index 8f4f8d2562f99f9c1e8a3fe86d8d37bf8f6357d1..b103d659c77f0ad7533ee06a0a74e989d12537aa 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -42,6 +42,26 @@ def test_eye(): ) +@pytest.mark.parametrize("is_varnode", [False, True]) +def test_diag(is_varnode): + if is_varnode: + network = Network() + else: + network = None + + shapes = [(10, 10), (6, 9), (8, 7), (8,)] + cases = [] + for shp in shapes: + cases.append({"input": [np.random.random(shp).astype("float32")]}) + + for axis in range(-2, 3): + + def run(data): + return F.diag(data, k=axis) + + opr_test(cases, run, ref_fn=lambda x: np.diag(x, axis), network=network) + + def test_full(): shape = (2, 3) values = [True, 4, 5.0] diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index 5d3562a2b7aea8052a1230d1903279a8748bb861..d3819a0471a81c29e1bf861aaa03f4e894165441 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -432,6 +432,19 @@ OP_TRAIT_REG(Eye, Eye).apply_on_var_node(apply_on_var_node).fallback(); } // namespace eye } // namespace +namespace { +namespace diag { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 1); + cg::OperatorNodeConfig config{op.make_name()}; + opr::Diag::Param param{op.k}; + return opr::Diag::make(inputs[0], param, config); +} +OP_TRAIT_REG(Diag, Diag).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace diag +} // namespace + namespace { namespace roi_pooling { VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 894dd5fc7699f574fc0a0e0941faec5b255ff03e..d90bdcadf82b3e577fba9e4df8643f5bd452122d 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -240,6 +240,8 @@ def Eye: MgbHashableOp<"Eye", [EyeParam]> { ); } +def Diag: MgbHashableOp<"Diag", [DiagParam]>; + def GetVarShape : MgbHashableOp<"GetVarShape", [OptionalAxisV1Param]>; def Concat: MgbHashableOp<"Concat", [AxisParam]> { diff --git a/src/opr/impl/indexing.cpp b/src/opr/impl/indexing.cpp index 99f5d0e756e194648e3d9be311014075d0360919..c01ba2b6190407951ea0ea08296474a3ae570444 100644 --- a/src/opr/impl/indexing.cpp +++ b/src/opr/impl/indexing.cpp @@ -75,6 +75,91 @@ struct MegDNNOprInitInputsModifier } // namespace opr } // namespace mgb +/* ==================== Diag ==================== */ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(Diag); +MEGDNN_OPR_INIT1(Diag, "diag") + +#if MGB_ENABLE_GRAD +MGB_IMPL_OPR_GRAD(Diag) { + if (wrt_idx == 0) { + SymbolVar data_sym{opr.input(0)}; + return DiagBackward::make(data_sym.symshape(), out_grad[0], opr.param()).node(); + } + return InvalidGrad::make(opr, wrt_idx); +} +#endif + +/* ==================== DiagBackward ==================== */ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(DiagBackward); + +DiagBackward::DiagBackward( + VarNode* shape, VarNode* value, const Param& param, + const OperatorNodeConfig& config) + : Super{shape->owner_graph(), config, "diag_backward", {shape, value}}, + m_param{param} { + add_input({shape, value}); + add_output(None)->dtype(value->dtype()); + add_equivalence_component>(&m_param); +} + +SymbolVar DiagBackward::make( + SymbolVar shape, SymbolVar value, const Param& param, + const OperatorNodeConfig& config) { + return shape.insert_single_output_opr( + shape.node(), value.node(), param, config); +} + +cg::OperatorNodeBase::NodeProp* DiagBackward::do_make_node_prop() const { + auto prop = Super::do_make_node_prop(); + using D = NodeProp::DepType; + prop->add_dep_type(input(0), D::HOST_VALUE); + return prop; +} + +void DiagBackward::scn_do_execute() { + auto&& dest = output(0)->dev_tensor(); + auto&& val = input(1)->dev_tensor(); + auto&& layout = dest.layout(); + mgb_assert(layout.ndim == 1 || layout.ndim == 2); + if (layout.ndim == 2) { + dev_tensor_memset(dest, 0); + size_t offset = (m_param.k >= 0) ? (m_param.k * layout.stride[1]) + : (-m_param.k * layout.stride[0]); + auto dest_sub = dest.sub(SubTensorSpec::make_from_offset_elem( + {val.shape(), {layout.stride[0] + layout.stride[1]}, val.dtype()}, + offset)); + dest_sub.copy_from_fixlayout(val); + } else { + auto&& opr = m_dnn_opr; + if (!opr) { + opr = intl::create_megdnn_opr(comp_node()); + opr->param() = m_param; + } + opr->exec(val.as_megdnn(), dest.as_megdnn(), {}); + } +} + +void DiagBackward::record_execute_deps(ExecDependencyArray& deps) { + deps.emplace_back(std::make_unique(std::move(m_dnn_opr))); +} + +void DiagBackward::init_output_static_infer_desc() { + using namespace cg::static_infer; + auto&& mgr = owner_graph()->static_infer_manager(); + auto infer_shape = [](TensorShape& dest, const InpVal& inp) { + cg::copy_tensor_value_to_shape(dest, inp.val.at(0).value()); + return true; + }; + mgr.register_shape_infer( + output(0), {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_shape}); +} + +#if MGB_ENABLE_GRAD +MGB_IMPL_OPR_GRAD(DiagBackward) { + return InvalidGrad::make(opr, wrt_idx); +} +#endif + /* ==================== IndexingOneHot ==================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingOneHot); MEGDNN_OPR_INIT2(IndexingOneHot, "indexing_one_hot") diff --git a/src/opr/impl/indexing.oprdecl b/src/opr/impl/indexing.oprdecl index 1cd40c1d4b934d84a795307a51193c84cf5883fc..cdafdbfc00b9e3053dac8599c034f1a54a14c5dc 100644 --- a/src/opr/impl/indexing.oprdecl +++ b/src/opr/impl/indexing.oprdecl @@ -1,3 +1,25 @@ +decl_opr( + 'Diag', + desc='Extract a diagonal or construct a diagonal array', + inputs=[ + Doc('k', 'Index of the diagonal: 0 (the default) refers to the main ' + 'diagonal, a positive value refers to an upper diagonal, and a ' + 'negative value to a lower diagonal.') + ], + params='Diag' +) + +decl_opr( + 'DiagBackward', + desc='backward function of Diag', + inputs=[ + Doc('k', 'Index of the diagonal: 0 (the default) refers to the main ' + 'diagonal, a positive value refers to an upper diagonal, and a ' + 'negative value to a lower diagonal.') + ], + params='Diag' +) + decl_opr('IndexingOneHot', pyname='_indexing_one_hot', inputs=['src', 'index'], params=[('axis', 'Axis')]) diff --git a/src/opr/impl/indexing.sereg.h b/src/opr/impl/indexing.sereg.h index d910b4d791748d7aba926ad527e32eb7f3357d9b..8796783f575ca55bbbda8f13018ab956673c3970 100644 --- a/src/opr/impl/indexing.sereg.h +++ b/src/opr/impl/indexing.sereg.h @@ -25,6 +25,8 @@ MGB_SEREG_MODIFY_SUBTENSOR_OPR(BatchedSetMeshIndexing); namespace mgb { namespace opr { +MGB_SEREG_OPR(Diag, 1); +MGB_SEREG_OPR(DiagBackward, 2); MGB_SEREG_OPR(IndexingOneHot, 2); MGB_SEREG_OPR(IndexingRemap, 2); MGB_SEREG_OPR(IndexingRemapBackward, 3); diff --git a/src/opr/include/megbrain/opr/indexing.h b/src/opr/include/megbrain/opr/indexing.h index a8993830d1e7a8fcb0962a1d798fd8c07b4e256a..58d0b281c3a6f1e953ecb4d8f58bd3860c3339df 100644 --- a/src/opr/include/megbrain/opr/indexing.h +++ b/src/opr/include/megbrain/opr/indexing.h @@ -19,6 +19,37 @@ namespace mgb { namespace opr { +MGB_DEFINE_OPR_CLASS(Diag, intl::MegDNNOprWrapperFwd) // { +public: + MGE_WIN_DECLSPEC_FUC Diag( + VarNode* src, const Param& param, const OperatorNodeConfig& config); + MGE_WIN_DECLSPEC_FUC static SymbolVar make( + SymbolVar src, const Param& param, const OperatorNodeConfig& config = {}); +}; + +MGB_DEFINE_OPR_CLASS(DiagBackward, cg::SingleCNOperatorNodeBase) // { +public: + using Param = megdnn::Diag::Param; + MGE_WIN_DECLSPEC_FUC DiagBackward( + VarNode* shape, VarNode* value, const Param& param, + const OperatorNodeConfig& config); + MGE_WIN_DECLSPEC_FUC static SymbolVar make( + SymbolVar shape, SymbolVar value, const Param& param, + const OperatorNodeConfig& config = {}); + + const Param& param() const { return m_param; } + +private: + Param m_param; + intl::UniqPtrWithCN m_dnn_opr; + + void scn_do_execute() override; + void init_output_static_infer_desc() override; + NodeProp* do_make_node_prop() const override; + + void record_execute_deps(ExecDependencyArray& deps) override; +}; + MGB_DEFINE_OPR_CLASS( IndexingOneHot, intl::MegDNNOprWrapperFwd) // { public: diff --git a/src/opr/test/indexing.cpp b/src/opr/test/indexing.cpp index 0510936f10fe7d2802b44e0aa219a8634c7183c4..2f6433737dcf6cabade0d383b38f5489c9a90179 100644 --- a/src/opr/test/indexing.cpp +++ b/src/opr/test/indexing.cpp @@ -52,6 +52,37 @@ void gen_index_onehot(int* max_value, HostTensorND& dest) { } } +void test_diag(int32_t axis, const TensorShapeArray& test_cases) { + using Checker = AutoOprChecker<1, 1>; + auto nopr = megdnn_naive_handle()->create_operator(); + nopr->param() = {axis}; + + auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { + return {opr::Diag::make(inputs[0], {axis})}; + }; + + auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { + auto&& src = *inp[0]; + TensorShape oshp(src.shape()); + if (oshp.ndim == 1) { + size_t o = oshp.shape[0] + std::abs(axis); + oshp = {o, o}; + } else { + size_t m = oshp.shape[0]; + size_t n = oshp.shape[1]; + size_t o = (axis >= 0) ? std::min(n - axis, m) : std::min(m + axis, n); + oshp = {o}; + } + dest[0].resize(oshp); + nopr->exec(src.as_megdnn(), dest[0].as_megdnn(), {}); + }; + + Checker checker{make_graph, fwd}; + for (auto&& i : test_cases) { + checker.run({i}); + } +} + void test_one_hot_get(int32_t axis, const TensorShapeArray& test_cases) { using Checker = AutoOprChecker<2, 1>; @@ -145,6 +176,12 @@ void test_one_hot(int32_t axis, const TensorShapeArray& test_cases) { } // anonymous namespace +TEST(TestOprDiag, Diag) { + TensorShapeArray cases = {{7, 7}, {7, 9}, {9, 7}, {8}}; + for (int32_t k = -3; k < 3; ++k) + test_diag(k, cases); +} + TEST(TestOprIndexing, OneHot2D) { TensorShapeArray cases = {{1, 1}, {2, 2}, {10, 8}, {8, 10}}; test_one_hot(0, cases); diff --git a/src/serialization/impl/schema.fbs b/src/serialization/impl/schema.fbs index b6c46c39869a08da1726b03ee0c24047a71a2649..a7ac763bd9383faaf9f765f8207b36bdd20fe875 100644 --- a/src/serialization/impl/schema.fbs +++ b/src/serialization/impl/schema.fbs @@ -122,6 +122,7 @@ union OperatorParam { param.RNN = 88, param.LSTM = 89, param.Softmax = 90, + param.Diag = 91, } table Operator {