diff --git a/dnn/include/megdnn/oprs/nn.h b/dnn/include/megdnn/oprs/nn.h index 05cfd5bb1e2175437906a1995a58eae7dd3849d3..4d8e24182da3a3c9bb9acafaa9b371beb76d647e 100644 --- a/dnn/include/megdnn/oprs/nn.h +++ b/dnn/include/megdnn/oprs/nn.h @@ -131,6 +131,13 @@ public: } }; + struct PreprocessedFilter { + //! user data; its lifetime should be bound to MegDNN Convolution + //! operator + void* algorithm_id; + TensorNDArray tensors; + }; + protected: // Check or deduce output DType void check_or_deduce_dtype_fwd(DType src, DType filter, DType& dst) const; @@ -200,13 +207,27 @@ public: * \param[out] dst (n, oc, oh, ow) */ virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, - _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; + _megdnn_tensor_out dst, + const PreprocessedFilter* preprocessed_filter, + _megdnn_workspace workspace) = 0; + virtual void exec_preprocess(const TensorLayout& src_layout, + _megdnn_tensor_in filter, + const TensorLayout& dst_layout, + PreprocessedFilter* preprocessed_filter, + _megdnn_workspace workspace) = 0; void deduce_dtype(DType src, DType filter, DType& dst); void deduce_layout(const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst); - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst, + PreprocessedFilter* preprocessed_filter) = 0; + virtual SmallVector deduce_preprocessed_filter_layout( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) = 0; + virtual size_t get_preprocess_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) = 0; protected: CanonizedFilterMeta check_exec(const TensorLayout& src, @@ -297,17 +318,35 @@ public: */ virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias, _megdnn_tensor_in z, - _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; + _megdnn_tensor_out dst, + const PreprocessedFilter* preprocessed_filter, + _megdnn_workspace workspace) = 0; + virtual void exec_preprocess(const TensorLayout& src_layout, + _megdnn_tensor_in filter, + const TensorLayout& bias_layout, + const TensorLayout& z_layout, + const TensorLayout& dst_layout, + PreprocessedFilter* preprocessed_filter, + _megdnn_workspace workspace) = 0; void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst); void deduce_layout(const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, const TensorLayout& z, TensorLayout& dst); - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& bias, - const TensorLayout& z, - const TensorLayout& dst) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, const TensorLayout& z, + const TensorLayout& dst, + PreprocessedFilter* preprocessed_filter) = 0; + virtual size_t get_preprocess_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, const TensorLayout& z, + const TensorLayout& dst) = 0; + virtual SmallVector deduce_preprocessed_filter_layout( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, const TensorLayout& z, + const TensorLayout& dst) = 0; + enum class BiasMode : uint32_t { NO_BIAS = 0, //!< no bias BROADCAST_CHANNEL_BIAS, //!< broadcast channel bias, [1, c, 1, 1]