未验证 提交 2168ff38 编写于 作者: H huzhiqiang 提交者: GitHub

[Framework][InferShape] accelerate op infer_shape period (#3434)

上级 212c3227
...@@ -25,16 +25,16 @@ namespace lite { ...@@ -25,16 +25,16 @@ namespace lite {
bool OpLite::InferShape() { bool OpLite::InferShape() {
// if input_tensor_ptrs and output_tensor_ptrs are overloaded in param_ // if input_tensor_ptrs and output_tensor_ptrs are overloaded in param_
// InferShapeByMemoryInternal will be applied. // InferShapeByMemoryInternal will be applied.
if (param_.input_tensor_ptrs() && param_.output_tensor_ptrs()) { if (op_param_ && op_param_->input_tensor_ptrs() &&
op_param_->output_tensor_ptrs()) {
return this->InferShapeWithCache(); return this->InferShapeWithCache();
} else { } else {
// otherwise, InferShapeImpl is applied directly.
return this->InferShapeImpl(); return this->InferShapeImpl();
} }
} }
bool OpLite::InferShapeWithCache() { bool OpLite::InferShapeWithCache() {
// 1. Get vector of current input tensors // 1. Get vector of current input tensors
auto *current_inputs = param_.input_tensor_ptrs(); auto *current_inputs = op_param_->input_tensor_ptrs();
// 2. Get hash value of current inputs shape and lod // 2. Get hash value of current inputs shape and lod
size_t new_hash = 0; size_t new_hash = 0;
for (auto iter = current_inputs->begin(); iter != current_inputs->end(); for (auto iter = current_inputs->begin(); iter != current_inputs->end();
...@@ -59,7 +59,7 @@ bool OpLite::InferShapeWithCache() { ...@@ -59,7 +59,7 @@ bool OpLite::InferShapeWithCache() {
if (new_hash == io_shape_lod_hash_ && new_hash != 0) { if (new_hash == io_shape_lod_hash_ && new_hash != 0) {
// if current hash value is consistent with io_shape_lod_hash_, // if current hash value is consistent with io_shape_lod_hash_,
// previous outputs shape and lod are reused. // previous outputs shape and lod are reused.
auto *current_outputs = param_.output_tensor_ptrs(); auto *current_outputs = op_param_->output_tensor_ptrs();
for (size_t i = 0; i < current_outputs->size(); i++) { for (size_t i = 0; i < current_outputs->size(); i++) {
current_outputs->at(i)->Resize(last_output_shapes[i]); current_outputs->at(i)->Resize(last_output_shapes[i]);
current_outputs->at(i)->set_lod(last_output_lods[i]); current_outputs->at(i)->set_lod(last_output_lods[i]);
...@@ -68,10 +68,12 @@ bool OpLite::InferShapeWithCache() { ...@@ -68,10 +68,12 @@ bool OpLite::InferShapeWithCache() {
// otherwise, current hash value is changed, InferShapeImpl will apply. // otherwise, current hash value is changed, InferShapeImpl will apply.
io_shape_lod_hash_ = new_hash; io_shape_lod_hash_ = new_hash;
this->InferShapeImpl(); this->InferShapeImpl();
auto *current_outputs = param_.output_tensor_ptrs(); auto *current_outputs = op_param_->output_tensor_ptrs();
last_output_shapes.clear();
last_output_lods.clear();
for (size_t i = 0; i < current_outputs->size(); i++) { for (size_t i = 0; i < current_outputs->size(); i++) {
last_output_shapes[i] = current_outputs->at(i)->dims(); last_output_shapes.push_back(current_outputs->at(i)->dims());
last_output_lods[i] = current_outputs->at(i)->lod(); last_output_lods.push_back(current_outputs->at(i)->lod());
} }
} }
return true; return true;
......
...@@ -77,6 +77,11 @@ class OpLite : public Registry { ...@@ -77,6 +77,11 @@ class OpLite : public Registry {
// Link the external execution environ to internal context. // Link the external execution environ to internal context.
bool Attach(const cpp::OpDesc &opdesc, lite::Scope *scope); bool Attach(const cpp::OpDesc &opdesc, lite::Scope *scope);
template <typename T>
inline void AttachParam(T *param) {
op_param_ = static_cast<T *>(param);
}
const OpInfo *op_info() const { return op_info_.get(); } const OpInfo *op_info() const { return op_info_.get(); }
OpInfo *mutable_op_info() { return op_info_.get(); } OpInfo *mutable_op_info() { return op_info_.get(); }
...@@ -167,11 +172,10 @@ class OpLite : public Registry { ...@@ -167,11 +172,10 @@ class OpLite : public Registry {
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
Place kernel_place_{TARGET(kHost), PRECISION(kFloat)}; Place kernel_place_{TARGET(kHost), PRECISION(kFloat)};
std::unique_ptr<OpInfo> op_info_; std::unique_ptr<OpInfo> op_info_;
std::vector<DDimLite> last_output_shapes{}; std::vector<DDimLite> last_output_shapes{};
std::vector<std::vector<std::vector<uint64_t>>> last_output_lods{}; std::vector<std::vector<std::vector<uint64_t>>> last_output_lods{};
size_t io_shape_lod_hash_{}; size_t io_shape_lod_hash_{};
mutable operators::ParamBase param_; mutable operators::ParamBase *op_param_{nullptr};
private: private:
// Infer Shape according to memory, if current input shapes are consistent // Infer Shape according to memory, if current input shapes are consistent
......
...@@ -73,6 +73,7 @@ bool BatchNormOp::InferShapeImpl() const { ...@@ -73,6 +73,7 @@ bool BatchNormOp::InferShapeImpl() const {
} }
bool BatchNormOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { bool BatchNormOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
AttachParam(&param_);
param_.x = scope->FindVar(op_desc.Input("X").front())->GetMutable<Tensor>(); param_.x = scope->FindVar(op_desc.Input("X").front())->GetMutable<Tensor>();
param_.bias = param_.bias =
scope->FindVar(op_desc.Input("Bias").front())->GetMutable<Tensor>(); scope->FindVar(op_desc.Input("Bias").front())->GetMutable<Tensor>();
......
...@@ -66,6 +66,7 @@ bool ConcatOpLite::InferShapeImpl() const { ...@@ -66,6 +66,7 @@ bool ConcatOpLite::InferShapeImpl() const {
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool ConcatOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { bool ConcatOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
AttachParam(&param_);
auto inputs = op_desc.Input("X"); auto inputs = op_desc.Input("X");
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
......
...@@ -38,6 +38,7 @@ class ConvOpLite : public OpLite { ...@@ -38,6 +38,7 @@ class ConvOpLite : public OpLite {
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override { bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override {
AttachParam(&param_);
auto X = op_desc.Input("Input").front(); auto X = op_desc.Input("Input").front();
auto Filter = op_desc.Input("Filter").front(); auto Filter = op_desc.Input("Filter").front();
auto Out = op_desc.Output("Output").front(); auto Out = op_desc.Output("Output").front();
......
...@@ -87,6 +87,8 @@ bool ElementwiseOp::InferShapeImpl() const { ...@@ -87,6 +87,8 @@ bool ElementwiseOp::InferShapeImpl() const {
} }
bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
AttachParam(&param_);
auto X_name = opdesc.Input("X").front(); auto X_name = opdesc.Input("X").front();
auto Y_name = opdesc.Input("Y").front(); auto Y_name = opdesc.Input("Y").front();
auto Out_name = opdesc.Output("Out").front(); auto Out_name = opdesc.Output("Out").front();
......
...@@ -69,6 +69,8 @@ bool FcOpLite::InferShapeImpl() const { ...@@ -69,6 +69,8 @@ bool FcOpLite::InferShapeImpl() const {
} }
bool FcOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { bool FcOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
AttachParam(&param_);
auto input = op_desc.Input("Input").front(); auto input = op_desc.Input("Input").front();
auto W = op_desc.Input("W").front(); auto W = op_desc.Input("W").front();
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
......
...@@ -132,6 +132,7 @@ bool MatMulOpLite::InferShapeImpl() const { ...@@ -132,6 +132,7 @@ bool MatMulOpLite::InferShapeImpl() const {
} }
bool MatMulOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { bool MatMulOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
AttachParam(&param_);
CHECK(!op_desc.Input("X").empty()); CHECK(!op_desc.Input("X").empty());
CHECK(!op_desc.Input("Y").empty()); CHECK(!op_desc.Input("Y").empty());
CHECK(!op_desc.Output("Out").empty()); CHECK(!op_desc.Output("Out").empty());
......
...@@ -38,6 +38,8 @@ class MulOpLite : public OpLite { ...@@ -38,6 +38,8 @@ class MulOpLite : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
AttachParam(&param_);
CHECK(!op_desc.Input("X").empty()); CHECK(!op_desc.Input("X").empty());
CHECK(!op_desc.Input("Y").empty()); CHECK(!op_desc.Input("Y").empty());
CHECK(!op_desc.Output("Out").empty()); CHECK(!op_desc.Output("Out").empty());
...@@ -56,7 +58,6 @@ class MulOpLite : public OpLite { ...@@ -56,7 +58,6 @@ class MulOpLite : public OpLite {
param_.output = var->GetMutable<Tensor>(); param_.output = var->GetMutable<Tensor>();
param_.x_num_col_dims = op_desc.GetAttr<int>("x_num_col_dims"); param_.x_num_col_dims = op_desc.GetAttr<int>("x_num_col_dims");
param_.y_num_col_dims = op_desc.GetAttr<int>("y_num_col_dims"); param_.y_num_col_dims = op_desc.GetAttr<int>("y_num_col_dims");
return true; return true;
} }
......
...@@ -35,8 +35,11 @@ namespace operators { ...@@ -35,8 +35,11 @@ namespace operators {
struct ParamBase { struct ParamBase {
public: public:
const std::vector<Tensor*>* input_tensor_ptrs() const { return nullptr; } virtual ~ParamBase() {}
std::vector<Tensor*>* output_tensor_ptrs() { return nullptr; } virtual const std::vector<const Tensor*>* input_tensor_ptrs() {
return nullptr;
}
virtual std::vector<Tensor*>* output_tensor_ptrs() { return nullptr; }
protected: protected:
std::shared_ptr<std::vector<const Tensor*>> input_tensor_ptrs_cache_{nullptr}; std::shared_ptr<std::vector<const Tensor*>> input_tensor_ptrs_cache_{nullptr};
...@@ -108,15 +111,15 @@ struct FcParam : ParamBase { ...@@ -108,15 +111,15 @@ struct FcParam : ParamBase {
WITH_INT8_CONFIG WITH_INT8_CONFIG
/////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() { const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (UNLIKELY(input_tensor_ptrs_cache_)) { if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({input})); input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({input}));
} }
return input_tensor_ptrs_cache_.get(); return input_tensor_ptrs_cache_.get();
} }
// get a vector of output tensors // get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() { std::vector<Tensor*>* output_tensor_ptrs() override {
if (UNLIKELY(output_tensor_ptrs_cache_)) { if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output})); output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
} }
return output_tensor_ptrs_cache_.get(); return output_tensor_ptrs_cache_.get();
...@@ -160,15 +163,15 @@ struct MulParam : ParamBase { ...@@ -160,15 +163,15 @@ struct MulParam : ParamBase {
WITH_INT8_CONFIG WITH_INT8_CONFIG
/////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() { const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (UNLIKELY(input_tensor_ptrs_cache_)) { if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x, y})); input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x, y}));
} }
return input_tensor_ptrs_cache_.get(); return input_tensor_ptrs_cache_.get();
} }
// get a vector of output tensors // get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() { std::vector<Tensor*>* output_tensor_ptrs() override {
if (UNLIKELY(output_tensor_ptrs_cache_)) { if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output})); output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
} }
return output_tensor_ptrs_cache_.get(); return output_tensor_ptrs_cache_.get();
...@@ -243,15 +246,15 @@ struct ScaleParam : ParamBase { ...@@ -243,15 +246,15 @@ struct ScaleParam : ParamBase {
bool bias_after_scale{true}; bool bias_after_scale{true};
/////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() { const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (UNLIKELY(input_tensor_ptrs_cache_)) { if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x})); input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
} }
return input_tensor_ptrs_cache_.get(); return input_tensor_ptrs_cache_.get();
} }
// get a vector of output tensors // get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() { std::vector<Tensor*>* output_tensor_ptrs() override {
if (UNLIKELY(output_tensor_ptrs_cache_)) { if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output})); output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
} }
return output_tensor_ptrs_cache_.get(); return output_tensor_ptrs_cache_.get();
...@@ -265,15 +268,15 @@ struct SoftmaxParam : ParamBase { ...@@ -265,15 +268,15 @@ struct SoftmaxParam : ParamBase {
int axis{-1}; int axis{-1};
/////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() { const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (UNLIKELY(input_tensor_ptrs_cache_)) { if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x})); input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
} }
return input_tensor_ptrs_cache_.get(); return input_tensor_ptrs_cache_.get();
} }
// get a vector of output tensors // get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() { std::vector<Tensor*>* output_tensor_ptrs() override {
if (UNLIKELY(output_tensor_ptrs_cache_)) { if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output})); output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
} }
return output_tensor_ptrs_cache_.get(); return output_tensor_ptrs_cache_.get();
...@@ -292,15 +295,15 @@ struct ReshapeParam : ParamBase { ...@@ -292,15 +295,15 @@ struct ReshapeParam : ParamBase {
bool inplace{false}; bool inplace{false};
/////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() { const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (UNLIKELY(input_tensor_ptrs_cache_)) { if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x})); input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
} }
return input_tensor_ptrs_cache_.get(); return input_tensor_ptrs_cache_.get();
} }
// get a vector of output tensors // get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() { std::vector<Tensor*>* output_tensor_ptrs() override {
if (UNLIKELY(output_tensor_ptrs_cache_)) { if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output})); output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
} }
return output_tensor_ptrs_cache_.get(); return output_tensor_ptrs_cache_.get();
...@@ -314,8 +317,8 @@ struct ConcatParam : ParamBase { ...@@ -314,8 +317,8 @@ struct ConcatParam : ParamBase {
int axis{0}; int axis{0};
lite::Tensor* axis_tensor{}; lite::Tensor* axis_tensor{};
// get a vector of input tensors // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() { const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (UNLIKELY(input_tensor_ptrs_cache_)) { if (!input_tensor_ptrs_cache_) {
std::vector<const Tensor*> vec; std::vector<const Tensor*> vec;
for (auto in : x) { for (auto in : x) {
vec.push_back(in); vec.push_back(in);
...@@ -325,8 +328,8 @@ struct ConcatParam : ParamBase { ...@@ -325,8 +328,8 @@ struct ConcatParam : ParamBase {
return input_tensor_ptrs_cache_.get(); return input_tensor_ptrs_cache_.get();
} }
// get a vector of output tensors // get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() { std::vector<Tensor*>* output_tensor_ptrs() override {
if (UNLIKELY(output_tensor_ptrs_cache_)) { if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output})); output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
} }
return output_tensor_ptrs_cache_.get(); return output_tensor_ptrs_cache_.get();
...@@ -406,15 +409,15 @@ struct ConvParam : ParamBase { ...@@ -406,15 +409,15 @@ struct ConvParam : ParamBase {
/////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() { const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (UNLIKELY(input_tensor_ptrs_cache_)) { if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x})); input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
} }
return input_tensor_ptrs_cache_.get(); return input_tensor_ptrs_cache_.get();
} }
// get a vector of output tensors // get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() { std::vector<Tensor*>* output_tensor_ptrs() override {
if (UNLIKELY(output_tensor_ptrs_cache_)) { if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output})); output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
} }
return output_tensor_ptrs_cache_.get(); return output_tensor_ptrs_cache_.get();
...@@ -440,15 +443,15 @@ struct BatchNormParam : ParamBase { ...@@ -440,15 +443,15 @@ struct BatchNormParam : ParamBase {
DataLayoutType data_layout{DATALAYOUT(kNCHW)}; DataLayoutType data_layout{DATALAYOUT(kNCHW)};
/////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() { const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (UNLIKELY(input_tensor_ptrs_cache_)) { if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x})); input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
} }
return input_tensor_ptrs_cache_.get(); return input_tensor_ptrs_cache_.get();
} }
// get a vector of output tensors // get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() { std::vector<Tensor*>* output_tensor_ptrs() override {
if (UNLIKELY(output_tensor_ptrs_cache_)) { if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({y})); output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({y}));
} }
return output_tensor_ptrs_cache_.get(); return output_tensor_ptrs_cache_.get();
...@@ -479,15 +482,15 @@ struct PoolParam : ParamBase { ...@@ -479,15 +482,15 @@ struct PoolParam : ParamBase {
WITH_INT8_CONFIG WITH_INT8_CONFIG
/////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() { const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (UNLIKELY(input_tensor_ptrs_cache_)) { if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x})); input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
} }
return input_tensor_ptrs_cache_.get(); return input_tensor_ptrs_cache_.get();
} }
// get a vector of output tensors // get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() { std::vector<Tensor*>* output_tensor_ptrs() override {
if (UNLIKELY(output_tensor_ptrs_cache_)) { if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output})); output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
} }
return output_tensor_ptrs_cache_.get(); return output_tensor_ptrs_cache_.get();
...@@ -518,15 +521,15 @@ struct SplitParam : ParamBase { ...@@ -518,15 +521,15 @@ struct SplitParam : ParamBase {
std::vector<int> sections; std::vector<int> sections;
/////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() { const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (UNLIKELY(input_tensor_ptrs_cache_)) { if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x})); input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
} }
return input_tensor_ptrs_cache_.get(); return input_tensor_ptrs_cache_.get();
} }
// get a vector of output tensors // get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() { std::vector<Tensor*>* output_tensor_ptrs() override {
if (UNLIKELY(output_tensor_ptrs_cache_)) { if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output})); output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
} }
return output_tensor_ptrs_cache_.get(); return output_tensor_ptrs_cache_.get();
...@@ -544,15 +547,15 @@ struct TransposeParam : ParamBase { ...@@ -544,15 +547,15 @@ struct TransposeParam : ParamBase {
std::string data_format{"AnyLayout"}; std::string data_format{"AnyLayout"};
/////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////
// // get a vector of input tensors // // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() { const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (UNLIKELY(input_tensor_ptrs_cache_)) { if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x})); input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
} }
return input_tensor_ptrs_cache_.get(); return input_tensor_ptrs_cache_.get();
} }
// get a vector of output tensors // get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() { std::vector<Tensor*>* output_tensor_ptrs() override {
if (UNLIKELY(output_tensor_ptrs_cache_)) { if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output})); output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
} }
return output_tensor_ptrs_cache_.get(); return output_tensor_ptrs_cache_.get();
...@@ -571,15 +574,15 @@ struct ElementwiseParam : ParamBase { ...@@ -571,15 +574,15 @@ struct ElementwiseParam : ParamBase {
float y_input_scale{1.0}; float y_input_scale{1.0};
/////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() { const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (UNLIKELY(input_tensor_ptrs_cache_)) { if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X, Y})); input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X, Y}));
} }
return input_tensor_ptrs_cache_.get(); return input_tensor_ptrs_cache_.get();
} }
// get a vector of output tensors // get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() { std::vector<Tensor*>* output_tensor_ptrs() override {
if (UNLIKELY(output_tensor_ptrs_cache_)) { if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out})); output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out}));
} }
return output_tensor_ptrs_cache_.get(); return output_tensor_ptrs_cache_.get();
...@@ -884,15 +887,15 @@ struct SequenceSoftmaxParam : ParamBase { ...@@ -884,15 +887,15 @@ struct SequenceSoftmaxParam : ParamBase {
lite::Tensor* Out{}; lite::Tensor* Out{};
/////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////
// // get a vector of input tensors // // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() { const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (UNLIKELY(input_tensor_ptrs_cache_)) { if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X})); input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X}));
} }
return input_tensor_ptrs_cache_.get(); return input_tensor_ptrs_cache_.get();
} }
// get a vector of output tensors // get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() { std::vector<Tensor*>* output_tensor_ptrs() override {
if (UNLIKELY(output_tensor_ptrs_cache_)) { if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out})); output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out}));
} }
return output_tensor_ptrs_cache_.get(); return output_tensor_ptrs_cache_.get();
...@@ -1135,15 +1138,15 @@ struct SliceParam : ParamBase { ...@@ -1135,15 +1138,15 @@ struct SliceParam : ParamBase {
lite::Tensor* EndsTensor{nullptr}; lite::Tensor* EndsTensor{nullptr};
/////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() { const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (UNLIKELY(input_tensor_ptrs_cache_)) { if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X})); input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X}));
} }
return input_tensor_ptrs_cache_.get(); return input_tensor_ptrs_cache_.get();
} }
// get a vector of output tensors // get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() { std::vector<Tensor*>* output_tensor_ptrs() override {
if (UNLIKELY(output_tensor_ptrs_cache_)) { if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out})); output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out}));
} }
return output_tensor_ptrs_cache_.get(); return output_tensor_ptrs_cache_.get();
...@@ -1197,15 +1200,15 @@ struct SqueezeParam : ParamBase { ...@@ -1197,15 +1200,15 @@ struct SqueezeParam : ParamBase {
std::vector<int> axes{}; std::vector<int> axes{};
/////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() { const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (UNLIKELY(input_tensor_ptrs_cache_)) { if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X})); input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X}));
} }
return input_tensor_ptrs_cache_.get(); return input_tensor_ptrs_cache_.get();
} }
// get a vector of output tensors // get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() { std::vector<Tensor*>* output_tensor_ptrs() override {
if (UNLIKELY(output_tensor_ptrs_cache_)) { if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out})); output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out}));
} }
return output_tensor_ptrs_cache_.get(); return output_tensor_ptrs_cache_.get();
...@@ -1221,15 +1224,15 @@ struct UnsqueezeParam : ParamBase { ...@@ -1221,15 +1224,15 @@ struct UnsqueezeParam : ParamBase {
std::vector<const lite::Tensor*> axes_tensor_vct{}; std::vector<const lite::Tensor*> axes_tensor_vct{};
/////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() { const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (UNLIKELY(input_tensor_ptrs_cache_)) { if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X})); input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X}));
} }
return input_tensor_ptrs_cache_.get(); return input_tensor_ptrs_cache_.get();
} }
// get a vector of output tensors // get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() { std::vector<Tensor*>* output_tensor_ptrs() override {
if (UNLIKELY(output_tensor_ptrs_cache_)) { if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out})); output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out}));
} }
return output_tensor_ptrs_cache_.get(); return output_tensor_ptrs_cache_.get();
...@@ -1253,15 +1256,15 @@ struct MatMulParam : ParamBase { ...@@ -1253,15 +1256,15 @@ struct MatMulParam : ParamBase {
float alpha{1.0f}; float alpha{1.0f};
/////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() { const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (UNLIKELY(input_tensor_ptrs_cache_)) { if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X, Y})); input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X, Y}));
} }
return input_tensor_ptrs_cache_.get(); return input_tensor_ptrs_cache_.get();
} }
// get a vector of output tensors // get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() { std::vector<Tensor*>* output_tensor_ptrs() override {
if (UNLIKELY(output_tensor_ptrs_cache_)) { if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out})); output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out}));
} }
return output_tensor_ptrs_cache_.get(); return output_tensor_ptrs_cache_.get();
......
...@@ -41,6 +41,7 @@ class PoolOpLite : public OpLite { ...@@ -41,6 +41,7 @@ class PoolOpLite : public OpLite {
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
AttachParam(&param_);
auto x = op_desc.Input("X").front(); auto x = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
......
...@@ -56,6 +56,7 @@ bool ReshapeOp::InferShapeImpl() const { ...@@ -56,6 +56,7 @@ bool ReshapeOp::InferShapeImpl() const {
} }
bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
AttachParam(&param_);
param_.x = param_.x =
scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>(); scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
param_.output = param_.output =
......
...@@ -30,6 +30,7 @@ bool ScaleOp::InferShapeImpl() const { ...@@ -30,6 +30,7 @@ bool ScaleOp::InferShapeImpl() const {
} }
bool ScaleOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { bool ScaleOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
AttachParam(&param_);
auto x = op_desc.Input("X").front(); auto x = op_desc.Input("X").front();
auto output = op_desc.Output("Out").front(); auto output = op_desc.Output("Out").front();
param_.x = scope->FindVar(x)->GetMutable<Tensor>(); param_.x = scope->FindVar(x)->GetMutable<Tensor>();
......
...@@ -34,6 +34,7 @@ bool SequenceSoftmaxOp::InferShapeImpl() const { ...@@ -34,6 +34,7 @@ bool SequenceSoftmaxOp::InferShapeImpl() const {
bool SequenceSoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc, bool SequenceSoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc,
lite::Scope *scope) { lite::Scope *scope) {
AttachParam(&param_);
param_.X = param_.X =
scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>(); scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
param_.Out = param_.Out =
......
...@@ -87,6 +87,7 @@ bool SliceOp::InferShapeImpl() const { ...@@ -87,6 +87,7 @@ bool SliceOp::InferShapeImpl() const {
} }
bool SliceOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { bool SliceOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
AttachParam(&param_);
param_.X = param_.X =
scope->FindVar(opdesc.Input("Input").front())->GetMutable<lite::Tensor>(); scope->FindVar(opdesc.Input("Input").front())->GetMutable<lite::Tensor>();
param_.Out = param_.Out =
......
...@@ -38,6 +38,8 @@ bool SoftmaxOp::InferShapeImpl() const { ...@@ -38,6 +38,8 @@ bool SoftmaxOp::InferShapeImpl() const {
} }
bool SoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { bool SoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
AttachParam(&param_);
param_.x = const_cast<lite::Tensor *>( param_.x = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>()); &scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
param_.output = param_.output =
......
...@@ -75,6 +75,7 @@ bool SplitOp::InferShapeImpl() const { ...@@ -75,6 +75,7 @@ bool SplitOp::InferShapeImpl() const {
} }
bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
AttachParam(&param_);
param_.axis = opdesc.GetAttr<int>("axis"); param_.axis = opdesc.GetAttr<int>("axis");
param_.num = opdesc.GetAttr<int>("num"); param_.num = opdesc.GetAttr<int>("num");
param_.sections = opdesc.GetAttr<std::vector<int>>("sections"); param_.sections = opdesc.GetAttr<std::vector<int>>("sections");
......
...@@ -84,6 +84,7 @@ bool SqueezeOp::InferShapeImpl() const { ...@@ -84,6 +84,7 @@ bool SqueezeOp::InferShapeImpl() const {
} }
bool SqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { bool SqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
AttachParam(&param_);
auto x_var = scope->FindVar(opdesc.Input("X").front()); auto x_var = scope->FindVar(opdesc.Input("X").front());
auto output_var = scope->FindVar(opdesc.Output("Out").front()); auto output_var = scope->FindVar(opdesc.Output("Out").front());
CHECK(x_var); CHECK(x_var);
......
...@@ -70,6 +70,7 @@ bool TransposeOp::InferShapeImpl() const { ...@@ -70,6 +70,7 @@ bool TransposeOp::InferShapeImpl() const {
} }
bool TransposeOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { bool TransposeOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
AttachParam(&param_);
auto x = op_desc.Input("X").front(); auto x = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
......
...@@ -89,6 +89,7 @@ bool UnsqueezeOp::InferShapeImpl() const { ...@@ -89,6 +89,7 @@ bool UnsqueezeOp::InferShapeImpl() const {
} }
bool UnsqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { bool UnsqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
AttachParam(&param_);
auto x_var = scope->FindVar(opdesc.Input("X").front()); auto x_var = scope->FindVar(opdesc.Input("X").front());
auto output_var = scope->FindVar(opdesc.Output("Out").front()); auto output_var = scope->FindVar(opdesc.Output("Out").front());
CHECK(x_var); CHECK(x_var);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册