提交 b2f0ceb2 编写于 作者: M Megvii Engine Team

feat(dnn/naive): add convolution weight preprocess interface

GitOrigin-RevId: d0fd6c75a6c45922f18e734c134c5ddcb7bfa8d9
上级 9b908c02
......@@ -131,6 +131,13 @@ public:
}
};
struct PreprocessedFilter {
//! user data; its lifetime should be bound to MegDNN Convolution
//! operator
void* algorithm_id;
TensorNDArray tensors;
};
protected:
// Check or deduce output DType
void check_or_deduce_dtype_fwd(DType src, DType filter, DType& dst) const;
......@@ -200,12 +207,26 @@ public:
* \param[out] dst (n, oc, oh, ow)
*/
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
_megdnn_tensor_out dst,
const PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) = 0;
virtual void exec_preprocess(const TensorLayout& src_layout,
_megdnn_tensor_in filter,
const TensorLayout& dst_layout,
PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) = 0;
void deduce_dtype(DType src, DType filter, DType& dst);
void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& filter,
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst,
PreprocessedFilter* preprocessed_filter) = 0;
virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) = 0;
virtual size_t get_preprocess_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) = 0;
protected:
......@@ -297,17 +318,35 @@ public:
*/
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
_megdnn_tensor_in bias, _megdnn_tensor_in z,
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
_megdnn_tensor_out dst,
const PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) = 0;
virtual void exec_preprocess(const TensorLayout& src_layout,
_megdnn_tensor_in filter,
const TensorLayout& bias_layout,
const TensorLayout& z_layout,
const TensorLayout& dst_layout,
PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) = 0;
void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst);
void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& bias,
const TensorLayout& z,
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst,
PreprocessedFilter* preprocessed_filter) = 0;
virtual size_t get_preprocess_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) = 0;
virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) = 0;
enum class BiasMode : uint32_t {
NO_BIAS = 0, //!< no bias
BROADCAST_CHANNEL_BIAS, //!< broadcast channel bias, [1, c, 1, 1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册