Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
10e08d66
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
10e08d66
编写于
8月 16, 2018
作者:
N
Niu Chong
提交者:
Jinhui Yuan
8月 16, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix: rename UseCudnnOnGpu&UseCudnn to EnableCudnn&DevIsGpuAndEnableCudnn (#1121)
上级
02629a55
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
23 addition
and
23 deletion
+23
-23
oneflow/core/job/job_conf.proto
oneflow/core/job/job_conf.proto
+1
-1
oneflow/core/job/job_desc.h
oneflow/core/job/job_desc.h
+1
-1
oneflow/core/kernel/conv_kernel.cpp
oneflow/core/kernel/conv_kernel.cpp
+2
-1
oneflow/core/kernel/conv_kernel.cu
oneflow/core/kernel/conv_kernel.cu
+4
-4
oneflow/core/kernel/kernel.h
oneflow/core/kernel/kernel.h
+1
-2
oneflow/core/kernel/reduce_gather_kernel.cpp
oneflow/core/kernel/reduce_gather_kernel.cpp
+1
-1
oneflow/core/operator/conv_op.cpp
oneflow/core/operator/conv_op.cpp
+8
-8
oneflow/core/operator/op_conf.proto
oneflow/core/operator/op_conf.proto
+1
-1
oneflow/core/operator/operator.cpp
oneflow/core/operator/operator.cpp
+2
-2
oneflow/core/operator/operator.h
oneflow/core/operator/operator.h
+2
-2
未找到文件。
oneflow/core/job/job_conf.proto
浏览文件 @
10e08d66
...
...
@@ -46,7 +46,7 @@ message OtherConf {
optional
bool
use_rdma
=
100
[
default
=
false
];
optional
string
model_load_snapshot_path
=
101
[
default
=
""
];
optional
int32
max_data_id_length
=
102
[
default
=
0
];
optional
bool
use_cudnn_on_gpu
=
103
[
default
=
true
];
optional
bool
enable_cudnn
=
103
[
default
=
true
];
optional
DataType
default_data_type
=
104
[
default
=
kFloat
];
// kFloat or kDouble
optional
int64
piece_num_of_experiment_phase
=
105
[
default
=
-
1
];
optional
uint64
persistence_buf_mbyte
=
106
[
default
=
64
];
...
...
oneflow/core/job/job_desc.h
浏览文件 @
10e08d66
...
...
@@ -26,7 +26,7 @@ class JobDesc final {
size_t
SizeOfOneDataId
()
const
{
return
job_conf_
.
other
().
max_data_id_length
()
*
sizeof
(
char
);
}
bool
use_rdma
()
const
{
return
job_conf_
.
other
().
use_rdma
();
}
bool
use_synthetic_data
()
const
{
return
job_conf_
.
other
().
use_synthetic_data
();
}
bool
UseCudnnOnGpu
()
const
{
return
job_conf_
.
other
().
use_cudnn_on_gpu
();
}
bool
EnableCudnn
()
const
{
return
job_conf_
.
other
().
enable_cudnn
();
}
int64_t
TotalMachineNum
()
const
{
return
job_conf_
.
resource
().
machine
().
size
();
}
int32_t
CpuDeviceNum
()
const
{
return
job_conf_
.
resource
().
cpu_device_num
();
}
void
SetCpuDeviceNum
(
int32_t
val
)
{
job_conf_
.
mutable_resource
()
->
set_cpu_device_num
(
val
);
}
...
...
oneflow/core/kernel/conv_kernel.cpp
浏览文件 @
10e08d66
...
...
@@ -26,7 +26,8 @@ void ConvKernelIf<device_type, T>::BackwardDataContent(
template
<
DeviceType
device_type
,
typename
T
>
void
ConvKernelIf
<
device_type
,
T
>::
InitConstBufBlobs
(
DeviceCtx
*
ctx
,
std
::
function
<
Blob
*
(
const
std
::
string
&
)
>
BnInOp2Blob
)
const
{
if
(
this
->
template
GetValFromCustomizedOpConf
<
bool
>(
"use_bias"
)
&&
!
this
->
UseCudnnOnGpu
())
{
if
(
this
->
template
GetValFromCustomizedOpConf
<
bool
>(
"use_bias"
)
&&
(
device_type
==
DeviceType
::
kCPU
||
this
->
EnableCudnn
()
==
false
))
{
InitializerConf
bias_multiplier_initializer_conf
;
bias_multiplier_initializer_conf
.
mutable_constant_conf
()
->
set_value
(
1.0
f
);
KernelUtil
<
device_type
,
T
>::
InitializeWithConf
(
ctx
,
bias_multiplier_initializer_conf
,
0
,
...
...
oneflow/core/kernel/conv_kernel.cu
浏览文件 @
10e08d66
...
...
@@ -5,7 +5,7 @@ namespace oneflow {
template
<
typename
T
>
void
ConvKernel
<
DeviceType
::
kGPU
,
T
>::
VirtualKernelInit
(
const
ParallelContext
*
parallel_ctx
)
{
if
(
this
->
UseCudnnOnGpu
())
{
if
(
this
->
EnableCudnn
())
{
KernelInitWithCudnn
(
parallel_ctx
);
}
else
{
ConvKernelImplByIm2Col
<
DeviceType
::
kGPU
,
T
>::
VirtualKernelInit
(
parallel_ctx
);
...
...
@@ -16,7 +16,7 @@ template<typename T>
void
ConvKernel
<
DeviceType
::
kGPU
,
T
>::
DoForwardDataContent
(
DeviceCtx
*
device_ctx
,
const
Blob
*
in_blob
,
const
Blob
*
weight_blob
,
Blob
*
out_blob
,
std
::
function
<
Blob
*
(
const
std
::
string
&
)
>
BnInOp2Blob
)
const
{
if
(
this
->
UseCudnnOnGpu
())
{
if
(
this
->
EnableCudnn
())
{
DoForwardDataContentWithCudnn
(
device_ctx
,
in_blob
,
weight_blob
,
out_blob
,
BnInOp2Blob
);
}
else
{
ConvKernelImplByIm2Col
<
DeviceType
::
kGPU
,
T
>::
DoForwardDataContent
(
...
...
@@ -28,7 +28,7 @@ template<typename T>
void
ConvKernel
<
DeviceType
::
kGPU
,
T
>::
WeightBackward
(
DeviceCtx
*
device_ctx
,
const
Blob
*
out_diff_blob
,
const
Blob
*
in_blob
,
Blob
*
weight_diff_blob
,
Blob
*
in_diff_blob
,
std
::
function
<
Blob
*
(
const
std
::
string
&
)
>
BnInOp2Blob
)
const
{
if
(
this
->
UseCudnnOnGpu
())
{
if
(
this
->
EnableCudnn
())
{
WeightBackwardWithCudnn
(
device_ctx
,
out_diff_blob
,
in_blob
,
weight_diff_blob
,
in_diff_blob
,
BnInOp2Blob
);
}
else
{
...
...
@@ -41,7 +41,7 @@ template<typename T>
void
ConvKernel
<
DeviceType
::
kGPU
,
T
>::
BiasBackward
(
DeviceCtx
*
device_ctx
,
const
Blob
*
out_diff_blob
,
Blob
*
bias_diff_blob
,
std
::
function
<
Blob
*
(
const
std
::
string
&
)
>
BnInOp2Blob
)
const
{
if
(
this
->
UseCudnnOnGpu
())
{
if
(
this
->
EnableCudnn
())
{
BiasBackwardWithCudnn
(
device_ctx
,
out_diff_blob
,
bias_diff_blob
,
BnInOp2Blob
);
}
else
{
ConvKernelImplByIm2Col
<
DeviceType
::
kGPU
,
T
>::
BiasBackward
(
device_ctx
,
out_diff_blob
,
...
...
oneflow/core/kernel/kernel.h
浏览文件 @
10e08d66
...
...
@@ -136,8 +136,7 @@ class KernelIf : public Kernel {
const
PbRpf
<
std
::
string
>&
from_bns
,
const
PbRpf
<
std
::
string
>&
to_bns
,
void
(
Blob
::*
Copy
)(
DeviceCtx
*
,
const
Blob
*
))
const
;
bool
UseCudnn
()
const
{
return
device_type
==
DeviceType
::
kGPU
&&
UseCudnnOnGpu
();
}
bool
UseCudnnOnGpu
()
const
{
return
op_conf
().
use_cudnn_on_gpu
();
}
bool
EnableCudnn
()
const
{
return
op_conf
().
enable_cudnn
();
}
};
template
<
DeviceType
device_type
,
typename
ModelType
>
...
...
oneflow/core/kernel/reduce_gather_kernel.cpp
浏览文件 @
10e08d66
...
...
@@ -13,7 +13,7 @@ void ReduceGatherKernel<device_type>::ForwardDataContent(
dst_cur_dptr
+=
this
->
kernel_conf
().
reduce_gather_conf
().
data_offset
().
Get
(
in_bn_id
);
Blob
*
in_blob
=
BnInOp2Blob
(
this
->
op_attribute
().
input_bns
().
Get
(
in_bn_id
));
size_t
in_byte_size
=
in_blob
->
ByteSizeOfDataContentField
();
Memcpy
<
DeviceType
::
k
G
PU
>
(
ctx
.
device_ctx
,
dst_cur_dptr
,
in_blob
->
dptr
<
char
>
(),
in_byte_size
);
Memcpy
<
DeviceType
::
k
C
PU
>
(
ctx
.
device_ctx
,
dst_cur_dptr
,
in_blob
->
dptr
<
char
>
(),
in_byte_size
);
}
ADD_DEVICE_TYPE_KERNEL_CREATOR
(
OperatorConf
::
kReduceGatherConf
,
ReduceGatherKernel
);
...
...
oneflow/core/operator/conv_op.cpp
浏览文件 @
10e08d66
...
...
@@ -132,7 +132,7 @@ void ConvOp<NDims>::InferBlobDescs(std::function<BlobDesc*(const std::string&)>
if
(
GetValFromCustomizedConf
<
bool
>
(
"use_bias"
))
{
// bias and bias_multiplier
GetBlobDesc4BnInOp
(
"bias"
)
->
mut_shape
()
=
Shape
({
filters
,
1
});
if
(
!
UseCudnnOnGpu
()
)
{
if
(
DevIsGpuAndEnableCudnn
()
==
false
)
{
std
::
vector
<
int64_t
>
bias_mul_shape
(
NDims
+
1
,
1
);
for
(
size_t
i
=
0
;
i
!=
NDims
;
++
i
)
{
bias_mul_shape
[
i
+
1
]
=
out_shape
[
dhw_offset
+
i
];
}
GetBlobDesc4BnInOp
(
"bias_multiplier"
)
->
mut_shape
()
=
Shape
(
bias_mul_shape
);
...
...
@@ -142,7 +142,7 @@ void ConvOp<NDims>::InferBlobDescs(std::function<BlobDesc*(const std::string&)>
ConvOpCtx
*
conv_op_ctx
=
new
ConvOpCtx
();
EnrollOpCtx
(
conv_op_ctx
);
if
(
device_type
()
==
DeviceType
::
kCPU
||
!
UseCudnnOnGpu
()
)
{
if
(
DevIsGpuAndEnableCudnn
()
==
false
)
{
// col_buf
int64_t
col_buf_elem_cnt
=
1
;
for
(
size_t
i
=
0
;
i
!=
NDims
+
1
;
++
i
)
{
col_buf_elem_cnt
*=
weight_shape
[
i
+
1
];
}
...
...
@@ -154,7 +154,7 @@ void ConvOp<NDims>::InferBlobDescs(std::function<BlobDesc*(const std::string&)>
}
#ifdef WITH_CUDA
if
(
device_type
()
==
DeviceType
::
kGPU
&&
UseCudnnOnGpu
())
{
if
(
DevIsGpuAndEnableCudnn
())
{
// cudnn_buf
InferCudnnAlgo
(
GetBlobDesc4BnInOp
,
&
(
conv_op_ctx
->
cudnn_conv_algo_ctx
),
0
);
BlobDesc
*
fw_cudnn_buf
=
GetBlobDesc4BnInOp
(
"fw_cudnn_buf"
);
...
...
@@ -170,7 +170,7 @@ void ConvOp<NDims>::InferBwBufBlobDescs(
std
::
function
<
BlobDesc
*
(
const
std
::
string
&
)
>
GetBlobDesc4BnInOp
,
const
ParallelContext
*
,
const
OpContext
*
op_ctx
)
const
{
const
ConvOpCtx
*
conv_op_ctx
=
static_cast
<
const
ConvOpCtx
*>
(
op_ctx
);
if
(
device_type
()
==
DeviceType
::
kCPU
||
!
UseCudnnOnGpu
()
)
{
if
(
DevIsGpuAndEnableCudnn
()
==
false
)
{
// col_buf
BlobDesc
*
bw_col_buf
=
GetBlobDesc4BnInOp
(
"bw_col_buf"
);
bw_col_buf
->
mut_shape
()
=
Shape
({
conv_op_ctx
->
col_buf_size
});
...
...
@@ -178,7 +178,7 @@ void ConvOp<NDims>::InferBwBufBlobDescs(
}
#ifdef WITH_CUDA
if
(
device_type
()
==
DeviceType
::
kGPU
&&
UseCudnnOnGpu
())
{
if
(
DevIsGpuAndEnableCudnn
())
{
// cudnn_buf
BlobDesc
*
bw_cudnn_buf
=
GetBlobDesc4BnInOp
(
"bw_cudnn_buf"
);
bw_cudnn_buf
->
mut_shape
()
=
...
...
@@ -275,10 +275,10 @@ void ConvOp<NDims>::VirtualGenKernelConf(
const
ParallelContext
*
parallel_ctx
,
KernelConf
*
kernel_conf
,
const
OpContext
*
op_ctx
)
const
{
ConvKernelConf
*
conv_conf
=
kernel_conf
->
mutable_conv_conf
();
conv_conf
->
set_dim
(
NDims
);
if
(
!
UseCudnnOnGpu
())
{
GenKernelConfWithoutCudnn
(
GetBlobDesc4BnInOp
,
conv_conf
);
}
else
{
if
(
DevIsGpuAndEnableCudnn
())
{
GenKernelConfWithCudnn
(
GetBlobDesc4BnInOp
,
kernel_conf
,
conv_conf
,
op_ctx
);
}
else
{
GenKernelConfWithoutCudnn
(
GetBlobDesc4BnInOp
,
conv_conf
);
}
}
...
...
oneflow/core/operator/op_conf.proto
浏览文件 @
10e08d66
...
...
@@ -607,7 +607,7 @@ message OperatorConf {
optional
string
model_load_dir
=
2
;
optional
bool
trainable
=
3
[
default
=
true
];
optional
DeviceType
device_type
=
4
[
default
=
kInvalidDevice
];
optional
bool
use_cudnn_on_gpu
=
5
;
optional
bool
enable_cudnn
=
5
;
optional
int64
cudnn_buf_limit_mbyte
=
6
[
default
=
1024
];
// 1GByte
oneof
op_type
{
FullyConnectedOpConf
fully_connected_conf
=
101
;
...
...
oneflow/core/operator/operator.cpp
浏览文件 @
10e08d66
...
...
@@ -20,8 +20,8 @@ DataType GetDataTypeFromBnInOpVec(
void
Operator
::
InitFromOpConf
(
const
OperatorConf
&
op_conf
)
{
OperatorConf
*
this_op_conf
=
op_attribute_
.
mutable_op_conf
();
*
this_op_conf
=
op_conf
;
if
(
this_op_conf
->
has_
use_cudnn_on_gpu
()
==
false
)
{
this_op_conf
->
set_
use_cudnn_on_gpu
(
Global
<
JobDesc
>::
Get
()
->
UseCudnnOnGpu
());
if
(
this_op_conf
->
has_
enable_cudnn
()
==
false
)
{
this_op_conf
->
set_
enable_cudnn
(
Global
<
JobDesc
>::
Get
()
->
EnableCudnn
());
}
if
(
GetActivationType
()
!=
ActivationType
::
kNone
)
{
EnrollBwBufBn
(
"bw_activation"
);
}
InitFromOpConf
();
...
...
oneflow/core/operator/operator.h
浏览文件 @
10e08d66
...
...
@@ -53,8 +53,8 @@ class Operator {
// Getters
const
std
::
string
&
op_name
()
const
{
return
op_conf
().
name
();
}
DeviceType
device_type
()
const
{
return
op_attribute_
.
op_conf
().
device_type
();
}
bool
UseCudnn
()
const
{
return
device_type
()
==
DeviceType
::
kGPU
&&
UseCudnnOnGpu
();
}
bool
UseCudnnOnGpu
()
const
{
return
op_conf
().
use_cudnn_on_gpu
();
}
bool
EnableCudnn
()
const
{
return
op_conf
().
enable_cudnn
();
}
bool
DevIsGpuAndEnableCudnn
()
const
{
return
device_type
()
==
DeviceType
::
kGPU
&&
EnableCudnn
();
}
const
OperatorConf
&
op_conf
()
const
{
return
op_attribute_
.
op_conf
();
}
virtual
const
PbMessage
&
GetCustomizedConf
()
const
{
UNIMPLEMENTED
();
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录