提交 9c687090 编写于 作者: M minqiyang

Accelerate sequence_pool functor

上级 14ebc424
...@@ -231,9 +231,30 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> { ...@@ -231,9 +231,30 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
math::SetConstant<platform::CPUDeviceContext, T> functor; math::SetConstant<platform::CPUDeviceContext, T> functor;
functor(context, in_grad, 0); functor(context, in_grad, 0);
} }
if (pooltype == "SUM") {
auto lod = in_grad->lod()[0]; auto lod = in_grad->lod()[0];
auto& place = *context.eigen_device(); int64_t out_w = out_grad.numel() / out_grad.dims()[0];
int64_t in_w = in_grad->numel() / in_grad->dims()[0];
PADDLE_ENFORCE(in_w == out_w);
const T* out_g_data = out_grad.data<T>();
T* in_g_data = in_grad->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
int64_t in_offset = lod[i];
const T* out_pos = out_g_data + i * out_w;
T* in_pos = in_g_data + in_offset;
for (int r = 0; r != h; ++r) {
blas.VCOPY(in_w, out_pos, in_pos + r * in_w);
}
}
return;
}
auto lod = in_grad->lod()[0];
auto& place = *context.eigen_device();
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
auto in_g_t = in_grad->Slice(static_cast<int>(lod[i]), auto in_g_t = in_grad->Slice(static_cast<int>(lod[i]),
static_cast<int>(lod[i + 1])); static_cast<int>(lod[i + 1]));
...@@ -247,12 +268,6 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> { ...@@ -247,12 +268,6 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
if (pooltype == "AVERAGE") { if (pooltype == "AVERAGE") {
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast); in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
} else if (pooltype == "SUM") {
const T* out_g_data = out_g_t.data<T>();
T* in_g_data = in_g_t.mutable_data<T>(context.GetPlace());
for (int r = 0; r != h; ++r) {
blas.VCOPY(w, out_g_data, in_g_data + r * w);
}
} else if (pooltype == "SQRT") { } else if (pooltype == "SQRT") {
in_g_e.device(place) = in_g_e.device(place) =
(out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast); (out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册