提交 3e8e6647 编写于 作者: L liutuo

fix deconv fold activation

上级 c2523ee2
......@@ -16,6 +16,7 @@
#define MACE_OPS_DECONV_2D_H_
#include <memory>
#include <string>
#include "mace/core/operator.h"
#include "mace/kernels/deconv_2d.h"
......@@ -34,8 +35,10 @@ class Deconv2dOp : public Operator<D, T> {
"padding", static_cast<int>(SAME))),
OperatorBase::GetRepeatedArgs<int>("padding_values"),
OperatorBase::GetRepeatedArgs<index_t>("output_shape"),
kernels::ActivationType::NOOP,
0.0f) {}
kernels::StringToActivationType(
OperatorBase::GetOptionalArg<std::string>("activation",
"NOOP")),
OperatorBase::GetOptionalArg<float>("max_limit", 0.0f)) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
......
......@@ -403,6 +403,12 @@ class TensorflowConverter(base_converter.ConverterInterface):
dilation_val = [1, 1]
dilation_arg.ints.extend(dilation_val)
else:
try:
dilation_val = tf_op.get_attr(tf_dilations_str)[1:3]
except ValueError:
dilation_val = [1, 1]
mace_check(dilation_val[0] == 1 and dilation_val[1] == 1,
"Mace only supports dilation == 1 conv2d_transpose.")
mace_check(len(tf_op.inputs) >= 3,
"deconv should have (>=) 3 inputs.")
output_shape_arg = op.arg.add()
......
......@@ -950,13 +950,13 @@ class Transformer(base_converter.ConverterInterface):
return False
def reshape_fc_weight(self):
print("Reshape fully connected weight shape")
net = self._model
filter_format = self.filter_format()
for op in net.op:
if op.type == MaceOp.FullyConnected.name:
weight = self._consts[op.input[1]]
if len(weight.dims) == 2:
print("Reshape fully connected weight shape")
input_op = self._producer[op.input[0]]
input_shape = list(input_op.output_shape[0].dims)
weight.dims[:] = [weight.dims[0]] + input_shape[1:]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册