From 84835784173cb7a6bf79fc86665372dfbca69768 Mon Sep 17 00:00:00 2001 From: Dong Zhihong Date: Fri, 10 Nov 2017 19:32:02 -0800 Subject: [PATCH] fix shape bug --- paddle/operators/reduce_op.h | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/paddle/operators/reduce_op.h b/paddle/operators/reduce_op.h index 45043c440b..12ec1fcf44 100644 --- a/paddle/operators/reduce_op.h +++ b/paddle/operators/reduce_op.h @@ -14,6 +14,7 @@ #pragma once +#include "glog/logging.h" #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" @@ -26,6 +27,10 @@ template using EigenTensor = framework::EigenTensor; +template +using EigenScalar = framework::EigenScalar; + struct SumFunctor { template void operator()(const Place& place, X& x, Y& y, const Dim& dim) { @@ -133,10 +138,21 @@ class ReduceKernel : public framework::OpKernel { dims_vector.erase(dims_vector.begin() + dim); dims = framework::make_ddim(dims_vector); } - auto out = EigenTensor < T, D == 1 ? 1 : (D - 1) > ::From(*output, dims); + auto& place = context.GetEigenDevice(); Functor functor; - functor(place, x, out, reduce_dim); + + if (D == 1) { + auto out = EigenScalar::From(*output); + // auto out = EigenTensor::From(*output, dims); + VLOG(0) << "x dims : " << x.rank() << " out dims : " << out.rank(); + functor(place, x, out, reduce_dim); + } else { + auto out = EigenTensor::From(*output, dims); + // VLOG(0) << "x dims : "<< x.dimensions().size() << " out dims : " + // << out.dimensions().size(); + functor(place, x, out, reduce_dim); + } } }; @@ -186,13 +202,13 @@ class ReduceGradKernel : public framework::OpKernel { auto x_reduce = EigenTensor::From(*input1, dims); auto x_reduce_grad = EigenTensor::From(*input2, dims); - Eigen::array braodcast_dim; - for (size_t i = 0; i < D; ++i) braodcast_dim[i] = 1; - braodcast_dim[dim] = input0->dims()[dim]; + Eigen::array broadcast_dim; + for (size_t i = 0; i < D; ++i) broadcast_dim[i] = 1; + broadcast_dim[dim] = input0->dims()[dim]; auto& place = context.GetEigenDevice(); Functor functor; - functor(place, x, x_reduce, x_grad, x_reduce_grad, braodcast_dim, - braodcast_dim[dim]); + functor(place, x, x_reduce, x_grad, x_reduce_grad, broadcast_dim, + broadcast_dim[dim]); } }; -- GitLab