diff --git a/dnn/include/megdnn/oprs/general.h b/dnn/include/megdnn/oprs/general.h index 829d5395f6177ca37ae5c8ff98dd79133284db49..2f9abccd0ebe18b3b01f62cb69097c4ee46e3ef9 100644 --- a/dnn/include/megdnn/oprs/general.h +++ b/dnn/include/megdnn/oprs/general.h @@ -1338,6 +1338,21 @@ class CheckHasInf: public OperatorBase { void check_exec(const TensorLayout &src, const TensorLayout &dst, size_t workspace_in_bytes); }; + +/*! + * \brief fill the tensor with a scalar value + */ +class Fill: public OperatorBase { + DEF_OPR_PARAM(Fill); + DEF_OPR_IMPL(Fill, OperatorBase, 0, 1); + +public: + virtual void exec(_megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes(const TensorLayout& dst) = 0; +protected: + void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); +}; } // namespace megdnn #include "megdnn/internal/opr_header_epilogue.h" diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index f1aa516c10cce4fab3a580f9ef039c6976d05d11..e8c6616767d68ee20f0ccac29fa9e991029d5e2e 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -1170,4 +1170,4 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o add_fields('int32', 'qmin', '-2147483648'). add_fields('int32', 'qmax', '2147483647') ) - +pdef('Fill').add_fields('float32', 'value', '0') diff --git a/dnn/src/common/fill.cpp b/dnn/src/common/fill.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d4ccf2fa932a4b20268c66657d65087d67db5164 --- /dev/null +++ b/dnn/src/common/fill.cpp @@ -0,0 +1,25 @@ +/** + * \file dnn/src/common/fill.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megdnn/oprs.h" +#include "src/common/utils.h" + +namespace megdnn { + +void Fill::check_exec(const TensorLayout& dst, size_t workspace_in_bytes) { + megdnn_assert_contiguous(dst); + auto required_workspace_in_bytes = get_workspace_in_bytes(dst); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); +} + +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index 1f6431f26ec2f55a1178e5cf3ee7df86ee7d0261..4e393752f6d272b2139be3d5e4ef69dfccf55132 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -216,7 +216,8 @@ private: cb(TQTBackward) \ cb(CheckHasInf) \ cb(LSQForward) \ - cb(LSQBackward) + cb(LSQBackward) \ + cb(Fill) /*! * \brief specialize HandleImpl::create_operator for a single opr type; diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h index 952010620884d38c34c4996f2a97fc86ee9d0cfa..1417bdce731ff30d7108876a616f7992f8ae972c 100644 --- a/dnn/src/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -130,6 +130,7 @@ DEF(ChecksumForward, 1, true, false); DEF(CheckHasInf, 2, true, true); DEF(LSQForward, 5, true, true); DEF(LSQBackward, 7, true, false); +DEF(Fill, 1, true, false); } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/fill/kern.cu b/dnn/src/cuda/fill/kern.cu new file mode 100644 index 0000000000000000000000000000000000000000..1daf50a9452b34364ba107abdefc0d3de153f720 --- /dev/null +++ b/dnn/src/cuda/fill/kern.cu @@ -0,0 +1,45 @@ +/** + * \file dnn/src/cuda/fill/kern.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/cuda/fill/kern.cuh" +#include "megdnn/dtype.h" +#include "src/cuda/utils.cuh" + +namespace { + +template +__global__ void kernel(T *dst, T value, uint32_t size) { + int32_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i < size) { + dst[i] = value; + } +} + +} // anonymous namespace + +namespace megdnn { +namespace cuda { +namespace fill { + +template +void exec_internal(T *dst, T value, size_t size, cudaStream_t stream) { + kernel<<>>(dst, value, size); + after_kernel_launch(); +} + +#define INST(T) template void exec_internal(T *, \ + T, size_t, cudaStream_t); +#define cb(DType) INST(typename DTypeTrait::ctype) +MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + +} // namespace fill +} // namespace cuda +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/cuda/fill/kern.cuh b/dnn/src/cuda/fill/kern.cuh new file mode 100644 index 0000000000000000000000000000000000000000..a79f93560b1ab0ae7c88fb70d9d70badfa6e355e --- /dev/null +++ b/dnn/src/cuda/fill/kern.cuh @@ -0,0 +1,25 @@ +/** + * \file dnn/src/cuda/fill/kern.cuh + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include +#include + +namespace megdnn { +namespace cuda { +namespace fill { + +template +void exec_internal(T *dst, T value, size_t size, cudaStream_t stream); + +} // namespace fill +} // namespace cuda +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/cuda/fill/opr_impl.cpp b/dnn/src/cuda/fill/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f9966a2c65c13f7868a2980a01c2cb6e3d273d71 --- /dev/null +++ b/dnn/src/cuda/fill/opr_impl.cpp @@ -0,0 +1,37 @@ +/** + * \file dnn/src/cuda/fill/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/cuda/fill/kern.cuh" +#include "src/cuda/fill/opr_impl.h" + +#include "src/cuda/utils.h" + +namespace megdnn { +namespace cuda { + +void FillImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { + check_exec(dst.layout, workspace.size); + auto stream = cuda_stream(handle()); + auto size = dst.layout.total_nr_elems(); +#define cb(DType) \ + if (dst.layout.dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ + fill::exec_internal(dst.ptr(), \ + static_cast(param().value), size, stream); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb +} + +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/fill/opr_impl.h b/dnn/src/cuda/fill/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..d29eae249b84869132efc7efc2bc1915f2103066 --- /dev/null +++ b/dnn/src/cuda/fill/opr_impl.h @@ -0,0 +1,31 @@ +/** + * \file dnn/src/cuda/fill/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace cuda { + +class FillImpl final : public Fill { +public: + using Fill::Fill; + void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout &) override { + return 0; + } +}; + +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index 83409e0682960e4cedf81e8685584dc90df8bda1..f9ab2233e6a89cf4c0c16c4517aee5396a2d3ee7 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -38,6 +38,7 @@ #include "src/cuda/elemwise_multi_type/opr_impl.h" #include "src/cuda/eye/opr_impl.h" #include "src/cuda/fake_quant/opr_impl.h" +#include "src/cuda/fill/opr_impl.h" #include "src/cuda/flip/opr_impl.h" #include "src/cuda/gaussian_blur/opr_impl.h" #include "src/cuda/group_local/opr_impl.h" diff --git a/dnn/src/naive/fill/opr_impl.cpp b/dnn/src/naive/fill/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b133a27f04bbf4f7debb8ee91d67774a97db914b --- /dev/null +++ b/dnn/src/naive/fill/opr_impl.cpp @@ -0,0 +1,46 @@ +/** + * \file dnn/src/naive/fill/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/naive/fill/opr_impl.h" +#include "src/naive/handle.h" +#include "src/common/utils.h" + +#include +#include + +namespace megdnn { +namespace naive { + +void FillImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { + check_exec(dst.layout, workspace.size); + size_t size = dst.layout.total_nr_elems(); +#define cb(DType) \ + if (dst.layout.dtype.enumv() == DTypeTrait::enumv) { \ + using ctype = typename DTypeTrait::ctype; \ + ctype *ptr = dst.ptr(); \ + MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal(ptr, size)); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb +} + +template +void FillImpl::exec_internal(ctype *dst, size_t size) { + auto value = static_cast(param().value); + for (size_t i = 0; i < size; ++i) { + dst[i] = value; + } +} + +} // namespace naive +} // namespace megdnn +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/naive/fill/opr_impl.h b/dnn/src/naive/fill/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..68f652e5a139700c3d5ab155de75a5461cb182d7 --- /dev/null +++ b/dnn/src/naive/fill/opr_impl.h @@ -0,0 +1,33 @@ +/** + * \file dnn/src/naive/fill/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace naive { + +class FillImpl : public Fill { +public: + using Fill::Fill; + + void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout &) override { + return 0; + } +private: + template + void exec_internal(ctype *dst, size_t size); +}; + +} // namespace naive +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index 2d6bd32d7d488391a780886e38f5843a486e78da..6c76dc3391d90a0e5f6d528620dfb8d34c8e65d0 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -40,6 +40,7 @@ #include "src/naive/elemwise_multi_type/opr_impl.h" #include "src/naive/eye/opr_impl.h" #include "src/naive/fake_quant/opr_impl.h" +#include "src/naive/fill/opr_impl.h" #include "src/naive/flip/opr_impl.h" #include "src/naive/gaussian_blur/opr_impl.h" #include "src/naive/group_local/opr_impl.h" diff --git a/dnn/src/rocm/fill/fill.cpp.hip b/dnn/src/rocm/fill/fill.cpp.hip new file mode 100644 index 0000000000000000000000000000000000000000..6c252827fe8ed30d510f2987977c43f0f684db8b --- /dev/null +++ b/dnn/src/rocm/fill/fill.cpp.hip @@ -0,0 +1,51 @@ +/** + * \file dnn/src/rocm/fill/fill.cpp.hip + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" +#include "hip_header.h" +#include "megdnn/dtype.h" +#include "src/rocm/fill/fill.h.hip" +#include "src/rocm/utils.h.hip" + +namespace { + +template +__global__ void kernel(T *dst, T value, uint32_t size) { + int32_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i < size) { + dst[i] = value; + } +} + +} // anonymous namespace + +namespace megdnn { +namespace rocm { +namespace fill { + +template +void exec_internal(T *dst, T value, size_t size, hipStream_t stream) { + hipLaunchKernelGGL( + (kernel), + dim3(DIVUP(size, NR_THREADS)), + dim3(NR_THREADS), + 0, stream, dst, value, size); + after_kernel_launch(); +} + +#define INST(T) template void exec_internal(T *, \ + T, size_t, hipStream_t); +#define cb(DType) INST(typename DTypeTrait::ctype) +MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + +} // namespace fill +} // namespace rocm +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/rocm/fill/fill.h.hip b/dnn/src/rocm/fill/fill.h.hip new file mode 100644 index 0000000000000000000000000000000000000000..24b9ee01670392804012cf7e86ea70f120a9f183 --- /dev/null +++ b/dnn/src/rocm/fill/fill.h.hip @@ -0,0 +1,25 @@ +/** + * \file dnn/src/rocm/fill/fill.h.hip + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include +#include "hip_header.h" + +namespace megdnn { +namespace rocm { +namespace fill { + +template +void exec_internal(T *dst, T value, size_t size, hipStream_t stream); + +} // namespace fill +} // namespace rocm +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/rocm/fill/opr_impl.cpp b/dnn/src/rocm/fill/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e15bcfde25f4b94cbd1e4344cc086a1de94c67a9 --- /dev/null +++ b/dnn/src/rocm/fill/opr_impl.cpp @@ -0,0 +1,36 @@ +/** + * \file dnn/src/rocm/fill/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" +#include "src/rocm/fill/opr_impl.h" + +#include "src/rocm/fill/fill.h.hip" +#include "src/rocm/utils.h" + +namespace megdnn { +namespace rocm { + +void FillImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { + check_exec(dst.layout, workspace.size); + auto stream = hip_stream(handle()); + auto size = dst.layout.total_nr_elems(); +#define cb(DType) \ + if (dst.layout.dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ + fill::exec_internal(dst.ptr(), \ + static_cast(param().value), size, stream); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb +} + +} // namespace rocm +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/rocm/fill/opr_impl.h b/dnn/src/rocm/fill/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..4c2f0cc47e7dac6a77330fb2d7d048b2227d1ea6 --- /dev/null +++ b/dnn/src/rocm/fill/opr_impl.h @@ -0,0 +1,26 @@ +/** + * \file dnn/src/rocm/fill/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace rocm { + +class FillImpl final : public Fill { +public: + using Fill::Fill; + void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } +}; + +} // namespace rocm +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/rocm/handle.cpp b/dnn/src/rocm/handle.cpp index 3b5b8a9ec6bfd716f8d7f483e57a9fae729fb150..b3752e76346b263690d71b798d4644bbc8590a2a 100644 --- a/dnn/src/rocm/handle.cpp +++ b/dnn/src/rocm/handle.cpp @@ -37,6 +37,7 @@ #include "src/rocm/sleep/opr_impl.h" #include "src/rocm/batch_normalization/opr_impl.h" #include "src/rocm/param_pack/opr_impl.h" +#include "src/rocm/fill/opr_impl.h" #include #include @@ -176,6 +177,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ParamPackConcat); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(Fill); #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wpragmas" diff --git a/dnn/test/common/fill.h b/dnn/test/common/fill.h new file mode 100644 index 0000000000000000000000000000000000000000..26fff73cb5c7e5f8dad22c0eb52bd1a5a785bb32 --- /dev/null +++ b/dnn/test/common/fill.h @@ -0,0 +1,37 @@ +/** + * \file dnn/test/common/fill.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "megdnn/handle.h" +#include "megdnn/oprs/general.h" + +#include "src/common/opr_trait.h" +#include "test/common/checker.h" + +namespace megdnn { +namespace test { +namespace fill { + +inline void run_fill_test(Handle* handle, DType dtype) { + Checker checker(handle); + for (float value : {-1.23, 0.0, 0.001, 234.0, 2021.072}) { + checker.set_param({value}); + checker.set_dtype(0, dtype); + checker.exec(TensorShapeArray{{1, 1}}); + checker.exec(TensorShapeArray{{2, 3, 4}}); + } +} + +} // namespace fill +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/cuda/fill.cpp b/dnn/test/cuda/fill.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cd57c10146dc5fad3fe6a50a351e87afe54a6753 --- /dev/null +++ b/dnn/test/cuda/fill.cpp @@ -0,0 +1,36 @@ +/** + * \file dnn/test/cuda/fill.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/common/fill.h" + +#include "test/cuda/fixture.h" + +namespace megdnn { +namespace test { +namespace fill { + +TEST_F(CUDA, FILL_F32) { + run_fill_test(handle_cuda(), dtype::Float32{}); +} + +TEST_F(CUDA, FILL_I32) { + run_fill_test(handle_cuda(), dtype::Int32{}); +} + +#if !MEGDNN_DISABLE_FLOAT16 +TEST_F(CUDA, FILL_F16) { + run_fill_test(handle_cuda(), dtype::Float16{}); +} +#endif + +} // namespace fill +} // namespace test +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/test/rocm/fill.cpp b/dnn/test/rocm/fill.cpp new file mode 100644 index 0000000000000000000000000000000000000000..731722920fcae82a005ee1629c1936e58d75fdd4 --- /dev/null +++ b/dnn/test/rocm/fill.cpp @@ -0,0 +1,36 @@ +/** + * \file dnn/test/rocm/fill.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/common/fill.h" + +#include "test/rocm/fixture.h" + +namespace megdnn { +namespace test { +namespace fill { + +TEST_F(ROCM, FILL_F32) { + run_fill_test(handle_rocm(), dtype::Float32{}); +} + +TEST_F(ROCM, FILL_I32) { + run_fill_test(handle_rocm(), dtype::Int32{}); +} + +#if !MEGDNN_DISABLE_FLOAT16 +TEST_F(ROCM, FILL_F16) { + run_fill_test(handle_rocm(), dtype::Float16{}); +} +#endif + +} // namespace fill +} // namespace test +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}