提交 53185fde 编写于 作者: B bingyanghuang

Rewrite sequence pooling last and first mode with memcpy and clean code

上级 a5576083
......@@ -103,6 +103,39 @@ class MaxSeqPoolGradFunctor {
}
};
template <typename T>
class LastFirstSeqPoolFunctor {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& input, framework::Tensor* output,
const std::string pooltype) {
auto* in_data = input.data<T>();
auto* out_data = output->data<T>();
int64_t word_len = input.numel() / input.dims()[0];
auto lod = input.lod()[0];
auto dims = input.dims();
if (pooltype == "LAST"){
for (int i=0; i < static_cast<int>(lod.size()) - 1; ++i ){
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
in_data += seq_len* word_len;
std::memcpy(out_data,(in_data-word_len),word_len*sizeof(int));
out_data += word_len;
}
}
else if(pooltype == "FIRST"){
for (int i=0; i < static_cast<int>(lod.size()) - 1; ++i ){
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
std::memcpy(out_data,in_data,word_len*sizeof(int));
in_data += seq_len * word_len;
out_data += word_len;
}
}
}
};
template <typename T>
class SequencePoolFunctor<platform::CPUDeviceContext, T> {
public:
......@@ -116,6 +149,12 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
max_pool(context, input, output, index);
return;
}
if (pooltype == "LAST" || pooltype == "FIRST") {
math::LastFirstSeqPoolFunctor<T> lastfirst_pool;
lastfirst_pool(context, input, output, pooltype);
return;
}
auto lod = input.lod()[0];
auto& place = *context.eigen_device();
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
......@@ -133,10 +172,6 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
} else if (pooltype == "SQRT") {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
std::sqrt(static_cast<T>(h));
} else if (pooltype == "LAST") {
out_e.device(place) = in_e.chip(h - 1, 0);
} else if (pooltype == "FIRST") {
out_e.device(place) = in_e.chip(0, 0);
} else {
PADDLE_THROW("unsupported pooling pooltype");
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册