未验证 提交 17f00635 编写于 作者: H huzhiqiang 提交者: GitHub

[Framework] Update InferShape period (#4146)

上级 dfdfa644
......@@ -83,6 +83,21 @@ struct CalibParam : ParamBase {
const lite::Tensor* input{};
lite::Tensor* output{};
float scale;
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({input}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
}
};
struct SubgraphParam : ParamBase {
......@@ -364,6 +379,22 @@ struct ActivationParam : ParamBase {
float relu_threshold{1.0f};
// elu
float Elu_alpha{1.0f};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out}));
}
return output_tensor_ptrs_cache_.get();
}
};
struct ActivationGradParam : ParamBase {
......@@ -800,6 +831,23 @@ struct BoxCoderParam : ParamBase {
bool box_normalized{true};
int axis{0};
std::vector<float> variance{};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>(
{prior_box, prior_box_var, target_box}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(
new std::vector<lite::Tensor*>({proposals}));
}
return output_tensor_ptrs_cache_.get();
}
};
/// ----------------------- multiclass_nms operators ----------------------
......@@ -839,6 +887,23 @@ struct PriorBoxParam : ParamBase {
// priortype: prior_min, prior_max, prior_com
std::vector<std::string> order;
bool min_max_aspect_ratios_order{false};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(
new std::vector<const Tensor*>({input, image}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(
new std::vector<lite::Tensor*>({boxes, variances}));
}
return output_tensor_ptrs_cache_.get();
}
};
struct DensityPriorBoxParam : public PriorBoxParam {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册