From c1fccc29c1dbd3ee7f9629372b98b26fbeeb7e10 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Thu, 8 Nov 2018 11:35:21 +0100 Subject: [PATCH] - Noise adding removed for Test phase of softmax --- paddle/fluid/operators/math/softmax.cc | 6 ++- paddle/fluid/operators/math/softmax.h | 2 +- paddle/fluid/operators/math/softmax_impl.h | 40 ++++++++++++++++++- paddle/fluid/operators/softmax_op.h | 10 ++++- .../operators/softmax_with_cross_entropy_op.h | 2 +- 5 files changed, 52 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/math/softmax.cc b/paddle/fluid/operators/math/softmax.cc index 78c65af24a..6300836e50 100644 --- a/paddle/fluid/operators/math/softmax.cc +++ b/paddle/fluid/operators/math/softmax.cc @@ -19,8 +19,10 @@ namespace paddle { namespace operators { namespace math { -template class SoftmaxFunctor; -template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; diff --git a/paddle/fluid/operators/math/softmax.h b/paddle/fluid/operators/math/softmax.h index da1f0b672d..bf698dc2f7 100644 --- a/paddle/fluid/operators/math/softmax.h +++ b/paddle/fluid/operators/math/softmax.h @@ -19,7 +19,7 @@ namespace paddle { namespace operators { namespace math { -template +template class SoftmaxFunctor { public: void operator()(const DeviceContext& context, const framework::Tensor* X, diff --git a/paddle/fluid/operators/math/softmax_impl.h b/paddle/fluid/operators/math/softmax_impl.h index dd9971ba09..6a0a6c2e46 100644 --- a/paddle/fluid/operators/math/softmax_impl.h +++ b/paddle/fluid/operators/math/softmax_impl.h @@ -32,8 +32,8 @@ struct ValueClip { } }; -template -void SoftmaxFunctor::operator()(const DeviceContext& context, +template +void SoftmaxFunctor::operator()(const DeviceContext& context, const framework::Tensor* X, framework::Tensor* Y) { auto logits = EigenMatrix::From(*X); @@ -65,6 +65,42 @@ void SoftmaxFunctor::operator()(const DeviceContext& context, .broadcast(one_by_class)); } +template +class SoftmaxFunctor { +void operator()(const DeviceContext& context, + const framework::Tensor* X, + framework::Tensor* Y) { + auto logits = EigenMatrix::From(*X); + auto softmax = EigenMatrix::From(*Y); + + const int kBatchDim = 0; + const int kClassDim = 1; + + const int batch_size = logits.dimension(kBatchDim); + const int num_classes = logits.dimension(kClassDim); + + Eigen::DSizes along_class(kClassDim); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + + auto shifted_logits = (logits - + logits.maximum(along_class) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)); + + softmax.device(*context.eigen_device()) = shifted_logits.exp(); + softmax.device(*context.eigen_device()) = (softmax * + softmax.sum(along_class) + .inverse() + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)); +} +}; + + + template void SoftmaxGradFunctor::operator()( const DeviceContext& context, const framework::Tensor* y, diff --git a/paddle/fluid/operators/softmax_op.h b/paddle/fluid/operators/softmax_op.h index cf1eeb017d..5bc72aac48 100644 --- a/paddle/fluid/operators/softmax_op.h +++ b/paddle/fluid/operators/softmax_op.h @@ -35,8 +35,14 @@ class SoftmaxKernel : public framework::OpKernel { Tensor X_2d = framework::ReshapeToMatrix(*X, rank - 1); Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); - math::SoftmaxFunctor()( - context.template device_context(), &X_2d, &Out_2d); + const bool is_test = context.Attr("is_test"); + if( is_test == true) { + math::SoftmaxFunctor()( + context.template device_context(), &X_2d, &Out_2d); + } else { + math::SoftmaxFunctor()( + context.template device_context(), &X_2d, &Out_2d); + } } }; diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.h b/paddle/fluid/operators/softmax_with_cross_entropy_op.h index e9aba3b37b..2eec8541c8 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.h +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.h @@ -42,7 +42,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { auto& dev_ctx = context.template device_context(); - math::SoftmaxFunctor()(dev_ctx, logits, + math::SoftmaxFunctor()(dev_ctx, logits, softmax); math::CrossEntropyFunctor()( dev_ctx, loss, softmax, labels, context.Attr("soft_label"), -- GitLab