提交 c2f3a5a3 编写于 作者: H huzhiqiang 提交者: GitHub

[Framework][InferShape] accelerate op infer_shape period (#3434)

上级 44c4c68d
......@@ -25,16 +25,16 @@ namespace lite {
bool OpLite::InferShape() {
// if input_tensor_ptrs and output_tensor_ptrs are overloaded in param_
// InferShapeByMemoryInternal will be applied.
if (param_.input_tensor_ptrs() && param_.output_tensor_ptrs()) {
if (op_param_ && op_param_->input_tensor_ptrs() &&
op_param_->output_tensor_ptrs()) {
return this->InferShapeWithCache();
} else {
// otherwise, InferShapeImpl is applied directly.
return this->InferShapeImpl();
}
}
bool OpLite::InferShapeWithCache() {
// 1. Get vector of current input tensors
auto *current_inputs = param_.input_tensor_ptrs();
auto *current_inputs = op_param_->input_tensor_ptrs();
// 2. Get hash value of current inputs shape and lod
size_t new_hash = 0;
for (auto iter = current_inputs->begin(); iter != current_inputs->end();
......@@ -59,7 +59,7 @@ bool OpLite::InferShapeWithCache() {
if (new_hash == io_shape_lod_hash_ && new_hash != 0) {
// if current hash value is consistent with io_shape_lod_hash_,
// previous outputs shape and lod are reused.
auto *current_outputs = param_.output_tensor_ptrs();
auto *current_outputs = op_param_->output_tensor_ptrs();
for (size_t i = 0; i < current_outputs->size(); i++) {
current_outputs->at(i)->Resize(last_output_shapes[i]);
current_outputs->at(i)->set_lod(last_output_lods[i]);
......@@ -68,10 +68,12 @@ bool OpLite::InferShapeWithCache() {
// otherwise, current hash value is changed, InferShapeImpl will apply.
io_shape_lod_hash_ = new_hash;
this->InferShapeImpl();
auto *current_outputs = param_.output_tensor_ptrs();
auto *current_outputs = op_param_->output_tensor_ptrs();
last_output_shapes.clear();
last_output_lods.clear();
for (size_t i = 0; i < current_outputs->size(); i++) {
last_output_shapes[i] = current_outputs->at(i)->dims();
last_output_lods[i] = current_outputs->at(i)->lod();
last_output_shapes.push_back(current_outputs->at(i)->dims());
last_output_lods.push_back(current_outputs->at(i)->lod());
}
}
return true;
......
......@@ -77,6 +77,11 @@ class OpLite : public Registry {
// Link the external execution environ to internal context.
bool Attach(const cpp::OpDesc &opdesc, lite::Scope *scope);
template <typename T>
inline void AttachParam(T *param) {
op_param_ = static_cast<T *>(param);
}
const OpInfo *op_info() const { return op_info_.get(); }
OpInfo *mutable_op_info() { return op_info_.get(); }
......@@ -167,11 +172,10 @@ class OpLite : public Registry {
std::vector<Place> valid_places_;
Place kernel_place_{TARGET(kHost), PRECISION(kFloat)};
std::unique_ptr<OpInfo> op_info_;
std::vector<DDimLite> last_output_shapes{};
std::vector<std::vector<std::vector<uint64_t>>> last_output_lods{};
size_t io_shape_lod_hash_{};
mutable operators::ParamBase param_;
mutable operators::ParamBase *op_param_{nullptr};
private:
// Infer Shape according to memory, if current input shapes are consistent
......
......@@ -73,6 +73,7 @@ bool BatchNormOp::InferShapeImpl() const {
}
bool BatchNormOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
AttachParam(&param_);
param_.x = scope->FindVar(op_desc.Input("X").front())->GetMutable<Tensor>();
param_.bias =
scope->FindVar(op_desc.Input("Bias").front())->GetMutable<Tensor>();
......
......@@ -66,6 +66,7 @@ bool ConcatOpLite::InferShapeImpl() const {
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool ConcatOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
AttachParam(&param_);
auto inputs = op_desc.Input("X");
auto out = op_desc.Output("Out").front();
......
......@@ -38,6 +38,7 @@ class ConvOpLite : public OpLite {
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override {
AttachParam(&param_);
auto X = op_desc.Input("Input").front();
auto Filter = op_desc.Input("Filter").front();
auto Out = op_desc.Output("Output").front();
......
......@@ -87,6 +87,8 @@ bool ElementwiseOp::InferShapeImpl() const {
}
bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
AttachParam(&param_);
auto X_name = opdesc.Input("X").front();
auto Y_name = opdesc.Input("Y").front();
auto Out_name = opdesc.Output("Out").front();
......
......@@ -69,6 +69,8 @@ bool FcOpLite::InferShapeImpl() const {
}
bool FcOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
AttachParam(&param_);
auto input = op_desc.Input("Input").front();
auto W = op_desc.Input("W").front();
auto out = op_desc.Output("Out").front();
......
......@@ -132,6 +132,7 @@ bool MatMulOpLite::InferShapeImpl() const {
}
bool MatMulOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
AttachParam(&param_);
CHECK(!op_desc.Input("X").empty());
CHECK(!op_desc.Input("Y").empty());
CHECK(!op_desc.Output("Out").empty());
......
......@@ -38,6 +38,8 @@ class MulOpLite : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
AttachParam(&param_);
CHECK(!op_desc.Input("X").empty());
CHECK(!op_desc.Input("Y").empty());
CHECK(!op_desc.Output("Out").empty());
......@@ -56,7 +58,6 @@ class MulOpLite : public OpLite {
param_.output = var->GetMutable<Tensor>();
param_.x_num_col_dims = op_desc.GetAttr<int>("x_num_col_dims");
param_.y_num_col_dims = op_desc.GetAttr<int>("y_num_col_dims");
return true;
}
......
......@@ -35,8 +35,11 @@ namespace operators {
struct ParamBase {
public:
const std::vector<Tensor*>* input_tensor_ptrs() const { return nullptr; }
std::vector<Tensor*>* output_tensor_ptrs() { return nullptr; }
virtual ~ParamBase() {}
virtual const std::vector<const Tensor*>* input_tensor_ptrs() {
return nullptr;
}
virtual std::vector<Tensor*>* output_tensor_ptrs() { return nullptr; }
protected:
std::shared_ptr<std::vector<const Tensor*>> input_tensor_ptrs_cache_{nullptr};
......@@ -108,15 +111,15 @@ struct FcParam : ParamBase {
WITH_INT8_CONFIG
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({input}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
......@@ -160,15 +163,15 @@ struct MulParam : ParamBase {
WITH_INT8_CONFIG
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x, y}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
......@@ -243,15 +246,15 @@ struct ScaleParam : ParamBase {
bool bias_after_scale{true};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
......@@ -265,15 +268,15 @@ struct SoftmaxParam : ParamBase {
int axis{-1};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
......@@ -292,15 +295,15 @@ struct ReshapeParam : ParamBase {
bool inplace{false};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
......@@ -314,8 +317,8 @@ struct ConcatParam : ParamBase {
int axis{0};
lite::Tensor* axis_tensor{};
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
std::vector<const Tensor*> vec;
for (auto in : x) {
vec.push_back(in);
......@@ -325,8 +328,8 @@ struct ConcatParam : ParamBase {
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
......@@ -406,15 +409,15 @@ struct ConvParam : ParamBase {
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
......@@ -440,15 +443,15 @@ struct BatchNormParam : ParamBase {
DataLayoutType data_layout{DATALAYOUT(kNCHW)};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({y}));
}
return output_tensor_ptrs_cache_.get();
......@@ -479,15 +482,15 @@ struct PoolParam : ParamBase {
WITH_INT8_CONFIG
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
......@@ -518,15 +521,15 @@ struct SplitParam : ParamBase {
std::vector<int> sections;
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
......@@ -544,15 +547,15 @@ struct TransposeParam : ParamBase {
std::string data_format{"AnyLayout"};
///////////////////////////////////////////////////////////////////////////////////
// // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({x}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({output}));
}
return output_tensor_ptrs_cache_.get();
......@@ -571,15 +574,15 @@ struct ElementwiseParam : ParamBase {
float y_input_scale{1.0};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X, Y}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out}));
}
return output_tensor_ptrs_cache_.get();
......@@ -884,15 +887,15 @@ struct SequenceSoftmaxParam : ParamBase {
lite::Tensor* Out{};
///////////////////////////////////////////////////////////////////////////////////
// // get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out}));
}
return output_tensor_ptrs_cache_.get();
......@@ -1135,15 +1138,15 @@ struct SliceParam : ParamBase {
lite::Tensor* EndsTensor{nullptr};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out}));
}
return output_tensor_ptrs_cache_.get();
......@@ -1197,15 +1200,15 @@ struct SqueezeParam : ParamBase {
std::vector<int> axes{};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out}));
}
return output_tensor_ptrs_cache_.get();
......@@ -1221,15 +1224,15 @@ struct UnsqueezeParam : ParamBase {
std::vector<const lite::Tensor*> axes_tensor_vct{};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out}));
}
return output_tensor_ptrs_cache_.get();
......@@ -1253,15 +1256,15 @@ struct MatMulParam : ParamBase {
float alpha{1.0f};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() {
if (UNLIKELY(input_tensor_ptrs_cache_)) {
const std::vector<const Tensor*>* input_tensor_ptrs() override {
if (!input_tensor_ptrs_cache_) {
input_tensor_ptrs_cache_.reset(new std::vector<const Tensor*>({X, Y}));
}
return input_tensor_ptrs_cache_.get();
}
// get a vector of output tensors
const std::vector<Tensor*>* output_tensor_ptrs() {
if (UNLIKELY(output_tensor_ptrs_cache_)) {
std::vector<Tensor*>* output_tensor_ptrs() override {
if (!output_tensor_ptrs_cache_) {
output_tensor_ptrs_cache_.reset(new std::vector<lite::Tensor*>({Out}));
}
return output_tensor_ptrs_cache_.get();
......
......@@ -41,6 +41,7 @@ class PoolOpLite : public OpLite {
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
AttachParam(&param_);
auto x = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front();
......
......@@ -56,6 +56,7 @@ bool ReshapeOp::InferShapeImpl() const {
}
bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
AttachParam(&param_);
param_.x =
scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
param_.output =
......
......@@ -30,6 +30,7 @@ bool ScaleOp::InferShapeImpl() const {
}
bool ScaleOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
AttachParam(&param_);
auto x = op_desc.Input("X").front();
auto output = op_desc.Output("Out").front();
param_.x = scope->FindVar(x)->GetMutable<Tensor>();
......
......@@ -34,6 +34,7 @@ bool SequenceSoftmaxOp::InferShapeImpl() const {
bool SequenceSoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc,
lite::Scope *scope) {
AttachParam(&param_);
param_.X =
scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
param_.Out =
......
......@@ -87,6 +87,7 @@ bool SliceOp::InferShapeImpl() const {
}
bool SliceOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
AttachParam(&param_);
param_.X =
scope->FindVar(opdesc.Input("Input").front())->GetMutable<lite::Tensor>();
param_.Out =
......
......@@ -38,6 +38,8 @@ bool SoftmaxOp::InferShapeImpl() const {
}
bool SoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
AttachParam(&param_);
param_.x = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
param_.output =
......
......@@ -75,6 +75,7 @@ bool SplitOp::InferShapeImpl() const {
}
bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
AttachParam(&param_);
param_.axis = opdesc.GetAttr<int>("axis");
param_.num = opdesc.GetAttr<int>("num");
param_.sections = opdesc.GetAttr<std::vector<int>>("sections");
......
......@@ -84,6 +84,7 @@ bool SqueezeOp::InferShapeImpl() const {
}
bool SqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
AttachParam(&param_);
auto x_var = scope->FindVar(opdesc.Input("X").front());
auto output_var = scope->FindVar(opdesc.Output("Out").front());
CHECK(x_var);
......
......@@ -70,6 +70,7 @@ bool TransposeOp::InferShapeImpl() const {
}
bool TransposeOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
AttachParam(&param_);
auto x = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front();
......
......@@ -89,6 +89,7 @@ bool UnsqueezeOp::InferShapeImpl() const {
}
bool UnsqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
AttachParam(&param_);
auto x_var = scope->FindVar(opdesc.Input("X").front());
auto output_var = scope->FindVar(opdesc.Output("Out").front());
CHECK(x_var);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册