From 746f2a2e3616f8b9b5736b67c759be89bbd3e52d Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Thu, 26 Oct 2017 18:32:28 +0800 Subject: [PATCH] only compute the first max value in backward --- paddle/operators/sequence_pool_op.h | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/paddle/operators/sequence_pool_op.h b/paddle/operators/sequence_pool_op.h index b5835dad5b0..ead30e8e90b 100644 --- a/paddle/operators/sequence_pool_op.h +++ b/paddle/operators/sequence_pool_op.h @@ -103,7 +103,6 @@ class SequencePoolGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); - auto* out = context.Input("Out"); auto* in_g = context.Output(framework::GradVarName("X")); auto* out_g = context.Input(framework::GradVarName("Out")); int strategy = context.Attr("strategy"); @@ -140,16 +139,19 @@ class SequencePoolGradKernel : public framework::OpKernel { (out_g_e / std::sqrt(static_cast(h))).broadcast(bcast); break; case MAX: { - auto in_t = in->Slice(static_cast(lod[i]), - static_cast(lod[i + 1])); - auto out_t = out->Slice(i, i + 1); - auto in_e = EigenMatrix::From(in_t, {h, w}); - auto out_e = EigenMatrix::From(out_t, {1, w}); - auto equals = in_e == out_e.broadcast(bcast); - auto ones = in_g_e.constant(1); - auto zeros = in_g_e.constant(0); - in_g_e.device(place) = - out_g_e.broadcast(bcast) * equals.select(ones, zeros); + auto in_t = + in->Slice(static_cast(lod[i]), static_cast(lod[i + 1])); + Eigen::Map> + in_t_map(in_t.data(), h, w); + int row_id; + Eigen::array extents = {1, 1}; + for (int col_id = 0; col_id < w; col_id++) { + in_t_map.col(col_id).maxCoeff(&row_id); + Eigen::array in_offsets = {row_id, col_id}; + Eigen::array out_offsets = {0, col_id}; + in_g_e.slice(in_offsets, extents).device(place) = + out_g_e.slice(out_offsets, extents); + } break; } case LAST: -- GitLab