diff --git a/x2paddle/op_mapper/onnx_op_mapper.py b/x2paddle/op_mapper/onnx_op_mapper.py index ea670d7e815ea4f8566881aab115408526e443c8..2c9d48be5cdb69d9fbebce3936c2a88cae8bcc9f 100644 --- a/x2paddle/op_mapper/onnx_op_mapper.py +++ b/x2paddle/op_mapper/onnx_op_mapper.py @@ -1103,7 +1103,7 @@ class ONNXOpMapper(OpMapper): val_x = self.graph.get_input_node(node, idx=0, copy=True) where_name = node.layer_name + '_where' node.fluid_code.add_layer("where", - inputs=val_x.layer_name + '==1', + inputs=val_x.layer_name + '!=0', output=where_name) dims = len(val_x.out_shapes[0]) elements_count_val_x = reduce(lambda x, y: x * y, val_x.out_shapes[0])