提交 95ac0555 编写于 作者: M Megvii Engine Team

feat(dnn,mgb,imperative): add diag opr implement

GitOrigin-RevId: 43016ffa2b99b3b459eea63e43ccb11deaf9eb36
上级 39d77fb5
......@@ -998,6 +998,28 @@ protected:
void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
};
class Diag : public OperatorBase {
DEF_OPR_IMPL(Diag, OperatorBase, 1, 1);
DEF_OPR_PARAM(Diag);
public:
/**
* \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.diag.html
*/
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) = 0;
protected:
void check_exec(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
};
class IndexingOneHotBase : public OperatorBase {
DEF_OPR_IMPL_CTOR(IndexingOneHotBase, OperatorBase);
DEF_OPR_PARAM(Axis);
......
......@@ -759,6 +759,14 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
'dtype', Doc('dtype', 'data type of output value'),
'DTypeEnum::Float32'))
(pdef('Diag').
add_fields(
'int32',
Doc('k', 'Index of the diagonal: 0 (the default) refers to the main '
'diagonal, a positive value refers to an upper diagonal, and a '
'negative value to a lower diagonal.'),
0))
(pdef('UniformRNG', version=0, is_legacy=True).
add_fields('uint64', 'seed', 0))
......
/**
* \file dnn/src/common/diag.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megdnn/oprs.h"
#include "src/common/utils.h"
namespace megdnn {
void Diag::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
megdnn_assert(
src.ndim == 1 || src.ndim == 2, "Only support vector or matrix as input.");
int k = param().k;
if (src.ndim == 1) {
size_t o = src.total_nr_elems() + std::abs(k);
dst = TensorLayout(TensorShape({o, o}), src.dtype);
} else { // src.ndim == 2
size_t m = src.shape[0];
size_t n = src.shape[1];
size_t o = (k >= 0 ? std::min(n - k, m) : std::min(m + k, n));
megdnn_assert(o > 0, "The moved diagonal is out of the input matrix.");
dst = TensorLayout(TensorShape({o}), src.dtype);
}
}
void Diag::check_exec(
const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) {
TensorLayout dst_expected;
megdnn_assert_eq_dtype(src, dst);
deduce_layout(src, dst_expected);
megdnn_assert_eq_layout(dst_expected, dst);
megdnn_assert_contiguous(dst);
auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -146,6 +146,7 @@ private:
cb(BatchedSetMeshIndexing) \
cb(Linspace) \
cb(Eye) \
cb(Diag) \
cb(SleepForward) \
cb(UniformRNG) \
cb(GaussianRNG) \
......
......@@ -88,6 +88,7 @@ DEF(IndexingRemapForward, 3, true, true);
DEF(IndexingRemapBackward, 3, true, false);
DEF(Linspace, 1, true, false);
DEF(Eye, 1, true, false);
DEF(Diag, 2, true, true);
DEF(Flip, 2, true, true);
DEF(ROICopy, 2, true, true);
DEF(Rotate, 2, true, true);
......
/**
* \file dnn/src/cuda/diag/diag.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megdnn/dtype.h"
#include "src/cuda/diag/diag.cuh"
#include "src/cuda/utils.cuh"
namespace {
template <typename T>
__global__ void kernel_to_vector(
T* src, T* dst, ptrdiff_t start, ptrdiff_t size, ptrdiff_t stride_sum,
ptrdiff_t dst_stride) {
ptrdiff_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i < size) {
dst[dst_stride * i] = src[start + stride_sum * i];
}
}
template <typename T>
__global__ void kernel_to_matrix(
T* src, T* dst, ptrdiff_t offset, ptrdiff_t n, ptrdiff_t k,
ptrdiff_t dst_stride0, ptrdiff_t dst_stride1, ptrdiff_t src_stride) {
ptrdiff_t i = threadIdx.x + blockIdx.x * blockDim.x;
ptrdiff_t x = i % n;
ptrdiff_t y = i / n;
ptrdiff_t p = dst_stride0 * y + dst_stride1 * x;
if (i < n * n) {
if (y + k == x)
dst[p] = src[src_stride * (y - offset)];
else
dst[p] = 0;
}
}
} // anonymous namespace
namespace megdnn {
namespace cuda {
namespace diag {
template <typename T>
void exec_internal_to_vector(
T* src, T* dst, ptrdiff_t start, ptrdiff_t size, ptrdiff_t stride_sum,
ptrdiff_t dst_stride, cudaStream_t stream) {
kernel_to_vector<T><<<DIVUP(size, NR_THREADS), NR_THREADS, 0, stream>>>(
src, dst, start, size, stride_sum, dst_stride);
after_kernel_launch();
}
template <typename T>
void exec_internal_to_matrix(
T* src, T* dst, ptrdiff_t offset, ptrdiff_t n, ptrdiff_t k,
ptrdiff_t dst_stride0, ptrdiff_t dst_stride1, ptrdiff_t src_stride,
cudaStream_t stream) {
kernel_to_matrix<T><<<DIVUP(n * n, NR_THREADS), NR_THREADS, 0, stream>>>(
src, dst, offset, n, k, dst_stride0, dst_stride1, src_stride);
after_kernel_launch();
}
#define INST(T) \
template void exec_internal_to_vector<T>( \
T*, T*, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, cudaStream_t);
#define cb(DType) INST(typename DTypeTrait<DType>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef INST
#undef cb
#define INST(T) \
template void exec_internal_to_matrix<T>( \
T*, T*, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, \
cudaStream_t);
#define cb(DType) INST(typename DTypeTrait<DType>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) cb(::megdnn::dtype::Bool)
} // namespace diag
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file dnn/src/cuda/diag/diag.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <cuda_runtime_api.h>
#include <stdint.h>
namespace megdnn {
namespace cuda {
namespace diag {
template <typename T>
void exec_internal_to_vector(
T* src, T* dst, ptrdiff_t start, ptrdiff_t size, ptrdiff_t stride_sum,
ptrdiff_t dst_stride, cudaStream_t stream);
template <typename T>
void exec_internal_to_matrix(
T* src, T* dst, ptrdiff_t start, ptrdiff_t n, ptrdiff_t k,
ptrdiff_t dst_stride0, ptrdiff_t dst_stride1, ptrdiff_t src_stride,
cudaStream_t stream);
} // namespace diag
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file dnn/src/cuda/diag/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/diag/opr_impl.h"
#include "src/cuda/diag/diag.cuh"
#include "src/cuda/utils.h"
namespace megdnn {
namespace cuda {
void DiagImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
if (src.layout.ndim == 2) {
auto src_stride0 = src.layout.stride[0];
auto src_stride1 = src.layout.stride[1];
auto dst_stride = dst.layout.stride[0];
auto start =
(param().k >= 0) ? param().k * src_stride1 : -param().k * src_stride0;
#define cb(DType) \
if (dst.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
using ctype = typename DTypeTrait<DType>::ctype; \
diag::exec_internal_to_vector<ctype>( \
src.ptr<ctype>(), dst.ptr<ctype>(), start, dst.layout.shape[0], \
src_stride0 + src_stride1, dst_stride, cuda_stream(handle())); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
} else {
auto n = dst.layout.shape[0];
auto src_stride = src.layout.stride[0];
auto dst_stride0 = dst.layout.stride[0];
auto dst_stride1 = dst.layout.stride[1];
auto offset = (param().k >= 0) ? 0 : -param().k;
#define cb(DType) \
if (dst.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
using ctype = typename DTypeTrait<DType>::ctype; \
diag::exec_internal_to_matrix<ctype>( \
src.ptr<ctype>(), dst.ptr<ctype>(), offset, n, param().k, dst_stride0, \
dst_stride1, src_stride, cuda_stream(handle())); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
}
}
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file dnn/src/cuda/diag/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace cuda {
class DiagImpl final : public Diag {
public:
using Diag::Diag;
void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) override {
return 0;
}
};
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -33,6 +33,7 @@
#include "src/cuda/dct/opr_impl.h"
#include "src/cuda/deformable_conv/opr_impl.h"
#include "src/cuda/deformable_ps_roi_pooling/opr_impl.h"
#include "src/cuda/diag/opr_impl.h"
#include "src/cuda/dot/opr_impl.h"
#include "src/cuda/dropout/opr_impl.h"
#include "src/cuda/elemwise/opr_impl.h"
......@@ -154,6 +155,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedIncrMeshIndexing);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedSetMeshIndexing);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Linspace);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Eye);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Diag);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(UniformRNG);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GaussianRNG);
......
/**
* \file dnn/src/naive/diag/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/naive/diag/opr_impl.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
namespace megdnn {
namespace naive {
template <typename ctype>
void DiagImpl::exec_internal(
ctype* src, const TensorLayout& src_layout, ctype* dst,
const TensorLayout& dst_layout, size_t input_ndim, int k) {
if (input_ndim == 1) {
size_t l = src_layout.shape[0];
size_t s0 = dst_layout.stride[0];
size_t s1 = dst_layout.stride[1];
size_t start = (k >= 0) ? (k * s1) : (-k * s0);
for (size_t i = 0; i < dst_layout.shape[0]; ++i)
for (size_t j = 0; j < dst_layout.shape[1]; ++j)
dst[i * s0 + j * s1] = 0;
for (size_t i = 0; i < l; ++i)
dst[start + i * (s0 + s1)] = src[i];
} else {
size_t l = dst_layout.shape[0];
size_t s0 = src_layout.stride[0];
size_t s1 = src_layout.stride[1];
size_t start = (k >= 0) ? (k * s1) : (-k * s0);
for (size_t i = 0; i < l; ++i)
dst[i] = src[start + i * (s0 + s1)];
}
}
void DiagImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
#define cb(DType) \
if (src.layout.dtype == DType()) { \
using ctype = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal<ctype>( \
src.ptr<ctype>(), src.layout, dst.ptr<ctype>(), dst.layout, \
src.layout.ndim, param().k)); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
}
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/naive/diag/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace naive {
class DiagImpl : public Diag {
public:
using Diag::Diag;
void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override {
return 0;
}
private:
template <typename ctype>
void exec_internal(
ctype* src, const TensorLayout& src_layout, ctype* dst,
const TensorLayout& dst_layout, size_t input_ndim, int k);
};
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -34,6 +34,7 @@
#include "src/naive/dct/opr_impl.h"
#include "src/naive/deformable_conv/opr_impl.h"
#include "src/naive/deformable_ps_roi_pooling/opr_impl.h"
#include "src/naive/diag/opr_impl.h"
#include "src/naive/dot/opr_impl.h"
#include "src/naive/dropout/opr_impl.h"
#include "src/naive/elemwise/opr_impl.h"
......
/**
* \file dnn/test/cuda/diag.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "test/cuda/fixture.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
namespace megdnn {
namespace test {
TEST_F(CUDA, DIAG) {
Checker<Diag> checker(handle_cuda());
for (DType dtype :
std::vector<DType>{dtype::Float16(), dtype::Int32(), dtype::Float32()})
for (int k = -5; k < 5; ++k) {
checker.set_param({k});
checker.set_dtype(0, dtype);
checker.set_dtype(1, dtype);
size_t absk = static_cast<size_t>(std::abs(k));
checker.exec(TensorShapeArray{{8}, {8 + absk, 8 + absk}});
auto oshape = [&](int n, int m) -> TensorShape {
size_t o = (k >= 0 ? std::min(n - k, m) : std::min(m + k, n));
return {o, o};
};
checker.exec(TensorShapeArray{{8, 6}, oshape(8, 6)});
checker.exec(TensorShapeArray{{6, 8}, oshape(6, 8)});
checker.exec(TensorShapeArray{{8, 8}, oshape(8, 8)});
}
}
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file dnn/test/naive/diag.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/dtype.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
#include "test/naive/fixture.h"
namespace megdnn {
namespace test {
TEST_F(NAIVE, DiagVector2Matrix) {
Checker<Diag> checker(handle(), false);
Diag::Param param;
param.k = 0;
checker.set_param(param).exect(
Testcase{TensorValue({3}, dtype::Float32(), {1, 2, 3}), {}},
Testcase{
{},
// clang-format off
TensorValue({3, 3}, dtype::Float32(), {1, 0, 0,
0, 2, 0,
0, 0, 3})});
// clang-format on
}
TEST_F(NAIVE, DiagVector2Matrix_PositiveK) {
Checker<Diag> checker(handle(), false);
Diag::Param param;
param.k = 1;
checker.set_param(param).exect(
Testcase{TensorValue({3}, dtype::Float32(), {1, 2, 3}), {}},
Testcase{
{},
// clang-format off
TensorValue({4, 4}, dtype::Float32(), {0, 1, 0, 0,
0, 0, 2, 0,
0, 0, 0, 3,
0, 0, 0, 0,})});
// clang-format on
}
TEST_F(NAIVE, DiagVector2Matrix_NegativeK) {
Checker<Diag> checker(handle(), false);
Diag::Param param;
param.k = -1;
checker.set_param(param).exect(
Testcase{TensorValue({3}, dtype::Float32(), {1, 2, 3}), {}},
Testcase{
{},
// clang-format off
TensorValue({4, 4}, dtype::Float32(), {0, 0, 0, 0,
1, 0, 0, 0,
0, 2, 0, 0,
0, 0, 3, 0,})});
// clang-format on
}
TEST_F(NAIVE, DiagMatrix2Vector) {
Checker<Diag> checker(handle(), false);
Diag::Param param;
param.k = 0;
checker.set_param(param).exect(
// clang-format off
Testcase{TensorValue({3, 3}, dtype::Float32(), {1, 2, 3,
4, 5, 6,
7, 8, 9}),
// clang-format on
{}},
Testcase{{}, TensorValue({3}, dtype::Float32(), {1, 5, 9})});
}
TEST_F(NAIVE, DiagMatrix2Vector_PositiveK) {
Checker<Diag> checker(handle(), false);
Diag::Param param;
param.k = 1;
checker.set_param(param).exect(
// clang-format off
Testcase{TensorValue({3, 3}, dtype::Float32(), {1, 2, 3,
4, 5, 6,
7, 8, 9}),
// clang-format on
{}},
Testcase{{}, TensorValue({2}, dtype::Float32(), {2, 6})});
}
TEST_F(NAIVE, DiagMatrix2Vector_NegativeK) {
Checker<Diag> checker(handle(), false);
Diag::Param param;
param.k = -1;
checker.set_param(param).exect(
// clang-format off
Testcase{TensorValue({3, 3}, dtype::Float32(), {1, 2, 3,
4, 5, 6,
7, 8, 9}),
// clang-format on
{}},
Testcase{{}, TensorValue({2}, dtype::Float32(), {4, 8})});
}
} // namespace test
} // namespace megdnn
......@@ -28,6 +28,7 @@ __all__ = [
"concat",
"cond_take",
"cumsum",
"diag",
"expand_dims",
"eye",
"flatten",
......@@ -53,6 +54,32 @@ __all__ = [
]
def diag(inp, k=0) -> Tensor:
r"""If ``inp`` is a 1D tensor, then returns a 2D tensor with the elements of ``inp`` as the diagonal.
If ``inp`` is a 2D tensor, then returns a 1D tensor with the diagonal elements of ``inp``.
Args:
inp: input tensor.
k: diagonal in consider. Use :math:`k=0` for the main diagonal, :math:`k>0` for diagonals above the
main diagonal, and :math:`k<0` for diagonals below the main diagonal. Default: 0.
Returns:
the extracted diagonal or constructed diagonal array.
Examples:
>>> inp = F.arange(6, dtype='int32').reshape(2,3)
>>> out = F.diag(inp, k=1)
>>> out
Tensor([1 5], dtype=int32, device=xpux:0)
>>> F.diag(out)
Tensor([[1 0]
[0 5]], dtype=int32, device=xpux:0)
"""
op = builtin.Diag(k=k)
(result,) = apply(op, inp)
return result
def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor:
r"""Returns a 2D tensor with ones on the diagonal and zeros elsewhere.
......
......@@ -42,6 +42,26 @@ def test_eye():
)
@pytest.mark.parametrize("is_varnode", [False, True])
def test_diag(is_varnode):
if is_varnode:
network = Network()
else:
network = None
shapes = [(10, 10), (6, 9), (8, 7), (8,)]
cases = []
for shp in shapes:
cases.append({"input": [np.random.random(shp).astype("float32")]})
for axis in range(-2, 3):
def run(data):
return F.diag(data, k=axis)
opr_test(cases, run, ref_fn=lambda x: np.diag(x, axis), network=network)
def test_full():
shape = (2, 3)
values = [True, 4, 5.0]
......
......@@ -432,6 +432,19 @@ OP_TRAIT_REG(Eye, Eye).apply_on_var_node(apply_on_var_node).fallback();
} // namespace eye
} // namespace
namespace {
namespace diag {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const Diag&>(def);
mgb_assert(inputs.size() == 1);
cg::OperatorNodeConfig config{op.make_name()};
opr::Diag::Param param{op.k};
return opr::Diag::make(inputs[0], param, config);
}
OP_TRAIT_REG(Diag, Diag).apply_on_var_node(apply_on_var_node).fallback();
} // namespace diag
} // namespace
namespace {
namespace roi_pooling {
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
......
......@@ -240,6 +240,8 @@ def Eye: MgbHashableOp<"Eye", [EyeParam]> {
);
}
def Diag: MgbHashableOp<"Diag", [DiagParam]>;
def GetVarShape : MgbHashableOp<"GetVarShape", [OptionalAxisV1Param]>;
def Concat: MgbHashableOp<"Concat", [AxisParam]> {
......
......@@ -75,6 +75,91 @@ struct MegDNNOprInitInputsModifier<IndexingSetOneHot>
} // namespace opr
} // namespace mgb
/* ==================== Diag ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Diag);
MEGDNN_OPR_INIT1(Diag, "diag")
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Diag) {
if (wrt_idx == 0) {
SymbolVar data_sym{opr.input(0)};
return DiagBackward::make(data_sym.symshape(), out_grad[0], opr.param()).node();
}
return InvalidGrad::make(opr, wrt_idx);
}
#endif
/* ==================== DiagBackward ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(DiagBackward);
DiagBackward::DiagBackward(
VarNode* shape, VarNode* value, const Param& param,
const OperatorNodeConfig& config)
: Super{shape->owner_graph(), config, "diag_backward", {shape, value}},
m_param{param} {
add_input({shape, value});
add_output(None)->dtype(value->dtype());
add_equivalence_component<PODHash<Param>>(&m_param);
}
SymbolVar DiagBackward::make(
SymbolVar shape, SymbolVar value, const Param& param,
const OperatorNodeConfig& config) {
return shape.insert_single_output_opr<DiagBackward>(
shape.node(), value.node(), param, config);
}
cg::OperatorNodeBase::NodeProp* DiagBackward::do_make_node_prop() const {
auto prop = Super::do_make_node_prop();
using D = NodeProp::DepType;
prop->add_dep_type(input(0), D::HOST_VALUE);
return prop;
}
void DiagBackward::scn_do_execute() {
auto&& dest = output(0)->dev_tensor();
auto&& val = input(1)->dev_tensor();
auto&& layout = dest.layout();
mgb_assert(layout.ndim == 1 || layout.ndim == 2);
if (layout.ndim == 2) {
dev_tensor_memset(dest, 0);
size_t offset = (m_param.k >= 0) ? (m_param.k * layout.stride[1])
: (-m_param.k * layout.stride[0]);
auto dest_sub = dest.sub(SubTensorSpec::make_from_offset_elem(
{val.shape(), {layout.stride[0] + layout.stride[1]}, val.dtype()},
offset));
dest_sub.copy_from_fixlayout(val);
} else {
auto&& opr = m_dnn_opr;
if (!opr) {
opr = intl::create_megdnn_opr<megdnn::Diag>(comp_node());
opr->param() = m_param;
}
opr->exec(val.as_megdnn(), dest.as_megdnn(), {});
}
}
void DiagBackward::record_execute_deps(ExecDependencyArray& deps) {
deps.emplace_back(std::make_unique<intl::MegDNNGraphDep>(std::move(m_dnn_opr)));
}
void DiagBackward::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
auto infer_shape = [](TensorShape& dest, const InpVal& inp) {
cg::copy_tensor_value_to_shape(dest, inp.val.at(0).value());
return true;
};
mgr.register_shape_infer(
output(0), {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_shape});
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(DiagBackward) {
return InvalidGrad::make(opr, wrt_idx);
}
#endif
/* ==================== IndexingOneHot ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingOneHot);
MEGDNN_OPR_INIT2(IndexingOneHot, "indexing_one_hot")
......
decl_opr(
'Diag',
desc='Extract a diagonal or construct a diagonal array',
inputs=[
Doc('k', 'Index of the diagonal: 0 (the default) refers to the main '
'diagonal, a positive value refers to an upper diagonal, and a '
'negative value to a lower diagonal.')
],
params='Diag'
)
decl_opr(
'DiagBackward',
desc='backward function of Diag',
inputs=[
Doc('k', 'Index of the diagonal: 0 (the default) refers to the main '
'diagonal, a positive value refers to an upper diagonal, and a '
'negative value to a lower diagonal.')
],
params='Diag'
)
decl_opr('IndexingOneHot', pyname='_indexing_one_hot',
inputs=['src', 'index'],
params=[('axis', 'Axis')])
......
......@@ -25,6 +25,8 @@ MGB_SEREG_MODIFY_SUBTENSOR_OPR(BatchedSetMeshIndexing);
namespace mgb {
namespace opr {
MGB_SEREG_OPR(Diag, 1);
MGB_SEREG_OPR(DiagBackward, 2);
MGB_SEREG_OPR(IndexingOneHot, 2);
MGB_SEREG_OPR(IndexingRemap, 2);
MGB_SEREG_OPR(IndexingRemapBackward, 3);
......
......@@ -19,6 +19,37 @@
namespace mgb {
namespace opr {
MGB_DEFINE_OPR_CLASS(Diag, intl::MegDNNOprWrapperFwd<megdnn::Diag>) // {
public:
MGE_WIN_DECLSPEC_FUC Diag(
VarNode* src, const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, const Param& param, const OperatorNodeConfig& config = {});
};
MGB_DEFINE_OPR_CLASS(DiagBackward, cg::SingleCNOperatorNodeBase) // {
public:
using Param = megdnn::Diag::Param;
MGE_WIN_DECLSPEC_FUC DiagBackward(
VarNode* shape, VarNode* value, const Param& param,
const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar shape, SymbolVar value, const Param& param,
const OperatorNodeConfig& config = {});
const Param& param() const { return m_param; }
private:
Param m_param;
intl::UniqPtrWithCN<megdnn::Diag> m_dnn_opr;
void scn_do_execute() override;
void init_output_static_infer_desc() override;
NodeProp* do_make_node_prop() const override;
void record_execute_deps(ExecDependencyArray& deps) override;
};
MGB_DEFINE_OPR_CLASS(
IndexingOneHot, intl::MegDNNOprWrapperFwd<megdnn::IndexingOneHotForward>) // {
public:
......
......@@ -52,6 +52,37 @@ void gen_index_onehot(int* max_value, HostTensorND& dest) {
}
}
void test_diag(int32_t axis, const TensorShapeArray& test_cases) {
using Checker = AutoOprChecker<1, 1>;
auto nopr = megdnn_naive_handle()->create_operator<megdnn::Diag>();
nopr->param() = {axis};
auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
return {opr::Diag::make(inputs[0], {axis})};
};
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
auto&& src = *inp[0];
TensorShape oshp(src.shape());
if (oshp.ndim == 1) {
size_t o = oshp.shape[0] + std::abs(axis);
oshp = {o, o};
} else {
size_t m = oshp.shape[0];
size_t n = oshp.shape[1];
size_t o = (axis >= 0) ? std::min(n - axis, m) : std::min(m + axis, n);
oshp = {o};
}
dest[0].resize(oshp);
nopr->exec(src.as_megdnn(), dest[0].as_megdnn(), {});
};
Checker checker{make_graph, fwd};
for (auto&& i : test_cases) {
checker.run({i});
}
}
void test_one_hot_get(int32_t axis, const TensorShapeArray& test_cases) {
using Checker = AutoOprChecker<2, 1>;
......@@ -145,6 +176,12 @@ void test_one_hot(int32_t axis, const TensorShapeArray& test_cases) {
} // anonymous namespace
TEST(TestOprDiag, Diag) {
TensorShapeArray cases = {{7, 7}, {7, 9}, {9, 7}, {8}};
for (int32_t k = -3; k < 3; ++k)
test_diag(k, cases);
}
TEST(TestOprIndexing, OneHot2D) {
TensorShapeArray cases = {{1, 1}, {2, 2}, {10, 8}, {8, 10}};
test_one_hot(0, cases);
......
......@@ -122,6 +122,7 @@ union OperatorParam {
param.RNN = 88,
param.LSTM = 89,
param.Softmax = 90,
param.Diag = 91,
}
table Operator {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册