提交 aeb2d848 编写于 作者: H Haonan 提交者: GitHub

Merge pull request #76 from emailweixu/fix_RecurrentGradientMachine

Further fix for memory of RecurrentGradientMachine
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "hl_matrix_apply.cuh"
#include "hl_sequence.h"
#include "paddle/utils/Logging.h"
#include "hl_device_functions.cuh"
DEFINE_MATRIX_UNARY_OP(Zero, a = 0);
DEFINE_MATRIX_TERNARY_PARAMETER_OP(_add, TWO_PARAMETER, c = p1*a + p2*b);
......
......@@ -434,23 +434,25 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs,
}
}
seqLengthAndStart_.clear();
info_.clear();
info_.resize(inFrameLines_.size());
seqLengthAndStart_.resize(inFrameLines_.size());
seqInfos_.clear();
seqInfos_.resize(inFrameLines_.size());
{
AsyncGpuBlock asyncGpuBlock;
// if shareInlinkInfo, only calculate info of the first inlink
// else, calculate info for each inlink
if (shareInlinkInfo) {
input.getSeqLengthAndStart(&seqLengthAndStart_[0], &maxSequenceLength_);
input.getSeqInfo(&seqInfos_[0]);
maxSequenceLength_ = seqInfos_[0][0].topLevelLength;
createInFrameInfo(0, input, passType);
} else {
for (size_t i = 0; i < inFrameLines_.size(); i++) {
const Argument& input1 = inFrameLines_[i].inLayer->getOutput();
input1.getSeqLengthAndStart(&seqLengthAndStart_[i],
&maxSequenceLength_);
input1.getSeqInfo(&seqInfos_[i]);
maxSequenceLength_ = seqInfos_[i][0].topLevelLength;
createInFrameInfo(i, input1, passType);
}
}
......@@ -614,7 +616,7 @@ void RecurrentGradientMachine::removeBeamSearchStatisticsCallbacks() {
* for all realLayer of inFrameLines one time.
*/
void RecurrentGradientMachine::createInFrameInfo(int inlinks_id,
void RecurrentGradientMachine::createInFrameInfo(int inlinkId,
const Argument& input,
PassType passType) {
bool hasSubseq = input.hasSubseq();
......@@ -622,66 +624,67 @@ void RecurrentGradientMachine::createInFrameInfo(int inlinks_id,
size_t numSequences = input.getNumSequences();
std::vector<int> allIds;
auto& seqInfo = seqInfos_[inlinkId];
numSeqs_.clear();
Info* inlink_info = &info_[inlinks_id];
inlink_info->idIndex.clear();
inlink_info->idIndex.push_back(0); // first idIndex = 0
Info* inlinkInfo = &info_[inlinkId];
inlinkInfo->idIndex.clear();
inlinkInfo->idIndex.push_back(0); // first idIndex = 0
std::vector<int> sequenceStartPositions;
const int* subSequenceStartPositions = nullptr;
if (hasSubseq) { // for sequenceScatterAgentLayer
// numSubSequences : all sentences within all samples(batch)
size_t numSubSequences = input.getNumSubSequences();
std::vector<int> sequenceStartPositions;
inlink_info->seqStartPosIndex.clear();
inlink_info->seqStartPosIndex.push_back(0); // first seqStartPosIndex = 0
// maxSequenceLength_: max number of sentences(subseq) in allsamples
for (int i = 0; i < maxSequenceLength_; ++i) {
subSequenceStartPositions =
input.subSequenceStartPositions->getData(false);
inlinkInfo->seqStartPosIndex.clear();
inlinkInfo->seqStartPosIndex.push_back(0); // first seqStartPosIndex = 0
}
// maxSequenceLength_: max topLevelLength in allsamples
for (int i = 0; i < maxSequenceLength_; ++i) {
if (hasSubseq) {
sequenceStartPositions.push_back(0); // first element = 0
int numSeqs = 0;
for (size_t j = 0; j < numSubSequences; ++j) { // for each sentence
// seqLengthAndStart_[inlinks_id][j]:
// a 4-tuple including <subseqlen, subseqstart, seqid, subseqid>
if (std::get<3>(seqLengthAndStart_[inlinks_id][j]) == i) {
++numSeqs;
// subseqstart: the cpuSubSequenceStartPositions of this subseq
int subSeqStart = std::get<1>(seqLengthAndStart_[inlinks_id][j]);
int subSeqLength = std::get<0>(seqLengthAndStart_[inlinks_id][j]);
for (int k = subSeqStart; k < subSeqStart + subSeqLength; ++k) {
allIds.push_back(k);
}
sequenceStartPositions.push_back(sequenceStartPositions.back() +
subSeqLength);
}
}
inlink_info->idIndex.push_back(allIds.size());
inlink_info->seqStartPosIndex.push_back(sequenceStartPositions.size());
numSeqs_.push_back(numSeqs);
}
// inFrameLine create sequenceStartPositions one time
CHECK_EQ(sequenceStartPositions.size(),
maxSequenceLength_ + numSubSequences);
CHECK_EQ(inlink_info->seqStartPosIndex.size(),
static_cast<size_t>(maxSequenceLength_ + 1));
createSeqPos(sequenceStartPositions, &inlink_info->sequenceStartPositions);
} else { // for scatterAgentLayer
for (int i = 0; i < maxSequenceLength_; ++i) {
int numSeqs = 0;
for (size_t j = 0; j < numSequences; ++j) {
int seqLength = std::get<0>(seqLengthAndStart_[inlinks_id][j]);
if (i >= seqLength) {
break;
int numSeqs = 0;
for (size_t j = 0; j < numSequences; ++j) {
int seqLength = seqInfo[j].topLevelLength;
if (i >= seqLength) {
break;
}
++numSeqs;
if (hasSubseq) {
int subSeqStart = subSequenceStartPositions[seqInfo[j].subSeqStart + i];
int subSeqEnd =
subSequenceStartPositions[seqInfo[j].subSeqStart + i + 1];
for (int k = subSeqStart; k < subSeqEnd; ++k) {
allIds.push_back(k);
}
++numSeqs;
int seqStart = std::get<1>(seqLengthAndStart_[inlinks_id][j]);
sequenceStartPositions.push_back(sequenceStartPositions.back() +
subSeqEnd - subSeqStart);
} else {
int seqStart = seqInfo[j].seqStart;
allIds.push_back(reversed_ ? (seqStart + seqLength - 1 - i)
: (seqStart + i));
}
inlink_info->idIndex.push_back(allIds.size());
numSeqs_.push_back(numSeqs);
}
inlinkInfo->idIndex.push_back(allIds.size());
numSeqs_.push_back(numSeqs);
if (hasSubseq) {
inlinkInfo->seqStartPosIndex.push_back(sequenceStartPositions.size());
}
}
if (hasSubseq) {
// inFrameLine create sequenceStartPositions one time
CHECK_EQ(sequenceStartPositions.size(),
maxSequenceLength_ + input.getNumSubSequences());
CHECK_EQ(inlinkInfo->seqStartPosIndex.size(),
static_cast<size_t>(maxSequenceLength_ + 1));
createSeqPos(sequenceStartPositions, &inlinkInfo->sequenceStartPositions);
}
// copy and check scatterId
copyScattedId(allIds, &inlink_info->allIds, input.getBatchSize());
CHECK_EQ(inlink_info->idIndex.size(),
copyScattedId(allIds, &inlinkInfo->allIds, input.getBatchSize());
CHECK_EQ(inlinkInfo->idIndex.size(),
static_cast<size_t>(maxSequenceLength_ + 1));
}
......@@ -701,7 +704,7 @@ void RecurrentGradientMachine::createMemoryFrameInfo(
const int* starts = input.sequenceStartPositions->getData(false);
for (size_t i = 0; i < numSequences; ++i) {
// memory info adopt info of inlinks[0]
int seqId = std::get<2>(seqLengthAndStart_[0][i]);
int seqId = seqInfos_[0][i].seqId;
for (int k = starts[seqId]; k < starts[seqId + 1]; ++k) {
allIds.push_back(k);
}
......@@ -713,7 +716,7 @@ void RecurrentGradientMachine::createMemoryFrameInfo(
} else { // for scatterAgentLayer
for (size_t i = 0; i < numSequences; ++i) {
allIds.push_back(std::get<2>(seqLengthAndStart_[0][i]));
allIds.push_back(seqInfos_[0][i].seqId);
}
}
// copy and check scatterId
......
......@@ -337,11 +337,7 @@ protected:
// data) or has more than i subsequences (for subsequence data)
std::vector<int> numSeqs_;
// each inlinks has a "std::vector<std::tuple<int, int, int, int>>" denotes
// its sequence info:
// if hasSubSeq, tuple of (subSeqLength, subSeqStart, seqIndex, subSeqIndex)
// else, tuple of (seqLength, seqStart, seqIndex, seqIndex)
std::vector<std::vector<std::tuple<int, int, int, int>>> seqLengthAndStart_;
std::vector<std::vector<Argument::SeqInfo>> seqInfos_;
// the id of inlink which share info with outlinks
int targetInfoInlinkId_;
......
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Layer.h"
namespace paddle {
class PrintLayer : public Layer {
public:
explicit PrintLayer(const LayerConfig& config)
: Layer(config) {}
void forward(PassType passType);
void backward(const UpdateCallback& callback) {}
};
void PrintLayer::forward(PassType passType) {
Layer::forward(passType);
for (size_t i = 0; i != inputLayers_.size(); ++i) {
const auto& argu = getInput(i);
const std::string& name = inputLayers_[i]->getName();
if (argu.value) {
std::ostringstream os;
argu.value->print(os);
LOG(INFO) << "layer=" << name << " value matrix:\n" << os.str();
}
if (argu.ids) {
std::ostringstream os;
argu.ids->print(os, argu.ids->getSize());
LOG(INFO) << "layer=" << name << " ids vector:\n" << os.str();
}
if (auto startPos = argu.sequenceStartPositions) {
std::ostringstream os;
startPos->getVector(false)->print(os, startPos->getSize());
LOG(INFO) << "layer=" << name << " sequence pos vector:\n" << os.str();
}
if (auto subStartPos = argu.subSequenceStartPositions) {
std::ostringstream os;
subStartPos->getVector(false)->print(os, subStartPos->getSize());
LOG(INFO) << "layer=" << name << " sub-sequence pos vector:\n"
<< os.str();
}
}
}
REGISTER_LAYER(print, PrintLayer);
} // namespace paddle
......@@ -42,14 +42,16 @@ def outer_step(x):
inner_mem = memory(name="inner_rnn_state",
size=hidden_dim,
boot_layer=outer_mem)
return fc_layer(input=[y, inner_mem],
out = fc_layer(input=[y, inner_mem],
size=hidden_dim,
act=TanhActivation(),
bias_attr=True,
name="inner_rnn_state")
return out
inner_rnn_output = recurrent_group(
step=inner_step,
name="inner",
input=x)
last = last_seq(input=inner_rnn_output, name="outer_rnn_state")
......@@ -60,11 +62,10 @@ def outer_step(x):
return inner_rnn_output
out = recurrent_group(
name="outer",
step=outer_step,
input=SubsequenceInput(emb))
value_printer_evaluator(input=out)
rep = last_seq(input=out)
prob = fc_layer(size=label_dim,
input=rep,
......
......@@ -35,18 +35,18 @@ emb = embedding_layer(input=data, size=word_dim)
def step(y):
mem = memory(name="rnn_state", size=hidden_dim)
return fc_layer(input=[y, mem],
out = fc_layer(input=[y, mem],
size=hidden_dim,
act=TanhActivation(),
bias_attr=True,
name="rnn_state")
return out
out = recurrent_group(
name="rnn",
step=step,
input=emb)
value_printer_evaluator(input=out)
rep = last_seq(input=out)
prob = fc_layer(size=label_dim,
input=rep,
......
......@@ -92,7 +92,7 @@ void CalCost(const string& conf, const string& dir, real* cost,
rmDir(dir.c_str());
}
void test(const string& conf1, const string& conf2) {
void test(const string& conf1, const string& conf2, double eps) {
int num_passes = 5;
real* cost1 = new real[num_passes];
const string dir1 = "gserver/tests/t1";
......@@ -104,8 +104,9 @@ void test(const string& conf1, const string& conf2) {
for (int i = 0; i < num_passes; i++) {
LOG(INFO) << "num_passes: " << i << ", cost1=" << cost1[i]
<< ", cost2=" << cost2[i];
ASSERT_NEAR(cost1[i], cost2[i], 1e-3);
<< ", cost2=" << cost2[i]
<< ", diff=" << std::abs(cost1[i] - cost2[i]);
ASSERT_NEAR(cost1[i], cost2[i], eps);
}
delete[] cost1;
delete[] cost2;
......@@ -113,12 +114,14 @@ void test(const string& conf1, const string& conf2) {
TEST(RecurrentGradientMachine, HasSubSequence) {
test("gserver/tests/sequence_layer_group.conf",
"gserver/tests/sequence_nest_layer_group.conf");
"gserver/tests/sequence_nest_layer_group.conf",
1e-5);
}
TEST(RecurrentGradientMachine, rnn) {
test("gserver/tests/sequence_rnn.conf",
"gserver/tests/sequence_nest_rnn.conf");
"gserver/tests/sequence_nest_rnn.conf",
0);
}
......
......@@ -477,51 +477,34 @@ void Argument::splitByDataId(const std::vector<Argument>& argus,
}
}
void Argument::getSeqLengthAndStart(
std::vector<std::tuple<int, int, int, int>>* seqLengthAndStart,
int* maxSequenceLength) const {
void Argument::getSeqInfo(std::vector<SeqInfo>* seqInfo) const {
const int* starts = sequenceStartPositions->getData(false);
if (hasSubseq()) {
size_t numSubSequences = getNumSubSequences();
(*seqLengthAndStart).reserve(numSubSequences);
const int* subStarts = subSequenceStartPositions->getData(false);
int seqIndex = 0;
int subSeqIndex = 0;
*maxSequenceLength = 0;
for (size_t i = 0; i < numSubSequences; ++i) {
if (subStarts[i] == starts[seqIndex]) {
subSeqIndex = 0;
(*seqLengthAndStart)
.push_back(std::make_tuple<int, int, int, int>(
subStarts[i + 1] - subStarts[i], (int)subStarts[i],
(int)seqIndex, (int)subSeqIndex));
++subSeqIndex;
++seqIndex;
} else if (subStarts[i] < starts[seqIndex]) {
(*seqLengthAndStart)
.push_back(std::make_tuple<int, int, int, int>(
subStarts[i + 1] - subStarts[i], (int)subStarts[i],
(int)seqIndex - 1, (int)subSeqIndex));
++subSeqIndex;
const int* subStarts = hasSubseq()
? subSequenceStartPositions->getData(false) : nullptr;
size_t numSequences = getNumSequences();
seqInfo->reserve(numSequences);
int subSeqEnd = 0;
for (size_t i = 0; i < numSequences; ++i) {
SeqInfo info;
info.seqStart = starts[i];
info.subLevelLength = starts[i + 1] - starts[i];
info.seqId = i;
if (hasSubseq()) {
info.subSeqStart = subSeqEnd;
while (subStarts[subSeqEnd] < starts[i + 1]) {
++subSeqEnd;
}
// maxSequenceLength_ = 1 + max(subSeqIndex) in each Seq.
if (*maxSequenceLength < std::get<3>((*seqLengthAndStart)[i]))
*maxSequenceLength = std::get<3>((*seqLengthAndStart)[i]);
}
*maxSequenceLength += 1;
} else {
size_t numSequences = getNumSequences();
(*seqLengthAndStart).reserve(numSequences);
for (size_t i = 0; i < numSequences; ++i) {
(*seqLengthAndStart)
.push_back(std::make_tuple<int, int, int, int>(
starts[i + 1] - starts[i], (int)starts[i], (int)i, (int)i));
info.topLevelLength = subSeqEnd - info.subSeqStart;
} else {
info.topLevelLength = info.subLevelLength;
info.subSeqStart = 0; // not used
}
std::sort((*seqLengthAndStart).begin(), (*seqLengthAndStart).end(),
std::greater<std::tuple<int, int, int, int>>());
*maxSequenceLength = std::get<0>((*seqLengthAndStart)[0]);
seqInfo->push_back(info);
}
std::sort(seqInfo->begin(), seqInfo->end(),
[](const SeqInfo& a, const SeqInfo& b) {
return a.topLevelLength > b.topLevelLength;
});
}
void Argument::checkSubset() const {
......
......@@ -253,21 +253,29 @@ struct Argument {
static void splitByDataId(const std::vector<Argument>& argus,
std::vector<std::vector<Argument>>* arguGroups);
struct SeqInfo {
// Equal to sequence length for sequence data
// Equal to number of subsequences for subsequence data
int topLevelLength;
int seqStart;
int seqId;
// Equal to topLevelLength for sequence data
// Equal to sum of the length of subsequences for subsequence data
int subLevelLength;
// Only used for subsequence data, start position of this sequence
// is subSequenceStartPositions, i.e.
// subSequenceStartPositions[subSeqStart] == seqStart
int subSeqStart;
};
/*
Get Sequence Length, startPositions and max Length according to input
1. For sequence data:
Each tuple is (seq_length, seq_start, seq_id, seq_id)
The tuples are sorted according to seq_length or subseq_length
*maxSequenceLength is the maximal sequence length
2. For subsequence data:
Each tuple is (subseq_length, subseq_start, seq_id, subseq_id)
The tuples are not sorted. They are in the original order.
*maxSequenceLenth is the maximal number of subsequences in each sequence.
*/
void getSeqLengthAndStart(
std::vector<std::tuple<int, int, int, int>>* seqLengthAndStart,
int* maxSequenceLength) const;
Get SeqInfo for each sequence of this argument
Elements in *seqInfo are sorted by topLevelLength in descending order
*/
void getSeqInfo(std::vector<SeqInfo>* segInfo) const;
/*
Check Whether sequenceStartPositions is subset of
subSequenceStartPositions.
......
......@@ -1408,6 +1408,14 @@ class SelectiveFCLayer(LayerBase):
input_index, psize, dims, sparse, format)
self.create_bias_parameter(bias, self.config.size)
@config_layer('print')
class PrintLayer(LayerBase):
def __init__(
self,
name,
inputs):
super(PrintLayer, self).__init__(name, 'print', 0, inputs)
@config_layer('data')
class DataLayer(LayerBase):
def __init__(
......
......@@ -21,7 +21,6 @@ from .evaluators import *
from .poolings import MaxPooling, AvgPooling, BasePoolingType
from .attrs import *
from .default_decorators import *
try:
import cPickle as pickle
except ImportError:
......@@ -52,7 +51,7 @@ __all__ = ["full_matrix_projection", "AggregateLevel", "ExpandLevel",
'cross_entropy_with_selfnorm', 'cross_entropy',
'multi_binary_label_cross_entropy',
'rank_cost', 'lambda_cost', 'huber_cost',
'block_expand_layer', 'out_prod_layer',
'block_expand_layer', 'out_prod_layer', 'print_layer'
]
......@@ -108,6 +107,8 @@ class LayerType(object):
LINEAR_COMBINATION_LAYER = "convex_comb"
BLOCK_EXPAND = "blockexpand"
PRINT_LAYER = "print"
CTC_LAYER = "ctc"
CRF_LAYER = "crf"
CRF_DECODING_LAYER = "crf_decoding"
......@@ -202,6 +203,25 @@ ERROR_CLIPPING = 'error_clipping_threshold'
DROPOUT = 'drop_rate'
def check_input(input):
"""
Check input is a LayerOutput or list of LayerOutput or tuple of LayerOutput
if is a LayerOutput,
:param input: The input layer. Could be a list/tuple of input layer.
:type input: LayerOutput|list|tuple
:return: list of LayerOutput
:rtype: list of LayerOutput
"""
if isinstance(input, LayerOutput):
return [LayerOutput]
assert isinstance(input, list)
for inp in input:
assert isinstance(inp, LayerOutput)
return list(input)
def layer_support(*attrs):
def decorator(method):
@functools.wraps(method)
......@@ -730,6 +750,27 @@ def fc_layer(input, size, act=None, name=None,
size=size)
@wrap_name_default("print")
def print_layer(input, name=None):
"""
Print the output value of input layers. This layer is useful for debugging.
:param name: The Layer Name.
:type name: basestring
:param input: The input layer. Could be a list/tuple of input layer.
:type input: LayerOutput|list|tuple
:return: No return
"""
check_input(input)
Layer(
name=name,
type=LayerType.PRINT_LAYER,
inputs=[l.name for l in input],
)
LayerOutput(name, LayerType.PRINT_LAYER, input)
@wrap_name_default("seq_pooling")
@wrap_bias_attr_default(has_bias=False)
@wrap_param_default(['pooling_type'], default_factory=lambda _: MaxPooling())
......
......@@ -34,6 +34,8 @@ out = fc_layer(input=[cos1, cos3, linear_comb, z],
size=num_classes,
act=SoftmaxActivation())
print_layer(input=[out])
outputs(classification_cost(out, data_layer(name="label", size=num_classes)))
# for ctc
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册