提交 00df570e 编写于 作者: W wjj19950828

resolve conflict

......@@ -46,11 +46,10 @@ def make_grid(tensor: Union[paddle.Tensor, List[paddle.Tensor]],
if tensor.dim() == 2: # single image H x W
tensor = tensor.unsqueeze(0)
if tensor.dim() == 3: # single image
if tensor.size(0) == 1: # if single-channel, convert to 3-channel
if tensor.shape[0] == 1: # if single-channel, convert to 3-channel
tensor = paddle.concat((tensor, tensor, tensor), 0)
tensor = tensor.unsqueeze(0)
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
if tensor.dim() == 4 and tensor.shape[1] == 1: # single-channel images
tensor = paddle.concat((tensor, tensor, tensor), 1)
if normalize is True:
......@@ -75,11 +74,11 @@ def make_grid(tensor: Union[paddle.Tensor, List[paddle.Tensor]],
else:
norm_range(tensor, value_range)
if tensor.size(0) == 1:
if tensor.shape[0] == 1:
return tensor.squeeze(0)
# make the mini-batch of images into a grid
nmaps = tensor.size(0)
nmaps = tensor.shape[0]
xmaps = min(nrow, nmaps)
ymaps = int(math.ceil(float(nmaps) / xmaps))
height, width = int(tensor.shape[2] + padding), int(tensor.shape[3] +
......
__version__ = "1.3.7"
__version__ = "1.3.8"
from .core.program import PaddleGraph
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
......@@ -13,7 +13,15 @@
# limitations under the License.
import sys
from x2paddle.op_mapper.onnx2paddle.opset import OpSet7, OpSet8, OpSet9, OpSet10, OpSet11, OpSet12, OpSet13, OpSet14, OpSet15
from .opset7 import OpSet7
from .opset8 import OpSet8
from .opset9 import OpSet9
from .opset10 import OpSet10
from .opset11 import OpSet11
from .opset12 import OpSet12
from .opset13 import OpSet13
from .opset14 import OpSet14
from .opset15 import OpSet15
from x2paddle.decoder.onnx_decoder import ONNXGraphNode
from x2paddle.core.program import PaddleGraph
......
from .opset7 import OpSet7
from .opset8 import OpSet8
from .opset9 import OpSet9
from .opset10 import OpSet10
from .opset11 import OpSet11
from .opset12 import OpSet12
from .opset13 import OpSet13
from .opset14 import OpSet14
from .opset15 import OpSet15
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .opset import OpSet
from .opset_legacy import OpSet
class OpSet10(OpSet):
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .opset import OpSet
from .opset_legacy import OpSet
class OpSet11(OpSet):
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .opset import OpSet
from .opset_legacy import OpSet
class OpSet12(OpSet):
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .opset import OpSet
from .opset_legacy import OpSet
class OpSet13(OpSet):
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .opset import OpSet
from .opset_legacy import OpSet
class OpSet14(OpSet):
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .opset import OpSet
from .opset_legacy import OpSet
class OpSet15(OpSet):
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .opset import OpSet
from .opset_legacy import OpSet
class OpSet7(OpSet):
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .opset import OpSet
from .opset_legacy import OpSet
class OpSet8(OpSet):
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .opset import OpSet
from .opset_legacy import OpSet
class OpSet9(OpSet):
......
......@@ -72,10 +72,15 @@ def prim_add_(layer,
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {} + {} * {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs),
layer.attrs["alpha"],
get_value(layer, "y", different_attrs))
if abs(layer.attrs["alpha"] - 1.) < 1e-6:
line = "{} = {} + {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs))
else:
line = "{} = {} + {} * {}".format(
layer.outputs[0],
get_value(layer, "x", different_attrs), layer.attrs["alpha"],
get_value(layer, "y", different_attrs))
forward_func.extend(gen_codes([line], indent=indent))
......
......@@ -169,6 +169,9 @@ pd_reshape = partial(paddle.Tensor.reshape)
@add_tensor_function
def reshape(self, *shape):
# deal with list or tuple type
if isinstance(shape, (list, tuple)):
shape = shape[0]
return pd_reshape(self, shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册