diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 1e221a602a426f3f117c69b9525f2a1d85880ee0..3fdca389bca1ba09ebfe008365b6992b717270d8 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -158,6 +158,21 @@ struct MulParam : ParamBase { int y_num_col_dims{1}; // for int8 WITH_INT8_CONFIG + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({x, y})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } }; struct MulGradParam : ParamBase { @@ -226,6 +241,21 @@ struct ScaleParam : ParamBase { float scale{1.}; float bias{}; bool bias_after_scale{true}; + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({x})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } }; // For Softmax op @@ -260,6 +290,21 @@ struct ReshapeParam : ParamBase { lite::Tensor* xshape{}; bool inplace{false}; + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({x})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } }; // For Concat op @@ -268,6 +313,24 @@ struct ConcatParam : ParamBase { lite::Tensor* output{}; int axis{0}; lite::Tensor* axis_tensor{}; + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + std::vector vec; + for (auto in : x) { + vec.push_back(in); + } + input_tensor_ptrs_cache_.reset(new std::vector(vec)); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } }; /// ----------------------- activation operators ---------------------- @@ -370,6 +433,21 @@ struct BatchNormParam : ParamBase { float epsilon; float momentum; DataLayoutType data_layout{DATALAYOUT(kNCHW)}; + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({x})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({y})); + } + return output_tensor_ptrs_cache_.get(); + } }; // For Pooling op @@ -394,6 +472,21 @@ struct PoolParam : ParamBase { std::string data_format{"AnyLayout"}; // for int8 WITH_INT8_CONFIG + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({x})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } }; // For Dropout op @@ -418,6 +511,21 @@ struct SplitParam : ParamBase { int axis{-1}; int num{0}; std::vector sections; + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({x})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } }; // For Transpose op @@ -429,6 +537,21 @@ struct TransposeParam : ParamBase { std::vector axis; bool use_mkldnn{false}; std::string data_format{"AnyLayout"}; + /////////////////////////////////////////////////////////////////////////////////// + // // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({x})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } }; /// ----------------------- element wise operators ---------------------- @@ -754,6 +877,21 @@ struct Im2SequenceParam : ParamBase { struct SequenceSoftmaxParam : ParamBase { const lite::Tensor* X{}; lite::Tensor* Out{}; + /////////////////////////////////////////////////////////////////////////////////// + // // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({X})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({Out})); + } + return output_tensor_ptrs_cache_.get(); + } }; struct NormParam : ParamBase { @@ -984,6 +1122,21 @@ struct SliceParam : ParamBase { std::vector EndsTensorList{}; lite::Tensor* StartsTensor{nullptr}; lite::Tensor* EndsTensor{nullptr}; + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({X})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({Out})); + } + return output_tensor_ptrs_cache_.get(); + } }; struct AffineChannelParam : ParamBase { @@ -1031,6 +1184,21 @@ struct SqueezeParam : ParamBase { lite::Tensor* Out{}; lite::Tensor* XShape{}; std::vector axes{}; + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({X})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({Out})); + } + return output_tensor_ptrs_cache_.get(); + } }; struct UnsqueezeParam : ParamBase { @@ -1040,6 +1208,21 @@ struct UnsqueezeParam : ParamBase { std::vector axes{}; const lite::Tensor* axes_tensor{}; std::vector axes_tensor_vct{}; + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({X})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({Out})); + } + return output_tensor_ptrs_cache_.get(); + } }; /// ----------------------- expand operators ---------------------- @@ -1057,6 +1240,21 @@ struct MatMulParam : ParamBase { bool transpose_X{false}; bool transpose_Y{false}; float alpha{1.0f}; + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() { + if (UNLIKELY(input_tensor_ptrs_cache_)) { + input_tensor_ptrs_cache_.reset(new std::vector({X, Y})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + const std::vector* output_tensor_ptrs() { + if (UNLIKELY(output_tensor_ptrs_cache_)) { + output_tensor_ptrs_cache_.reset(new std::vector({Out})); + } + return output_tensor_ptrs_cache_.get(); + } }; struct GatherParam : ParamBase {