提交 d303f7ae 编写于 作者: Q qiaolongfei

fix int overflow

上级 229c2e78
...@@ -28,7 +28,8 @@ using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; ...@@ -28,7 +28,8 @@ using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using DataLayout = platform::DataLayout; using DataLayout = platform::DataLayout;
static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 1024 * 1024 * 1024; static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES =
static_cast<size_t>(1024) * 1024 * 1024;
template <typename T> template <typename T>
class CudnnConvOpKernel : public framework::OpKernel<T> { class CudnnConvOpKernel : public framework::OpKernel<T> {
...@@ -44,7 +45,8 @@ class CudnnConvOpKernel : public framework::OpKernel<T> { ...@@ -44,7 +45,8 @@ class CudnnConvOpKernel : public framework::OpKernel<T> {
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
int user_workspace_size = ctx.Attr<int>("workspace_size_MB"); int64_t user_workspace_size =
static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>(); const T* filter_data = filter->data<T>();
...@@ -163,7 +165,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> { ...@@ -163,7 +165,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
int user_workspace_size = ctx.Attr<int>("workspace_size_MB"); int64_t user_workspace_size =
static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
// ------------------- cudnn descriptors --------------------- // ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc; ScopedTensorDescriptor input_desc;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册