提交 c754a38f 编写于 作者: H huzhiqiang 提交者: GitHub

[operator] add InferShapeImpl method (#3294)

上级 50638e96
...@@ -22,6 +22,61 @@ ...@@ -22,6 +22,61 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
bool OpLite::InferShape() {
// if input_tensor_ptrs and output_tensor_ptrs are overloaded in param_
// InferShapeByMemoryInternal will be applied.
if (param_.input_tensor_ptrs() && param_.output_tensor_ptrs()) {
return this->InferShapeWithCache();
} else {
// otherwise, InferShapeImpl is applied directly.
return this->InferShapeImpl();
}
}
bool OpLite::InferShapeWithCache() {
// 1. Get vector of current input tensors
auto *current_inputs = param_.input_tensor_ptrs();
// 2. Get hash value of current inputs shape and lod
size_t new_hash = 0;
for (auto iter = current_inputs->begin(); iter != current_inputs->end();
iter++) {
// combined dims value into new_hash value.
auto &element_dims = (*iter)->dims();
for (int i = 0; i < element_dims.size(); i++) {
new_hash =
lite::hash_combine(new_hash, static_cast<int>(element_dims[i]));
}
// combine lod value into new_hash valud.
auto &emement_lods = (*iter)->lod();
for (auto lod_iter = emement_lods.begin(); lod_iter != emement_lods.end();
lod_iter++) {
for (int i = 0; i < lod_iter->size(); i++) {
new_hash =
lite::hash_combine(new_hash, static_cast<int>(lod_iter->at(i)));
}
}
}
// 3. infer shapes of output tensors
if (new_hash == io_shape_lod_hash_ && new_hash != 0) {
// if current hash value is consistent with io_shape_lod_hash_,
// previous outputs shape and lod are reused.
auto *current_outputs = param_.output_tensor_ptrs();
for (int i = 0; i < current_outputs->size(); i++) {
current_outputs->at(i)->Resize(last_output_shapes[i]);
current_outputs->at(i)->set_lod(last_output_lods[i]);
}
} else {
// otherwise, current hash value is changed, InferShapeImpl will apply.
io_shape_lod_hash_ = new_hash;
this->InferShapeImpl();
auto *current_outputs = param_.output_tensor_ptrs();
for (int i = 0; i < current_outputs->size(); i++) {
last_output_shapes[i] = current_outputs->at(i)->dims();
last_output_lods[i] = current_outputs->at(i)->lod();
}
}
return true;
}
std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels( std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
const std::vector<Place> &places, const std::string &kernel_type) { const std::vector<Place> &places, const std::string &kernel_type) {
std::vector<std::unique_ptr<KernelBase>> kernels; std::vector<std::unique_ptr<KernelBase>> kernels;
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <functional>
#include <list> #include <list>
#include <map> #include <map>
#include <memory> #include <memory>
...@@ -24,6 +25,7 @@ ...@@ -24,6 +25,7 @@
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/core/scope.h" #include "lite/core/scope.h"
#include "lite/model_parser/cpp/op_desc.h" #include "lite/model_parser/cpp/op_desc.h"
#include "lite/operators/op_params.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -64,8 +66,8 @@ class OpLite : public Registry { ...@@ -64,8 +66,8 @@ class OpLite : public Registry {
// Check the shape. // Check the shape.
virtual bool CheckShape() const { return true; } virtual bool CheckShape() const { return true; }
// Inference the outputs' shape. // Inference the outputs' shape.
virtual bool InferShape() const { return true; } virtual bool InferShapeImpl() const { return true; }
virtual bool SmartInferShape() { return this->InferShape(); } virtual bool InferShape();
// Run this operator. // Run this operator.
virtual bool Run(); virtual bool Run();
// Indicate whether the Op runs only once or not // Indicate whether the Op runs only once or not
...@@ -151,10 +153,16 @@ class OpLite : public Registry { ...@@ -151,10 +153,16 @@ class OpLite : public Registry {
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
Place kernel_place_{TARGET(kHost), PRECISION(kFloat)}; Place kernel_place_{TARGET(kHost), PRECISION(kFloat)};
std::unique_ptr<OpInfo> op_info_; std::unique_ptr<OpInfo> op_info_;
std::vector<DDimLite> last_input_shapes;
std::vector<DDimLite> last_output_shapes; std::vector<DDimLite> last_output_shapes{};
std::vector<std::vector<std::vector<uint64_t>>> last_output_lods; std::vector<std::vector<std::vector<uint64_t>>> last_output_lods{};
std::vector<std::vector<std::vector<uint64_t>>> last_input_lods; size_t io_shape_lod_hash_{};
mutable operators::ParamBase param_;
private:
// Infer Shape according to memory, if current input shapes are consistent
// with that of previous inputs, output shapes of last time will be reused.
bool InferShapeWithCache();
}; };
/* /*
......
...@@ -286,8 +286,7 @@ void Instruction::Run() { ...@@ -286,8 +286,7 @@ void Instruction::Run() {
return; return;
} }
// op_->InferShape(); op_->InferShape();
op_->SmartInferShape();
kernel_->Launch(); kernel_->Launch();
has_run_ = true; has_run_ = true;
} }
......
...@@ -25,7 +25,7 @@ bool ActivationGradOp::CheckShape() const { ...@@ -25,7 +25,7 @@ bool ActivationGradOp::CheckShape() const {
return true; return true;
} }
bool ActivationGradOp::InferShape() const { bool ActivationGradOp::InferShapeImpl() const {
param_.X_grad->Resize(param_.Out_grad->dims()); param_.X_grad->Resize(param_.Out_grad->dims());
return true; return true;
} }
......
...@@ -26,7 +26,7 @@ class ActivationGradOp : public OpLite { ...@@ -26,7 +26,7 @@ class ActivationGradOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
......
...@@ -25,7 +25,7 @@ bool ActivationOp::CheckShape() const { ...@@ -25,7 +25,7 @@ bool ActivationOp::CheckShape() const {
return true; return true;
} }
bool ActivationOp::InferShape() const { bool ActivationOp::InferShapeImpl() const {
param_.Out->Resize(param_.X->dims()); param_.Out->Resize(param_.X->dims());
auto out_lod = param_.Out->mutable_lod(); auto out_lod = param_.Out->mutable_lod();
*out_lod = param_.X->lod(); *out_lod = param_.X->lod();
......
...@@ -26,7 +26,7 @@ class ActivationOp : public OpLite { ...@@ -26,7 +26,7 @@ class ActivationOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
......
...@@ -44,7 +44,7 @@ bool AffineChannelOpLite::CheckShape() const { ...@@ -44,7 +44,7 @@ bool AffineChannelOpLite::CheckShape() const {
return true; return true;
} }
bool AffineChannelOpLite::InferShape() const { bool AffineChannelOpLite::InferShapeImpl() const {
const auto x_dims = param_.X->dims(); const auto x_dims = param_.X->dims();
param_.Out->Resize(x_dims); param_.Out->Resize(x_dims);
return true; return true;
......
...@@ -31,7 +31,7 @@ class AffineChannelOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class AffineChannelOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -31,7 +31,7 @@ bool AnchorGeneratorOpLite::CheckShape() const { ...@@ -31,7 +31,7 @@ bool AnchorGeneratorOpLite::CheckShape() const {
return true; return true;
} }
bool AnchorGeneratorOpLite::InferShape() const { bool AnchorGeneratorOpLite::InferShapeImpl() const {
auto input_dims = param_.Input->dims(); auto input_dims = param_.Input->dims();
size_t num_anchors = param_.aspect_ratios.size() * param_.anchor_sizes.size(); size_t num_anchors = param_.aspect_ratios.size() * param_.anchor_sizes.size();
std::vector<int64_t> output_shape( std::vector<int64_t> output_shape(
......
...@@ -32,7 +32,7 @@ class AnchorGeneratorOpLite : public OpLite { ...@@ -32,7 +32,7 @@ class AnchorGeneratorOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -29,7 +29,7 @@ bool ArgmaxOpLite::CheckShape() const { ...@@ -29,7 +29,7 @@ bool ArgmaxOpLite::CheckShape() const {
return true; return true;
} }
bool ArgmaxOpLite::InferShape() const { bool ArgmaxOpLite::InferShapeImpl() const {
auto x_dims = param_.X->dims(); auto x_dims = param_.X->dims();
int x_rank = x_dims.size(); int x_rank = x_dims.size();
int axis = param_.Axis; int axis = param_.Axis;
......
...@@ -31,7 +31,7 @@ class ArgmaxOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class ArgmaxOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -26,7 +26,7 @@ bool AssignOpLite::CheckShape() const { ...@@ -26,7 +26,7 @@ bool AssignOpLite::CheckShape() const {
return true; return true;
} }
bool AssignOpLite::InferShape() const { bool AssignOpLite::InferShapeImpl() const {
lite::DDim input_dims; lite::DDim input_dims;
input_dims = param_.X->dims(); input_dims = param_.X->dims();
param_.Out->Resize(lite::DDim(input_dims)); param_.Out->Resize(lite::DDim(input_dims));
......
...@@ -30,7 +30,7 @@ class AssignOpLite : public OpLite { ...@@ -30,7 +30,7 @@ class AssignOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -35,7 +35,7 @@ bool AssignValueOpLite::CheckShape() const { ...@@ -35,7 +35,7 @@ bool AssignValueOpLite::CheckShape() const {
return true; return true;
} }
bool AssignValueOpLite::InferShape() const { bool AssignValueOpLite::InferShapeImpl() const {
std::vector<int> shape = param_.shape; std::vector<int> shape = param_.shape;
std::vector<int64_t> out_shape; std::vector<int64_t> out_shape;
for (size_t i = 0; i < shape.size(); i++) out_shape.push_back(shape[i]); for (size_t i = 0; i < shape.size(); i++) out_shape.push_back(shape[i]);
......
...@@ -31,7 +31,7 @@ class AssignValueOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class AssignValueOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -28,7 +28,7 @@ bool AttentionPaddingMaskOp::CheckShape() const { ...@@ -28,7 +28,7 @@ bool AttentionPaddingMaskOp::CheckShape() const {
return true; return true;
} }
bool AttentionPaddingMaskOp::InferShape() const { bool AttentionPaddingMaskOp::InferShapeImpl() const {
auto src_len = param_.X->lod()[0][1]; auto src_len = param_.X->lod()[0][1];
CHECK_EQ(src_len, param_.X->dims()[1]) CHECK_EQ(src_len, param_.X->dims()[1])
<< "Mismatch source length, expect: " << src_len << "Mismatch source length, expect: " << src_len
......
...@@ -29,7 +29,7 @@ class AttentionPaddingMaskOp : public OpLite { ...@@ -29,7 +29,7 @@ class AttentionPaddingMaskOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -34,7 +34,7 @@ bool AxpyOpLite::CheckShape() const { ...@@ -34,7 +34,7 @@ bool AxpyOpLite::CheckShape() const {
return true; return true;
} }
bool AxpyOpLite::InferShape() const { bool AxpyOpLite::InferShapeImpl() const {
auto dims = param_.Bias->dims(); auto dims = param_.Bias->dims();
// Set output dims // Set output dims
......
...@@ -31,7 +31,7 @@ class AxpyOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class AxpyOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -46,7 +46,7 @@ bool BatchNormOp::CheckShape() const { ...@@ -46,7 +46,7 @@ bool BatchNormOp::CheckShape() const {
return true; return true;
} }
bool BatchNormOp::InferShape() const { bool BatchNormOp::InferShapeImpl() const {
auto x_dims = param_.x->dims(); auto x_dims = param_.x->dims();
int64_t channel_size = 0; int64_t channel_size = 0;
switch (param_.data_layout) { switch (param_.data_layout) {
......
...@@ -30,7 +30,7 @@ class BatchNormOp : public OpLite { ...@@ -30,7 +30,7 @@ class BatchNormOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -28,7 +28,7 @@ bool BeamSearchDecodeOpLite::CheckShape() const { ...@@ -28,7 +28,7 @@ bool BeamSearchDecodeOpLite::CheckShape() const {
return true; return true;
} }
bool BeamSearchDecodeOpLite::InferShape() const { return true; } bool BeamSearchDecodeOpLite::InferShapeImpl() const { return true; }
bool BeamSearchDecodeOpLite::AttachImpl(const cpp::OpDesc &op_desc, bool BeamSearchDecodeOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) { lite::Scope *scope) {
......
...@@ -31,7 +31,7 @@ class BeamSearchDecodeOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class BeamSearchDecodeOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -30,7 +30,7 @@ bool BeamSearchOp::CheckShape() const { ...@@ -30,7 +30,7 @@ bool BeamSearchOp::CheckShape() const {
return true; return true;
} }
bool BeamSearchOp::InferShape() const { return true; } bool BeamSearchOp::InferShapeImpl() const { return true; }
bool BeamSearchOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { bool BeamSearchOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.pre_ids = scope->FindTensor(opdesc.Input("pre_ids").front()); param_.pre_ids = scope->FindTensor(opdesc.Input("pre_ids").front());
......
...@@ -30,7 +30,7 @@ class BeamSearchOp : public OpLite { ...@@ -30,7 +30,7 @@ class BeamSearchOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -35,7 +35,7 @@ bool BoxClipOpLite::CheckShape() const { ...@@ -35,7 +35,7 @@ bool BoxClipOpLite::CheckShape() const {
return true; return true;
} }
bool BoxClipOpLite::InferShape() const { bool BoxClipOpLite::InferShapeImpl() const {
auto* input = param_.Input; auto* input = param_.Input;
auto* output = param_.Output; auto* output = param_.Output;
output->Resize(input->dims()); output->Resize(input->dims());
......
...@@ -31,7 +31,7 @@ class BoxClipOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class BoxClipOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -35,7 +35,7 @@ bool BoxCoderOpLite::CheckShape() const { ...@@ -35,7 +35,7 @@ bool BoxCoderOpLite::CheckShape() const {
return true; return true;
} }
bool BoxCoderOpLite::InferShape() const { bool BoxCoderOpLite::InferShapeImpl() const {
auto prior_box_dims = param_.prior_box->dims(); auto prior_box_dims = param_.prior_box->dims();
auto target_box_dims = param_.target_box->dims(); auto target_box_dims = param_.target_box->dims();
std::string code_type = param_.code_type; std::string code_type = param_.code_type;
......
...@@ -29,7 +29,7 @@ class BoxCoderOpLite : public OpLite { ...@@ -29,7 +29,7 @@ class BoxCoderOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -24,7 +24,7 @@ bool CalibOpLite::CheckShape() const { ...@@ -24,7 +24,7 @@ bool CalibOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.output); CHECK_OR_FALSE(param_.output);
return true; return true;
} }
bool CalibOpLite::InferShape() const { bool CalibOpLite::InferShapeImpl() const {
param_.output->Resize(param_.input->dims()); param_.output->Resize(param_.input->dims());
return true; return true;
} }
......
...@@ -42,7 +42,7 @@ class CalibOpLite : public OpLite { ...@@ -42,7 +42,7 @@ class CalibOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope); bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope);
......
...@@ -25,7 +25,7 @@ bool CastOp::CheckShape() const { ...@@ -25,7 +25,7 @@ bool CastOp::CheckShape() const {
return true; return true;
} }
bool CastOp::InferShape() const { bool CastOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out);
// TODO(Superjomn) Enable data sharing. // TODO(Superjomn) Enable data sharing.
auto out_dims = param_.X->dims(); auto out_dims = param_.X->dims();
......
...@@ -30,7 +30,7 @@ class CastOp : public OpLite { ...@@ -30,7 +30,7 @@ class CastOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -43,7 +43,7 @@ bool CollectFpnProposalsOpLite::CheckShape() const { ...@@ -43,7 +43,7 @@ bool CollectFpnProposalsOpLite::CheckShape() const {
return true; return true;
} }
bool CollectFpnProposalsOpLite::InferShape() const { bool CollectFpnProposalsOpLite::InferShapeImpl() const {
param_.fpn_rois->Resize({param_.post_nms_topN, 4}); param_.fpn_rois->Resize({param_.post_nms_topN, 4});
return true; return true;
......
...@@ -32,7 +32,7 @@ class CollectFpnProposalsOpLite : public OpLite { ...@@ -32,7 +32,7 @@ class CollectFpnProposalsOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -26,7 +26,7 @@ bool CompareOp::CheckShape() const { ...@@ -26,7 +26,7 @@ bool CompareOp::CheckShape() const {
return true; return true;
} }
bool CompareOp::InferShape() const { bool CompareOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out);
// TODO(Superjomn) Enable data sharing. // TODO(Superjomn) Enable data sharing.
auto input_dims = param_.X->dims(); auto input_dims = param_.X->dims();
......
...@@ -30,7 +30,7 @@ class CompareOp : public OpLite { ...@@ -30,7 +30,7 @@ class CompareOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -26,7 +26,7 @@ bool ConcatOpLite::CheckShape() const { ...@@ -26,7 +26,7 @@ bool ConcatOpLite::CheckShape() const {
return true; return true;
} }
bool ConcatOpLite::InferShape() const { bool ConcatOpLite::InferShapeImpl() const {
const std::vector<Tensor *> &inputs = param_.x; const std::vector<Tensor *> &inputs = param_.x;
const size_t n = inputs.size(); const size_t n = inputs.size();
CHECK_GT_OR_FALSE(n, 0); CHECK_GT_OR_FALSE(n, 0);
......
...@@ -30,7 +30,7 @@ class ConcatOpLite : public OpLite { ...@@ -30,7 +30,7 @@ class ConcatOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -27,7 +27,7 @@ bool ConditionalBlockOpLite::CheckShape() const { ...@@ -27,7 +27,7 @@ bool ConditionalBlockOpLite::CheckShape() const {
return true; return true;
} }
bool ConditionalBlockOpLite::InferShape() const { return true; } bool ConditionalBlockOpLite::InferShapeImpl() const { return true; }
bool ConditionalBlockOpLite::AttachImpl(const cpp::OpDesc &op_desc, bool ConditionalBlockOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) { lite::Scope *scope) {
......
...@@ -31,7 +31,7 @@ class ConditionalBlockOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class ConditionalBlockOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -80,35 +80,7 @@ void UpdatePaddingAndDilation(std::vector<int>* paddings, ...@@ -80,35 +80,7 @@ void UpdatePaddingAndDilation(std::vector<int>* paddings,
} }
} }
bool ConvOpLite::SmartInferShape() { bool ConvOpLite::InferShapeImpl() const {
if (!last_input_shapes.empty()) {
if (last_input_shapes[0] == param_.x->dims() &&
last_input_lods[0] == param_.x->lod()) {
param_.output->Resize(last_output_shapes[0]);
param_.output->set_lod(last_output_lods[0]);
return true;
}
}
this->InferShape();
if (!last_input_shapes.empty()) {
last_input_shapes.clear();
last_input_lods.clear();
}
last_input_shapes.push_back(param_.x->dims());
last_input_lods.push_back(param_.x->lod());
if (!last_output_shapes.empty()) {
last_output_shapes.clear();
last_output_lods.clear();
}
last_output_shapes.push_back(param_.output->dims());
last_output_lods.push_back(param_.output->lod());
return true;
}
bool ConvOpLite::InferShape() const {
const auto in_dims = param_.x->dims(); const auto in_dims = param_.x->dims();
const auto filter_dims = param_.filter->dims(); const auto filter_dims = param_.filter->dims();
......
...@@ -34,9 +34,7 @@ class ConvOpLite : public OpLite { ...@@ -34,9 +34,7 @@ class ConvOpLite : public OpLite {
explicit ConvOpLite(const std::string& type) : OpLite(type) {} explicit ConvOpLite(const std::string& type) : OpLite(type) {}
bool CheckShape() const override; bool CheckShape() const override;
bool InferShapeImpl() const override;
bool InferShape() const override;
bool SmartInferShape() override;
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override { bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override {
......
...@@ -52,7 +52,7 @@ inline int ConvTransposeOutputSize(int input_size, ...@@ -52,7 +52,7 @@ inline int ConvTransposeOutputSize(int input_size,
return output_size; return output_size;
} }
bool ConvTransposeOpLite::InferShape() const { bool ConvTransposeOpLite::InferShapeImpl() const {
const auto in_dims = param_.x->dims(); const auto in_dims = param_.x->dims();
const auto filter_dims = param_.filter->dims(); const auto filter_dims = param_.filter->dims();
......
...@@ -34,7 +34,7 @@ class ConvTransposeOpLite : public OpLite { ...@@ -34,7 +34,7 @@ class ConvTransposeOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -60,7 +60,7 @@ bool CrfDecodingOpLite::CheckShape() const { ...@@ -60,7 +60,7 @@ bool CrfDecodingOpLite::CheckShape() const {
return true; return true;
} }
bool CrfDecodingOpLite::InferShape() const { bool CrfDecodingOpLite::InferShapeImpl() const {
auto emission_dims = param_.emission->dims(); auto emission_dims = param_.emission->dims();
if (param_.length == nullptr) { if (param_.length == nullptr) {
param_.viterbi_path->Resize({emission_dims[0], 1}); param_.viterbi_path->Resize({emission_dims[0], 1});
......
...@@ -31,7 +31,7 @@ class CrfDecodingOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class CrfDecodingOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -26,7 +26,7 @@ bool CropOpLite::CheckShape() const { ...@@ -26,7 +26,7 @@ bool CropOpLite::CheckShape() const {
return true; return true;
} }
bool CropOpLite::InferShape() const { bool CropOpLite::InferShapeImpl() const {
// nchw // nchw
auto x_dims = param_.X->dims(); auto x_dims = param_.X->dims();
lite::DDim output_shape(x_dims); lite::DDim output_shape(x_dims);
......
...@@ -30,7 +30,7 @@ class CropOpLite : public OpLite { ...@@ -30,7 +30,7 @@ class CropOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -29,7 +29,7 @@ bool DecodeBboxesOpLite::CheckShape() const { ...@@ -29,7 +29,7 @@ bool DecodeBboxesOpLite::CheckShape() const {
return true; return true;
} }
bool DecodeBboxesOpLite::InferShape() const { bool DecodeBboxesOpLite::InferShapeImpl() const {
param_.bbox_data->Resize(param_.loc_data->dims()); param_.bbox_data->Resize(param_.loc_data->dims());
return true; return true;
} }
......
...@@ -29,7 +29,7 @@ class DecodeBboxesOpLite : public OpLite { ...@@ -29,7 +29,7 @@ class DecodeBboxesOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -27,7 +27,7 @@ bool DensityPriorBoxOpLite::CheckShape() const { ...@@ -27,7 +27,7 @@ bool DensityPriorBoxOpLite::CheckShape() const {
return true; return true;
} }
bool DensityPriorBoxOpLite::InferShape() const { return true; } bool DensityPriorBoxOpLite::InferShapeImpl() const { return true; }
bool DensityPriorBoxOpLite::AttachImpl(const cpp::OpDesc& opdesc, bool DensityPriorBoxOpLite::AttachImpl(const cpp::OpDesc& opdesc,
lite::Scope* scope) { lite::Scope* scope) {
......
...@@ -30,7 +30,7 @@ class DensityPriorBoxOpLite : public OpLite { ...@@ -30,7 +30,7 @@ class DensityPriorBoxOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -32,7 +32,7 @@ bool DistributeFpnProposalsOpLite::CheckShape() const { ...@@ -32,7 +32,7 @@ bool DistributeFpnProposalsOpLite::CheckShape() const {
return true; return true;
} }
bool DistributeFpnProposalsOpLite::InferShape() const { bool DistributeFpnProposalsOpLite::InferShapeImpl() const {
int num_out_rois = param_.max_level - param_.min_level + 1; int num_out_rois = param_.max_level - param_.min_level + 1;
for (int i = 0; i < num_out_rois; i++) { for (int i = 0; i < num_out_rois; i++) {
param_.multi_fpn_rois[i]->Resize({-1, 4}); param_.multi_fpn_rois[i]->Resize({-1, 4});
......
...@@ -32,7 +32,7 @@ class DistributeFpnProposalsOpLite : public OpLite { ...@@ -32,7 +32,7 @@ class DistributeFpnProposalsOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -26,7 +26,7 @@ bool DropoutOp::CheckShape() const { ...@@ -26,7 +26,7 @@ bool DropoutOp::CheckShape() const {
return true; return true;
} }
bool DropoutOp::InferShape() const { bool DropoutOp::InferShapeImpl() const {
const auto x_dims = param_.x->dims(); const auto x_dims = param_.x->dims();
param_.output->Resize(x_dims); param_.output->Resize(x_dims);
if (param_.is_test == false) { if (param_.is_test == false) {
......
...@@ -28,7 +28,7 @@ class DropoutOp : public OpLite { ...@@ -28,7 +28,7 @@ class DropoutOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
......
...@@ -26,7 +26,7 @@ bool ElementwiseGradOp::CheckShape() const { ...@@ -26,7 +26,7 @@ bool ElementwiseGradOp::CheckShape() const {
return true; return true;
} }
bool ElementwiseGradOp::InferShape() const { bool ElementwiseGradOp::InferShapeImpl() const {
auto x_dim = param_.X->dims(); auto x_dim = param_.X->dims();
auto y_dim = param_.Y->dims(); auto y_dim = param_.Y->dims();
if (param_.XGrad) { if (param_.XGrad) {
......
...@@ -27,7 +27,7 @@ class ElementwiseGradOp : public OpLite { ...@@ -27,7 +27,7 @@ class ElementwiseGradOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
......
...@@ -26,39 +26,8 @@ bool ElementwiseOp::CheckShape() const { ...@@ -26,39 +26,8 @@ bool ElementwiseOp::CheckShape() const {
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out);
return true; return true;
} }
bool ElementwiseOp::SmartInferShape() {
if (!last_input_shapes.empty()) {
if (last_input_shapes[0] == param_.X->dims() &&
last_input_shapes[1] == param_.Y->dims() &&
last_input_lods[0] == param_.X->lod() &&
last_input_lods[1] == param_.Y->lod()) {
param_.Out->Resize(last_output_shapes[0]);
param_.Out->set_lod(last_output_lods[0]);
return true;
}
}
this->InferShape();
if (!last_input_shapes.empty()) {
last_input_shapes.clear();
last_input_lods.clear();
}
last_input_shapes.push_back(param_.X->dims());
last_input_lods.push_back(param_.X->lod());
last_input_shapes.push_back(param_.Y->dims());
last_input_lods.push_back(param_.Y->lod());
if (!last_output_shapes.empty()) { bool ElementwiseOp::InferShapeImpl() const {
last_output_shapes.clear();
last_output_lods.clear();
}
last_output_shapes.push_back(param_.Out->dims());
last_output_lods.push_back(param_.Out->lod());
return true;
}
bool ElementwiseOp::InferShape() const {
auto x_dim = param_.X->dims(); auto x_dim = param_.X->dims();
auto y_dim = param_.Y->dims(); auto y_dim = param_.Y->dims();
if (x_dim == y_dim) { if (x_dim == y_dim) {
...@@ -136,7 +105,7 @@ bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { ...@@ -136,7 +105,7 @@ bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
// return true; // return true;
//} //}
// bool ElementwiseGradExplicitOp::InferShape() const { // bool ElementwiseGradExplicitOp::InferShapeImpl() const {
// param_.X_grad->Resize(param_.Out_grad->dims()); // param_.X_grad->Resize(param_.Out_grad->dims());
// if (param_.Y_grad) param_.Y_grad->Resize(param_.Y->dims()); // if (param_.Y_grad) param_.Y_grad->Resize(param_.Y->dims());
// return true; // return true;
......
...@@ -27,8 +27,7 @@ class ElementwiseOp : public OpLite { ...@@ -27,8 +27,7 @@ class ElementwiseOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool SmartInferShape() override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
...@@ -48,7 +47,7 @@ class ElementwiseOp : public OpLite { ...@@ -48,7 +47,7 @@ class ElementwiseOp : public OpLite {
// bool CheckShape() const override; // bool CheckShape() const override;
// bool InferShape() const override; // bool InferShapeImpl() const override;
// bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; // bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
......
...@@ -32,7 +32,7 @@ bool ExpandOpLite::CheckShape() const { ...@@ -32,7 +32,7 @@ bool ExpandOpLite::CheckShape() const {
return true; return true;
} }
bool ExpandOpLite::InferShape() const { bool ExpandOpLite::InferShapeImpl() const {
DDim out_dims(param_.X->dims()); DDim out_dims(param_.X->dims());
for (size_t i = 0; i < param_.expand_times.size(); ++i) { for (size_t i = 0; i < param_.expand_times.size(); ++i) {
out_dims[i] *= param_.expand_times[i]; out_dims[i] *= param_.expand_times[i];
......
...@@ -28,7 +28,7 @@ class ExpandOpLite : public OpLite { ...@@ -28,7 +28,7 @@ class ExpandOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -36,7 +36,7 @@ class FakeChannelWiseDequantizeMaxAbsOpLite : public OpLite { ...@@ -36,7 +36,7 @@ class FakeChannelWiseDequantizeMaxAbsOpLite : public OpLite {
bool CheckShape() const override { return true; } bool CheckShape() const override { return true; }
bool InferShape() const override { return true; } bool InferShapeImpl() const override { return true; }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front(); auto x = op_desc.Input("X").front();
......
...@@ -35,7 +35,7 @@ class FakeDequantizeMaxAbsOpLite : public OpLite { ...@@ -35,7 +35,7 @@ class FakeDequantizeMaxAbsOpLite : public OpLite {
bool CheckShape() const override { return true; } bool CheckShape() const override { return true; }
bool InferShape() const override { return true; } bool InferShapeImpl() const override { return true; }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front(); auto x = op_desc.Input("X").front();
......
...@@ -36,7 +36,7 @@ class FakeQuantizeDequantizeMovingAvgMaxAbsOpLite : public OpLite { ...@@ -36,7 +36,7 @@ class FakeQuantizeDequantizeMovingAvgMaxAbsOpLite : public OpLite {
bool CheckShape() const override { return true; } bool CheckShape() const override { return true; }
bool InferShape() const override { return true; } bool InferShapeImpl() const override { return true; }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front(); auto x = op_desc.Input("X").front();
......
...@@ -36,7 +36,7 @@ class FakeQuantizeMovingAvgMaxAbsOpLite : public OpLite { ...@@ -36,7 +36,7 @@ class FakeQuantizeMovingAvgMaxAbsOpLite : public OpLite {
bool CheckShape() const override { return true; } bool CheckShape() const override { return true; }
bool InferShape() const override { return true; } bool InferShapeImpl() const override { return true; }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front(); auto x = op_desc.Input("X").front();
......
...@@ -36,7 +36,7 @@ class FakeQuantizeRangeMaxAbsOpLite : public OpLite { ...@@ -36,7 +36,7 @@ class FakeQuantizeRangeMaxAbsOpLite : public OpLite {
bool CheckShape() const override { return true; } bool CheckShape() const override { return true; }
bool InferShape() const override { return true; } bool InferShapeImpl() const override { return true; }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front(); auto x = op_desc.Input("X").front();
......
...@@ -48,34 +48,7 @@ bool FcOpLite::CheckShape() const { ...@@ -48,34 +48,7 @@ bool FcOpLite::CheckShape() const {
return true; return true;
} }
bool FcOpLite::SmartInferShape() { bool FcOpLite::InferShapeImpl() const {
if (!last_input_shapes.empty() && !last_output_shapes.empty()) {
if (last_input_shapes[0] == param_.input->dims() &&
last_input_lods[0] == param_.input->lod()) {
param_.output->Resize(last_output_shapes[0]);
param_.output->set_lod(last_output_lods[0]);
return true;
}
}
this->InferShape();
if (!last_input_shapes.empty()) {
last_input_shapes.clear();
last_input_lods.clear();
}
last_input_shapes.push_back(param_.input->dims());
last_input_lods.push_back(param_.input->lod());
if (!last_output_shapes.empty()) {
last_output_shapes.clear();
last_output_lods.clear();
}
last_output_shapes.push_back(param_.output->dims());
last_output_lods.push_back(param_.output->lod());
return true;
}
bool FcOpLite::InferShape() const {
const auto& input_dims = param_.input->dims(); const auto& input_dims = param_.input->dims();
const auto& w_dims = param_.w->dims(); const auto& w_dims = param_.w->dims();
int in_num_col_dims = param_.in_num_col_dims; int in_num_col_dims = param_.in_num_col_dims;
......
...@@ -35,8 +35,7 @@ class FcOpLite : public OpLite { ...@@ -35,8 +35,7 @@ class FcOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool SmartInferShape() override;
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override;
......
...@@ -29,7 +29,7 @@ class FeedOp : public OpLite { ...@@ -29,7 +29,7 @@ class FeedOp : public OpLite {
return true; return true;
} }
bool InferShape() const override { return true; } bool InferShapeImpl() const override { return true; }
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
......
...@@ -29,7 +29,7 @@ class FetchOp : public OpLite { ...@@ -29,7 +29,7 @@ class FetchOp : public OpLite {
return true; return true;
} }
bool InferShape() const override { return true; } bool InferShapeImpl() const override { return true; }
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
protected: protected:
......
...@@ -28,7 +28,7 @@ bool FillConstantBatchSizeLikeOp::CheckShape() const { ...@@ -28,7 +28,7 @@ bool FillConstantBatchSizeLikeOp::CheckShape() const {
return true; return true;
} }
bool FillConstantBatchSizeLikeOp::InferShape() const { bool FillConstantBatchSizeLikeOp::InferShapeImpl() const {
std::vector<int64_t> output_dim{param_.shape.begin(), param_.shape.end()}; std::vector<int64_t> output_dim{param_.shape.begin(), param_.shape.end()};
if (param_.input_dim_idx == 0 && !param_.input->lod().empty()) { if (param_.input_dim_idx == 0 && !param_.input->lod().empty()) {
output_dim[param_.output_dim_idx] = param_.input->lod().back().size() - 1; output_dim[param_.output_dim_idx] = param_.input->lod().back().size() - 1;
......
...@@ -32,7 +32,7 @@ class FillConstantBatchSizeLikeOp : public OpLite { ...@@ -32,7 +32,7 @@ class FillConstantBatchSizeLikeOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -24,7 +24,7 @@ bool FillConstantOp::CheckShape() const { ...@@ -24,7 +24,7 @@ bool FillConstantOp::CheckShape() const {
return true; return true;
} }
bool FillConstantOp::InferShape() const { bool FillConstantOp::InferShapeImpl() const {
std::vector<int64_t> out_shape; std::vector<int64_t> out_shape;
auto shape_tensor = param_.shape_tensor; auto shape_tensor = param_.shape_tensor;
auto shape_tensor_list = param_.shape_tensor_list; auto shape_tensor_list = param_.shape_tensor_list;
......
...@@ -31,7 +31,7 @@ class FillConstantOp : public OpLite { ...@@ -31,7 +31,7 @@ class FillConstantOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -25,7 +25,7 @@ bool FlattenOp::CheckShape() const { ...@@ -25,7 +25,7 @@ bool FlattenOp::CheckShape() const {
return true; return true;
} }
bool FlattenOp::InferShape() const { bool FlattenOp::InferShapeImpl() const {
auto x_dims = param_.x->dims(); auto x_dims = param_.x->dims();
auto out_lod = param_.output->mutable_lod(); auto out_lod = param_.output->mutable_lod();
...@@ -71,8 +71,8 @@ bool Flatten2Op::CheckShape() const { ...@@ -71,8 +71,8 @@ bool Flatten2Op::CheckShape() const {
return true; return true;
} }
bool Flatten2Op::InferShape() const { bool Flatten2Op::InferShapeImpl() const {
FlattenOp::InferShape(); FlattenOp::InferShapeImpl();
auto x_dims = param_.x->dims(); auto x_dims = param_.x->dims();
std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 0); std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 0);
for (size_t i = 0; i < x_dims.size(); i++) { for (size_t i = 0; i < x_dims.size(); i++) {
......
...@@ -30,7 +30,7 @@ class FlattenOp : public OpLite { ...@@ -30,7 +30,7 @@ class FlattenOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
...@@ -49,7 +49,7 @@ class Flatten2Op : public FlattenOp { ...@@ -49,7 +49,7 @@ class Flatten2Op : public FlattenOp {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -27,7 +27,7 @@ bool FusionElementwiseActivationOp::CheckShape() const { ...@@ -27,7 +27,7 @@ bool FusionElementwiseActivationOp::CheckShape() const {
return true; return true;
} }
bool FusionElementwiseActivationOp::InferShape() const { bool FusionElementwiseActivationOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.X->dims().size() >= param_.Y->dims().size()); CHECK_OR_FALSE(param_.X->dims().size() >= param_.Y->dims().size());
param_.Out->Resize(param_.X->dims()); param_.Out->Resize(param_.X->dims());
return true; return true;
...@@ -59,7 +59,7 @@ bool FusionElementwiseActivationOp::AttachImpl(const cpp::OpDesc& opdesc, ...@@ -59,7 +59,7 @@ bool FusionElementwiseActivationOp::AttachImpl(const cpp::OpDesc& opdesc,
// return true; // return true;
// } // }
// bool FusionElementwiseActivationGradExplicitOp::InferShape() const { // bool FusionElementwiseActivationGradExplicitOp::InferShapeImpl() const {
// param_.X_grad->Resize(param_.Out_grad->dims()); // param_.X_grad->Resize(param_.Out_grad->dims());
// param_.Y_grad->Resize(param_.Y->dims()); // param_.Y_grad->Resize(param_.Y->dims());
// return true; // return true;
......
...@@ -29,7 +29,7 @@ class FusionElementwiseActivationOp : public OpLite { ...@@ -29,7 +29,7 @@ class FusionElementwiseActivationOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
...@@ -51,7 +51,7 @@ class FusionElementwiseActivationOp : public OpLite { ...@@ -51,7 +51,7 @@ class FusionElementwiseActivationOp : public OpLite {
// bool CheckShape() const override; // bool CheckShape() const override;
// bool InferShape() const override; // bool InferShapeImpl() const override;
// bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; // bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
......
...@@ -26,7 +26,7 @@ bool GatherOp::CheckShape() const { ...@@ -26,7 +26,7 @@ bool GatherOp::CheckShape() const {
return true; return true;
} }
bool GatherOp::InferShape() const { bool GatherOp::InferShapeImpl() const {
auto index_dims = param_.Index->dims(); auto index_dims = param_.Index->dims();
CHECK(index_dims.size() == 1 || CHECK(index_dims.size() == 1 ||
(index_dims.size() == 2 && index_dims[1] == 1)) (index_dims.size() == 2 && index_dims[1] == 1))
......
...@@ -30,7 +30,7 @@ class GatherOp : public OpLite { ...@@ -30,7 +30,7 @@ class GatherOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -43,7 +43,7 @@ bool GenerateProposalsOpLite::CheckShape() const { ...@@ -43,7 +43,7 @@ bool GenerateProposalsOpLite::CheckShape() const {
return true; return true;
} }
bool GenerateProposalsOpLite::InferShape() const { bool GenerateProposalsOpLite::InferShapeImpl() const {
param_.RpnRois->Resize(std::vector<int64_t>({-1, 4})); param_.RpnRois->Resize(std::vector<int64_t>({-1, 4}));
param_.RpnRoiProbs->Resize(std::vector<int64_t>({-1, 1})); param_.RpnRoiProbs->Resize(std::vector<int64_t>({-1, 1}));
return true; return true;
......
...@@ -32,7 +32,7 @@ class GenerateProposalsOpLite : public OpLite { ...@@ -32,7 +32,7 @@ class GenerateProposalsOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -42,7 +42,7 @@ bool GridSamplerOp::CheckShape() const { ...@@ -42,7 +42,7 @@ bool GridSamplerOp::CheckShape() const {
return true; return true;
} }
bool GridSamplerOp::InferShape() const { bool GridSamplerOp::InferShapeImpl() const {
auto x_dims = param_.x->dims(); auto x_dims = param_.x->dims();
param_.out->Resize(x_dims); param_.out->Resize(x_dims);
return true; return true;
......
...@@ -31,7 +31,7 @@ class GridSamplerOp : public OpLite { ...@@ -31,7 +31,7 @@ class GridSamplerOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -51,7 +51,7 @@ bool GRUOpLite::CheckShape() const { ...@@ -51,7 +51,7 @@ bool GRUOpLite::CheckShape() const {
return true; return true;
} }
bool GRUOpLite::InferShape() const { bool GRUOpLite::InferShapeImpl() const {
const auto& input_dims = param_.input->dims(); const auto& input_dims = param_.input->dims();
const auto& weight_dims = param_.weight->dims(); const auto& weight_dims = param_.weight->dims();
int frame_size = weight_dims[0]; int frame_size = weight_dims[0];
......
...@@ -30,7 +30,7 @@ class GRUOpLite : public OpLite { ...@@ -30,7 +30,7 @@ class GRUOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -51,7 +51,7 @@ bool GRUUnitOpLite::CheckShape() const { ...@@ -51,7 +51,7 @@ bool GRUUnitOpLite::CheckShape() const {
return true; return true;
} }
bool GRUUnitOpLite::InferShape() const { bool GRUUnitOpLite::InferShapeImpl() const {
auto input_dims = param_.input->dims(); auto input_dims = param_.input->dims();
auto hidden_prev_dims = param_.hidden_prev->dims(); auto hidden_prev_dims = param_.hidden_prev->dims();
auto weight_dims = param_.weight->dims(); auto weight_dims = param_.weight->dims();
......
...@@ -30,7 +30,7 @@ class GRUUnitOpLite : public OpLite { ...@@ -30,7 +30,7 @@ class GRUUnitOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -26,7 +26,7 @@ inline int Im2SeqOutputSize( ...@@ -26,7 +26,7 @@ inline int Im2SeqOutputSize(
} }
bool Im2SequenceOp::CheckShape() const { return true; } bool Im2SequenceOp::CheckShape() const { return true; }
bool Im2SequenceOp::InferShape() const { bool Im2SequenceOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out);
// TODO(Superjomn) Enable data sharing. // TODO(Superjomn) Enable data sharing.
auto input_dims = param_.X->dims(); auto input_dims = param_.X->dims();
......
...@@ -30,7 +30,7 @@ class Im2SequenceOp : public OpLite { ...@@ -30,7 +30,7 @@ class Im2SequenceOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -25,7 +25,7 @@ bool IncrementOp::CheckShape() const { ...@@ -25,7 +25,7 @@ bool IncrementOp::CheckShape() const {
return true; return true;
} }
bool IncrementOp::InferShape() const { bool IncrementOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out);
// TODO(Superjomn) Enable data sharing. // TODO(Superjomn) Enable data sharing.
auto out_dims = param_.X->dims(); auto out_dims = param_.X->dims();
......
...@@ -30,7 +30,7 @@ class IncrementOp : public OpLite { ...@@ -30,7 +30,7 @@ class IncrementOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -42,7 +42,7 @@ bool InstanceNormOp::CheckShape() const { ...@@ -42,7 +42,7 @@ bool InstanceNormOp::CheckShape() const {
return true; return true;
} }
bool InstanceNormOp::InferShape() const { bool InstanceNormOp::InferShapeImpl() const {
auto x_dims = param_.x->dims(); auto x_dims = param_.x->dims();
int64_t batch_size = x_dims[0]; int64_t batch_size = x_dims[0];
int64_t channel_size = x_dims[1]; int64_t channel_size = x_dims[1];
......
...@@ -31,7 +31,7 @@ class InstanceNormOp : public OpLite { ...@@ -31,7 +31,7 @@ class InstanceNormOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -34,7 +34,7 @@ bool InterpolateOp::CheckShape() const { ...@@ -34,7 +34,7 @@ bool InterpolateOp::CheckShape() const {
return true; return true;
} }
bool InterpolateOp::InferShape() const { bool InterpolateOp::InferShapeImpl() const {
auto X = param_.X; auto X = param_.X;
int n = X->dims()[0]; int n = X->dims()[0];
......
...@@ -31,7 +31,7 @@ class InterpolateOp : public OpLite { ...@@ -31,7 +31,7 @@ class InterpolateOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -24,7 +24,7 @@ bool IoCopyOp::CheckShape() const { ...@@ -24,7 +24,7 @@ bool IoCopyOp::CheckShape() const {
CHECK_OR_FALSE(param_.y); CHECK_OR_FALSE(param_.y);
return true; return true;
} }
bool IoCopyOp::InferShape() const { bool IoCopyOp::InferShapeImpl() const {
param_.y->Resize(param_.x->dims()); param_.y->Resize(param_.x->dims());
return true; return true;
} }
......
...@@ -24,7 +24,7 @@ class IoCopyOp : public OpLite { ...@@ -24,7 +24,7 @@ class IoCopyOp : public OpLite {
public: public:
explicit IoCopyOp(const std::string &type) : OpLite(type) {} explicit IoCopyOp(const std::string &type) : OpLite(type) {}
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool Run() override; bool Run() override;
std::string DebugString() const override; std::string DebugString() const override;
......
...@@ -21,7 +21,7 @@ namespace operators { ...@@ -21,7 +21,7 @@ namespace operators {
bool IsEmptyOp::CheckShape() const { return true; } bool IsEmptyOp::CheckShape() const { return true; }
bool IsEmptyOp::InferShape() const { return true; } bool IsEmptyOp::InferShapeImpl() const { return true; }
bool IsEmptyOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { bool IsEmptyOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.X = param_.X =
......
...@@ -30,7 +30,7 @@ class IsEmptyOp : public OpLite { ...@@ -30,7 +30,7 @@ class IsEmptyOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -27,7 +27,7 @@ bool LayerNormOp::CheckShape() const { ...@@ -27,7 +27,7 @@ bool LayerNormOp::CheckShape() const {
return true; return true;
} }
bool LayerNormOp::InferShape() const { bool LayerNormOp::InferShapeImpl() const {
auto out_dims = param_.X->dims(); auto out_dims = param_.X->dims();
param_.Y->Resize(out_dims); param_.Y->Resize(out_dims);
auto inner_size = out_dims.Flatten2D(param_.begin_norm_axis)[0]; auto inner_size = out_dims.Flatten2D(param_.begin_norm_axis)[0];
......
...@@ -30,7 +30,7 @@ class LayerNormOp : public OpLite { ...@@ -30,7 +30,7 @@ class LayerNormOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -24,7 +24,7 @@ bool LayoutOp::CheckShape() const { ...@@ -24,7 +24,7 @@ bool LayoutOp::CheckShape() const {
CHECK_OR_FALSE(param_.y); CHECK_OR_FALSE(param_.y);
return true; return true;
} }
bool LayoutOp::InferShape() const { bool LayoutOp::InferShapeImpl() const {
param_.y->Resize(param_.x->dims()); param_.y->Resize(param_.x->dims());
return true; return true;
} }
......
...@@ -24,7 +24,7 @@ class LayoutOp : public OpLite { ...@@ -24,7 +24,7 @@ class LayoutOp : public OpLite {
public: public:
explicit LayoutOp(const std::string &type) : OpLite(type) {} explicit LayoutOp(const std::string &type) : OpLite(type) {}
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool Run() override; bool Run() override;
std::string DebugString() const override; std::string DebugString() const override;
......
...@@ -25,7 +25,7 @@ bool LodResetOp::CheckShape() const { ...@@ -25,7 +25,7 @@ bool LodResetOp::CheckShape() const {
return true; return true;
} }
bool LodResetOp::InferShape() const { bool LodResetOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out);
// TODO(Superjomn) Enable data sharing. // TODO(Superjomn) Enable data sharing.
param_.Out->Resize(param_.X->dims()); param_.Out->Resize(param_.X->dims());
......
...@@ -30,7 +30,7 @@ class LodResetOp : public OpLite { ...@@ -30,7 +30,7 @@ class LodResetOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -26,7 +26,7 @@ bool BinaryLogicalOp::CheckShape() const { ...@@ -26,7 +26,7 @@ bool BinaryLogicalOp::CheckShape() const {
return true; return true;
} }
bool BinaryLogicalOp::InferShape() const { bool BinaryLogicalOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out);
// TODO(Superjomn) Enable data sharing. // TODO(Superjomn) Enable data sharing.
auto input_dims = param_.X->dims(); auto input_dims = param_.X->dims();
...@@ -53,7 +53,7 @@ bool UnaryLogicalOp::CheckShape() const { ...@@ -53,7 +53,7 @@ bool UnaryLogicalOp::CheckShape() const {
return true; return true;
} }
bool UnaryLogicalOp::InferShape() const { bool UnaryLogicalOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out);
// TODO(Superjomn) Enable data sharing. // TODO(Superjomn) Enable data sharing.
auto input_dims = param_.X->dims(); auto input_dims = param_.X->dims();
......
...@@ -30,7 +30,7 @@ class BinaryLogicalOp : public OpLite { ...@@ -30,7 +30,7 @@ class BinaryLogicalOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
...@@ -49,7 +49,7 @@ class UnaryLogicalOp : public OpLite { ...@@ -49,7 +49,7 @@ class UnaryLogicalOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -36,7 +36,7 @@ bool LookupTableDequantOpLite::CheckShape() const { ...@@ -36,7 +36,7 @@ bool LookupTableDequantOpLite::CheckShape() const {
return true; return true;
} }
bool LookupTableDequantOpLite::InferShape() const { bool LookupTableDequantOpLite::InferShapeImpl() const {
const auto& table_dims = param_.W->dims(); const auto& table_dims = param_.W->dims();
const auto& ids_dims = param_.Ids->dims(); const auto& ids_dims = param_.Ids->dims();
......
...@@ -31,7 +31,7 @@ class LookupTableDequantOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class LookupTableDequantOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -36,7 +36,7 @@ bool LookupTableOpLite::CheckShape() const { ...@@ -36,7 +36,7 @@ bool LookupTableOpLite::CheckShape() const {
return true; return true;
} }
bool LookupTableOpLite::InferShape() const { bool LookupTableOpLite::InferShapeImpl() const {
const auto& table_dims = param_.W->dims(); const auto& table_dims = param_.W->dims();
const auto& ids_dims = param_.Ids->dims(); const auto& ids_dims = param_.Ids->dims();
......
...@@ -30,7 +30,7 @@ class LookupTableOpLite : public OpLite { ...@@ -30,7 +30,7 @@ class LookupTableOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -32,7 +32,7 @@ bool LookupTableV2OpLite::CheckShape() const { ...@@ -32,7 +32,7 @@ bool LookupTableV2OpLite::CheckShape() const {
return true; return true;
} }
bool LookupTableV2OpLite::InferShape() const { bool LookupTableV2OpLite::InferShapeImpl() const {
auto table_dims = param_.W->dims(); auto table_dims = param_.W->dims();
auto ids_dims = param_.Ids->dims(); auto ids_dims = param_.Ids->dims();
......
...@@ -30,7 +30,7 @@ class LookupTableV2OpLite : public OpLite { ...@@ -30,7 +30,7 @@ class LookupTableV2OpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -27,7 +27,7 @@ bool LrnOpLite::CheckShape() const { ...@@ -27,7 +27,7 @@ bool LrnOpLite::CheckShape() const {
return true; return true;
} }
bool LrnOpLite::InferShape() const { bool LrnOpLite::InferShapeImpl() const {
param_.Out->Resize(param_.X->dims()); param_.Out->Resize(param_.X->dims());
return true; return true;
} }
......
...@@ -28,7 +28,7 @@ class LrnOpLite : public OpLite { ...@@ -28,7 +28,7 @@ class LrnOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -26,7 +26,7 @@ bool LstmOp::CheckShape() const { ...@@ -26,7 +26,7 @@ bool LstmOp::CheckShape() const {
return true; return true;
} }
bool LstmOp::InferShape() const { bool LstmOp::InferShapeImpl() const {
auto in_dims = param_.Input->dims(); auto in_dims = param_.Input->dims();
if (param_.H0) { if (param_.H0) {
CHECK(param_.C0) << "lstm must has H0 and C0 in the same time"; CHECK(param_.C0) << "lstm must has H0 and C0 in the same time";
......
...@@ -30,7 +30,7 @@ class LstmOp : public OpLite { ...@@ -30,7 +30,7 @@ class LstmOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -42,7 +42,7 @@ bool MatchMatrixTensorOpLite::CheckShape() const { ...@@ -42,7 +42,7 @@ bool MatchMatrixTensorOpLite::CheckShape() const {
return true; return true;
} }
bool MatchMatrixTensorOpLite::InferShape() const { bool MatchMatrixTensorOpLite::InferShapeImpl() const {
const Tensor* x = param_.x; const Tensor* x = param_.x;
const Tensor* y = param_.y; const Tensor* y = param_.y;
DDim x_dims = param_.x->dims(); DDim x_dims = param_.x->dims();
......
...@@ -32,7 +32,7 @@ class MatchMatrixTensorOpLite : public OpLite { ...@@ -32,7 +32,7 @@ class MatchMatrixTensorOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -27,7 +27,7 @@ bool MatMulOpLite::CheckShape() const { ...@@ -27,7 +27,7 @@ bool MatMulOpLite::CheckShape() const {
return true; return true;
} }
bool MatMulOpLite::InferShape() const { bool MatMulOpLite::InferShapeImpl() const {
const auto x_dims = param_.X->dims(); const auto x_dims = param_.X->dims();
const auto y_dims = param_.Y->dims(); const auto y_dims = param_.Y->dims();
bool x_transpose = param_.transpose_X; bool x_transpose = param_.transpose_X;
......
...@@ -33,7 +33,7 @@ class MatMulOpLite : public OpLite { ...@@ -33,7 +33,7 @@ class MatMulOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
......
...@@ -28,7 +28,7 @@ bool MeanGradOp::CheckShape() const { ...@@ -28,7 +28,7 @@ bool MeanGradOp::CheckShape() const {
return true; return true;
} }
bool MeanGradOp::InferShape() const { bool MeanGradOp::InferShapeImpl() const {
param_.X_grad->Resize(param_.X->dims()); param_.X_grad->Resize(param_.X->dims());
return true; return true;
} }
......
...@@ -27,7 +27,7 @@ class MeanGradOp : public OpLite { ...@@ -27,7 +27,7 @@ class MeanGradOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override;
......
...@@ -27,7 +27,7 @@ bool MeanOp::CheckShape() const { ...@@ -27,7 +27,7 @@ bool MeanOp::CheckShape() const {
return true; return true;
} }
bool MeanOp::InferShape() const { bool MeanOp::InferShapeImpl() const {
param_.Out->Resize(std::vector<int64_t>{1}); param_.Out->Resize(std::vector<int64_t>{1});
return true; return true;
} }
......
...@@ -27,7 +27,7 @@ class MeanOp : public OpLite { ...@@ -27,7 +27,7 @@ class MeanOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
......
...@@ -34,7 +34,7 @@ bool MergeLodTensorOpLite::CheckShape() const { ...@@ -34,7 +34,7 @@ bool MergeLodTensorOpLite::CheckShape() const {
return true; return true;
} }
bool MergeLodTensorOpLite::InferShape() const { bool MergeLodTensorOpLite::InferShapeImpl() const {
auto dims = param_.in_true->dims(); auto dims = param_.in_true->dims();
param_.out->Resize(dims); param_.out->Resize(dims);
return true; return true;
......
...@@ -31,7 +31,7 @@ class MergeLodTensorOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class MergeLodTensorOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -46,7 +46,7 @@ bool MulGradOpLite::CheckShape() const { ...@@ -46,7 +46,7 @@ bool MulGradOpLite::CheckShape() const {
return true; return true;
} }
bool MulGradOpLite::InferShape() const { bool MulGradOpLite::InferShapeImpl() const {
const auto x_dims = param_.x->dims(); const auto x_dims = param_.x->dims();
const auto y_dims = param_.y->dims(); const auto y_dims = param_.y->dims();
if (param_.x_grad) { if (param_.x_grad) {
......
...@@ -33,7 +33,7 @@ class MulGradOpLite : public OpLite { ...@@ -33,7 +33,7 @@ class MulGradOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
......
...@@ -35,7 +35,7 @@ bool MulOpLite::CheckShape() const { ...@@ -35,7 +35,7 @@ bool MulOpLite::CheckShape() const {
return true; return true;
} }
bool MulOpLite::InferShape() const { bool MulOpLite::InferShapeImpl() const {
const auto x_dims = param_.x->dims(); const auto x_dims = param_.x->dims();
const auto y_dims = param_.y->dims(); const auto y_dims = param_.y->dims();
......
...@@ -33,7 +33,7 @@ class MulOpLite : public OpLite { ...@@ -33,7 +33,7 @@ class MulOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
......
...@@ -41,7 +41,7 @@ bool MulticlassNmsOpLite::CheckShape() const { ...@@ -41,7 +41,7 @@ bool MulticlassNmsOpLite::CheckShape() const {
return true; return true;
} }
bool MulticlassNmsOpLite::InferShape() const { bool MulticlassNmsOpLite::InferShapeImpl() const {
auto box_dims = param_.bboxes->dims(); auto box_dims = param_.bboxes->dims();
auto score_dims = param_.scores->dims(); auto score_dims = param_.scores->dims();
auto score_size = score_dims.size(); auto score_size = score_dims.size();
......
...@@ -29,7 +29,7 @@ class MulticlassNmsOpLite : public OpLite { ...@@ -29,7 +29,7 @@ class MulticlassNmsOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -26,7 +26,7 @@ bool NegativeOpLite::CheckShape() const { ...@@ -26,7 +26,7 @@ bool NegativeOpLite::CheckShape() const {
return true; return true;
} }
bool NegativeOpLite::InferShape() const { bool NegativeOpLite::InferShapeImpl() const {
lite::DDim input_dims; lite::DDim input_dims;
input_dims = param_.X->dims(); input_dims = param_.X->dims();
param_.Out->Resize(lite::DDim(input_dims)); param_.Out->Resize(lite::DDim(input_dims));
......
...@@ -30,7 +30,7 @@ class NegativeOpLite : public OpLite { ...@@ -30,7 +30,7 @@ class NegativeOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -25,7 +25,7 @@ bool NormOp::CheckShape() const { ...@@ -25,7 +25,7 @@ bool NormOp::CheckShape() const {
return true; return true;
} }
bool NormOp::InferShape() const { bool NormOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out);
// TODO(Superjomn) Enable data sharing. // TODO(Superjomn) Enable data sharing.
auto out_dims = param_.X->dims(); auto out_dims = param_.X->dims();
......
...@@ -30,7 +30,7 @@ class NormOp : public OpLite { ...@@ -30,7 +30,7 @@ class NormOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "lite/model_parser/cpp/block_desc.h" #include "lite/model_parser/cpp/block_desc.h"
#include "lite/model_parser/desc_apis.h" #include "lite/model_parser/desc_apis.h"
#include "lite/utils/all.h" #include "lite/utils/all.h"
#include "lite/utils/variant.h"
/* /*
* This file contains all the argument parameter data structure for operators. * This file contains all the argument parameter data structure for operators.
*/ */
...@@ -32,6 +33,16 @@ namespace paddle { ...@@ -32,6 +33,16 @@ namespace paddle {
namespace lite { namespace lite {
namespace operators { namespace operators {
struct ParamBase {
public:
const std::vector<Tensor*>* input_tensor_ptrs() const { return nullptr; }
std::vector<Tensor*>* output_tensor_ptrs() { return nullptr; }
protected:
std::shared_ptr<std::vector<const Tensor*>> input_tensor_ptrs_cache_{nullptr};
std::shared_ptr<std::vector<Tensor*>> output_tensor_ptrs_cache_{nullptr};
};
using param_t = Any; using param_t = Any;
#define WITH_INT8_CONFIG \ #define WITH_INT8_CONFIG \
bool enable_int8{false}; \ bool enable_int8{false}; \
...@@ -41,38 +52,38 @@ using param_t = Any; ...@@ -41,38 +52,38 @@ using param_t = Any;
int bit_length{8}; int bit_length{8};
/// ----------------------- Functional operators ------------------------------ /// ----------------------- Functional operators ------------------------------
struct FeedParam { struct FeedParam : ParamBase {
std::vector<lite::Tensor>* feed_list{}; std::vector<lite::Tensor>* feed_list{};
lite::Tensor* out{}; lite::Tensor* out{};
int col; int col;
}; };
struct FetchParam { struct FetchParam : ParamBase {
const lite::Tensor* input{}; const lite::Tensor* input{};
std::vector<lite::Tensor>* fetch_list{}; std::vector<lite::Tensor>* fetch_list{};
int col; int col;
}; };
// Helper op for lite framework // Helper op for lite framework
struct IoCopyParam { struct IoCopyParam : ParamBase {
const lite::Tensor* x{}; const lite::Tensor* x{};
lite::Tensor* y{}; lite::Tensor* y{};
int process_type{0}; int process_type{0};
}; };
struct LayoutParam { struct LayoutParam : ParamBase {
const lite::Tensor* x{}; const lite::Tensor* x{};
lite::Tensor* y{}; lite::Tensor* y{};
int process_type{0}; int process_type{0};
}; };
struct CalibParam { struct CalibParam : ParamBase {
const lite::Tensor* input{}; const lite::Tensor* input{};
lite::Tensor* output{}; lite::Tensor* output{};
float scale; float scale;
}; };
struct SubgraphParam { struct SubgraphParam : ParamBase {
std::vector<std::string> input_names{}; std::vector<std::string> input_names{};
std::vector<std::string> output_names{}; std::vector<std::string> output_names{};
std::vector<std::string> input_data_names{}; std::vector<std::string> input_data_names{};
...@@ -84,7 +95,7 @@ struct SubgraphParam { ...@@ -84,7 +95,7 @@ struct SubgraphParam {
/// -------------------------- NN operators ------------------------------------ /// -------------------------- NN operators ------------------------------------
struct FcParam { struct FcParam : ParamBase {
lite::Tensor* input{nullptr}; lite::Tensor* input{nullptr};
lite::Tensor* w{nullptr}; lite::Tensor* w{nullptr};
lite::Tensor* bias{nullptr}; lite::Tensor* bias{nullptr};
...@@ -95,9 +106,24 @@ struct FcParam { ...@@ -95,9 +106,24 @@ struct FcParam {
bool padding_weights{false}; bool padding_weights{false};
// for int8 // for int8
WITH_INT8_CONFIG WITH_INT8_CONFIG
}; ///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
struct SearchSeqFcParam { const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({input}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
}
};
struct SearchSeqFcParam : ParamBase {
lite::Tensor* x{nullptr}; lite::Tensor* x{nullptr};
lite::Tensor* w{nullptr}; lite::Tensor* w{nullptr};
lite::Tensor* b{nullptr}; lite::Tensor* b{nullptr};
...@@ -106,7 +132,7 @@ struct SearchSeqFcParam { ...@@ -106,7 +132,7 @@ struct SearchSeqFcParam {
}; };
// For Interpolate Op // For Interpolate Op
struct InterpolateParam { struct InterpolateParam : ParamBase {
lite::Tensor* X{}; lite::Tensor* X{};
lite::Tensor* OutSize{}; lite::Tensor* OutSize{};
lite::Tensor* Out{}; lite::Tensor* Out{};
...@@ -123,7 +149,7 @@ struct InterpolateParam { ...@@ -123,7 +149,7 @@ struct InterpolateParam {
}; };
// For Mul Op // For Mul Op
struct MulParam { struct MulParam : ParamBase {
const lite::Tensor* x{}; const lite::Tensor* x{};
const lite::Tensor* y{}; const lite::Tensor* y{};
lite::Tensor* output{}; lite::Tensor* output{};
...@@ -134,7 +160,7 @@ struct MulParam { ...@@ -134,7 +160,7 @@ struct MulParam {
WITH_INT8_CONFIG WITH_INT8_CONFIG
}; };
struct MulGradParam { struct MulGradParam : ParamBase {
const lite::Tensor* x{}; const lite::Tensor* x{};
const lite::Tensor* y{}; const lite::Tensor* y{};
const lite::Tensor* output_grad{}; const lite::Tensor* output_grad{};
...@@ -146,7 +172,7 @@ struct MulGradParam { ...@@ -146,7 +172,7 @@ struct MulGradParam {
}; };
// For ReduceMean Op // For ReduceMean Op
struct ReduceMeanParam { struct ReduceMeanParam : ParamBase {
lite::Tensor* X{}; lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
...@@ -155,7 +181,7 @@ struct ReduceMeanParam { ...@@ -155,7 +181,7 @@ struct ReduceMeanParam {
}; };
// For Stack Op // For Stack Op
struct StackParam { struct StackParam : ParamBase {
std::vector<lite::Tensor*> X; std::vector<lite::Tensor*> X;
lite::Tensor* Out{}; lite::Tensor* Out{};
...@@ -163,7 +189,7 @@ struct StackParam { ...@@ -163,7 +189,7 @@ struct StackParam {
}; };
// For Power Op // For Power Op
struct PowerParam { struct PowerParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
...@@ -172,7 +198,7 @@ struct PowerParam { ...@@ -172,7 +198,7 @@ struct PowerParam {
float power{}; float power{};
}; };
struct ShuffleChannelParam { struct ShuffleChannelParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
...@@ -180,7 +206,7 @@ struct ShuffleChannelParam { ...@@ -180,7 +206,7 @@ struct ShuffleChannelParam {
}; };
// For Yolobox // For Yolobox
struct YoloBoxParam { struct YoloBoxParam : ParamBase {
lite::Tensor* X{}; lite::Tensor* X{};
lite::Tensor* ImgSize{}; lite::Tensor* ImgSize{};
lite::Tensor* Boxes{}; lite::Tensor* Boxes{};
...@@ -193,7 +219,7 @@ struct YoloBoxParam { ...@@ -193,7 +219,7 @@ struct YoloBoxParam {
}; };
// For Scale Op // For Scale Op
struct ScaleParam { struct ScaleParam : ParamBase {
lite::Tensor* x{}; lite::Tensor* x{};
lite::Tensor* output{}; lite::Tensor* output{};
...@@ -203,14 +229,29 @@ struct ScaleParam { ...@@ -203,14 +229,29 @@ struct ScaleParam {
}; };
// For Softmax op // For Softmax op
struct SoftmaxParam { struct SoftmaxParam : ParamBase {
lite::Tensor* x{}; lite::Tensor* x{};
lite::Tensor* output{}; lite::Tensor* output{};
int axis{-1}; int axis{-1};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
}
}; };
// For Reshape and Reshape2 Op // For Reshape and Reshape2 Op
struct ReshapeParam { struct ReshapeParam : ParamBase {
const lite::Tensor* x{}; const lite::Tensor* x{};
std::vector<const lite::Tensor*> shape_tensor_vct{}; std::vector<const lite::Tensor*> shape_tensor_vct{};
const lite::Tensor* shape_tensor{}; const lite::Tensor* shape_tensor{};
...@@ -222,7 +263,7 @@ struct ReshapeParam { ...@@ -222,7 +263,7 @@ struct ReshapeParam {
}; };
// For Concat op // For Concat op
struct ConcatParam { struct ConcatParam : ParamBase {
std::vector<lite::Tensor*> x{}; std::vector<lite::Tensor*> x{};
lite::Tensor* output{}; lite::Tensor* output{};
int axis{0}; int axis{0};
...@@ -230,7 +271,7 @@ struct ConcatParam { ...@@ -230,7 +271,7 @@ struct ConcatParam {
}; };
/// ----------------------- activation operators ---------------------- /// ----------------------- activation operators ----------------------
struct ActivationParam { struct ActivationParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
float Leaky_relu_alpha{0}; // leaky_relu param float Leaky_relu_alpha{0}; // leaky_relu param
float Relu_clipped_coef{6}; // relu_clipped param float Relu_clipped_coef{6}; // relu_clipped param
...@@ -245,7 +286,7 @@ struct ActivationParam { ...@@ -245,7 +286,7 @@ struct ActivationParam {
lite_api::ActivationType active_type; lite_api::ActivationType active_type;
}; };
struct ActivationGradParam { struct ActivationGradParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
const lite::Tensor* Out{}; const lite::Tensor* Out{};
// for backward // for backward
...@@ -254,7 +295,7 @@ struct ActivationGradParam { ...@@ -254,7 +295,7 @@ struct ActivationGradParam {
}; };
// For Convolution op // For Convolution op
struct ConvParam { struct ConvParam : ParamBase {
lite::Tensor* x{}; lite::Tensor* x{};
lite::Tensor* filter{}; lite::Tensor* filter{};
lite::Tensor* bias{nullptr}; lite::Tensor* bias{nullptr};
...@@ -294,10 +335,26 @@ struct ConvParam { ...@@ -294,10 +335,26 @@ struct ConvParam {
std::vector<int> output_size; std::vector<int> output_size;
// for int8 // for int8
WITH_INT8_CONFIG WITH_INT8_CONFIG
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
}
}; };
// For BatchNorm op // For BatchNorm op
struct BatchNormParam { struct BatchNormParam : ParamBase {
lite::Tensor* x{}; lite::Tensor* x{};
lite::Tensor* bias{}; lite::Tensor* bias{};
lite::Tensor* scale{}; lite::Tensor* scale{};
...@@ -316,7 +373,7 @@ struct BatchNormParam { ...@@ -316,7 +373,7 @@ struct BatchNormParam {
}; };
// For Pooling op // For Pooling op
struct PoolParam { struct PoolParam : ParamBase {
lite::Tensor* x{}; lite::Tensor* x{};
lite::Tensor* output{}; lite::Tensor* output{};
std::string pooling_type{""}; std::string pooling_type{""};
...@@ -340,7 +397,7 @@ struct PoolParam { ...@@ -340,7 +397,7 @@ struct PoolParam {
}; };
// For Dropout op // For Dropout op
struct DropoutParam { struct DropoutParam : ParamBase {
const lite::Tensor* x{}; const lite::Tensor* x{};
lite::Tensor* output{}; lite::Tensor* output{};
lite::Tensor* mask{}; lite::Tensor* mask{};
...@@ -352,7 +409,7 @@ struct DropoutParam { ...@@ -352,7 +409,7 @@ struct DropoutParam {
}; };
// For Split op // For Split op
struct SplitParam { struct SplitParam : ParamBase {
lite::Tensor* x{}; lite::Tensor* x{};
std::vector<lite::Tensor*> output{}; std::vector<lite::Tensor*> output{};
lite::Tensor* axis_tensor; lite::Tensor* axis_tensor;
...@@ -364,7 +421,7 @@ struct SplitParam { ...@@ -364,7 +421,7 @@ struct SplitParam {
}; };
// For Transpose op // For Transpose op
struct TransposeParam { struct TransposeParam : ParamBase {
const lite::Tensor* x{}; const lite::Tensor* x{};
lite::Tensor* output{}; lite::Tensor* output{};
lite::Tensor* xshape{}; lite::Tensor* xshape{};
...@@ -375,7 +432,7 @@ struct TransposeParam { ...@@ -375,7 +432,7 @@ struct TransposeParam {
}; };
/// ----------------------- element wise operators ---------------------- /// ----------------------- element wise operators ----------------------
struct ElementwiseParam { struct ElementwiseParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
const lite::Tensor* Y{}; const lite::Tensor* Y{};
lite::Tensor* Out{}; lite::Tensor* Out{};
...@@ -384,9 +441,24 @@ struct ElementwiseParam { ...@@ -384,9 +441,24 @@ struct ElementwiseParam {
WITH_INT8_CONFIG WITH_INT8_CONFIG
float x_input_scale{1.0}; float x_input_scale{1.0};
float y_input_scale{1.0}; float y_input_scale{1.0};
}; ///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
struct ElementwiseGradParam { const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X, Y}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out}));
}
return output_tensor_ptrs_cache_.get();
}
};
struct ElementwiseGradParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
const lite::Tensor* Y{}; const lite::Tensor* Y{};
const lite::Tensor* OutGrad{}; const lite::Tensor* OutGrad{};
...@@ -404,12 +476,12 @@ struct FusionElementwiseActivationGradParam : public ElementwiseGradParam { ...@@ -404,12 +476,12 @@ struct FusionElementwiseActivationGradParam : public ElementwiseGradParam {
}; };
/// ----------------------- mean operators ---------------------- /// ----------------------- mean operators ----------------------
struct MeanParam { struct MeanParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
}; };
struct MeanGradParam { struct MeanGradParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
const lite::Tensor* Out_grad{}; const lite::Tensor* Out_grad{};
// for backward // for backward
...@@ -417,7 +489,7 @@ struct MeanGradParam { ...@@ -417,7 +489,7 @@ struct MeanGradParam {
}; };
/// ----------------------- fill_constant operators ---------------------- /// ----------------------- fill_constant operators ----------------------
struct FillConstantParam { struct FillConstantParam : ParamBase {
int dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)}; int dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)};
std::vector<int64_t> shape{}; std::vector<int64_t> shape{};
lite::Tensor* shape_tensor{nullptr}; lite::Tensor* shape_tensor{nullptr};
...@@ -429,7 +501,7 @@ struct FillConstantParam { ...@@ -429,7 +501,7 @@ struct FillConstantParam {
lite::Tensor* out{}; lite::Tensor* out{};
}; };
struct FillConstantBatchSizeLikeParam { struct FillConstantBatchSizeLikeParam : ParamBase {
const lite::Tensor* input{nullptr}; const lite::Tensor* input{nullptr};
lite::Tensor* out{nullptr}; lite::Tensor* out{nullptr};
...@@ -443,7 +515,7 @@ struct FillConstantBatchSizeLikeParam { ...@@ -443,7 +515,7 @@ struct FillConstantBatchSizeLikeParam {
}; };
// //
struct FakeQuantizeMovingAvgMaxAbsParam { struct FakeQuantizeMovingAvgMaxAbsParam : ParamBase {
const lite::Tensor* x{}; const lite::Tensor* x{};
const lite::Tensor* in_scale{}; const lite::Tensor* in_scale{};
const lite::Tensor* in_accum{}; const lite::Tensor* in_accum{};
...@@ -457,14 +529,14 @@ struct FakeQuantizeMovingAvgMaxAbsParam { ...@@ -457,14 +529,14 @@ struct FakeQuantizeMovingAvgMaxAbsParam {
float moving_rate{0.9}; float moving_rate{0.9};
}; };
struct FakeDequantizeMaxAbsParam { struct FakeDequantizeMaxAbsParam : ParamBase {
const lite::Tensor* x{}; const lite::Tensor* x{};
const lite::Tensor* in_scale{}; const lite::Tensor* in_scale{};
lite::Tensor* out{}; lite::Tensor* out{};
float max_range; float max_range;
}; };
struct FakeChannelWiseDequantizeMaxAbsParam { struct FakeChannelWiseDequantizeMaxAbsParam : ParamBase {
const lite::Tensor* x{}; const lite::Tensor* x{};
std::vector<const lite::Tensor*> scale_tensors{}; std::vector<const lite::Tensor*> scale_tensors{};
lite::Tensor* out{}; lite::Tensor* out{};
...@@ -472,7 +544,7 @@ struct FakeChannelWiseDequantizeMaxAbsParam { ...@@ -472,7 +544,7 @@ struct FakeChannelWiseDequantizeMaxAbsParam {
}; };
/// ----------------------- sgd operators ---------------------- /// ----------------------- sgd operators ----------------------
struct SGDParam { struct SGDParam : ParamBase {
int dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)}; int dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)};
const lite::Tensor* Param{}; const lite::Tensor* Param{};
...@@ -482,7 +554,7 @@ struct SGDParam { ...@@ -482,7 +554,7 @@ struct SGDParam {
}; };
/// ----------------------- uniform_random operators ---------------------- /// ----------------------- uniform_random operators ----------------------
struct UniformRandomParam { struct UniformRandomParam : ParamBase {
std::vector<int64_t> shape{}; std::vector<int64_t> shape{};
float min{-1.0f}; float min{-1.0f};
float max{1.0f}; float max{1.0f};
...@@ -491,12 +563,12 @@ struct UniformRandomParam { ...@@ -491,12 +563,12 @@ struct UniformRandomParam {
lite::Tensor* Out{}; lite::Tensor* Out{};
}; };
/// ----------------------- negative operators -------------- /// ----------------------- negative operators --------------
struct NegativeParam { struct NegativeParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
}; };
/// ----------------------- pad2d operators ---------------------- /// ----------------------- pad2d operators ----------------------
struct Pad2dParam { struct Pad2dParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
std::vector<int> paddings{0, 0, 0, 0}; std::vector<int> paddings{0, 0, 0, 0};
...@@ -506,7 +578,7 @@ struct Pad2dParam { ...@@ -506,7 +578,7 @@ struct Pad2dParam {
}; };
/// ----------------------- Crop operators ---------------------- /// ----------------------- Crop operators ----------------------
struct CropParam { struct CropParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
std::vector<int> offsets; std::vector<int> offsets;
...@@ -514,21 +586,21 @@ struct CropParam { ...@@ -514,21 +586,21 @@ struct CropParam {
}; };
///----------------------- argmax operators ---------------------- ///----------------------- argmax operators ----------------------
struct ArgmaxParam { struct ArgmaxParam : ParamBase {
lite::Tensor* X{}; lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
int Axis{0}; int Axis{0};
}; };
///----------------------- axpy operators ---------------------- ///----------------------- axpy operators ----------------------
struct AxpyParam { struct AxpyParam : ParamBase {
lite::Tensor* Scale{}; lite::Tensor* Scale{};
lite::Tensor* X{}; lite::Tensor* X{};
lite::Tensor* Bias{}; lite::Tensor* Bias{};
lite::Tensor* Out{}; lite::Tensor* Out{};
}; };
/// ----------------------- GRU unit operators ----------------------f /// ----------------------- GRU unit operators ----------------------f
struct GRUUnitParam { struct GRUUnitParam : ParamBase {
enum ActType { identity, sigmoid, tanh, relu }; enum ActType { identity, sigmoid, tanh, relu };
const lite::Tensor* input{nullptr}; const lite::Tensor* input{nullptr};
const lite::Tensor* hidden_prev{nullptr}; const lite::Tensor* hidden_prev{nullptr};
...@@ -544,7 +616,7 @@ struct GRUUnitParam { ...@@ -544,7 +616,7 @@ struct GRUUnitParam {
}; };
/// ------------------------------ lrn operators ------------------------------ /// ------------------------------ lrn operators ------------------------------
struct LrnParam { struct LrnParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
int n{5}; int n{5};
...@@ -555,7 +627,7 @@ struct LrnParam { ...@@ -555,7 +627,7 @@ struct LrnParam {
}; };
/// ----------------------- decode_bboxes operators ---------------------- /// ----------------------- decode_bboxes operators ----------------------
struct DecodeBboxesParam { struct DecodeBboxesParam : ParamBase {
const lite::Tensor* loc_data{}; const lite::Tensor* loc_data{};
const lite::Tensor* prior_data{}; const lite::Tensor* prior_data{};
lite::Tensor* bbox_data{}; lite::Tensor* bbox_data{};
...@@ -571,7 +643,7 @@ struct DecodeBboxesParam { ...@@ -571,7 +643,7 @@ struct DecodeBboxesParam {
}; };
/// ----------------------- box_coder operators ---------------------- /// ----------------------- box_coder operators ----------------------
struct BoxCoderParam { struct BoxCoderParam : ParamBase {
const lite::Tensor* prior_box{}; const lite::Tensor* prior_box{};
const lite::Tensor* prior_box_var{}; const lite::Tensor* prior_box_var{};
const lite::Tensor* target_box{}; const lite::Tensor* target_box{};
...@@ -584,7 +656,7 @@ struct BoxCoderParam { ...@@ -584,7 +656,7 @@ struct BoxCoderParam {
}; };
/// ----------------------- multiclass_nms operators ---------------------- /// ----------------------- multiclass_nms operators ----------------------
struct MulticlassNmsParam { struct MulticlassNmsParam : ParamBase {
const lite::Tensor* bboxes{}; const lite::Tensor* bboxes{};
const lite::Tensor* scores{}; const lite::Tensor* scores{};
lite::Tensor* out{}; lite::Tensor* out{};
...@@ -599,7 +671,7 @@ struct MulticlassNmsParam { ...@@ -599,7 +671,7 @@ struct MulticlassNmsParam {
}; };
/// ----------------------- priorbox operators ---------------------- /// ----------------------- priorbox operators ----------------------
struct PriorBoxParam { struct PriorBoxParam : ParamBase {
lite::Tensor* input{}; lite::Tensor* input{};
lite::Tensor* image{}; lite::Tensor* image{};
lite::Tensor* boxes{}; lite::Tensor* boxes{};
...@@ -628,7 +700,7 @@ struct DensityPriorBoxParam : public PriorBoxParam { ...@@ -628,7 +700,7 @@ struct DensityPriorBoxParam : public PriorBoxParam {
std::vector<int> density_sizes; std::vector<int> density_sizes;
}; };
/// ----------------------- GRU operators ----------------------f /// ----------------------- GRU operators ----------------------f
struct GRUParam { struct GRUParam : ParamBase {
const lite::Tensor* input{nullptr}; const lite::Tensor* input{nullptr};
const lite::Tensor* h0{nullptr}; const lite::Tensor* h0{nullptr};
const lite::Tensor* weight{nullptr}; const lite::Tensor* weight{nullptr};
...@@ -645,7 +717,7 @@ struct GRUParam { ...@@ -645,7 +717,7 @@ struct GRUParam {
}; };
/// ----------------------- BeamSearchDecode operators ----------------------f /// ----------------------- BeamSearchDecode operators ----------------------f
struct BeamSearchDecodeParam { struct BeamSearchDecodeParam : ParamBase {
std::vector<lite::Tensor>* ids{nullptr}; std::vector<lite::Tensor>* ids{nullptr};
std::vector<lite::Tensor>* scores{nullptr}; std::vector<lite::Tensor>* scores{nullptr};
lite::Tensor* sentence_ids{nullptr}; lite::Tensor* sentence_ids{nullptr};
...@@ -655,21 +727,21 @@ struct BeamSearchDecodeParam { ...@@ -655,21 +727,21 @@ struct BeamSearchDecodeParam {
}; };
/// ----------------------- LookupTable operators ----------------------f /// ----------------------- LookupTable operators ----------------------f
struct LookupTableParam { struct LookupTableParam : ParamBase {
const lite::Tensor* W{nullptr}; const lite::Tensor* W{nullptr};
const lite::Tensor* Ids{nullptr}; const lite::Tensor* Ids{nullptr};
lite::Tensor* Out{nullptr}; lite::Tensor* Out{nullptr};
int64_t padding_idx{-1}; int64_t padding_idx{-1};
}; };
struct LookupTableDequantParam { struct LookupTableDequantParam : ParamBase {
lite::Tensor* W{nullptr}; lite::Tensor* W{nullptr};
lite::Tensor* Ids{nullptr}; lite::Tensor* Ids{nullptr};
lite::Tensor* Out{nullptr}; lite::Tensor* Out{nullptr};
int64_t padding_idx{-1}; int64_t padding_idx{-1};
}; };
struct Im2SequenceParam { struct Im2SequenceParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
const lite::Tensor* Y{}; const lite::Tensor* Y{};
lite::Tensor* Out{}; lite::Tensor* Out{};
...@@ -679,19 +751,19 @@ struct Im2SequenceParam { ...@@ -679,19 +751,19 @@ struct Im2SequenceParam {
std::vector<int> out_strides{1, 1}; std::vector<int> out_strides{1, 1};
}; };
struct SequenceSoftmaxParam { struct SequenceSoftmaxParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
}; };
struct NormParam { struct NormParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
lite::Tensor* Norm{}; lite::Tensor* Norm{};
int axis{1}; int axis{1};
float epsilon{1e-10}; float epsilon{1e-10};
}; };
struct LayerNormParam { struct LayerNormParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
const lite::Tensor* Scale{}; const lite::Tensor* Scale{};
const lite::Tensor* Bias{}; const lite::Tensor* Bias{};
...@@ -702,13 +774,13 @@ struct LayerNormParam { ...@@ -702,13 +774,13 @@ struct LayerNormParam {
float epsilon{1e-5}; float epsilon{1e-5};
}; };
struct LogicalParam { struct LogicalParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
const lite::Tensor* Y{}; const lite::Tensor* Y{};
lite::Tensor* Out{}; lite::Tensor* Out{};
}; };
struct CompareParam { struct CompareParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
const lite::Tensor* Y{}; const lite::Tensor* Y{};
bool force_cpu{0}; bool force_cpu{0};
...@@ -716,7 +788,7 @@ struct CompareParam { ...@@ -716,7 +788,7 @@ struct CompareParam {
lite::Tensor* Out{}; lite::Tensor* Out{};
}; };
struct WhileParam { struct WhileParam : ParamBase {
Scope* scope{}; Scope* scope{};
Tensor* cond{}; Tensor* cond{};
cpp::BlockDesc* sub_block{}; cpp::BlockDesc* sub_block{};
...@@ -724,32 +796,32 @@ struct WhileParam { ...@@ -724,32 +796,32 @@ struct WhileParam {
std::vector<Tensor*> outs{}; std::vector<Tensor*> outs{};
}; };
struct TopkParam { struct TopkParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
lite::Tensor* Indices{}; lite::Tensor* Indices{};
int K{1}; int K{1};
}; };
struct IncrementParam { struct IncrementParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
float step{1}; float step{1};
}; };
struct WriteToArrayParam { struct WriteToArrayParam : ParamBase {
const lite::Tensor* X{nullptr}; const lite::Tensor* X{nullptr};
const lite::Tensor* I{nullptr}; const lite::Tensor* I{nullptr};
std::vector<lite::Tensor>* Out{nullptr}; std::vector<lite::Tensor>* Out{nullptr};
}; };
struct ReadFromArrayParam { struct ReadFromArrayParam : ParamBase {
const std::vector<lite::Tensor>* X{nullptr}; const std::vector<lite::Tensor>* X{nullptr};
const lite::Tensor* I{nullptr}; const lite::Tensor* I{nullptr};
lite::Tensor* Out{nullptr}; lite::Tensor* Out{nullptr};
}; };
struct BeamSearchParam { struct BeamSearchParam : ParamBase {
const lite::Tensor* pre_ids{}; const lite::Tensor* pre_ids{};
const lite::Tensor* pre_scores{}; const lite::Tensor* pre_scores{};
const lite::Tensor* ids{}; const lite::Tensor* ids{};
...@@ -763,7 +835,7 @@ struct BeamSearchParam { ...@@ -763,7 +835,7 @@ struct BeamSearchParam {
bool is_accumulated; bool is_accumulated;
}; };
struct SequencePoolParam { struct SequencePoolParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
std::string pool_type{"AVERAGE"}; std::string pool_type{"AVERAGE"};
...@@ -773,7 +845,7 @@ struct SequencePoolParam { ...@@ -773,7 +845,7 @@ struct SequencePoolParam {
#endif #endif
}; };
struct SequenceConvParam { struct SequenceConvParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
const lite::Tensor* Filter{}; const lite::Tensor* Filter{};
lite::Tensor* Out{}; lite::Tensor* Out{};
...@@ -782,13 +854,13 @@ struct SequenceConvParam { ...@@ -782,13 +854,13 @@ struct SequenceConvParam {
int contextLength; int contextLength;
}; };
struct SequencePoolConcatParam { struct SequencePoolConcatParam : ParamBase {
std::vector<lite::Tensor*> X{}; std::vector<lite::Tensor*> X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
std::vector<std::string> pool_type{}; std::vector<std::string> pool_type{};
}; };
struct SearchGroupPaddingParam { struct SearchGroupPaddingParam : ParamBase {
lite::Tensor* x{}; lite::Tensor* x{};
lite::Tensor* out_emb_padding{}; lite::Tensor* out_emb_padding{};
lite::Tensor* out_new{}; lite::Tensor* out_new{};
...@@ -796,36 +868,36 @@ struct SearchGroupPaddingParam { ...@@ -796,36 +868,36 @@ struct SearchGroupPaddingParam {
int pad_id; int pad_id;
}; };
struct SequenceReshapeParam { struct SequenceReshapeParam : ParamBase {
lite::Tensor* x{}; lite::Tensor* x{};
lite::Tensor* output{}; lite::Tensor* output{};
int new_dim; int new_dim;
}; };
struct SequenceExpandParam { struct SequenceExpandParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
const lite::Tensor* Y{}; const lite::Tensor* Y{};
lite::Tensor* Out{}; lite::Tensor* Out{};
int ref_level{-1}; int ref_level{-1};
}; };
struct SequenceExpandAsParam { struct SequenceExpandAsParam : ParamBase {
const lite::Tensor* x{nullptr}; const lite::Tensor* x{nullptr};
const lite::Tensor* y{nullptr}; const lite::Tensor* y{nullptr};
lite::Tensor* out{nullptr}; lite::Tensor* out{nullptr};
}; };
struct SequenceReverseParam { struct SequenceReverseParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
}; };
struct SequenceConcatParam { struct SequenceConcatParam : ParamBase {
std::vector<lite::Tensor*> X{}; std::vector<lite::Tensor*> X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
}; };
struct AttentionPaddingMaskParam { struct AttentionPaddingMaskParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
const lite::Tensor* Y{}; const lite::Tensor* Y{};
int pad_id; int pad_id;
...@@ -834,21 +906,21 @@ struct AttentionPaddingMaskParam { ...@@ -834,21 +906,21 @@ struct AttentionPaddingMaskParam {
lite::Tensor* pad_begin{}; lite::Tensor* pad_begin{};
}; };
struct SequenceArithmeticParam { struct SequenceArithmeticParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
const lite::Tensor* Y{}; const lite::Tensor* Y{};
int op_type{1}; int op_type{1};
lite::Tensor* Out{}; lite::Tensor* Out{};
}; };
struct ReduceMaxParam { struct ReduceMaxParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
std::vector<int> dim{}; std::vector<int> dim{};
bool keep_dim{false}; bool keep_dim{false};
}; };
struct LodResetParam { struct LodResetParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
const lite::Tensor* Y{}; const lite::Tensor* Y{};
lite::Tensor* Out{}; lite::Tensor* Out{};
...@@ -856,12 +928,12 @@ struct LodResetParam { ...@@ -856,12 +928,12 @@ struct LodResetParam {
bool append; bool append;
}; };
struct IsEmptyParam { struct IsEmptyParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
}; };
struct ReduceParam { struct ReduceParam : ParamBase {
lite::Tensor* x{}; lite::Tensor* x{};
lite::Tensor* output{}; lite::Tensor* output{};
std::vector<int> dim{0}; std::vector<int> dim{0};
...@@ -869,7 +941,7 @@ struct ReduceParam { ...@@ -869,7 +941,7 @@ struct ReduceParam {
bool reduce_all{false}; bool reduce_all{false};
}; };
struct VarConv2DParam { struct VarConv2DParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
const lite::Tensor* ROW{}; const lite::Tensor* ROW{};
const lite::Tensor* COLUMN{}; const lite::Tensor* COLUMN{};
...@@ -888,19 +960,19 @@ struct VarConv2DParam { ...@@ -888,19 +960,19 @@ struct VarConv2DParam {
}; };
/// ----------------------- shape operators ---------------------- /// ----------------------- shape operators ----------------------
struct ShapeParam { struct ShapeParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
}; };
struct CastParam { struct CastParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
int out_dtype{2}; int out_dtype{2};
int in_dtype{2}; int in_dtype{2};
}; };
struct SliceParam { struct SliceParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
std::vector<int> axes{}; std::vector<int> axes{};
...@@ -914,7 +986,7 @@ struct SliceParam { ...@@ -914,7 +986,7 @@ struct SliceParam {
lite::Tensor* EndsTensor{nullptr}; lite::Tensor* EndsTensor{nullptr};
}; };
struct AffineChannelParam { struct AffineChannelParam : ParamBase {
const lite::Tensor* X{}; // X is 4D tensor const lite::Tensor* X{}; // X is 4D tensor
const lite::Tensor* Scale{}; const lite::Tensor* Scale{};
const lite::Tensor* Bias{}; const lite::Tensor* Bias{};
...@@ -922,7 +994,7 @@ struct AffineChannelParam { ...@@ -922,7 +994,7 @@ struct AffineChannelParam {
lite::Tensor* Out{}; lite::Tensor* Out{};
}; };
struct AnchorGeneratorParam { struct AnchorGeneratorParam : ParamBase {
const lite::Tensor* Input{}; const lite::Tensor* Input{};
std::vector<float> anchor_sizes{}; std::vector<float> anchor_sizes{};
std::vector<float> aspect_ratios{}; std::vector<float> aspect_ratios{};
...@@ -934,7 +1006,7 @@ struct AnchorGeneratorParam { ...@@ -934,7 +1006,7 @@ struct AnchorGeneratorParam {
lite::Tensor* Variances{}; lite::Tensor* Variances{};
}; };
struct GenerateProposalsParam { struct GenerateProposalsParam : ParamBase {
// inputs // inputs
const lite::Tensor* Scores{}; const lite::Tensor* Scores{};
const lite::Tensor* BboxDeltas{}; const lite::Tensor* BboxDeltas{};
...@@ -954,14 +1026,14 @@ struct GenerateProposalsParam { ...@@ -954,14 +1026,14 @@ struct GenerateProposalsParam {
lite::Tensor* RpnRoiProbs{}; lite::Tensor* RpnRoiProbs{};
}; };
/// ----------------------- squeeze operators ---------------------- /// ----------------------- squeeze operators ----------------------
struct SqueezeParam { struct SqueezeParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
lite::Tensor* XShape{}; lite::Tensor* XShape{};
std::vector<int> axes{}; std::vector<int> axes{};
}; };
struct UnsqueezeParam { struct UnsqueezeParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
lite::Tensor* XShape{}; lite::Tensor* XShape{};
...@@ -971,14 +1043,14 @@ struct UnsqueezeParam { ...@@ -971,14 +1043,14 @@ struct UnsqueezeParam {
}; };
/// ----------------------- expand operators ---------------------- /// ----------------------- expand operators ----------------------
struct ExpandParam { struct ExpandParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
std::vector<int> expand_times{}; std::vector<int> expand_times{};
}; };
/// ----------------------- matmul operators ---------------------- /// ----------------------- matmul operators ----------------------
struct MatMulParam { struct MatMulParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
const lite::Tensor* Y{}; const lite::Tensor* Y{};
lite::Tensor* Out{}; lite::Tensor* Out{};
...@@ -987,20 +1059,20 @@ struct MatMulParam { ...@@ -987,20 +1059,20 @@ struct MatMulParam {
float alpha{1.0f}; float alpha{1.0f};
}; };
struct GatherParam { struct GatherParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
const lite::Tensor* Index{}; const lite::Tensor* Index{};
lite::Tensor* Out{}; lite::Tensor* Out{};
}; };
/// ----------------------- assign operators ----------------------- /// ----------------------- assign operators -----------------------
struct AssignParam { struct AssignParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
}; };
/// ----------------------- roi_align operators ----------------------- /// ----------------------- roi_align operators -----------------------
struct RoiAlignParam { struct RoiAlignParam : ParamBase {
lite::Tensor* X{}; lite::Tensor* X{};
lite::Tensor* ROIs{}; lite::Tensor* ROIs{};
lite::Tensor* Out{}; lite::Tensor* Out{};
...@@ -1011,13 +1083,13 @@ struct RoiAlignParam { ...@@ -1011,13 +1083,13 @@ struct RoiAlignParam {
}; };
/// ----------------------- box_clip operators ----------------------- /// ----------------------- box_clip operators -----------------------
struct BoxClipParam { struct BoxClipParam : ParamBase {
const lite::Tensor* Input{}; const lite::Tensor* Input{};
const lite::Tensor* ImInfo{}; const lite::Tensor* ImInfo{};
lite::Tensor* Output{}; lite::Tensor* Output{};
}; };
struct RangeParam { struct RangeParam : ParamBase {
const lite::Tensor* Start; const lite::Tensor* Start;
const lite::Tensor* End; const lite::Tensor* End;
const lite::Tensor* Step; const lite::Tensor* Step;
...@@ -1025,7 +1097,7 @@ struct RangeParam { ...@@ -1025,7 +1097,7 @@ struct RangeParam {
}; };
/// ----------------------- assign_value operators ----------------------- /// ----------------------- assign_value operators -----------------------
struct AssignValueParam { struct AssignValueParam : ParamBase {
std::vector<int> shape{}; std::vector<int> shape{};
int dtype{}; int dtype{};
std::vector<float> fp32_values{}; std::vector<float> fp32_values{};
...@@ -1034,7 +1106,7 @@ struct AssignValueParam { ...@@ -1034,7 +1106,7 @@ struct AssignValueParam {
}; };
/// --------------- sequence_topk_avg_pooling operators ------------------ /// --------------- sequence_topk_avg_pooling operators ------------------
struct SequenceTopkAvgPoolingParam { struct SequenceTopkAvgPoolingParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
const lite::Tensor* ROW{}; const lite::Tensor* ROW{};
const lite::Tensor* COLUMN{}; const lite::Tensor* COLUMN{};
...@@ -1045,7 +1117,7 @@ struct SequenceTopkAvgPoolingParam { ...@@ -1045,7 +1117,7 @@ struct SequenceTopkAvgPoolingParam {
}; };
/// --------------- search_fc operators ------------------ /// --------------- search_fc operators ------------------
struct SearchFcParam { struct SearchFcParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
const lite::Tensor* W{}; const lite::Tensor* W{};
const lite::Tensor* b{}; const lite::Tensor* b{};
...@@ -1053,7 +1125,7 @@ struct SearchFcParam { ...@@ -1053,7 +1125,7 @@ struct SearchFcParam {
int out_size{}; int out_size{};
}; };
/// --------------------- match_matrix_tensor operators -------------------- /// --------------------- match_matrix_tensor operators --------------------
struct MatchMatrixTensorParam { struct MatchMatrixTensorParam : ParamBase {
const lite::Tensor* x{}; const lite::Tensor* x{};
const lite::Tensor* y{}; const lite::Tensor* y{};
const lite::Tensor* w{}; const lite::Tensor* w{};
...@@ -1064,14 +1136,14 @@ struct MatchMatrixTensorParam { ...@@ -1064,14 +1136,14 @@ struct MatchMatrixTensorParam {
}; };
/// --------------------- search_seq_depadding operators -------------------- /// --------------------- search_seq_depadding operators --------------------
struct SearchSeqDepaddingParam { struct SearchSeqDepaddingParam : ParamBase {
const lite::Tensor* pad{}; const lite::Tensor* pad{};
const lite::Tensor* src{}; const lite::Tensor* src{};
lite::Tensor* out{}; lite::Tensor* out{};
}; };
/// --------------------- search_grnn operators -------------------- /// --------------------- search_grnn operators --------------------
struct SearchGrnnParam { struct SearchGrnnParam : ParamBase {
const lite::Tensor* x{}; const lite::Tensor* x{};
const lite::Tensor* wi{}; const lite::Tensor* wi{};
const lite::Tensor* wh{}; const lite::Tensor* wh{};
...@@ -1084,7 +1156,7 @@ struct SearchGrnnParam { ...@@ -1084,7 +1156,7 @@ struct SearchGrnnParam {
lite::Tensor* layout_input{}; lite::Tensor* layout_input{};
}; };
struct SplitLodTensorParam { struct SplitLodTensorParam : ParamBase {
const lite::Tensor* x{}; const lite::Tensor* x{};
const lite::Tensor* mask{}; const lite::Tensor* mask{};
lite::Tensor* out_true{}; lite::Tensor* out_true{};
...@@ -1092,7 +1164,7 @@ struct SplitLodTensorParam { ...@@ -1092,7 +1164,7 @@ struct SplitLodTensorParam {
int level{}; int level{};
}; };
struct MergeLodTensorParam { struct MergeLodTensorParam : ParamBase {
const lite::Tensor* x{}; const lite::Tensor* x{};
const lite::Tensor* mask{}; const lite::Tensor* mask{};
const lite::Tensor* in_true{}; const lite::Tensor* in_true{};
...@@ -1101,7 +1173,7 @@ struct MergeLodTensorParam { ...@@ -1101,7 +1173,7 @@ struct MergeLodTensorParam {
int level{}; int level{};
}; };
struct ConditionalBlockParam { struct ConditionalBlockParam : ParamBase {
const lite::Tensor* cond{}; const lite::Tensor* cond{};
std::vector<lite::Tensor*> x{}; std::vector<lite::Tensor*> x{};
std::vector<lite::Tensor*> outs{}; std::vector<lite::Tensor*> outs{};
...@@ -1110,14 +1182,14 @@ struct ConditionalBlockParam { ...@@ -1110,14 +1182,14 @@ struct ConditionalBlockParam {
bool is_scalar_condition{}; bool is_scalar_condition{};
}; };
struct CollectFpnProposalsParam { struct CollectFpnProposalsParam : ParamBase {
std::vector<lite::Tensor*> multi_level_rois{}; std::vector<lite::Tensor*> multi_level_rois{};
std::vector<lite::Tensor*> multi_level_scores{}; std::vector<lite::Tensor*> multi_level_scores{};
lite::Tensor* fpn_rois{}; lite::Tensor* fpn_rois{};
int post_nms_topN{}; int post_nms_topN{};
}; };
struct DistributeFpnProposalsParam { struct DistributeFpnProposalsParam : ParamBase {
const lite::Tensor* fpn_rois{}; const lite::Tensor* fpn_rois{};
std::vector<lite::Tensor*> multi_fpn_rois{}; std::vector<lite::Tensor*> multi_fpn_rois{};
lite::Tensor* restore_index{}; lite::Tensor* restore_index{};
...@@ -1128,7 +1200,7 @@ struct DistributeFpnProposalsParam { ...@@ -1128,7 +1200,7 @@ struct DistributeFpnProposalsParam {
}; };
/// --------------------- instance_norm operators -------------------- /// --------------------- instance_norm operators --------------------
struct InstanceNormParam { struct InstanceNormParam : ParamBase {
lite::Tensor* x{}; lite::Tensor* x{};
lite::Tensor* out{}; lite::Tensor* out{};
lite::Tensor* bias{}; lite::Tensor* bias{};
...@@ -1138,12 +1210,12 @@ struct InstanceNormParam { ...@@ -1138,12 +1210,12 @@ struct InstanceNormParam {
float epsilon; float epsilon;
}; };
/// --------------------- grid sampler operators -------------------- /// --------------------- grid sampler operators --------------------
struct GridSamplerParam { struct GridSamplerParam : ParamBase {
lite::Tensor* x{}; lite::Tensor* x{};
lite::Tensor* out{}; lite::Tensor* out{};
lite::Tensor* grid{}; lite::Tensor* grid{};
}; };
struct LstmParam { struct LstmParam : ParamBase {
lite::Tensor* Input{}; lite::Tensor* Input{};
lite::Tensor* Weight{}; lite::Tensor* Weight{};
lite::Tensor* Bias{}; lite::Tensor* Bias{};
...@@ -1160,7 +1232,7 @@ struct LstmParam { ...@@ -1160,7 +1232,7 @@ struct LstmParam {
std::string candidate_activation; std::string candidate_activation;
}; };
struct CrfDecodingParam { struct CrfDecodingParam : ParamBase {
lite::Tensor* emission{}; lite::Tensor* emission{};
lite::Tensor* transition{}; lite::Tensor* transition{};
lite::Tensor* label{}; lite::Tensor* label{};
......
...@@ -30,7 +30,7 @@ bool Pad2dOpLite::CheckShape() const { ...@@ -30,7 +30,7 @@ bool Pad2dOpLite::CheckShape() const {
return true; return true;
} }
bool Pad2dOpLite::InferShape() const { bool Pad2dOpLite::InferShapeImpl() const {
// nchw // nchw
auto x_dims = param_.X->dims(); auto x_dims = param_.X->dims();
int out_h = x_dims[2] + param_.paddings[0] + param_.paddings[1]; int out_h = x_dims[2] + param_.paddings[0] + param_.paddings[1];
......
...@@ -30,7 +30,7 @@ class Pad2dOpLite : public OpLite { ...@@ -30,7 +30,7 @@ class Pad2dOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -60,7 +60,7 @@ int PoolOutputSize(int input_size, ...@@ -60,7 +60,7 @@ int PoolOutputSize(int input_size,
return output_size; return output_size;
} }
bool PoolOpLite::InferShape() const { bool PoolOpLite::InferShapeImpl() const {
const auto x_dims = param_.x->dims(); const auto x_dims = param_.x->dims();
std::vector<int>& ksize = param_.ksize; std::vector<int>& ksize = param_.ksize;
// dynamic update 4-pad // dynamic update 4-pad
......
...@@ -37,7 +37,7 @@ class PoolOpLite : public OpLite { ...@@ -37,7 +37,7 @@ class PoolOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
......
...@@ -27,7 +27,7 @@ bool PowerOp::CheckShape() const { ...@@ -27,7 +27,7 @@ bool PowerOp::CheckShape() const {
return true; return true;
} }
bool PowerOp::InferShape() const { bool PowerOp::InferShapeImpl() const {
param_.Out->Resize(param_.X->dims()); param_.Out->Resize(param_.X->dims());
return true; return true;
} }
......
...@@ -31,7 +31,7 @@ class PowerOp : public OpLite { ...@@ -31,7 +31,7 @@ class PowerOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -27,7 +27,7 @@ bool PriorBoxOpLite::CheckShape() const { ...@@ -27,7 +27,7 @@ bool PriorBoxOpLite::CheckShape() const {
return true; return true;
} }
bool PriorBoxOpLite::InferShape() const { return true; } bool PriorBoxOpLite::InferShapeImpl() const { return true; }
bool PriorBoxOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { bool PriorBoxOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
auto input = opdesc.Input("Input").front(); auto input = opdesc.Input("Input").front();
......
...@@ -29,7 +29,7 @@ class PriorBoxOpLite : public OpLite { ...@@ -29,7 +29,7 @@ class PriorBoxOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -41,7 +41,7 @@ void GetSize(T start, T end, T step, int64_t* size) { ...@@ -41,7 +41,7 @@ void GetSize(T start, T end, T step, int64_t* size) {
: std::ceil(std::abs((end - start) / step)); : std::ceil(std::abs((end - start) / step));
} }
bool RangeOpLite::InferShape() const { bool RangeOpLite::InferShapeImpl() const {
int start = param_.Start->data<float>()[0]; int start = param_.Start->data<float>()[0];
int end = param_.End->data<float>()[0]; int end = param_.End->data<float>()[0];
int step = param_.Step->data<float>()[0]; int step = param_.Step->data<float>()[0];
......
...@@ -29,7 +29,7 @@ class RangeOpLite : public OpLite { ...@@ -29,7 +29,7 @@ class RangeOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -26,7 +26,7 @@ bool ReadFromArrayOp::CheckShape() const { ...@@ -26,7 +26,7 @@ bool ReadFromArrayOp::CheckShape() const {
return true; return true;
} }
bool ReadFromArrayOp::InferShape() const { bool ReadFromArrayOp::InferShapeImpl() const {
int id = param_.I->data<int64_t>()[0]; int id = param_.I->data<int64_t>()[0];
auto out_dims = (*param_.X)[id].dims(); auto out_dims = (*param_.X)[id].dims();
param_.Out->Resize(out_dims); param_.Out->Resize(out_dims);
......
...@@ -30,7 +30,7 @@ class ReadFromArrayOp : public OpLite { ...@@ -30,7 +30,7 @@ class ReadFromArrayOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -39,7 +39,7 @@ bool ReduceMaxOp::CheckShape() const { ...@@ -39,7 +39,7 @@ bool ReduceMaxOp::CheckShape() const {
return true; return true;
} }
bool ReduceMaxOp::InferShape() const { bool ReduceMaxOp::InferShapeImpl() const {
auto dims = param_.dim; auto dims = param_.dim;
auto x_dims = param_.X->dims(); auto x_dims = param_.X->dims();
bool reduce_all = false; bool reduce_all = false;
......
...@@ -28,7 +28,7 @@ class ReduceMaxOp : public OpLite { ...@@ -28,7 +28,7 @@ class ReduceMaxOp : public OpLite {
ReduceMaxOp() {} ReduceMaxOp() {}
explicit ReduceMaxOp(const std::string &op_type) : OpLite(op_type) {} explicit ReduceMaxOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
......
...@@ -39,7 +39,7 @@ bool ReduceMeanOp::CheckShape() const { ...@@ -39,7 +39,7 @@ bool ReduceMeanOp::CheckShape() const {
return true; return true;
} }
bool ReduceMeanOp::InferShape() const { bool ReduceMeanOp::InferShapeImpl() const {
auto dims = param_.dim; auto dims = param_.dim;
auto x_dims = param_.X->dims(); auto x_dims = param_.X->dims();
bool reduce_all = false; bool reduce_all = false;
......
...@@ -28,7 +28,7 @@ class ReduceMeanOp : public OpLite { ...@@ -28,7 +28,7 @@ class ReduceMeanOp : public OpLite {
ReduceMeanOp() {} ReduceMeanOp() {}
explicit ReduceMeanOp(const std::string &op_type) : OpLite(op_type) {} explicit ReduceMeanOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
......
...@@ -28,7 +28,7 @@ bool ReduceOp::CheckShape() const { ...@@ -28,7 +28,7 @@ bool ReduceOp::CheckShape() const {
return true; return true;
} }
bool ReduceOp::InferShape() const { bool ReduceOp::InferShapeImpl() const {
const auto &x_dims = param_.x->dims(); const auto &x_dims = param_.x->dims();
auto x_rank = x_dims.size(); auto x_rank = x_dims.size();
auto dims = param_.dim; auto dims = param_.dim;
......
...@@ -30,7 +30,7 @@ class ReduceOp : public OpLite { ...@@ -30,7 +30,7 @@ class ReduceOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -28,7 +28,7 @@ bool ReduceProdOpLite::CheckShape() const { ...@@ -28,7 +28,7 @@ bool ReduceProdOpLite::CheckShape() const {
return true; return true;
} }
bool ReduceProdOpLite::InferShape() const { bool ReduceProdOpLite::InferShapeImpl() const {
auto x = param_.x; auto x = param_.x;
auto out = param_.output; auto out = param_.output;
std::vector<int> dim = param_.dim; std::vector<int> dim = param_.dim;
......
...@@ -29,7 +29,7 @@ class ReduceProdOpLite : public OpLite { ...@@ -29,7 +29,7 @@ class ReduceProdOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -20,7 +20,7 @@ namespace lite { ...@@ -20,7 +20,7 @@ namespace lite {
namespace operators { namespace operators {
bool ReluOp::CheckShape() const { return true; } bool ReluOp::CheckShape() const { return true; }
bool ReluOp::InferShape() const { bool ReluOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.X); CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out);
// TODO(Superjomn) Enable data sharing. // TODO(Superjomn) Enable data sharing.
......
...@@ -30,7 +30,7 @@ class ReluOp : public OpLite { ...@@ -30,7 +30,7 @@ class ReluOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -26,7 +26,7 @@ bool ReshapeOp::CheckShape() const { ...@@ -26,7 +26,7 @@ bool ReshapeOp::CheckShape() const {
return true; return true;
} }
bool ReshapeOp::InferShape() const { bool ReshapeOp::InferShapeImpl() const {
const auto &shape_tensor_vct = param_.shape_tensor_vct; const auto &shape_tensor_vct = param_.shape_tensor_vct;
auto *shape_tensor = param_.shape_tensor; auto *shape_tensor = param_.shape_tensor;
const auto &shape_vct = param_.shape_vct; const auto &shape_vct = param_.shape_vct;
...@@ -97,8 +97,8 @@ bool Reshape2Op::CheckShape() const { ...@@ -97,8 +97,8 @@ bool Reshape2Op::CheckShape() const {
return true; return true;
} }
bool Reshape2Op::InferShape() const { bool Reshape2Op::InferShapeImpl() const {
ReshapeOp::InferShape(); ReshapeOp::InferShapeImpl();
const auto &x_dims = param_.x->dims(); const auto &x_dims = param_.x->dims();
std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1); std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0; xshape_dims[0] = 0;
......
...@@ -30,7 +30,7 @@ class ReshapeOp : public OpLite { ...@@ -30,7 +30,7 @@ class ReshapeOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
...@@ -48,7 +48,7 @@ class Reshape2Op : public ReshapeOp { ...@@ -48,7 +48,7 @@ class Reshape2Op : public ReshapeOp {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -38,7 +38,7 @@ bool RoiAlignOpLite::CheckShape() const { ...@@ -38,7 +38,7 @@ bool RoiAlignOpLite::CheckShape() const {
return true; return true;
} }
bool RoiAlignOpLite::InferShape() const { bool RoiAlignOpLite::InferShapeImpl() const {
auto x_dims = param_.X->dims(); auto x_dims = param_.X->dims();
auto rois_dims = param_.ROIs->dims(); auto rois_dims = param_.ROIs->dims();
......
...@@ -31,7 +31,7 @@ class RoiAlignOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class RoiAlignOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -24,7 +24,7 @@ bool ScaleOp::CheckShape() const { ...@@ -24,7 +24,7 @@ bool ScaleOp::CheckShape() const {
return true; return true;
} }
bool ScaleOp::InferShape() const { bool ScaleOp::InferShapeImpl() const {
param_.output->Resize(param_.x->dims()); param_.output->Resize(param_.x->dims());
return true; return true;
} }
......
...@@ -30,7 +30,7 @@ class ScaleOp : public OpLite { ...@@ -30,7 +30,7 @@ class ScaleOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -27,7 +27,7 @@ bool SearchAlignedMatMulOpLite::CheckShape() const { ...@@ -27,7 +27,7 @@ bool SearchAlignedMatMulOpLite::CheckShape() const {
return true; return true;
} }
bool SearchAlignedMatMulOpLite::InferShape() const { bool SearchAlignedMatMulOpLite::InferShapeImpl() const {
const auto x_dims = param_.X->dims(); const auto x_dims = param_.X->dims();
const auto y_dims = param_.Y->dims(); const auto y_dims = param_.Y->dims();
const auto& x_lod = param_.X->lod(); const auto& x_lod = param_.X->lod();
......
...@@ -31,7 +31,7 @@ class SearchAlignedMatMulOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class SearchAlignedMatMulOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
......
...@@ -50,7 +50,7 @@ bool SearchFcOpLite::CheckShape() const { ...@@ -50,7 +50,7 @@ bool SearchFcOpLite::CheckShape() const {
return true; return true;
} }
bool SearchFcOpLite::InferShape() const { bool SearchFcOpLite::InferShapeImpl() const {
auto out_size = param_.out_size; auto out_size = param_.out_size;
lite::DDim dims(std::vector<int64_t>({-1, out_size})); lite::DDim dims(std::vector<int64_t>({-1, out_size}));
param_.Out->Resize(dims); param_.Out->Resize(dims);
......
...@@ -30,7 +30,7 @@ class SearchFcOpLite : public OpLite { ...@@ -30,7 +30,7 @@ class SearchFcOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -51,7 +51,7 @@ bool SearchGrnnOpLite::CheckShape() const { ...@@ -51,7 +51,7 @@ bool SearchGrnnOpLite::CheckShape() const {
return true; return true;
} }
bool SearchGrnnOpLite::InferShape() const { bool SearchGrnnOpLite::InferShapeImpl() const {
const auto& x_dims = param_.x->dims(); const auto& x_dims = param_.x->dims();
const auto& x_lod = param_.x->lod(); const auto& x_lod = param_.x->lod();
CHECK_OR_FALSE(!x_lod.empty()); CHECK_OR_FALSE(!x_lod.empty());
......
...@@ -31,7 +31,7 @@ class SearchGrnnOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class SearchGrnnOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -31,7 +31,7 @@ bool SearchGroupPaddingOp::CheckShape() const { ...@@ -31,7 +31,7 @@ bool SearchGroupPaddingOp::CheckShape() const {
return true; return true;
} }
bool SearchGroupPaddingOp::InferShape() const { bool SearchGroupPaddingOp::InferShapeImpl() const {
std::vector<int64_t> x_dims = param_.x->dims().Vectorize(); std::vector<int64_t> x_dims = param_.x->dims().Vectorize();
param_.out_emb_padding->Resize({-1, x_dims[1]}); param_.out_emb_padding->Resize({-1, x_dims[1]});
......
...@@ -27,7 +27,7 @@ class SearchGroupPaddingOp : public OpLite { ...@@ -27,7 +27,7 @@ class SearchGroupPaddingOp : public OpLite {
SearchGroupPaddingOp() {} SearchGroupPaddingOp() {}
explicit SearchGroupPaddingOp(const std::string &op_type) : OpLite(op_type) {} explicit SearchGroupPaddingOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "search_group_padding"; } std::string DebugString() const override { return "search_group_padding"; }
......
...@@ -44,7 +44,7 @@ bool SearchSeqDepaddingOpLite::CheckShape() const { ...@@ -44,7 +44,7 @@ bool SearchSeqDepaddingOpLite::CheckShape() const {
return true; return true;
} }
bool SearchSeqDepaddingOpLite::InferShape() const { bool SearchSeqDepaddingOpLite::InferShapeImpl() const {
DDim pad_dims = param_.pad->dims(); DDim pad_dims = param_.pad->dims();
param_.out->Resize({-1, pad_dims[1]}); param_.out->Resize({-1, pad_dims[1]});
return true; return true;
......
...@@ -32,7 +32,7 @@ class SearchSeqDepaddingOpLite : public OpLite { ...@@ -32,7 +32,7 @@ class SearchSeqDepaddingOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -26,7 +26,7 @@ bool SearchSeqFcOpLite::CheckShape() const { ...@@ -26,7 +26,7 @@ bool SearchSeqFcOpLite::CheckShape() const {
return true; return true;
} }
bool SearchSeqFcOpLite::InferShape() const { bool SearchSeqFcOpLite::InferShapeImpl() const {
const auto x_dims = param_.x->dims(); const auto x_dims = param_.x->dims();
const auto w_dims = param_.w->dims(); const auto w_dims = param_.w->dims();
const auto& x_lod = param_.x->lod(); const auto& x_lod = param_.x->lod();
......
...@@ -31,7 +31,7 @@ class SearchSeqFcOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class SearchSeqFcOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
......
...@@ -25,7 +25,7 @@ bool SearchSeqSoftmaxOp::CheckShape() const { ...@@ -25,7 +25,7 @@ bool SearchSeqSoftmaxOp::CheckShape() const {
return true; return true;
} }
bool SearchSeqSoftmaxOp::InferShape() const { bool SearchSeqSoftmaxOp::InferShapeImpl() const {
param_.output->Resize(param_.x->dims()); param_.output->Resize(param_.x->dims());
param_.output->set_lod(param_.x->lod()); param_.output->set_lod(param_.x->lod());
return true; return true;
......
...@@ -31,7 +31,7 @@ class SearchSeqSoftmaxOp : public OpLite { ...@@ -31,7 +31,7 @@ class SearchSeqSoftmaxOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -28,7 +28,7 @@ bool SequenceArithmeticOp::CheckShape() const { ...@@ -28,7 +28,7 @@ bool SequenceArithmeticOp::CheckShape() const {
return true; return true;
} }
bool SequenceArithmeticOp::InferShape() const { bool SequenceArithmeticOp::InferShapeImpl() const {
param_.Out->Resize(param_.X->dims()); param_.Out->Resize(param_.X->dims());
param_.Out->set_lod(param_.X->lod()); param_.Out->set_lod(param_.X->lod());
return true; return true;
......
...@@ -29,7 +29,7 @@ class SequenceArithmeticOp : public OpLite { ...@@ -29,7 +29,7 @@ class SequenceArithmeticOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -26,7 +26,7 @@ bool SequenceConcatOp::CheckShape() const { ...@@ -26,7 +26,7 @@ bool SequenceConcatOp::CheckShape() const {
return true; return true;
} }
bool SequenceConcatOp::InferShape() const { return true; } bool SequenceConcatOp::InferShapeImpl() const { return true; }
bool SequenceConcatOp::AttachImpl(const cpp::OpDesc &opdesc, bool SequenceConcatOp::AttachImpl(const cpp::OpDesc &opdesc,
lite::Scope *scope) { lite::Scope *scope) {
......
...@@ -27,7 +27,7 @@ class SequenceConcatOp : public OpLite { ...@@ -27,7 +27,7 @@ class SequenceConcatOp : public OpLite {
SequenceConcatOp() {} SequenceConcatOp() {}
explicit SequenceConcatOp(const std::string &op_type) : OpLite(op_type) {} explicit SequenceConcatOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "sequence_concat"; } std::string DebugString() const override { return "sequence_concat"; }
......
...@@ -44,7 +44,7 @@ bool SequenceConvOp::CheckShape() const { ...@@ -44,7 +44,7 @@ bool SequenceConvOp::CheckShape() const {
return true; return true;
} }
bool SequenceConvOp::InferShape() const { bool SequenceConvOp::InferShapeImpl() const {
const auto *input = param_.X; const auto *input = param_.X;
const auto *filter = param_.Filter; const auto *filter = param_.Filter;
auto in_dims = input->dims(); auto in_dims = input->dims();
......
...@@ -28,7 +28,7 @@ class SequenceConvOp : public OpLite { ...@@ -28,7 +28,7 @@ class SequenceConvOp : public OpLite {
SequenceConvOp() {} SequenceConvOp() {}
explicit SequenceConvOp(const std::string &op_type) : OpLite(op_type) {} explicit SequenceConvOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
......
...@@ -34,7 +34,7 @@ bool SequenceExpandAsOpLite::CheckShape() const { ...@@ -34,7 +34,7 @@ bool SequenceExpandAsOpLite::CheckShape() const {
return true; return true;
} }
bool SequenceExpandAsOpLite::InferShape() const { bool SequenceExpandAsOpLite::InferShapeImpl() const {
auto x_dims = param_.x->dims(); auto x_dims = param_.x->dims();
auto y_lod = param_.y->lod(); auto y_lod = param_.y->lod();
auto out_dims = x_dims; auto out_dims = x_dims;
......
...@@ -31,7 +31,7 @@ class SequenceExpandAsOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class SequenceExpandAsOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -40,7 +40,7 @@ bool SequenceExpandOp::CheckShape() const { ...@@ -40,7 +40,7 @@ bool SequenceExpandOp::CheckShape() const {
return true; return true;
} }
bool SequenceExpandOp::InferShape() const { bool SequenceExpandOp::InferShapeImpl() const {
const auto x_lod = param_.X->lod(); const auto x_lod = param_.X->lod();
auto x_dims = param_.X->dims(); auto x_dims = param_.X->dims();
int ref_level = param_.ref_level; int ref_level = param_.ref_level;
......
...@@ -30,7 +30,7 @@ class SequenceExpandOp : public OpLite { ...@@ -30,7 +30,7 @@ class SequenceExpandOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -26,7 +26,7 @@ bool SequencePoolConcatOp::CheckShape() const { ...@@ -26,7 +26,7 @@ bool SequencePoolConcatOp::CheckShape() const {
return true; return true;
} }
bool SequencePoolConcatOp::InferShape() const { bool SequencePoolConcatOp::InferShapeImpl() const {
int out_dim = 0; int out_dim = 0;
for (int i = 0; i < param_.X.size(); ++i) { for (int i = 0; i < param_.X.size(); ++i) {
out_dim += param_.X[i]->dims().count(1, param_.X[i]->dims().size()); out_dim += param_.X[i]->dims().count(1, param_.X[i]->dims().size());
......
...@@ -28,7 +28,7 @@ class SequencePoolConcatOp : public OpLite { ...@@ -28,7 +28,7 @@ class SequencePoolConcatOp : public OpLite {
SequencePoolConcatOp() {} SequencePoolConcatOp() {}
explicit SequencePoolConcatOp(const std::string &op_type) : OpLite(op_type) {} explicit SequencePoolConcatOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
......
...@@ -29,7 +29,7 @@ bool SequencePoolOp::CheckShape() const { ...@@ -29,7 +29,7 @@ bool SequencePoolOp::CheckShape() const {
return true; return true;
} }
bool SequencePoolOp::InferShape() const { bool SequencePoolOp::InferShapeImpl() const {
const auto *input = param_.X; const auto *input = param_.X;
auto out_dims = input->dims(); auto out_dims = input->dims();
out_dims[0] = input->lod()[0].size() - 1; out_dims[0] = input->lod()[0].size() - 1;
......
...@@ -28,7 +28,7 @@ class SequencePoolOp : public OpLite { ...@@ -28,7 +28,7 @@ class SequencePoolOp : public OpLite {
SequencePoolOp() {} SequencePoolOp() {}
explicit SequencePoolOp(const std::string &op_type) : OpLite(op_type) {} explicit SequencePoolOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
......
...@@ -27,7 +27,7 @@ bool SequenceReshapeOp::CheckShape() const { ...@@ -27,7 +27,7 @@ bool SequenceReshapeOp::CheckShape() const {
return true; return true;
} }
bool SequenceReshapeOp::InferShape() const { bool SequenceReshapeOp::InferShapeImpl() const {
int new_dim = param_.new_dim; int new_dim = param_.new_dim;
auto x_numel = param_.x->dims().production(); auto x_numel = param_.x->dims().production();
std::vector<int64_t> out_shape{x_numel / new_dim, std::vector<int64_t> out_shape{x_numel / new_dim,
......
...@@ -31,7 +31,7 @@ class SequenceReshapeOp : public OpLite { ...@@ -31,7 +31,7 @@ class SequenceReshapeOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -30,7 +30,7 @@ bool SequenceReverseOp::CheckShape() const { ...@@ -30,7 +30,7 @@ bool SequenceReverseOp::CheckShape() const {
return true; return true;
} }
bool SequenceReverseOp::InferShape() const { bool SequenceReverseOp::InferShapeImpl() const {
const auto *input = param_.X; const auto *input = param_.X;
auto out_dims = input->dims(); auto out_dims = input->dims();
param_.Out->Resize(out_dims); param_.Out->Resize(out_dims);
......
...@@ -27,7 +27,7 @@ class SequenceReverseOp : public OpLite { ...@@ -27,7 +27,7 @@ class SequenceReverseOp : public OpLite {
SequenceReverseOp() {} SequenceReverseOp() {}
explicit SequenceReverseOp(const std::string &op_type) : OpLite(op_type) {} explicit SequenceReverseOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "sequence_reverse"; } std::string DebugString() const override { return "sequence_reverse"; }
......
...@@ -24,7 +24,7 @@ bool SequenceSoftmaxOp::CheckShape() const { ...@@ -24,7 +24,7 @@ bool SequenceSoftmaxOp::CheckShape() const {
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out);
return true; return true;
} }
bool SequenceSoftmaxOp::InferShape() const { bool SequenceSoftmaxOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out);
// TODO(Superjomn) Enable data sharing. // TODO(Superjomn) Enable data sharing.
auto input_dims = param_.X->dims(); auto input_dims = param_.X->dims();
......
...@@ -30,7 +30,7 @@ class SequenceSoftmaxOp : public OpLite { ...@@ -30,7 +30,7 @@ class SequenceSoftmaxOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -43,7 +43,7 @@ bool SequenceTopkAvgPoolingOpLite::CheckShape() const { ...@@ -43,7 +43,7 @@ bool SequenceTopkAvgPoolingOpLite::CheckShape() const {
return true; return true;
} }
bool SequenceTopkAvgPoolingOpLite::InferShape() const { bool SequenceTopkAvgPoolingOpLite::InferShapeImpl() const {
int channel_num = param_.channel_num; int channel_num = param_.channel_num;
std::vector<int> topks = param_.topks; std::vector<int> topks = param_.topks;
auto row_dim = param_.ROW->dims(); auto row_dim = param_.ROW->dims();
......
...@@ -31,7 +31,7 @@ class SequenceTopkAvgPoolingOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class SequenceTopkAvgPoolingOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -30,7 +30,7 @@ bool SGDOpLite::CheckShape() const { ...@@ -30,7 +30,7 @@ bool SGDOpLite::CheckShape() const {
return true; return true;
} }
bool SGDOpLite::InferShape() const { bool SGDOpLite::InferShapeImpl() const {
param_.ParamOut->Resize(param_.Param->dims()); param_.ParamOut->Resize(param_.Param->dims());
return true; return true;
} }
......
...@@ -33,7 +33,7 @@ class SGDOpLite : public OpLite { ...@@ -33,7 +33,7 @@ class SGDOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
......
...@@ -25,7 +25,7 @@ bool ShapeOpLite::CheckShape() const { ...@@ -25,7 +25,7 @@ bool ShapeOpLite::CheckShape() const {
return true; return true;
} }
bool ShapeOpLite::InferShape() const { bool ShapeOpLite::InferShapeImpl() const {
std::vector<int64_t> shape_vec; std::vector<int64_t> shape_vec;
shape_vec.push_back(static_cast<int64_t>(param_.X->dims().size())); shape_vec.push_back(static_cast<int64_t>(param_.X->dims().size()));
param_.Out->Resize(shape_vec); param_.Out->Resize(shape_vec);
......
...@@ -28,7 +28,7 @@ class ShapeOpLite : public OpLite { ...@@ -28,7 +28,7 @@ class ShapeOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -27,7 +27,7 @@ bool ShuffleChannelOpLite::CheckShape() const { ...@@ -27,7 +27,7 @@ bool ShuffleChannelOpLite::CheckShape() const {
return true; return true;
} }
bool ShuffleChannelOpLite::InferShape() const { bool ShuffleChannelOpLite::InferShapeImpl() const {
param_.Out->Resize(param_.X->dims()); param_.Out->Resize(param_.X->dims());
return true; return true;
} }
......
...@@ -33,7 +33,7 @@ class ShuffleChannelOpLite : public OpLite { ...@@ -33,7 +33,7 @@ class ShuffleChannelOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
......
...@@ -27,7 +27,7 @@ bool SliceOp::CheckShape() const { ...@@ -27,7 +27,7 @@ bool SliceOp::CheckShape() const {
return true; return true;
} }
bool SliceOp::InferShape() const { bool SliceOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out);
// TODO(Superjomn) Enable data sharing. // TODO(Superjomn) Enable data sharing.
auto in_dims = param_.X->dims(); auto in_dims = param_.X->dims();
......
...@@ -30,7 +30,7 @@ class SliceOp : public OpLite { ...@@ -30,7 +30,7 @@ class SliceOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -29,35 +29,7 @@ bool SoftmaxOp::CheckShape() const { ...@@ -29,35 +29,7 @@ bool SoftmaxOp::CheckShape() const {
return true; return true;
} }
bool SoftmaxOp::SmartInferShape() { bool SoftmaxOp::InferShapeImpl() const {
if (!last_input_shapes.empty() && !last_output_shapes.empty()) {
if (param_.x->dims() == last_input_shapes[0] &&
param_.x->lod() == last_input_lods[0]) {
param_.output->Resize(last_output_shapes[0]);
param_.output->set_lod(last_output_lods[0]);
return true;
}
}
this->InferShape();
if (!last_input_shapes.empty()) {
last_input_shapes.clear();
last_input_lods.clear();
}
last_input_shapes.push_back(param_.x->dims());
last_input_lods.push_back(param_.x->lod());
if (!last_output_shapes.empty()) {
last_output_shapes.clear();
last_output_lods.clear();
}
last_output_shapes.push_back(param_.output->dims());
last_output_lods.push_back(param_.output->lod());
return true;
}
bool SoftmaxOp::InferShape() const {
param_.output->Resize(param_.x->dims()); param_.output->Resize(param_.x->dims());
auto out_lod = param_.output->mutable_lod(); auto out_lod = param_.output->mutable_lod();
*out_lod = param_.x->lod(); *out_lod = param_.x->lod();
......
...@@ -30,8 +30,7 @@ class SoftmaxOp : public OpLite { ...@@ -30,8 +30,7 @@ class SoftmaxOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool SmartInferShape() override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -33,7 +33,7 @@ bool SplitLodTensorOpLite::CheckShape() const { ...@@ -33,7 +33,7 @@ bool SplitLodTensorOpLite::CheckShape() const {
return true; return true;
} }
bool SplitLodTensorOpLite::InferShape() const { bool SplitLodTensorOpLite::InferShapeImpl() const {
auto x_dims = param_.x->dims(); auto x_dims = param_.x->dims();
param_.out_true->Resize(x_dims); param_.out_true->Resize(x_dims);
param_.out_false->Resize(x_dims); param_.out_false->Resize(x_dims);
......
...@@ -31,7 +31,7 @@ class SplitLodTensorOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class SplitLodTensorOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -29,7 +29,7 @@ bool SplitOp::CheckShape() const { ...@@ -29,7 +29,7 @@ bool SplitOp::CheckShape() const {
return true; return true;
} }
bool SplitOp::InferShape() const { bool SplitOp::InferShapeImpl() const {
const auto &outs = param_.output; const auto &outs = param_.output;
auto in_dims = param_.x->dims(); auto in_dims = param_.x->dims();
int axis = param_.axis; int axis = param_.axis;
......
...@@ -30,7 +30,7 @@ class SplitOp : public OpLite { ...@@ -30,7 +30,7 @@ class SplitOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -75,7 +75,7 @@ bool SqueezeOp::CheckShape() const { ...@@ -75,7 +75,7 @@ bool SqueezeOp::CheckShape() const {
return true; return true;
} }
bool SqueezeOp::InferShape() const { bool SqueezeOp::InferShapeImpl() const {
std::vector<int> squeeze_dims = param_.axes; std::vector<int> squeeze_dims = param_.axes;
DDim in_dims = param_.X->dims(); DDim in_dims = param_.X->dims();
DDim out_dim = GetOutputShape(squeeze_dims, in_dims, true); DDim out_dim = GetOutputShape(squeeze_dims, in_dims, true);
...@@ -105,8 +105,8 @@ bool Squeeze2Op::CheckShape() const { ...@@ -105,8 +105,8 @@ bool Squeeze2Op::CheckShape() const {
return true; return true;
} }
bool Squeeze2Op::InferShape() const { bool Squeeze2Op::InferShapeImpl() const {
SqueezeOp::InferShape(); SqueezeOp::InferShapeImpl();
auto x_dims = param_.X->dims(); auto x_dims = param_.X->dims();
std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 1); std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 1);
for (size_t i = 0; i < x_dims.size(); i++) { for (size_t i = 0; i < x_dims.size(); i++) {
......
...@@ -30,7 +30,7 @@ class SqueezeOp : public OpLite { ...@@ -30,7 +30,7 @@ class SqueezeOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
...@@ -48,7 +48,7 @@ class Squeeze2Op : public SqueezeOp { ...@@ -48,7 +48,7 @@ class Squeeze2Op : public SqueezeOp {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -32,7 +32,7 @@ bool StackOp::CheckShape() const { ...@@ -32,7 +32,7 @@ bool StackOp::CheckShape() const {
return true; return true;
} }
bool StackOp::InferShape() const { bool StackOp::InferShapeImpl() const {
auto input = param_.X; auto input = param_.X;
auto input_dims = input[0]->dims(); auto input_dims = input[0]->dims();
int axis = param_.axis; int axis = param_.axis;
......
...@@ -31,7 +31,7 @@ class StackOp : public OpLite { ...@@ -31,7 +31,7 @@ class StackOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -22,7 +22,7 @@ namespace operators { ...@@ -22,7 +22,7 @@ namespace operators {
bool SubgraphOp::CheckShape() const { return true; } bool SubgraphOp::CheckShape() const { return true; }
bool SubgraphOp::InferShape() const { return CheckShape(); /* enrich me */ } bool SubgraphOp::InferShapeImpl() const { return CheckShape(); /* enrich me */ }
bool SubgraphOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { bool SubgraphOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
param_.input_names = op_desc.Input("Inputs"); param_.input_names = op_desc.Input("Inputs");
......
...@@ -35,7 +35,7 @@ class SubgraphOp : public OpLite { ...@@ -35,7 +35,7 @@ class SubgraphOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override;
......
...@@ -25,7 +25,7 @@ bool TopkOp::CheckShape() const { ...@@ -25,7 +25,7 @@ bool TopkOp::CheckShape() const {
return true; return true;
} }
bool TopkOp::InferShape() const { bool TopkOp::InferShapeImpl() const {
auto out_dims = param_.X->dims(); auto out_dims = param_.X->dims();
out_dims[out_dims.size() - 1] = param_.K; out_dims[out_dims.size() - 1] = param_.K;
auto out = param_.Out; auto out = param_.Out;
......
...@@ -30,7 +30,7 @@ class TopkOp : public OpLite { ...@@ -30,7 +30,7 @@ class TopkOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -42,7 +42,7 @@ bool TransposeOp::CheckShape() const { ...@@ -42,7 +42,7 @@ bool TransposeOp::CheckShape() const {
return true; return true;
} }
bool TransposeOp::InferShape() const { bool TransposeOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.x); CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output); CHECK_OR_FALSE(param_.output);
auto x_dims = param_.x->dims(); auto x_dims = param_.x->dims();
...@@ -111,7 +111,7 @@ bool Transpose2Op::CheckShape() const { ...@@ -111,7 +111,7 @@ bool Transpose2Op::CheckShape() const {
return true; return true;
} }
bool Transpose2Op::InferShape() const { bool Transpose2Op::InferShapeImpl() const {
CHECK_OR_FALSE(param_.x); CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output); CHECK_OR_FALSE(param_.output);
auto x_dims = param_.x->dims(); auto x_dims = param_.x->dims();
......
...@@ -31,7 +31,7 @@ class TransposeOp : public OpLite { ...@@ -31,7 +31,7 @@ class TransposeOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
...@@ -50,7 +50,7 @@ class Transpose2Op : public OpLite { ...@@ -50,7 +50,7 @@ class Transpose2Op : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -22,7 +22,7 @@ namespace operators { ...@@ -22,7 +22,7 @@ namespace operators {
bool UniformRandomOpLite::CheckShape() const { return true; } bool UniformRandomOpLite::CheckShape() const { return true; }
bool UniformRandomOpLite::InferShape() const { bool UniformRandomOpLite::InferShapeImpl() const {
param_.Out->Resize(param_.shape); param_.Out->Resize(param_.shape);
return true; return true;
} }
......
...@@ -33,7 +33,7 @@ class UniformRandomOpLite : public OpLite { ...@@ -33,7 +33,7 @@ class UniformRandomOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
......
...@@ -62,7 +62,7 @@ bool UnsqueezeOp::CheckShape() const { ...@@ -62,7 +62,7 @@ bool UnsqueezeOp::CheckShape() const {
return true; return true;
} }
bool UnsqueezeOp::InferShape() const { bool UnsqueezeOp::InferShapeImpl() const {
std::vector<int> final_axes; std::vector<int> final_axes;
auto axes = param_.axes; auto axes = param_.axes;
auto *axes_tensor = param_.axes_tensor; auto *axes_tensor = param_.axes_tensor;
...@@ -129,8 +129,8 @@ bool Unsqueeze2Op::CheckShape() const { ...@@ -129,8 +129,8 @@ bool Unsqueeze2Op::CheckShape() const {
return true; return true;
} }
bool Unsqueeze2Op::InferShape() const { bool Unsqueeze2Op::InferShapeImpl() const {
UnsqueezeOp::InferShape(); UnsqueezeOp::InferShapeImpl();
auto x_dims = param_.X->dims(); auto x_dims = param_.X->dims();
std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 1); std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 1);
for (size_t i = 0; i < x_dims.size(); i++) { for (size_t i = 0; i < x_dims.size(); i++) {
......
...@@ -30,7 +30,7 @@ class UnsqueezeOp : public OpLite { ...@@ -30,7 +30,7 @@ class UnsqueezeOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
...@@ -48,7 +48,7 @@ class Unsqueeze2Op : public UnsqueezeOp { ...@@ -48,7 +48,7 @@ class Unsqueeze2Op : public UnsqueezeOp {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -21,7 +21,7 @@ namespace operators { ...@@ -21,7 +21,7 @@ namespace operators {
bool VarConv2dOp::CheckShape() const { return true; } bool VarConv2dOp::CheckShape() const { return true; }
bool VarConv2dOp::InferShape() const { return true; } bool VarConv2dOp::InferShapeImpl() const { return true; }
bool VarConv2dOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { bool VarConv2dOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.X = const_cast<lite::Tensor *>( param_.X = const_cast<lite::Tensor *>(
......
...@@ -27,7 +27,7 @@ class VarConv2dOp : public OpLite { ...@@ -27,7 +27,7 @@ class VarConv2dOp : public OpLite {
VarConv2dOp() {} VarConv2dOp() {}
explicit VarConv2dOp(const std::string &op_type) : OpLite(op_type) {} explicit VarConv2dOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "var_conv_2d"; } std::string DebugString() const override { return "var_conv_2d"; }
......
...@@ -27,7 +27,7 @@ bool WhileOpLite::CheckShape() const { ...@@ -27,7 +27,7 @@ bool WhileOpLite::CheckShape() const {
return true; return true;
} }
bool WhileOpLite::InferShape() const { return true; } bool WhileOpLite::InferShapeImpl() const { return true; }
bool WhileOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { bool WhileOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
auto inputs = op_desc.Input("X"); auto inputs = op_desc.Input("X");
......
...@@ -30,7 +30,7 @@ class WhileOpLite : public OpLite { ...@@ -30,7 +30,7 @@ class WhileOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -26,7 +26,7 @@ bool WriteToArrayOp::CheckShape() const { ...@@ -26,7 +26,7 @@ bool WriteToArrayOp::CheckShape() const {
return true; return true;
} }
bool WriteToArrayOp::InferShape() const { bool WriteToArrayOp::InferShapeImpl() const {
int id = param_.I->data<int64_t>()[0]; int id = param_.I->data<int64_t>()[0];
if (param_.Out->size() < id + 1) { if (param_.Out->size() < id + 1) {
param_.Out->resize(id + 1); param_.Out->resize(id + 1);
......
...@@ -30,7 +30,7 @@ class WriteToArrayOp : public OpLite { ...@@ -30,7 +30,7 @@ class WriteToArrayOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -46,7 +46,7 @@ bool YoloBoxOp::CheckShape() const { ...@@ -46,7 +46,7 @@ bool YoloBoxOp::CheckShape() const {
return true; return true;
} }
bool YoloBoxOp::InferShape() const { bool YoloBoxOp::InferShapeImpl() const {
auto* X = param_.X; auto* X = param_.X;
auto anchors = param_.anchors; auto anchors = param_.anchors;
int anchor_num = anchors.size() / 2; int anchor_num = anchors.size() / 2;
......
...@@ -30,7 +30,7 @@ class YoloBoxOp : public OpLite { ...@@ -30,7 +30,7 @@ class YoloBoxOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册