未验证 提交 c62674f4 编写于 作者: C chengduo 提交者: GitHub

Refine StaticRnn (#16707)

* enable recurrent op test=develop
上级 e9409665
......@@ -301,12 +301,12 @@ paddle.fluid.layers.DynamicRNN.static_input (ArgSpec(args=['self', 'x'], varargs
paddle.fluid.layers.DynamicRNN.step_input (ArgSpec(args=['self', 'x', 'level'], varargs=None, keywords=None, defaults=(0,)), ('document', '7568c5ac7622a10288d3307a94134655'))
paddle.fluid.layers.DynamicRNN.update_memory (ArgSpec(args=['self', 'ex_mem', 'new_mem'], varargs=None, keywords=None, defaults=None), ('document', '5d83987da13b98363d6a807a52d8024f'))
paddle.fluid.layers.StaticRNN.__init__ (ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.layers.StaticRNN.memory (ArgSpec(args=['self', 'init', 'shape', 'batch_ref', 'init_value', 'init_batch_dim_idx', 'ref_batch_dim_idx'], varargs=None, keywords=None, defaults=(None, None, None, 0.0, 0, 1)), ('document', 'c24e368e23afac1ed91a78a639d7a9c7'))
paddle.fluid.layers.StaticRNN.output (ArgSpec(args=['self'], varargs='outputs', keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.layers.StaticRNN.step (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.layers.StaticRNN.step_input (ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.layers.StaticRNN.step_output (ArgSpec(args=['self', 'o'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.layers.StaticRNN.update_memory (ArgSpec(args=['self', 'mem', 'var'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.layers.StaticRNN.memory (ArgSpec(args=['self', 'init', 'shape', 'batch_ref', 'init_value', 'init_batch_dim_idx', 'ref_batch_dim_idx'], varargs=None, keywords=None, defaults=(None, None, None, 0.0, 0, 1)), ('document', '72530f299d6451a567cf4a12dc3fb1ff'))
paddle.fluid.layers.StaticRNN.output (ArgSpec(args=['self'], varargs='outputs', keywords=None, defaults=None), ('document', 'df6ceab6e6c9bd31e97914d7e7538137'))
paddle.fluid.layers.StaticRNN.step (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6d3e0a5d9aa519a9773a36e1620ea9b7'))
paddle.fluid.layers.StaticRNN.step_input (ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None), ('document', '903387ec11f3d0bf46821d31a68cffa5'))
paddle.fluid.layers.StaticRNN.step_output (ArgSpec(args=['self', 'o'], varargs=None, keywords=None, defaults=None), ('document', '252890d4c3199a7623ab8667e13fd837'))
paddle.fluid.layers.StaticRNN.update_memory (ArgSpec(args=['self', 'mem', 'var'], varargs=None, keywords=None, defaults=None), ('document', '7a0000520f179f35239956a5ba55119f'))
paddle.fluid.layers.reorder_lod_tensor_by_rank (ArgSpec(args=['x', 'rank_table'], varargs=None, keywords=None, defaults=None), ('document', '3545f529ef04e8f6ecb76b47fa3df01a'))
paddle.fluid.layers.Print (ArgSpec(args=['input', 'first_n', 'message', 'summarize', 'print_tensor_name', 'print_tensor_type', 'print_tensor_shape', 'print_tensor_lod', 'print_phase'], varargs=None, keywords=None, defaults=(-1, None, -1, True, True, True, True, 'both')), ('document', '5fef91b0e21c93610785f2b1f7161732'))
paddle.fluid.layers.is_empty (ArgSpec(args=['x', 'cond'], varargs=None, keywords=None, defaults=(None,)), ('document', 'bbe578dbb49ad13e15b014e98c22b519'))
......
......@@ -23,6 +23,7 @@ constexpr char kInitialStates[] = "initial_states";
constexpr char kParameters[] = "parameters";
constexpr char kOutputs[] = "outputs";
constexpr char kStepScopes[] = "step_scopes";
constexpr char kHasStates[] = "has_states";
constexpr char kExStates[] = "ex_states";
constexpr char kStates[] = "states";
constexpr char kStepBlock[] = "sub_block";
......@@ -241,11 +242,16 @@ class RecurrentOp : public RecurrentBase {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
bool has_state = Attr<bool>(kHasStates);
auto seq_len = static_cast<size_t>(this->GetSequenceLength(scope));
VLOG(3) << "Static RNN input sequence length = " << seq_len;
StepScopes scopes = CreateStepScopes(scope, seq_len);
auto reverse = Attr<bool>(kReverse);
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::Executor executor(place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
......@@ -269,15 +275,17 @@ class RecurrentOp : public RecurrentBase {
inside->Resize(framework::make_ddim(dims));
});
if (i == 0) {
// Link initial states --> ex_states
LinkTensor(scope, Inputs(kInitialStates), &cur_scope,
Attr<std::vector<std::string>>(kExStates));
} else {
auto &ex_scope = scopes.ExScope();
// Link ex_scope::state --> cur_scope::ex_state
LinkTensor(ex_scope, Attr<std::vector<std::string>>(kStates),
&cur_scope, Attr<std::vector<std::string>>(kExStates));
if (has_state) {
if (i == 0) {
// Link initial states --> ex_states
LinkTensor(scope, Inputs(kInitialStates), &cur_scope,
Attr<std::vector<std::string>>(kExStates));
} else {
auto &ex_scope = scopes.ExScope();
// Link ex_scope::state --> cur_scope::ex_state
LinkTensor(ex_scope, Attr<std::vector<std::string>>(kStates),
&cur_scope, Attr<std::vector<std::string>>(kExStates));
}
}
// Every inputs are linked now, execute!
......@@ -286,11 +294,6 @@ class RecurrentOp : public RecurrentBase {
std::vector<std::string>() /*skip_ref_cnt_vars*/,
true /*force_disable_gc*/);
// get device context from pool
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
// Copy inside::output -> outside::output
// outside::output[seq_offset: seq_offset + 1] = inside::output
this->LinkTensorWithCallback(
......@@ -333,13 +336,13 @@ class RecurrentGradOp : public RecurrentBase {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto seq_len = static_cast<size_t>(GetSequenceLength(scope));
bool has_state = Attr<bool>(kHasStates);
const size_t seq_len = static_cast<size_t>(GetSequenceLength(scope));
StepScopes scopes = CreateStepScopes(scope, seq_len);
auto reverse = Attr<bool>(kReverse);
framework::Executor executor(place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program();
// get device context from pool
......@@ -350,6 +353,7 @@ class RecurrentGradOp : public RecurrentBase {
size_t seq_offset = reverse ? step_id : seq_len - step_id - 1;
VLOG(3) << "Recurrent backward operate at the time step " << seq_offset;
auto &cur_scope = scopes.CurScope();
// Link outside::output_grads --> inside::output_grads
// inside::output_grad = outside::output_grad[seq_offset:seq_offset+1]
LinkTensorWithCallback(
......@@ -370,30 +374,32 @@ class RecurrentGradOp : public RecurrentBase {
VLOG(10) << " RNN output gradients = [" << sout.str() << "]";
}
// Link states
// if cur_scope::cur_state_grad in out_grads:
// cur_scope::cur_state_grad += ex_scope::ex_state_grad
// else:
// ex_scope::ex_state_grad --> cur_scope::cur_state_grad
if (step_id != 0) { // not at beginning
auto &ex_scope = scopes.ExScope();
auto ex_state_grads =
GradVarLists(Attr<std::vector<std::string>>(kExStates));
auto cur_state_grads =
GradVarLists(Attr<std::vector<std::string>>(kStates));
PADDLE_ENFORCE_EQ(ex_state_grads.size(), cur_state_grads.size());
for (size_t i = 0; i < ex_state_grads.size(); ++i) {
auto &cur_grad = cur_state_grads[i];
auto &ex_grad = ex_state_grads[i];
auto &ex_tensor =
ex_scope.FindVar(ex_grad)->Get<framework::LoDTensor>();
VLOG(10) << " RNN link " << cur_grad << " from " << ex_grad;
auto *cur_grad_var = cur_scope.Var(cur_grad);
auto cur_grad_tensor =
cur_grad_var->GetMutable<framework::LoDTensor>();
framework::TensorCopy(ex_tensor, place, dev_ctx, cur_grad_tensor);
if (has_state) {
// Link states
// if cur_scope::cur_state_grad in out_grads:
// cur_scope::cur_state_grad += ex_scope::ex_state_grad
// else:
// ex_scope::ex_state_grad --> cur_scope::cur_state_grad
if (step_id != 0) { // not at beginning
auto &ex_scope = scopes.ExScope();
auto ex_state_grads =
GradVarLists(Attr<std::vector<std::string>>(kExStates));
auto cur_state_grads =
GradVarLists(Attr<std::vector<std::string>>(kStates));
PADDLE_ENFORCE_EQ(ex_state_grads.size(), cur_state_grads.size());
for (size_t i = 0; i < ex_state_grads.size(); ++i) {
auto &cur_grad = cur_state_grads[i];
auto &ex_grad = ex_state_grads[i];
auto &ex_tensor =
ex_scope.FindVar(ex_grad)->Get<framework::LoDTensor>();
VLOG(10) << " RNN link " << cur_grad << " from " << ex_grad;
auto *cur_grad_var = cur_scope.Var(cur_grad);
auto cur_grad_tensor =
cur_grad_var->GetMutable<framework::LoDTensor>();
framework::TensorCopy(ex_tensor, place, dev_ctx, cur_grad_tensor);
}
}
}
......@@ -442,8 +448,8 @@ class RecurrentGradOp : public RecurrentBase {
}
auto new_inside_name = cur_scope.Rename(inside_grad_name);
// sum gradient
// sum gradient
auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {pg_names[param_id], new_inside_name}}},
{{"Out", {pg_names[param_id]}}},
......@@ -475,22 +481,33 @@ class RecurrentGradOp : public RecurrentBase {
true /*is_backward*/);
VLOG(5) << "Link outside gradient finished ";
if (step_id + 1 == seq_len) { // at_end
// copy initialize states gradient from inside to outside
LinkTensorWithCallback(
cur_scope, GradVarLists(Attr<std::vector<std::string>>(kExStates)),
scope, Outputs(kInitStateGrads),
[&](const framework::LoDTensor &inside,
framework::LoDTensor *outside) {
outside->Resize(inside.dims());
outside->mutable_data(place, inside.type());
framework::TensorCopy(inside, place, dev_ctx, outside);
},
true /*is_backward*/);
VLOG(5) << "Link initialize state gradient finished ";
if (has_state) {
if (step_id + 1 == seq_len) { // at_end
// copy initialize states gradient from inside to outside
LinkTensorWithCallback(
cur_scope,
GradVarLists(Attr<std::vector<std::string>>(kExStates)), scope,
Outputs(kInitStateGrads),
[&](const framework::LoDTensor &inside,
framework::LoDTensor *outside) {
outside->Resize(inside.dims());
outside->mutable_data(place, inside.type());
framework::TensorCopy(inside, place, dev_ctx, outside);
},
true /*is_backward*/);
VLOG(5) << "Link initialize state gradient finished ";
}
}
scopes.Next();
}
// Delete the scope of StepScopes
dev_ctx.Wait();
auto *var = scope.FindVar(Input(kStepScopes));
PADDLE_ENFORCE(var != nullptr);
auto step_scopes = var->GetMutable<StepScopeVar>();
for (auto *sub_scope : *step_scopes) {
const_cast<framework::Scope &>(scope).DeleteScope(sub_scope);
}
}
private:
......@@ -541,6 +558,7 @@ class RecurrentOpProtoMaker : public framework::OpProtoAndCheckerMaker {
.AsDuplicable();
AddOutput(kStepScopes,
"StepScopes contain all local variables in each time step.");
AddAttr<bool>(kHasStates, "Whether has states.").SetDefault(false);
AddAttr<std::vector<std::string>>(kExStates,
string::Sprintf(
R"DOC(The ex-state variable names.
......@@ -624,20 +642,44 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker {
class RecurrentGradOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
std::vector<std::string> input{kInputs, kInitialStates};
std::vector<std::string> output{kOutputs};
for (auto &s : input) {
// NOTE(zcd): In some case, some of kInputs doesn't have gradient.
PADDLE_ENFORCE(ctx->HasInputs(s));
}
for (auto &s : output) {
PADDLE_ENFORCE(ctx->HasInputs(s));
// In some case the kInitialStates is empty.
// If the kInitialStates is empty, all the states should be empty.
if (!ctx->HasInputs(kInitialStates)) {
PADDLE_ENFORCE_EQ(
ctx->Attrs().Get<std::vector<std::string>>(kExStates).size(), 0,
"The Attr(%s) should be empty.", kExStates);
PADDLE_ENFORCE_EQ(
ctx->Attrs().Get<std::vector<std::string>>(kStates).size(), 0,
"The Attr(%s) should be empty.", kStates);
}
for (auto &s : input) {
ctx->SetOutputsDim(framework::GradVarName(s), ctx->GetInputsDim(s));
PADDLE_ENFORCE(ctx->HasInputs(kInputs),
"The input(%s) should not be empty.", kInputs);
PADDLE_ENFORCE(ctx->HasInputs(kOutputs),
"The input(%s) should not be empty.", kOutputs);
// In some case the kInitialStates is empty.
if (ctx->HasInputs(kInitialStates)) {
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(kInitialStates)),
"The output of(%s) should not be empty.",
framework::GradVarName(kInitialStates));
ctx->SetOutputsDim(framework::GradVarName(kInitialStates),
ctx->GetInputsDim(kInitialStates));
}
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(kInputs)),
"The output of(%s) should not be empty.",
framework::GradVarName(kInputs));
ctx->SetOutputsDim(framework::GradVarName(kInputs),
ctx->GetInputsDim(kInputs));
// In some case the kParameters is empty.
if (ctx->HasInputs(kParameters)) {
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(kParameters)));
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(kParameters)),
"The output of(%s) should not be empty.",
framework::GradVarName(kParameters));
ctx->SetOutputsDim(framework::GradVarName(kParameters),
ctx->GetInputsDim(kParameters));
}
......
......@@ -40,9 +40,12 @@ class RNNMemoryHelperOp : public framework::OperatorBase {
"Cannot find out_var in scope, out_var_name is %s",
out_name);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
auto *out_tensor = out_var->GetMutable<framework::LoDTensor>();
auto &mem_tensor = mem_var->Get<framework::LoDTensor>();
framework::TensorCopySync(mem_tensor, dev_place, out_tensor);
framework::TensorCopy(mem_tensor, dev_place, dev_ctx, out_tensor);
out_tensor->set_lod(mem_tensor.lod());
}
};
......@@ -92,6 +95,9 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
"Cannot find in_grad_var in scope, name is %s",
in_grad_var_name);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
if (out_grad_var == nullptr) {
VLOG(5) << "Using fill constant 0 as starting gradient";
auto in_var_name = Input("X");
......@@ -109,7 +115,8 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
} else {
auto &out_grad_tensor = out_grad_var->Get<framework::LoDTensor>();
auto *in_grad_tensor = in_grad_var->GetMutable<framework::LoDTensor>();
framework::TensorCopySync(out_grad_tensor, dev_place, in_grad_tensor);
framework::TensorCopy(out_grad_tensor, dev_place, dev_ctx,
in_grad_tensor);
in_grad_tensor->set_lod(out_grad_tensor.lod());
}
}
......
......@@ -267,8 +267,44 @@ class StaticRNN(object):
"""
StaticRNN class.
StaticRNN class is used to create a StaticRNN. The RNN will have its
own parameters like inputs, outputs, memories, status and length.
The StaticRNN can process a batch of sequence data. The length of each
sample sequence must be equal. The StaticRNN will have its own parameters
like inputs, outputs, memories. **Note that the first dimension of inputs
represents sequence length, and all the sequence length of inputs must be
the same. And the meaning of each axis of input and output are the same.**
Examples:
>>> import paddle.fluid as fluid
>>> import paddle.fluid.layers as layers
>>>
>>> vocab_size, hidden_size=10000, 200
>>> x = layers.data(name="x", shape=[-1, 1, 1], dtype='int64')
>>> x_emb = layers.embedding(
>>> input=x,
>>> size=[vocab_size, hidden_size],
>>> dtype='float32',
>>> is_sparse=False)
>>> x_emb = layers.transpose(x_emb, perm=[1, 0, 2])
>>>
>>> rnn = fluid.layers.StaticRNN()
>>> with rnn.step():
>>> word = rnn.step_input(x_emb)
>>> prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
>>> hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
>>> rnn.update_memory(prev, hidden) # set prev to hidden
>>> rnn.step_output(hidden)
>>>
>>> result = rnn()
The StaticRNN will unfold sequence into time steps. Users need to define
how to process each time step during the :code:`with` step.
The :code:`memory` is used as a staging data cross time step. The initial
value of memory can be a variable that is filled with a constant value or
a specified variable.
The StaticRNN can mark multiple variables as its output. Use `rnn()` to
get the output sequence.
"""
BEFORE_RNN_BLOCK = 0
IN_RNN_BLOCK = 1
......@@ -284,6 +320,9 @@ class StaticRNN(object):
self.seq_len = None
def step(self):
"""
The block for user to define operators in RNN.
"""
return BlockGuardWithCompletion(self)
def _assert_in_rnn_block_(self, method):
......@@ -298,13 +337,28 @@ class StaticRNN(object):
init_batch_dim_idx=0,
ref_batch_dim_idx=1):
"""
Create a memory variable for static rnn.
If the :code:`init` is not None, :code:`memory` will be initialized by
this Variable. If the :code:`init` is None, :code:`shape` and :code:`batch_ref`
must be set, and this function will initialize a :code:`init` Variable.
Args:
init: boot memory, if not set, a shape, batch_ref must be provided
shape: shape of the boot memory
batch_ref: batch size reference variable
init_value: the init value of boot memory
init_batch_dim_idx: the index of batch size in init's dimension
ref_batch_dim_idx: the index of batch size in batch_ref's dimension
init(Variable|None): The initialized variable. If it is not set,
:code:`shape` and :code:`batch_ref` must be provided.
Default: None.
shape(list|tuple): The shape of the boot memory. NOTE the shape
does not contain batch_size. Default: None.
batch_ref(Variable|None): The batch size reference Variable.
Default: None.
init_value(float): the init value of boot memory. Default: 0.0.
init_batch_dim_idx(int): the batch_size axis of the
:code:`init` Variable. Default: 0.
ref_batch_dim_idx(int): the batch_size axis of the
:code:`batch_ref` Variable. Default: 1.
Returns:
The memory variable.
"""
self._assert_in_rnn_block_('memory')
if init is None:
......@@ -343,6 +397,16 @@ class StaticRNN(object):
return pre_mem
def step_input(self, x):
"""
Mark a sequence as a StaticRNN input.
Args:
x(Variable): The input sequence, the shape of x
should be [seq_len, ...].
Returns:
The current time step in the input sequence.
"""
self._assert_in_rnn_block_('step_input')
if not isinstance(x, Variable):
raise TypeError("step input takes a Variable")
......@@ -357,6 +421,15 @@ class StaticRNN(object):
return ipt
def step_output(self, o):
"""
Mark a sequence as a StaticRNN output.
Args:
o(Variable): The output sequence.
Returns:
None.
"""
self._assert_in_rnn_block_('step_output')
if not isinstance(o, Variable):
raise TypeError("step output takes a Variable")
......@@ -376,10 +449,30 @@ class StaticRNN(object):
self.outputs.append(out_var)
def output(self, *outputs):
"""
Mark the StaticRNN output variables.
Args:
outputs: The output Variables.
Returns:
None
"""
for each in outputs:
self.step_output(each)
def update_memory(self, mem, var):
"""
Update the memory from ex_mem to new_mem. NOTE that the shape and data
type of :code:`ex_mem` and :code:`new_mem` must be same.
Args:
mem(Variable): the memory variable.
var(Variable): the plain variable generated in RNN block.
Returns:
None
"""
if not isinstance(mem, Variable) or not isinstance(var, Variable):
raise TypeError("update memory should take variables")
self.memories[mem.name].mem = var
......@@ -419,6 +512,9 @@ class StaticRNN(object):
for m in self.memories:
local_inputs.add(m)
# NOTE(zcd): the params have two categories of variables.
# - the variables that are the out of StaticRnn.
# - the variables that are the parameters of some layers, for example, conv2d.
params = list()
for op in rnn_block.ops:
assert isinstance(op, Operator)
......@@ -435,17 +531,19 @@ class StaticRNN(object):
inlinks = [parent_block.var(i.name) for i in self.inputs]
outlinks = self.outputs
# NOTE(zcd): the states maybe empty in some case.
boot_memories = []
pre_memories = []
memories = []
for _, mem in six.iteritems(self.memories):
boot_memories.append(mem.init)
pre_memories.append(mem.pre_mem.name)
assert mem.mem is not None, "%s should be updated in every step." % (
mem.init.name)
mem_var = rnn_block.var(mem.mem.name)
assert isinstance(mem_var, Variable)
new_mem = self.helper.create_variable_for_type_inference(
dtype=mem_var.dtype)
rnn_block.append_op(
type='rnn_memory_helper',
inputs={'X': [mem_var]},
......@@ -464,6 +562,7 @@ class StaticRNN(object):
outputs={'outputs': outlinks,
'step_scopes': [step_scope]},
attrs={
'has_states': len(pre_memories) > 0,
'ex_states': pre_memories,
'states': memories,
'sub_block': rnn_block
......
......@@ -25,7 +25,6 @@ endif()
list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290
list(REMOVE_ITEM TEST_OPS test_modified_huber_loss_op) # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5184
list(REMOVE_ITEM TEST_OPS test_lstm_unit_op) # # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5185
list(REMOVE_ITEM TEST_OPS test_recurrent_op) # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/6152
list(REMOVE_ITEM TEST_OPS test_cond_op) # FIXME(qijun): https://github.com/PaddlePaddle/Paddle/issues/5101#issuecomment-339814957
list(REMOVE_ITEM TEST_OPS op_test) # op_test is a helper python file, not a test
......
......@@ -15,7 +15,7 @@
from __future__ import print_function
import unittest
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.framework import Program, grad_var_name
from paddle.fluid.executor import Executor
......@@ -115,10 +115,6 @@ class RecurrentOpTest1(unittest.TestCase):
def setup_program(self):
self.main_program = Program()
self.startup_program = Program()
self.p_info = {
"main_program": self.main_program,
"startup_program": self.startup_program
}
self.place = core.CPUPlace()
def setUp(self):
......@@ -129,33 +125,29 @@ class RecurrentOpTest1(unittest.TestCase):
self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
self.py_rnn = PySimpleRNN1(self.input_shape, self.output_shape)
self.output = layers.mean(self.create_rnn_op(), **self.p_info)
with fluid.program_guard(self.main_program, self.startup_program):
self.output = layers.mean(self.create_rnn_op())
def create_rnn_op(self):
x = layers.data(
shape=[self.sent_len, self.batch_size, self.input_dim],
dtype='float32',
name='x',
append_batch_size=False,
**self.p_info)
append_batch_size=False)
x.stop_gradient = False
h_boot = layers.data(
shape=[self.input_dim],
dtype='float32',
name='h_boot',
**self.p_info)
shape=[self.input_dim], dtype='float32', name='h_boot')
h_boot.stop_gradient = False
rnn = layers.StaticRNN(main_program=self.main_program)
rnn = layers.StaticRNN()
with rnn.step():
h_pre = rnn.memory(init=h_boot)
x_t = rnn.step_input(x)
h = layers.scale(
x=layers.elementwise_add(
x=h_pre, y=x_t, **self.p_info),
scale=self.py_rnn.scale,
**self.p_info)
x=h_pre, y=x_t),
scale=self.py_rnn.scale)
rnn.update_memory(h_pre, h)
rnn.output(h)
......@@ -193,7 +185,8 @@ class RecurrentOpTest1(unittest.TestCase):
def test_backward(self):
self.check_forward()
append_backward(self.output)
with fluid.program_guard(self.main_program, self.startup_program):
append_backward(self.output)
ana_grad = [np.array(x) for x in self.backward()]
......@@ -205,12 +198,8 @@ class RecurrentOpTest1(unittest.TestCase):
num_grad[idx], ana_grad[idx], rtol=0.1).all())
def check_forward(self):
print('test recurrent op forward')
pd_output = self.forward()
py_output = self.py_rnn.forward()
print('pd_output', pd_output)
print
print('py_output', py_output)
self.assertEqual(pd_output.shape, py_output.shape)
self.assertTrue(np.isclose(pd_output, py_output, rtol=0.1).all())
......@@ -263,24 +252,21 @@ class RecurrentOpTest2(RecurrentOpTest1):
self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
self.py_rnn = PySimpleRNN2(self.input_shape, self.output_shape)
self.output = layers.mean(self.create_rnn_op(), **self.p_info)
with fluid.program_guard(self.main_program, self.startup_program):
self.output = layers.mean(self.create_rnn_op())
def create_rnn_op(self):
x = layers.data(
shape=[self.sent_len, self.batch_size, self.input_dim],
dtype='float32',
name='x',
append_batch_size=False,
**self.p_info)
append_batch_size=False)
x.stop_gradient = False
h_boot = layers.data(
shape=[self.input_dim],
dtype='float32',
name='h_boot',
**self.p_info)
shape=[self.input_dim], dtype='float32', name='h_boot')
h_boot.stop_gradient = False
rnn = layers.StaticRNN(main_program=self.main_program)
rnn = layers.StaticRNN()
with rnn.step():
h_pre = rnn.memory(init=h_boot)
x_t = rnn.step_input(x)
......@@ -288,18 +274,13 @@ class RecurrentOpTest2(RecurrentOpTest1):
temp_l = layers.fc(input=x_t,
size=self.input_dim,
param_attr='W',
bias_attr=False,
**self.p_info)
bias_attr=False)
temp_r = layers.fc(input=h_pre,
size=self.input_dim,
param_attr='U',
bias_attr=False,
**self.p_info)
bias_attr=False)
h = layers.sigmoid(
x=layers.elementwise_add(
x=temp_l, y=temp_r, **self.p_info),
**self.p_info)
h = layers.sigmoid(x=layers.elementwise_add(x=temp_l, y=temp_r))
rnn.update_memory(h_pre, h)
rnn.output(h)
......@@ -362,40 +343,38 @@ class RecurrentOpMultipleMemoryTest(RecurrentOpTest1):
self.py_rnn = RecurrentOpMultipleMemoryTest.PySimpleRNN3(
self.input_shape, self.output_shape)
self.output = layers.mean(self.create_rnn_op(), **self.p_info)
with fluid.program_guard(self.main_program, self.startup_program):
self.output = layers.mean(self.create_rnn_op())
def create_rnn_op(self):
x = layers.data(
shape=[self.sent_len, self.batch_size, self.input_dim],
dtype='float32',
name='x',
append_batch_size=False,
**self.p_info)
append_batch_size=False)
x.stop_gradient = False
h_boot1 = layers.data(
shape=[self.batch_size, self.input_dim],
dtype='float32',
name='h_boot1',
append_batch_size=False,
**self.p_info)
append_batch_size=False)
h_boot1.stop_gradient = False
h_boot2 = layers.data(
shape=[self.batch_size, self.input_dim],
dtype='float32',
name='h_boot2',
append_batch_size=False,
**self.p_info)
append_batch_size=False)
h_boot2.stop_gradient = False
rnn = layers.StaticRNN(main_program=self.main_program)
rnn = layers.StaticRNN()
with rnn.step():
h_pre1 = rnn.memory(init=h_boot1)
h_pre2 = rnn.memory(init=h_boot2)
x_t = rnn.step_input(x)
mem1 = layers.scale(x=h_pre1, scale=1.0, **self.p_info)
mem2 = layers.scale(x=h_pre2, scale=1.0, **self.p_info)
out = layers.sums(input=[mem1, x_t, mem2], **self.p_info)
mem1 = layers.scale(x=h_pre1, scale=1.0)
mem2 = layers.scale(x=h_pre2, scale=1.0)
out = layers.sums(input=[mem1, x_t, mem2])
rnn.update_memory(h_pre1, mem1)
rnn.update_memory(h_pre2, mem2)
......@@ -446,23 +425,23 @@ class RecurrentOpNoMemBootTest(RecurrentOpTest1):
self.output_shape = (self.sent_len, self.batch_size, self.input_dim)
self.py_rnn = RecurrentOpNoMemBootTest.PySimpleRNN4(self.input_shape,
self.output_shape)
self.output = layers.mean(self.create_rnn_op(), **self.p_info)
print(self.main_program)
with fluid.program_guard(self.main_program, self.startup_program):
self.output = layers.mean(self.create_rnn_op())
def create_rnn_op(self):
x = layers.data(
shape=[self.sent_len, self.batch_size, self.input_dim],
dtype='float32',
name='x',
append_batch_size=False,
**self.p_info)
append_batch_size=False)
x.stop_gradient = False
rnn = layers.StaticRNN(main_program=self.main_program)
rnn = layers.StaticRNN()
with rnn.step():
mem_pre = rnn.memory(shape=[-1, self.input_dim], batch_ref=x)
x_t = rnn.step_input(x)
mem = layers.elementwise_add(x=mem_pre, y=x_t, **self.p_info)
mem = layers.elementwise_add(x=mem_pre, y=x_t)
rnn.update_memory(mem_pre, mem)
rnn.output(mem)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册