提交 0385b0a1 编写于 作者: M minqiyang

Accelerate SequencePool Op on SUM mode

test=develop
上级 e1904ac2
......@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/sequence_pooling.h"
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence_pooling.h"
namespace paddle {
namespace operators {
......@@ -180,6 +182,7 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
}
auto lod = input.lod()[0];
auto& place = *context.eigen_device();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
Tensor in_t =
input.Slice(static_cast<int>(lod[i]), static_cast<int>(lod[i + 1]));
......@@ -191,7 +194,14 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
if (pooltype == "AVERAGE") {
out_e.device(place) = in_e.mean(Eigen::array<int, 1>({{0}}));
} else if (pooltype == "SUM") {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}}));
if (h > 0) {
const T* in_data = in_t.data<T>();
T* out_data = out_t.mutable_data<T>(context.GetPlace());
blas.VCOPY(w, in_data, out_data);
for (int64_t r = 1; r != h; ++r) {
blas.AXPY(w, 1., in_data + r * w, out_data);
}
}
} else if (pooltype == "SQRT") {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
std::sqrt(static_cast<T>(h));
......@@ -223,6 +233,7 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
}
auto lod = in_grad->lod()[0];
auto& place = *context.eigen_device();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
auto in_g_t = in_grad->Slice(static_cast<int>(lod[i]),
static_cast<int>(lod[i + 1]));
......@@ -237,7 +248,11 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
if (pooltype == "AVERAGE") {
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
} else if (pooltype == "SUM") {
in_g_e.device(place) = (out_g_e).broadcast(bcast);
const T* out_g_data = out_g_t.data<T>();
T* in_g_data = in_g_t.mutable_data<T>(context.GetPlace());
for (int r = 0; r != h; ++r) {
blas.VCOPY(w, out_g_data, in_g_data + r * w);
}
} else if (pooltype == "SQRT") {
in_g_e.device(place) =
(out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册