提交 e540cbaa 编写于 作者: L liutuo

fix scaler eltwise op tranform bug

上级 2b1defc7
...@@ -30,7 +30,7 @@ class EltwiseOp : public Operator<D, T> { ...@@ -30,7 +30,7 @@ class EltwiseOp : public Operator<D, T> {
static_cast<kernels::EltwiseType>(OperatorBase::GetOptionalArg<int>( static_cast<kernels::EltwiseType>(OperatorBase::GetOptionalArg<int>(
"type", static_cast<int>(kernels::EltwiseType::NONE))), "type", static_cast<int>(kernels::EltwiseType::NONE))),
OperatorBase::GetRepeatedArgs<float>("coeff"), OperatorBase::GetRepeatedArgs<float>("coeff"),
OperatorBase::GetOptionalArg<float>("x", 1.0)) {} OperatorBase::GetOptionalArg<float>("value", 1.0)) {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
const Tensor *input0 = this->Input(0); const Tensor *input0 = this->Input(0);
......
...@@ -40,7 +40,7 @@ void SimpleTensorScalar(const kernels::EltwiseType type, ...@@ -40,7 +40,7 @@ void SimpleTensorScalar(const kernels::EltwiseType type,
OpDefBuilder("Eltwise", "EltwiseTest") OpDefBuilder("Eltwise", "EltwiseTest")
.Input("TInput") .Input("TInput")
.AddIntArg("type", static_cast<int>(type)) .AddIntArg("type", static_cast<int>(type))
.AddFloatArg("x", x) .AddFloatArg("value", x)
.Output("TOutput") .Output("TOutput")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
...@@ -52,7 +52,7 @@ void SimpleTensorScalar(const kernels::EltwiseType type, ...@@ -52,7 +52,7 @@ void SimpleTensorScalar(const kernels::EltwiseType type,
OpDefBuilder("Eltwise", "EltwiseTest") OpDefBuilder("Eltwise", "EltwiseTest")
.Input("InputImg") .Input("InputImg")
.AddIntArg("type", static_cast<int>(type)) .AddIntArg("type", static_cast<int>(type))
.AddFloatArg("x", x) .AddFloatArg("value", x)
.Output("OutputImg") .Output("OutputImg")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
...@@ -321,7 +321,7 @@ void RandomTensorScalar(const kernels::EltwiseType type, ...@@ -321,7 +321,7 @@ void RandomTensorScalar(const kernels::EltwiseType type,
OpDefBuilder("Eltwise", "EltwiseTest") OpDefBuilder("Eltwise", "EltwiseTest")
.Input("TInput") .Input("TInput")
.AddIntArg("type", static_cast<int>(type)) .AddIntArg("type", static_cast<int>(type))
.AddFloatArg("x", 0.1) .AddFloatArg("value", 0.1)
.Output("TOutput") .Output("TOutput")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
...@@ -336,7 +336,7 @@ void RandomTensorScalar(const kernels::EltwiseType type, ...@@ -336,7 +336,7 @@ void RandomTensorScalar(const kernels::EltwiseType type,
OpDefBuilder("Eltwise", "EltwiseTest") OpDefBuilder("Eltwise", "EltwiseTest")
.Input("InputImg") .Input("InputImg")
.AddIntArg("type", static_cast<int>(type)) .AddIntArg("type", static_cast<int>(type))
.AddFloatArg("x", 0.1) .AddFloatArg("value", 0.1)
.Output("OutputImg") .Output("OutputImg")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
......
...@@ -139,6 +139,7 @@ class MaceKeyword(object): ...@@ -139,6 +139,7 @@ class MaceKeyword(object):
mace_shape_str = 'shape' mace_shape_str = 'shape'
mace_winograd_filter_transformed = 'is_filter_transformed' mace_winograd_filter_transformed = 'is_filter_transformed'
mace_device = 'device' mace_device = 'device'
mace_value_str = 'value'
class TransformerRule(Enum): class TransformerRule(Enum):
......
...@@ -309,6 +309,19 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -309,6 +309,19 @@ class TensorflowConverter(base_converter.ConverterInterface):
type_arg.name = MaceKeyword.mace_element_type_str type_arg.name = MaceKeyword.mace_element_type_str
type_arg.i = self.eltwise_type[tf_op.type].value type_arg.i = self.eltwise_type[tf_op.type].value
if len(tf_op.inputs[0].shape) == 0:
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_value_str
value_arg.f = tf_op.inputs[0].eval().astype(np.float32)
self._skip_tensor.add(tf_op.inputs[0].name)
del op.input[0]
elif len(tf_op.inputs[1].shape) == 0:
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_value_str
value_arg.f = tf_op.inputs[1].eval().astype(np.float32)
self._skip_tensor.add(tf_op.inputs[1].name)
del op.input[1]
def convert_biasadd(self, tf_op): def convert_biasadd(self, tf_op):
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
op.type = MaceOp.BiasAdd.name op.type = MaceOp.BiasAdd.name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册