提交 8e8cdab2 编写于 作者: W willzhang4a58

use_cudnn_on_gpu

上级 0ed0a2df
......@@ -163,6 +163,12 @@ inline double GetCurTime() {
size_t GetAvailableCpuMemSize();
inline void CheckUseCudnn(bool val) {
#ifndef WITH_CUDNN
CHECK_EQ(val, false) << "Please compile ONEFLOW with CUDNN";
#endif
}
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_UTIL_H_
......@@ -44,6 +44,7 @@ message JobConf {
optional string model_load_snapshot_path = 7 [default = ""];
optional int32 max_data_id_length = 8 [default = 0];
optional bool use_rdma = 9 [default = false];
optional bool use_cudnn_on_gpu = 10;
optional DataType default_data_type = 100 [default = kFloat]; // kFloat or kDouble
optional int64 piece_num_of_experiment_phase = 101 [default = 100];
......
......@@ -73,8 +73,16 @@ JobDesc::JobDesc(const JobDescProto& job_desc) {
resource_ = job_desc.resource();
placement_ = job_desc.placement();
#ifndef WITH_RDMA
CHECK_EQ(job_conf_.use_rdma(), false) << "Please compile oneflow with rdma";
CHECK_EQ(job_conf_.use_rdma(), false) << "Please compile ONEFLOW with RDMA";
#endif
if (job_conf_.has_use_cudnn_on_gpu() == false) {
#ifdef WITH_CUDNN
job_conf_.set_use_cudnn_on_gpu(true);
#else
job_conf_.set_use_cudnn_on_gpu(false);
#endif
}
CheckUseCudnn(job_conf_.use_cudnn_on_gpu());
int64_t piece_experiment = job_conf_.piece_num_of_experiment_phase();
if (job_conf_.has_train_conf()) {
const TrainConf& train_conf = job_conf_.train_conf();
......
......@@ -27,6 +27,7 @@ class JobDesc final {
DataType DefaultDataType() const { return job_conf_.default_data_type(); }
size_t SizeOfOneDataId() const;
bool use_rdma() const { return job_conf_.use_rdma(); }
bool UseCudnn() const { return job_conf_.use_cudnn_on_gpu(); }
int64_t TotalMachineNum() const { return resource_.machine().size(); }
int32_t CpuDeviceNum() const { return resource_.cpu_device_num(); }
int32_t GpuDeviceNum() const { return resource_.gpu_device_num(); }
......
......@@ -256,6 +256,7 @@ message RecurrentOpConf {
message OperatorConf {
required string name = 1;
optional string model_load_dir = 2;
optional bool use_cudnn_on_gpu = 3;
oneof op_type {
ConvolutionOpConf convolution_conf = 100;
FullyConnectedOpConf fully_connected_conf = 101;
......
......@@ -18,6 +18,10 @@ DataType GetDataTypeFromBnInOpVec(
void Operator::InitFromOpConf(const OperatorConf& op_conf) {
op_conf_ = op_conf;
if (op_conf_.has_use_cudnn_on_gpu() == false) {
op_conf_.set_use_cudnn_on_gpu(JobDesc::Singleton()->UseCudnn());
}
CheckUseCudnn(op_conf_.use_cudnn_on_gpu());
InitFromOpConf();
}
......
......@@ -46,6 +46,7 @@ class Operator {
// Getters
const std::string& op_name() const { return op_conf_.name(); }
bool UseCudnn() const { return op_conf_.use_cudnn_on_gpu(); }
const OperatorConf& op_conf() const { return op_conf_; }
virtual const PbMessage& GetSpecialConf() const { UNEXPECTED_RUN(); }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册