From 84ff7e97842890e70f1baf6bf41ef54513d1a4a3 Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 28 Sep 2017 20:05:15 -0700 Subject: [PATCH] refine SoftmaxFunctor --- paddle/operators/math/softmax.h | 6 +++--- paddle/operators/softmax_op.h | 2 +- paddle/operators/softmax_with_cross_entropy_op.cu | 3 ++- paddle/operators/softmax_with_cross_entropy_op.h | 3 ++- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/paddle/operators/math/softmax.h b/paddle/operators/math/softmax.h index 3d2f0d0ae..225323f05 100644 --- a/paddle/operators/math/softmax.h +++ b/paddle/operators/math/softmax.h @@ -36,7 +36,7 @@ struct ValueClip { template class SoftmaxFunctor { public: - void operator()(const framework::ExecutionContext& context, + void operator()(const platform::DeviceContext& context, const framework::Tensor* X, framework::Tensor* Y) { auto logits = EigenMatrix::From(*X); auto softmax = EigenMatrix::From(*Y); @@ -58,8 +58,8 @@ class SoftmaxFunctor { .broadcast(one_by_class)) .unaryExpr(ValueClip()); - softmax.device(context.GetEigenDevice()) = shifted_logits.exp(); - softmax.device(context.GetEigenDevice()) = + softmax.device(*context.GetEigenDevice()) = shifted_logits.exp(); + softmax.device(*context.GetEigenDevice()) = (softmax * softmax.sum(along_class) .inverse() diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index 999653645..8fdda8b1d 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -35,7 +35,7 @@ class SoftmaxKernel : public framework::OpKernel { // allocate memory on device. Y->mutable_data(context.GetPlace()); - math::SoftmaxFunctor()(context, X, Y); + math::SoftmaxFunctor()(context.device_context(), X, Y); } }; diff --git a/paddle/operators/softmax_with_cross_entropy_op.cu b/paddle/operators/softmax_with_cross_entropy_op.cu index c3086e729..b5a7cda73 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/operators/softmax_with_cross_entropy_op.cu @@ -66,7 +66,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { softmax->mutable_data(context.GetPlace()); loss->mutable_data(context.GetPlace()); - math::SoftmaxFunctor()(context, logits, softmax); + math::SoftmaxFunctor()(context.device_context(), + logits, softmax); math::CrossEntropyFunctor()( context, loss, softmax, labels, context.Attr("softLabel")); } diff --git a/paddle/operators/softmax_with_cross_entropy_op.h b/paddle/operators/softmax_with_cross_entropy_op.h index 7dcb6ad9b..cffd422f1 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.h +++ b/paddle/operators/softmax_with_cross_entropy_op.h @@ -40,7 +40,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { softmax->mutable_data(context.GetPlace()); loss->mutable_data(context.GetPlace()); - math::SoftmaxFunctor()(context, logits, softmax); + math::SoftmaxFunctor()(context.device_context(), + logits, softmax); math::CrossEntropyFunctor()( context.device_context(), loss, softmax, labels, context.Attr("softLabel")); -- GitLab