提交 1454cd54 编写于 作者: B bingyanghuang

pre-commit check

上级 7429067a
......@@ -109,40 +109,38 @@ class LastFirstSeqPoolFunctor {
void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& input, framework::Tensor* output,
const std::string pooltype) {
//Create pointers to input and output data
auto* in_data = input.data<T>();
auto* out_data = output->data<T>();
//Calculate the size of each item in sequence
int64_t item_size = input.numel() / input.dims()[0];
auto lod = input.lod()[0];
int seq_num = static_cast<int>(lod.size()) - 1;
if (pooltype == "LAST"){
for (int i=0; i < seq_num; ++i ){
//Calculate the length of each sequence
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
//Point to the begin of next sequence
in_data += seq_len* item_size;
//Copy the last item to output
std::memcpy(out_data,(in_data-item_size),item_size*sizeof(T));
out_data += item_size;
}
}
else if(pooltype == "FIRST"){
for (int i=0; i < seq_num; ++i ){
//Calculate the length of each sequence
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
//Copy the first item of sequence to output
std::memcpy(out_data,in_data,item_size*sizeof(T));
//Point to the next sequence
in_data += seq_len * item_size;
out_data += item_size;
}
}
else {
PADDLE_THROW("it's not LAST or FIRST pool type");
// Create pointers to input and output data
auto* in_data = input.data<T>();
auto* out_data = output->data<T>();
// Calculate the size of each item in sequence
int64_t item_size = input.numel() / input.dims()[0];
auto lod = input.lod()[0];
int seq_num = static_cast<int>(lod.size()) - 1;
if (pooltype == "LAST") {
for (int i = 0; i < seq_num; ++i) {
// Calculate the length of each sequence
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
// Point to the begin of next sequence
in_data += seq_len * item_size;
// Copy the last item of sequence to output
std::memcpy(out_data, (in_data - item_size), item_size * sizeof(T));
out_data += item_size;
}
}
} else if (pooltype == "FIRST") {
for (int i = 0; i < seq_num; ++i) {
// Calculate the length of each sequence
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
// Copy the first item of sequence to output
std::memcpy(out_data, in_data, item_size * sizeof(T));
// Point to the next sequence
in_data += seq_len * item_size;
out_data += item_size;
}
} else {
PADDLE_THROW("it's not LAST or FIRST pool type");
}
}
};
template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册