From 17f006359cf04bb7d58d292571aea520fbb77feb Mon Sep 17 00:00:00 2001 From: huzhiqiang <912790387@qq.com> Date: Wed, 19 Aug 2020 11:47:09 +0800 Subject: [PATCH] [Framework] Update InferShape period (#4146) --- lite/operators/op_params.h | 65 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index b3bbd648ed..93986a1903 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -83,6 +83,21 @@ struct CalibParam : ParamBase { const lite::Tensor* input{}; lite::Tensor* output{}; float scale; + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { + input_tensor_ptrs_cache_.reset(new std::vector({input})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { + output_tensor_ptrs_cache_.reset(new std::vector({output})); + } + return output_tensor_ptrs_cache_.get(); + } }; struct SubgraphParam : ParamBase { @@ -364,6 +379,22 @@ struct ActivationParam : ParamBase { float relu_threshold{1.0f}; // elu float Elu_alpha{1.0f}; + + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { + input_tensor_ptrs_cache_.reset(new std::vector({X})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { + output_tensor_ptrs_cache_.reset(new std::vector({Out})); + } + return output_tensor_ptrs_cache_.get(); + } }; struct ActivationGradParam : ParamBase { @@ -800,6 +831,23 @@ struct BoxCoderParam : ParamBase { bool box_normalized{true}; int axis{0}; std::vector variance{}; + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { + input_tensor_ptrs_cache_.reset(new std::vector( + {prior_box, prior_box_var, target_box})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { + output_tensor_ptrs_cache_.reset( + new std::vector({proposals})); + } + return output_tensor_ptrs_cache_.get(); + } }; /// ----------------------- multiclass_nms operators ---------------------- @@ -839,6 +887,23 @@ struct PriorBoxParam : ParamBase { // priortype: prior_min, prior_max, prior_com std::vector order; bool min_max_aspect_ratios_order{false}; + /////////////////////////////////////////////////////////////////////////////////// + // get a vector of input tensors + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { + input_tensor_ptrs_cache_.reset( + new std::vector({input, image})); + } + return input_tensor_ptrs_cache_.get(); + } + // get a vector of output tensors + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { + output_tensor_ptrs_cache_.reset( + new std::vector({boxes, variances})); + } + return output_tensor_ptrs_cache_.get(); + } }; struct DensityPriorBoxParam : public PriorBoxParam { -- GitLab