提交 d8bdd9bf 编写于 作者: L liutuo

fix onnx deconv

上级 708aa21e
......@@ -64,6 +64,7 @@ MACE_MAPPING_DATA_TYPE_AND_ENUM(int32_t, DT_INT32);
enum FrameworkType {
TENSORFLOW = 0,
CAFFE = 1,
ONNX = 2,
};
template <typename T>
......
......@@ -66,12 +66,12 @@ class Deconv2dOp<DeviceType::CPU, float> : public Deconv2dOpBase {
const Tensor *filter = this->Input(1);
const Tensor *bias = nullptr;
const Tensor *output_shape_tensor = nullptr;
if (model_type_ == CAFFE) {
bias = this->InputSize() >= 3 ? this->Input(2) : nullptr;
} else {
if (model_type_ == TENSORFLOW) {
output_shape_tensor =
this->InputSize() >= 3 ? this->Input(2) : nullptr;
bias = this->InputSize() >= 4 ? this->Input(3) : nullptr;
} else {
bias = this->InputSize() >= 3 ? this->Input(2) : nullptr;
}
Tensor *output = this->Output(0);
......@@ -182,13 +182,7 @@ class Deconv2dOp<DeviceType::GPU, float> : public Deconv2dOpBase {
context, operator_def_.get(), 1,
OpenCLBufferType::CONV2D_FILTER, mem_type)
== MaceStatus::MACE_SUCCESS);
if (model_type_ == FrameworkType::CAFFE) {
if (operator_def_->input_size() >= 3) {
MACE_CHECK(TransformFilter(
context, operator_def_.get(), 2,
OpenCLBufferType::ARGUMENT, mem_type) == MaceStatus::MACE_SUCCESS);
}
} else {
if (model_type_ == FrameworkType::TENSORFLOW) {
if (operator_def_->input_size() >= 4) {
MACE_CHECK(TransformFilter(
context,
......@@ -197,6 +191,12 @@ class Deconv2dOp<DeviceType::GPU, float> : public Deconv2dOpBase {
OpenCLBufferType::ARGUMENT,
mem_type) == MaceStatus::MACE_SUCCESS);
}
} else {
if (operator_def_->input_size() >= 3) {
MACE_CHECK(TransformFilter(
context, operator_def_.get(), 2,
OpenCLBufferType::ARGUMENT, mem_type) == MaceStatus::MACE_SUCCESS);
}
}
}
MaceStatus Run(OpContext *context) override {
......@@ -204,12 +204,12 @@ class Deconv2dOp<DeviceType::GPU, float> : public Deconv2dOpBase {
const Tensor *filter = this->Input(1);
const Tensor *bias = nullptr;
const Tensor *output_shape_tensor = nullptr;
if (model_type_ == CAFFE) {
bias = this->InputSize() >= 3 ? this->Input(2) : nullptr;
} else {
if (model_type_ == TENSORFLOW) {
output_shape_tensor =
this->InputSize() >= 3 ? this->Input(2) : nullptr;
bias = this->InputSize() >= 4 ? this->Input(3) : nullptr;
} else {
bias = this->InputSize() >= 3 ? this->Input(2) : nullptr;
}
Tensor *output = this->Output(0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册