提交 f5e8efae 编写于 作者: W willzhang4a58

basic data loader

上级 c6a7a351
#ifndef ONEFLOW_CORE_COMM_NETWORK_RDMA_RDMA_MEMORY_H
#define ONEFLOW_CORE_COMM_NETWORK_RDMA_RDMA_MEMORY_H
#ifndef ONEFLOW_CORE_COMM_NETWORK_RDMA_RDMA_MEMORY_H_
#define ONEFLOW_CORE_COMM_NETWORK_RDMA_RDMA_MEMORY_H_
#include "oneflow/core/common/util.h"
......@@ -27,4 +27,4 @@ class RdmaMem {
#endif // WITH_RDMA
#endif // ONEFLOW_CORE_COMM_NETWORK_RDMA_RDMA_MEMORY_H
#endif // ONEFLOW_CORE_COMM_NETWORK_RDMA_RDMA_MEMORY_H_
......@@ -9,7 +9,7 @@ class BackwardCompTaskNode : public CompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(BackwardCompTaskNode);
BackwardCompTaskNode() = default;
~BackwardCompTaskNode() = default;
virtual ~BackwardCompTaskNode() = default;
void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;
......
......@@ -9,7 +9,7 @@ class ForwardCompTaskNode : public CompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(ForwardCompTaskNode);
ForwardCompTaskNode() = default;
~ForwardCompTaskNode() = default;
virtual ~ForwardCompTaskNode() = default;
void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;
......
#include "oneflow/core/kernel/data_loader_kernel.h"
#include "oneflow/core/kernel/basic_data_loader_kernel.h"
#include "oneflow/core/common/str_util.h"
#include "oneflow/core/job/runtime_context.h"
#include "oneflow/core/persistence/cyclic_persistent_in_stream.h"
......@@ -7,7 +7,7 @@
namespace oneflow {
template<typename T>
void DataLoaderKernel<T>::Forward(
void BasicDataLoaderKernel<T>::Forward(
const KernelCtx& kernel_ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
Blob* out_blob = BnInOp2Blob("out");
......@@ -48,9 +48,9 @@ void DataLoaderKernel<T>::Forward(
}
template<typename T>
void DataLoaderKernel<T>::VirtualKernelInit(
void BasicDataLoaderKernel<T>::VirtualKernelInit(
const ParallelContext* parallel_ctx) {
const std::string& data_dir = op_conf().data_loader_conf().data_dir();
const std::string& data_dir = op_conf().basic_data_loader_conf().data_dir();
std::string parallel_id = std::to_string(parallel_ctx->parallel_id());
std::string file_path = JoinPath(data_dir, "part-" + parallel_id);
if (JobDesc::Singleton()->IsTrain()) {
......@@ -60,7 +60,7 @@ void DataLoaderKernel<T>::VirtualKernelInit(
}
}
ADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kDataLoaderConf, DataLoaderKernel,
ARITHMETIC_DATA_TYPE_SEQ);
ADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kBasicDataLoaderConf,
BasicDataLoaderKernel, ARITHMETIC_DATA_TYPE_SEQ);
} // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_DATA_LOADER_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_DATA_LOADER_KERNEL_H_
#ifndef ONEFLOW_CORE_KERNEL_BASIC_DATA_LOADER_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_BASIC_DATA_LOADER_KERNEL_H_
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
template<typename T>
class DataLoaderKernel final : public KernelIf<DeviceType::kCPU> {
class BasicDataLoaderKernel final : public KernelIf<DeviceType::kCPU> {
public:
OF_DISALLOW_COPY_AND_MOVE(DataLoaderKernel);
DataLoaderKernel() = default;
~DataLoaderKernel() = default;
OF_DISALLOW_COPY_AND_MOVE(BasicDataLoaderKernel);
BasicDataLoaderKernel() = default;
~BasicDataLoaderKernel() = default;
void Forward(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
......@@ -23,4 +23,4 @@ class DataLoaderKernel final : public KernelIf<DeviceType::kCPU> {
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_DATA_LOADER_KERNEL_H_
#endif // ONEFLOW_CORE_KERNEL_BASIC_DATA_LOADER_KERNEL_H_
#include "oneflow/core/operator/data_loader_op.h"
#include "oneflow/core/operator/basic_data_loader_op.h"
#include "oneflow/core/job/job_desc.h"
namespace oneflow {
void DataLoaderOp::InitFromOpConf() {
CHECK(op_conf().has_data_loader_conf());
void BasicDataLoaderOp::InitFromOpConf() {
CHECK(op_conf().has_basic_data_loader_conf());
EnrollOutputBn("out", false);
}
const PbMessage& DataLoaderOp::GetSpecialConf() const {
return op_conf().data_loader_conf();
const PbMessage& BasicDataLoaderOp::GetSpecialConf() const {
return op_conf().basic_data_loader_conf();
}
void DataLoaderOp::InferBlobDescs(
void BasicDataLoaderOp::InferBlobDescs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
const DataLoaderOpConf& conf = op_conf().data_loader_conf();
const BasicDataLoaderOpConf& conf = op_conf().basic_data_loader_conf();
BlobDesc* out = GetBlobDesc4BnInOp("out");
std::vector<int64_t> dim_vec(1 + conf.shape().dim_size());
dim_vec[0] = JobDesc::Singleton()->SinglePieceSize();
......@@ -28,6 +28,6 @@ void DataLoaderOp::InferBlobDescs(
out->set_has_data_id(JobDesc::Singleton()->SizeOfOneDataId() > 0);
}
REGISTER_OP(OperatorConf::kDataLoaderConf, DataLoaderOp);
REGISTER_OP(OperatorConf::kBasicDataLoaderConf, BasicDataLoaderOp);
} // namespace oneflow
#ifndef ONEFLOW_CORE_OPERATOR_DATA_LOADER_OP_H_
#define ONEFLOW_CORE_OPERATOR_DATA_LOADER_OP_H_
#ifndef ONEFLOW_CORE_OPERATOR_BASIC_DATA_LOADER_OP_H_
#define ONEFLOW_CORE_OPERATOR_BASIC_DATA_LOADER_OP_H_
#include "oneflow/core/operator/operator.h"
namespace oneflow {
class DataLoaderOp final : public Operator {
class BasicDataLoaderOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(DataLoaderOp);
DataLoaderOp() = default;
~DataLoaderOp() = default;
OF_DISALLOW_COPY_AND_MOVE(BasicDataLoaderOp);
BasicDataLoaderOp() = default;
~BasicDataLoaderOp() = default;
void InitFromOpConf() override;
const PbMessage& GetSpecialConf() const override;
......@@ -22,4 +22,4 @@ class DataLoaderOp final : public Operator {
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_DATA_LOADER_OP_H_
#endif // ONEFLOW_CORE_OPERATOR_BASIC_DATA_LOADER_OP_H_
......@@ -57,7 +57,7 @@ message InnerProductOpConf {
optional FillConf bias_fill = 6;
}
message DataLoaderOpConf {
message BasicDataLoaderOpConf {
required string out = 1;
required DataType data_type = 2;
required string data_dir = 3;
......@@ -217,7 +217,7 @@ message OperatorConf {
oneof op_type {
ConvolutionOpConf convolution_conf = 100;
InnerProductOpConf innerproduct_conf = 101;
DataLoaderOpConf data_loader_conf = 102;
BasicDataLoaderOpConf basic_data_loader_conf = 102;
PoolingOpConf pooling_conf = 103;
ReluOpConf relu_conf = 104;
SoftmaxOpConf softmax_conf = 105;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册