提交 83394bab 编写于 作者: B bingyanghuang

modified by luotao's suggestion

上级 1454cd54
...@@ -104,11 +104,10 @@ class MaxSeqPoolGradFunctor { ...@@ -104,11 +104,10 @@ class MaxSeqPoolGradFunctor {
}; };
template <typename T> template <typename T>
class LastFirstSeqPoolFunctor { class LastSeqPoolFunctor {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& input, framework::Tensor* output, const framework::LoDTensor& input, framework::Tensor* output) {
const std::string pooltype) {
// Create pointers to input and output data // Create pointers to input and output data
auto* in_data = input.data<T>(); auto* in_data = input.data<T>();
auto* out_data = output->data<T>(); auto* out_data = output->data<T>();
...@@ -117,29 +116,40 @@ class LastFirstSeqPoolFunctor { ...@@ -117,29 +116,40 @@ class LastFirstSeqPoolFunctor {
int64_t item_size = input.numel() / input.dims()[0]; int64_t item_size = input.numel() / input.dims()[0];
auto lod = input.lod()[0]; auto lod = input.lod()[0];
int seq_num = static_cast<int>(lod.size()) - 1; int seq_num = static_cast<int>(lod.size()) - 1;
if (pooltype == "LAST") { for (int i = 0; i < seq_num; ++i) {
for (int i = 0; i < seq_num; ++i) { // Calculate the length of each sequence
// Calculate the length of each sequence int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]); // Point to the begin of next sequence
// Point to the begin of next sequence in_data += seq_len * item_size;
in_data += seq_len * item_size; // Copy the last item of sequence to output
// Copy the last item of sequence to output std::memcpy(out_data, (in_data - item_size), item_size * sizeof(T));
std::memcpy(out_data, (in_data - item_size), item_size * sizeof(T)); out_data += item_size;
out_data += item_size; }
} }
} else if (pooltype == "FIRST") { };
for (int i = 0; i < seq_num; ++i) {
// Calculate the length of each sequence template <typename T>
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]); class FirstSeqPoolFunctor {
// Copy the first item of sequence to output public:
std::memcpy(out_data, in_data, item_size * sizeof(T)); void operator()(const platform::CPUDeviceContext& context,
// Point to the next sequence const framework::LoDTensor& input, framework::Tensor* output) {
in_data += seq_len * item_size; // Create pointers to input and output data
out_data += item_size; auto* in_data = input.data<T>();
auto* out_data = output->data<T>();
// Calculate the size of each item in sequence
int64_t item_size = input.numel() / input.dims()[0];
auto lod = input.lod()[0];
int seq_num = static_cast<int>(lod.size()) - 1;
for (int i = 0; i < seq_num; ++i) {
// Calculate the length of each sequence
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
// Copy the first item of sequence to output
std::memcpy(out_data, in_data, item_size * sizeof(T));
// Point to the next sequence
in_data += seq_len * item_size;
out_data += item_size;
} }
} else {
PADDLE_THROW("it's not LAST or FIRST pool type");
}
} }
}; };
...@@ -156,11 +166,17 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> { ...@@ -156,11 +166,17 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
max_pool(context, input, output, index); max_pool(context, input, output, index);
return; return;
} }
if (pooltype == "LAST" || pooltype == "FIRST") { if (pooltype == "LAST") {
math::LastFirstSeqPoolFunctor<T> lastfirst_pool; math::LastSeqPoolFunctor<T> last_pool;
lastfirst_pool(context, input, output, pooltype); last_pool(context, input, output);
return; return;
} }
if (pooltype == "FIRST") {
math::FirstSeqPoolFunctor<T> first_pool;
first_pool(context, input, output);
return;
}
auto lod = input.lod()[0]; auto lod = input.lod()[0];
auto& place = *context.eigen_device(); auto& place = *context.eigen_device();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册