未验证 提交 4d5f6937 编写于 作者: J Jiabin Yang 提交者: GitHub

Feature/refine api for dygraph (#17907)

* WIP

* WIP

* test=develop, add api doc and example code for dygraph
上级 dd4cd352
...@@ -155,7 +155,30 @@ void BindImperative(pybind11::module *m_ptr) { ...@@ -155,7 +155,30 @@ void BindImperative(pybind11::module *m_ptr) {
auto &m = *m_ptr; auto &m = *m_ptr;
py::class_<imperative::detail::BackwardStrategy> backward_strategy( py::class_<imperative::detail::BackwardStrategy> backward_strategy(
m, "BackwardStrategy", R"DOC()DOC"); m, "BackwardStrategy", R"DOC(
BackwardStrategy is a descriptor of a how to run the backward process. Now it has:
1. :code:`sort_sum_gradient`, which will sum the gradient by the reverse order of trace.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
from paddle.fluid import FC
x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard():
inputs2 = []
for _ in range(10):
inputs2.append(fluid.dygraph.base.to_variable(x))
ret2 = fluid.layers.sums(inputs2)
loss2 = fluid.layers.reduce_sum(ret2)
backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True
loss2.backward(backward_strategy)
)DOC");
backward_strategy.def(py::init()) backward_strategy.def(py::init())
.def_property("sort_sum_gradient", .def_property("sort_sum_gradient",
[](const imperative::detail::BackwardStrategy &self) { [](const imperative::detail::BackwardStrategy &self) {
......
...@@ -54,6 +54,33 @@ def _dygraph_not_support_(func): ...@@ -54,6 +54,33 @@ def _dygraph_not_support_(func):
def _no_grad_(func): def _no_grad_(func):
"""
This Decorator will avoid the func being decorated creating backward network in dygraph mode
Args:
func: the func don't need grad
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
@fluid.dygraph.no_grad
def test_layer():
with fluid.dygraph.guard():
inp = np.ones([3, 32, 32], dtype='float32')
t = fluid.dygraph.base.to_variable(inp)
fc1 = fluid.FC('fc1', size=4, bias_attr=False, num_flatten_dims=1)
fc2 = fluid.FC('fc2', size=4)
ret = fc1(t)
dy_ret = fc2(ret)
test_layer()
"""
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
with _switch_tracer_mode_guard_(is_train=False): with _switch_tracer_mode_guard_(is_train=False):
return func(*args, **kwargs) return func(*args, **kwargs)
...@@ -67,6 +94,31 @@ not_support = wrap_decorator(_dygraph_not_support_) ...@@ -67,6 +94,31 @@ not_support = wrap_decorator(_dygraph_not_support_)
@signature_safe_contextmanager @signature_safe_contextmanager
def guard(place=None): def guard(place=None):
"""
This context will create a dygraph context for dygraph to run
Args:
place(fluid.CPUPlace|fluid.CUDAPlace|None): Place to run
return:
None
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
with fluid.dygraph.guard():
inp = np.ones([3, 32, 32], dtype='float32')
t = fluid.dygraph.base.to_variable(inp)
fc1 = fluid.FC('fc1', size=4, bias_attr=False, num_flatten_dims=1)
fc2 = fluid.FC('fc2', size=4)
ret = fc1(t)
dy_ret = fc2(ret)
"""
train = framework.Program() train = framework.Program()
startup = framework.Program() startup = framework.Program()
tracer = Tracer(train.current_block().desc) tracer = Tracer(train.current_block().desc)
...@@ -85,6 +137,29 @@ def guard(place=None): ...@@ -85,6 +137,29 @@ def guard(place=None):
def to_variable(value, block=None, name=None): def to_variable(value, block=None, name=None):
"""
This function will create a variable from ndarray
Args:
value(ndarray): the numpy value need to be convert
block(fluid.Block|None): which block this variable will be in
name(str|None): Name of Varaible
return:
Variable: The variable created from given numpy
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
with fluid.dygraph.guard():
x = np.ones([2, 2], np.float32)
y = fluid.dygraph.to_variable(x)
"""
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
assert enabled(), "to_variable could only be called in dygraph mode" assert enabled(), "to_variable could only be called in dygraph mode"
......
...@@ -1170,22 +1170,15 @@ class Embedding(layers.Layer): ...@@ -1170,22 +1170,15 @@ class Embedding(layers.Layer):
This layer is used to lookup embeddings of IDs, provided by :attr:`input`, in This layer is used to lookup embeddings of IDs, provided by :attr:`input`, in
a lookup table. The result of this lookup is the embedding of each ID in the a lookup table. The result of this lookup is the embedding of each ID in the
:attr:`input`. :attr:`input`.
All the input variables are passed in as local variables to the LayerHelper constructor
All the input variables are passed in as local variables to the LayerHelper
constructor.
Args: Args:
name_scope: See base class. name_scope: See base class.
size(tuple|list): The shape of the look up table parameter. It should size(tuple|list): The shape of the look up table parameter. It should have two elements which indicate the size of the dictionary of embeddings and the size of each embedding vector respectively.
have two elements which indicate the size of the dictionary of
embeddings and the size of each embedding vector respectively.
is_sparse(bool): The flag indicating whether to use sparse update. is_sparse(bool): The flag indicating whether to use sparse update.
is_distributed(bool): Whether to run lookup table from remote parameter server. is_distributed(bool): Whether to run lookup table from remote parameter server.
padding_idx(int|long|None): If :attr:`None`, it makes no effect to lookup. padding_idx(int|long|None): If :attr:`None`, it makes no effect to lookup. Otherwise the given :attr:`padding_idx` indicates padding the output with zeros whenever lookup encounters it in :attr:`input`. If :math:`padding_idx < 0`, the :attr:`padding_idx` to use in lookup is :math:`size[0] + dim`.
Otherwise the given :attr:`padding_idx` indicates padding the output
with zeros whenever lookup encounters it in :attr:`input`. If
:math:`padding_idx < 0`, the :attr:`padding_idx` to use in lookup is
:math:`size[0] + dim`.
param_attr(ParamAttr): Parameters for this layer param_attr(ParamAttr): Parameters for this layer
dtype(np.dtype|core.VarDesc.VarType|str): The type of data : float32, float_16, int etc dtype(np.dtype|core.VarDesc.VarType|str): The type of data : float32, float_16, int etc
...@@ -1197,10 +1190,15 @@ class Embedding(layers.Layer): ...@@ -1197,10 +1190,15 @@ class Embedding(layers.Layer):
.. code-block:: python .. code-block:: python
dict_size = len(dataset.ids) inp_word = np.array([[[1]]]).astype('int64')
input = fluid.layers.data(name='ids', shape=[32, 32], dtype='float32') dict_size = 20
embedding = fluid.Embedding(size=[dict_size, 16]) with fluid.dygraph.guard():
fc = embedding(input) emb = fluid.Embedding(
name_scope='embedding',
size=[dict_size, 32],
param_attr='emb.w',
is_sparse=False)
static_rlt3 = emb2(base.to_variable(inp_word))
""" """
def __init__(self, def __init__(self,
...@@ -1509,12 +1507,15 @@ class GRUUnit(layers.Layer): ...@@ -1509,12 +1507,15 @@ class GRUUnit(layers.Layer):
class NCE(layers.Layer): class NCE(layers.Layer):
""" """
${comment} Compute and return the noise-contrastive estimation training loss. See
`Noise-contrastive estimation: A new estimation principle for unnormalized
statistical models
<http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf>`_.
By default this operator uses a uniform distribution for sampling.
Args: Args:
input (Variable): input variable. name_scope (str): See base class.
label (Variable): label. num_total_classes (int): Total number of classes in all samples
num_total_classes (int):${num_total_classes_comment}
sample_weight (Variable|None): A Variable of shape [batch_size, 1] sample_weight (Variable|None): A Variable of shape [batch_size, 1]
storing a weight for each sample. The default weight for each storing a weight for each sample. The default weight for each
sample is 1.0. sample is 1.0.
...@@ -1527,7 +1528,7 @@ class NCE(layers.Layer): ...@@ -1527,7 +1528,7 @@ class NCE(layers.Layer):
If it is set to None or one attribute of ParamAttr, nce If it is set to None or one attribute of ParamAttr, nce
will create ParamAttr as bias_attr. If the Initializer of the bias_attr will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None. is not set, the bias is initialized zero. Default: None.
num_neg_samples (int): ${num_neg_samples_comment} num_neg_samples (int): The number of negative classes. The default value is 10.
name (str|None): A name for this layer(optional). If set None, the layer name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None. will be named automatically. Default: None.
sampler (str): The sampler used to sample class from negtive classes. sampler (str): The sampler used to sample class from negtive classes.
...@@ -1546,37 +1547,45 @@ class NCE(layers.Layer): ...@@ -1546,37 +1547,45 @@ class NCE(layers.Layer):
Examples: Examples:
.. code-block:: python .. code-block:: python
import numpy as np
import paddle.fluid as fluid
window_size = 5 window_size = 5
dict_size = 20
label_word = int(window_size // 2) + 1
inp_word = np.array([[[1]], [[2]], [[3]], [[4]], [[5]]]).astype('int64')
nid_freq_arr = np.random.dirichlet(np.ones(20) * 1000).astype('float32')
with fluid.dygraph.guard():
words = [] words = []
for i in xrange(window_size): for i in range(window_size):
words.append(layers.data( words.append(fluid.dygraph.base.to_variable(inp_word[i]))
name='word_{0}'.format(i), shape=[1], dtype='int64'))
dict_size = 10000 emb = fluid.Embedding(
label_word = int(window_size / 2) + 1 'embedding',
size=[dict_size, 32],
param_attr='emb.w',
is_sparse=False)
embs = [] embs3 = []
for i in xrange(window_size): for i in range(window_size):
if i == label_word: if i == label_word:
continue continue
emb = layers.embedding(input=words[i], size=[dict_size, 32], emb_rlt = emb(words[i])
param_attr='emb.w', is_sparse=True) embs3.append(emb_rlt)
embs.append(emb)
embs = layers.concat(input=embs, axis=1) embs3 = fluid.layers.concat(input=embs3, axis=1)
loss = layers.nce(input=embs, label=words[label_word], nce = fluid.NCE('nce',
num_total_classes=dict_size, param_attr='nce.w', num_total_classes=dict_size,
num_neg_samples=2,
sampler="custom_dist",
custom_dist=nid_freq_arr.tolist(),
seed=1,
param_attr='nce.w',
bias_attr='nce.b') bias_attr='nce.b')
#or use custom distribution nce_loss3 = nce(embs3, words[label_word])
dist = fluid.layers.assign(input=np.array([0.05,0.5,0.1,0.3,0.05]).astype("float32"))
loss = layers.nce(input=embs, label=words[label_word],
num_total_classes=5, param_attr='nce.w',
bias_attr='nce.b',
num_neg_samples=3,
sampler="custom_dist",
custom_dist=dist)
""" """
...@@ -1739,13 +1748,13 @@ class PRelu(layers.Layer): ...@@ -1739,13 +1748,13 @@ class PRelu(layers.Layer):
y = \max(0, x) + \\alpha * \min(0, x) y = \max(0, x) + \\alpha * \min(0, x)
Args: Args:
x (Variable): The input tensor. name_scope (str): See base class.
param_attr(ParamAttr|None): The parameter attribute for the learnable
weight (alpha).
mode (string): The mode for weight sharing. It supports all, channel mode (string): The mode for weight sharing. It supports all, channel
and element. all: all elements share same weight and element. all: all elements share same weight
channel:elements in a channel share same weight channel:elements in a channel share same weight
element:each element has a weight element:each element has a weight
param_attr(ParamAttr|None): The parameter attribute for the learnable
weight (alpha).
name(str|None): A name for this layer(optional). If set None, the layer name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
...@@ -1756,9 +1765,14 @@ class PRelu(layers.Layer): ...@@ -1756,9 +1765,14 @@ class PRelu(layers.Layer):
.. code-block:: python .. code-block:: python
x = fluid.layers.data(name="x", shape=[10,10], dtype="float32") inp_np = np.ones([5, 200, 100, 100]).astype('float32')
with fluid.dygraph.guard():
mode = 'channel' mode = 'channel'
output = fluid.layers.prelu(x,mode) prelu = fluid.PRelu(
'prelu',
mode=mode,
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(1.0)))
dy_rlt = prelu(fluid.dygraph.base.to_variable(inp_np))
""" """
def __init__(self, name_scope, mode, param_attr=None): def __init__(self, name_scope, mode, param_attr=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册