From d8bdd9bf2c6aa4f2cf5cc1e10cd99adc085c93e8 Mon Sep 17 00:00:00 2001 From: liutuo Date: Wed, 14 Aug 2019 14:53:46 +0800 Subject: [PATCH] fix onnx deconv --- mace/core/types.h | 1 + mace/ops/deconv_2d.cc | 26 +++++++++++++------------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/mace/core/types.h b/mace/core/types.h index aa1f9a89..5bdd4930 100644 --- a/mace/core/types.h +++ b/mace/core/types.h @@ -64,6 +64,7 @@ MACE_MAPPING_DATA_TYPE_AND_ENUM(int32_t, DT_INT32); enum FrameworkType { TENSORFLOW = 0, CAFFE = 1, + ONNX = 2, }; template diff --git a/mace/ops/deconv_2d.cc b/mace/ops/deconv_2d.cc index 6453544a..9eba1bd0 100644 --- a/mace/ops/deconv_2d.cc +++ b/mace/ops/deconv_2d.cc @@ -66,12 +66,12 @@ class Deconv2dOp : public Deconv2dOpBase { const Tensor *filter = this->Input(1); const Tensor *bias = nullptr; const Tensor *output_shape_tensor = nullptr; - if (model_type_ == CAFFE) { - bias = this->InputSize() >= 3 ? this->Input(2) : nullptr; - } else { + if (model_type_ == TENSORFLOW) { output_shape_tensor = this->InputSize() >= 3 ? this->Input(2) : nullptr; bias = this->InputSize() >= 4 ? this->Input(3) : nullptr; + } else { + bias = this->InputSize() >= 3 ? this->Input(2) : nullptr; } Tensor *output = this->Output(0); @@ -182,13 +182,7 @@ class Deconv2dOp : public Deconv2dOpBase { context, operator_def_.get(), 1, OpenCLBufferType::CONV2D_FILTER, mem_type) == MaceStatus::MACE_SUCCESS); - if (model_type_ == FrameworkType::CAFFE) { - if (operator_def_->input_size() >= 3) { - MACE_CHECK(TransformFilter( - context, operator_def_.get(), 2, - OpenCLBufferType::ARGUMENT, mem_type) == MaceStatus::MACE_SUCCESS); - } - } else { + if (model_type_ == FrameworkType::TENSORFLOW) { if (operator_def_->input_size() >= 4) { MACE_CHECK(TransformFilter( context, @@ -197,6 +191,12 @@ class Deconv2dOp : public Deconv2dOpBase { OpenCLBufferType::ARGUMENT, mem_type) == MaceStatus::MACE_SUCCESS); } + } else { + if (operator_def_->input_size() >= 3) { + MACE_CHECK(TransformFilter( + context, operator_def_.get(), 2, + OpenCLBufferType::ARGUMENT, mem_type) == MaceStatus::MACE_SUCCESS); + } } } MaceStatus Run(OpContext *context) override { @@ -204,12 +204,12 @@ class Deconv2dOp : public Deconv2dOpBase { const Tensor *filter = this->Input(1); const Tensor *bias = nullptr; const Tensor *output_shape_tensor = nullptr; - if (model_type_ == CAFFE) { - bias = this->InputSize() >= 3 ? this->Input(2) : nullptr; - } else { + if (model_type_ == TENSORFLOW) { output_shape_tensor = this->InputSize() >= 3 ? this->Input(2) : nullptr; bias = this->InputSize() >= 4 ? this->Input(3) : nullptr; + } else { + bias = this->InputSize() >= 3 ? this->Input(2) : nullptr; } Tensor *output = this->Output(0); -- GitLab