未验证 提交 075df09f 编写于 作者: J Jiabin Yang 提交者: GitHub

Merge pull request #15470 from JiabinYang/feature/imperative

Add simple RNN in imperative
......@@ -31,6 +31,7 @@ void CreateGradOp(const framework::OpDesc& op_desc,
framework::OpInfoMap::Instance()
.Get(op_desc.Type())
.GradOpMaker()(op_desc, no_grad_set, grad_to_var, grad_sub_block);
for (auto& desc : descs) {
grad_op_descs->emplace_back(desc.release());
}
......
......@@ -13,7 +13,9 @@
// limitations under the License.
#pragma once
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/platform/port.h"
......
cc_library(benchmark SRCS benchmark.cc DEPS enforce)
cc_test(test_benchmark SRCS benchmark_tester.cc DEPS benchmark)
cc_binary(visualizer SRCS visualizer.cc DEPS analysis
paddle_pass_builder ir_pass_manager pass graph_viz_pass analysis_passes)
#cc_binary(visualizer SRCS visualizer.cc DEPS analysis
# paddle_pass_builder ir_pass_manager pass graph_viz_pass analysis_passes)
......@@ -22,13 +22,7 @@ from . import layers
from ..framework import Variable, OpProtoHolder
from ..param_attr import ParamAttr
from ..initializer import Normal, Constant
__all__ = [
'Conv2D',
'Pool2D',
'FC',
'BatchNorm',
]
__all__ = ['Conv2D', 'Pool2D', 'FC', 'BatchNorm', 'Embedding']
class Conv2D(layers.Layer):
......@@ -414,3 +408,91 @@ class BatchNorm(layers.Layer):
# Currently, we don't support inplace in imperative mode
return self._helper.append_activation(batch_norm_out)
class Embedding(layers.Layer):
"""
**Embedding Layer**
This layer is used to lookup embeddings of IDs, provided by :attr:`input`, in
a lookup table. The result of this lookup is the embedding of each ID in the
:attr:`input`.
All the input variables are passed in as local variables to the LayerHelper
constructor.
Args:
size(tuple|list): The shape of the look up table parameter. It should
have two elements which indicate the size of the dictionary of
embeddings and the size of each embedding vector respectively.
is_sparse(bool): The flag indicating whether to use sparse update.
is_distributed(bool): Whether to run lookup table from remote parameter server.
padding_idx(int|long|None): If :attr:`None`, it makes no effect to lookup.
Otherwise the given :attr:`padding_idx` indicates padding the output
with zeros whenever lookup encounters it in :attr:`input`. If
:math:`padding_idx < 0`, the :attr:`padding_idx` to use in lookup is
:math:`size[0] + dim`.
param_attr(ParamAttr): Parameters for this layer
dtype(np.dtype|core.VarDesc.VarType|str): The type of data : float32, float_16, int etc
Returns:
Variable: The tensor variable storing the embeddings of the \
supplied inputs.
Examples:
.. code-block:: python
dict_size = len(dataset.ids)
input = fluid.layers.data(name='ids', shape=[32, 32], dtype='float32')
embedding = fluid.imperative.Embedding(size=[dict_size, 16])
fc = embedding(input)
"""
def __init__(self,
size,
is_sparse=False,
is_distributed=False,
padding_idx=None,
param_attr=None,
dtype='float32'):
super(Embedding, self).__init__()
self._size = size
self._is_sparse = is_sparse
self._is_distributed = is_distributed
self._padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else (
size[0] + padding_idx)
self._param_attr = param_attr
self._dtype = dtype
self._remote_prefetch = self._is_sparse and (not self._is_distributed)
if self._remote_prefetch:
assert self._is_sparse is True and self._is_distributed is False
from ..layer_helper import LayerHelper
self._helper = LayerHelper('embedding', param_attr=param_attr)
self._w = self._helper.create_parameter(
attr=self._param_attr,
shape=self._size,
dtype=self._dtype,
is_bias=False)
def parameters(self):
return [self._w]
def forward(self, input):
out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type='lookup_table',
inputs={'Ids': input,
'W': self._w},
outputs={'Out': out},
attrs={
'is_sparse': self._is_sparse,
'is_distributed': self._is_distributed,
'remote_prefetch': self._remote_prefetch,
'padding_idx': self._padding_idx
})
return out
......@@ -66,6 +66,128 @@ class MLP(fluid.imperative.Layer):
return x
class SimpleRNNCell(fluid.imperative.Layer):
def __init__(self, step_input_size, hidden_size, output_size, param_attr):
super(SimpleRNNCell, self).__init__()
self.step_input_size = step_input_size
self.hidden_size = hidden_size
self.output_size = output_size
self._dype = core.VarDesc.VarType.FP32
from paddle.fluid.layer_helper import LayerHelper
self._helper = LayerHelper(
'SimpleRNNCell', act="tanh", param_attr=param_attr)
def _build_once(self, inputs, pre_hidden):
i2h_param_shape = [self.step_input_size, self.hidden_size]
h2h_param_shape = [self.hidden_size, self.hidden_size]
h2o_param_shape = [self.output_size, self.hidden_size]
self._i2h_w = self._helper.create_parameter(
attr=self._helper.param_attr,
shape=i2h_param_shape,
dtype=self._dtype,
is_bias=False)
self._h2h_w = self._helper.create_parameter(
attr=self._helper.param_attr,
shape=h2h_param_shape,
dtype=self._dtype,
is_bias=False)
self._h2o_w = self._helper.create_parameter(
attr=self._helper.param_attr,
shape=h2o_param_shape,
dtype=self._dtype,
is_bias=False)
def forward(self, input, pre_hidden):
tmp_i2h = self._helper.create_variable_for_type_inference(self._dtype)
tmp_h2h = self._helper.create_variable_for_type_inference(self._dtype)
hidden = self._helper.create_variable_for_type_inference(self._dype)
out = self._helper.create_variable_for_type_inference(self._dype)
softmax_out = self._helper.create_variable_for_type_inference(
self._dtype)
reduce_out = self._helper.create_variable_for_type_inference(
self._dtype)
self._helper.append_op(
type="mul",
inputs={"X": input,
"Y": self._i2h_w},
outputs={"Out": tmp_i2h},
attrs={"x_num_col_dims": 1,
"y_num_col_dims": 1})
self._helper.append_op(
type="mul",
inputs={"X": pre_hidden,
"Y": self._h2h_w},
outputs={"Out": tmp_h2h},
attrs={"x_num_col_dims": 1,
"y_num_col_dims": 1})
self._helper.append_op(
type="elementwise_add",
inputs={'X': tmp_h2h,
'Y': tmp_i2h},
outputs={'Out': hidden},
attrs={'axis': -1,
'use_mkldnn': False})
hidden = self._helper.append_activation(hidden)
self._helper.append_op(
type="mul",
inputs={"X": hidden,
"Y": self._h2o_w},
outputs={"Out": out},
attrs={"x_num_col_dims": 1,
"y_num_col_dims": 1})
self._helper.append_op(
type="softmax",
inputs={"X": out},
outputs={"Out": softmax_out},
attrs={"use_cudnn": False})
self._helper.append_op(
type='reduce_sum',
inputs={'X': softmax_out},
outputs={'Out': reduce_out},
attrs={'dim': None,
'keep_dim': False,
'reduce_all': True})
return reduce_out, hidden
class SimpleRNN(fluid.imperative.Layer):
def __init__(self):
super(SimpleRNN, self).__init__()
self.seq_len = 4
self._cell = SimpleRNNCell(
3,
3,
3,
fluid.ParamAttr(initializer=fluid.initializer.Constant(value=0.1)))
def forward(self, inputs):
outs = list()
pre_hiddens = list()
init_hidden = fluid.layers.tensor.create_parameter(
attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.1)),
shape=[1, 3],
dtype='float32',
is_bias=False)
pre_hidden = init_hidden
for i in range(self.seq_len):
input = fluid.layers.slice(
inputs, axes=[1], starts=[i], ends=[i + 1])
input = fluid.layers.reshape(input, shape=[1, 3])
out_softmax, pre_hidden = self._cell(input, pre_hidden)
outs.append(out_softmax)
return outs, pre_hiddens
class TestImperative(unittest.TestCase):
def test_sum_op(self):
x = np.ones([2, 2], np.float32)
......@@ -211,6 +333,41 @@ class TestImperative(unittest.TestCase):
self.assertTrue(np.allclose(dy_out, static_out))
self.assertTrue(np.allclose(dy_grad, static_grad))
def test_rnn(self):
np_inp = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0],
[10.0, 11.0, 12.0]])
np_inp = np_inp.reshape((1, 4, 3))
np_inp = np_inp.astype(np.float32)
with fluid.imperative.guard():
var_inp = fluid.imperative.base.to_variable(np_inp)
var_inp = fluid.layers.reshape(var_inp, shape=[1, 4, 3])
simple_rnn = SimpleRNN()
outs, pre_hiddens = simple_rnn.forward(var_inp)
dy_out = outs[3]._numpy()
outs[3]._backward()
dy_grad_h2o = simple_rnn._cell._h2o_w._gradient()
dy_grad_h2h = simple_rnn._cell._h2h_w._gradient()
dy_grad_i2h = simple_rnn._cell._i2h_w._gradient()
with new_program_scope():
inp = fluid.layers.data(
name="inp", shape=[1, 4, 3], append_batch_size=False)
simple_rnn = SimpleRNN()
outs, pre_hiddens = simple_rnn(inp)
param_grads = fluid.backward.append_backward(outs[3])
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
static_out, static_grad_h2o, static_grad_h2h, static_grad_i2h = exe.run(
feed={inp.name: np_inp},
fetch_list=[
outs[3].name, param_grads[0][1].name,
param_grads[1][1].name, param_grads[2][1].name
])
self.assertTrue(np.allclose(dy_out, static_out))
self.assertTrue(np.allclose(dy_grad_h2o, static_grad_h2o))
self.assertTrue(np.allclose(dy_grad_h2h, static_grad_h2h))
self.assertTrue(np.allclose(dy_grad_i2h, static_grad_i2h))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
from paddle.fluid.imperative.nn import Embedding
import paddle.fluid.framework as framework
from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.imperative.base import to_variable
from test_imperative_base import new_program_scope
import numpy as np
import six
from paddle.fluid.backward import append_backward
class SimpleLSTMRNN(fluid.imperative.Layer):
def __init__(self,
hidden_size,
num_steps,
num_layers=2,
init_scale=0.1,
dropout=None):
super(SimpleLSTMRNN, self).__init__()
self._hidden_size = hidden_size
self._num_layers = num_layers
self._init_scale = init_scale
self._dropout = dropout
self._input = None
self._num_steps = num_steps
def _build_once(self, input_embedding, init_hidden=None, init_cell=None):
self.weight_1_arr = []
self.weight_2_arr = []
self.bias_arr = []
self.hidden_array = []
self.cell_array = []
self.mask_array = []
for i in range(self._num_layers):
weight_1 = fluid.layers.create_parameter(
shape=[self._hidden_size * 2, self._hidden_size * 4],
dtype="float32",
name="fc_weight1_" + str(i),
default_initializer=fluid.initializer.UniformInitializer(
low=-self._init_scale, high=self._init_scale))
self.weight_1_arr.append(weight_1)
bias_1 = fluid.layers.create_parameter(
[self._hidden_size * 4],
dtype="float32",
name="fc_bias1_" + str(i),
default_initializer=fluid.initializer.Constant(0.0))
self.bias_arr.append(bias_1)
pre_hidden = fluid.layers.slice(
init_hidden, axes=[0], starts=[i], ends=[i + 1])
pre_cell = fluid.layers.slice(
init_cell, axes=[0], starts=[i], ends=[i + 1])
pre_hidden = fluid.layers.reshape(
pre_hidden, shape=[-1, self._hidden_size])
pre_cell = fluid.layers.reshape(
pre_cell, shape=[-1, self._hidden_size])
self.hidden_array.append(pre_hidden)
self.cell_array.append(pre_cell)
def parameters(self):
parameters = list()
for param in self.weight_1_arr:
parameters.append(param)
for param in self.weight_2_arr:
parameters.append(param)
for bias in self.bias_arr:
parameters.append(bias)
return parameters
def forward(self, input_embedding, init_hidden=None, init_cell=None):
res = []
for index in range(self._num_steps):
self._input = fluid.layers.slice(
input_embedding, axes=[1], starts=[index], ends=[index + 1])
self._input = fluid.layers.reshape(
self._input, shape=[-1, self._hidden_size])
for k in range(self._num_layers):
pre_hidden = self.hidden_array[k]
pre_cell = self.cell_array[k]
weight_1 = self.weight_1_arr[k]
bias = self.bias_arr[k]
nn = fluid.layers.concat([self._input, pre_hidden], 1)
gate_input = fluid.layers.matmul(x=nn, y=weight_1)
gate_input = fluid.layers.elementwise_add(gate_input, bias)
i, j, f, o = fluid.layers.split(
gate_input, num_or_sections=4, dim=-1)
c = pre_cell * fluid.layers.sigmoid(f) + fluid.layers.sigmoid(
i) * fluid.layers.tanh(j)
m = fluid.layers.tanh(c) * fluid.layers.sigmoid(o)
self.hidden_array[k] = m
self.cell_array[k] = c
self._input = m
if self._dropout is not None and self._dropout > 0.0:
self._input = fluid.layers.dropout(
self._input,
dropout_prob=self._dropout,
dropout_implementation='upscale_in_train')
res.append(
fluid.layers.reshape(
self._input, shape=[1, -1, self._hidden_size]))
real_res = fluid.layers.concat(res, 0)
real_res = fluid.layers.transpose(x=real_res, perm=[1, 0, 2])
last_hidden = fluid.layers.concat(self.hidden_array, 1)
last_hidden = fluid.layers.reshape(
last_hidden, shape=[-1, self._num_layers, self._hidden_size])
last_hidden = fluid.layers.transpose(x=last_hidden, perm=[1, 0, 2])
last_cell = fluid.layers.concat(self.cell_array, 1)
last_cell = fluid.layers.reshape(
last_cell, shape=[-1, self._num_layers, self._hidden_size])
last_cell = fluid.layers.transpose(x=last_cell, perm=[1, 0, 2])
return real_res, last_hidden, last_cell
class PtbModel(fluid.imperative.Layer):
def __init__(self,
hidden_size,
vocab_size,
num_layers=2,
num_steps=20,
init_scale=0.1,
dropout=None):
super(PtbModel, self).__init__()
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.init_scale = init_scale
self.num_layers = num_layers
self.num_steps = num_steps
self.dropout = dropout
self.simple_lstm_rnn = SimpleLSTMRNN(
hidden_size,
num_steps,
num_layers=num_layers,
init_scale=init_scale,
dropout=dropout)
self.embedding = Embedding(
size=[vocab_size, hidden_size],
dtype='float32',
is_sparse=False,
param_attr=fluid.ParamAttr(
name='embedding_para',
initializer=fluid.initializer.UniformInitializer(
low=-init_scale, high=init_scale)))
self.softmax_weight = fluid.layers.create_parameter(
[self.hidden_size, self.vocab_size],
dtype="float32",
name="softmax_weight",
default_initializer=fluid.initializer.UniformInitializer(
low=-self.init_scale, high=self.init_scale))
self.softmax_bias = fluid.layers.create_parameter(
[self.vocab_size],
dtype="float32",
name='softmax_bias',
default_initializer=fluid.initializer.UniformInitializer(
low=-self.init_scale, high=self.init_scale))
def _build_once(self, input, label, init_hidden, init_cell):
pass
def parameters(self):
parameters = self.simple_lstm_rnn.parameters() + [
self.softmax_weight, self.softmax_bias
] + self.embedding.parameters()
return parameters
def forward(self, input, label, init_hidden, init_cell):
init_h = fluid.layers.reshape(
init_hidden, shape=[self.num_layers, -1, self.hidden_size])
init_c = fluid.layers.reshape(
init_cell, shape=[self.num_layers, -1, self.hidden_size])
x_emb = self.embedding(input)
x_emb = fluid.layers.reshape(
x_emb, shape=[-1, self.num_steps, self.hidden_size])
if self.dropout is not None and self.dropout > 0.0:
x_emb = fluid.layers.dropout(
x_emb,
dropout_prob=self.drop_out,
dropout_implementation='upscale_in_train')
rnn_out, last_hidden, last_cell = self.simple_lstm_rnn(x_emb, init_h,
init_c)
rnn_out = fluid.layers.reshape(
rnn_out, shape=[-1, self.num_steps, self.hidden_size])
projection = fluid.layers.matmul(rnn_out, self.softmax_weight)
projection = fluid.layers.elementwise_add(projection, self.softmax_bias)
projection = fluid.layers.reshape(
projection, shape=[-1, self.vocab_size])
projection = fluid.layers.reshape(
projection, shape=[-1, self.vocab_size])
loss = fluid.layers.softmax_with_cross_entropy(
logits=projection, label=label, soft_label=False)
loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps])
loss = fluid.layers.reduce_mean(loss, dim=[0])
loss = fluid.layers.reduce_sum(loss)
loss.permissions = True
return loss, last_hidden, last_cell
class TestImperativePtbRnn(unittest.TestCase):
def test_ptb_rnn_cpu_float32(self):
seed = 90
hidden_size = 10
vocab_size = 1000
num_layers = 1
num_steps = 3
init_scale = 0.1
batch_size = 4
with fluid.imperative.guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
# TODO: marsyang1993 Change seed to
ptb_model = PtbModel(
hidden_size=hidden_size,
vocab_size=vocab_size,
num_layers=num_layers,
num_steps=num_steps,
init_scale=init_scale)
sgd = SGDOptimizer(learning_rate=1e-3)
dy_param_updated = dict()
dy_param_init = dict()
dy_loss = None
last_hidden = None
last_cell = None
for i in range(2):
x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1))
y_data = y_data.reshape((-1, 1))
init_hidden_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
init_cell_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
x = to_variable(x_data)
y = to_variable(y_data)
init_hidden = to_variable(init_hidden_data)
init_cell = to_variable(init_cell_data)
dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden,
init_cell)
if i == 0:
for param in ptb_model.parameters():
dy_param_init[param.name] = param._numpy()
dy_loss._backward()
sgd.minimize(dy_loss)
for param in ptb_model.parameters():
dy_param_updated[param.name] = param._numpy()
# print("dy_loss is {}".format(dy_loss._numpy()))
# print("last_hidden is {}".format(last_hidden._numpy()))
# print("last_cell is {}".format(last_cell._numpy()))
with new_program_scope():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
# TODO: marsyang1993 Change seed to
ptb_model = PtbModel(
hidden_size=hidden_size,
vocab_size=vocab_size,
num_layers=num_layers,
num_steps=num_steps,
init_scale=init_scale)
exe = fluid.Executor(fluid.CPUPlace())
sgd = SGDOptimizer(learning_rate=1e-3)
x = fluid.layers.data(name="x", shape=[-1, 3, 1], dtype='int64')
y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32')
init_hidden = fluid.layers.data(
name="init_hidden", shape=[1], dtype='float32')
init_cell = fluid.layers.data(
name="init_cell", shape=[1], dtype='float32')
static_loss, static_last_hidden, static_last_cell = ptb_model(
x, y, init_hidden, init_cell)
sgd.minimize(static_loss)
static_param_updated = dict()
static_param_init = dict()
static_param_name_list = list()
for param in ptb_model.parameters():
static_param_name_list.append(param.name)
out = exe.run(framework.default_startup_program(),
fetch_list=static_param_name_list)
for i in range(len(static_param_name_list)):
static_param_init[static_param_name_list[i]] = out[i]
static_loss_value = None
static_last_cell_value = None
static_last_hidden_value = None
for i in range(2):
x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1))
y_data = y_data.reshape((-1, 1))
init_hidden_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
init_cell_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
fetch_list = [static_loss, static_last_hidden, static_last_cell]
fetch_list.extend(static_param_name_list)
out = exe.run(fluid.default_main_program(),
feed={
"x": x_data,
"y": y_data,
"init_hidden": init_hidden_data,
"init_cell": init_cell_data
},
fetch_list=fetch_list)
static_loss_value = out[0]
static_last_cell_value = out[1]
static_last_hidden_value = out[2]
for k in range(3, len(out)):
static_param_updated[static_param_name_list[k - 3]] = out[k]
self.assertTrue(
np.allclose(static_loss_value.all(), dy_loss._numpy().all()))
self.assertTrue(
np.allclose(static_last_cell_value.all(),
last_cell._numpy().all()))
self.assertTrue(
np.allclose(static_last_hidden_value.all(),
last_hidden._numpy().all()))
for key, value in six.iteritems(static_param_init):
self.assertTrue(
np.allclose(value.all(), dy_param_init[key].all()))
for key, value in six.iteritems(static_param_updated):
self.assertTrue(
np.allclose(value.all(), dy_param_updated[key].all()))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册