提交 00df570e 编写于 作者: W wjj19950828

resolve conflict

...@@ -46,11 +46,10 @@ def make_grid(tensor: Union[paddle.Tensor, List[paddle.Tensor]], ...@@ -46,11 +46,10 @@ def make_grid(tensor: Union[paddle.Tensor, List[paddle.Tensor]],
if tensor.dim() == 2: # single image H x W if tensor.dim() == 2: # single image H x W
tensor = tensor.unsqueeze(0) tensor = tensor.unsqueeze(0)
if tensor.dim() == 3: # single image if tensor.dim() == 3: # single image
if tensor.size(0) == 1: # if single-channel, convert to 3-channel if tensor.shape[0] == 1: # if single-channel, convert to 3-channel
tensor = paddle.concat((tensor, tensor, tensor), 0) tensor = paddle.concat((tensor, tensor, tensor), 0)
tensor = tensor.unsqueeze(0) tensor = tensor.unsqueeze(0)
if tensor.dim() == 4 and tensor.shape[1] == 1: # single-channel images
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
tensor = paddle.concat((tensor, tensor, tensor), 1) tensor = paddle.concat((tensor, tensor, tensor), 1)
if normalize is True: if normalize is True:
...@@ -75,11 +74,11 @@ def make_grid(tensor: Union[paddle.Tensor, List[paddle.Tensor]], ...@@ -75,11 +74,11 @@ def make_grid(tensor: Union[paddle.Tensor, List[paddle.Tensor]],
else: else:
norm_range(tensor, value_range) norm_range(tensor, value_range)
if tensor.size(0) == 1: if tensor.shape[0] == 1:
return tensor.squeeze(0) return tensor.squeeze(0)
# make the mini-batch of images into a grid # make the mini-batch of images into a grid
nmaps = tensor.size(0) nmaps = tensor.shape[0]
xmaps = min(nrow, nmaps) xmaps = min(nrow, nmaps)
ymaps = int(math.ceil(float(nmaps) / xmaps)) ymaps = int(math.ceil(float(nmaps) / xmaps))
height, width = int(tensor.shape[2] + padding), int(tensor.shape[3] + height, width = int(tensor.shape[2] + padding), int(tensor.shape[3] +
......
__version__ = "1.3.7" __version__ = "1.3.8"
from .core.program import PaddleGraph from .core.program import PaddleGraph
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,7 +13,15 @@ ...@@ -13,7 +13,15 @@
# limitations under the License. # limitations under the License.
import sys import sys
from x2paddle.op_mapper.onnx2paddle.opset import OpSet7, OpSet8, OpSet9, OpSet10, OpSet11, OpSet12, OpSet13, OpSet14, OpSet15 from .opset7 import OpSet7
from .opset8 import OpSet8
from .opset9 import OpSet9
from .opset10 import OpSet10
from .opset11 import OpSet11
from .opset12 import OpSet12
from .opset13 import OpSet13
from .opset14 import OpSet14
from .opset15 import OpSet15
from x2paddle.decoder.onnx_decoder import ONNXGraphNode from x2paddle.decoder.onnx_decoder import ONNXGraphNode
from x2paddle.core.program import PaddleGraph from x2paddle.core.program import PaddleGraph
......
from .opset7 import OpSet7
from .opset8 import OpSet8
from .opset9 import OpSet9
from .opset10 import OpSet10
from .opset11 import OpSet11
from .opset12 import OpSet12
from .opset13 import OpSet13
from .opset14 import OpSet14
from .opset15 import OpSet15
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .opset import OpSet from .opset_legacy import OpSet
class OpSet10(OpSet): class OpSet10(OpSet):
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .opset import OpSet from .opset_legacy import OpSet
class OpSet11(OpSet): class OpSet11(OpSet):
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .opset import OpSet from .opset_legacy import OpSet
class OpSet12(OpSet): class OpSet12(OpSet):
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .opset import OpSet from .opset_legacy import OpSet
class OpSet13(OpSet): class OpSet13(OpSet):
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .opset import OpSet from .opset_legacy import OpSet
class OpSet14(OpSet): class OpSet14(OpSet):
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .opset import OpSet from .opset_legacy import OpSet
class OpSet15(OpSet): class OpSet15(OpSet):
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .opset import OpSet from .opset_legacy import OpSet
class OpSet7(OpSet): class OpSet7(OpSet):
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .opset import OpSet from .opset_legacy import OpSet
class OpSet8(OpSet): class OpSet8(OpSet):
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .opset import OpSet from .opset_legacy import OpSet
class OpSet9(OpSet): class OpSet9(OpSet):
......
...@@ -72,9 +72,14 @@ def prim_add_(layer, ...@@ -72,9 +72,14 @@ def prim_add_(layer,
forward_func=[], forward_func=[],
layer_id=None, layer_id=None,
different_attrs=None): different_attrs=None):
line = "{} = {} + {} * {}".format(layer.outputs[0], if abs(layer.attrs["alpha"] - 1.) < 1e-6:
line = "{} = {} + {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
layer.attrs["alpha"], get_value(layer, "y", different_attrs))
else:
line = "{} = {} + {} * {}".format(
layer.outputs[0],
get_value(layer, "x", different_attrs), layer.attrs["alpha"],
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
......
...@@ -169,6 +169,9 @@ pd_reshape = partial(paddle.Tensor.reshape) ...@@ -169,6 +169,9 @@ pd_reshape = partial(paddle.Tensor.reshape)
@add_tensor_function @add_tensor_function
def reshape(self, *shape): def reshape(self, *shape):
# deal with list or tuple type
if isinstance(shape, (list, tuple)):
shape = shape[0]
return pd_reshape(self, shape) return pd_reshape(self, shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册