提交 dc111d34 编写于 作者: F fengjiayi

update softmax_cudnn_op

上级 f7bd0b22
......@@ -30,8 +30,16 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
// allocate memory on device.
Out->mutable_data<T>(context.GetPlace());
auto dims = X->dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
framework::LoDTensor flattened_x;
framework::LoDTensor flattened_out;
flattened_x.ShareDataWith(*X).Resize(flattened_dims);
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
math::SoftmaxCUDNNFunctor<T>()(
context.template device_context<platform::CUDADeviceContext>(), X, Out);
context.template device_context<platform::CUDADeviceContext>(),
&flattened_x, &flattened_out);
}
};
......@@ -46,9 +54,18 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
// allocate memory on device.
dX->mutable_data<T>(context.GetPlace());
auto dims = Out->dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
framework::LoDTensor flattened_out;
framework::LoDTensor flattened_d_out;
framework::LoDTensor flattened_d_x;
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims);
flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims);
math::SoftmaxGradCUDNNFunctor<T>()(
context.template device_context<platform::CUDADeviceContext>(), Out,
dOut, dX);
context.template device_context<platform::CUDADeviceContext>(),
&flattened_out, &flattened_d_out, &flattened_d_x);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册