提交 e732bdd4 编写于 作者: Y Yi Wang

Merge branch 'develop' of https://github.com/paddlepaddle/paddle into memory_design

group: deprecated-2017Q2
language: cpp language: cpp
cache: cache:
directories: directories:
......
...@@ -241,11 +241,14 @@ void NeuralNetwork::forward(const std::vector<Argument>& inArgs, ...@@ -241,11 +241,14 @@ void NeuralNetwork::forward(const std::vector<Argument>& inArgs,
dataLayers_[i]->setData(inArgs[i]); dataLayers_[i]->setData(inArgs[i]);
} }
gLayerStackTrace.set_stage(true);
{ {
for (auto& layer : layers_) { for (auto& layer : layers_) {
REGISTER_TIMER_INFO("ForwardTimer", layer->getName().c_str()); REGISTER_TIMER_INFO("ForwardTimer", layer->getName().c_str());
gLayerStackTrace.push(layer->getName()); gLayerStackTrace.push(layer->getName());
layer->forward(passType); layer->forward(passType);
gLayerStackTrace.pop(layer->getName());
} }
} }
...@@ -254,9 +257,6 @@ void NeuralNetwork::forward(const std::vector<Argument>& inArgs, ...@@ -254,9 +257,6 @@ void NeuralNetwork::forward(const std::vector<Argument>& inArgs,
for (auto& layer : outputLayers_) { for (auto& layer : outputLayers_) {
outArgs->push_back(layer->getOutput()); outArgs->push_back(layer->getOutput());
} }
if (passType == PASS_TEST) {
gLayerStackTrace.clear();
}
} }
void NeuralNetwork::resetState() { void NeuralNetwork::resetState() {
...@@ -283,9 +283,10 @@ void NeuralNetwork::getState(MachineState& machineState) { ...@@ -283,9 +283,10 @@ void NeuralNetwork::getState(MachineState& machineState) {
} }
void NeuralNetwork::backward(const UpdateCallback& callback) { void NeuralNetwork::backward(const UpdateCallback& callback) {
gLayerStackTrace.pop(""); // tell layer trace is during backward. gLayerStackTrace.set_stage(false);
FOR_EACH_R(layer, layers_) { FOR_EACH_R(layer, layers_) {
REGISTER_TIMER_INFO("BackwardTimer", (*layer)->getName().c_str()); REGISTER_TIMER_INFO("BackwardTimer", (*layer)->getName().c_str());
gLayerStackTrace.push((*layer)->getName());
if ((*layer)->needGradient()) { if ((*layer)->needGradient()) {
(*layer)->backward(callback); (*layer)->backward(callback);
} }
...@@ -320,7 +321,7 @@ public: ...@@ -320,7 +321,7 @@ public:
} }
} }
virtual void eval(const NeuralNetwork& nn) { virtual void eval(const NeuralNetwork& nn) override {
for (auto& evaluator : evaluators_) { for (auto& evaluator : evaluators_) {
evaluator->eval(nn); evaluator->eval(nn);
} }
...@@ -395,6 +396,30 @@ private: ...@@ -395,6 +396,30 @@ private:
} }
}; };
class SubnetEvaluator : public CombinedEvaluator {
public:
SubnetEvaluator(const std::string& layerName,
std::unique_ptr<Evaluator>&& evaluator)
: layerName_(layerName) {
addEvaluator(std::move(evaluator));
}
virtual void eval(const NeuralNetwork& nn) override {
const LayerPtr& layer = nn.getLayer(layerName_);
CHECK(layer) << "Nonexisted layer: " << layerName_ << " in submodel "
<< nn.getName();
bool accessed = false;
layer->accessSubNetwork([this, &accessed](NeuralNetwork& subnet) {
subnet.eval(evaluators_[0].get());
accessed = true;
});
CHECK(accessed) << "There is no subnetwork for layer " << layerName_
<< " in submodel " << nn.getName();
}
protected:
std::string layerName_;
};
Evaluator* NeuralNetwork::makeEvaluator() const { Evaluator* NeuralNetwork::makeEvaluator() const {
CombinedEvaluator* combinedEvaluator = new CombinedEvaluator(); CombinedEvaluator* combinedEvaluator = new CombinedEvaluator();
auto subModelConfig = std::find_if(config_.sub_models().begin(), auto subModelConfig = std::find_if(config_.sub_models().begin(),
...@@ -421,6 +446,15 @@ Evaluator* NeuralNetwork::makeEvaluator() const { ...@@ -421,6 +446,15 @@ Evaluator* NeuralNetwork::makeEvaluator() const {
combinedEvaluator->addEvaluator(std::move(evaluator)); combinedEvaluator->addEvaluator(std::move(evaluator));
} }
} }
for (auto& layer : layers_) {
layer->accessSubNetwork(
[layer, combinedEvaluator](NeuralNetwork& subnet) {
std::unique_ptr<Evaluator> subEvaluator(new SubnetEvaluator(
layer->getName(),
std::unique_ptr<Evaluator>(subnet.makeEvaluator())));
combinedEvaluator->addEvaluator(std::move(subEvaluator));
});
}
} else { } else {
for (const EvaluatorConfig& evalConfig : config_.evaluators()) { for (const EvaluatorConfig& evalConfig : config_.evaluators()) {
std::unique_ptr<Evaluator> evaluator(Evaluator::create(evalConfig)); std::unique_ptr<Evaluator> evaluator(Evaluator::create(evalConfig));
......
...@@ -129,6 +129,8 @@ public: ...@@ -129,6 +129,8 @@ public:
static NeuralNetwork* newNeuralNetwork(const std::string& name = "", static NeuralNetwork* newNeuralNetwork(const std::string& name = "",
NeuralNetwork* rootNetwork = nullptr); NeuralNetwork* rootNetwork = nullptr);
const std::string& getName() const { return subModelName_; }
protected: protected:
/** /**
* The constructor of NeuralNetwork. * The constructor of NeuralNetwork.
......
...@@ -208,6 +208,7 @@ void RecurrentGradientMachine::init( ...@@ -208,6 +208,7 @@ void RecurrentGradientMachine::init(
}); });
CHECK(subModelConfig != config.sub_models().end()); CHECK(subModelConfig != config.sub_models().end());
reversed_ = subModelConfig->reversed(); reversed_ = subModelConfig->reversed();
generating_ = subModelConfig->has_generator();
inFrameLines_.resize(subModelConfig->in_links_size()); inFrameLines_.resize(subModelConfig->in_links_size());
for (size_t i = 0; i < inFrameLines_.size(); ++i) { for (size_t i = 0; i < inFrameLines_.size(); ++i) {
...@@ -287,10 +288,6 @@ void RecurrentGradientMachine::init( ...@@ -287,10 +288,6 @@ void RecurrentGradientMachine::init(
parameterIds_.push_back(para->getID()); parameterIds_.push_back(para->getID());
} }
} }
if (subModelConfig->evaluator_names_size() > 0) {
evaluator_.reset(frames_[0]->makeEvaluator());
}
} }
void RecurrentGradientMachine::resizeOrCreateFrames(int numFrames) { void RecurrentGradientMachine::resizeOrCreateFrames(int numFrames) {
...@@ -538,7 +535,7 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs, ...@@ -538,7 +535,7 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs,
The outputs are outFramesLines_[i].agentLayer The outputs are outFramesLines_[i].agentLayer
*/ */
if (inFrameLines_.empty() && passType == PASS_TEST) { if (generating_) {
generateSequence(); generateSequence();
return; return;
} // else forward.. } // else forward..
...@@ -561,14 +558,14 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs, ...@@ -561,14 +558,14 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs,
std::vector<Argument> outArgs; std::vector<Argument> outArgs;
frames_[i]->forward(inArgs, &outArgs, passType); frames_[i]->forward(inArgs, &outArgs, passType);
} }
if (evaluator_ && passType == PASS_TEST) {
this->eval(evaluator_.get());
}
reorganizeOutput(passType); reorganizeOutput(passType);
} }
void RecurrentGradientMachine::backward(const UpdateCallback& callback) { void RecurrentGradientMachine::backward(const UpdateCallback& callback) {
if (generating_) {
return;
}
REGISTER_TIMER_INFO("RecurrentBwTime", "RecurrentBwTime"); REGISTER_TIMER_INFO("RecurrentBwTime", "RecurrentBwTime");
AsyncGpuBlock asyncGpuBlock; AsyncGpuBlock asyncGpuBlock;
for (int i = maxSequenceLength_ - 1; i >= 0; --i) { for (int i = maxSequenceLength_ - 1; i >= 0; --i) {
...@@ -577,11 +574,6 @@ void RecurrentGradientMachine::backward(const UpdateCallback& callback) { ...@@ -577,11 +574,6 @@ void RecurrentGradientMachine::backward(const UpdateCallback& callback) {
for (auto& memoryFrameLine : memoryFrameLines_) { for (auto& memoryFrameLine : memoryFrameLines_) {
memoryFrameLine.bootLayer->backward(nullptr); memoryFrameLine.bootLayer->backward(nullptr);
} }
// call printers here so the gradient can be printed
if (evaluator_) {
this->eval(evaluator_.get());
}
} }
void RecurrentGradientMachine::forwardBackward( void RecurrentGradientMachine::forwardBackward(
...@@ -595,9 +587,9 @@ void RecurrentGradientMachine::forwardBackward( ...@@ -595,9 +587,9 @@ void RecurrentGradientMachine::forwardBackward(
void RecurrentGradientMachine::eval(Evaluator* evaluator) const { void RecurrentGradientMachine::eval(Evaluator* evaluator) const {
// call printers frame by frame // call printers frame by frame
for (int i = 0; i < maxSequenceLength_; ++i) { for (int i = 0; i < maxSequenceLength_; ++i) {
LOG(INFO) << "Recurrent Layer Group eval frame " << i << " begin"; VLOG(2) << "Recurrent Layer Group eval frame " << i << " begin";
evaluator->eval(*(frames_[i].get())); evaluator->eval(*(frames_[i].get()));
LOG(INFO) << "Recurrent Layer Group eval frame " << i << " end"; VLOG(2) << "Recurrent Layer Group eval frame " << i << " end";
} }
} }
...@@ -1093,10 +1085,6 @@ void RecurrentGradientMachine::oneWaySearch(size_t batchSize) { ...@@ -1093,10 +1085,6 @@ void RecurrentGradientMachine::oneWaySearch(size_t batchSize) {
copyDataOutlinkFrame(machineCur); copyDataOutlinkFrame(machineCur);
// call value printer
if (evaluator_) {
evaluator_->eval(*(frames_[machineCur].get()));
}
// check eos // check eos
const IVectorPtr& eosVec = const IVectorPtr& eosVec =
eosFrameLine_->layers[machineCur]->getOutput().ids; eosFrameLine_->layers[machineCur]->getOutput().ids;
...@@ -1321,11 +1309,10 @@ void RecurrentGradientMachine::fillGenOutputs() { ...@@ -1321,11 +1309,10 @@ void RecurrentGradientMachine::fillGenOutputs() {
batchMachineIdVec_.clear(); batchMachineIdVec_.clear();
generator_.ids.clear(); generator_.ids.clear();
int* starts = generator_.outArg.sequenceStartPositions->getMutableData(false);
starts[0] = 0;
if (numResults > 1) { if (numResults > 1) {
real* probs = generator_.outArg.in->getData(); real* probs = generator_.outArg.in->getData();
int* starts =
generator_.outArg.sequenceStartPositions->getMutableData(false);
starts[0] = 0;
for (size_t i = 0; i < finalPaths_.size(); ++i) { for (size_t i = 0; i < finalPaths_.size(); ++i) {
for (size_t j = 0; j < finalPaths_[i].size(); ++j) { for (size_t j = 0; j < finalPaths_[i].size(); ++j) {
Path& path = finalPaths_[i][j]; Path& path = finalPaths_[i][j];
...@@ -1348,7 +1335,10 @@ void RecurrentGradientMachine::fillGenOutputs() { ...@@ -1348,7 +1335,10 @@ void RecurrentGradientMachine::fillGenOutputs() {
} else { } else {
for (size_t i = 0; i < finalPaths_.size(); ++i) { for (size_t i = 0; i < finalPaths_.size(); ++i) {
CHECK(!finalPaths_[i].empty()); CHECK(!finalPaths_[i].empty());
generator_.ids = finalPaths_[i][0].ids; generator_.ids.insert(generator_.ids.begin(),
finalPaths_[i][0].ids.begin(),
finalPaths_[i][0].ids.end());
starts[i + 1] = starts[i] + finalPaths_[i][0].ids.size();
} }
} }
} }
......
...@@ -414,6 +414,7 @@ protected: ...@@ -414,6 +414,7 @@ protected:
std::vector<int> ids; // store generated sequences std::vector<int> ids; // store generated sequences
Argument outArg; // final output argument Argument outArg; // final output argument
}; };
bool generating_;
Generator generator_; Generator generator_;
std::vector<std::unique_ptr<NeuralNetwork>> frames_; std::vector<std::unique_ptr<NeuralNetwork>> frames_;
...@@ -428,8 +429,6 @@ protected: ...@@ -428,8 +429,6 @@ protected:
std::vector<int> std::vector<int>
parameterIds_; // parameters actually used by this Layer Group parameterIds_; // parameters actually used by this Layer Group
std::unique_ptr<Evaluator> evaluator_; // frame printers in this layer group
// store final argument of outFrameLines_ // store final argument of outFrameLines_
std::vector<Argument> dataArgs_; std::vector<Argument> dataArgs_;
// store each frame's output argument of outFrameLines_ // store each frame's output argument of outFrameLines_
......
...@@ -109,6 +109,40 @@ void GatherAgentLayer::forwardValue(PassType passType) { ...@@ -109,6 +109,40 @@ void GatherAgentLayer::forwardValue(PassType passType) {
} }
} }
namespace {
// dest[index[i]] <- src[i] for each i
void copyElements(const IVector& srcVec,
const IVector& indexVec,
IVector& destVec) {
const int* src = srcVec.getData();
const int* index = indexVec.getData();
int* dest = destVec.getData();
int len = indexVec.getSize();
CHECK_EQ(srcVec.getSize(), indexVec.getSize());
for (int i = 0; i < len; ++i) {
dest[index[i]] = src[i];
}
}
}
void GatherAgentLayer::forwardIds(PassType passType) {
IVectorPtr realId = realLayers_[0]->getOutputLabel();
if (!realId) return;
IVector::resizeOrCreate(output_.ids, allIds_->getSize(), useGpu_);
IVectorPtr outId = output_.ids;
idsVec_.resize(idIndex_.size());
for (size_t i = 0; i < realLayers_.size(); ++i) {
const IVectorPtr& realId = realLayers_[i]->getOutputLabel();
idsVec_[i] = IVector::create(allIds_->getData() + idIndex_[i],
/* size */ realId->getSize(),
useGpu_);
execViaCpu(&copyElements, *realId, *idsVec_[i], *outId);
}
}
void GatherAgentLayer::backward(const UpdateCallback& callback) { void GatherAgentLayer::backward(const UpdateCallback& callback) {
(void)callback; (void)callback;
const MatrixPtr& outputGrad = getOutputGrad(); const MatrixPtr& outputGrad = getOutputGrad();
...@@ -136,23 +170,22 @@ void ScatterAgentLayer::forward(PassType passType) { ...@@ -136,23 +170,22 @@ void ScatterAgentLayer::forward(PassType passType) {
CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId()); CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId());
int width = this->getSize(); int width = this->getSize();
if (selectionMode_) {
forwardWithSelection(passType);
} else {
if (realOutArg_.hasSeq()) { if (realOutArg_.hasSeq()) {
forwardSequence(passType); output_.subArgFrom(realOutArg_,
} else if (realOutArg_.value || realOutArg_.ids) { /* offset */ idIndex_,
idSize_,
width,
useGpu_,
/* trans */ false,
/* seqFlag */ true,
/* seqStart */ seqStartPosIndex_,
/* seqSize */ numSequences_);
} else {
output_.subArgFrom( output_.subArgFrom(
realOutArg_, /* offset */ idIndex_, idSize_, width, useGpu_); realOutArg_, /* offset */ idIndex_, idSize_, width, useGpu_);
} else { // used in generation
if (realLayer_->getOutput().ids) {
IVector::resizeOrCreate(output_.ids, ids_->getSize(), useGpu_);
output_.ids->selectFrom(*realLayer_->getOutput().ids, *ids_);
}
if (realLayer_->getOutput().value) {
int height = ids_->getSize();
resetOutput(height, width);
const MatrixPtr& outV = getOutputValue();
const MatrixPtr& realV = realLayer_->getOutputValue();
outV->selectRows(*realV, *ids_);
} }
} }
} }
...@@ -160,6 +193,8 @@ void ScatterAgentLayer::forward(PassType passType) { ...@@ -160,6 +193,8 @@ void ScatterAgentLayer::forward(PassType passType) {
void ScatterAgentLayer::backward(const UpdateCallback& callback) { void ScatterAgentLayer::backward(const UpdateCallback& callback) {
(void)callback; (void)callback;
CHECK(!selectionMode_);
const MatrixPtr& outputGrad = realOutArg_.grad; const MatrixPtr& outputGrad = realOutArg_.grad;
const MatrixPtr& realGrad = realLayer_->getOutputGrad(); const MatrixPtr& realGrad = realLayer_->getOutputGrad();
if (realGrad) { if (realGrad) {
...@@ -174,42 +209,7 @@ void ScatterAgentLayer::backward(const UpdateCallback& callback) { ...@@ -174,42 +209,7 @@ void ScatterAgentLayer::backward(const UpdateCallback& callback) {
REGISTER_LAYER(gather_agent, GatherAgentLayer); REGISTER_LAYER(gather_agent, GatherAgentLayer);
REGISTER_LAYER(scatter_agent, ScatterAgentLayer); REGISTER_LAYER(scatter_agent, ScatterAgentLayer);
void GatherAgentLayer::forwardIds(PassType passType) { void ScatterAgentLayer::forwardWithSelection(PassType passType) {
int height = 0;
IVectorPtr idReal = realLayers_[0]->getOutputLabel();
if (!idReal) return;
if (output_.subSequenceStartPositions) {
int* starts = output_.subSequenceStartPositions->getMutableData(false);
// Gather generator.idsVec
// if is beam search generation result. Get first result.
if (idReal->getData()[idReal->getSize() - 1] == -1) {
for (size_t i = 0; i < realLayers_.size(); ++i) {
// The first element stores first result size
idReal = realLayers_[i]->getOutputLabel();
idReal->subVecFrom(*idReal, 1, idReal->getData()[0]);
}
}
for (size_t i = 0; i < realLayers_.size(); ++i) {
CHECK(realLayers_[i]->getOutputLabel());
starts[i] = height;
height += realLayers_[i]->getOutputLabel()->getSize();
}
starts[realLayers_.size()] = height;
output_.sequenceStartPositions->getMutableData(false)[1] = height;
IVector::resizeOrCreate(output_.ids, height, false);
for (size_t i = 0; i < realLayers_.size(); ++i) {
output_.ids->subVec(starts[i], starts[i + 1] - starts[i])
->copyFrom(*realLayers_[i]->getOutputLabel());
}
} else {
LOG(FATAL) << "Not implemented";
}
}
void ScatterAgentLayer::forwardSequence(PassType passType) {
Layer::forward(passType); Layer::forward(passType);
CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId()); CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId());
...@@ -220,17 +220,19 @@ void ScatterAgentLayer::forwardSequence(PassType passType) { ...@@ -220,17 +220,19 @@ void ScatterAgentLayer::forwardSequence(PassType passType) {
AsyncGpuBlock asyncGpuBlock; AsyncGpuBlock asyncGpuBlock;
REGISTER_TIMER_INFO("SequenceAgentLayerForward", getName().c_str()); REGISTER_TIMER_INFO("SequenceAgentLayerForward", getName().c_str());
if (realOutArg_.value || realOutArg_.ids) { if (!input.hasSeq()) {
CHECK(realOutArg_.sequenceStartPositions); if (realLayer_->getOutput().ids) {
output_.subArgFrom(realOutArg_, IVector::resizeOrCreate(output_.ids, ids_->getSize(), useGpu_);
/* offset */ idIndex_, output_.ids->selectFrom(*realLayer_->getOutput().ids, *ids_);
idSize_, }
width, if (realLayer_->getOutput().value) {
useGpu_, int height = ids_->getSize();
/* trans */ false, resetOutput(height, width);
/* seqFlag */ true,
/* seqStart */ seqStartPosIndex_, const MatrixPtr& outV = getOutputValue();
/* seqSize */ numSequences_); const MatrixPtr& realV = realLayer_->getOutputValue();
outV->selectRows(*realV, *ids_);
}
} else { } else {
// Putting the generation logic here is really an ugly hack! // Putting the generation logic here is really an ugly hack!
// used in generation // used in generation
......
...@@ -110,6 +110,9 @@ protected: ...@@ -110,6 +110,9 @@ protected:
// of real layer. // of real layer.
ICpuGpuVectorPtr inputStartPos_; ICpuGpuVectorPtr inputStartPos_;
// true for setRealLayer, false for setRealLayerAndOutput
bool selectionMode_;
public: public:
explicit ScatterAgentLayer(const LayerConfig& config) : Layer(config) {} explicit ScatterAgentLayer(const LayerConfig& config) : Layer(config) {}
...@@ -137,6 +140,7 @@ public: ...@@ -137,6 +140,7 @@ public:
} else { } else {
cpuIds_ = ids_; cpuIds_ = ids_;
} }
selectionMode_ = true;
} }
// set real layer and output, [idIndex, idIndex + idSize) of *ids* // set real layer and output, [idIndex, idIndex + idSize) of *ids*
...@@ -153,6 +157,7 @@ public: ...@@ -153,6 +157,7 @@ public:
idIndex_ = idIndex; idIndex_ = idIndex;
idSize_ = idSize; idSize_ = idSize;
handleBackward_ = handleBackward; handleBackward_ = handleBackward;
selectionMode_ = false;
} }
void setSequenceStartPositions(const ICpuGpuVectorPtr& sequenceStartPositions, void setSequenceStartPositions(const ICpuGpuVectorPtr& sequenceStartPositions,
...@@ -166,7 +171,7 @@ public: ...@@ -166,7 +171,7 @@ public:
void forward(PassType passType) override; void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override; void backward(const UpdateCallback& callback) override;
void forwardSequence(PassType passType); void forwardWithSelection(PassType passType);
}; };
} // namespace paddle } // namespace paddle
...@@ -53,7 +53,7 @@ def outer_step(dummy_data): ...@@ -53,7 +53,7 @@ def outer_step(dummy_data):
bos_id=0, bos_id=0,
eos_id=num_words-1, eos_id=num_words-1,
beam_size=2 if beam_flag else 1, beam_size=2 if beam_flag else 1,
num_results_per_sample=2 if beam_flag else 1, num_results_per_sample=1,
max_length=10) max_length=10)
return beam_gen return beam_gen
......
...@@ -55,13 +55,17 @@ public: ...@@ -55,13 +55,17 @@ public:
* Else, just set status to popping. * Else, just set status to popping.
*/ */
void pop(const T& item) { void pop(const T& item) {
pushing() = false;
auto& s = this->stack(); auto& s = this->stack();
if (item == s.top()) { if (item == s.top()) {
s.pop(); s.pop();
} }
} }
/**
* @brief Indicate whether we are at forward or backward stage of computation
*/
void set_stage(bool isForward) { pushing() = isForward; }
/** /**
* @brief clear current thread stack. * @brief clear current thread stack.
*/ */
......
...@@ -72,7 +72,6 @@ TEST(CustomStackTrace, normalTrain) { ...@@ -72,7 +72,6 @@ TEST(CustomStackTrace, normalTrain) {
for (size_t i = 0; i < layerSize; ++i) { for (size_t i = 0; i < layerSize; ++i) {
tracer.push("layer_" + paddle::str::to_string(i)); tracer.push("layer_" + paddle::str::to_string(i));
} }
tracer.pop("");
for (size_t i = 0; i < layerSize; ++i) { for (size_t i = 0; i < layerSize; ++i) {
tracer.pop("layer_" + paddle::str::to_string(layerSize - 1 - i)); tracer.pop("layer_" + paddle::str::to_string(layerSize - 1 - i));
} }
......
...@@ -45,12 +45,12 @@ __all__ = ['data', 'parse_network'] ...@@ -45,12 +45,12 @@ __all__ = ['data', 'parse_network']
def __need_to_keep__(name): def __need_to_keep__(name):
return name in [ return name in [
'StaticInput', 'SubsequenceInput', 'GeneratedInput', 'LayerType', 'StaticInput', 'SubsequenceInput', 'GeneratedInput', 'LayerType',
'layer_support' 'layer_support', 'BaseGeneratedInput'
] ]
def __need_to_wrap__(name): def __need_to_wrap__(name):
return name not in ['AggregateLevel', 'ExpandLevel'] return name not in ['AggregateLevel', 'ExpandLevel', 'BaseGeneratedInput']
def __convert_name__(inname): def __convert_name__(inname):
...@@ -199,6 +199,15 @@ def __get_used_submodels__(layer_names): ...@@ -199,6 +199,15 @@ def __get_used_submodels__(layer_names):
return submodel_names return submodel_names
def __get_submodel_data_out_links__():
data_links = set()
for submodel in cp.g_config.model_config.sub_models:
for link in submodel.out_links:
if cp.g_layer_map[link.link_name].type == 'data':
data_links.add(link.link_name)
return data_links
def __get_used_evaluators__(layer_names): def __get_used_evaluators__(layer_names):
evaluator_names = set() evaluator_names = set()
for e in cp.g_config.model_config.evaluators: for e in cp.g_config.model_config.evaluators:
...@@ -264,6 +273,7 @@ def parse_network(output_layers, extra_layers=None): ...@@ -264,6 +273,7 @@ def parse_network(output_layers, extra_layers=None):
submodel_names = __get_used_submodels__(layer_names) submodel_names = __get_used_submodels__(layer_names)
submodel_names.add('root') submodel_names.add('root')
evaluator_names = __get_used_evaluators__(layer_names) evaluator_names = __get_used_evaluators__(layer_names)
data_out_links = __get_submodel_data_out_links__()
input_layer_names = set() input_layer_names = set()
output_layer_names = set() output_layer_names = set()
...@@ -279,7 +289,7 @@ def parse_network(output_layers, extra_layers=None): ...@@ -279,7 +289,7 @@ def parse_network(output_layers, extra_layers=None):
continue continue
model_config.layers.extend([l]) model_config.layers.extend([l])
if l.type == 'data': if l.type == 'data':
if l.name in model_config.output_layer_names: if l.name in data_out_links:
""" """
In text generation, the outlink to save the generated word In text generation, the outlink to save the generated word
indices is a data_layer defined in recurrent_group. This indices is a data_layer defined in recurrent_group. This
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册