未验证 提交 4d5f6937 编写于 作者: J Jiabin Yang 提交者: GitHub

Feature/refine api for dygraph (#17907)

* WIP

* WIP

* test=develop, add api doc and example code for dygraph
上级 dd4cd352
......@@ -155,7 +155,30 @@ void BindImperative(pybind11::module *m_ptr) {
auto &m = *m_ptr;
py::class_<imperative::detail::BackwardStrategy> backward_strategy(
m, "BackwardStrategy", R"DOC()DOC");
m, "BackwardStrategy", R"DOC(
BackwardStrategy is a descriptor of a how to run the backward process. Now it has:
1. :code:`sort_sum_gradient`, which will sum the gradient by the reverse order of trace.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
from paddle.fluid import FC
x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard():
inputs2 = []
for _ in range(10):
inputs2.append(fluid.dygraph.base.to_variable(x))
ret2 = fluid.layers.sums(inputs2)
loss2 = fluid.layers.reduce_sum(ret2)
backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True
loss2.backward(backward_strategy)
)DOC");
backward_strategy.def(py::init())
.def_property("sort_sum_gradient",
[](const imperative::detail::BackwardStrategy &self) {
......
......@@ -54,6 +54,33 @@ def _dygraph_not_support_(func):
def _no_grad_(func):
"""
This Decorator will avoid the func being decorated creating backward network in dygraph mode
Args:
func: the func don't need grad
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
@fluid.dygraph.no_grad
def test_layer():
with fluid.dygraph.guard():
inp = np.ones([3, 32, 32], dtype='float32')
t = fluid.dygraph.base.to_variable(inp)
fc1 = fluid.FC('fc1', size=4, bias_attr=False, num_flatten_dims=1)
fc2 = fluid.FC('fc2', size=4)
ret = fc1(t)
dy_ret = fc2(ret)
test_layer()
"""
def __impl__(*args, **kwargs):
with _switch_tracer_mode_guard_(is_train=False):
return func(*args, **kwargs)
......@@ -67,6 +94,31 @@ not_support = wrap_decorator(_dygraph_not_support_)
@signature_safe_contextmanager
def guard(place=None):
"""
This context will create a dygraph context for dygraph to run
Args:
place(fluid.CPUPlace|fluid.CUDAPlace|None): Place to run
return:
None
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
with fluid.dygraph.guard():
inp = np.ones([3, 32, 32], dtype='float32')
t = fluid.dygraph.base.to_variable(inp)
fc1 = fluid.FC('fc1', size=4, bias_attr=False, num_flatten_dims=1)
fc2 = fluid.FC('fc2', size=4)
ret = fc1(t)
dy_ret = fc2(ret)
"""
train = framework.Program()
startup = framework.Program()
tracer = Tracer(train.current_block().desc)
......@@ -85,6 +137,29 @@ def guard(place=None):
def to_variable(value, block=None, name=None):
"""
This function will create a variable from ndarray
Args:
value(ndarray): the numpy value need to be convert
block(fluid.Block|None): which block this variable will be in
name(str|None): Name of Varaible
return:
Variable: The variable created from given numpy
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
with fluid.dygraph.guard():
x = np.ones([2, 2], np.float32)
y = fluid.dygraph.to_variable(x)
"""
if isinstance(value, np.ndarray):
assert enabled(), "to_variable could only be called in dygraph mode"
......
......@@ -1170,22 +1170,15 @@ class Embedding(layers.Layer):
This layer is used to lookup embeddings of IDs, provided by :attr:`input`, in
a lookup table. The result of this lookup is the embedding of each ID in the
:attr:`input`.
All the input variables are passed in as local variables to the LayerHelper
constructor.
All the input variables are passed in as local variables to the LayerHelper constructor
Args:
name_scope: See base class.
size(tuple|list): The shape of the look up table parameter. It should
have two elements which indicate the size of the dictionary of
embeddings and the size of each embedding vector respectively.
size(tuple|list): The shape of the look up table parameter. It should have two elements which indicate the size of the dictionary of embeddings and the size of each embedding vector respectively.
is_sparse(bool): The flag indicating whether to use sparse update.
is_distributed(bool): Whether to run lookup table from remote parameter server.
padding_idx(int|long|None): If :attr:`None`, it makes no effect to lookup.
Otherwise the given :attr:`padding_idx` indicates padding the output
with zeros whenever lookup encounters it in :attr:`input`. If
:math:`padding_idx < 0`, the :attr:`padding_idx` to use in lookup is
:math:`size[0] + dim`.
padding_idx(int|long|None): If :attr:`None`, it makes no effect to lookup. Otherwise the given :attr:`padding_idx` indicates padding the output with zeros whenever lookup encounters it in :attr:`input`. If :math:`padding_idx < 0`, the :attr:`padding_idx` to use in lookup is :math:`size[0] + dim`.
param_attr(ParamAttr): Parameters for this layer
dtype(np.dtype|core.VarDesc.VarType|str): The type of data : float32, float_16, int etc
......@@ -1197,10 +1190,15 @@ class Embedding(layers.Layer):
.. code-block:: python
dict_size = len(dataset.ids)
input = fluid.layers.data(name='ids', shape=[32, 32], dtype='float32')
embedding = fluid.Embedding(size=[dict_size, 16])
fc = embedding(input)
inp_word = np.array([[[1]]]).astype('int64')
dict_size = 20
with fluid.dygraph.guard():
emb = fluid.Embedding(
name_scope='embedding',
size=[dict_size, 32],
param_attr='emb.w',
is_sparse=False)
static_rlt3 = emb2(base.to_variable(inp_word))
"""
def __init__(self,
......@@ -1509,12 +1507,15 @@ class GRUUnit(layers.Layer):
class NCE(layers.Layer):
"""
${comment}
Compute and return the noise-contrastive estimation training loss. See
`Noise-contrastive estimation: A new estimation principle for unnormalized
statistical models
<http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf>`_.
By default this operator uses a uniform distribution for sampling.
Args:
input (Variable): input variable.
label (Variable): label.
num_total_classes (int):${num_total_classes_comment}
name_scope (str): See base class.
num_total_classes (int): Total number of classes in all samples
sample_weight (Variable|None): A Variable of shape [batch_size, 1]
storing a weight for each sample. The default weight for each
sample is 1.0.
......@@ -1527,7 +1528,7 @@ class NCE(layers.Layer):
If it is set to None or one attribute of ParamAttr, nce
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
num_neg_samples (int): ${num_neg_samples_comment}
num_neg_samples (int): The number of negative classes. The default value is 10.
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None.
sampler (str): The sampler used to sample class from negtive classes.
......@@ -1546,37 +1547,45 @@ class NCE(layers.Layer):
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
window_size = 5
words = []
for i in xrange(window_size):
words.append(layers.data(
name='word_{0}'.format(i), shape=[1], dtype='int64'))
dict_size = 10000
label_word = int(window_size / 2) + 1
embs = []
for i in xrange(window_size):
if i == label_word:
continue
emb = layers.embedding(input=words[i], size=[dict_size, 32],
param_attr='emb.w', is_sparse=True)
embs.append(emb)
embs = layers.concat(input=embs, axis=1)
loss = layers.nce(input=embs, label=words[label_word],
num_total_classes=dict_size, param_attr='nce.w',
bias_attr='nce.b')
#or use custom distribution
dist = fluid.layers.assign(input=np.array([0.05,0.5,0.1,0.3,0.05]).astype("float32"))
loss = layers.nce(input=embs, label=words[label_word],
num_total_classes=5, param_attr='nce.w',
bias_attr='nce.b',
num_neg_samples=3,
sampler="custom_dist",
custom_dist=dist)
dict_size = 20
label_word = int(window_size // 2) + 1
inp_word = np.array([[[1]], [[2]], [[3]], [[4]], [[5]]]).astype('int64')
nid_freq_arr = np.random.dirichlet(np.ones(20) * 1000).astype('float32')
with fluid.dygraph.guard():
words = []
for i in range(window_size):
words.append(fluid.dygraph.base.to_variable(inp_word[i]))
emb = fluid.Embedding(
'embedding',
size=[dict_size, 32],
param_attr='emb.w',
is_sparse=False)
embs3 = []
for i in range(window_size):
if i == label_word:
continue
emb_rlt = emb(words[i])
embs3.append(emb_rlt)
embs3 = fluid.layers.concat(input=embs3, axis=1)
nce = fluid.NCE('nce',
num_total_classes=dict_size,
num_neg_samples=2,
sampler="custom_dist",
custom_dist=nid_freq_arr.tolist(),
seed=1,
param_attr='nce.w',
bias_attr='nce.b')
nce_loss3 = nce(embs3, words[label_word])
"""
......@@ -1739,13 +1748,13 @@ class PRelu(layers.Layer):
y = \max(0, x) + \\alpha * \min(0, x)
Args:
x (Variable): The input tensor.
param_attr(ParamAttr|None): The parameter attribute for the learnable
weight (alpha).
name_scope (str): See base class.
mode (string): The mode for weight sharing. It supports all, channel
and element. all: all elements share same weight
channel:elements in a channel share same weight
element:each element has a weight
param_attr(ParamAttr|None): The parameter attribute for the learnable
weight (alpha).
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
......@@ -1756,9 +1765,14 @@ class PRelu(layers.Layer):
.. code-block:: python
x = fluid.layers.data(name="x", shape=[10,10], dtype="float32")
inp_np = np.ones([5, 200, 100, 100]).astype('float32')
with fluid.dygraph.guard():
mode = 'channel'
output = fluid.layers.prelu(x,mode)
prelu = fluid.PRelu(
'prelu',
mode=mode,
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(1.0)))
dy_rlt = prelu(fluid.dygraph.base.to_variable(inp_np))
"""
def __init__(self, name_scope, mode, param_attr=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册