提交 1af350c6 编写于 作者: M Megvii Engine Team

feat(dnn): add fill kernel

GitOrigin-RevId: d2cee3a7a009849c2c608918e2ec51500bafa88b
上级 bbaf524f
...@@ -1338,6 +1338,21 @@ class CheckHasInf: public OperatorBase { ...@@ -1338,6 +1338,21 @@ class CheckHasInf: public OperatorBase {
void check_exec(const TensorLayout &src, const TensorLayout &dst, void check_exec(const TensorLayout &src, const TensorLayout &dst,
size_t workspace_in_bytes); size_t workspace_in_bytes);
}; };
/*!
* \brief fill the tensor with a scalar value
*/
class Fill: public OperatorBase {
DEF_OPR_PARAM(Fill);
DEF_OPR_IMPL(Fill, OperatorBase, 0, 1);
public:
virtual void exec(_megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout& dst) = 0;
protected:
void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
};
} // namespace megdnn } // namespace megdnn
#include "megdnn/internal/opr_header_epilogue.h" #include "megdnn/internal/opr_header_epilogue.h"
......
...@@ -1170,4 +1170,4 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o ...@@ -1170,4 +1170,4 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o
add_fields('int32', 'qmin', '-2147483648'). add_fields('int32', 'qmin', '-2147483648').
add_fields('int32', 'qmax', '2147483647') add_fields('int32', 'qmax', '2147483647')
) )
pdef('Fill').add_fields('float32', 'value', '0')
/**
* \file dnn/src/common/fill.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megdnn/oprs.h"
#include "src/common/utils.h"
namespace megdnn {
void Fill::check_exec(const TensorLayout& dst, size_t workspace_in_bytes) {
megdnn_assert_contiguous(dst);
auto required_workspace_in_bytes = get_workspace_in_bytes(dst);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
} // namespace megdnn
// vim: syntax=cpp.doxygen
...@@ -216,7 +216,8 @@ private: ...@@ -216,7 +216,8 @@ private:
cb(TQTBackward) \ cb(TQTBackward) \
cb(CheckHasInf) \ cb(CheckHasInf) \
cb(LSQForward) \ cb(LSQForward) \
cb(LSQBackward) cb(LSQBackward) \
cb(Fill)
/*! /*!
* \brief specialize HandleImpl::create_operator for a single opr type; * \brief specialize HandleImpl::create_operator for a single opr type;
......
...@@ -130,6 +130,7 @@ DEF(ChecksumForward, 1, true, false); ...@@ -130,6 +130,7 @@ DEF(ChecksumForward, 1, true, false);
DEF(CheckHasInf, 2, true, true); DEF(CheckHasInf, 2, true, true);
DEF(LSQForward, 5, true, true); DEF(LSQForward, 5, true, true);
DEF(LSQBackward, 7, true, false); DEF(LSQBackward, 7, true, false);
DEF(Fill, 1, true, false);
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
/**
* \file dnn/src/cuda/fill/kern.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/fill/kern.cuh"
#include "megdnn/dtype.h"
#include "src/cuda/utils.cuh"
namespace {
template <typename T>
__global__ void kernel(T *dst, T value, uint32_t size) {
int32_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i < size) {
dst[i] = value;
}
}
} // anonymous namespace
namespace megdnn {
namespace cuda {
namespace fill {
template <typename T>
void exec_internal(T *dst, T value, size_t size, cudaStream_t stream) {
kernel<T><<<DIVUP(size, NR_THREADS), NR_THREADS, 0, stream>>>(dst, value, size);
after_kernel_launch();
}
#define INST(T) template void exec_internal<T>(T *, \
T, size_t, cudaStream_t);
#define cb(DType) INST(typename DTypeTrait<DType>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
} // namespace fill
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file dnn/src/cuda/fill/kern.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <stdint.h>
#include <cuda_runtime_api.h>
namespace megdnn {
namespace cuda {
namespace fill {
template <typename T>
void exec_internal(T *dst, T value, size_t size, cudaStream_t stream);
} // namespace fill
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file dnn/src/cuda/fill/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/fill/kern.cuh"
#include "src/cuda/fill/opr_impl.h"
#include "src/cuda/utils.h"
namespace megdnn {
namespace cuda {
void FillImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) {
check_exec(dst.layout, workspace.size);
auto stream = cuda_stream(handle());
auto size = dst.layout.total_nr_elems();
#define cb(DType) \
if (dst.layout.dtype == DType()) { \
using ctype = typename DTypeTrait<DType>::ctype; \
fill::exec_internal<ctype>(dst.ptr<ctype>(), \
static_cast<ctype>(param().value), size, stream); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
}
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cuda/fill/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace cuda {
class FillImpl final : public Fill {
public:
using Fill::Fill;
void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout &) override {
return 0;
}
};
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
...@@ -38,6 +38,7 @@ ...@@ -38,6 +38,7 @@
#include "src/cuda/elemwise_multi_type/opr_impl.h" #include "src/cuda/elemwise_multi_type/opr_impl.h"
#include "src/cuda/eye/opr_impl.h" #include "src/cuda/eye/opr_impl.h"
#include "src/cuda/fake_quant/opr_impl.h" #include "src/cuda/fake_quant/opr_impl.h"
#include "src/cuda/fill/opr_impl.h"
#include "src/cuda/flip/opr_impl.h" #include "src/cuda/flip/opr_impl.h"
#include "src/cuda/gaussian_blur/opr_impl.h" #include "src/cuda/gaussian_blur/opr_impl.h"
#include "src/cuda/group_local/opr_impl.h" #include "src/cuda/group_local/opr_impl.h"
......
/**
* \file dnn/src/naive/fill/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/naive/fill/opr_impl.h"
#include "src/naive/handle.h"
#include "src/common/utils.h"
#include <cstring>
#include <limits>
namespace megdnn {
namespace naive {
void FillImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) {
check_exec(dst.layout, workspace.size);
size_t size = dst.layout.total_nr_elems();
#define cb(DType) \
if (dst.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
using ctype = typename DTypeTrait<DType>::ctype; \
ctype *ptr = dst.ptr<ctype>(); \
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal<ctype>(ptr, size)); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
}
template <typename ctype>
void FillImpl::exec_internal(ctype *dst, size_t size) {
auto value = static_cast<ctype>(param().value);
for (size_t i = 0; i < size; ++i) {
dst[i] = value;
}
}
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/naive/fill/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace naive {
class FillImpl : public Fill {
public:
using Fill::Fill;
void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout &) override {
return 0;
}
private:
template <typename ctype>
void exec_internal(ctype *dst, size_t size);
};
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
...@@ -40,6 +40,7 @@ ...@@ -40,6 +40,7 @@
#include "src/naive/elemwise_multi_type/opr_impl.h" #include "src/naive/elemwise_multi_type/opr_impl.h"
#include "src/naive/eye/opr_impl.h" #include "src/naive/eye/opr_impl.h"
#include "src/naive/fake_quant/opr_impl.h" #include "src/naive/fake_quant/opr_impl.h"
#include "src/naive/fill/opr_impl.h"
#include "src/naive/flip/opr_impl.h" #include "src/naive/flip/opr_impl.h"
#include "src/naive/gaussian_blur/opr_impl.h" #include "src/naive/gaussian_blur/opr_impl.h"
#include "src/naive/group_local/opr_impl.h" #include "src/naive/group_local/opr_impl.h"
......
/**
* \file dnn/src/rocm/fill/fill.cpp.hip
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "hcc_detail/hcc_defs_prologue.h"
#include "hip_header.h"
#include "megdnn/dtype.h"
#include "src/rocm/fill/fill.h.hip"
#include "src/rocm/utils.h.hip"
namespace {
template <typename T>
__global__ void kernel(T *dst, T value, uint32_t size) {
int32_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i < size) {
dst[i] = value;
}
}
} // anonymous namespace
namespace megdnn {
namespace rocm {
namespace fill {
template <typename T>
void exec_internal(T *dst, T value, size_t size, hipStream_t stream) {
hipLaunchKernelGGL(
(kernel<T>),
dim3(DIVUP(size, NR_THREADS)),
dim3(NR_THREADS),
0, stream, dst, value, size);
after_kernel_launch();
}
#define INST(T) template void exec_internal<T>(T *, \
T, size_t, hipStream_t);
#define cb(DType) INST(typename DTypeTrait<DType>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
} // namespace fill
} // namespace rocm
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file dnn/src/rocm/fill/fill.h.hip
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <stdint.h>
#include "hip_header.h"
namespace megdnn {
namespace rocm {
namespace fill {
template <typename T>
void exec_internal(T *dst, T value, size_t size, hipStream_t stream);
} // namespace fill
} // namespace rocm
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file dnn/src/rocm/fill/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "hcc_detail/hcc_defs_prologue.h"
#include "src/rocm/fill/opr_impl.h"
#include "src/rocm/fill/fill.h.hip"
#include "src/rocm/utils.h"
namespace megdnn {
namespace rocm {
void FillImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) {
check_exec(dst.layout, workspace.size);
auto stream = hip_stream(handle());
auto size = dst.layout.total_nr_elems();
#define cb(DType) \
if (dst.layout.dtype == DType()) { \
using ctype = typename DTypeTrait<DType>::ctype; \
fill::exec_internal<ctype>(dst.ptr<ctype>(), \
static_cast<ctype>(param().value), size, stream); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
}
} // namespace rocm
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file dnn/src/rocm/fill/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace rocm {
class FillImpl final : public Fill {
public:
using Fill::Fill;
void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; }
};
} // namespace rocm
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include "src/rocm/sleep/opr_impl.h" #include "src/rocm/sleep/opr_impl.h"
#include "src/rocm/batch_normalization/opr_impl.h" #include "src/rocm/batch_normalization/opr_impl.h"
#include "src/rocm/param_pack/opr_impl.h" #include "src/rocm/param_pack/opr_impl.h"
#include "src/rocm/fill/opr_impl.h"
#include <miopen/version.h> #include <miopen/version.h>
#include <hip/hip_version.h> #include <hip/hip_version.h>
...@@ -176,6 +177,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward); ...@@ -176,6 +177,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ParamPackConcat); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ParamPackConcat);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Fill);
#pragma GCC diagnostic push #pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpragmas" #pragma GCC diagnostic ignored "-Wpragmas"
......
/**
* \file dnn/test/common/fill.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/handle.h"
#include "megdnn/oprs/general.h"
#include "src/common/opr_trait.h"
#include "test/common/checker.h"
namespace megdnn {
namespace test {
namespace fill {
inline void run_fill_test(Handle* handle, DType dtype) {
Checker<Fill> checker(handle);
for (float value : {-1.23, 0.0, 0.001, 234.0, 2021.072}) {
checker.set_param({value});
checker.set_dtype(0, dtype);
checker.exec(TensorShapeArray{{1, 1}});
checker.exec(TensorShapeArray{{2, 3, 4}});
}
}
} // namespace fill
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/test/cuda/fill.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "test/common/fill.h"
#include "test/cuda/fixture.h"
namespace megdnn {
namespace test {
namespace fill {
TEST_F(CUDA, FILL_F32) {
run_fill_test(handle_cuda(), dtype::Float32{});
}
TEST_F(CUDA, FILL_I32) {
run_fill_test(handle_cuda(), dtype::Int32{});
}
#if !MEGDNN_DISABLE_FLOAT16
TEST_F(CUDA, FILL_F16) {
run_fill_test(handle_cuda(), dtype::Float16{});
}
#endif
} // namespace fill
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file dnn/test/rocm/fill.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "test/common/fill.h"
#include "test/rocm/fixture.h"
namespace megdnn {
namespace test {
namespace fill {
TEST_F(ROCM, FILL_F32) {
run_fill_test(handle_rocm(), dtype::Float32{});
}
TEST_F(ROCM, FILL_I32) {
run_fill_test(handle_rocm(), dtype::Int32{});
}
#if !MEGDNN_DISABLE_FLOAT16
TEST_F(ROCM, FILL_F16) {
run_fill_test(handle_rocm(), dtype::Float16{});
}
#endif
} // namespace fill
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册