diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index b3bbd648ed612cc9d835e6550261311bf02cb8fa..93986a19031307ea585f401b747f524b180faccb 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -83,6 +83,21 @@ struct CalibParam : ParamBase { const lite::Tensor* input{}; lite::Tensor* output{}; float scale; + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { + input_tensor_ptrs_cache_.reset(new std::vector({input})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } }; struct SubgraphParam : ParamBase { @@ -364,6 +379,22 @@ struct ActivationParam : ParamBase { float relu_threshold{1.0f}; // elu float Elu_alpha{1.0f}; + + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { + input_tensor_ptrs_cache_.reset(new std::vector({X})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { + output_tensor_ptrs_cache_.reset(new std::vector({Out})); + } + return output_tensor_ptrs_cache_.get(); + } }; struct ActivationGradParam : ParamBase { @@ -800,6 +831,23 @@ struct BoxCoderParam : ParamBase { bool box_normalized{true}; int axis{0}; std::vector variance{}; + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { + input_tensor_ptrs_cache_.reset(new std::vector( + {prior_box, prior_box_var, target_box})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { + output_tensor_ptrs_cache_.reset( + new std::vector({proposals})); + } + return output_tensor_ptrs_cache_.get(); + } }; /// ----------------------- multiclass_nms operators ---------------------- @@ -839,6 +887,23 @@ struct PriorBoxParam : ParamBase { // priortype: prior_min, prior_max, prior_com std::vector order; bool min_max_aspect_ratios_order{false}; + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { + input_tensor_ptrs_cache_.reset( + new std::vector({input, image})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { + output_tensor_ptrs_cache_.reset( + new std::vector({boxes, variances})); + } + return output_tensor_ptrs_cache_.get(); + } }; struct DensityPriorBoxParam : public PriorBoxParam {