diff --git a/lite/core/op_lite.cc b/lite/core/op_lite.cc index 941a9e9f88cf04ef47487237b1a3f6509dea762b..de76f404f8a129eb94e645dc731a0d09c1ee3c77 100644 --- a/lite/core/op_lite.cc +++ b/lite/core/op_lite.cc @@ -25,16 +25,16 @@ namespace lite { bool OpLite::InferShape() { // if input_tensor_ptrs and output_tensor_ptrs are overloaded in param_ // InferShapeByMemoryInternal will be applied. - if (param_.input_tensor_ptrs() && param_.output_tensor_ptrs()) { + if (op_param_ && op_param_->input_tensor_ptrs() && + op_param_->output_tensor_ptrs()) { return this->InferShapeWithCache(); } else { - // otherwise, InferShapeImpl is applied directly. return this->InferShapeImpl(); } } bool OpLite::InferShapeWithCache() { // 1. Get vector of current input tensors - auto *current_inputs = param_.input_tensor_ptrs(); + auto *current_inputs = op_param_->input_tensor_ptrs(); // 2. Get hash value of current inputs shape and lod size_t new_hash = 0; for (auto iter = current_inputs->begin(); iter != current_inputs->end(); @@ -59,7 +59,7 @@ bool OpLite::InferShapeWithCache() { if (new_hash == io_shape_lod_hash_ && new_hash != 0) { // if current hash value is consistent with io_shape_lod_hash_, // previous outputs shape and lod are reused. - auto *current_outputs = param_.output_tensor_ptrs(); + auto *current_outputs = op_param_->output_tensor_ptrs(); for (size_t i = 0; i < current_outputs->size(); i++) { current_outputs->at(i)->Resize(last_output_shapes[i]); current_outputs->at(i)->set_lod(last_output_lods[i]); @@ -68,10 +68,12 @@ bool OpLite::InferShapeWithCache() { // otherwise, current hash value is changed, InferShapeImpl will apply. io_shape_lod_hash_ = new_hash; this->InferShapeImpl(); - auto *current_outputs = param_.output_tensor_ptrs(); + auto *current_outputs = op_param_->output_tensor_ptrs(); + last_output_shapes.clear(); + last_output_lods.clear(); for (size_t i = 0; i < current_outputs->size(); i++) { - last_output_shapes[i] = current_outputs->at(i)->dims(); - last_output_lods[i] = current_outputs->at(i)->lod(); + last_output_shapes.push_back(current_outputs->at(i)->dims()); + last_output_lods.push_back(current_outputs->at(i)->lod()); } } return true; diff --git a/lite/core/op_lite.h b/lite/core/op_lite.h index 428b188c468ded790e74c9cc4f5da5c7efe2fd00..656f992b1736d88abd1ed95759b19519ec11aff7 100644 --- a/lite/core/op_lite.h +++ b/lite/core/op_lite.h @@ -77,6 +77,11 @@ class OpLite : public Registry { // Link the external execution environ to internal context. bool Attach(const cpp::OpDesc &opdesc, lite::Scope *scope); + template + inline void AttachParam(T *param) { + op_param_ = static_cast(param); + } + const OpInfo *op_info() const { return op_info_.get(); } OpInfo *mutable_op_info() { return op_info_.get(); } @@ -167,11 +172,10 @@ class OpLite : public Registry { std::vector valid_places_; Place kernel_place_{TARGET(kHost), PRECISION(kFloat)}; std::unique_ptr op_info_; - std::vector last_output_shapes{}; std::vector>> last_output_lods{}; size_t io_shape_lod_hash_{}; - mutable operators::ParamBase param_; + mutable operators::ParamBase *op_param_{nullptr}; private: // Infer Shape according to memory, if current input shapes are consistent diff --git a/lite/operators/batch_norm_op.cc b/lite/operators/batch_norm_op.cc index 67e037fba349e811f1faf991c84310b11ab7a13c..b043aad2aca05c7d42edec1960f5335b5fc91fc6 100644 --- a/lite/operators/batch_norm_op.cc +++ b/lite/operators/batch_norm_op.cc @@ -73,6 +73,7 @@ bool BatchNormOp::InferShapeImpl() const { } bool BatchNormOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { + AttachParam(¶m_); param_.x = scope->FindVar(op_desc.Input("X").front())->GetMutable(); param_.bias = scope->FindVar(op_desc.Input("Bias").front())->GetMutable(); diff --git a/lite/operators/concat_op.cc b/lite/operators/concat_op.cc index c15bf292897006b3c6d5e67bcfaea5d0e590a82d..052b9cdca0a898185649cfdbddb933230e968b14 100644 --- a/lite/operators/concat_op.cc +++ b/lite/operators/concat_op.cc @@ -66,6 +66,7 @@ bool ConcatOpLite::InferShapeImpl() const { // TODO(Superjomn) replace framework::OpDesc with a lite one. bool ConcatOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { + AttachParam(¶m_); auto inputs = op_desc.Input("X"); auto out = op_desc.Output("Out").front(); diff --git a/lite/operators/conv_op.h b/lite/operators/conv_op.h index eab17fe6db0a59a9eb0eea0ab7344758a8232d15..49452fc44f1b114efc7eb2fb433000bebdb577a6 100644 --- a/lite/operators/conv_op.h +++ b/lite/operators/conv_op.h @@ -38,6 +38,7 @@ class ConvOpLite : public OpLite { // TODO(Superjomn) replace framework::OpDesc with a lite one. bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override { + AttachParam(¶m_); auto X = op_desc.Input("Input").front(); auto Filter = op_desc.Input("Filter").front(); auto Out = op_desc.Output("Output").front(); diff --git a/lite/operators/elementwise_ops.cc b/lite/operators/elementwise_ops.cc index 3996c933407233538a62ae9e197978f799ce06b0..6cc41f0a66cfac4a0baa0153765a59766fa045f4 100644 --- a/lite/operators/elementwise_ops.cc +++ b/lite/operators/elementwise_ops.cc @@ -87,6 +87,8 @@ bool ElementwiseOp::InferShapeImpl() const { } bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { + AttachParam(¶m_); + auto X_name = opdesc.Input("X").front(); auto Y_name = opdesc.Input("Y").front(); auto Out_name = opdesc.Output("Out").front(); diff --git a/lite/operators/fc_op.cc b/lite/operators/fc_op.cc index d58a9e5b881048dd47340082fe9c94a618a7a5fb..d4032c5e8b98ff6d5763d2d06610d2e214ad90ca 100644 --- a/lite/operators/fc_op.cc +++ b/lite/operators/fc_op.cc @@ -69,6 +69,8 @@ bool FcOpLite::InferShapeImpl() const { } bool FcOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { + AttachParam(¶m_); + auto input = op_desc.Input("Input").front(); auto W = op_desc.Input("W").front(); auto out = op_desc.Output("Out").front(); diff --git a/lite/operators/matmul_op.cc b/lite/operators/matmul_op.cc index 04a0fc97d77a181e45e3e829010934e22381ae12..d3e2e963abbb68adf890a5ba42d3d187d3e611c4 100644 --- a/lite/operators/matmul_op.cc +++ b/lite/operators/matmul_op.cc @@ -132,6 +132,7 @@ bool MatMulOpLite::InferShapeImpl() const { } bool MatMulOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { + AttachParam(¶m_); CHECK(!op_desc.Input("X").empty()); CHECK(!op_desc.Input("Y").empty()); CHECK(!op_desc.Output("Out").empty()); diff --git a/lite/operators/mul_op.h b/lite/operators/mul_op.h index 10a2e2efaa4db0e106e3c56c2f9b1cec9fb55ac4..74b64f11ae2ec75efa61a7799da49187c9e684ea 100644 --- a/lite/operators/mul_op.h +++ b/lite/operators/mul_op.h @@ -38,6 +38,8 @@ class MulOpLite : public OpLite { void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } // TODO(Superjomn) replace framework::OpDesc with a lite one. bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { + AttachParam(¶m_); + CHECK(!op_desc.Input("X").empty()); CHECK(!op_desc.Input("Y").empty()); CHECK(!op_desc.Output("Out").empty()); @@ -56,7 +58,6 @@ class MulOpLite : public OpLite { param_.output = var->GetMutable(); param_.x_num_col_dims = op_desc.GetAttr("x_num_col_dims"); param_.y_num_col_dims = op_desc.GetAttr("y_num_col_dims"); - return true; } diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 30ee736de494e1a93902d1252db2672aeef38f2e..c5e595672eb580a282ee101294a48a053d8d9c02 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -35,8 +35,11 @@ namespace operators { struct ParamBase { public: - const std::vector* input_tensor_ptrs() const { return nullptr; } - std::vector* output_tensor_ptrs() { return nullptr; } + virtual ~ParamBase() {} + virtual const std::vector* input_tensor_ptrs() { + return nullptr; + } + virtual std::vector* output_tensor_ptrs() { return nullptr; } protected: std::shared_ptr> input_tensor_ptrs_cache_{nullptr}; @@ -108,15 +111,15 @@ struct FcParam : ParamBase { WITH_INT8_CONFIG /////////////////////////////////////////////////////////////////////////////////// // get a vector of input tensors - const std::vector* input_tensor_ptrs() { - if (UNLIKELY(input_tensor_ptrs_cache_)) { + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { input_tensor_ptrs_cache_.reset(new std::vector({input})); } return input_tensor_ptrs_cache_.get(); } // get a vector of output tensors - const std::vector* output_tensor_ptrs() { - if (UNLIKELY(output_tensor_ptrs_cache_)) { + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { output_tensor_ptrs_cache_.reset(new std::vector({output})); } return output_tensor_ptrs_cache_.get(); @@ -160,15 +163,15 @@ struct MulParam : ParamBase { WITH_INT8_CONFIG /////////////////////////////////////////////////////////////////////////////////// // get a vector of input tensors - const std::vector* input_tensor_ptrs() { - if (UNLIKELY(input_tensor_ptrs_cache_)) { + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { input_tensor_ptrs_cache_.reset(new std::vector({x, y})); } return input_tensor_ptrs_cache_.get(); } // get a vector of output tensors - const std::vector* output_tensor_ptrs() { - if (UNLIKELY(output_tensor_ptrs_cache_)) { + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { output_tensor_ptrs_cache_.reset(new std::vector({output})); } return output_tensor_ptrs_cache_.get(); @@ -243,15 +246,15 @@ struct ScaleParam : ParamBase { bool bias_after_scale{true}; /////////////////////////////////////////////////////////////////////////////////// // get a vector of input tensors - const std::vector* input_tensor_ptrs() { - if (UNLIKELY(input_tensor_ptrs_cache_)) { + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { input_tensor_ptrs_cache_.reset(new std::vector({x})); } return input_tensor_ptrs_cache_.get(); } // get a vector of output tensors - const std::vector* output_tensor_ptrs() { - if (UNLIKELY(output_tensor_ptrs_cache_)) { + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { output_tensor_ptrs_cache_.reset(new std::vector({output})); } return output_tensor_ptrs_cache_.get(); @@ -265,15 +268,15 @@ struct SoftmaxParam : ParamBase { int axis{-1}; /////////////////////////////////////////////////////////////////////////////////// // get a vector of input tensors - const std::vector* input_tensor_ptrs() { - if (UNLIKELY(input_tensor_ptrs_cache_)) { + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { input_tensor_ptrs_cache_.reset(new std::vector({x})); } return input_tensor_ptrs_cache_.get(); } // get a vector of output tensors - const std::vector* output_tensor_ptrs() { - if (UNLIKELY(output_tensor_ptrs_cache_)) { + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { output_tensor_ptrs_cache_.reset(new std::vector({output})); } return output_tensor_ptrs_cache_.get(); @@ -292,15 +295,15 @@ struct ReshapeParam : ParamBase { bool inplace{false}; /////////////////////////////////////////////////////////////////////////////////// // get a vector of input tensors - const std::vector* input_tensor_ptrs() { - if (UNLIKELY(input_tensor_ptrs_cache_)) { + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { input_tensor_ptrs_cache_.reset(new std::vector({x})); } return input_tensor_ptrs_cache_.get(); } // get a vector of output tensors - const std::vector* output_tensor_ptrs() { - if (UNLIKELY(output_tensor_ptrs_cache_)) { + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { output_tensor_ptrs_cache_.reset(new std::vector({output})); } return output_tensor_ptrs_cache_.get(); @@ -314,8 +317,8 @@ struct ConcatParam : ParamBase { int axis{0}; lite::Tensor* axis_tensor{}; // get a vector of input tensors - const std::vector* input_tensor_ptrs() { - if (UNLIKELY(input_tensor_ptrs_cache_)) { + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { std::vector vec; for (auto in : x) { vec.push_back(in); @@ -325,8 +328,8 @@ struct ConcatParam : ParamBase { return input_tensor_ptrs_cache_.get(); } // get a vector of output tensors - const std::vector* output_tensor_ptrs() { - if (UNLIKELY(output_tensor_ptrs_cache_)) { + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { output_tensor_ptrs_cache_.reset(new std::vector({output})); } return output_tensor_ptrs_cache_.get(); @@ -406,15 +409,15 @@ struct ConvParam : ParamBase { /////////////////////////////////////////////////////////////////////////////////// // get a vector of input tensors - const std::vector* input_tensor_ptrs() { - if (UNLIKELY(input_tensor_ptrs_cache_)) { + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { input_tensor_ptrs_cache_.reset(new std::vector({x})); } return input_tensor_ptrs_cache_.get(); } // get a vector of output tensors - const std::vector* output_tensor_ptrs() { - if (UNLIKELY(output_tensor_ptrs_cache_)) { + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { output_tensor_ptrs_cache_.reset(new std::vector({output})); } return output_tensor_ptrs_cache_.get(); @@ -440,15 +443,15 @@ struct BatchNormParam : ParamBase { DataLayoutType data_layout{DATALAYOUT(kNCHW)}; /////////////////////////////////////////////////////////////////////////////////// // get a vector of input tensors - const std::vector* input_tensor_ptrs() { - if (UNLIKELY(input_tensor_ptrs_cache_)) { + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { input_tensor_ptrs_cache_.reset(new std::vector({x})); } return input_tensor_ptrs_cache_.get(); } // get a vector of output tensors - const std::vector* output_tensor_ptrs() { - if (UNLIKELY(output_tensor_ptrs_cache_)) { + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { output_tensor_ptrs_cache_.reset(new std::vector({y})); } return output_tensor_ptrs_cache_.get(); @@ -479,15 +482,15 @@ struct PoolParam : ParamBase { WITH_INT8_CONFIG /////////////////////////////////////////////////////////////////////////////////// // get a vector of input tensors - const std::vector* input_tensor_ptrs() { - if (UNLIKELY(input_tensor_ptrs_cache_)) { + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { input_tensor_ptrs_cache_.reset(new std::vector({x})); } return input_tensor_ptrs_cache_.get(); } // get a vector of output tensors - const std::vector* output_tensor_ptrs() { - if (UNLIKELY(output_tensor_ptrs_cache_)) { + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { output_tensor_ptrs_cache_.reset(new std::vector({output})); } return output_tensor_ptrs_cache_.get(); @@ -518,15 +521,15 @@ struct SplitParam : ParamBase { std::vector sections; /////////////////////////////////////////////////////////////////////////////////// // get a vector of input tensors - const std::vector* input_tensor_ptrs() { - if (UNLIKELY(input_tensor_ptrs_cache_)) { + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { input_tensor_ptrs_cache_.reset(new std::vector({x})); } return input_tensor_ptrs_cache_.get(); } // get a vector of output tensors - const std::vector* output_tensor_ptrs() { - if (UNLIKELY(output_tensor_ptrs_cache_)) { + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { output_tensor_ptrs_cache_.reset(new std::vector({output})); } return output_tensor_ptrs_cache_.get(); @@ -544,15 +547,15 @@ struct TransposeParam : ParamBase { std::string data_format{"AnyLayout"}; /////////////////////////////////////////////////////////////////////////////////// // // get a vector of input tensors - const std::vector* input_tensor_ptrs() { - if (UNLIKELY(input_tensor_ptrs_cache_)) { + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { input_tensor_ptrs_cache_.reset(new std::vector({x})); } return input_tensor_ptrs_cache_.get(); } // get a vector of output tensors - const std::vector* output_tensor_ptrs() { - if (UNLIKELY(output_tensor_ptrs_cache_)) { + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { output_tensor_ptrs_cache_.reset(new std::vector({output})); } return output_tensor_ptrs_cache_.get(); @@ -571,15 +574,15 @@ struct ElementwiseParam : ParamBase { float y_input_scale{1.0}; /////////////////////////////////////////////////////////////////////////////////// // get a vector of input tensors - const std::vector* input_tensor_ptrs() { - if (UNLIKELY(input_tensor_ptrs_cache_)) { + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { input_tensor_ptrs_cache_.reset(new std::vector({X, Y})); } return input_tensor_ptrs_cache_.get(); } // get a vector of output tensors - const std::vector* output_tensor_ptrs() { - if (UNLIKELY(output_tensor_ptrs_cache_)) { + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { output_tensor_ptrs_cache_.reset(new std::vector({Out})); } return output_tensor_ptrs_cache_.get(); @@ -884,15 +887,15 @@ struct SequenceSoftmaxParam : ParamBase { lite::Tensor* Out{}; /////////////////////////////////////////////////////////////////////////////////// // // get a vector of input tensors - const std::vector* input_tensor_ptrs() { - if (UNLIKELY(input_tensor_ptrs_cache_)) { + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { input_tensor_ptrs_cache_.reset(new std::vector({X})); } return input_tensor_ptrs_cache_.get(); } // get a vector of output tensors - const std::vector* output_tensor_ptrs() { - if (UNLIKELY(output_tensor_ptrs_cache_)) { + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { output_tensor_ptrs_cache_.reset(new std::vector({Out})); } return output_tensor_ptrs_cache_.get(); @@ -1135,15 +1138,15 @@ struct SliceParam : ParamBase { lite::Tensor* EndsTensor{nullptr}; /////////////////////////////////////////////////////////////////////////////////// // get a vector of input tensors - const std::vector* input_tensor_ptrs() { - if (UNLIKELY(input_tensor_ptrs_cache_)) { + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { input_tensor_ptrs_cache_.reset(new std::vector({X})); } return input_tensor_ptrs_cache_.get(); } // get a vector of output tensors - const std::vector* output_tensor_ptrs() { - if (UNLIKELY(output_tensor_ptrs_cache_)) { + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { output_tensor_ptrs_cache_.reset(new std::vector({Out})); } return output_tensor_ptrs_cache_.get(); @@ -1197,15 +1200,15 @@ struct SqueezeParam : ParamBase { std::vector axes{}; /////////////////////////////////////////////////////////////////////////////////// // get a vector of input tensors - const std::vector* input_tensor_ptrs() { - if (UNLIKELY(input_tensor_ptrs_cache_)) { + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { input_tensor_ptrs_cache_.reset(new std::vector({X})); } return input_tensor_ptrs_cache_.get(); } // get a vector of output tensors - const std::vector* output_tensor_ptrs() { - if (UNLIKELY(output_tensor_ptrs_cache_)) { + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { output_tensor_ptrs_cache_.reset(new std::vector({Out})); } return output_tensor_ptrs_cache_.get(); @@ -1221,15 +1224,15 @@ struct UnsqueezeParam : ParamBase { std::vector axes_tensor_vct{}; /////////////////////////////////////////////////////////////////////////////////// // get a vector of input tensors - const std::vector* input_tensor_ptrs() { - if (UNLIKELY(input_tensor_ptrs_cache_)) { + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { input_tensor_ptrs_cache_.reset(new std::vector({X})); } return input_tensor_ptrs_cache_.get(); } // get a vector of output tensors - const std::vector* output_tensor_ptrs() { - if (UNLIKELY(output_tensor_ptrs_cache_)) { + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { output_tensor_ptrs_cache_.reset(new std::vector({Out})); } return output_tensor_ptrs_cache_.get(); @@ -1253,15 +1256,15 @@ struct MatMulParam : ParamBase { float alpha{1.0f}; /////////////////////////////////////////////////////////////////////////////////// // get a vector of input tensors - const std::vector* input_tensor_ptrs() { - if (UNLIKELY(input_tensor_ptrs_cache_)) { + const std::vector* input_tensor_ptrs() override { + if (!input_tensor_ptrs_cache_) { input_tensor_ptrs_cache_.reset(new std::vector({X, Y})); } return input_tensor_ptrs_cache_.get(); } // get a vector of output tensors - const std::vector* output_tensor_ptrs() { - if (UNLIKELY(output_tensor_ptrs_cache_)) { + std::vector* output_tensor_ptrs() override { + if (!output_tensor_ptrs_cache_) { output_tensor_ptrs_cache_.reset(new std::vector({Out})); } return output_tensor_ptrs_cache_.get(); diff --git a/lite/operators/pool_op.h b/lite/operators/pool_op.h index 97f4a8a0083550fdcb0bc2d011e5e33d2d02011d..9c29f9597cde534ba158aa5d1b055c3d21a70474 100644 --- a/lite/operators/pool_op.h +++ b/lite/operators/pool_op.h @@ -41,6 +41,7 @@ class PoolOpLite : public OpLite { // TODO(Superjomn) replace framework::OpDesc with a lite one. bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { + AttachParam(¶m_); auto x = op_desc.Input("X").front(); auto out = op_desc.Output("Out").front(); diff --git a/lite/operators/reshape_op.cc b/lite/operators/reshape_op.cc index 32bc91a3a0b9b852024e2e0f2ea36585e2a29892..93f4ad9048779d1ea6861a273ff09c73cbd89281 100644 --- a/lite/operators/reshape_op.cc +++ b/lite/operators/reshape_op.cc @@ -56,6 +56,7 @@ bool ReshapeOp::InferShapeImpl() const { } bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + AttachParam(¶m_); param_.x = scope->FindVar(opdesc.Input("X").front())->GetMutable(); param_.output = diff --git a/lite/operators/scale_op.cc b/lite/operators/scale_op.cc index 3236277187462dd1185e698e5cb8fe919fe20b97..d2090076fe387198bbb2db904a73940504ba7841 100644 --- a/lite/operators/scale_op.cc +++ b/lite/operators/scale_op.cc @@ -30,6 +30,7 @@ bool ScaleOp::InferShapeImpl() const { } bool ScaleOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { + AttachParam(¶m_); auto x = op_desc.Input("X").front(); auto output = op_desc.Output("Out").front(); param_.x = scope->FindVar(x)->GetMutable(); diff --git a/lite/operators/sequence_softmax_op.cc b/lite/operators/sequence_softmax_op.cc index eb1821129d8b036a252fb36ab69094c8a58cce95..c13c4cc7392a931e0066c8a177f2c2ca56bc76f4 100644 --- a/lite/operators/sequence_softmax_op.cc +++ b/lite/operators/sequence_softmax_op.cc @@ -34,6 +34,7 @@ bool SequenceSoftmaxOp::InferShapeImpl() const { bool SequenceSoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + AttachParam(¶m_); param_.X = scope->FindVar(opdesc.Input("X").front())->GetMutable(); param_.Out = diff --git a/lite/operators/slice_op.cc b/lite/operators/slice_op.cc index ecbcc5c2c5925d320c0334889634e57ed894695f..c18fc989411b8e074f562af0f1685810872151c6 100644 --- a/lite/operators/slice_op.cc +++ b/lite/operators/slice_op.cc @@ -87,6 +87,7 @@ bool SliceOp::InferShapeImpl() const { } bool SliceOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + AttachParam(¶m_); param_.X = scope->FindVar(opdesc.Input("Input").front())->GetMutable(); param_.Out = diff --git a/lite/operators/softmax_op.cc b/lite/operators/softmax_op.cc index 000953007c27e37bc05d85d810880f6ccd7728ce..e95e355bda428d724e3b89ee80fc01f592032765 100644 --- a/lite/operators/softmax_op.cc +++ b/lite/operators/softmax_op.cc @@ -38,6 +38,8 @@ bool SoftmaxOp::InferShapeImpl() const { } bool SoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + AttachParam(¶m_); + param_.x = const_cast( &scope->FindVar(opdesc.Input("X").front())->Get()); param_.output = diff --git a/lite/operators/split_op.cc b/lite/operators/split_op.cc index 14cff7d692e3aaa37d95233931760f37c31e4526..ed913a72bc1174f7919dc677b78059771146391a 100644 --- a/lite/operators/split_op.cc +++ b/lite/operators/split_op.cc @@ -75,6 +75,7 @@ bool SplitOp::InferShapeImpl() const { } bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + AttachParam(¶m_); param_.axis = opdesc.GetAttr("axis"); param_.num = opdesc.GetAttr("num"); param_.sections = opdesc.GetAttr>("sections"); diff --git a/lite/operators/squeeze_op.cc b/lite/operators/squeeze_op.cc index c34ad06debb0c4bb99d083bc7938ea26b2dcac9f..8dada8fed06de4dc44149c0fd7583fe646cc2dd2 100644 --- a/lite/operators/squeeze_op.cc +++ b/lite/operators/squeeze_op.cc @@ -84,6 +84,7 @@ bool SqueezeOp::InferShapeImpl() const { } bool SqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + AttachParam(¶m_); auto x_var = scope->FindVar(opdesc.Input("X").front()); auto output_var = scope->FindVar(opdesc.Output("Out").front()); CHECK(x_var); diff --git a/lite/operators/transpose_op.cc b/lite/operators/transpose_op.cc index 40780346d038c875a2eb96b11aff9d1c2a578a2f..fe40bf6fa2f84ce7c999b41435aed00cd6555887 100644 --- a/lite/operators/transpose_op.cc +++ b/lite/operators/transpose_op.cc @@ -70,6 +70,7 @@ bool TransposeOp::InferShapeImpl() const { } bool TransposeOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { + AttachParam(¶m_); auto x = op_desc.Input("X").front(); auto out = op_desc.Output("Out").front(); diff --git a/lite/operators/unsqueeze_op.cc b/lite/operators/unsqueeze_op.cc index 0a7487d34eeb6fe149f956e2f48bdb411a690f14..23865aaabbb6c7617b21fffd4cddea1e358f302f 100644 --- a/lite/operators/unsqueeze_op.cc +++ b/lite/operators/unsqueeze_op.cc @@ -89,6 +89,7 @@ bool UnsqueezeOp::InferShapeImpl() const { } bool UnsqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + AttachParam(¶m_); auto x_var = scope->FindVar(opdesc.Input("X").front()); auto output_var = scope->FindVar(opdesc.Output("Out").front()); CHECK(x_var);