提交 393c748c 编写于 作者: L Luo Tao

add seqlastin/seqfirstin for seq_pool op

上级 e69a565a
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -81,6 +82,12 @@ class SequencePoolKernel : public framework::OpKernel<T> { ...@@ -81,6 +82,12 @@ class SequencePoolKernel : public framework::OpKernel<T> {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) / out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
std::sqrt(static_cast<T>(h)); std::sqrt(static_cast<T>(h));
break; break;
case LAST:
out_e.device(place) = in_e.chip(h - 1, 0);
break;
case FIRST:
out_e.device(place) = in_e.chip(0, 0);
break;
default: default:
PADDLE_THROW("unsupported pooling strategy"); PADDLE_THROW("unsupported pooling strategy");
} }
...@@ -102,6 +109,10 @@ class SequencePoolGradKernel : public framework::OpKernel<T> { ...@@ -102,6 +109,10 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
int64_t w = in->numel() / dims[0]; int64_t w = in->numel() / dims[0];
in_g->mutable_data<T>(context.GetPlace()); in_g->mutable_data<T>(context.GetPlace());
if (strategy > 2) {
// set X@Grad be zero at first when strategy is LAST/FIRST/MAX
math::SetConstant<Place, T>(context.device_context(), in_g, 0);
}
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
auto in_g_t = in_g->Slice<T>(static_cast<int>(lod[i]), auto in_g_t = in_g->Slice<T>(static_cast<int>(lod[i]),
...@@ -123,6 +134,12 @@ class SequencePoolGradKernel : public framework::OpKernel<T> { ...@@ -123,6 +134,12 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
in_g_e.device(place) = in_g_e.device(place) =
(out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast); (out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast);
break; break;
case LAST:
in_g_e.chip(h - 1, 0).device(place) = out_g_e;
break;
case FIRST:
in_g_e.chip(0, 0).device(place) = out_g_e;
break;
default: default:
PADDLE_THROW("unsupported pooling strategy"); PADDLE_THROW("unsupported pooling strategy");
} }
......
...@@ -107,5 +107,45 @@ class TestSeqSqrtPool2D(TestSeqAvgPool2D): ...@@ -107,5 +107,45 @@ class TestSeqSqrtPool2D(TestSeqAvgPool2D):
self.check_grad(["X"], "Out", max_relative_error=0.06) self.check_grad(["X"], "Out", max_relative_error=0.06)
class TestSeqLastPool(TestSeqAvgPool):
def compute(self):
self.attrs = {'strategy': SeqPoolType.LAST}
x, lod = self.inputs['X']
out = self.outputs['Out']
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x[-1, :]
class TestSeqLastPool2D(TestSeqAvgPool2D):
def compute(self):
self.attrs = {'strategy': SeqPoolType.LAST}
x, lod = self.inputs['X']
out = self.outputs['Out']
for i in range(4):
sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17))
out[i] = np.reshape(sub_x[-1, :], (3, 17))
class TestSeqFirstPool(TestSeqAvgPool):
def compute(self):
self.attrs = {'strategy': SeqPoolType.FIRST}
x, lod = self.inputs['X']
out = self.outputs['Out']
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x[0, :]
class TestSeqFirstPool2D(TestSeqAvgPool2D):
def compute(self):
self.attrs = {'strategy': SeqPoolType.FIRST}
x, lod = self.inputs['X']
out = self.outputs['Out']
for i in range(4):
sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17))
out[i] = np.reshape(sub_x[0, :], (3, 17))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册