提交 84ff7e97 编写于 作者: Q qijun

refine SoftmaxFunctor

上级 79def5e6
......@@ -36,7 +36,7 @@ struct ValueClip {
template <typename Place, typename T>
class SoftmaxFunctor {
public:
void operator()(const framework::ExecutionContext& context,
void operator()(const platform::DeviceContext& context,
const framework::Tensor* X, framework::Tensor* Y) {
auto logits = EigenMatrix<T>::From(*X);
auto softmax = EigenMatrix<T>::From(*Y);
......@@ -58,8 +58,8 @@ class SoftmaxFunctor {
.broadcast(one_by_class))
.unaryExpr(ValueClip<T>());
softmax.device(context.GetEigenDevice<Place>()) = shifted_logits.exp();
softmax.device(context.GetEigenDevice<Place>()) =
softmax.device(*context.GetEigenDevice<Place>()) = shifted_logits.exp();
softmax.device(*context.GetEigenDevice<Place>()) =
(softmax *
softmax.sum(along_class)
.inverse()
......
......@@ -35,7 +35,7 @@ class SoftmaxKernel : public framework::OpKernel<T> {
// allocate memory on device.
Y->mutable_data<T>(context.GetPlace());
math::SoftmaxFunctor<Place, T>()(context, X, Y);
math::SoftmaxFunctor<Place, T>()(context.device_context(), X, Y);
}
};
......
......@@ -66,7 +66,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace());
math::SoftmaxFunctor<platform::GPUPlace, T>()(context, logits, softmax);
math::SoftmaxFunctor<platform::GPUPlace, T>()(context.device_context(),
logits, softmax);
math::CrossEntropyFunctor<platform::GPUPlace, T>()(
context, loss, softmax, labels, context.Attr<bool>("softLabel"));
}
......
......@@ -40,7 +40,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace());
math::SoftmaxFunctor<platform::CPUPlace, T>()(context, logits, softmax);
math::SoftmaxFunctor<platform::CPUPlace, T>()(context.device_context(),
logits, softmax);
math::CrossEntropyFunctor<platform::CPUPlace, T>()(
context.device_context(), loss, softmax, labels,
context.Attr<bool>("softLabel"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册