diff --git a/mace/core/types.h b/mace/core/types.h index aa1f9a89f911aaa076171bc8d7cc7322b767358c..5bdd4930c17aface45bf8859e3291e1d8464b228 100644 --- a/mace/core/types.h +++ b/mace/core/types.h @@ -64,6 +64,7 @@ MACE_MAPPING_DATA_TYPE_AND_ENUM(int32_t, DT_INT32); enum FrameworkType { TENSORFLOW = 0, CAFFE = 1, + ONNX = 2, }; template diff --git a/mace/ops/deconv_2d.cc b/mace/ops/deconv_2d.cc index 6453544ae92c75efc5560ef5f157dcbbfedb13d5..9eba1bd014c8c21e835a2b4d72664b6a19769951 100644 --- a/mace/ops/deconv_2d.cc +++ b/mace/ops/deconv_2d.cc @@ -66,12 +66,12 @@ class Deconv2dOp : public Deconv2dOpBase { const Tensor *filter = this->Input(1); const Tensor *bias = nullptr; const Tensor *output_shape_tensor = nullptr; - if (model_type_ == CAFFE) { - bias = this->InputSize() >= 3 ? this->Input(2) : nullptr; - } else { + if (model_type_ == TENSORFLOW) { output_shape_tensor = this->InputSize() >= 3 ? this->Input(2) : nullptr; bias = this->InputSize() >= 4 ? this->Input(3) : nullptr; + } else { + bias = this->InputSize() >= 3 ? this->Input(2) : nullptr; } Tensor *output = this->Output(0); @@ -182,13 +182,7 @@ class Deconv2dOp : public Deconv2dOpBase { context, operator_def_.get(), 1, OpenCLBufferType::CONV2D_FILTER, mem_type) == MaceStatus::MACE_SUCCESS); - if (model_type_ == FrameworkType::CAFFE) { - if (operator_def_->input_size() >= 3) { - MACE_CHECK(TransformFilter( - context, operator_def_.get(), 2, - OpenCLBufferType::ARGUMENT, mem_type) == MaceStatus::MACE_SUCCESS); - } - } else { + if (model_type_ == FrameworkType::TENSORFLOW) { if (operator_def_->input_size() >= 4) { MACE_CHECK(TransformFilter( context, @@ -197,6 +191,12 @@ class Deconv2dOp : public Deconv2dOpBase { OpenCLBufferType::ARGUMENT, mem_type) == MaceStatus::MACE_SUCCESS); } + } else { + if (operator_def_->input_size() >= 3) { + MACE_CHECK(TransformFilter( + context, operator_def_.get(), 2, + OpenCLBufferType::ARGUMENT, mem_type) == MaceStatus::MACE_SUCCESS); + } } } MaceStatus Run(OpContext *context) override { @@ -204,12 +204,12 @@ class Deconv2dOp : public Deconv2dOpBase { const Tensor *filter = this->Input(1); const Tensor *bias = nullptr; const Tensor *output_shape_tensor = nullptr; - if (model_type_ == CAFFE) { - bias = this->InputSize() >= 3 ? this->Input(2) : nullptr; - } else { + if (model_type_ == TENSORFLOW) { output_shape_tensor = this->InputSize() >= 3 ? this->Input(2) : nullptr; bias = this->InputSize() >= 4 ? this->Input(3) : nullptr; + } else { + bias = this->InputSize() >= 3 ? this->Input(2) : nullptr; } Tensor *output = this->Output(0);