未验证 提交 f05c213f 编写于 作者: H hong 提交者: GitHub

fix basic gru lstm parameter attr bug; test=develop (#22508)

* fix basic gru lstm parameter attr bug; test=develop

* fix bias attr bug; test=develop

* add basic lstm gru name unitest; test=develop
上级 0fb5ea78
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from paddle.fluid import layers, unique_name
from paddle.fluid.dygraph import Layer
from paddle.fluid.dygraph.layer_object_helper import LayerObjectHelper
......@@ -98,23 +100,41 @@ class BasicGRUUnit(Layer):
self._input_size = input.shape[-1]
assert (self._input_size > 0)
if self._param_attr is not None and self._param_attr.name is not None:
gate_param_attr = copy.deepcopy(self._param_attr)
candidate_param_attr = copy.deepcopy(self._param_attr)
gate_param_attr.name += "_gate"
candidate_param_attr.name += "_candidate"
else:
gate_param_attr = self._param_attr
candidate_param_attr = self._param_attr
self._gate_weight = self.create_parameter(
attr=self._param_attr,
attr=gate_param_attr,
shape=[self._input_size + self._hiden_size, 2 * self._hiden_size],
dtype=self._dtype)
self._candidate_weight = self.create_parameter(
attr=self._param_attr,
attr=candidate_param_attr,
shape=[self._input_size + self._hiden_size, self._hiden_size],
dtype=self._dtype)
if self._bias_attr is not None and self._bias_attr.name is not None:
gate_bias_attr = copy.deepcopy(self._bias_attr)
candidate_bias_attr = copy.deepcopy(self._bias_attr)
gate_bias_attr.name += "_gate"
candidate_bias_attr.name += "_candidate"
else:
gate_bias_attr = self._bias_attr
candidate_bias_attr = self._bias_attr
self._gate_bias = self.create_parameter(
attr=self._bias_attr,
attr=gate_bias_attr,
shape=[2 * self._hiden_size],
dtype=self._dtype,
is_bias=True)
self._candidate_bias = self.create_parameter(
attr=self._bias_attr,
attr=candidate_bias_attr,
shape=[self._hiden_size],
dtype=self._dtype,
is_bias=True)
......@@ -244,17 +264,39 @@ def basic_gru(input,
for i in range(num_layers):
new_name = name + "_layers_" + str(i)
if param_attr is not None and param_attr.name is not None:
layer_param_attr = copy.deepcopy(param_attr)
layer_param_attr.name += "_fw_w_" + str(i)
else:
layer_param_attr = param_attr
if bias_attr is not None and bias_attr.name is not None:
layer_bias_attr = copy.deepcopy(bias_attr)
layer_bias_attr.name += "_fw_b_" + str(i)
else:
layer_bias_attr = bias_attr
fw_unit_list.append(
BasicGRUUnit(new_name, hidden_size, param_attr, bias_attr,
gate_activation, activation, dtype))
BasicGRUUnit(new_name, hidden_size, layer_param_attr,
layer_bias_attr, gate_activation, activation, dtype))
if bidirectional:
bw_unit_list = []
for i in range(num_layers):
new_name = name + "_reverse_layers_" + str(i)
if param_attr is not None and param_attr.name is not None:
layer_param_attr = copy.deepcopy(param_attr)
layer_param_attr.name += "_bw_w_" + str(i)
else:
layer_param_attr = param_attr
if bias_attr is not None and bias_attr.name is not None:
layer_bias_attr = copy.deepcopy(bias_attr)
layer_bias_attr.name += "_bw_b_" + str(i)
else:
layer_bias_attr = bias_attr
bw_unit_list.append(
BasicGRUUnit(new_name, hidden_size, param_attr, bias_attr,
gate_activation, activation, dtype))
BasicGRUUnit(new_name, hidden_size, layer_param_attr,
layer_bias_attr, gate_activation, activation,
dtype))
if batch_first:
input = layers.transpose(input, [1, 0, 2])
......@@ -479,12 +521,22 @@ def basic_lstm(input,
for i in range(num_layers):
new_name = name + "_layers_" + str(i)
if param_attr is not None and param_attr.name is not None:
layer_param_attr = copy.deepcopy(param_attr)
layer_param_attr.name += "_fw_w_" + str(i)
else:
layer_param_attr = param_attr
if bias_attr is not None and bias_attr.name is not None:
layer_bias_attr = copy.deepcopy(bias_attr)
layer_bias_attr.name += "_fw_b_" + str(i)
else:
layer_bias_attr = bias_attr
fw_unit_list.append(
BasicLSTMUnit(
new_name,
hidden_size,
param_attr=param_attr,
bias_attr=bias_attr,
param_attr=layer_param_attr,
bias_attr=layer_bias_attr,
gate_activation=gate_activation,
activation=activation,
forget_bias=forget_bias,
......@@ -494,12 +546,22 @@ def basic_lstm(input,
for i in range(num_layers):
new_name = name + "_reverse_layers_" + str(i)
if param_attr is not None and param_attr.name is not None:
layer_param_attr = copy.deepcopy(param_attr)
layer_param_attr.name += "_bw_w_" + str(i)
else:
layer_param_attr = param_attr
if bias_attr is not None and bias_attr.name is not None:
layer_bias_attr = copy.deepcopy(bias_attr)
layer_bias_attr.name += "_bw_b_" + str(i)
else:
layer_bias_attr = param_attr
bw_unit_list.append(
BasicLSTMUnit(
new_name,
hidden_size,
param_attr=param_attr,
bias_attr=bias_attr,
param_attr=layer_param_attr,
bias_attr=layer_bias_attr,
gate_activation=gate_activation,
activation=activation,
forget_bias=forget_bias,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core
from paddle.fluid.contrib.layers import basic_gru, basic_lstm
from paddle.fluid.executor import Executor
from paddle.fluid import framework
from test_imperative_base import new_program_scope
import numpy as np
class TestBasicGRUApiName(unittest.TestCase):
def setUp(self):
self.name_set = set([
"test1_fw_w_0_gate", "test1_fw_w_0_candidate", "test1_fw_b_0_gate",
"test1_fw_b_0_candidate", "test1_bw_w_0_gate",
"test1_bw_w_0_candidate", "test1_bw_b_0_gate",
"test1_bw_b_0_candidate"
])
def test_name(self):
batch_size = 20
input_size = 128
hidden_size = 256
num_layers = 1
dropout = 0.5
bidirectional = True
batch_first = False
with new_program_scope():
input = layers.data(
name="input",
shape=[-1, batch_size, input_size],
dtype='float32')
pre_hidden = layers.data(
name="pre_hidden", shape=[-1, hidden_size], dtype='float32')
sequence_length = layers.data(
name="sequence_length", shape=[-1], dtype='int32')
rnn_out, last_hidden = basic_gru( input, pre_hidden, hidden_size, num_layers = num_layers, \
sequence_length = sequence_length, dropout_prob=dropout, bidirectional = bidirectional, \
batch_first = batch_first, param_attr=fluid.ParamAttr( name ="test1"), bias_attr=fluid.ParamAttr( name="test1"), name="basic_gru")
var_list = fluid.io.get_program_parameter(
fluid.default_main_program())
for var in var_list:
self.assertTrue(var.name in self.name_set)
class TestBasicLSTMApiName(unittest.TestCase):
def setUp(self):
self.name_set = set([
"test1_fw_w_0", "test1_fw_b_0", "test1_fw_w_1", "test1_fw_b_1",
"test1_bw_w_0", "test1_bw_b_0", "test1_bw_w_1", "test1_bw_b_1"
])
def test_name(self):
batch_size = 20
input_size = 128
hidden_size = 256
num_layers = 2
dropout = 0.5
bidirectional = True
batch_first = False
with new_program_scope():
input = layers.data(
name="input",
shape=[-1, batch_size, input_size],
dtype='float32')
pre_hidden = layers.data(
name="pre_hidden", shape=[-1, hidden_size], dtype='float32')
pre_cell = layers.data(
name="pre_cell", shape=[-1, hidden_size], dtype='float32')
sequence_length = layers.data(
name="sequence_length", shape=[-1], dtype='int32')
rnn_out, last_hidden, last_cell = basic_lstm( input, pre_hidden, pre_cell, \
hidden_size, num_layers = num_layers, \
sequence_length = sequence_length, dropout_prob=dropout, bidirectional = bidirectional, \
param_attr=fluid.ParamAttr( name ="test1"), bias_attr=fluid.ParamAttr( name = "test1"), \
batch_first = batch_first)
var_list = fluid.io.get_program_parameter(
fluid.default_main_program())
for var in var_list:
self.assertTrue(var.name in self.name_set)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册