提交 cdbc5e73 编写于 作者: B bingyanghuang

Add some comments

上级 53185fde
...@@ -109,24 +109,32 @@ class LastFirstSeqPoolFunctor { ...@@ -109,24 +109,32 @@ class LastFirstSeqPoolFunctor {
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& input, framework::Tensor* output, const framework::LoDTensor& input, framework::Tensor* output,
const std::string pooltype) { const std::string pooltype) {
//Create pointers to input and output data
auto* in_data = input.data<T>(); auto* in_data = input.data<T>();
auto* out_data = output->data<T>(); auto* out_data = output->data<T>();
//Calculate length of each word
int64_t word_len = input.numel() / input.dims()[0]; int64_t word_len = input.numel() / input.dims()[0];
auto lod = input.lod()[0]; auto lod = input.lod()[0];
auto dims = input.dims();
if (pooltype == "LAST"){ if (pooltype == "LAST"){
for (int i=0; i < static_cast<int>(lod.size()) - 1; ++i ){ for (int i=0; i < static_cast<int>(lod.size()) - 1; ++i ){
//Calculate the length of each sequence
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]); int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
//Point to the begin of next sequence
in_data += seq_len* word_len; in_data += seq_len* word_len;
std::memcpy(out_data,(in_data-word_len),word_len*sizeof(int)); //Copy the last words to output
std::memcpy(out_data,(in_data-word_len),word_len*sizeof(T));
out_data += word_len; out_data += word_len;
} }
} }
else if(pooltype == "FIRST"){ else if(pooltype == "FIRST"){
for (int i=0; i < static_cast<int>(lod.size()) - 1; ++i ){ for (int i=0; i < static_cast<int>(lod.size()) - 1; ++i ){
//Calculate the length of each sequence
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]); int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
std::memcpy(out_data,in_data,word_len*sizeof(int)); //Copy the first words of sequence to output
std::memcpy(out_data,in_data,word_len*sizeof(T));
//Point to the next sequence
in_data += seq_len * word_len; in_data += seq_len * word_len;
out_data += word_len; out_data += word_len;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册