未验证 提交 346689c6 编写于 作者: L LielinJiang 提交者: GitHub

Register conv_transpose Op version for compatible Op upgrades (#26745)

* fix bug

* add version check

* fix docs, test=document_fix

* fix formula, test=document_fix
上级 8bcb1f29
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/cudnn_workspace_helper.h" #include "paddle/fluid/platform/cudnn_workspace_helper.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
...@@ -567,3 +568,14 @@ REGISTER_OP_CPU_KERNEL( ...@@ -567,3 +568,14 @@ REGISTER_OP_CPU_KERNEL(
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>, ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
double>); double>);
REGISTER_OP_VERSION(conv_transpose)
.AddCheckpoint(
R"ROC(
Upgrade convtranspose add a new attribute [output_padding].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"output_padding",
"In order to add additional size to one side of each dimension "
"in the output",
{}));
...@@ -350,7 +350,7 @@ def conv2d(x, ...@@ -350,7 +350,7 @@ def conv2d(x,
For each input :math:`X`, the equation is: For each input :math:`X`, the equation is:
.. math:: .. math::
Out = \sigma (W \\ast X + b) Out = \sigma (W \\ast X + b)
...@@ -377,7 +377,7 @@ def conv2d(x, ...@@ -377,7 +377,7 @@ def conv2d(x,
Where Where
.. math:: .. math::
H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1 \\\\ H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1 \\\\
W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1 W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1
...@@ -807,10 +807,10 @@ def conv_transpose2d(x, ...@@ -807,10 +807,10 @@ def conv_transpose2d(x,
stride=1, stride=1,
padding=0, padding=0,
output_padding=0, output_padding=0,
groups=1,
dilation=1, dilation=1,
data_format='NCHW', groups=1,
output_size=None, output_size=None,
data_format='NCHW',
name=None): name=None):
""" """
...@@ -829,7 +829,7 @@ def conv_transpose2d(x, ...@@ -829,7 +829,7 @@ def conv_transpose2d(x,
For each input :math:`X`, the equation is: For each input :math:`X`, the equation is:
.. math:: .. math::
Out = \sigma (W \\ast X + b) Out = \sigma (W \\ast X + b)
...@@ -856,7 +856,7 @@ def conv_transpose2d(x, ...@@ -856,7 +856,7 @@ def conv_transpose2d(x,
Where Where
.. math:: .. math::
H^\prime_{out} &= (H_{in} - 1) * strides[0] - pad_height_top - pad_height_bottom + dilations[0] * (H_f - 1) + 1 \\\\ H^\prime_{out} &= (H_{in} - 1) * strides[0] - pad_height_top - pad_height_bottom + dilations[0] * (H_f - 1) + 1 \\\\
W^\prime_{out} &= (W_{in} - 1) * strides[1] - pad_width_left - pad_width_right + dilations[1] * (W_f - 1) + 1 \\\\ W^\prime_{out} &= (W_{in} - 1) * strides[1] - pad_width_left - pad_width_right + dilations[1] * (W_f - 1) + 1 \\\\
...@@ -883,28 +883,27 @@ def conv_transpose2d(x, ...@@ -883,28 +883,27 @@ def conv_transpose2d(x,
stride(int|list|tuple, optional): The stride size. It means the stride in transposed convolution. stride(int|list|tuple, optional): The stride size. It means the stride in transposed convolution.
If stride is a tuple, it must contain two integers, (stride_height, stride_width). If stride is a tuple, it must contain two integers, (stride_height, stride_width).
Otherwise, stride_height = stride_width = stride. Default: stride = 1. Otherwise, stride_height = stride_width = stride. Default: stride = 1.
padding(int|list|str|tuple, optional): The padding size. The padding argument effectively adds padding(str|int|list|tuple, optional): The padding size. It means the number of zero-paddings
`dilation * (kernel - 1)` amount of zero-padding on both sides of input. If `padding` is a on both sides for each dimension. If `padding` is a string, either 'VALID' or
string, either 'VALID' or 'SAME' supported, which is the padding algorithm. 'SAME' which is the padding algorithm. If padding size is a tuple or list,
If `padding` is a tuple or list, it could be in three forms: it could be in three forms: `[pad_height, pad_width]` or
`[pad_height, pad_width]` or `[pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`,
`[pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`, and and when `data_format` is `"NCHW"`, `pool_padding` can be in the form
when `data_format` is `'NCHW'`, `[[0,0], [0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`.
`padding` can be in the form `[[0,0], [0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`. when `data_format` is `"NHWC"`, `pool_padding` can be in the form
when `data_format` is `'NHWC'`, `padding` can be in the form
`[[0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`. `[[0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`.
Default: padding = 0. Default: padding = 0.
output_padding(int|list|tuple, optional): Additional size added to one side output_padding(int|list|tuple, optional): Additional size added to one side
of each dimension in the output shape. Default: 0. of each dimension in the output shape. Default: 0.
dilation(int|list|tuple, optional): The dilation size. It means the spacing between the kernel points.
If dilation is a tuple, it must contain two integers, (dilation_height, dilation_width).
Otherwise, dilation_height = dilation_width = dilation. Default: dilation = 1.
groups(int, optional): The groups number of the Conv2d transpose layer. Inspired by groups(int, optional): The groups number of the Conv2d transpose layer. Inspired by
grouped convolution in Alex Krizhevsky's Deep CNN paper, in which grouped convolution in Alex Krizhevsky's Deep CNN paper, in which
when group=2, the first half of the filters is only connected to the when group=2, the first half of the filters is only connected to the
first half of the input channels, while the second half of the first half of the input channels, while the second half of the
filters is only connected to the second half of the input channels. filters is only connected to the second half of the input channels.
Default: groups = 1. Default: groups = 1.
dilation(int|list|tuple, optional): The dilation size. It means the spacing between the kernel points.
If dilation is a tuple, it must contain two integers, (dilation_height, dilation_width).
Otherwise, dilation_height = dilation_width = dilation. Default: dilation = 1.
output_size(int|tuple|list, optional): The output image size. If output size is a output_size(int|tuple|list, optional): The output image size. If output size is a
tuple, it must contain two integers, (image_height, image_width). None if use tuple, it must contain two integers, (image_height, image_width). None if use
filter_size, padding, and stride to calculate output_size. filter_size, padding, and stride to calculate output_size.
...@@ -950,7 +949,7 @@ def conv_transpose2d(x, ...@@ -950,7 +949,7 @@ def conv_transpose2d(x,
paddle.disable_static() paddle.disable_static()
x_var = paddle.to_tensor(x) x_var = paddle.to_tensor(x)
w_var = paddle.to_tensor(w) w_var = paddle.to_tensor(w)
y_var = F.conv2d_transpose(x_var, w_var) y_var = F.conv_transpose2d(x_var, w_var)
y_np = y_var.numpy() y_np = y_var.numpy()
print(y_np.shape) print(y_np.shape)
...@@ -1070,7 +1069,7 @@ def conv3d(x, ...@@ -1070,7 +1069,7 @@ def conv3d(x,
For each input :math:`X`, the equation is: For each input :math:`X`, the equation is:
.. math:: .. math::
Out = \sigma (W \\ast X + b) Out = \sigma (W \\ast X + b)
...@@ -1096,7 +1095,7 @@ def conv3d(x, ...@@ -1096,7 +1095,7 @@ def conv3d(x,
Where Where
.. math:: .. math::
D_{out}&= \\frac{(D_{in} + 2 * paddings[0] - (dilations[0] * (D_f - 1) + 1))}{strides[0]} + 1 \\\\ D_{out}&= \\frac{(D_{in} + 2 * paddings[0] - (dilations[0] * (D_f - 1) + 1))}{strides[0]} + 1 \\\\
H_{out}&= \\frac{(H_{in} + 2 * paddings[1] - (dilations[1] * (H_f - 1) + 1))}{strides[1]} + 1 \\\\ H_{out}&= \\frac{(H_{in} + 2 * paddings[1] - (dilations[1] * (H_f - 1) + 1))}{strides[1]} + 1 \\\\
...@@ -1160,20 +1159,18 @@ def conv3d(x, ...@@ -1160,20 +1159,18 @@ def conv3d(x,
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddle import fluid
import paddle.nn.functional as F
import paddle.fluid.dygraph as dg
import numpy as np import numpy as np
import paddle
import paddle.nn.functional as F
x = np.random.randn(2, 3, 8, 8, 8).astype(np.float32) x = np.random.randn(2, 3, 8, 8, 8).astype(np.float32)
w = np.random.randn(6, 3, 3, 3, 3).astype(np.float32) w = np.random.randn(6, 3, 3, 3, 3).astype(np.float32)
place = fluid.CPUPlace() paddle.disable_static()
with dg.guard(place): x_var = paddle.to_tensor(x)
x_var = dg.to_variable(x) w_var = paddle.to_tensor(w)
w_var = dg.to_variable(w) y_var = F.conv3d(x_var, w_var)
y_var = F.conv3d(x_var, w_var, act="relu") y_np = y_var.numpy()
y_np = y_var.numpy()
print(y_np.shape) print(y_np.shape)
# (2, 6, 6, 6, 6) # (2, 6, 6, 6, 6)
...@@ -1260,8 +1257,8 @@ def conv_transpose3d(x, ...@@ -1260,8 +1257,8 @@ def conv_transpose3d(x,
output_padding=0, output_padding=0,
groups=1, groups=1,
dilation=1, dilation=1,
data_format='NCDHW',
output_size=None, output_size=None,
data_format='NCDHW',
name=None): name=None):
""" """
The convolution3d transpose layer calculates the output based on the input, The convolution3d transpose layer calculates the output based on the input,
...@@ -1279,7 +1276,7 @@ def conv_transpose3d(x, ...@@ -1279,7 +1276,7 @@ def conv_transpose3d(x,
For each input :math:`X`, the equation is: For each input :math:`X`, the equation is:
.. math:: .. math::
Out = \sigma (W \\ast X + b) Out = \sigma (W \\ast X + b)
...@@ -1306,7 +1303,7 @@ def conv_transpose3d(x, ...@@ -1306,7 +1303,7 @@ def conv_transpose3d(x,
Where Where
.. math:: .. math::
D^\prime_{out} &= (D_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (D_f - 1) + 1 \\\\ D^\prime_{out} &= (D_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (D_f - 1) + 1 \\\\
H^\prime_{out} &= (H_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (H_f - 1) + 1 \\\\ H^\prime_{out} &= (H_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (H_f - 1) + 1 \\\\
...@@ -1338,37 +1335,37 @@ def conv_transpose3d(x, ...@@ -1338,37 +1335,37 @@ def conv_transpose3d(x,
If stride is a tuple, it must contain three integers, (stride_depth, stride_height, If stride is a tuple, it must contain three integers, (stride_depth, stride_height,
stride_width). Otherwise, stride_depth = stride_height = stride_width = stride. stride_width). Otherwise, stride_depth = stride_height = stride_width = stride.
Default: stride = 1. Default: stride = 1.
padding(int|list|str|tuple, optional): The padding size. The padding argument effectively padding (string|int|list|tuple, optional): The padding size. It means the number of zero-paddings
adds `dilation * (kernel - 1)` amount of zero-padding on both sides of input. If `padding` is a string, on both sides for each dimension. If `padding` is a string, either 'VALID' or
either 'VALID' or 'SAME' supported, which is the padding algorithm. If `padding` 'SAME' which is the padding algorithm. If padding size is a tuple or list,
is a tuple or list, it could be in three forms: `[pad_depth, pad_height, pad_width]` or it could be in three forms: `[pad_depth, pad_height, pad_width]` or
`[pad_depth_front, pad_depth_back, pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`, `[pad_depth_front, pad_depth_back, pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`,
and when `data_format` is `'NCDHW'`, `padding` can be in the form and when `data_format` is `"NCDHW"`, `pool_padding` can be in the form
`[[0,0], [0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`. `[[0,0], [0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`.
when `data_format` is `'NDHWC'`, `padding` can be in the form when `data_format` is `"NDHWC"`, `pool_padding` can be in the form
`[[0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`. `[[0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`.
Default: padding = 0. Default: padding = 0.
output_padding(int|list|tuple, optional): Additional size added to one side output_padding(int|list|tuple, optional): Additional size added to one side
of each dimension in the output shape. Default: 0. of each dimension in the output shape. Default: 0.
dilation(int|list|tuple, optional): The dilation size. It means the spacing between the kernel points.
If dilation is a tuple, it must contain three integers, (dilation_depth, dilation_height,
dilation_width). Otherwise, dilation_depth = dilation_height = dilation_width = dilation.
Default: dilation = 1.
groups(int, optional): The groups number of the Conv3d transpose layer. Inspired by groups(int, optional): The groups number of the Conv3d transpose layer. Inspired by
grouped convolution in Alex Krizhevsky's Deep CNN paper, in which grouped convolution in Alex Krizhevsky's Deep CNN paper, in which
when group=2, the first half of the filters is only connected to the when group=2, the first half of the filters is only connected to the
first half of the input channels, while the second half of the first half of the input channels, while the second half of the
filters is only connected to the second half of the input channels. filters is only connected to the second half of the input channels.
Default: groups=1 Default: groups=1
data_format (str, optional): Specify the data format of the input, and the data format of the output dilation(int|list|tuple, optional): The dilation size. It means the spacing between the kernel points.
will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`. If dilation is a tuple, it must contain three integers, (dilation_depth, dilation_height,
The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: dilation_width). Otherwise, dilation_depth = dilation_height = dilation_width = dilation.
`[batch_size, input_channels, input_height, input_width]`. Default: dilation = 1.
output_size(int|list|tuple, optional): The output image size. If output size is a output_size(int|list|tuple, optional): The output image size. If output size is a
tuple, it must contain three integers, (image_depth, image_height, image_width). This tuple, it must contain three integers, (image_depth, image_height, image_width). This
parameter only works when filter_size is None. If output_size and filter_size are parameter only works when filter_size is None. If output_size and filter_size are
specified at the same time, They should follow the formula above. Default: None. specified at the same time, They should follow the formula above. Default: None.
Output_size and filter_size should not be None at the same time. Output_size and filter_size should not be None at the same time.
data_format (str, optional): Specify the data format of the input, and the data format of the output
will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`.
The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`.
name(str, optional): For detailed information, please refer name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and to :ref:`api_guide_Name`. Usually name is no need to set and
None by default. None by default.
......
...@@ -784,30 +784,30 @@ def kl_div(input, label, reduction='mean', name=None): ...@@ -784,30 +784,30 @@ def kl_div(input, label, reduction='mean', name=None):
import numpy as np import numpy as np
import paddle.nn.functional as F import paddle.nn.functional as F
paddle.enable_imperative() paddle.disable_static()
shape = (5, 20) shape = (5, 20)
input = np.random.uniform(-10, 10, shape).astype('float32') input = np.random.uniform(-10, 10, shape).astype('float32')
target = np.random.uniform(-10, 10, shape).astype('float32') target = np.random.uniform(-10, 10, shape).astype('float32')
# 'batchmean' reduction, loss shape will be [N] # 'batchmean' reduction, loss shape will be [N]
pred_loss = F.kl_div(paddle.to_variable(input), pred_loss = F.kl_div(paddle.to_tensor(input),
paddle.to_variable(target), reduction='batchmean') paddle.to_tensor(target), reduction='batchmean')
# shape=[5] # shape=[5]
# 'mean' reduction, loss shape will be [1] # 'mean' reduction, loss shape will be [1]
pred_loss = F.kl_div(paddle.to_variable(input), pred_loss = F.kl_div(paddle.to_tensor(input),
paddle.to_variable(target), reduction='mean') paddle.to_tensor(target), reduction='mean')
# shape=[1] # shape=[1]
# 'sum' reduction, loss shape will be [1] # 'sum' reduction, loss shape will be [1]
pred_loss = F.kl_div(paddle.to_variable(input), pred_loss = F.kl_div(paddle.to_tensor(input),
paddle.to_variable(target), reduction='sum') paddle.to_tensor(target), reduction='sum')
# shape=[1] # shape=[1]
# 'none' reduction, loss shape is same with input shape # 'none' reduction, loss shape is same with input shape
pred_loss = F.kl_div(paddle.to_variable(input), pred_loss = F.kl_div(paddle.to_tensor(input),
paddle.to_variable(target), reduction='none') paddle.to_tensor(target), reduction='none')
# shape=[5, 20] # shape=[5, 20]
""" """
......
此差异已折叠。
...@@ -634,9 +634,12 @@ class KLDivLoss(fluid.dygraph.Layer): ...@@ -634,9 +634,12 @@ class KLDivLoss(fluid.dygraph.Layer):
Default is ``'mean'``. Default is ``'mean'``.
Shape: Shape:
- input: (N, *) where * means, any number of additional dimensions.
- label: (N, *), same shape as input - input (Tensor): (N, *), where * means, any number of additional dimensions.
- output: tensor with shape: (1) by default.
- label (Tensor): (N, *), same shape as input.
- output (Tensor): tensor with shape: [1] by default.
Examples: Examples:
...@@ -646,7 +649,7 @@ class KLDivLoss(fluid.dygraph.Layer): ...@@ -646,7 +649,7 @@ class KLDivLoss(fluid.dygraph.Layer):
import numpy as np import numpy as np
import paddle.nn as nn import paddle.nn as nn
paddle.enable_imperative() paddle.disable_static()
shape = (5, 20) shape = (5, 20)
x = np.random.uniform(-10, 10, shape).astype('float32') x = np.random.uniform(-10, 10, shape).astype('float32')
...@@ -654,26 +657,26 @@ class KLDivLoss(fluid.dygraph.Layer): ...@@ -654,26 +657,26 @@ class KLDivLoss(fluid.dygraph.Layer):
# 'batchmean' reduction, loss shape will be [N] # 'batchmean' reduction, loss shape will be [N]
kldiv_criterion = nn.KLDivLoss(reduction='batchmean') kldiv_criterion = nn.KLDivLoss(reduction='batchmean')
pred_loss = kldiv_criterion(paddle.to_variable(x), pred_loss = kldiv_criterion(paddle.to_tensor(x),
paddle.to_variable(target)) paddle.to_tensor(target))
# shape=[5] # shape=[5]
# 'mean' reduction, loss shape will be [1] # 'mean' reduction, loss shape will be [1]
kldiv_criterion = nn.KLDivLoss(reduction='mean') kldiv_criterion = nn.KLDivLoss(reduction='mean')
pred_loss = kldiv_criterion(paddle.to_variable(x), pred_loss = kldiv_criterion(paddle.to_tensor(x),
paddle.to_variable(target)) paddle.to_tensor(target))
# shape=[1] # shape=[1]
# 'sum' reduction, loss shape will be [1] # 'sum' reduction, loss shape will be [1]
kldiv_criterion = nn.KLDivLoss(reduction='sum') kldiv_criterion = nn.KLDivLoss(reduction='sum')
pred_loss = kldiv_criterion(paddle.to_variable(x), pred_loss = kldiv_criterion(paddle.to_tensor(x),
paddle.to_variable(target)) paddle.to_tensor(target))
# shape=[1] # shape=[1]
# 'none' reduction, loss shape is same with X shape # 'none' reduction, loss shape is same with X shape
kldiv_criterion = nn.KLDivLoss(reduction='none') kldiv_criterion = nn.KLDivLoss(reduction='none')
pred_loss = kldiv_criterion(paddle.to_variable(x), pred_loss = kldiv_criterion(paddle.to_tensor(x),
paddle.to_variable(target)) paddle.to_tensor(target))
# shape=[5, 20] # shape=[5, 20]
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册