未验证 提交 1f8b5c2b 编写于 作者: H huzhiqiang 提交者: GitHub

[operator] add InferShapeImpl method (#3294)

上级 f81db03a
......@@ -22,6 +22,61 @@
namespace paddle {
namespace lite {
bool OpLite::InferShape() {
// if input_tensor_ptrs and output_tensor_ptrs are overloaded in param_
// InferShapeByMemoryInternal will be applied.
if (param_.input_tensor_ptrs() && param_.output_tensor_ptrs()) {
return this->InferShapeWithCache();
} else {
// otherwise, InferShapeImpl is applied directly.
return this->InferShapeImpl();
}
}
bool OpLite::InferShapeWithCache() {
// 1. Get vector of current input tensors
auto *current_inputs = param_.input_tensor_ptrs();
// 2. Get hash value of current inputs shape and lod
size_t new_hash = 0;
for (auto iter = current_inputs->begin(); iter != current_inputs->end();
iter++) {
// combined dims value into new_hash value.
auto &element_dims = (*iter)->dims();
for (int i = 0; i < element_dims.size(); i++) {
new_hash =
lite::hash_combine(new_hash, static_cast<int>(element_dims[i]));
}
// combine lod value into new_hash valud.
auto &emement_lods = (*iter)->lod();
for (auto lod_iter = emement_lods.begin(); lod_iter != emement_lods.end();
lod_iter++) {
for (int i = 0; i < lod_iter->size(); i++) {
new_hash =
lite::hash_combine(new_hash, static_cast<int>(lod_iter->at(i)));
}
}
}
// 3. infer shapes of output tensors
if (new_hash == io_shape_lod_hash_ && new_hash != 0) {
// if current hash value is consistent with io_shape_lod_hash_,
// previous outputs shape and lod are reused.
auto *current_outputs = param_.output_tensor_ptrs();
for (int i = 0; i < current_outputs->size(); i++) {
current_outputs->at(i)->Resize(last_output_shapes[i]);
current_outputs->at(i)->set_lod(last_output_lods[i]);
}
} else {
// otherwise, current hash value is changed, InferShapeImpl will apply.
io_shape_lod_hash_ = new_hash;
this->InferShapeImpl();
auto *current_outputs = param_.output_tensor_ptrs();
for (int i = 0; i < current_outputs->size(); i++) {
last_output_shapes[i] = current_outputs->at(i)->dims();
last_output_lods[i] = current_outputs->at(i)->lod();
}
}
return true;
}
std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
const std::vector<Place> &places, const std::string &kernel_type) {
std::vector<std::unique_ptr<KernelBase>> kernels;
......
......@@ -14,6 +14,7 @@
#pragma once
#include <functional>
#include <list>
#include <map>
#include <memory>
......@@ -24,6 +25,7 @@
#include "lite/core/kernel.h"
#include "lite/core/scope.h"
#include "lite/model_parser/cpp/op_desc.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
......@@ -64,8 +66,8 @@ class OpLite : public Registry {
// Check the shape.
virtual bool CheckShape() const { return true; }
// Inference the outputs' shape.
virtual bool InferShape() const { return true; }
virtual bool SmartInferShape() { return this->InferShape(); }
virtual bool InferShapeImpl() const { return true; }
virtual bool InferShape();
// Run this operator.
virtual bool Run();
// Indicate whether the Op runs only once or not
......@@ -151,10 +153,16 @@ class OpLite : public Registry {
std::vector<Place> valid_places_;
Place kernel_place_{TARGET(kHost), PRECISION(kFloat)};
std::unique_ptr<OpInfo> op_info_;
std::vector<DDimLite> last_input_shapes;
std::vector<DDimLite> last_output_shapes;
std::vector<std::vector<std::vector<uint64_t>>> last_output_lods;
std::vector<std::vector<std::vector<uint64_t>>> last_input_lods;
std::vector<DDimLite> last_output_shapes{};
std::vector<std::vector<std::vector<uint64_t>>> last_output_lods{};
size_t io_shape_lod_hash_{};
mutable operators::ParamBase param_;
private:
// Infer Shape according to memory, if current input shapes are consistent
// with that of previous inputs, output shapes of last time will be reused.
bool InferShapeWithCache();
};
/*
......
......@@ -286,8 +286,7 @@ void Instruction::Run() {
return;
}
// op_->InferShape();
op_->SmartInferShape();
op_->InferShape();
kernel_->Launch();
has_run_ = true;
}
......
......@@ -25,7 +25,7 @@ bool ActivationGradOp::CheckShape() const {
return true;
}
bool ActivationGradOp::InferShape() const {
bool ActivationGradOp::InferShapeImpl() const {
param_.X_grad->Resize(param_.Out_grad->dims());
return true;
}
......
......@@ -26,7 +26,7 @@ class ActivationGradOp : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
......
......@@ -25,7 +25,7 @@ bool ActivationOp::CheckShape() const {
return true;
}
bool ActivationOp::InferShape() const {
bool ActivationOp::InferShapeImpl() const {
param_.Out->Resize(param_.X->dims());
auto out_lod = param_.Out->mutable_lod();
*out_lod = param_.X->lod();
......
......@@ -26,7 +26,7 @@ class ActivationOp : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
......
......@@ -44,7 +44,7 @@ bool AffineChannelOpLite::CheckShape() const {
return true;
}
bool AffineChannelOpLite::InferShape() const {
bool AffineChannelOpLite::InferShapeImpl() const {
const auto x_dims = param_.X->dims();
param_.Out->Resize(x_dims);
return true;
......
......@@ -31,7 +31,7 @@ class AffineChannelOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -31,7 +31,7 @@ bool AnchorGeneratorOpLite::CheckShape() const {
return true;
}
bool AnchorGeneratorOpLite::InferShape() const {
bool AnchorGeneratorOpLite::InferShapeImpl() const {
auto input_dims = param_.Input->dims();
size_t num_anchors = param_.aspect_ratios.size() * param_.anchor_sizes.size();
std::vector<int64_t> output_shape(
......
......@@ -32,7 +32,7 @@ class AnchorGeneratorOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -29,7 +29,7 @@ bool ArgmaxOpLite::CheckShape() const {
return true;
}
bool ArgmaxOpLite::InferShape() const {
bool ArgmaxOpLite::InferShapeImpl() const {
auto x_dims = param_.X->dims();
int x_rank = x_dims.size();
int axis = param_.Axis;
......
......@@ -31,7 +31,7 @@ class ArgmaxOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -26,7 +26,7 @@ bool AssignOpLite::CheckShape() const {
return true;
}
bool AssignOpLite::InferShape() const {
bool AssignOpLite::InferShapeImpl() const {
lite::DDim input_dims;
input_dims = param_.X->dims();
param_.Out->Resize(lite::DDim(input_dims));
......
......@@ -30,7 +30,7 @@ class AssignOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -35,7 +35,7 @@ bool AssignValueOpLite::CheckShape() const {
return true;
}
bool AssignValueOpLite::InferShape() const {
bool AssignValueOpLite::InferShapeImpl() const {
std::vector<int> shape = param_.shape;
std::vector<int64_t> out_shape;
for (size_t i = 0; i < shape.size(); i++) out_shape.push_back(shape[i]);
......
......@@ -31,7 +31,7 @@ class AssignValueOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -28,7 +28,7 @@ bool AttentionPaddingMaskOp::CheckShape() const {
return true;
}
bool AttentionPaddingMaskOp::InferShape() const {
bool AttentionPaddingMaskOp::InferShapeImpl() const {
auto src_len = param_.X->lod()[0][1];
CHECK_EQ(src_len, param_.X->dims()[1])
<< "Mismatch source length, expect: " << src_len
......
......@@ -29,7 +29,7 @@ class AttentionPaddingMaskOp : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -34,7 +34,7 @@ bool AxpyOpLite::CheckShape() const {
return true;
}
bool AxpyOpLite::InferShape() const {
bool AxpyOpLite::InferShapeImpl() const {
auto dims = param_.Bias->dims();
// Set output dims
......
......@@ -31,7 +31,7 @@ class AxpyOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -46,7 +46,7 @@ bool BatchNormOp::CheckShape() const {
return true;
}
bool BatchNormOp::InferShape() const {
bool BatchNormOp::InferShapeImpl() const {
auto x_dims = param_.x->dims();
int64_t channel_size = 0;
switch (param_.data_layout) {
......
......@@ -30,7 +30,7 @@ class BatchNormOp : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -28,7 +28,7 @@ bool BeamSearchDecodeOpLite::CheckShape() const {
return true;
}
bool BeamSearchDecodeOpLite::InferShape() const { return true; }
bool BeamSearchDecodeOpLite::InferShapeImpl() const { return true; }
bool BeamSearchDecodeOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
......
......@@ -31,7 +31,7 @@ class BeamSearchDecodeOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -30,7 +30,7 @@ bool BeamSearchOp::CheckShape() const {
return true;
}
bool BeamSearchOp::InferShape() const { return true; }
bool BeamSearchOp::InferShapeImpl() const { return true; }
bool BeamSearchOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.pre_ids = scope->FindTensor(opdesc.Input("pre_ids").front());
......
......@@ -30,7 +30,7 @@ class BeamSearchOp : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -35,7 +35,7 @@ bool BoxClipOpLite::CheckShape() const {
return true;
}
bool BoxClipOpLite::InferShape() const {
bool BoxClipOpLite::InferShapeImpl() const {
auto* input = param_.Input;
auto* output = param_.Output;
output->Resize(input->dims());
......
......@@ -31,7 +31,7 @@ class BoxClipOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -35,7 +35,7 @@ bool BoxCoderOpLite::CheckShape() const {
return true;
}
bool BoxCoderOpLite::InferShape() const {
bool BoxCoderOpLite::InferShapeImpl() const {
auto prior_box_dims = param_.prior_box->dims();
auto target_box_dims = param_.target_box->dims();
std::string code_type = param_.code_type;
......
......@@ -29,7 +29,7 @@ class BoxCoderOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -24,7 +24,7 @@ bool CalibOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.output);
return true;
}
bool CalibOpLite::InferShape() const {
bool CalibOpLite::InferShapeImpl() const {
param_.output->Resize(param_.input->dims());
return true;
}
......
......@@ -42,7 +42,7 @@ class CalibOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope);
......
......@@ -25,7 +25,7 @@ bool CastOp::CheckShape() const {
return true;
}
bool CastOp::InferShape() const {
bool CastOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.Out);
// TODO(Superjomn) Enable data sharing.
auto out_dims = param_.X->dims();
......
......@@ -30,7 +30,7 @@ class CastOp : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -43,7 +43,7 @@ bool CollectFpnProposalsOpLite::CheckShape() const {
return true;
}
bool CollectFpnProposalsOpLite::InferShape() const {
bool CollectFpnProposalsOpLite::InferShapeImpl() const {
param_.fpn_rois->Resize({param_.post_nms_topN, 4});
return true;
......
......@@ -32,7 +32,7 @@ class CollectFpnProposalsOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -26,7 +26,7 @@ bool CompareOp::CheckShape() const {
return true;
}
bool CompareOp::InferShape() const {
bool CompareOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.Out);
// TODO(Superjomn) Enable data sharing.
auto input_dims = param_.X->dims();
......
......@@ -30,7 +30,7 @@ class CompareOp : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -26,7 +26,7 @@ bool ConcatOpLite::CheckShape() const {
return true;
}
bool ConcatOpLite::InferShape() const {
bool ConcatOpLite::InferShapeImpl() const {
const std::vector<Tensor *> &inputs = param_.x;
const size_t n = inputs.size();
CHECK_GT_OR_FALSE(n, 0);
......
......@@ -30,7 +30,7 @@ class ConcatOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -27,7 +27,7 @@ bool ConditionalBlockOpLite::CheckShape() const {
return true;
}
bool ConditionalBlockOpLite::InferShape() const { return true; }
bool ConditionalBlockOpLite::InferShapeImpl() const { return true; }
bool ConditionalBlockOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
......
......@@ -31,7 +31,7 @@ class ConditionalBlockOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -80,35 +80,7 @@ void UpdatePaddingAndDilation(std::vector<int>* paddings,
}
}
bool ConvOpLite::SmartInferShape() {
if (!last_input_shapes.empty()) {
if (last_input_shapes[0] == param_.x->dims() &&
last_input_lods[0] == param_.x->lod()) {
param_.output->Resize(last_output_shapes[0]);
param_.output->set_lod(last_output_lods[0]);
return true;
}
}
this->InferShape();
if (!last_input_shapes.empty()) {
last_input_shapes.clear();
last_input_lods.clear();
}
last_input_shapes.push_back(param_.x->dims());
last_input_lods.push_back(param_.x->lod());
if (!last_output_shapes.empty()) {
last_output_shapes.clear();
last_output_lods.clear();
}
last_output_shapes.push_back(param_.output->dims());
last_output_lods.push_back(param_.output->lod());
return true;
}
bool ConvOpLite::InferShape() const {
bool ConvOpLite::InferShapeImpl() const {
const auto in_dims = param_.x->dims();
const auto filter_dims = param_.filter->dims();
......
......@@ -34,9 +34,7 @@ class ConvOpLite : public OpLite {
explicit ConvOpLite(const std::string& type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool SmartInferShape() override;
bool InferShapeImpl() const override;
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override {
......
......@@ -52,7 +52,7 @@ inline int ConvTransposeOutputSize(int input_size,
return output_size;
}
bool ConvTransposeOpLite::InferShape() const {
bool ConvTransposeOpLite::InferShapeImpl() const {
const auto in_dims = param_.x->dims();
const auto filter_dims = param_.filter->dims();
......
......@@ -34,7 +34,7 @@ class ConvTransposeOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -60,7 +60,7 @@ bool CrfDecodingOpLite::CheckShape() const {
return true;
}
bool CrfDecodingOpLite::InferShape() const {
bool CrfDecodingOpLite::InferShapeImpl() const {
auto emission_dims = param_.emission->dims();
if (param_.length == nullptr) {
param_.viterbi_path->Resize({emission_dims[0], 1});
......
......@@ -31,7 +31,7 @@ class CrfDecodingOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -26,7 +26,7 @@ bool CropOpLite::CheckShape() const {
return true;
}
bool CropOpLite::InferShape() const {
bool CropOpLite::InferShapeImpl() const {
// nchw
auto x_dims = param_.X->dims();
lite::DDim output_shape(x_dims);
......
......@@ -30,7 +30,7 @@ class CropOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -29,7 +29,7 @@ bool DecodeBboxesOpLite::CheckShape() const {
return true;
}
bool DecodeBboxesOpLite::InferShape() const {
bool DecodeBboxesOpLite::InferShapeImpl() const {
param_.bbox_data->Resize(param_.loc_data->dims());
return true;
}
......
......@@ -29,7 +29,7 @@ class DecodeBboxesOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -27,7 +27,7 @@ bool DensityPriorBoxOpLite::CheckShape() const {
return true;
}
bool DensityPriorBoxOpLite::InferShape() const { return true; }
bool DensityPriorBoxOpLite::InferShapeImpl() const { return true; }
bool DensityPriorBoxOpLite::AttachImpl(const cpp::OpDesc& opdesc,
lite::Scope* scope) {
......
......@@ -30,7 +30,7 @@ class DensityPriorBoxOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -32,7 +32,7 @@ bool DistributeFpnProposalsOpLite::CheckShape() const {
return true;
}
bool DistributeFpnProposalsOpLite::InferShape() const {
bool DistributeFpnProposalsOpLite::InferShapeImpl() const {
int num_out_rois = param_.max_level - param_.min_level + 1;
for (int i = 0; i < num_out_rois; i++) {
param_.multi_fpn_rois[i]->Resize({-1, 4});
......
......@@ -32,7 +32,7 @@ class DistributeFpnProposalsOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -26,7 +26,7 @@ bool DropoutOp::CheckShape() const {
return true;
}
bool DropoutOp::InferShape() const {
bool DropoutOp::InferShapeImpl() const {
const auto x_dims = param_.x->dims();
param_.output->Resize(x_dims);
if (param_.is_test == false) {
......
......@@ -28,7 +28,7 @@ class DropoutOp : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one.
......
......@@ -26,7 +26,7 @@ bool ElementwiseGradOp::CheckShape() const {
return true;
}
bool ElementwiseGradOp::InferShape() const {
bool ElementwiseGradOp::InferShapeImpl() const {
auto x_dim = param_.X->dims();
auto y_dim = param_.Y->dims();
if (param_.XGrad) {
......
......@@ -27,7 +27,7 @@ class ElementwiseGradOp : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
......
......@@ -26,39 +26,8 @@ bool ElementwiseOp::CheckShape() const {
CHECK_OR_FALSE(param_.Out);
return true;
}
bool ElementwiseOp::SmartInferShape() {
if (!last_input_shapes.empty()) {
if (last_input_shapes[0] == param_.X->dims() &&
last_input_shapes[1] == param_.Y->dims() &&
last_input_lods[0] == param_.X->lod() &&
last_input_lods[1] == param_.Y->lod()) {
param_.Out->Resize(last_output_shapes[0]);
param_.Out->set_lod(last_output_lods[0]);
return true;
}
}
this->InferShape();
if (!last_input_shapes.empty()) {
last_input_shapes.clear();
last_input_lods.clear();
}
last_input_shapes.push_back(param_.X->dims());
last_input_lods.push_back(param_.X->lod());
last_input_shapes.push_back(param_.Y->dims());
last_input_lods.push_back(param_.Y->lod());
if (!last_output_shapes.empty()) {
last_output_shapes.clear();
last_output_lods.clear();
}
last_output_shapes.push_back(param_.Out->dims());
last_output_lods.push_back(param_.Out->lod());
return true;
}
bool ElementwiseOp::InferShape() const {
bool ElementwiseOp::InferShapeImpl() const {
auto x_dim = param_.X->dims();
auto y_dim = param_.Y->dims();
if (x_dim == y_dim) {
......@@ -136,7 +105,7 @@ bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
// return true;
//}
// bool ElementwiseGradExplicitOp::InferShape() const {
// bool ElementwiseGradExplicitOp::InferShapeImpl() const {
// param_.X_grad->Resize(param_.Out_grad->dims());
// if (param_.Y_grad) param_.Y_grad->Resize(param_.Y->dims());
// return true;
......
......@@ -27,8 +27,7 @@ class ElementwiseOp : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool SmartInferShape() override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
......@@ -48,7 +47,7 @@ class ElementwiseOp : public OpLite {
// bool CheckShape() const override;
// bool InferShape() const override;
// bool InferShapeImpl() const override;
// bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
......
......@@ -32,7 +32,7 @@ bool ExpandOpLite::CheckShape() const {
return true;
}
bool ExpandOpLite::InferShape() const {
bool ExpandOpLite::InferShapeImpl() const {
DDim out_dims(param_.X->dims());
for (size_t i = 0; i < param_.expand_times.size(); ++i) {
out_dims[i] *= param_.expand_times[i];
......
......@@ -28,7 +28,7 @@ class ExpandOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -36,7 +36,7 @@ class FakeChannelWiseDequantizeMaxAbsOpLite : public OpLite {
bool CheckShape() const override { return true; }
bool InferShape() const override { return true; }
bool InferShapeImpl() const override { return true; }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front();
......
......@@ -35,7 +35,7 @@ class FakeDequantizeMaxAbsOpLite : public OpLite {
bool CheckShape() const override { return true; }
bool InferShape() const override { return true; }
bool InferShapeImpl() const override { return true; }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front();
......
......@@ -36,7 +36,7 @@ class FakeQuantizeDequantizeMovingAvgMaxAbsOpLite : public OpLite {
bool CheckShape() const override { return true; }
bool InferShape() const override { return true; }
bool InferShapeImpl() const override { return true; }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front();
......
......@@ -36,7 +36,7 @@ class FakeQuantizeMovingAvgMaxAbsOpLite : public OpLite {
bool CheckShape() const override { return true; }
bool InferShape() const override { return true; }
bool InferShapeImpl() const override { return true; }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front();
......
......@@ -36,7 +36,7 @@ class FakeQuantizeRangeMaxAbsOpLite : public OpLite {
bool CheckShape() const override { return true; }
bool InferShape() const override { return true; }
bool InferShapeImpl() const override { return true; }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front();
......
......@@ -48,34 +48,7 @@ bool FcOpLite::CheckShape() const {
return true;
}
bool FcOpLite::SmartInferShape() {
if (!last_input_shapes.empty() && !last_output_shapes.empty()) {
if (last_input_shapes[0] == param_.input->dims() &&
last_input_lods[0] == param_.input->lod()) {
param_.output->Resize(last_output_shapes[0]);
param_.output->set_lod(last_output_lods[0]);
return true;
}
}
this->InferShape();
if (!last_input_shapes.empty()) {
last_input_shapes.clear();
last_input_lods.clear();
}
last_input_shapes.push_back(param_.input->dims());
last_input_lods.push_back(param_.input->lod());
if (!last_output_shapes.empty()) {
last_output_shapes.clear();
last_output_lods.clear();
}
last_output_shapes.push_back(param_.output->dims());
last_output_lods.push_back(param_.output->lod());
return true;
}
bool FcOpLite::InferShape() const {
bool FcOpLite::InferShapeImpl() const {
const auto& input_dims = param_.input->dims();
const auto& w_dims = param_.w->dims();
int in_num_col_dims = param_.in_num_col_dims;
......
......@@ -35,8 +35,7 @@ class FcOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool SmartInferShape() override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override;
......
......@@ -29,7 +29,7 @@ class FeedOp : public OpLite {
return true;
}
bool InferShape() const override { return true; }
bool InferShapeImpl() const override { return true; }
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
......
......@@ -29,7 +29,7 @@ class FetchOp : public OpLite {
return true;
}
bool InferShape() const override { return true; }
bool InferShapeImpl() const override { return true; }
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
protected:
......
......@@ -28,7 +28,7 @@ bool FillConstantBatchSizeLikeOp::CheckShape() const {
return true;
}
bool FillConstantBatchSizeLikeOp::InferShape() const {
bool FillConstantBatchSizeLikeOp::InferShapeImpl() const {
std::vector<int64_t> output_dim{param_.shape.begin(), param_.shape.end()};
if (param_.input_dim_idx == 0 && !param_.input->lod().empty()) {
output_dim[param_.output_dim_idx] = param_.input->lod().back().size() - 1;
......
......@@ -32,7 +32,7 @@ class FillConstantBatchSizeLikeOp : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -24,7 +24,7 @@ bool FillConstantOp::CheckShape() const {
return true;
}
bool FillConstantOp::InferShape() const {
bool FillConstantOp::InferShapeImpl() const {
std::vector<int64_t> out_shape;
auto shape_tensor = param_.shape_tensor;
auto shape_tensor_list = param_.shape_tensor_list;
......
......@@ -31,7 +31,7 @@ class FillConstantOp : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -25,7 +25,7 @@ bool FlattenOp::CheckShape() const {
return true;
}
bool FlattenOp::InferShape() const {
bool FlattenOp::InferShapeImpl() const {
auto x_dims = param_.x->dims();
auto out_lod = param_.output->mutable_lod();
......@@ -71,8 +71,8 @@ bool Flatten2Op::CheckShape() const {
return true;
}
bool Flatten2Op::InferShape() const {
FlattenOp::InferShape();
bool Flatten2Op::InferShapeImpl() const {
FlattenOp::InferShapeImpl();
auto x_dims = param_.x->dims();
std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 0);
for (size_t i = 0; i < x_dims.size(); i++) {
......
......@@ -30,7 +30,7 @@ class FlattenOp : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......@@ -49,7 +49,7 @@ class Flatten2Op : public FlattenOp {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -27,7 +27,7 @@ bool FusionElementwiseActivationOp::CheckShape() const {
return true;
}
bool FusionElementwiseActivationOp::InferShape() const {
bool FusionElementwiseActivationOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.X->dims().size() >= param_.Y->dims().size());
param_.Out->Resize(param_.X->dims());
return true;
......@@ -59,7 +59,7 @@ bool FusionElementwiseActivationOp::AttachImpl(const cpp::OpDesc& opdesc,
// return true;
// }
// bool FusionElementwiseActivationGradExplicitOp::InferShape() const {
// bool FusionElementwiseActivationGradExplicitOp::InferShapeImpl() const {
// param_.X_grad->Resize(param_.Out_grad->dims());
// param_.Y_grad->Resize(param_.Y->dims());
// return true;
......
......@@ -29,7 +29,7 @@ class FusionElementwiseActivationOp : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
......@@ -51,7 +51,7 @@ class FusionElementwiseActivationOp : public OpLite {
// bool CheckShape() const override;
// bool InferShape() const override;
// bool InferShapeImpl() const override;
// bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
......
......@@ -26,7 +26,7 @@ bool GatherOp::CheckShape() const {
return true;
}
bool GatherOp::InferShape() const {
bool GatherOp::InferShapeImpl() const {
auto index_dims = param_.Index->dims();
CHECK(index_dims.size() == 1 ||
(index_dims.size() == 2 && index_dims[1] == 1))
......
......@@ -30,7 +30,7 @@ class GatherOp : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -43,7 +43,7 @@ bool GenerateProposalsOpLite::CheckShape() const {
return true;
}
bool GenerateProposalsOpLite::InferShape() const {
bool GenerateProposalsOpLite::InferShapeImpl() const {
param_.RpnRois->Resize(std::vector<int64_t>({-1, 4}));
param_.RpnRoiProbs->Resize(std::vector<int64_t>({-1, 1}));
return true;
......
......@@ -32,7 +32,7 @@ class GenerateProposalsOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -42,7 +42,7 @@ bool GridSamplerOp::CheckShape() const {
return true;
}
bool GridSamplerOp::InferShape() const {
bool GridSamplerOp::InferShapeImpl() const {
auto x_dims = param_.x->dims();
param_.out->Resize(x_dims);
return true;
......
......@@ -31,7 +31,7 @@ class GridSamplerOp : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -51,7 +51,7 @@ bool GRUOpLite::CheckShape() const {
return true;
}
bool GRUOpLite::InferShape() const {
bool GRUOpLite::InferShapeImpl() const {
const auto& input_dims = param_.input->dims();
const auto& weight_dims = param_.weight->dims();
int frame_size = weight_dims[0];
......
......@@ -30,7 +30,7 @@ class GRUOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -51,7 +51,7 @@ bool GRUUnitOpLite::CheckShape() const {
return true;
}
bool GRUUnitOpLite::InferShape() const {
bool GRUUnitOpLite::InferShapeImpl() const {
auto input_dims = param_.input->dims();
auto hidden_prev_dims = param_.hidden_prev->dims();
auto weight_dims = param_.weight->dims();
......
......@@ -30,7 +30,7 @@ class GRUUnitOpLite : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -26,7 +26,7 @@ inline int Im2SeqOutputSize(
}
bool Im2SequenceOp::CheckShape() const { return true; }
bool Im2SequenceOp::InferShape() const {
bool Im2SequenceOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.Out);
// TODO(Superjomn) Enable data sharing.
auto input_dims = param_.X->dims();
......
......@@ -30,7 +30,7 @@ class Im2SequenceOp : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -25,7 +25,7 @@ bool IncrementOp::CheckShape() const {
return true;
}
bool IncrementOp::InferShape() const {
bool IncrementOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.Out);
// TODO(Superjomn) Enable data sharing.
auto out_dims = param_.X->dims();
......
......@@ -30,7 +30,7 @@ class IncrementOp : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -42,7 +42,7 @@ bool InstanceNormOp::CheckShape() const {
return true;
}
bool InstanceNormOp::InferShape() const {
bool InstanceNormOp::InferShapeImpl() const {
auto x_dims = param_.x->dims();
int64_t batch_size = x_dims[0];
int64_t channel_size = x_dims[1];
......
......@@ -31,7 +31,7 @@ class InstanceNormOp : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
......@@ -34,7 +34,7 @@ bool InterpolateOp::CheckShape() const {
return true;
}
bool InterpolateOp::InferShape() const {
bool InterpolateOp::InferShapeImpl() const {
auto X = param_.X;
int n = X->dims()[0];
......
......@@ -31,7 +31,7 @@ class InterpolateOp : public OpLite {
bool CheckShape() const override;
bool InferShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册