未验证 提交 418a0967 编写于 作者: A Aurelius84 提交者: GitHub

move match_matrix var_conv2d et.al api into fluid.contrib test=develop (#19859)

上级 baccd7e2
......@@ -261,7 +261,6 @@ paddle.fluid.layers.maxout (ArgSpec(args=['x', 'groups', 'name'], varargs=None,
paddle.fluid.layers.space_to_depth (ArgSpec(args=['x', 'blocksize', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '26decdea9376b6b9a0d3432d82ca207b'))
paddle.fluid.layers.affine_grid (ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'f85b263b7b6698d000977529a28f202b'))
paddle.fluid.layers.sequence_reverse (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '65c8362e48810b8226e311c5d046db51'))
paddle.fluid.layers.sequence_topk_avg_pooling (ArgSpec(args=['input', 'row', 'col', 'topks', 'channel_num'], varargs=None, keywords=None, defaults=None), ('document', '1cee1bbbba8b567ae50509a38d9ec42a'))
paddle.fluid.layers.affine_channel (ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name', 'act'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None, None)), ('document', '9f303c67538e468a36c5904a0a3aa110'))
paddle.fluid.layers.similarity_focus (ArgSpec(args=['input', 'axis', 'indexes', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '18ec2e3afeb90e70c8b73d2b71c40fdb'))
paddle.fluid.layers.hash (ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None)), ('document', 'a0b73c21be618cec0281e7903039e5e3'))
......@@ -290,9 +289,7 @@ paddle.fluid.layers.sign (ArgSpec(args=['x'], varargs=None, keywords=None, defau
paddle.fluid.layers.deformable_conv (ArgSpec(args=['input', 'offset', 'mask', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'deformable_groups', 'im2col_step', 'param_attr', 'bias_attr', 'modulated', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, None, None, True, None)), ('document', '335193ac57d41d7199f8d26d30c069b1'))
paddle.fluid.layers.unfold (ArgSpec(args=['x', 'kernel_sizes', 'strides', 'paddings', 'dilations', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None)), ('document', '3f884662ad443d9ecc2b3734b4f61ad6'))
paddle.fluid.layers.deformable_roi_pooling (ArgSpec(args=['input', 'rois', 'trans', 'no_trans', 'spatial_scale', 'group_size', 'pooled_height', 'pooled_width', 'part_size', 'sample_per_part', 'trans_std', 'position_sensitive', 'name'], varargs=None, keywords=None, defaults=(False, 1.0, [1, 1], 1, 1, None, 1, 0.1, False, None)), ('document', '99c03e3f249e36854f87dedaa17c8f35'))
paddle.fluid.layers.match_matrix_tensor (ArgSpec(args=['x', 'y', 'channel_num', 'act', 'param_attr', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, 'float32', None)), ('document', 'b6ea7d4ddeacae85e37d1e47d5262948'))
paddle.fluid.layers.filter_by_instag (ArgSpec(args=['ins', 'ins_tag', 'filter_tag', 'is_lod'], varargs=None, keywords=None, defaults=None), ('document', '7703a2088af8de4128b143ff1164ca4a'))
paddle.fluid.layers.var_conv_2d (ArgSpec(args=['input', 'row', 'col', 'input_channel', 'output_channel', 'filter_size', 'stride', 'param_attr', 'act', 'dtype', 'name'], varargs=None, keywords=None, defaults=(1, None, None, 'float32', None)), ('document', '7a8b8ade5512c95f9ea30261d33ded6c'))
paddle.fluid.layers.shard_index (ArgSpec(args=['input', 'index_num', 'nshards', 'shard_id', 'ignore_value'], varargs=None, keywords=None, defaults=(-1,)), ('document', '5786fdbba6753ecd6cbce5e6b0889924'))
paddle.fluid.layers.hard_swish (ArgSpec(args=['x', 'threshold', 'scale', 'offset', 'name'], varargs=None, keywords=None, defaults=(6.0, 6.0, 3.0, None)), ('document', '6a5152a7015c62cb8278fc24cb456459'))
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '9d7806e31bdf727c1a23b8782a09b545'))
......@@ -513,6 +510,9 @@ paddle.fluid.contrib.mixed_precision.decorate (ArgSpec(args=['optimizer', 'amp_l
paddle.fluid.contrib.mixed_precision.AutoMixedPrecisionLists ('paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists', ('document', 'c116ec6bb5d30998792daea8db21ee40'))
paddle.fluid.contrib.mixed_precision.AutoMixedPrecisionLists.__init__ (ArgSpec(args=['self', 'custom_white_list', 'custom_black_list'], varargs=None, keywords=None, defaults=(None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.fused_elemwise_activation (ArgSpec(args=['x', 'y', 'functor_list', 'axis', 'scale', 'save_intermediate_out'], varargs=None, keywords=None, defaults=(-1, 0.0, True)), ('document', '1c4b247a2858cea8d9d8750693688270'))
paddle.fluid.contrib.sequence_topk_avg_pooling (ArgSpec(args=['input', 'row', 'col', 'topks', 'channel_num'], varargs=None, keywords=None, defaults=None), ('document', '5218c85dd4122b626da9bb92f3b50042'))
paddle.fluid.contrib.var_conv_2d (ArgSpec(args=['input', 'row', 'col', 'input_channel', 'output_channel', 'filter_size', 'stride', 'param_attr', 'act', 'dtype', 'name'], varargs=None, keywords=None, defaults=(1, None, None, 'float32', None)), ('document', 'f52a6edf6d3e970568788604da3329c2'))
paddle.fluid.contrib.match_matrix_tensor (ArgSpec(args=['x', 'y', 'channel_num', 'act', 'param_attr', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, 'float32', None)), ('document', '3bdc4b2891c1460bc630fdcd22766b21'))
paddle.fluid.contrib.BasicGRUUnit ('paddle.fluid.contrib.layers.rnn_impl.BasicGRUUnit', ('document', '2aed2540ed1540f081be9f4d08f2a65e'))
paddle.fluid.contrib.BasicGRUUnit.__init__ (ArgSpec(args=['self', 'name_scope', 'hidden_size', 'param_attr', 'bias_attr', 'gate_activation', 'activation', 'dtype'], varargs=None, keywords=None, defaults=(None, None, None, None, 'float32')), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.BasicGRUUnit.add_parameter (ArgSpec(args=['self', 'name', 'parameter'], varargs=None, keywords=None, defaults=None), ('document', 'f35ab374c7d5165c3daf3bd64a5a2ec1'))
......
......@@ -22,8 +22,14 @@ import six
import os
import inspect
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers import utils
__all__ = ['fused_elemwise_activation', ]
__all__ = [
'fused_elemwise_activation',
'sequence_topk_avg_pooling',
'var_conv_2d',
'match_matrix_tensor',
]
def fused_elemwise_activation(x,
......@@ -88,3 +94,270 @@ def fused_elemwise_activation(x,
'functor_list': functor_list
})
return out
def var_conv_2d(input,
row,
col,
input_channel,
output_channel,
filter_size,
stride=1,
param_attr=None,
act=None,
dtype='float32',
name=None):
"""
The var_conv_2d layer calculates the output base on the :attr:`input` with variable length,
row, col, input channel, filter size and strides. Both :attr:`input`, :attr:`row`,
and :attr:`col` are 1-level LodTensor. The covolution operation is same as conv2d layer with
padding. Besides, input.dims[1] should be 1.
.. code-block:: text
If input_channel is 2 and given row lodTensor and col lodTensor as follows:
row.lod = [[5, 4]]
col.lod = [[6, 7]]
input is a lodTensor:
input.lod = [[60, 56]] # where 60 = input_channel * 5 * 6
input.dims = [116, 1] # where 116 = 60 + 56
If set output_channel is 3, filter_size is [3, 3], stride is [1, 1]:
output.lod = [[90, 84]] # where 90 = output_channel * [(5-1)/stride + 1] * [(6-1)/stride + 1]
output.dims = [174, 1] # where 174 = 90 + 84
Args:
input (Variable): The input shoud be 1-level LodTensor with dims[1] equals 1.
row (Variable): The row shoud be 1-level LodTensor to provide height information.
col (Variable): The col shoud be 1-level LodTensor to provide width information.
input_channel (int): The number of input channel.
output_channel (int): The number of output channel.
filter_size (int|tuple|None): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square.
stride (int|tuple): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: stride = 1.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of var_conv2d. If it is set to None or one attribute of ParamAttr, var_conv2d
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with :math:`Normal(0.0, std)`,
and the :math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None.
act (str): Activation type, if it is set to None, activation is not appended.
Default: None
dtype ('float32'): The data type of parameter and output.
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None
Returns:
Variable: Output variable with LoD specified by this layer.
Examples:
.. code-block:: python
import numpy as np
from paddle.fluid import layers
from paddle.fluid import contrib
x_lod_tensor = layers.data(name='x', shape=[1], lod_level=1)
row_lod_tensor = layers.data(name='row', shape=[6], lod_level=1)
col_lod_tensor = layers.data(name='col', shape=[6], lod_level=1)
out = contrib.var_conv_2d(input=x_lod_tensor,
row=row_lod_tensor,
col=col_lod_tensor,
input_channel=3,
output_channel=5,
filter_size=[3, 3],
stride=1)
"""
helper = LayerHelper('var_conv_2d', **locals())
x_shape = list(input.shape)
assert len(x_shape) == 2
filter_size = utils.convert_to_list(filter_size, 2, 'filter_size')
stride = utils.convert_to_list(stride, 2, 'stride')
filter_shape = [
int(output_channel),
int(input_channel) * filter_size[0] * filter_size[1]
]
filter_param = helper.create_parameter(
attr=helper.param_attr,
shape=filter_shape,
dtype=dtype, )
conv_res = helper.create_variable_for_type_inference(dtype)
tmp_res = helper.create_variable_for_type_inference(
dtype, stop_gradient=True)
helper.append_op(
type='var_conv_2d',
inputs={
'X': input,
'ROW': row,
'COLUMN': col,
'W': filter_param,
},
outputs={"Out": conv_res,
"Col": tmp_res},
attrs={
'InputChannel': input_channel,
'OutputChannel': output_channel,
'StrideH': stride[0],
'StrideW': stride[1],
'KernelH': filter_size[0],
'KernelW': filter_size[1],
})
return helper.append_activation(conv_res)
def match_matrix_tensor(x,
y,
channel_num,
act=None,
param_attr=None,
dtype='float32',
name=None):
"""
Calculate the semantic matching matrix of two word sequences with variable length.
Given a query A of length `n` and a title B of length `m`, the input shape are respectively
[n, h] and [m, h], which h is hidden_size. If :attr:`channel_num` is set to 3,
it will generate a learnable parameter matrix W with shape [h, 3, h].
Then the semantic matching matrix of query A and title B is calculated by
A * W * B.T = [n, h]*[h, 3, h]*[h, m] = [n, 3, m]. The learnable parameter matrix `W`
is equivalent to a fully connected layer in the calculation process. If :attr:`act` is provided,
the corresponding activation function will be applied to output matrix.
The :attr:`x` and :attr:`y` should be LodTensor and only one level LoD is supported.
.. code-block:: text
Given a 1-level LoDTensor x:
x.lod = [[2, 3, ]]
x.data = [[0.3, 0.1], [0.2, 0.3], [0.5, 0.6], [0.7, 0.1], [0.3, 0.4]]
x.dims = [5, 2]
y is a Tensor:
y.lod = [[3, 1, ]]
y.data = [[0.1, 0.2], [0.3, 0.7], [0.9, 0.2], [0.4, 0.1]]
y.dims = [4, 2]
set channel_num 2, then we get a 1-level LoDTensor:
out.lod = [[12, 6]] # where 12 = channel_num * x.lod[0][0] * y.lod[0][0]
out.dims = [18, 1] # where 18 = 12 + 6
Args:
x (Variable): Input variable x which should be 1-level LodTensor.
y (Variable): Input variable y which should be 1-level LodTensor.
channel_num (int): The channel number of learnable parameter W.
act (str, default None): Activation to be applied to the output of this layer.
param_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for learnable
parameters/weights of this layer.
dtype ('float32'): The data type of w data.
name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. Default: None
Returns:
Variable: output with LoD specified by this layer.
Examples:
.. code-block:: python
import numpy as np
from paddle.fluid import layers
from paddle.fluid import contrib
x_lod_tensor = layers.data(name='x', shape=[10], lod_level=1)
y_lod_tensor = layers.data(name='y', shape=[10], lod_level=1)
out, out_tmp = contrib.match_matrix_tensor(x=x_lod_tensor, y=y_lod_tensor, channel_num=3)
"""
helper = LayerHelper('match_matrix_tensor', **locals())
x_shape = list(x.shape)
y_shape = list(y.shape)
assert len(x_shape) == 2 and len(y_shape) == 2 and x_shape[-1] == y_shape[
-1]
weight_shape = [x_shape[-1], channel_num, y_shape[-1]]
w = helper.create_parameter(
attr=helper.param_attr, shape=weight_shape, dtype=dtype, is_bias=False)
mm_res = helper.create_variable_for_type_inference(dtype)
tmp_res = helper.create_variable_for_type_inference(
dtype, stop_gradient=True)
helper.append_op(
type='match_matrix_tensor',
inputs={
'X': x,
'Y': y,
'W': w,
},
outputs={"Out": mm_res,
"Tmp": tmp_res},
attrs={'dim_t': channel_num})
return helper.append_activation(mm_res), tmp_res
def sequence_topk_avg_pooling(input, row, col, topks, channel_num):
"""
The :attr:`topks` is a list with incremental values in this function. For each topk,
it will average the topk features as an output feature for each channel of every
input sequence. Both :attr:`row` and :attr:`col` are LodTensor, which provide height
and width information for :attr:`input` tensor. If feature size of input sequence is less
than topk, it will padding 0 at the back.
.. code-block:: text
If channel_num is 2 and given row LoDTensor and col LoDTensor as follows:
row.lod = [[5, 4]]
col.lod = [[6, 7]]
input is a LoDTensor with input.lod[0][i] = channel_num * row.lod[0][i] * col.lod[0][i]
input.lod = [[60, 56]] # where 60 = channel_num * 5 * 6
input.dims = [116, 1] # where 116 = 60 + 56
If topks is [1, 3, 5], then we get a 1-level LoDTensor:
out.lod = [[5, 4]] # share Lod info with row LodTensor
out.dims = [9, 6] # where 6 = len(topks) * channel_num
Args:
input (Variable): The input should be 2D LodTensor with dims[1] equals 1.
row (Variable): The row shoud be 1-level LodTensor to provide the height information
of the input tensor data.
col (Variable): The col shoud be 1-level LodTensor to provide the width information
of the input tensor data.
topks (list): A list of incremental value to average the topk feature.
channel_num (int): The number of input channel.
Returns:
Variable: output LodTensor specified by this layer.
Examples:
.. code-block:: python
import numpy as np
from paddle.fluid import layers
from paddle.fluid import contrib
x_lod_tensor = layers.data(name='x', shape=[1], lod_level=1)
row_lod_tensor = layers.data(name='row', shape=[6], lod_level=1)
col_lod_tensor = layers.data(name='col', shape=[6], lod_level=1)
out = contrib.sequence_topk_avg_pooling(input=x_lod_tensor,
row=row_lod_tensor,
col=col_lod_tensor,
topks=[1, 3, 5],
channel_num=5)
"""
helper = LayerHelper('sequence_topk_avg_pooling', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
pos = helper.create_variable_for_type_inference(
dtype=helper.input_dtype(), stop_gradient=True)
helper.append_op(
type='sequence_topk_avg_pooling',
inputs={'X': input,
'ROW': row,
'COLUMN': col},
outputs={'Out': out,
'pos': pos},
attrs={'topks': topks,
'channel_num': channel_num})
return out
......@@ -189,7 +189,6 @@ __all__ = [
'space_to_depth',
'affine_grid',
'sequence_reverse',
'sequence_topk_avg_pooling',
'affine_channel',
'similarity_focus',
'hash',
......@@ -218,9 +217,7 @@ __all__ = [
'deformable_conv',
'unfold',
'deformable_roi_pooling',
'match_matrix_tensor',
'filter_by_instag',
'var_conv_2d',
'shard_index',
'hard_swish',
]
......@@ -12009,73 +12006,6 @@ def sequence_reverse(x, name=None):
return out
def sequence_topk_avg_pooling(input, row, col, topks, channel_num):
"""
The :attr:`topks` is a list with incremental values in this function. For each topk,
it will average the topk features as an output feature for each channel of every
input sequence. Both :attr:`row` and :attr:`col` are LodTensor, which provide height
and width information for :attr:`input` tensor. If feature size of input sequence is less
than topk, it will padding 0 at the back.
.. code-block:: text
If channel_num is 2 and given row LoDTensor and col LoDTensor as follows:
row.lod = [[5, 4]]
col.lod = [[6, 7]]
input is a LoDTensor with input.lod[0][i] = channel_num * row.lod[0][i] * col.lod[0][i]
input.lod = [[60, 56]] # where 60 = channel_num * 5 * 6
input.dims = [116, 1] # where 116 = 60 + 56
If topks is [1, 3, 5], then we get a 1-level LoDTensor:
out.lod = [[5, 4]] # share Lod info with row LodTensor
out.dims = [9, 6] # where 6 = len(topks) * channel_num
Args:
input (Variable): The input should be 2D LodTensor with dims[1] equals 1.
row (Variable): The row shoud be 1-level LodTensor to provide the height information
of the input tensor data.
col (Variable): The col shoud be 1-level LodTensor to provide the width information
of the input tensor data.
topks (list): A list of incremental value to average the topk feature.
channel_num (int): The number of input channel.
Returns:
Variable: output LodTensor specified by this layer.
Examples:
.. code-block:: python
import numpy as np
from paddle.fluid import layers
x_lod_tensor = layers.data(name='x', shape=[1], lod_level=1)
row_lod_tensor = layers.data(name='row', shape=[6], lod_level=1)
col_lod_tensor = layers.data(name='col', shape=[6], lod_level=1)
out = layers.sequence_topk_avg_pooling(input=x_lod_tensor,
row=row_lod_tensor,
col=col_lod_tensor,
topks=[1, 3, 5],
channel_num=5)
"""
helper = LayerHelper('sequence_topk_avg_pooling', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
pos = helper.create_variable_for_type_inference(
dtype=helper.input_dtype(), stop_gradient=True)
helper.append_op(
type='sequence_topk_avg_pooling',
inputs={'X': input,
'ROW': row,
'COLUMN': col},
outputs={'Out': out,
'pos': pos},
attrs={'topks': topks,
'channel_num': channel_num})
return out
def affine_channel(x,
scale=None,
bias=None,
......@@ -14116,203 +14046,6 @@ def deformable_roi_pooling(input,
return output
def var_conv_2d(input,
row,
col,
input_channel,
output_channel,
filter_size,
stride=1,
param_attr=None,
act=None,
dtype='float32',
name=None):
"""
The var_conv_2d layer calculates the output base on the :attr:`input` with variable length,
row, col, input channel, filter size and strides. Both :attr:`input`, :attr:`row`,
and :attr:`col` are 1-level LodTensor. The covolution operation is same as conv2d layer with
padding. Besides, input.dims[1] should be 1.
.. code-block:: text
If input_channel is 2 and given row lodTensor and col lodTensor as follows:
row.lod = [[5, 4]]
col.lod = [[6, 7]]
input is a lodTensor:
input.lod = [[60, 56]] # where 60 = input_channel * 5 * 6
input.dims = [116, 1] # where 116 = 60 + 56
If set output_channel is 3, filter_size is [3, 3], stride is [1, 1]:
output.lod = [[90, 84]] # where 90 = output_channel * [(5-1)/stride + 1] * [(6-1)/stride + 1]
output.dims = [174, 1] # where 174 = 90 + 84
Args:
input (Variable): The input shoud be 1-level LodTensor with dims[1] equals 1.
row (Variable): The row shoud be 1-level LodTensor to provide height information.
col (Variable): The col shoud be 1-level LodTensor to provide width information.
input_channel (int): The number of input channel.
output_channel (int): The number of output channel.
filter_size (int|tuple|None): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square.
stride (int|tuple): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: stride = 1.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of var_conv2d. If it is set to None or one attribute of ParamAttr, var_conv2d
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with :math:`Normal(0.0, std)`,
and the :math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None.
act (str): Activation type, if it is set to None, activation is not appended.
Default: None
dtype ('float32'): The data type of parameter and output.
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None
Returns:
Variable: Output variable with LoD specified by this layer.
Examples:
.. code-block:: python
import numpy as np
from paddle.fluid import layers
x_lod_tensor = layers.data(name='x', shape=[1], lod_level=1)
row_lod_tensor = layers.data(name='row', shape=[6], lod_level=1)
col_lod_tensor = layers.data(name='col', shape=[6], lod_level=1)
out = layers.var_conv_2d(input=x_lod_tensor,
row=row_lod_tensor,
col=col_lod_tensor,
input_channel=3,
output_channel=5,
filter_size=[3, 3],
stride=1)
"""
helper = LayerHelper('var_conv_2d', **locals())
x_shape = list(input.shape)
assert len(x_shape) == 2
filter_size = utils.convert_to_list(filter_size, 2, 'filter_size')
stride = utils.convert_to_list(stride, 2, 'stride')
filter_shape = [
int(output_channel),
int(input_channel) * filter_size[0] * filter_size[1]
]
filter_param = helper.create_parameter(
attr=helper.param_attr,
shape=filter_shape,
dtype=dtype, )
conv_res = helper.create_variable_for_type_inference(dtype)
tmp_res = helper.create_variable_for_type_inference(
dtype, stop_gradient=True)
helper.append_op(
type='var_conv_2d',
inputs={
'X': input,
'ROW': row,
'COLUMN': col,
'W': filter_param,
},
outputs={"Out": conv_res,
"Col": tmp_res},
attrs={
'InputChannel': input_channel,
'OutputChannel': output_channel,
'StrideH': stride[0],
'StrideW': stride[1],
'KernelH': filter_size[0],
'KernelW': filter_size[1],
})
return helper.append_activation(conv_res)
def match_matrix_tensor(x,
y,
channel_num,
act=None,
param_attr=None,
dtype='float32',
name=None):
"""
Calculate the semantic matching matrix of two word sequences with variable length.
Given a query A of length `n` and a title B of length `m`, the input shape are respectively
[n, h] and [m, h], which h is hidden_size. If :attr:`channel_num` is set to 3,
it will generate a learnable parameter matrix W with shape [h, 3, h].
Then the semantic matching matrix of query A and title B is calculated by
A * W * B.T = [n, h]*[h, 3, h]*[h, m] = [n, 3, m]. The learnable parameter matrix `W`
is equivalent to a fully connected layer in the calculation process. If :attr:`act` is provided,
the corresponding activation function will be applied to output matrix.
The :attr:`x` and :attr:`y` should be LodTensor and only one level LoD is supported.
.. code-block:: text
Given a 1-level LoDTensor x:
x.lod = [[2, 3, ]]
x.data = [[0.3, 0.1], [0.2, 0.3], [0.5, 0.6], [0.7, 0.1], [0.3, 0.4]]
x.dims = [5, 2]
y is a Tensor:
y.lod = [[3, 1, ]]
y.data = [[0.1, 0.2], [0.3, 0.7], [0.9, 0.2], [0.4, 0.1]]
y.dims = [4, 2]
set channel_num 2, then we get a 1-level LoDTensor:
out.lod = [[12, 6]] # where 12 = channel_num * x.lod[0][0] * y.lod[0][0]
out.dims = [18, 1] # where 18 = 12 + 6
Args:
x (Variable): Input variable x which should be 1-level LodTensor.
y (Variable): Input variable y which should be 1-level LodTensor.
channel_num (int): The channel number of learnable parameter W.
act (str, default None): Activation to be applied to the output of this layer.
param_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for learnable
parameters/weights of this layer.
dtype ('float32'): The data type of w data.
name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. Default: None
Returns:
Variable: output with LoD specified by this layer.
Examples:
.. code-block:: python
import numpy as np
from paddle.fluid import layers
x_lod_tensor = layers.data(name='x', shape=[10], lod_level=1)
y_lod_tensor = layers.data(name='y', shape=[10], lod_level=1)
out, out_tmp = layers.match_matrix_tensor(x=x_lod_tensor, y=y_lod_tensor, channel_num=3)
"""
helper = LayerHelper('match_matrix_tensor', **locals())
x_shape = list(x.shape)
y_shape = list(y.shape)
assert len(x_shape) == 2 and len(y_shape) == 2 and x_shape[-1] == y_shape[
-1]
weight_shape = [x_shape[-1], channel_num, y_shape[-1]]
w = helper.create_parameter(
attr=helper.param_attr, shape=weight_shape, dtype=dtype, is_bias=False)
mm_res = helper.create_variable_for_type_inference(dtype)
tmp_res = helper.create_variable_for_type_inference(
dtype, stop_gradient=True)
helper.append_op(
type='match_matrix_tensor',
inputs={
'X': x,
'Y': y,
'W': w,
},
outputs={"Out": mm_res,
"Tmp": tmp_res},
attrs={'dim_t': channel_num})
return helper.append_activation(mm_res), tmp_res
def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
"""
This layer creates the sharded index for input. This layers is used in
......
......@@ -111,7 +111,7 @@ class TestMatchMatrixTensorOpCase4(TestMatchMatrixTensorOp):
def test_api(self):
x_lod_tensor = fluid.layers.data(name='x', shape=[10], lod_level=1)
y_lod_tensor = fluid.layers.data(name='y', shape=[10], lod_level=1)
out, out_tmp = fluid.layers.match_matrix_tensor(
out, out_tmp = fluid.contrib.match_matrix_tensor(
x=x_lod_tensor, y=y_lod_tensor, channel_num=3)
place = fluid.CPUPlace()
......
......@@ -133,7 +133,7 @@ class TestSequenceTopkAvgPoolingOpCase1(TestSequenceTopkAvgPoolingOp):
x = fluid.layers.data(name='x', shape=[1], lod_level=1)
row = fluid.layers.data(name='row', shape=[10], lod_level=1)
col = fluid.layers.data(name='col', shape=[10], lod_level=1)
topk_avg = fluid.layers.sequence_topk_avg_pooling(
topk_avg = fluid.contrib.sequence_topk_avg_pooling(
input=x, row=row, col=col, topks=[1, 3, 5], channel_num=5)
place = fluid.CPUPlace()
......
......@@ -267,5 +267,39 @@ class TestVarConv2dOpCase7(TestVarConv2dOp):
col)
class TestVarConv2dApi(unittest.TestCase):
def test_api(self):
import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[1], lod_level=1)
row = fluid.layers.data(name='row', shape=[6], lod_level=1)
col = fluid.layers.data(name='col', shape=[6], lod_level=1)
out = fluid.contrib.var_conv_2d(
input=x,
row=row,
col=col,
input_channel=3,
output_channel=5,
filter_size=[3, 3],
stride=1)
place = fluid.CPUPlace()
x_tensor = fluid.create_lod_tensor(
np.random.rand(116, 1).astype('float32'), [[60, 56]], place)
row_tensor = fluid.create_lod_tensor(
np.random.rand(9, 6).astype('float32'), [[5, 4]], place)
col_tensor = fluid.create_lod_tensor(
np.random.rand(13, 6).astype('float32'), [[6, 7]], place)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
ret = exe.run(
feed={'x': x_tensor,
'row': row_tensor,
'col': col_tensor},
fetch_list=[out],
return_numpy=False)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册