未验证 提交 4b3e22b8 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #7574 from lcy-seso/wraper_for_l2_normalize

add python wrapper for l2 normalize layer.
...@@ -499,3 +499,8 @@ swish ...@@ -499,3 +499,8 @@ swish
------ ------
.. autofunction:: paddle.v2.fluid.layers.swish .. autofunction:: paddle.v2.fluid.layers.swish
:noindex: :noindex:
l2_normalize
------------
.. autofunction:: paddle.v2.fluid.layers.l2_normalize
:noindex:
...@@ -51,8 +51,8 @@ class ClipOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -51,8 +51,8 @@ class ClipOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
Clip Operator. Clip Operator.
The clip operator limits the value of given input within an interval. The interval is The clip operator limits the value of given input within an interval. The
specified with arguments 'min' and 'max': interval is specified with arguments 'min' and 'max':
$$ $$
Out = \min(\max(X, min), max) Out = \min(\max(X, min), max)
......
...@@ -26,9 +26,9 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -26,9 +26,9 @@ class ElementwiseOp : public framework::OperatorWithKernel {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of elementwise op should not be null"); "Input(X) of elementwise op should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of elementwise op should not be null"); "Input(Y) of elementwise op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of elementwise op should not be null."); "Output(Out) of elementwise op should not be null.");
...@@ -45,12 +45,12 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -45,12 +45,12 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ElementwiseOpMaker(OpProto* proto, OpAttrChecker* op_checker) ElementwiseOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The first input tensor of elementwise op"); AddInput("X", "(Tensor), The first input tensor of elementwise op.");
AddInput("Y", "(Tensor) The second input tensor of elementwise op"); AddInput("Y", "(Tensor), The second input tensor of elementwise op.");
AddOutput("Out", "The output of elementwise op"); AddOutput("Out", "The output of elementwise op.");
AddAttr<int>("axis", AddAttr<int>("axis",
"(int, default -1) The starting dimension index " "(int, default -1). The start dimension index "
"for broadcasting Y onto X") "for broadcasting Y onto X.")
.SetDefault(-1) .SetDefault(-1)
.EqualGreaterThan(-1); .EqualGreaterThan(-1);
comment_ = R"DOC( comment_ = R"DOC(
...@@ -58,19 +58,18 @@ Limited Elementwise {name} Operator. ...@@ -58,19 +58,18 @@ Limited Elementwise {name} Operator.
The equation is: The equation is:
.. math:: $${equation}$$
{equation}
X is a tensor of any dimension and the dimensions of tensor Y must be smaller than $X$ is a tensor of any dimension and the dimensions of tensor $Y$ must be
or equal to the dimensions of X. smaller than or equal to the dimensions of $X$.
There are two cases for this operator: There are two cases for this operator:
1. The shape of Y is same with X; 1. The shape of $Y$ is same with $X$;
2. The shape of Y is a subset of X. 2. The shape of $Y$ is a subset of $X$.
For case 2: For case 2:
Y will be broadcasted to match the shape of X and axis should be $Y$ will be broadcasted to match the shape of $X$ and axis should be
the starting dimension index for broadcasting Y onto X. set to index of the start dimension to broadcast $Y$ onto $X$.
For example For example
.. code-block:: python .. code-block:: python
...@@ -81,7 +80,8 @@ For example ...@@ -81,7 +80,8 @@ For example
shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0 shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
Either of the inputs X and Y or none can carry the LoD (Level of Details) information. However, the output only shares the LoD information with input X. Either of the inputs $X$ and $Y$ or none can carry the LoD (Level of Details)
information. However, the output only shares the LoD information with input $X$.
)DOC"; )DOC";
AddComment(comment_); AddComment(comment_);
......
...@@ -58,21 +58,21 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -58,21 +58,21 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
ExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker) ExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor, default Tensor<float>) A tensor with rank in [1, 6]." "(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"X is the input tensor to be expanded."); "X is the input to be expanded.");
AddOutput("Out", AddOutput("Out",
"(Tensor, default Tensor<float>) A tensor with rank in [1, 6]." "(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"The rank of Output(Out) is same as Input(X) except that each " "The rank of Output(Out) have the same with Input(X). "
"dimension size of Output(Out) is equal to corresponding " "After expanding, size of each dimension of Output(Out) is equal "
"dimension size of Input(X) multiplying corresponding value of " "to size of the corresponding dimension of Input(X) multiplying "
"Attr(expand_times)."); "the corresponding value given by Attr(expand_times).");
AddAttr<std::vector<int>>("expand_times", AddAttr<std::vector<int>>("expand_times",
"Expand times number for each dimension."); "Expand times number for each dimension.");
AddComment(R"DOC( AddComment(R"DOC(
Expand operator tiles the input by given times number. You should set times Expand operator tiles the input by given times number. You should set times
number for each dimension by providing attribute 'expand_times'. The rank of X number for each dimension by providing attribute 'expand_times'. The rank of X
should be in [1, 6]. Please notice that size of 'expand_times' must be same with should be in [1, 6]. Please note that size of 'expand_times' must be the same
X's rank. Following is a using case: with X's rank. Following is a using case:
Input(X) is a 3-D tensor with shape [2, 3, 1]: Input(X) is a 3-D tensor with shape [2, 3, 1]:
......
...@@ -16,13 +16,22 @@ from paddle.trainer.config_parser import * ...@@ -16,13 +16,22 @@ from paddle.trainer.config_parser import *
from default_decorators import * from default_decorators import *
__all__ = [ __all__ = [
"evaluator_base", "classification_error_evaluator", "auc_evaluator", "evaluator_base",
"pnpair_evaluator", "precision_recall_evaluator", "ctc_error_evaluator", "classification_error_evaluator",
"chunk_evaluator", "sum_evaluator", "column_sum_evaluator", "auc_evaluator",
"value_printer_evaluator", "gradient_printer_evaluator", "pnpair_evaluator",
"maxid_printer_evaluator", "maxframe_printer_evaluator", "precision_recall_evaluator",
"seqtext_printer_evaluator", "classification_error_printer_evaluator", "ctc_error_evaluator",
"detection_map_evaluator" "chunk_evaluator",
"sum_evaluator",
"column_sum_evaluator",
"value_printer_evaluator",
"gradient_printer_evaluator",
"maxid_printer_evaluator",
"maxframe_printer_evaluator",
"seqtext_printer_evaluator",
"classification_error_printer_evaluator",
"detection_map_evaluator",
] ]
......
...@@ -116,8 +116,8 @@ def _debug_string_(proto, throw_on_error=True): ...@@ -116,8 +116,8 @@ def _debug_string_(proto, throw_on_error=True):
""" """
error_fields = list() error_fields = list()
if not proto.IsInitialized(error_fields) and throw_on_error: if not proto.IsInitialized(error_fields) and throw_on_error:
raise ValueError("{0} are not initialized\nThe message is {1}".format( raise ValueError("{0} are not initialized.\nThe message is {1}:\n".
error_fields, proto)) format(error_fields, proto))
return proto.__str__() return proto.__str__()
...@@ -374,12 +374,13 @@ class Operator(object): ...@@ -374,12 +374,13 @@ class Operator(object):
>>> outputs={"Out": [var1]}) >>> outputs={"Out": [var1]})
Args: Args:
block(Block): The block has the current operator block(Block): The block has the current operator.
desc(core.OpDesc): The protobuf description desc(core.OpDesc): The protobuf description.
type(str): The type of operator. type(str): The type of operator.
inputs(dict): The input dictionary. Key is the input parameter name. inputs(dict): The input dictionary. Key is the input parameter name.
Value is a list of variables. Value is a list of variables.
outputs(dict): The output dictionary. Has same format with inputs outputs(dict): The output dictionary which has the same format with
inputs.
attrs(dict): The attributes dictionary. Key is attribute name. Value attrs(dict): The attributes dictionary. Key is attribute name. Value
is the attribute value. The attribute type should be as same as is the attribute value. The attribute type should be as same as
the type registered in C++ the type registered in C++
...@@ -436,10 +437,11 @@ class Operator(object): ...@@ -436,10 +437,11 @@ class Operator(object):
for m in proto.outputs: for m in proto.outputs:
need.add(m.name) need.add(m.name)
if not given == need: if not given == need:
raise ValueError( raise ValueError(("Incorrect setting for output(s) of "
"Incorrect setting for output(s) of operator \"%s\". Need: [%s] Given: [%s]" "operator \"%s\". Need: [%s] Given: [%s]") %
% (type, ", ".join(str(e) for e in need), ", ".join( (type, ", ".join(str(e)
str(e) for e in given))) for e in need), ", ".join(
str(e) for e in given)))
for out_proto in proto.outputs: for out_proto in proto.outputs:
out_args = outputs[out_proto.name] out_args = outputs[out_proto.name]
...@@ -818,9 +820,8 @@ class Program(object): ...@@ -818,9 +820,8 @@ class Program(object):
if isinstance(t, Variable): if isinstance(t, Variable):
t = t.op t = t.op
else: else:
raise ValueError( raise ValueError(("All targets of prune() can only be "
"All targets of prune() can only be Variable or Operator." "Variable or Operator."))
)
targets_idx.append([t.block.idx, t.idx]) targets_idx.append([t.block.idx, t.idx])
res = Program() res = Program()
......
...@@ -28,9 +28,9 @@ def data(name, ...@@ -28,9 +28,9 @@ def data(name,
**Data Layer** **Data Layer**
This function takes in the input and based on whether data has This function takes in the input and based on whether data has
to be returned back as a minibatch, it creates the global variable using to be returned back as a minibatch, it creates the global variable by using
the helper functions. The global variables can be accessed by all the the helper functions. The global variables can be accessed by all the
following operations and layers in the graph. following operators in the graph.
All the input variables of this function are passed in as local variables All the input variables of this function are passed in as local variables
to the LayerHelper constructor. to the LayerHelper constructor.
......
...@@ -50,6 +50,7 @@ __all__ = [ ...@@ -50,6 +50,7 @@ __all__ = [
'sequence_last_step', 'sequence_last_step',
'dropout', 'dropout',
'split', 'split',
'l2_normalize',
'matmul', 'matmul',
] ]
...@@ -946,7 +947,8 @@ def pool2d(input, ...@@ -946,7 +947,8 @@ def pool2d(input,
pool_type, pool_type,
pool_stride=None, pool_stride=None,
pool_padding=None, pool_padding=None,
global_pooling=False): global_pooling=False,
name=None):
""" """
This function adds the operator for pooling in 2 dimensions, using the This function adds the operator for pooling in 2 dimensions, using the
pooling configurations mentioned in input parameters. pooling configurations mentioned in input parameters.
...@@ -992,7 +994,8 @@ def batch_norm(input, ...@@ -992,7 +994,8 @@ def batch_norm(input,
epsilon=1e-05, epsilon=1e-05,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
data_layout='NCHW'): data_layout='NCHW',
name=None):
""" """
This function helps create an operator to implement This function helps create an operator to implement
the BatchNorm layer using the configurations from the input parameters. the BatchNorm layer using the configurations from the input parameters.
...@@ -1068,7 +1071,7 @@ def batch_norm(input, ...@@ -1068,7 +1071,7 @@ def batch_norm(input,
return helper.append_activation(batch_norm_out) return helper.append_activation(batch_norm_out)
def beam_search_decode(ids, scores): def beam_search_decode(ids, scores, name=None):
helper = LayerHelper('beam_search_decode', **locals()) helper = LayerHelper('beam_search_decode', **locals())
sentence_ids = helper.create_tmp_variable(dtype=ids.dtype) sentence_ids = helper.create_tmp_variable(dtype=ids.dtype)
sentence_scores = helper.create_tmp_variable(dtype=ids.dtype) sentence_scores = helper.create_tmp_variable(dtype=ids.dtype)
...@@ -1092,7 +1095,8 @@ def conv2d_transpose(input, ...@@ -1092,7 +1095,8 @@ def conv2d_transpose(input,
padding=None, padding=None,
stride=None, stride=None,
dilation=None, dilation=None,
param_attr=None): param_attr=None,
name=None):
""" """
The transpose of conv2d layer. The transpose of conv2d layer.
...@@ -1119,8 +1123,8 @@ def conv2d_transpose(input, ...@@ -1119,8 +1123,8 @@ def conv2d_transpose(input,
contain two integers, (dilation_H, dilation_W). Otherwise, the contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. dilation_H = dilation_W = dilation.
param_attr: Parameter Attribute. param_attr: Parameter Attribute.
main_program(Program): the main program name(str|None): A name for this layer(optional). If set None, the layer
startup_program(Program): the startup program will be named automatically.
Returns: Returns:
Variable: Output image. Variable: Output image.
...@@ -1184,7 +1188,7 @@ def conv2d_transpose(input, ...@@ -1184,7 +1188,7 @@ def conv2d_transpose(input,
return out return out
def sequence_expand(x, y): def sequence_expand(x, y, name=None):
"""Sequence Expand Layer. This layer will expand the input variable **x** """Sequence Expand Layer. This layer will expand the input variable **x**
according to LoD information of **y**. And the following examples will according to LoD information of **y**. And the following examples will
explain how sequence_expand works: explain how sequence_expand works:
...@@ -1228,6 +1232,8 @@ def sequence_expand(x, y): ...@@ -1228,6 +1232,8 @@ def sequence_expand(x, y):
Args: Args:
x (Variable): The input variable which is a Tensor or LoDTensor. x (Variable): The input variable which is a Tensor or LoDTensor.
y (Variable): The input variable which is a LoDTensor. y (Variable): The input variable which is a LoDTensor.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns: Returns:
Variable: The expanded variable which is a LoDTensor. Variable: The expanded variable which is a LoDTensor.
...@@ -1254,7 +1260,8 @@ def lstm_unit(x_t, ...@@ -1254,7 +1260,8 @@ def lstm_unit(x_t,
cell_t_prev, cell_t_prev,
forget_bias=0.0, forget_bias=0.0,
param_attr=None, param_attr=None,
bias_attr=None): bias_attr=None,
name=None):
"""Lstm unit layer. The equation of a lstm step is: """Lstm unit layer. The equation of a lstm step is:
.. math:: .. math::
...@@ -1301,6 +1308,8 @@ def lstm_unit(x_t, ...@@ -1301,6 +1308,8 @@ def lstm_unit(x_t,
initializer, name etc. initializer, name etc.
bias_attr (ParamAttr): The attributes of bias weights, if not False, bias_attr (ParamAttr): The attributes of bias weights, if not False,
bias weights will be created and be set to default value. bias weights will be created and be set to default value.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns: Returns:
tuple: The hidden value and cell value of lstm unit. tuple: The hidden value and cell value of lstm unit.
...@@ -1366,7 +1375,7 @@ def lstm_unit(x_t, ...@@ -1366,7 +1375,7 @@ def lstm_unit(x_t,
return h, c return h, c
def reduce_sum(input, dim=None, keep_dim=False): def reduce_sum(input, dim=None, keep_dim=False, name=None):
""" """
Computes the sum of tensor elements over the given dimension. Computes the sum of tensor elements over the given dimension.
...@@ -1380,6 +1389,8 @@ def reduce_sum(input, dim=None, keep_dim=False): ...@@ -1380,6 +1389,8 @@ def reduce_sum(input, dim=None, keep_dim=False):
keep_dim (bool): Whether to reserve the reduced dimension in the keep_dim (bool): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have one fewer dimension output Tensor. The result tensor will have one fewer dimension
than the :attr:`input` unless :attr:`keep_dim` is true. than the :attr:`input` unless :attr:`keep_dim` is true.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns: Returns:
Variable: The reduced Tensor variable. Variable: The reduced Tensor variable.
...@@ -1410,7 +1421,7 @@ def reduce_sum(input, dim=None, keep_dim=False): ...@@ -1410,7 +1421,7 @@ def reduce_sum(input, dim=None, keep_dim=False):
return out return out
def reduce_mean(input, dim=None, keep_dim=False): def reduce_mean(input, dim=None, keep_dim=False, name=None):
""" """
Computes the mean of tensor elements over the given dimension. Computes the mean of tensor elements over the given dimension.
...@@ -1424,6 +1435,8 @@ def reduce_mean(input, dim=None, keep_dim=False): ...@@ -1424,6 +1435,8 @@ def reduce_mean(input, dim=None, keep_dim=False):
keep_dim (bool): Whether to reserve the reduced dimension in the keep_dim (bool): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have one fewer dimension output Tensor. The result tensor will have one fewer dimension
than the :attr:`input` unless :attr:`keep_dim` is true. than the :attr:`input` unless :attr:`keep_dim` is true.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns: Returns:
Variable: The reduced Tensor variable. Variable: The reduced Tensor variable.
...@@ -1454,7 +1467,7 @@ def reduce_mean(input, dim=None, keep_dim=False): ...@@ -1454,7 +1467,7 @@ def reduce_mean(input, dim=None, keep_dim=False):
return out return out
def reduce_max(input, dim=None, keep_dim=False): def reduce_max(input, dim=None, keep_dim=False, name=None):
""" """
Computes the maximum of tensor elements over the given dimension. Computes the maximum of tensor elements over the given dimension.
...@@ -1468,6 +1481,8 @@ def reduce_max(input, dim=None, keep_dim=False): ...@@ -1468,6 +1481,8 @@ def reduce_max(input, dim=None, keep_dim=False):
keep_dim (bool): Whether to reserve the reduced dimension in the keep_dim (bool): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have one fewer dimension output Tensor. The result tensor will have one fewer dimension
than the :attr:`input` unless :attr:`keep_dim` is true. than the :attr:`input` unless :attr:`keep_dim` is true.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns: Returns:
Variable: The reduced Tensor variable. Variable: The reduced Tensor variable.
...@@ -1498,7 +1513,7 @@ def reduce_max(input, dim=None, keep_dim=False): ...@@ -1498,7 +1513,7 @@ def reduce_max(input, dim=None, keep_dim=False):
return out return out
def reduce_min(input, dim=None, keep_dim=False): def reduce_min(input, dim=None, keep_dim=False, name=None):
""" """
Computes the minimum of tensor elements over the given dimension. Computes the minimum of tensor elements over the given dimension.
...@@ -1512,6 +1527,8 @@ def reduce_min(input, dim=None, keep_dim=False): ...@@ -1512,6 +1527,8 @@ def reduce_min(input, dim=None, keep_dim=False):
keep_dim (bool): Whether to reserve the reduced dimension in the keep_dim (bool): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have one fewer dimension output Tensor. The result tensor will have one fewer dimension
than the :attr:`input` unless :attr:`keep_dim` is true. than the :attr:`input` unless :attr:`keep_dim` is true.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns: Returns:
Variable: The reduced Tensor variable. Variable: The reduced Tensor variable.
...@@ -1542,20 +1559,22 @@ def reduce_min(input, dim=None, keep_dim=False): ...@@ -1542,20 +1559,22 @@ def reduce_min(input, dim=None, keep_dim=False):
return out return out
def split(input, num_or_sections, dim=-1): def split(input, num_or_sections, dim=-1, name=None):
""" """
Splits the tensor into multiple sub-tensors. Split the input tensor into multiple sub-tensors.
Args: Args:
input (Variable): The input variable which is a Tensor or LoDTensor. input (Variable): The input variable which is a Tensor or LoDTensor.
num_or_sections (int|list): If :attr:`num_or_sections` is an integer, num_or_sections (int|list): If :attr:`num_or_sections` is an integer,
then the integer indicates the number of equal sized sub-tensors then the integer indicates the number of equal sized sub-tensors
that the tensor will be divided into. If :attr:`num_or_sections` that the tensor will be divided into. If :attr:`num_or_sections`
is a list of integers, the length of list indicates the number of is a list of integers, the length of list indicates the number of
sub-tensors and the integers indicate the sizes of sub-tensors' sub-tensors and the integers indicate the sizes of sub-tensors'
:attr:`dim` dimension orderly. :attr:`dim` dimension orderly.
dim (int): The dimension along which to split. If :math:`dim < 0`, the dim (int): The dimension along which to split. If :math:`dim < 0`, the
dimension to split along is :math:`rank(input) + dim`. dimension to split along is :math:`rank(input) + dim`.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns: Returns:
List: The list of segmented tensor variables. List: The list of segmented tensor variables.
...@@ -1600,6 +1619,87 @@ def split(input, num_or_sections, dim=-1): ...@@ -1600,6 +1619,87 @@ def split(input, num_or_sections, dim=-1):
return outs return outs
def l2_normalize(x, axis, epsilon=1e-12, name=None):
"""
**L2 normalize Layer**
The l2 normalize layer normalizes `x` along dimension `axis` using an L2
norm. For a 1-D tensor (`dim` is fixed to 0), this layer computes
output = x / sqrt(max(sum(x**2), epsilon))
For `x` with more dimensions, this layer independently normalizes each 1-D
slice along dimension `axis`.
Args:
x(Variable|list): The input tensor to l2_normalize layer.
axis(int): Dimension along which to normalize the input.
epsilon(float): A lower bound value for `x`'s l2 norm. sqrt(epsilon) will
be used as the divisor if the l2 norm of `x` is less than
sqrt(epsilon).
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: The output tensor variable.
Examples:
.. code-block:: python
data = fluid.layers.data(name="data",
shape=(3, 17, 13),
dtype="float32")
fc = fluid.layers.l2_normalize(x=data, axis=1)
"""
if len(x.shape) == 1: axis = 0
helper = LayerHelper("l2_normalize", **locals())
square = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(type="square", inputs={"X": x}, outputs={"Out": square})
reduced_sum = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(
type="reduce_sum",
inputs={"X": square},
outputs={"Out": reduced_sum},
attrs={
"dim": 1 if axis is None else axis,
"keep_dim": True,
"reduce_all": False
})
# TODO(caoying) A lower bound value epsilon for the norm is needed to
# imporve the numeric stability of reciprocal. This requires a maximum_op.
rsquare = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(
type="reciprocal", inputs={"X": reduced_sum}, outputs={"Out": rsquare})
# TODO(caoying) the current elementwise_mul operator does not support a
# general broadcast rule which broadcasts input(Y) to have the same
# dimension with Input(X) starting from a specified dimension. So this
# exanpsion is requred. Once a general broadcast rule is spported, this
# expanding canbe removed.
rsquare_expanded = helper.create_tmp_variable(dtype=x.dtype)
expand_times = [1] * len(x.shape)
expand_times[axis] = int(x.shape[axis])
helper.append_op(
type="expand",
inputs={"X": rsquare},
outputs={"Out": rsquare_expanded},
attrs={"expand_times": expand_times})
out = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(
type="elementwise_mul",
inputs={"X": x,
"Y": rsquare_expanded},
outputs={"Out": out})
return out
def matmul(x, y, transpose_x=False, transpose_y=False, name=None): def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
""" """
Applies matrix multipication to two tensors. Currently only rank 1 to rank Applies matrix multipication to two tensors. Currently only rank 1 to rank
...@@ -1653,6 +1753,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): ...@@ -1653,6 +1753,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
# x: [K], y: [K] # x: [K], y: [K]
fluid.layers.matmul(x, y) # out: [1] fluid.layers.matmul(x, y) # out: [1]
# x: [M], y: [N] # x: [M], y: [N]
fluid.layers.matmul(x, y, True, True) # out: [M, N] fluid.layers.matmul(x, y, True, True) # out: [M, N]
""" """
helper = LayerHelper('matmul', **locals()) helper = LayerHelper('matmul', **locals())
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import unittest
import paddle.v2.fluid as fluid
import paddle.v2.fluid.core as core
import numpy as np
class TestNormalization(unittest.TestCase):
data_desc = {"name": "input", "shape": (2, 3, 7)}
def gen_random_input(self):
"""Generate random input data.
"""
self.data = np.random.random(
size=self.data_desc["shape"]).astype("float32")
def set_program(self, axis, epsilon):
"""Build the test program.
"""
data = fluid.layers.data(
name=self.data_desc["name"],
shape=self.data_desc["shape"],
dtype="float32",
append_batch_size=False)
data.stop_gradient = False
l2_norm = fluid.layers.l2_normalize(x=data, axis=axis, epsilon=epsilon)
out = fluid.layers.reduce_sum(l2_norm, dim=None)
fluid.backward.append_backward(loss=out)
self.fetch_list = [l2_norm]
def run_program(self):
"""Run the test program.
"""
places = [core.CPUPlace()]
if core.is_compile_gpu():
places.append(core.CUDAPlace(0))
for place in places:
self.set_inputs(place)
exe = fluid.Executor(place)
output = exe.run(fluid.default_main_program(),
feed=self.inputs,
fetch_list=self.fetch_list,
return_numpy=True)
self.op_output = output
def set_inputs(self, place):
"""Set the randomly generated data to the test program.
"""
self.inputs = {}
tensor = fluid.Tensor()
tensor.set(self.data, place)
self.inputs[self.data_desc["name"]] = tensor
def l2_normalize(self, data, axis, epsilon):
""" Compute the groundtruth.
"""
output = data * np.reciprocal(
np.sum(np.square(data), axis=axis, keepdims=True))
return output
def test_l2_normalize(self):
""" Test the python wrapper for l2_normalize.
"""
axis = 1
#TODO(caoying) epsilon is not supported due to lack of a maximum_op.
epsilon = 1e-6
self.gen_random_input()
self.set_program(axis, epsilon)
self.run_program()
expect_output = self.l2_normalize(self.data, axis, epsilon)
# check output
self.assertTrue(np.allclose(self.op_output, expect_output, atol=0.001))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册