From 1af350c6d22f7d973255f9b226d4f706829e1bdd Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 16 Jul 2021 18:40:03 +0800 Subject: [PATCH] feat(dnn): add fill kernel GitOrigin-RevId: d2cee3a7a009849c2c608918e2ec51500bafa88b --- dnn/include/megdnn/oprs/general.h | 15 +++++++++ dnn/scripts/opr_param_defs.py | 2 +- dnn/src/common/fill.cpp | 25 +++++++++++++++ dnn/src/common/handle_impl.h | 3 +- dnn/src/common/opr_trait.h | 1 + dnn/src/cuda/fill/kern.cu | 45 +++++++++++++++++++++++++++ dnn/src/cuda/fill/kern.cuh | 25 +++++++++++++++ dnn/src/cuda/fill/opr_impl.cpp | 37 ++++++++++++++++++++++ dnn/src/cuda/fill/opr_impl.h | 31 +++++++++++++++++++ dnn/src/cuda/handle_create.cpp | 1 + dnn/src/naive/fill/opr_impl.cpp | 46 ++++++++++++++++++++++++++++ dnn/src/naive/fill/opr_impl.h | 33 ++++++++++++++++++++ dnn/src/naive/handle.cpp | 1 + dnn/src/rocm/fill/fill.cpp.hip | 51 +++++++++++++++++++++++++++++++ dnn/src/rocm/fill/fill.h.hip | 25 +++++++++++++++ dnn/src/rocm/fill/opr_impl.cpp | 36 ++++++++++++++++++++++ dnn/src/rocm/fill/opr_impl.h | 26 ++++++++++++++++ dnn/src/rocm/handle.cpp | 2 ++ dnn/test/common/fill.h | 37 ++++++++++++++++++++++ dnn/test/cuda/fill.cpp | 36 ++++++++++++++++++++++ dnn/test/rocm/fill.cpp | 36 ++++++++++++++++++++++ 21 files changed, 512 insertions(+), 2 deletions(-) create mode 100644 dnn/src/common/fill.cpp create mode 100644 dnn/src/cuda/fill/kern.cu create mode 100644 dnn/src/cuda/fill/kern.cuh create mode 100644 dnn/src/cuda/fill/opr_impl.cpp create mode 100644 dnn/src/cuda/fill/opr_impl.h create mode 100644 dnn/src/naive/fill/opr_impl.cpp create mode 100644 dnn/src/naive/fill/opr_impl.h create mode 100644 dnn/src/rocm/fill/fill.cpp.hip create mode 100644 dnn/src/rocm/fill/fill.h.hip create mode 100644 dnn/src/rocm/fill/opr_impl.cpp create mode 100644 dnn/src/rocm/fill/opr_impl.h create mode 100644 dnn/test/common/fill.h create mode 100644 dnn/test/cuda/fill.cpp create mode 100644 dnn/test/rocm/fill.cpp diff --git a/dnn/include/megdnn/oprs/general.h b/dnn/include/megdnn/oprs/general.h index 829d5395f..2f9abccd0 100644 --- a/dnn/include/megdnn/oprs/general.h +++ b/dnn/include/megdnn/oprs/general.h @@ -1338,6 +1338,21 @@ class CheckHasInf: public OperatorBase { void check_exec(const TensorLayout &src, const TensorLayout &dst, size_t workspace_in_bytes); }; + +/*! + * \brief fill the tensor with a scalar value + */ +class Fill: public OperatorBase { + DEF_OPR_PARAM(Fill); + DEF_OPR_IMPL(Fill, OperatorBase, 0, 1); + +public: + virtual void exec(_megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes(const TensorLayout& dst) = 0; +protected: + void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); +}; } // namespace megdnn #include "megdnn/internal/opr_header_epilogue.h" diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index f1aa516c1..e8c661676 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -1170,4 +1170,4 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o add_fields('int32', 'qmin', '-2147483648'). add_fields('int32', 'qmax', '2147483647') ) - +pdef('Fill').add_fields('float32', 'value', '0') diff --git a/dnn/src/common/fill.cpp b/dnn/src/common/fill.cpp new file mode 100644 index 000000000..d4ccf2fa9 --- /dev/null +++ b/dnn/src/common/fill.cpp @@ -0,0 +1,25 @@ +/** + * \file dnn/src/common/fill.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megdnn/oprs.h" +#include "src/common/utils.h" + +namespace megdnn { + +void Fill::check_exec(const TensorLayout& dst, size_t workspace_in_bytes) { + megdnn_assert_contiguous(dst); + auto required_workspace_in_bytes = get_workspace_in_bytes(dst); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); +} + +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index 1f6431f26..4e393752f 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -216,7 +216,8 @@ private: cb(TQTBackward) \ cb(CheckHasInf) \ cb(LSQForward) \ - cb(LSQBackward) + cb(LSQBackward) \ + cb(Fill) /*! * \brief specialize HandleImpl::create_operator for a single opr type; diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h index 952010620..1417bdce7 100644 --- a/dnn/src/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -130,6 +130,7 @@ DEF(ChecksumForward, 1, true, false); DEF(CheckHasInf, 2, true, true); DEF(LSQForward, 5, true, true); DEF(LSQBackward, 7, true, false); +DEF(Fill, 1, true, false); } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/fill/kern.cu b/dnn/src/cuda/fill/kern.cu new file mode 100644 index 000000000..1daf50a94 --- /dev/null +++ b/dnn/src/cuda/fill/kern.cu @@ -0,0 +1,45 @@ +/** + * \file dnn/src/cuda/fill/kern.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/cuda/fill/kern.cuh" +#include "megdnn/dtype.h" +#include "src/cuda/utils.cuh" + +namespace { + +template +__global__ void kernel(T *dst, T value, uint32_t size) { + int32_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i < size) { + dst[i] = value; + } +} + +} // anonymous namespace + +namespace megdnn { +namespace cuda { +namespace fill { + +template +void exec_internal(T *dst, T value, size_t size, cudaStream_t stream) { + kernel<<>>(dst, value, size); + after_kernel_launch(); +} + +#define INST(T) template void exec_internal(T *, \ + T, size_t, cudaStream_t); +#define cb(DType) INST(typename DTypeTrait::ctype) +MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + +} // namespace fill +} // namespace cuda +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/cuda/fill/kern.cuh b/dnn/src/cuda/fill/kern.cuh new file mode 100644 index 000000000..a79f93560 --- /dev/null +++ b/dnn/src/cuda/fill/kern.cuh @@ -0,0 +1,25 @@ +/** + * \file dnn/src/cuda/fill/kern.cuh + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include +#include + +namespace megdnn { +namespace cuda { +namespace fill { + +template +void exec_internal(T *dst, T value, size_t size, cudaStream_t stream); + +} // namespace fill +} // namespace cuda +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/cuda/fill/opr_impl.cpp b/dnn/src/cuda/fill/opr_impl.cpp new file mode 100644 index 000000000..f9966a2c6 --- /dev/null +++ b/dnn/src/cuda/fill/opr_impl.cpp @@ -0,0 +1,37 @@ +/** + * \file dnn/src/cuda/fill/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/cuda/fill/kern.cuh" +#include "src/cuda/fill/opr_impl.h" + +#include "src/cuda/utils.h" + +namespace megdnn { +namespace cuda { + +void FillImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { + check_exec(dst.layout, workspace.size); + auto stream = cuda_stream(handle()); + auto size = dst.layout.total_nr_elems(); +#define cb(DType) \ + if (dst.layout.dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ + fill::exec_internal(dst.ptr(), \ + static_cast(param().value), size, stream); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb +} + +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/fill/opr_impl.h b/dnn/src/cuda/fill/opr_impl.h new file mode 100644 index 000000000..d29eae249 --- /dev/null +++ b/dnn/src/cuda/fill/opr_impl.h @@ -0,0 +1,31 @@ +/** + * \file dnn/src/cuda/fill/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace cuda { + +class FillImpl final : public Fill { +public: + using Fill::Fill; + void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout &) override { + return 0; + } +}; + +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index 83409e068..f9ab2233e 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -38,6 +38,7 @@ #include "src/cuda/elemwise_multi_type/opr_impl.h" #include "src/cuda/eye/opr_impl.h" #include "src/cuda/fake_quant/opr_impl.h" +#include "src/cuda/fill/opr_impl.h" #include "src/cuda/flip/opr_impl.h" #include "src/cuda/gaussian_blur/opr_impl.h" #include "src/cuda/group_local/opr_impl.h" diff --git a/dnn/src/naive/fill/opr_impl.cpp b/dnn/src/naive/fill/opr_impl.cpp new file mode 100644 index 000000000..b133a27f0 --- /dev/null +++ b/dnn/src/naive/fill/opr_impl.cpp @@ -0,0 +1,46 @@ +/** + * \file dnn/src/naive/fill/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/naive/fill/opr_impl.h" +#include "src/naive/handle.h" +#include "src/common/utils.h" + +#include +#include + +namespace megdnn { +namespace naive { + +void FillImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { + check_exec(dst.layout, workspace.size); + size_t size = dst.layout.total_nr_elems(); +#define cb(DType) \ + if (dst.layout.dtype.enumv() == DTypeTrait::enumv) { \ + using ctype = typename DTypeTrait::ctype; \ + ctype *ptr = dst.ptr(); \ + MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal(ptr, size)); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb +} + +template +void FillImpl::exec_internal(ctype *dst, size_t size) { + auto value = static_cast(param().value); + for (size_t i = 0; i < size; ++i) { + dst[i] = value; + } +} + +} // namespace naive +} // namespace megdnn +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/naive/fill/opr_impl.h b/dnn/src/naive/fill/opr_impl.h new file mode 100644 index 000000000..68f652e5a --- /dev/null +++ b/dnn/src/naive/fill/opr_impl.h @@ -0,0 +1,33 @@ +/** + * \file dnn/src/naive/fill/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace naive { + +class FillImpl : public Fill { +public: + using Fill::Fill; + + void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout &) override { + return 0; + } +private: + template + void exec_internal(ctype *dst, size_t size); +}; + +} // namespace naive +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index 2d6bd32d7..6c76dc339 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -40,6 +40,7 @@ #include "src/naive/elemwise_multi_type/opr_impl.h" #include "src/naive/eye/opr_impl.h" #include "src/naive/fake_quant/opr_impl.h" +#include "src/naive/fill/opr_impl.h" #include "src/naive/flip/opr_impl.h" #include "src/naive/gaussian_blur/opr_impl.h" #include "src/naive/group_local/opr_impl.h" diff --git a/dnn/src/rocm/fill/fill.cpp.hip b/dnn/src/rocm/fill/fill.cpp.hip new file mode 100644 index 000000000..6c252827f --- /dev/null +++ b/dnn/src/rocm/fill/fill.cpp.hip @@ -0,0 +1,51 @@ +/** + * \file dnn/src/rocm/fill/fill.cpp.hip + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" +#include "hip_header.h" +#include "megdnn/dtype.h" +#include "src/rocm/fill/fill.h.hip" +#include "src/rocm/utils.h.hip" + +namespace { + +template +__global__ void kernel(T *dst, T value, uint32_t size) { + int32_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i < size) { + dst[i] = value; + } +} + +} // anonymous namespace + +namespace megdnn { +namespace rocm { +namespace fill { + +template +void exec_internal(T *dst, T value, size_t size, hipStream_t stream) { + hipLaunchKernelGGL( + (kernel), + dim3(DIVUP(size, NR_THREADS)), + dim3(NR_THREADS), + 0, stream, dst, value, size); + after_kernel_launch(); +} + +#define INST(T) template void exec_internal(T *, \ + T, size_t, hipStream_t); +#define cb(DType) INST(typename DTypeTrait::ctype) +MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + +} // namespace fill +} // namespace rocm +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/rocm/fill/fill.h.hip b/dnn/src/rocm/fill/fill.h.hip new file mode 100644 index 000000000..24b9ee016 --- /dev/null +++ b/dnn/src/rocm/fill/fill.h.hip @@ -0,0 +1,25 @@ +/** + * \file dnn/src/rocm/fill/fill.h.hip + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include +#include "hip_header.h" + +namespace megdnn { +namespace rocm { +namespace fill { + +template +void exec_internal(T *dst, T value, size_t size, hipStream_t stream); + +} // namespace fill +} // namespace rocm +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/rocm/fill/opr_impl.cpp b/dnn/src/rocm/fill/opr_impl.cpp new file mode 100644 index 000000000..e15bcfde2 --- /dev/null +++ b/dnn/src/rocm/fill/opr_impl.cpp @@ -0,0 +1,36 @@ +/** + * \file dnn/src/rocm/fill/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" +#include "src/rocm/fill/opr_impl.h" + +#include "src/rocm/fill/fill.h.hip" +#include "src/rocm/utils.h" + +namespace megdnn { +namespace rocm { + +void FillImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { + check_exec(dst.layout, workspace.size); + auto stream = hip_stream(handle()); + auto size = dst.layout.total_nr_elems(); +#define cb(DType) \ + if (dst.layout.dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ + fill::exec_internal(dst.ptr(), \ + static_cast(param().value), size, stream); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb +} + +} // namespace rocm +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/rocm/fill/opr_impl.h b/dnn/src/rocm/fill/opr_impl.h new file mode 100644 index 000000000..4c2f0cc47 --- /dev/null +++ b/dnn/src/rocm/fill/opr_impl.h @@ -0,0 +1,26 @@ +/** + * \file dnn/src/rocm/fill/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace rocm { + +class FillImpl final : public Fill { +public: + using Fill::Fill; + void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } +}; + +} // namespace rocm +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/rocm/handle.cpp b/dnn/src/rocm/handle.cpp index 3b5b8a9ec..b3752e763 100644 --- a/dnn/src/rocm/handle.cpp +++ b/dnn/src/rocm/handle.cpp @@ -37,6 +37,7 @@ #include "src/rocm/sleep/opr_impl.h" #include "src/rocm/batch_normalization/opr_impl.h" #include "src/rocm/param_pack/opr_impl.h" +#include "src/rocm/fill/opr_impl.h" #include #include @@ -176,6 +177,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ParamPackConcat); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(Fill); #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wpragmas" diff --git a/dnn/test/common/fill.h b/dnn/test/common/fill.h new file mode 100644 index 000000000..26fff73cb --- /dev/null +++ b/dnn/test/common/fill.h @@ -0,0 +1,37 @@ +/** + * \file dnn/test/common/fill.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "megdnn/handle.h" +#include "megdnn/oprs/general.h" + +#include "src/common/opr_trait.h" +#include "test/common/checker.h" + +namespace megdnn { +namespace test { +namespace fill { + +inline void run_fill_test(Handle* handle, DType dtype) { + Checker checker(handle); + for (float value : {-1.23, 0.0, 0.001, 234.0, 2021.072}) { + checker.set_param({value}); + checker.set_dtype(0, dtype); + checker.exec(TensorShapeArray{{1, 1}}); + checker.exec(TensorShapeArray{{2, 3, 4}}); + } +} + +} // namespace fill +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/cuda/fill.cpp b/dnn/test/cuda/fill.cpp new file mode 100644 index 000000000..cd57c1014 --- /dev/null +++ b/dnn/test/cuda/fill.cpp @@ -0,0 +1,36 @@ +/** + * \file dnn/test/cuda/fill.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/common/fill.h" + +#include "test/cuda/fixture.h" + +namespace megdnn { +namespace test { +namespace fill { + +TEST_F(CUDA, FILL_F32) { + run_fill_test(handle_cuda(), dtype::Float32{}); +} + +TEST_F(CUDA, FILL_I32) { + run_fill_test(handle_cuda(), dtype::Int32{}); +} + +#if !MEGDNN_DISABLE_FLOAT16 +TEST_F(CUDA, FILL_F16) { + run_fill_test(handle_cuda(), dtype::Float16{}); +} +#endif + +} // namespace fill +} // namespace test +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/test/rocm/fill.cpp b/dnn/test/rocm/fill.cpp new file mode 100644 index 000000000..731722920 --- /dev/null +++ b/dnn/test/rocm/fill.cpp @@ -0,0 +1,36 @@ +/** + * \file dnn/test/rocm/fill.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/common/fill.h" + +#include "test/rocm/fixture.h" + +namespace megdnn { +namespace test { +namespace fill { + +TEST_F(ROCM, FILL_F32) { + run_fill_test(handle_rocm(), dtype::Float32{}); +} + +TEST_F(ROCM, FILL_I32) { + run_fill_test(handle_rocm(), dtype::Int32{}); +} + +#if !MEGDNN_DISABLE_FLOAT16 +TEST_F(ROCM, FILL_F16) { + run_fill_test(handle_rocm(), dtype::Float16{}); +} +#endif + +} // namespace fill +} // namespace test +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} -- GitLab