diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.cc b/paddle/fluid/operators/softmax_cudnn_op.cu.cc index 5596fa0648ccc151bc0d11de9c556599428a8d71..2bdb23e999621b10799b5163f326bc4b66a437e6 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu.cc +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.cc @@ -30,8 +30,16 @@ class SoftmaxCUDNNKernel : public framework::OpKernel { // allocate memory on device. Out->mutable_data(context.GetPlace()); + auto dims = X->dims(); + auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); + framework::LoDTensor flattened_x; + framework::LoDTensor flattened_out; + flattened_x.ShareDataWith(*X).Resize(flattened_dims); + flattened_out.ShareDataWith(*Out).Resize(flattened_dims); + math::SoftmaxCUDNNFunctor()( - context.template device_context(), X, Out); + context.template device_context(), + &flattened_x, &flattened_out); } }; @@ -46,9 +54,18 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel { // allocate memory on device. dX->mutable_data(context.GetPlace()); + auto dims = Out->dims(); + auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); + framework::LoDTensor flattened_out; + framework::LoDTensor flattened_d_out; + framework::LoDTensor flattened_d_x; + flattened_out.ShareDataWith(*Out).Resize(flattened_dims); + flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims); + flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims); + math::SoftmaxGradCUDNNFunctor()( - context.template device_context(), Out, - dOut, dX); + context.template device_context(), + &flattened_out, &flattened_d_out, &flattened_d_x); } };