diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index 69318a6598c8c69eceab7216df6382537153d34f..235b5405fb7d016f4bd8c738f75b303522183116 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/math/sequence_pooling.h" #include + +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/sequence_pooling.h" namespace paddle { namespace operators { @@ -180,6 +182,7 @@ class SequencePoolFunctor { } auto lod = input.lod()[0]; auto& place = *context.eigen_device(); + auto blas = math::GetBlas(context); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { Tensor in_t = input.Slice(static_cast(lod[i]), static_cast(lod[i + 1])); @@ -191,7 +194,14 @@ class SequencePoolFunctor { if (pooltype == "AVERAGE") { out_e.device(place) = in_e.mean(Eigen::array({{0}})); } else if (pooltype == "SUM") { - out_e.device(place) = in_e.sum(Eigen::array({{0}})); + if (h > 0) { + const T* in_data = in_t.data(); + T* out_data = out_t.mutable_data(context.GetPlace()); + blas.VCOPY(w, in_data, out_data); + for (int64_t r = 1; r != h; ++r) { + blas.AXPY(w, 1., in_data + r * w, out_data); + } + } } else if (pooltype == "SQRT") { out_e.device(place) = in_e.sum(Eigen::array({{0}})) / std::sqrt(static_cast(h)); @@ -223,6 +233,7 @@ class SequencePoolGradFunctor { } auto lod = in_grad->lod()[0]; auto& place = *context.eigen_device(); + auto blas = math::GetBlas(context); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { auto in_g_t = in_grad->Slice(static_cast(lod[i]), static_cast(lod[i + 1])); @@ -237,7 +248,11 @@ class SequencePoolGradFunctor { if (pooltype == "AVERAGE") { in_g_e.device(place) = (out_g_e / static_cast(h)).broadcast(bcast); } else if (pooltype == "SUM") { - in_g_e.device(place) = (out_g_e).broadcast(bcast); + const T* out_g_data = out_g_t.data(); + T* in_g_data = in_g_t.mutable_data(context.GetPlace()); + for (int r = 0; r != h; ++r) { + blas.VCOPY(w, out_g_data, in_g_data + r * w); + } } else if (pooltype == "SQRT") { in_g_e.device(place) = (out_g_e / std::sqrt(static_cast(h))).broadcast(bcast);