提交 53185fde 编写于 作者: B bingyanghuang

Rewrite sequence pooling last and first mode with memcpy and clean code

上级 a5576083
...@@ -103,6 +103,39 @@ class MaxSeqPoolGradFunctor { ...@@ -103,6 +103,39 @@ class MaxSeqPoolGradFunctor {
} }
}; };
template <typename T>
class LastFirstSeqPoolFunctor {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& input, framework::Tensor* output,
const std::string pooltype) {
auto* in_data = input.data<T>();
auto* out_data = output->data<T>();
int64_t word_len = input.numel() / input.dims()[0];
auto lod = input.lod()[0];
auto dims = input.dims();
if (pooltype == "LAST"){
for (int i=0; i < static_cast<int>(lod.size()) - 1; ++i ){
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
in_data += seq_len* word_len;
std::memcpy(out_data,(in_data-word_len),word_len*sizeof(int));
out_data += word_len;
}
}
else if(pooltype == "FIRST"){
for (int i=0; i < static_cast<int>(lod.size()) - 1; ++i ){
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
std::memcpy(out_data,in_data,word_len*sizeof(int));
in_data += seq_len * word_len;
out_data += word_len;
}
}
}
};
template <typename T> template <typename T>
class SequencePoolFunctor<platform::CPUDeviceContext, T> { class SequencePoolFunctor<platform::CPUDeviceContext, T> {
public: public:
...@@ -116,6 +149,12 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> { ...@@ -116,6 +149,12 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
max_pool(context, input, output, index); max_pool(context, input, output, index);
return; return;
} }
if (pooltype == "LAST" || pooltype == "FIRST") {
math::LastFirstSeqPoolFunctor<T> lastfirst_pool;
lastfirst_pool(context, input, output, pooltype);
return;
}
auto lod = input.lod()[0]; auto lod = input.lod()[0];
auto& place = *context.eigen_device(); auto& place = *context.eigen_device();
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
...@@ -133,10 +172,6 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> { ...@@ -133,10 +172,6 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
} else if (pooltype == "SQRT") { } else if (pooltype == "SQRT") {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) / out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
std::sqrt(static_cast<T>(h)); std::sqrt(static_cast<T>(h));
} else if (pooltype == "LAST") {
out_e.device(place) = in_e.chip(h - 1, 0);
} else if (pooltype == "FIRST") {
out_e.device(place) = in_e.chip(0, 0);
} else { } else {
PADDLE_THROW("unsupported pooling pooltype"); PADDLE_THROW("unsupported pooling pooltype");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册