diff --git a/paddle/operators/math/softmax.h b/paddle/operators/math/softmax.h index 3d2f0d0aecffcd0fe51166c3d863aa8b91bba196..225323f05ac9aacce55dfe4795315741ee2c8795 100644 --- a/paddle/operators/math/softmax.h +++ b/paddle/operators/math/softmax.h @@ -36,7 +36,7 @@ struct ValueClip { template class SoftmaxFunctor { public: - void operator()(const framework::ExecutionContext& context, + void operator()(const platform::DeviceContext& context, const framework::Tensor* X, framework::Tensor* Y) { auto logits = EigenMatrix::From(*X); auto softmax = EigenMatrix::From(*Y); @@ -58,8 +58,8 @@ class SoftmaxFunctor { .broadcast(one_by_class)) .unaryExpr(ValueClip()); - softmax.device(context.GetEigenDevice()) = shifted_logits.exp(); - softmax.device(context.GetEigenDevice()) = + softmax.device(*context.GetEigenDevice()) = shifted_logits.exp(); + softmax.device(*context.GetEigenDevice()) = (softmax * softmax.sum(along_class) .inverse() diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index 9996536454b1b6b992385787301faa6d66a4cd20..8fdda8b1dfc5dd40315682388dabe0bf2f2be555 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -35,7 +35,7 @@ class SoftmaxKernel : public framework::OpKernel { // allocate memory on device. Y->mutable_data(context.GetPlace()); - math::SoftmaxFunctor()(context, X, Y); + math::SoftmaxFunctor()(context.device_context(), X, Y); } }; diff --git a/paddle/operators/softmax_with_cross_entropy_op.cu b/paddle/operators/softmax_with_cross_entropy_op.cu index c3086e729e493228e06a176e1a64e5e95fad148b..b5a7cda734019e88df1a64ab87dbf5563211a168 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/operators/softmax_with_cross_entropy_op.cu @@ -66,7 +66,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { softmax->mutable_data(context.GetPlace()); loss->mutable_data(context.GetPlace()); - math::SoftmaxFunctor()(context, logits, softmax); + math::SoftmaxFunctor()(context.device_context(), + logits, softmax); math::CrossEntropyFunctor()( context, loss, softmax, labels, context.Attr("softLabel")); } diff --git a/paddle/operators/softmax_with_cross_entropy_op.h b/paddle/operators/softmax_with_cross_entropy_op.h index 7dcb6ad9b4d1ecdcc10261eba5d6a48d56fb844b..cffd422f1827b646a8abcd881fdcb5455e6a663a 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.h +++ b/paddle/operators/softmax_with_cross_entropy_op.h @@ -40,7 +40,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { softmax->mutable_data(context.GetPlace()); loss->mutable_data(context.GetPlace()); - math::SoftmaxFunctor()(context, logits, softmax); + math::SoftmaxFunctor()(context.device_context(), + logits, softmax); math::CrossEntropyFunctor()( context.device_context(), loss, softmax, labels, context.Attr("softLabel"));