提交 17994e38 编写于 作者: X xuwei06

RecurrentGroup with mixed input sequence types

No longer need to use SubsequenceInput. The framework will detect.
上级 14c0e71d
...@@ -284,6 +284,16 @@ public: ...@@ -284,6 +284,16 @@ public:
} }
protected: protected:
std::vector<Argument::SeqInfo> commonSeqInfo_;
ICpuGpuVectorPtr sequenceStartPositions_;
void calcSequenceStartPositions();
void checkInputConsistency(int inlinkId,
const std::vector<Argument::SeqInfo>& seqInfo);
void reorganizeInput(PassType passType);
void reorganizeOutput(PassType passType);
void connectFrames(PassType passType);
void calcNumSequencesAtEachStep();
void resizeOrCreateFrames(int numFrames); void resizeOrCreateFrames(int numFrames);
void resizeBootFrame(int numSequences); void resizeBootFrame(int numSequences);
...@@ -295,8 +305,7 @@ protected: ...@@ -295,8 +305,7 @@ protected:
std::string linkName; std::string linkName;
LayerPtr inLayer; LayerPtr inLayer;
std::vector<LayerPtr> agents; // Scatter Agents to reform batch input std::vector<LayerPtr> agents; // Scatter Agents to reform batch input
bool hasSubseq; Argument outArg; // scatter output argument
Argument outArg; // scatter output argument
}; };
std::vector<InFrameLine> inFrameLines_; std::vector<InFrameLine> inFrameLines_;
...@@ -318,7 +327,6 @@ protected: ...@@ -318,7 +327,6 @@ protected:
std::vector<LayerPtr> agents; std::vector<LayerPtr> agents;
std::vector<LayerPtr> scatterAgents; // scatter agent used by beam search std::vector<LayerPtr> scatterAgents; // scatter agent used by beam search
Argument outArg; // scatter output argument Argument outArg; // scatter output argument
bool is_sequence;
// Different memoryFrameLine have different element as follows // Different memoryFrameLine have different element as follows
IVectorPtr allIds; // scattered id of realLayer IVectorPtr allIds; // scattered id of realLayer
ICpuGpuVectorPtr ICpuGpuVectorPtr
...@@ -330,22 +338,27 @@ protected: ...@@ -330,22 +338,27 @@ protected:
// and all outFrameLines(outlinks) share the info with one inFrameLine, // and all outFrameLines(outlinks) share the info with one inFrameLine,
// which is assigned by targetInfoInlinkId_. // which is assigned by targetInfoInlinkId_.
struct Info { struct Info {
IVectorPtr allIds; // scattered id of realLayer // The original positions in the original batch
std::vector<int> idIndex; // index of allIds IVectorPtr allIds; // scattered id of realLayer [batchSize]
// index of allIds for each step [maxSequenceLength_]
// idIndex[i] is the total length of the first i sequences
std::vector<int> idIndex;
ICpuGpuVectorPtr ICpuGpuVectorPtr
sequenceStartPositions; // scattered sequenceStartPositions sequenceStartPositions; // scattered sequenceStartPositions
std::vector<int> seqStartPosIndex; // index of sequenceStartPositions std::vector<int> seqStartPosIndex; // index of sequenceStartPositions
}; };
std::vector<Info> info_; std::vector<Info> info_; // for input
// numSeqs_[i] is the number sequences which is longer than i (for sequence // numSeqs_[i] is the number sequences which is longer than i (for sequence
// data) or has more than i subsequences (for subsequence data) // data) or has more than i subsequences (for subsequence data)
// Equivalently, numSeqs_[i] is the number of sequences at step i;
std::vector<int> numSeqs_; std::vector<int> numSeqs_;
std::vector<std::vector<Argument::SeqInfo>> seqInfos_; std::vector<std::vector<Argument::SeqInfo>> seqInfos_;
// the id of inlink which share info with outlinks void checkOutputConsistency(OutFrameLine& outFrameLine);
int targetInfoInlinkId_;
/* create scattered id infomation for all realLayer of inFrameLines one time. /* create scattered id infomation for all realLayer of inFrameLines one time.
* If hasSubseq, will also create scattered sequenceStartPositions infomation * If hasSubseq, will also create scattered sequenceStartPositions infomation
...@@ -354,6 +367,28 @@ protected: ...@@ -354,6 +367,28 @@ protected:
void createInFrameInfo(int inlinks_id, void createInFrameInfo(int inlinks_id,
const Argument& input, const Argument& input,
PassType passType); PassType passType);
void createInFrameInfo_nonseq(int inlinks_id,
const Argument& input,
PassType passType);
void createInFrameInfo_seq(int inlinks_id,
const Argument& input,
PassType passType);
void createInFrameInfo_subseq(int inlinks_id,
const Argument& input,
PassType passType);
void createOutFrameInfo(OutFrameLine& outFrameLine,
Info& info,
ICpuGpuVectorPtr& sequenceStartPositions,
ICpuGpuVectorPtr& subSequenceStartPositions);
void createOutFrameInfo_seq(OutFrameLine& outFrameLine,
Info& info,
ICpuGpuVectorPtr& sequenceStartPositions,
ICpuGpuVectorPtr& subSequenceStartPositions);
void createOutFrameInfo_subseq(OutFrameLine& outFrameLine,
Info& info,
ICpuGpuVectorPtr& sequenceStartPositions,
ICpuGpuVectorPtr& subSequenceStartPositions);
void createMemoryFrameInfo(MemoryFrameLine* memoryFrameLine, void createMemoryFrameInfo(MemoryFrameLine* memoryFrameLine,
PassType passType); PassType passType);
...@@ -386,9 +421,7 @@ protected: ...@@ -386,9 +421,7 @@ protected:
NeuralNetwork* rootNetwork_; NeuralNetwork* rootNetwork_;
bool reversed_; bool reversed_;
// if hasSubseq: max number of sentences(subseq)in batchsize samples int maxSequenceLength_; // Max top-level length
// else: max number of tokens in batchsize samples(sentences)
int maxSequenceLength_;
bool useGpu_; bool useGpu_;
bool stopBeamSearch_; bool stopBeamSearch_;
......
...@@ -36,14 +36,23 @@ void AgentLayer::forward(PassType passType) { ...@@ -36,14 +36,23 @@ void AgentLayer::forward(PassType passType) {
Layer::forward(passType); Layer::forward(passType);
Argument& realOutput = realLayer_->getOutput(); Argument& realOutput = realLayer_->getOutput();
int realHeight = realOutput.getBatchSize(); int realNumSequences = realOutput.getNumSequences();
CHECK_LE(numSamples_, realHeight); CHECK_LE(numSamples_, realNumSequences);
// get Arguments from real layers // get Arguments from real layers
if (numSamples_ > 0 && numSamples_ < realHeight) { if (numSamples_ > 0 && numSamples_ < realNumSequences) {
if (realOutput.ids) { if (realOutput.hasSeq()) {
output_.ids = int numRows =
IVector::create(realOutput.ids->getData(), numSamples_, useGpu_); realOutput.sequenceStartPositions->getData(false)[numSamples_];
output_.subArgFrom(realOutput,
/* offset */ 0,
numRows,
getSize(),
useGpu_,
/* trans */ false,
/* seqFlag */ true,
/* seqStart */ 0,
/* seqSize */ numSamples_ + 1);
} else { } else {
output_.subArgFrom( output_.subArgFrom(
realOutput, /* offset */ 0, numSamples_, getSize(), useGpu_); realOutput, /* offset */ 0, numSamples_, getSize(), useGpu_);
...@@ -53,34 +62,6 @@ void AgentLayer::forward(PassType passType) { ...@@ -53,34 +62,6 @@ void AgentLayer::forward(PassType passType) {
} }
} }
void SequenceAgentLayer::forward(PassType passType) {
Layer::forward(passType);
Argument& realOutput = realLayer_->getOutput();
int realNumSequences = realOutput.getNumSequences();
CHECK_LE(numSamples_, realNumSequences);
// get Arguments from real layers
if (numSamples_ > 0 && numSamples_ < realNumSequences) {
int numRows =
realOutput.sequenceStartPositions->getData(false)[numSamples_];
CHECK(!realOutput.ids) << "Not supported";
output_.subArgFrom(realOutput,
/* offset */ 0,
numRows,
getSize(),
useGpu_,
/* trans */ false,
/* seqFlag */ true,
/* seqStart */ 0,
/* seqSize */ numSamples_ + 1);
} else {
output_ = realOutput;
}
}
REGISTER_LAYER(sequence_agent, SequenceAgentLayer);
bool GatherAgentLayer::init(const LayerMap& layerMap, bool GatherAgentLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) { const ParameterMap& parameterMap) {
CHECK_EQ(config_.inputs_size(), 0); CHECK_EQ(config_.inputs_size(), 0);
...@@ -91,18 +72,26 @@ bool GatherAgentLayer::init(const LayerMap& layerMap, ...@@ -91,18 +72,26 @@ bool GatherAgentLayer::init(const LayerMap& layerMap,
return true; return true;
} }
void GatherAgentLayer::copyIdAndSequenceInfo(const Argument& input, void GatherAgentLayer::copyIdAndSequenceInfo(
const IVectorPtr& ids, ICpuGpuVectorPtr sequenceStartPositions,
const std::vector<int>& idIndex) { ICpuGpuVectorPtr subSequenceStartPositions,
output_.sequenceStartPositions = input.sequenceStartPositions; const IVectorPtr& ids,
output_.subSequenceStartPositions = input.subSequenceStartPositions; const std::vector<int>& idIndex) {
realLayers_.clear(); output_.sequenceStartPositions = sequenceStartPositions;
output_.subSequenceStartPositions = subSequenceStartPositions;
allIds_ = ids; allIds_ = ids;
idIndex_ = idIndex; idIndex_ = idIndex;
} }
void GatherAgentLayer::forward(PassType passType) { void GatherAgentLayer::forward(PassType passType) {
Layer::forward(passType); Layer::forward(passType);
forwardIds(passType);
forwardValue(passType);
}
void GatherAgentLayer::forwardValue(PassType passType) {
MatrixPtr valueReal = realLayers_[0]->getOutputValue();
if (!valueReal) return;
int height = allIds_->getSize(); int height = allIds_->getSize();
int width = this->getSize(); int width = this->getSize();
...@@ -147,7 +136,9 @@ void ScatterAgentLayer::forward(PassType passType) { ...@@ -147,7 +136,9 @@ void ScatterAgentLayer::forward(PassType passType) {
CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId()); CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId());
int width = this->getSize(); int width = this->getSize();
if (realOutArg_.value || realOutArg_.ids) { if (realOutArg_.hasSeq()) {
forwardSequence(passType);
} else if (realOutArg_.value || realOutArg_.ids) {
output_.subArgFrom( output_.subArgFrom(
realOutArg_, /* offset */ idIndex_, idSize_, width, useGpu_); realOutArg_, /* offset */ idIndex_, idSize_, width, useGpu_);
} else { // used in generation } else { // used in generation
...@@ -174,7 +165,7 @@ void ScatterAgentLayer::backward(const UpdateCallback& callback) { ...@@ -174,7 +165,7 @@ void ScatterAgentLayer::backward(const UpdateCallback& callback) {
if (realGrad) { if (realGrad) {
// for agent in inFrameLines and memoryFrameLines, // for agent in inFrameLines and memoryFrameLines,
// only first scatterAgentLayer should do addToRows in backward // only first scatterAgentLayer should do addToRows in backward
if (idIndex_ == 0) { if (handleBackward_) {
outputGrad->addToRows(*realGrad, *ids_); outputGrad->addToRows(*realGrad, *ids_);
} }
} }
...@@ -183,12 +174,14 @@ void ScatterAgentLayer::backward(const UpdateCallback& callback) { ...@@ -183,12 +174,14 @@ void ScatterAgentLayer::backward(const UpdateCallback& callback) {
REGISTER_LAYER(gather_agent, GatherAgentLayer); REGISTER_LAYER(gather_agent, GatherAgentLayer);
REGISTER_LAYER(scatter_agent, ScatterAgentLayer); REGISTER_LAYER(scatter_agent, ScatterAgentLayer);
void SequenceGatherAgentLayer::forward(PassType passType) { void GatherAgentLayer::forwardIds(PassType passType) {
Layer::forward(passType);
int height = 0; int height = 0;
int* starts = output_.subSequenceStartPositions->getMutableData(false);
IVectorPtr idReal = realLayers_[0]->getOutputLabel(); IVectorPtr idReal = realLayers_[0]->getOutputLabel();
if (idReal) {
if (!idReal) return;
if (output_.subSequenceStartPositions) {
int* starts = output_.subSequenceStartPositions->getMutableData(false);
// Gather generator.idsVec // Gather generator.idsVec
// if is beam search generation result. Get first result. // if is beam search generation result. Get first result.
if (idReal->getData()[idReal->getSize() - 1] == -1) { if (idReal->getData()[idReal->getSize() - 1] == -1) {
...@@ -212,13 +205,11 @@ void SequenceGatherAgentLayer::forward(PassType passType) { ...@@ -212,13 +205,11 @@ void SequenceGatherAgentLayer::forward(PassType passType) {
->copyFrom(*realLayers_[i]->getOutputLabel()); ->copyFrom(*realLayers_[i]->getOutputLabel());
} }
} else { } else {
// Gather output.value, same as GatherAgentLayer LOG(FATAL) << "Not implemented";
CHECK(output_.subSequenceStartPositions);
GatherAgentLayer::forward(passType);
} }
} }
void SequenceScatterAgentLayer::forward(PassType passType) { void ScatterAgentLayer::forwardSequence(PassType passType) {
Layer::forward(passType); Layer::forward(passType);
CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId()); CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId());
...@@ -241,6 +232,7 @@ void SequenceScatterAgentLayer::forward(PassType passType) { ...@@ -241,6 +232,7 @@ void SequenceScatterAgentLayer::forward(PassType passType) {
/* seqStart */ seqStartPosIndex_, /* seqStart */ seqStartPosIndex_,
/* seqSize */ numSequences_); /* seqSize */ numSequences_);
} else { } else {
// Putting the generation logic here is really an ugly hack!
// used in generation // used in generation
int height = 0; int height = 0;
size_t numSequences = ids_->getSize(); size_t numSequences = ids_->getSize();
...@@ -284,7 +276,4 @@ void SequenceScatterAgentLayer::forward(PassType passType) { ...@@ -284,7 +276,4 @@ void SequenceScatterAgentLayer::forward(PassType passType) {
} }
} }
REGISTER_LAYER(sequence_gather_agent, SequenceGatherAgentLayer);
REGISTER_LAYER(sequence_scatter_agent, SequenceScatterAgentLayer);
} // namespace paddle } // namespace paddle
...@@ -49,18 +49,6 @@ public: ...@@ -49,18 +49,6 @@ public:
void backward(const UpdateCallback& callback = nullptr) override {} void backward(const UpdateCallback& callback = nullptr) override {}
}; };
/**
* like AgentLayer, but use first *numSamples* sequences
*/
class SequenceAgentLayer : public AgentLayer {
public:
explicit SequenceAgentLayer(const LayerConfig& config) : AgentLayer(config) {}
~SequenceAgentLayer() {}
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override {}
};
/** /**
* Like AgentLayer, but it can gather many real layers. Each real * Like AgentLayer, but it can gather many real layers. Each real
* layer give a few rows of a sequence, after gather all real layers, * layer give a few rows of a sequence, after gather all real layers,
...@@ -83,7 +71,10 @@ public: ...@@ -83,7 +71,10 @@ public:
const ParameterMap& parameterMap) override; const ParameterMap& parameterMap) override;
// call before addRealLayer // call before addRealLayer
void copyIdAndSequenceInfo(const Argument& input, void clearRealLayers() { realLayers_.clear(); }
void copyIdAndSequenceInfo(ICpuGpuVectorPtr sequenceStartPositions,
ICpuGpuVectorPtr subSequenceStartPositions,
const IVectorPtr& allIds, const IVectorPtr& allIds,
const std::vector<int>& idIndex); const std::vector<int>& idIndex);
...@@ -92,24 +83,8 @@ public: ...@@ -92,24 +83,8 @@ public:
void forward(PassType passType) override; void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override; void backward(const UpdateCallback& callback) override;
}; void forwardValue(PassType passType);
void forwardIds(PassType passType);
/**
* Like GatherAgentLayer, but select a few sequence in real layer.
* *ids* in addRealLayer() are the ids of selected sequence.
* It's used to reorder sequence output.
*/
class SequenceGatherAgentLayer : public GatherAgentLayer {
public:
explicit SequenceGatherAgentLayer(const LayerConfig& config)
: GatherAgentLayer(config) {}
virtual ~SequenceGatherAgentLayer() {}
void forward(PassType passType);
void backward(const UpdateCallback& callback) {
// same as GatherAgentLayer
GatherAgentLayer::backward(callback);
}
}; };
/** /**
...@@ -129,6 +104,11 @@ protected: ...@@ -129,6 +104,11 @@ protected:
int idSize_; int idSize_;
int seqStartPosIndex_; int seqStartPosIndex_;
int numSequences_; // number of sequences in this scatterAgentLayer int numSequences_; // number of sequences in this scatterAgentLayer
bool handleBackward_;
// use to store expanded cpuStartPositions or subSequenceStartPositions
// of real layer.
ICpuGpuVectorPtr inputStartPos_;
public: public:
explicit ScatterAgentLayer(const LayerConfig& config) : Layer(config) {} explicit ScatterAgentLayer(const LayerConfig& config) : Layer(config) {}
...@@ -147,19 +127,15 @@ public: ...@@ -147,19 +127,15 @@ public:
* false(default) in ScatterAgentLayer, and * false(default) in ScatterAgentLayer, and
* true in SequenceScatterAgentLayer. * true in SequenceScatterAgentLayer.
*/ */
void setRealLayer(LayerPtr layer, void setRealLayer(LayerPtr layer, const std::vector<int>& ids) {
const std::vector<int>& ids,
bool copyId = false) {
realLayer_ = layer; realLayer_ = layer;
IVector::resizeOrCreate(ids_, ids.size(), useGpu_); IVector::resizeOrCreate(ids_, ids.size(), useGpu_);
ids_->copyFrom(ids.data(), ids.size()); ids_->copyFrom(ids.data(), ids.size());
if (copyId) { if (useGpu_) {
if (useGpu_) { IVector::resizeOrCreate(cpuIds_, ids.size(), false);
IVector::resizeOrCreate(cpuIds_, ids.size(), false); cpuIds_->copyFrom(ids.data(), ids.size());
cpuIds_->copyFrom(ids.data(), ids.size()); } else {
} else { cpuIds_ = ids_;
cpuIds_ = ids_;
}
} }
} }
...@@ -169,12 +145,14 @@ public: ...@@ -169,12 +145,14 @@ public:
const Argument& outArg, const Argument& outArg,
const IVectorPtr& ids, const IVectorPtr& ids,
int idIndex, int idIndex,
int idSize) { int idSize,
bool handleBackward) {
realLayer_ = layer; realLayer_ = layer;
realOutArg_ = outArg; realOutArg_ = outArg;
ids_ = ids; ids_ = ids;
idIndex_ = idIndex; idIndex_ = idIndex;
idSize_ = idSize; idSize_ = idSize;
handleBackward_ = handleBackward;
} }
void setSequenceStartPositions(const ICpuGpuVectorPtr& sequenceStartPositions, void setSequenceStartPositions(const ICpuGpuVectorPtr& sequenceStartPositions,
...@@ -187,28 +165,8 @@ public: ...@@ -187,28 +165,8 @@ public:
void forward(PassType passType) override; void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override; void backward(const UpdateCallback& callback) override;
};
/** void forwardSequence(PassType passType);
* Like ScatterAgentLayer, but select a few sequence in real layer.
* *ids* in setRealLayer() or setRealLayerAndOutput() are the ids of
* selected sequence. It's used to reorder sequence input.
*/
class SequenceScatterAgentLayer : public ScatterAgentLayer {
protected:
// use to store expanded cpuStartPositions or subSequenceStartPositions
// of real layer.
ICpuGpuVectorPtr inputStartPos_;
public:
explicit SequenceScatterAgentLayer(const LayerConfig& config)
: ScatterAgentLayer(config) {}
virtual ~SequenceScatterAgentLayer() {}
void forward(PassType passType);
void backward(const UpdateCallback& callback) {
ScatterAgentLayer::backward(callback);
}
}; };
} // namespace paddle } // namespace paddle
...@@ -46,6 +46,9 @@ void SequencePoolLayer::forward(PassType passType) { ...@@ -46,6 +46,9 @@ void SequencePoolLayer::forward(PassType passType) {
Layer::forward(passType); Layer::forward(passType);
const Argument& input = getInput(0); const Argument& input = getInput(0);
CHECK(input.hasSeq() || input.hasSubseq())
<< "Input should be a sequence or subsequence for layer " << getName();
newBatchSize_ = type_ ? input.getNumSubSequences() : input.getNumSequences(); newBatchSize_ = type_ ? input.getNumSubSequences() : input.getNumSequences();
size_t dim = getSize(); size_t dim = getSize();
// check // check
......
...@@ -95,3 +95,22 @@ def process_unequalength_seq(settings, file_name): ...@@ -95,3 +95,22 @@ def process_unequalength_seq(settings, file_name):
words1 = reduce(lambda x, y: x + y, d[0]) words1 = reduce(lambda x, y: x + y, d[0])
words2 = reduce(lambda x, y: x + y, d[1]) words2 = reduce(lambda x, y: x + y, d[1])
yield words1, words2, d[2] yield words1, words2, d[2]
###########################################################
data3 = [
[[[1, 2], [4, 5, 2]], [1, 2], 0],
[[[0, 2], [2, 5], [0, 1, 2]], [2, 3, 0], 1],
]
# Used for sequence_nest_mixed_inputs.conf
@provider(
input_types=[
integer_value_sub_sequence(10), integer_value_sequence(10),
integer_value(2)
],
should_shuffle=False)
def process_mixed(settings, file_name):
for d in data3:
yield d
...@@ -19,7 +19,7 @@ from paddle.trainer_config_helpers import * ...@@ -19,7 +19,7 @@ from paddle.trainer_config_helpers import *
define_py_data_sources2(train_list='gserver/tests/Sequence/dummy.list', define_py_data_sources2(train_list='gserver/tests/Sequence/dummy.list',
test_list=None, test_list=None,
module='rnn_data_provider', module='rnn_data_provider',
obj='process_subseq2') obj='process_subseq')
settings(batch_size=2, learning_rate=0.01) settings(batch_size=2, learning_rate=0.01)
...@@ -57,7 +57,7 @@ def outer_step(wid, x): ...@@ -57,7 +57,7 @@ def outer_step(wid, x):
last = last_seq(input=inner_rnn_output, name="outer_rnn_state") last = last_seq(input=inner_rnn_output, name="outer_rnn_state")
# "return last" should also work. But currently RecurrentGradientMachine # "return last" should also work. But currently RecurrentGradientMachine
# does not handle it, and will report error: In hierachical RNN, all out # does not handle it, and will report error: In hierachical RNN, all out
# links should be from sequences now. # links should be from sequences now.
return inner_rnn_output return inner_rnn_output
......
# edit-mode: -*- python -*-
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer_config_helpers import *
######################## data source ################################
define_py_data_sources2(
train_list='gserver/tests/Sequence/dummy.list',
test_list=None,
module='rnn_data_provider',
obj='process_mixed')
settings(batch_size=2, learning_rate=0.01)
######################## network configure ################################
dict_dim = 10
word_dim = 2
hidden_dim = 2
label_dim = 2
data1 = data_layer(name="word1", size=dict_dim)
data2 = data_layer(name="word2", size=dict_dim)
label = data_layer(name="label", size=label_dim)
encoding = embedding_layer(input=data2, size=word_dim)
subseq = embedding_layer(input=data1, size=word_dim)
seq = embedding_layer(input=data2, size=word_dim)
nonseq = embedding_layer(input=label, size=word_dim)
# This hierarchical RNN is designed to be equivalent to the simple RNN in
# sequence_rnn_multi_unequalength_inputs.conf
def outer_step(subseq, seq, nonseq, encoding):
outer_mem = memory(name="outer_rnn_state", size=hidden_dim)
def inner_step(subseq, seq, nonseq):
inner_mem = memory(
name="inner_rnn_state", size=hidden_dim, boot_layer=outer_mem)
out = fc_layer(
input=[subseq, seq, nonseq, inner_mem],
size=hidden_dim,
act=TanhActivation(),
bias_attr=True,
name='inner_rnn_state')
return out
decoder = recurrent_group(
step=inner_step, name='inner', input=[subseq, seq, nonseq])
last = last_seq(name="outer_rnn_state", input=decoder)
context = simple_attention(
encoded_sequence=encoding, encoded_proj=encoding, decoder_state=last)
return context
out = recurrent_group(
name="outer",
step=outer_step,
input=[
subseq, expand_layer(
seq, expand_as=subseq,
expand_level=ExpandLevel.FROM_SEQUENCE), expand_layer(
nonseq,
expand_as=subseq,
expand_level=ExpandLevel.FROM_NO_SEQUENCE),
StaticInput(encoding)
])
rep = last_seq(input=out)
prob = fc_layer(
size=label_dim, input=rep, act=SoftmaxActivation(), bias_attr=True)
outputs(classification_cost(input=prob, label=label))
# edit-mode: -*- python -*-
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer_config_helpers import *
######################## data source ################################
define_py_data_sources2(
train_list='gserver/tests/Sequence/dummy.list',
test_list=None,
module='rnn_data_provider',
obj='process_mixed')
settings(batch_size=2, learning_rate=0.01)
######################## network configure ################################
dict_dim = 10
word_dim = 2
hidden_dim = 2
label_dim = 2
data1 = data_layer(name="word1", size=dict_dim)
data2 = data_layer(name="word2", size=dict_dim)
label = data_layer(name="label", size=label_dim)
encoding = embedding_layer(input=data2, size=word_dim)
# This hierarchical RNN is designed to be equivalent to the simple RNN in
# sequence_rnn_multi_unequalength_inputs.conf
def outer_step(subseq, seq, nonseq, encoding):
outer_mem = memory(name="outer_rnn_state", size=hidden_dim)
def inner_step(data1, data2, label):
inner_mem = memory(
name="inner_rnn_state", size=hidden_dim, boot_layer=outer_mem)
subseq = embedding_layer(input=data1, size=word_dim)
seq = embedding_layer(input=data2, size=word_dim)
nonseq = embedding_layer(input=label, size=word_dim)
print_layer(input=[data1, seq, label, inner_mem])
out = fc_layer(
input=[subseq, seq, nonseq, inner_mem],
size=hidden_dim,
act=TanhActivation(),
bias_attr=True,
name='inner_rnn_state')
return out
decoder = recurrent_group(
step=inner_step, name='inner', input=[subseq, seq, nonseq])
last = last_seq(name="outer_rnn_state", input=decoder)
context = simple_attention(
encoded_sequence=encoding, encoded_proj=encoding, decoder_state=last)
return context
out = recurrent_group(
name="outer",
step=outer_step,
input=[data1, data2, label, StaticInput(encoding)])
rep = last_seq(input=out)
prob = fc_layer(
size=label_dim, input=rep, act=SoftmaxActivation(), bias_attr=True)
outputs(classification_cost(input=prob, label=label))
...@@ -19,7 +19,7 @@ from paddle.trainer_config_helpers import * ...@@ -19,7 +19,7 @@ from paddle.trainer_config_helpers import *
define_py_data_sources2(train_list='gserver/tests/Sequence/dummy.list', define_py_data_sources2(train_list='gserver/tests/Sequence/dummy.list',
test_list=None, test_list=None,
module='rnn_data_provider', module='rnn_data_provider',
obj='process_seq2') obj='process_seq')
settings(batch_size=2, learning_rate=0.01) settings(batch_size=2, learning_rate=0.01)
......
...@@ -155,6 +155,15 @@ TEST(RecurrentGradientMachine, rnn_multi_unequalength_input) { ...@@ -155,6 +155,15 @@ TEST(RecurrentGradientMachine, rnn_multi_unequalength_input) {
} }
} }
TEST(RecurrentGradientMachine, rnn_mixed_input) {
for (bool useGpu : {false, true}) {
test("gserver/tests/sequence_rnn_mixed_inputs.py",
"gserver/tests/sequence_rnn_matched_inputs.py",
1e-6,
useGpu);
}
}
int main(int argc, char** argv) { int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
......
...@@ -908,12 +908,13 @@ const T* CpuGpuVectorT<T>::getData(bool useGpu) const { ...@@ -908,12 +908,13 @@ const T* CpuGpuVectorT<T>::getData(bool useGpu) const {
// Operation will change data and need to reset sync_ & syncFlag_. // Operation will change data and need to reset sync_ & syncFlag_.
#define MUTABLE_VECTOR_OP(OP, useGpu, args...) \ #define MUTABLE_VECTOR_OP(OP, useGpu, args...) \
do { \ do { \
setSync(useGpu); \
if (useGpu) { \ if (useGpu) { \
copyToGpu(); \ copyToGpu(); \
setSync(useGpu); \
return gpuVectorT_->OP(args); \ return gpuVectorT_->OP(args); \
} else { \ } else { \
copyToCpu(); \ copyToCpu(); \
setSync(useGpu); \
return cpuVectorT_->OP(args); \ return cpuVectorT_->OP(args); \
} \ } \
} while (0) } while (0)
...@@ -1030,7 +1031,7 @@ void CpuGpuVectorT<T>::copyToCpu() { ...@@ -1030,7 +1031,7 @@ void CpuGpuVectorT<T>::copyToCpu() {
case DATA_AT_GPU: case DATA_AT_GPU:
CHECK(gpuVectorT_); CHECK(gpuVectorT_);
this->resizeOrCreate(gpuVectorT_->getSize(), false); this->resizeOrCreate(gpuVectorT_->getSize(), false);
cpuVectorT_->copyFrom(*gpuVectorT_, HPPL_STREAM_DEFAULT); cpuVectorT_->copyFrom(*gpuVectorT_);
setSync(SYNCED); setSync(SYNCED);
break; break;
case DATA_AT_CPU: case DATA_AT_CPU:
...@@ -1049,7 +1050,7 @@ void CpuGpuVectorT<T>::copyToGpu() { ...@@ -1049,7 +1050,7 @@ void CpuGpuVectorT<T>::copyToGpu() {
case DATA_AT_CPU: case DATA_AT_CPU:
CHECK(cpuVectorT_); CHECK(cpuVectorT_);
this->resizeOrCreate(cpuVectorT_->getSize(), true); this->resizeOrCreate(cpuVectorT_->getSize(), true);
gpuVectorT_->copyFrom(*cpuVectorT_, HPPL_STREAM_DEFAULT); gpuVectorT_->copyFrom(*cpuVectorT_);
setSync(SYNCED); setSync(SYNCED);
break; break;
case DATA_AT_GPU: case DATA_AT_GPU:
......
...@@ -149,6 +149,7 @@ struct Argument { ...@@ -149,6 +149,7 @@ struct Argument {
: getBatchSize(); : getBatchSize();
} }
bool hasSeq() const { return sequenceStartPositions != nullptr; }
bool hasSubseq() const { return subSequenceStartPositions != nullptr; } bool hasSubseq() const { return subSequenceStartPositions != nullptr; }
const int* getCpuStartPositions() const { const int* getCpuStartPositions() const {
......
...@@ -124,6 +124,8 @@ TEST(RecurrentGradientMachine, test_generation) { ...@@ -124,6 +124,8 @@ TEST(RecurrentGradientMachine, test_generation) {
bool beam_search) { bool beam_search) {
FLAGS_config_args = beam_search ? "beam_search=1" : "beam_search=0"; FLAGS_config_args = beam_search ? "beam_search=1" : "beam_search=0";
for (auto useGpu : useGpuConfs) { for (auto useGpu : useGpuConfs) {
LOG(INFO) << configFile << " useGpu=" << useGpu
<< " beam_search=" << beam_search;
testGeneration(configFile, useGpu, hasSubseq, expRetFile); testGeneration(configFile, useGpu, hasSubseq, expRetFile);
} }
}; };
......
...@@ -333,48 +333,32 @@ def RecurrentLayerGroupWithoutOutLinksBegin(name, ...@@ -333,48 +333,32 @@ def RecurrentLayerGroupWithoutOutLinksBegin(name,
for linkid, link in enumerate(in_links): for linkid, link in enumerate(in_links):
if isinstance(link, basestring): if isinstance(link, basestring):
name = link name = link
has_subseq = False
else: else:
name = link.link_name name = link.link_name
has_subseq = link.has_subseq
# assign target_inlinkid according to target_inlinkname # assign target_inlinkid according to target_inlinkname
if target_inlinkname == name: if target_inlinkname == name:
g_current_submodel.target_inlinkid = linkid g_current_submodel.target_inlinkid = linkid
if in_links_count == 0:
in_links_has_subseq = has_subseq
else:
config_assert(
in_links_has_subseq == has_subseq,
"The sequence type of in_links should be the same in RecurrentLayerGroup"
)
in_links_count += 1 in_links_count += 1
layer_name = MakeLayerNameInParentSubmodel(name) layer_name = MakeLayerNameInParentSubmodel(name)
layer = g_layer_map[layer_name] layer = g_layer_map[layer_name]
if has_subseq: ScatterAgentLayer(name=name, size=layer.size)
SequenceScatterAgentLayer(name=name, size=layer.size)
else:
ScatterAgentLayer(name=name, size=layer.size)
pair = g_current_submodel.in_links.add() pair = g_current_submodel.in_links.add()
pair.layer_name = layer_name pair.layer_name = layer_name
pair.link_name = MakeLayerNameInSubmodel(name) pair.link_name = MakeLayerNameInSubmodel(name)
pair.has_subseq = has_subseq
@config_func @config_func
def RecurrentLayerGroupSetOutLink(link): def RecurrentLayerGroupSetOutLink(link):
if isinstance(link, basestring): if isinstance(link, basestring):
name = link name = link
has_subseq = False
else: else:
name = link.link_name name = link.link_name
has_subseq = link.has_subseq
layer_name = MakeLayerNameInParentSubmodel(name) layer_name = MakeLayerNameInParentSubmodel(name)
pair = g_current_submodel.out_links.add() pair = g_current_submodel.out_links.add()
pair.layer_name = MakeLayerNameInSubmodel(name) pair.layer_name = MakeLayerNameInSubmodel(name)
pair.link_name = layer_name pair.link_name = layer_name
pair.has_subseq = has_subseq
def RecurrentLayerGroupSetGenerator(generator=None): def RecurrentLayerGroupSetGenerator(generator=None):
...@@ -425,8 +409,6 @@ def RecurrentLayerGroupEnd(name): ...@@ -425,8 +409,6 @@ def RecurrentLayerGroupEnd(name):
agent_name = GetLayerBaseName(pair.link_name) agent_name = GetLayerBaseName(pair.link_name)
if prev_submodel.HasField("generator"): if prev_submodel.HasField("generator"):
DataLayer(name=agent_name, size=layer.size) DataLayer(name=agent_name, size=layer.size)
elif pair.has_subseq:
SequenceGatherAgentLayer(name=agent_name, size=layer.size)
else: else:
GatherAgentLayer(name=agent_name, size=layer.size) GatherAgentLayer(name=agent_name, size=layer.size)
...@@ -2253,13 +2235,6 @@ class AgentLayer(LayerBase): ...@@ -2253,13 +2235,6 @@ class AgentLayer(LayerBase):
name, 'agent', size, inputs=[], device=device) name, 'agent', size, inputs=[], device=device)
@config_layer('sequence_agent')
class SequenceAgentLayer(LayerBase):
def __init__(self, name, size, device=None):
super(SequenceAgentLayer, self).__init__(
name, 'sequence_agent', size, inputs=[], device=device)
@config_layer('gather_agent') @config_layer('gather_agent')
class GatherAgentLayer(LayerBase): class GatherAgentLayer(LayerBase):
def __init__(self, name, size, device=None): def __init__(self, name, size, device=None):
...@@ -2274,20 +2249,6 @@ class ScatterAgentLayer(LayerBase): ...@@ -2274,20 +2249,6 @@ class ScatterAgentLayer(LayerBase):
name, 'scatter_agent', size, inputs=[], device=device) name, 'scatter_agent', size, inputs=[], device=device)
@config_layer('sequence_gather_agent')
class SequenceGatherAgentLayer(LayerBase):
def __init__(self, name, size, device=None):
super(SequenceGatherAgentLayer, self).__init__(
name, 'sequence_gather_agent', size, inputs=[], device=device)
@config_layer('sequence_scatter_agent')
class SequenceScatterAgentLayer(LayerBase):
def __init__(self, name, size, device=None):
super(SequenceScatterAgentLayer, self).__init__(
name, 'sequence_scatter_agent', size, inputs=[], device=device)
@config_layer('multiplex') @config_layer('multiplex')
class MultiplexLayer(LayerBase): class MultiplexLayer(LayerBase):
def __init__(self, name, inputs, size, device=None): def __init__(self, name, inputs, size, device=None):
...@@ -2303,12 +2264,12 @@ class MultiplexLayer(LayerBase): ...@@ -2303,12 +2264,12 @@ class MultiplexLayer(LayerBase):
@config_func @config_func
def Link( def Link(name, has_subseq=False):
name, """
has_subseq=False, ): Still keeping has_subseq for backward compatibility
"""
link_config = LinkConfig() link_config = LinkConfig()
link_config.link_name = name link_config.link_name = name
link_config.has_subseq = has_subseq
return link_config return link_config
...@@ -2341,13 +2302,7 @@ def Memory(name, ...@@ -2341,13 +2302,7 @@ def Memory(name,
config_assert(name is not None, "name needs cannot be None") config_assert(name is not None, "name needs cannot be None")
memory_name = name + "+delay1" memory_name = name + "+delay1"
agent_name = memory_name agent_name = memory_name
if is_sequence: agent_layer = AgentLayer(agent_name, size)
config_assert(
boot_layer is not None,
"there must be boot_layer in network when is_sequence = True")
agent_layer = SequenceAgentLayer(agent_name, size)
else:
agent_layer = AgentLayer(agent_name, size)
config_assert(g_current_submodel.is_recurrent_layer_group, config_assert(g_current_submodel.is_recurrent_layer_group,
'Memory should be used in recurrent layer group only') 'Memory should be used in recurrent layer group only')
memory = g_current_submodel.memories.add() memory = g_current_submodel.memories.add()
......
...@@ -3329,8 +3329,9 @@ class StaticInput(object): ...@@ -3329,8 +3329,9 @@ class StaticInput(object):
input.size = size input.size = size
class SubsequenceInput(object): def SubsequenceInput(input):
""" """
DEPRECATED.
Input sequence has sub-sequence, used in recurrent_group. Input sequence has sub-sequence, used in recurrent_group.
The example usage is: The example usage is:
...@@ -3339,11 +3340,7 @@ class SubsequenceInput(object): ...@@ -3339,11 +3340,7 @@ class SubsequenceInput(object):
input = SubsequenceInput(layer) input = SubsequenceInput(layer)
""" """
return input
def __init__(self, input):
assert isinstance(input, LayerOutput)
assert input.size is not None
self.input = input
@wrap_name_default("recurrent_group") @wrap_name_default("recurrent_group")
...@@ -3407,7 +3404,8 @@ def recurrent_group(step, ...@@ -3407,7 +3404,8 @@ def recurrent_group(step,
input sequence in a reverse order. input sequence in a reverse order.
:type reverse: bool :type reverse: bool
:param targetInlink: the input layer which share info with layer group's output :param targetInlink: DEPRECATED.
The input layer which share info with layer group's output
Param input specifies multiple input layers. For Param input specifies multiple input layers. For
SubsequenceInput inputs, config should assign one input SubsequenceInput inputs, config should assign one input
...@@ -3429,46 +3427,21 @@ def recurrent_group(step, ...@@ -3429,46 +3427,21 @@ def recurrent_group(step,
model_type('recurrent_nn') model_type('recurrent_nn')
def is_single_input(x): def is_single_input(x):
return isinstance(x, LayerOutput) or isinstance(x, StaticInput) \ return isinstance(x, LayerOutput) or isinstance(x, StaticInput)
or isinstance(x, SubsequenceInput)
if is_single_input(input): if is_single_input(input):
input = [input] input = [input]
assert isinstance(input, collections.Sequence) assert isinstance(input, collections.Sequence)
def is_in_links(x): def is_in_links(x):
return isinstance(x, LayerOutput) or isinstance(x, SubsequenceInput) return isinstance(x, LayerOutput)
in_links = filter(is_in_links, input) in_links = filter(is_in_links, input)
def targetInlink_in_inlinks():
for inlink in in_links:
if isinstance(inlink, SubsequenceInput):
if targetInlink == inlink.input:
return True
elif targetInlink == inlink:
return True
return False
assert (targetInlink == None or targetInlink_in_inlinks())
targetInlinkName = None if targetInlink == None \
else targetInlink.name if isinstance(targetInlink, LayerOutput) \
else targetInlink.input.name
contains_sub_seq = [False]
def map_in_links(x):
if isinstance(x, SubsequenceInput):
contains_sub_seq[0] = True
return Link(name=x.input.name, has_subseq=True)
else:
return x.name
RecurrentLayerGroupWithoutOutLinksBegin( RecurrentLayerGroupWithoutOutLinksBegin(
name=name, name=name,
in_links=map(map_in_links, in_links), in_links=map(lambda x: x.name, in_links),
seq_reversed=reverse, seq_reversed=reverse)
target_inlinkname=targetInlinkName)
in_args = [] in_args = []
has_LayerOutput = False has_LayerOutput = False
for each_input in input: for each_input in input:
...@@ -3476,10 +3449,7 @@ def recurrent_group(step, ...@@ -3476,10 +3449,7 @@ def recurrent_group(step,
if isinstance(each_input, LayerOutput): if isinstance(each_input, LayerOutput):
in_args.append(each_input) in_args.append(each_input)
has_LayerOutput = True has_LayerOutput = True
elif isinstance(each_input, SubsequenceInput): else: # StaticInput
in_args.append(each_input.input)
has_LayerOutput = True
else:
mem_name = "__%s_memory__" % each_input.input.name mem_name = "__%s_memory__" % each_input.input.name
mem = memory( mem = memory(
name=mem_name, name=mem_name,
...@@ -3503,10 +3473,7 @@ def recurrent_group(step, ...@@ -3503,10 +3473,7 @@ def recurrent_group(step,
for ot in layer_outs: for ot in layer_outs:
assert isinstance(ot, LayerOutput) assert isinstance(ot, LayerOutput)
ot.reverse = reverse ot.reverse = reverse
if contains_sub_seq[0]: RecurrentLayerGroupSetOutLink(ot.name)
RecurrentLayerGroupSetOutLink(Link(ot.name, has_subseq=True))
else:
RecurrentLayerGroupSetOutLink(ot.name)
RecurrentLayerGroupEnd(name=name) RecurrentLayerGroupEnd(name=name)
...@@ -5608,13 +5575,13 @@ def row_conv_layer(input, ...@@ -5608,13 +5575,13 @@ def row_conv_layer(input,
to deploy in an online and low-latency setting. The lookahead convolution to deploy in an online and low-latency setting. The lookahead convolution
incorporates information from future subsequences in a computationally incorporates information from future subsequences in a computationally
efficient manner to improve unidirectional recurrent neural networks. efficient manner to improve unidirectional recurrent neural networks.
The connection of row convolution is different form the 1D sequence The connection of row convolution is different form the 1D sequence
convolution. Assumed that, the future context-length is k, that is to say, convolution. Assumed that, the future context-length is k, that is to say,
it can get the output at timestep t by using the the input feature from t-th it can get the output at timestep t by using the the input feature from t-th
timestep to (t+k+1)-th timestep. Assumed that the hidden dim of input timestep to (t+k+1)-th timestep. Assumed that the hidden dim of input
activations are d, the activations r_t for the new layer at time-step t are: activations are d, the activations r_t for the new layer at time-step t are:
.. math:: .. math::
r_{t,r} = \sum_{j=1}^{k + 1} {w_{i,j}h_{t+j-1, i}} r_{t,r} = \sum_{j=1}^{k + 1} {w_{i,j}h_{t+j-1, i}}
......
...@@ -261,12 +261,10 @@ sub_models { ...@@ -261,12 +261,10 @@ sub_models {
in_links { in_links {
layer_name: "__simple_gru_0___transform" layer_name: "__simple_gru_0___transform"
link_name: "__simple_gru_0___transform@__simple_gru_0___recurrent_group" link_name: "__simple_gru_0___transform@__simple_gru_0___recurrent_group"
has_subseq: false
} }
out_links { out_links {
layer_name: "__simple_gru_0__@__simple_gru_0___recurrent_group" layer_name: "__simple_gru_0__@__simple_gru_0___recurrent_group"
link_name: "__simple_gru_0__" link_name: "__simple_gru_0__"
has_subseq: false
} }
target_inlinkid: -1 target_inlinkid: -1
} }
...@@ -285,12 +283,10 @@ sub_models { ...@@ -285,12 +283,10 @@ sub_models {
in_links { in_links {
layer_name: "__simple_gru_1___transform" layer_name: "__simple_gru_1___transform"
link_name: "__simple_gru_1___transform@__simple_gru_1___recurrent_group" link_name: "__simple_gru_1___transform@__simple_gru_1___recurrent_group"
has_subseq: false
} }
out_links { out_links {
layer_name: "__simple_gru_1__@__simple_gru_1___recurrent_group" layer_name: "__simple_gru_1__@__simple_gru_1___recurrent_group"
link_name: "__simple_gru_1__" link_name: "__simple_gru_1__"
has_subseq: false
} }
target_inlinkid: -1 target_inlinkid: -1
} }
......
...@@ -351,12 +351,10 @@ sub_models { ...@@ -351,12 +351,10 @@ sub_models {
in_links { in_links {
layer_name: "__mixed_0__" layer_name: "__mixed_0__"
link_name: "__mixed_0__@__lstm_group_0___recurrent_group" link_name: "__mixed_0__@__lstm_group_0___recurrent_group"
has_subseq: false
} }
out_links { out_links {
layer_name: "__lstm_group_0__@__lstm_group_0___recurrent_group" layer_name: "__lstm_group_0__@__lstm_group_0___recurrent_group"
link_name: "__lstm_group_0__" link_name: "__lstm_group_0__"
has_subseq: false
} }
target_inlinkid: -1 target_inlinkid: -1
} }
...@@ -383,12 +381,10 @@ sub_models { ...@@ -383,12 +381,10 @@ sub_models {
in_links { in_links {
layer_name: "__mixed_1__" layer_name: "__mixed_1__"
link_name: "__mixed_1__@__lstm_group_1___recurrent_group" link_name: "__mixed_1__@__lstm_group_1___recurrent_group"
has_subseq: false
} }
out_links { out_links {
layer_name: "__lstm_group_1__@__lstm_group_1___recurrent_group" layer_name: "__lstm_group_1__@__lstm_group_1___recurrent_group"
link_name: "__lstm_group_1__" link_name: "__lstm_group_1__"
has_subseq: false
} }
target_inlinkid: -1 target_inlinkid: -1
} }
......
...@@ -155,7 +155,7 @@ layers { ...@@ -155,7 +155,7 @@ layers {
} }
layers { layers {
name: "sub_seq_input@__recurrent_group_2__" name: "sub_seq_input@__recurrent_group_2__"
type: "sequence_scatter_agent" type: "scatter_agent"
size: 100 size: 100
active_type: "" active_type: ""
} }
...@@ -182,7 +182,7 @@ layers { ...@@ -182,7 +182,7 @@ layers {
} }
layers { layers {
name: "rnn_subseq_forward" name: "rnn_subseq_forward"
type: "sequence_gather_agent" type: "gather_agent"
size: 200 size: 200
active_type: "" active_type: ""
} }
...@@ -623,12 +623,10 @@ sub_models { ...@@ -623,12 +623,10 @@ sub_models {
in_links { in_links {
layer_name: "seq_input" layer_name: "seq_input"
link_name: "seq_input@__recurrent_group_0__" link_name: "seq_input@__recurrent_group_0__"
has_subseq: false
} }
out_links { out_links {
layer_name: "rnn_forward@__recurrent_group_0__" layer_name: "rnn_forward@__recurrent_group_0__"
link_name: "rnn_forward" link_name: "rnn_forward"
has_subseq: false
} }
target_inlinkid: -1 target_inlinkid: -1
} }
...@@ -647,12 +645,10 @@ sub_models { ...@@ -647,12 +645,10 @@ sub_models {
in_links { in_links {
layer_name: "seq_input" layer_name: "seq_input"
link_name: "seq_input@__recurrent_group_1__" link_name: "seq_input@__recurrent_group_1__"
has_subseq: false
} }
out_links { out_links {
layer_name: "rnn_back@__recurrent_group_1__" layer_name: "rnn_back@__recurrent_group_1__"
link_name: "rnn_back" link_name: "rnn_back"
has_subseq: false
} }
target_inlinkid: -1 target_inlinkid: -1
} }
...@@ -671,12 +667,10 @@ sub_models { ...@@ -671,12 +667,10 @@ sub_models {
in_links { in_links {
layer_name: "sub_seq_input" layer_name: "sub_seq_input"
link_name: "sub_seq_input@__recurrent_group_2__" link_name: "sub_seq_input@__recurrent_group_2__"
has_subseq: true
} }
out_links { out_links {
layer_name: "rnn_subseq_forward@__recurrent_group_2__" layer_name: "rnn_subseq_forward@__recurrent_group_2__"
link_name: "rnn_subseq_forward" link_name: "rnn_subseq_forward"
has_subseq: true
} }
target_inlinkid: -1 target_inlinkid: -1
} }
...@@ -703,12 +697,10 @@ sub_models { ...@@ -703,12 +697,10 @@ sub_models {
in_links { in_links {
layer_name: "__mixed_0__" layer_name: "__mixed_0__"
link_name: "__mixed_0__@__lstm_group_0___recurrent_group" link_name: "__mixed_0__@__lstm_group_0___recurrent_group"
has_subseq: false
} }
out_links { out_links {
layer_name: "__lstm_group_0__@__lstm_group_0___recurrent_group" layer_name: "__lstm_group_0__@__lstm_group_0___recurrent_group"
link_name: "__lstm_group_0__" link_name: "__lstm_group_0__"
has_subseq: false
} }
target_inlinkid: -1 target_inlinkid: -1
} }
...@@ -727,12 +719,10 @@ sub_models { ...@@ -727,12 +719,10 @@ sub_models {
in_links { in_links {
layer_name: "__mixed_1__" layer_name: "__mixed_1__"
link_name: "__mixed_1__@__gru_group_0___recurrent_group" link_name: "__mixed_1__@__gru_group_0___recurrent_group"
has_subseq: false
} }
out_links { out_links {
layer_name: "__gru_group_0__@__gru_group_0___recurrent_group" layer_name: "__gru_group_0__@__gru_group_0___recurrent_group"
link_name: "__gru_group_0__" link_name: "__gru_group_0__"
has_subseq: false
} }
target_inlinkid: -1 target_inlinkid: -1
} }
...@@ -751,12 +741,10 @@ sub_models { ...@@ -751,12 +741,10 @@ sub_models {
in_links { in_links {
layer_name: "seq_input" layer_name: "seq_input"
link_name: "seq_input@__recurrent_group_3__" link_name: "seq_input@__recurrent_group_3__"
has_subseq: false
} }
out_links { out_links {
layer_name: "__fc_layer_0__@__recurrent_group_3__" layer_name: "__fc_layer_0__@__recurrent_group_3__"
link_name: "__fc_layer_0__" link_name: "__fc_layer_0__"
has_subseq: false
} }
target_inlinkid: -1 target_inlinkid: -1
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册