提交 be80bb4f 编写于 作者: J Jacek Czaja

- Fix to GPU

test=develop
上级 513bb6c1
......@@ -36,7 +36,9 @@ class SoftmaxKernel : public framework::OpKernel<T> {
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
#ifdef PADDLE_ON_INFERENCE
math::SoftmaxFunctor<DeviceContext, T, true>()(
math::SoftmaxFunctor<
DeviceContext, T,
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>()(
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
#else
math::SoftmaxFunctor<DeviceContext, T, false>()(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册