提交 686245f3 编写于 作者: C chenjiaoAngel

add reshape infershape, test=develop

上级 91fe997a
...@@ -290,6 +290,21 @@ struct ReshapeParam : ParamBase { ...@@ -290,6 +290,21 @@ struct ReshapeParam : ParamBase {
lite::Tensor* xshape{}; lite::Tensor* xshape{};
bool inplace{false}; bool inplace{false};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if(UNLIKELY(output_tensor_ptrs_cache_)) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
}
}; };
// For Concat op // For Concat op
...@@ -301,7 +316,11 @@ struct ConcatParam : ParamBase { ...@@ -301,7 +316,11 @@ struct ConcatParam : ParamBase {
// get a vector of input tensors // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() { const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) { if (UNLIKELY(input_tensor_ptrs_cache_)) {
input_tensor_ptrs_cache_.reset(x); std::vector<const Tensor*> vec;
for (auto in : x){
vec.push_back(in);
}
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>(vec));
} }
return input_tensor_ptrs_cache_.get(); return input_tensor_ptrs_cache_.get();
} }
...@@ -425,7 +444,7 @@ struct BatchNormParam : ParamBase { ...@@ -425,7 +444,7 @@ struct BatchNormParam : ParamBase {
// get a vector of output tensors // get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() { const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) { if (UNLIKELY(output_tensor_ptrs_cache_)) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output})); output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({y}));
} }
return output_tensor_ptrs_cache_.get(); return output_tensor_ptrs_cache_.get();
} }
...@@ -518,6 +537,21 @@ struct TransposeParam : ParamBase { ...@@ -518,6 +537,21 @@ struct TransposeParam : ParamBase {
std::vector<int> axis; std::vector<int> axis;
bool use_mkldnn{false}; bool use_mkldnn{false};
std::string data_format{"AnyLayout"}; std::string data_format{"AnyLayout"};
///////////////////////////////////////////////////////////////////////////////////
// // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
}
}; };
/// ----------------------- element wise operators ---------------------- /// ----------------------- element wise operators ----------------------
...@@ -843,6 +877,21 @@ struct Im2SequenceParam : ParamBase { ...@@ -843,6 +877,21 @@ struct Im2SequenceParam : ParamBase {
struct SequenceSoftmaxParam : ParamBase { struct SequenceSoftmaxParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
///////////////////////////////////////////////////////////////////////////////////
// // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out}));
}
return output_tensor_ptrs_cache_.get();
}
}; };
struct NormParam : ParamBase { struct NormParam : ParamBase {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册