diff --git a/mace/ops/deconv_2d.h b/mace/ops/deconv_2d.h index ec5b348e201f4048398ea2f3b8f69fca63c5337a..34f4739b95463b39a6a98792b74c01d08e4e78ff 100644 --- a/mace/ops/deconv_2d.h +++ b/mace/ops/deconv_2d.h @@ -16,6 +16,7 @@ #define MACE_OPS_DECONV_2D_H_ #include +#include #include "mace/core/operator.h" #include "mace/kernels/deconv_2d.h" @@ -34,8 +35,10 @@ class Deconv2dOp : public Operator { "padding", static_cast(SAME))), OperatorBase::GetRepeatedArgs("padding_values"), OperatorBase::GetRepeatedArgs("output_shape"), - kernels::ActivationType::NOOP, - 0.0f) {} + kernels::StringToActivationType( + OperatorBase::GetOptionalArg("activation", + "NOOP")), + OperatorBase::GetOptionalArg("max_limit", 0.0f)) {} MaceStatus Run(StatsFuture *future) override { const Tensor *input = this->Input(INPUT); diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index f6f0877de9b6ecda45a3117d2b22241ad0306203..a170de6e2bd2f64c9e529c41b22ec9f65c9f336c 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -403,6 +403,12 @@ class TensorflowConverter(base_converter.ConverterInterface): dilation_val = [1, 1] dilation_arg.ints.extend(dilation_val) else: + try: + dilation_val = tf_op.get_attr(tf_dilations_str)[1:3] + except ValueError: + dilation_val = [1, 1] + mace_check(dilation_val[0] == 1 and dilation_val[1] == 1, + "Mace only supports dilation == 1 conv2d_transpose.") mace_check(len(tf_op.inputs) >= 3, "deconv should have (>=) 3 inputs.") output_shape_arg = op.arg.add() diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 982ab5083c474ac0e6358841cdc49be4f3a4a701..9f3bd65473a499ad2d0265a596bbfe11fcddad0d 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -950,13 +950,13 @@ class Transformer(base_converter.ConverterInterface): return False def reshape_fc_weight(self): - print("Reshape fully connected weight shape") net = self._model filter_format = self.filter_format() for op in net.op: if op.type == MaceOp.FullyConnected.name: weight = self._consts[op.input[1]] if len(weight.dims) == 2: + print("Reshape fully connected weight shape") input_op = self._producer[op.input[0]] input_shape = list(input_op.output_shape[0].dims) weight.dims[:] = [weight.dims[0]] + input_shape[1:]