提交 b2f0ceb2 编写于 作者: M Megvii Engine Team

feat(dnn/naive): add convolution weight preprocess interface

GitOrigin-RevId: d0fd6c75a6c45922f18e734c134c5ddcb7bfa8d9
上级 9b908c02
...@@ -131,6 +131,13 @@ public: ...@@ -131,6 +131,13 @@ public:
} }
}; };
struct PreprocessedFilter {
//! user data; its lifetime should be bound to MegDNN Convolution
//! operator
void* algorithm_id;
TensorNDArray tensors;
};
protected: protected:
// Check or deduce output DType // Check or deduce output DType
void check_or_deduce_dtype_fwd(DType src, DType filter, DType& dst) const; void check_or_deduce_dtype_fwd(DType src, DType filter, DType& dst) const;
...@@ -200,13 +207,27 @@ public: ...@@ -200,13 +207,27 @@ public:
* \param[out] dst (n, oc, oh, ow) * \param[out] dst (n, oc, oh, ow)
*/ */
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; _megdnn_tensor_out dst,
const PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) = 0;
virtual void exec_preprocess(const TensorLayout& src_layout,
_megdnn_tensor_in filter,
const TensorLayout& dst_layout,
PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) = 0;
void deduce_dtype(DType src, DType filter, DType& dst); void deduce_dtype(DType src, DType filter, DType& dst);
void deduce_layout(const TensorLayout& src, const TensorLayout& filter, void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
TensorLayout& dst); TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src, virtual size_t get_workspace_in_bytes(
const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) = 0; const TensorLayout& dst,
PreprocessedFilter* preprocessed_filter) = 0;
virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) = 0;
virtual size_t get_preprocess_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) = 0;
protected: protected:
CanonizedFilterMeta check_exec(const TensorLayout& src, CanonizedFilterMeta check_exec(const TensorLayout& src,
...@@ -297,17 +318,35 @@ public: ...@@ -297,17 +318,35 @@ public:
*/ */
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
_megdnn_tensor_in bias, _megdnn_tensor_in z, _megdnn_tensor_in bias, _megdnn_tensor_in z,
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; _megdnn_tensor_out dst,
const PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) = 0;
virtual void exec_preprocess(const TensorLayout& src_layout,
_megdnn_tensor_in filter,
const TensorLayout& bias_layout,
const TensorLayout& z_layout,
const TensorLayout& dst_layout,
PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) = 0;
void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst); void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst);
void deduce_layout(const TensorLayout& src, const TensorLayout& filter, void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z, const TensorLayout& bias, const TensorLayout& z,
TensorLayout& dst); TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src, virtual size_t get_workspace_in_bytes(
const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& z, const TensorLayout& dst,
const TensorLayout& dst) = 0; PreprocessedFilter* preprocessed_filter) = 0;
virtual size_t get_preprocess_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) = 0;
virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) = 0;
enum class BiasMode : uint32_t { enum class BiasMode : uint32_t {
NO_BIAS = 0, //!< no bias NO_BIAS = 0, //!< no bias
BROADCAST_CHANNEL_BIAS, //!< broadcast channel bias, [1, c, 1, 1] BROADCAST_CHANNEL_BIAS, //!< broadcast channel bias, [1, c, 1, 1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册