提交 382de964 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #3249 from lcy-seso/kmax_score_layer

Add a Kmax sequence score layer.
......@@ -257,6 +257,11 @@ seq_concat
.. autoclass:: paddle.v2.layer.seq_concat
:noindex:
kmax_sequence_score
-------------------
.. autoclass:: paddle.v2.layer.kmax_sequence_score
:noindex:
sub_nested_seq
--------------
.. autoclass:: paddle.v2.layer.sub_nested_seq
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Layer.h"
namespace paddle {
class KmaxSeqScoreLayer : public Layer {
private:
MatrixPtr scores_;
size_t beamSize_;
void kmaxScorePerSeq(const real* score,
real* sortedRes,
const ICpuGpuVectorPtr seqStartPos);
public:
explicit KmaxSeqScoreLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(kmax_seq_score, KmaxSeqScoreLayer);
bool KmaxSeqScoreLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
bool ret = Layer::init(layerMap, parameterMap);
CHECK_EQ(1U, inputLayers_.size());
beamSize_ = config_.beam_size();
CHECK_GE(beamSize_, 1U);
setNeedSequenceInfo(false);
setNeedGradient(false);
return ret;
}
void KmaxSeqScoreLayer::kmaxScorePerSeq(const real* scores,
real* sortedIds,
const ICpuGpuVectorPtr seqStartPos) {
int* starts = seqStartPos->getMutableData(false);
std::vector<real> indices;
for (size_t i = 0; i < seqStartPos->getSize() - 1; ++i) {
int seqLen = starts[i + 1] - starts[i];
int k = std::min(static_cast<int>(beamSize_), seqLen);
indices.resize(seqLen, 0);
std::iota(begin(indices), end(indices), 0.);
std::vector<real> tmpScore(scores + starts[i], scores + starts[i + 1]);
std::partial_sort(
begin(indices),
begin(indices) + k,
end(indices),
[&](size_t a, size_t b) { return tmpScore[a] > tmpScore[b]; });
memcpy(sortedIds + (i * beamSize_), indices.data(), k * sizeof(real));
}
}
void KmaxSeqScoreLayer::forward(PassType passType) {
Layer::forward(passType);
const Argument& input = getInput(0);
const MatrixPtr inputScore = getInputValue(0);
CHECK(input.hasSeq() || input.hasSubseq())
<< "input of " << getName()
<< " must be a sequence or a nested sequence.";
CHECK_EQ(input.value->getWidth(), 1UL)
<< "input of " << getName()
<< " is score over a sequence or a nested sequence, so its width "
<< " must be 1.";
if (useGpu_) {
// this Layer runs only in CPU, if the model is runing on GPU,
// then copy the input to this layer from GPU to CPU.
Matrix::resizeOrCreate(scores_,
inputScore->getHeight(),
1,
false /* trans */,
false /* useGpu */);
scores_->copyFrom(*inputScore);
} else {
scores_ = inputScore;
}
Matrix::resizeOrCreate(
output_.value,
input.hasSubseq() ? input.getNumSubSequences() : input.getNumSequences(),
beamSize_,
false,
false);
output_.value->one();
output_.value->mulScalar(-1.);
kmaxScorePerSeq(scores_->getData(),
output_.value->getData(),
input.hasSubseq() ? input.subSequenceStartPositions
: input.sequenceStartPositions);
}
void KmaxSeqScoreLayer::backward(const UpdateCallback& callback) {}
} // namespace paddle
......@@ -66,6 +66,16 @@ add_unittest_without_exec(test_BatchNorm
add_test(NAME test_BatchNorm
COMMAND test_BatchNorm)
################# test_KmaxSeqScore #######################
add_unittest_without_exec(test_KmaxSeqScore
test_KmaxSeqScore.cpp
LayerGradUtil.cpp)
add_test(NAME test_KmaxSeqScore
COMMAND test_KmaxSeqScore)
################## test_Evaluator #######################
add_unittest(test_Evaluator
test_Evaluator.cpp)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <algorithm>
#include <string>
#include <vector>
#include "ModelConfig.pb.h"
#include "paddle/gserver/layers/DataLayer.h"
#include "paddle/trainer/Trainer.h"
#include "paddle/utils/GlobalConstants.h"
#include "LayerGradUtil.h"
#include "paddle/testing/TestUtil.h"
using namespace paddle; // NOLINT
using namespace std; // NOLINT
DECLARE_bool(use_gpu);
DECLARE_int32(gpu_id);
DECLARE_bool(thread_local_rand_use_global_seed);
vector<int> randSampling(int range, int n) {
CHECK_GE(range, n);
vector<int> num(range);
iota(begin(num), end(num), 0);
if (range == n) return num;
random_shuffle(begin(num), end(num));
num.resize(n);
return num;
}
void genRandomSeqInfo(vector<int>& seqStartPosition,
vector<int>& subSeqStartPosition) {
const int maxSeqNum = 100;
// generate random start position information
int seqNum = 1 + (rand() % maxSeqNum);
seqStartPosition.resize(seqNum + 1, 0);
subSeqStartPosition.resize(1, 0);
for (int i = 0; i < seqNum; ++i) {
int subSeqLen = 1 + (rand() % maxSeqNum);
for (int j = 0; j < subSeqLen; ++j)
subSeqStartPosition.push_back(subSeqStartPosition.back() + subSeqLen);
seqStartPosition[i + 1] = subSeqStartPosition.back();
}
}
void genRandomGroundTruth(real* values,
vector<vector<int>>& groundTruth,
vector<int>& startPos,
size_t beamSize) {
groundTruth.resize(startPos.size() - 1, vector<int>(beamSize, -1));
for (size_t i = 0; i < startPos.size() - 1; ++i) {
int seqLen = startPos[i + 1] - startPos[i];
vector<int> pos =
randSampling(seqLen, min(static_cast<int>(beamSize), seqLen));
for (size_t j = 0; j < pos.size(); ++j) {
groundTruth[i][j] = pos[j];
values[startPos[i] + pos[j]] = 1.;
}
}
}
void checkLayerOut(vector<vector<int>> groundTruth,
real* layerOut,
size_t beamSize) {
for (size_t i = 0; i < groundTruth.size(); ++i) {
int begPos = i * beamSize;
vector<real> tmp(layerOut + begPos, layerOut + begPos + beamSize);
sort(begin(tmp), end(tmp));
sort(begin(groundTruth[i]), end(groundTruth[i]));
for (size_t j = 0; j < beamSize; ++j) CHECK_EQ(tmp[j], groundTruth[i][j]);
}
}
TEST(Layer, kmaxSeqScoreLayer) {
const size_t maxBeamSize = 100;
int beamSize = 1 + (rand() % maxBeamSize);
vector<int> seqStartPosition;
vector<int> subSeqStartPosition;
genRandomSeqInfo(seqStartPosition, subSeqStartPosition);
MatrixPtr inValue =
Matrix::create(subSeqStartPosition.back(), 1, false, false);
for (auto hasSubseq : {false, true}) {
vector<vector<int>> groundTruth;
inValue->randomizeUniform();
genRandomGroundTruth(inValue->getData(),
groundTruth,
hasSubseq ? subSeqStartPosition : seqStartPosition,
beamSize);
for (auto useGpu : {false, true}) {
TestConfig config;
config.layerConfig.set_type("kmax_seq_score");
config.layerConfig.set_beam_size(beamSize);
if (hasSubseq) {
config.inputDefs.push_back({INPUT_SELF_DEFINE_DATA,
"scores",
inValue,
seqStartPosition,
subSeqStartPosition});
} else {
config.inputDefs.push_back(
{INPUT_SELF_DEFINE_DATA, "scores", inValue, seqStartPosition});
}
config.layerConfig.add_inputs();
// data layer initialize
std::vector<DataLayerPtr> dataLayers;
LayerMap layerMap;
vector<Argument> datas;
initDataLayer(
config,
&dataLayers,
&datas,
&layerMap,
"kmax_seq_score",
100 /* actually this parameter is unused in self-defined input*/,
false,
useGpu);
// test layer initialize
std::vector<ParameterPtr> parameters;
LayerPtr kmaxSeqScoreLayer;
FLAGS_use_gpu = useGpu;
initTestLayer(config, &layerMap, &parameters, &kmaxSeqScoreLayer);
kmaxSeqScoreLayer->forward(PASS_TRAIN);
const MatrixPtr outValue = kmaxSeqScoreLayer->getOutputValue();
CHECK_EQ(outValue->getHeight(),
hasSubseq ? subSeqStartPosition.size() - 1
: seqStartPosition.size() - 1);
CHECK_EQ(outValue->getWidth(), beamSize);
checkLayerOut(groundTruth, outValue->getData(), beamSize);
}
}
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
initMain(argc, argv);
FLAGS_thread_local_rand_use_global_seed = true;
srand((size_t)(time(NULL)));
return RUN_ALL_TESTS();
}
......@@ -3248,6 +3248,16 @@ class CTCLayer(LayerBase):
config_assert(len(self.inputs) == 2, 'CTCLayer must have 2 inputs')
@config_layer('kmax_seq_score')
class KmaxSeqScoreLayer(LayerBase):
def __init__(self, name, inputs, beam_size, **xargs):
super(KmaxSeqScoreLayer, self).__init__(
name, 'kmax_seq_score', 0, inputs=inputs, **xargs)
config_assert(
len(self.inputs) == 1, 'KmaxSeqScoreLayer has only one input.')
self.config.beam_size = beam_size
@config_layer('warp_ctc')
class WarpCTCLayer(LayerBase):
def __init__(self,
......
......@@ -132,6 +132,7 @@ __all__ = [
'sub_nested_seq_layer',
'clip_layer',
'slice_projection',
'kmax_sequence_score_layer',
]
......@@ -228,6 +229,8 @@ class LayerType(object):
SUB_NESTED_SEQ = 'sub_nested_seq'
CLIP_LAYER = 'clip'
KMAX_SEQ_SCORE = 'kmax_seq_score'
@staticmethod
def is_layer_type(type_name):
"""
......@@ -6158,7 +6161,8 @@ def clip_layer(input, min, max, name=None):
:type min: double
:param max: The upper threshold for clipping.
:type max: double
:return: LayerOutput
:return: LayerOutput object.
:rtype: LayerOutput
"""
Layer(
name=name,
......@@ -6168,3 +6172,41 @@ def clip_layer(input, min, max, name=None):
max=max)
return LayerOutput(
name, LayerType.CLIP_LAYER, parents=[input], size=input.size)
@wrap_name_default()
@layer_support()
def kmax_sequence_score_layer(input, name=None, beam_size=1):
"""
This layer accepts one input which are scores over a sequence or a nested
sequence, and returns indices of beam_size sequences with highest scores.
.. code-block:: python
kmax_indices = kmax_sequence_score_layer(input=input_layer, beam_size)
:param name: The Layer Name.
:type name: basestring
:param input: The input layer. It stores scores over a sequence or a nested
sequence and its size must be 1.
:type input: LayerOutput.
:param beam_size: squence indices with top beam_size scores are returned.
:type beam_size: double
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert isinstance(input, LayerOutput), ("kmax_sequence_score_layer "
"accepts only one input.")
assert input.size == 1, (
"input of kmax_sequence_score_layer is a score"
"over a sequence or a nested sequence, so its width must be 1.")
Layer(
name=name,
type=LayerType.KMAX_SEQ_SCORE,
inputs=[input.name],
beam_size=beam_size)
return LayerOutput(
name, LayerType.KMAX_SEQ_SCORE, parents=[input], size=input.size)
......@@ -8,6 +8,6 @@ test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops
test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer
test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer
test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer
test_seq_select_layers)
test_kmax_seq_socre_layer test_seq_select_layers)
export whole_configs=(test_split_datasource)
type: "nn"
layers {
name: "input"
type: "data"
size: 300
active_type: ""
}
layers {
name: "data"
type: "data"
size: 128
active_type: ""
}
layers {
name: "__fc_layer_0__"
type: "fc"
size: 1
active_type: "exponential"
inputs {
input_layer_name: "data"
input_parameter_name: "___fc_layer_0__.w0"
}
bias_parameter_name: "___fc_layer_0__.wbias"
}
layers {
name: "__kmax_sequence_score_layer_0__"
type: "kmax_seq_score"
active_type: ""
inputs {
input_layer_name: "__fc_layer_0__"
}
beam_size: 5
}
parameters {
name: "___fc_layer_0__.w0"
size: 128
initial_mean: 0.0
initial_std: 0.0883883476483
dims: 128
dims: 1
initial_strategy: 0
initial_smart: true
}
parameters {
name: "___fc_layer_0__.wbias"
size: 1
initial_mean: 0.0
initial_std: 0.0
dims: 1
dims: 1
initial_strategy: 0
initial_smart: false
}
input_layer_names: "data"
output_layer_names: "__kmax_sequence_score_layer_0__"
sub_models {
name: "root"
layer_names: "input"
layer_names: "data"
layer_names: "__fc_layer_0__"
layer_names: "__kmax_sequence_score_layer_0__"
input_layer_names: "data"
output_layer_names: "__kmax_sequence_score_layer_0__"
is_recurrent_layer_group: false
}
#!/usr/bin/env python
#coding=utf-8
from paddle.trainer_config_helpers import *
data = data_layer(name='input', size=300)
data = data_layer(name="data", size=128)
scores = fc_layer(input=data, size=1, act=ExpActivation())
kmax_seq_id = kmax_sequence_score_layer(input=scores, beam_size=5)
outputs(kmax_seq_id)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册