未验证 提交 4b0d60e7 编写于 作者: H HappyAngel 提交者: GitHub

[arm]imporve infer_shape profile (#3308)

* fix format, test=develop

* add some op infershape implement, test=develop
上级 24b270d7
......@@ -158,6 +158,21 @@ struct MulParam : ParamBase {
int y_num_col_dims{1};
// for int8
WITH_INT8_CONFIG
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x, y}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
}
};
struct MulGradParam : ParamBase {
......@@ -226,6 +241,21 @@ struct ScaleParam : ParamBase {
float scale{1.};
float bias{};
bool bias_after_scale{true};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
}
};
// For Softmax op
......@@ -260,6 +290,21 @@ struct ReshapeParam : ParamBase {
lite::Tensor* xshape{};
bool inplace{false};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
}
};
// For Concat op
......@@ -268,6 +313,24 @@ struct ConcatParam : ParamBase {
lite::Tensor* output{};
int axis{0};
lite::Tensor* axis_tensor{};
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
std::vector<const Tensor*> vec;
for (auto in : x) {
vec.push_back(in);
}
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>(vec));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
}
};
/// ----------------------- activation operators ----------------------
......@@ -370,6 +433,21 @@ struct BatchNormParam : ParamBase {
float epsilon;
float momentum;
DataLayoutType data_layout{DATALAYOUT(kNCHW)};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({y}));
}
return output_tensor_ptrs_cache_.get();
}
};
// For Pooling op
......@@ -394,6 +472,21 @@ struct PoolParam : ParamBase {
std::string data_format{"AnyLayout"};
// for int8
WITH_INT8_CONFIG
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
}
};
// For Dropout op
......@@ -418,6 +511,21 @@ struct SplitParam : ParamBase {
int axis{-1};
int num{0};
std::vector<int> sections;
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
}
};
// For Transpose op
......@@ -429,6 +537,21 @@ struct TransposeParam : ParamBase {
std::vector<int> axis;
bool use_mkldnn{false};
std::string data_format{"AnyLayout"};
///////////////////////////////////////////////////////////////////////////////////
// // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
}
};
/// ----------------------- element wise operators ----------------------
......@@ -754,6 +877,21 @@ struct Im2SequenceParam : ParamBase {
struct SequenceSoftmaxParam : ParamBase {
const lite::Tensor* X{};
lite::Tensor* Out{};
///////////////////////////////////////////////////////////////////////////////////
// // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out}));
}
return output_tensor_ptrs_cache_.get();
}
};
struct NormParam : ParamBase {
......@@ -984,6 +1122,21 @@ struct SliceParam : ParamBase {
std::vector<lite::Tensor*> EndsTensorList{};
lite::Tensor* StartsTensor{nullptr};
lite::Tensor* EndsTensor{nullptr};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out}));
}
return output_tensor_ptrs_cache_.get();
}
};
struct AffineChannelParam : ParamBase {
......@@ -1031,6 +1184,21 @@ struct SqueezeParam : ParamBase {
lite::Tensor* Out{};
lite::Tensor* XShape{};
std::vector<int> axes{};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out}));
}
return output_tensor_ptrs_cache_.get();
}
};
struct UnsqueezeParam : ParamBase {
......@@ -1040,6 +1208,21 @@ struct UnsqueezeParam : ParamBase {
std::vector<int> axes{};
const lite::Tensor* axes_tensor{};
std::vector<const lite::Tensor*> axes_tensor_vct{};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out}));
}
return output_tensor_ptrs_cache_.get();
}
};
/// ----------------------- expand operators ----------------------
......@@ -1057,6 +1240,21 @@ struct MatMulParam : ParamBase {
bool transpose_X{false};
bool transpose_Y{false};
float alpha{1.0f};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X, Y}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out}));
}
return output_tensor_ptrs_cache_.get();
}
};
struct GatherParam : ParamBase {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册