提交 f5e8efae 编写于 作者: W willzhang4a58

basic data loader

上级 c6a7a351
#ifndef ONEFLOW_CORE_COMM_NETWORK_RDMA_RDMA_MEMORY_H #ifndef ONEFLOW_CORE_COMM_NETWORK_RDMA_RDMA_MEMORY_H_
#define ONEFLOW_CORE_COMM_NETWORK_RDMA_RDMA_MEMORY_H #define ONEFLOW_CORE_COMM_NETWORK_RDMA_RDMA_MEMORY_H_
#include "oneflow/core/common/util.h" #include "oneflow/core/common/util.h"
...@@ -27,4 +27,4 @@ class RdmaMem { ...@@ -27,4 +27,4 @@ class RdmaMem {
#endif // WITH_RDMA #endif // WITH_RDMA
#endif // ONEFLOW_CORE_COMM_NETWORK_RDMA_RDMA_MEMORY_H #endif // ONEFLOW_CORE_COMM_NETWORK_RDMA_RDMA_MEMORY_H_
...@@ -9,7 +9,7 @@ class BackwardCompTaskNode : public CompTaskNode { ...@@ -9,7 +9,7 @@ class BackwardCompTaskNode : public CompTaskNode {
public: public:
OF_DISALLOW_COPY_AND_MOVE(BackwardCompTaskNode); OF_DISALLOW_COPY_AND_MOVE(BackwardCompTaskNode);
BackwardCompTaskNode() = default; BackwardCompTaskNode() = default;
~BackwardCompTaskNode() = default; virtual ~BackwardCompTaskNode() = default;
void ProduceAllRegstsAndBindEdges() override; void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override; void ConsumeAllRegsts() override;
......
...@@ -9,7 +9,7 @@ class ForwardCompTaskNode : public CompTaskNode { ...@@ -9,7 +9,7 @@ class ForwardCompTaskNode : public CompTaskNode {
public: public:
OF_DISALLOW_COPY_AND_MOVE(ForwardCompTaskNode); OF_DISALLOW_COPY_AND_MOVE(ForwardCompTaskNode);
ForwardCompTaskNode() = default; ForwardCompTaskNode() = default;
~ForwardCompTaskNode() = default; virtual ~ForwardCompTaskNode() = default;
void ProduceAllRegstsAndBindEdges() override; void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override; void ConsumeAllRegsts() override;
......
#include "oneflow/core/kernel/data_loader_kernel.h" #include "oneflow/core/kernel/basic_data_loader_kernel.h"
#include "oneflow/core/common/str_util.h" #include "oneflow/core/common/str_util.h"
#include "oneflow/core/job/runtime_context.h" #include "oneflow/core/job/runtime_context.h"
#include "oneflow/core/persistence/cyclic_persistent_in_stream.h" #include "oneflow/core/persistence/cyclic_persistent_in_stream.h"
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
namespace oneflow { namespace oneflow {
template<typename T> template<typename T>
void DataLoaderKernel<T>::Forward( void BasicDataLoaderKernel<T>::Forward(
const KernelCtx& kernel_ctx, const KernelCtx& kernel_ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const { std::function<Blob*(const std::string&)> BnInOp2Blob) const {
Blob* out_blob = BnInOp2Blob("out"); Blob* out_blob = BnInOp2Blob("out");
...@@ -48,9 +48,9 @@ void DataLoaderKernel<T>::Forward( ...@@ -48,9 +48,9 @@ void DataLoaderKernel<T>::Forward(
} }
template<typename T> template<typename T>
void DataLoaderKernel<T>::VirtualKernelInit( void BasicDataLoaderKernel<T>::VirtualKernelInit(
const ParallelContext* parallel_ctx) { const ParallelContext* parallel_ctx) {
const std::string& data_dir = op_conf().data_loader_conf().data_dir(); const std::string& data_dir = op_conf().basic_data_loader_conf().data_dir();
std::string parallel_id = std::to_string(parallel_ctx->parallel_id()); std::string parallel_id = std::to_string(parallel_ctx->parallel_id());
std::string file_path = JoinPath(data_dir, "part-" + parallel_id); std::string file_path = JoinPath(data_dir, "part-" + parallel_id);
if (JobDesc::Singleton()->IsTrain()) { if (JobDesc::Singleton()->IsTrain()) {
...@@ -60,7 +60,7 @@ void DataLoaderKernel<T>::VirtualKernelInit( ...@@ -60,7 +60,7 @@ void DataLoaderKernel<T>::VirtualKernelInit(
} }
} }
ADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kDataLoaderConf, DataLoaderKernel, ADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kBasicDataLoaderConf,
ARITHMETIC_DATA_TYPE_SEQ); BasicDataLoaderKernel, ARITHMETIC_DATA_TYPE_SEQ);
} // namespace oneflow } // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_DATA_LOADER_KERNEL_H_ #ifndef ONEFLOW_CORE_KERNEL_BASIC_DATA_LOADER_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_DATA_LOADER_KERNEL_H_ #define ONEFLOW_CORE_KERNEL_BASIC_DATA_LOADER_KERNEL_H_
#include "oneflow/core/kernel/kernel.h" #include "oneflow/core/kernel/kernel.h"
namespace oneflow { namespace oneflow {
template<typename T> template<typename T>
class DataLoaderKernel final : public KernelIf<DeviceType::kCPU> { class BasicDataLoaderKernel final : public KernelIf<DeviceType::kCPU> {
public: public:
OF_DISALLOW_COPY_AND_MOVE(DataLoaderKernel); OF_DISALLOW_COPY_AND_MOVE(BasicDataLoaderKernel);
DataLoaderKernel() = default; BasicDataLoaderKernel() = default;
~DataLoaderKernel() = default; ~BasicDataLoaderKernel() = default;
void Forward(const KernelCtx&, void Forward(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override; std::function<Blob*(const std::string&)>) const override;
...@@ -23,4 +23,4 @@ class DataLoaderKernel final : public KernelIf<DeviceType::kCPU> { ...@@ -23,4 +23,4 @@ class DataLoaderKernel final : public KernelIf<DeviceType::kCPU> {
} // namespace oneflow } // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_DATA_LOADER_KERNEL_H_ #endif // ONEFLOW_CORE_KERNEL_BASIC_DATA_LOADER_KERNEL_H_
#include "oneflow/core/operator/data_loader_op.h" #include "oneflow/core/operator/basic_data_loader_op.h"
#include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/job_desc.h"
namespace oneflow { namespace oneflow {
void DataLoaderOp::InitFromOpConf() { void BasicDataLoaderOp::InitFromOpConf() {
CHECK(op_conf().has_data_loader_conf()); CHECK(op_conf().has_basic_data_loader_conf());
EnrollOutputBn("out", false); EnrollOutputBn("out", false);
} }
const PbMessage& DataLoaderOp::GetSpecialConf() const { const PbMessage& BasicDataLoaderOp::GetSpecialConf() const {
return op_conf().data_loader_conf(); return op_conf().basic_data_loader_conf();
} }
void DataLoaderOp::InferBlobDescs( void BasicDataLoaderOp::InferBlobDescs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp, std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const { const ParallelContext* parallel_ctx) const {
const DataLoaderOpConf& conf = op_conf().data_loader_conf(); const BasicDataLoaderOpConf& conf = op_conf().basic_data_loader_conf();
BlobDesc* out = GetBlobDesc4BnInOp("out"); BlobDesc* out = GetBlobDesc4BnInOp("out");
std::vector<int64_t> dim_vec(1 + conf.shape().dim_size()); std::vector<int64_t> dim_vec(1 + conf.shape().dim_size());
dim_vec[0] = JobDesc::Singleton()->SinglePieceSize(); dim_vec[0] = JobDesc::Singleton()->SinglePieceSize();
...@@ -28,6 +28,6 @@ void DataLoaderOp::InferBlobDescs( ...@@ -28,6 +28,6 @@ void DataLoaderOp::InferBlobDescs(
out->set_has_data_id(JobDesc::Singleton()->SizeOfOneDataId() > 0); out->set_has_data_id(JobDesc::Singleton()->SizeOfOneDataId() > 0);
} }
REGISTER_OP(OperatorConf::kDataLoaderConf, DataLoaderOp); REGISTER_OP(OperatorConf::kBasicDataLoaderConf, BasicDataLoaderOp);
} // namespace oneflow } // namespace oneflow
#ifndef ONEFLOW_CORE_OPERATOR_DATA_LOADER_OP_H_ #ifndef ONEFLOW_CORE_OPERATOR_BASIC_DATA_LOADER_OP_H_
#define ONEFLOW_CORE_OPERATOR_DATA_LOADER_OP_H_ #define ONEFLOW_CORE_OPERATOR_BASIC_DATA_LOADER_OP_H_
#include "oneflow/core/operator/operator.h" #include "oneflow/core/operator/operator.h"
namespace oneflow { namespace oneflow {
class DataLoaderOp final : public Operator { class BasicDataLoaderOp final : public Operator {
public: public:
OF_DISALLOW_COPY_AND_MOVE(DataLoaderOp); OF_DISALLOW_COPY_AND_MOVE(BasicDataLoaderOp);
DataLoaderOp() = default; BasicDataLoaderOp() = default;
~DataLoaderOp() = default; ~BasicDataLoaderOp() = default;
void InitFromOpConf() override; void InitFromOpConf() override;
const PbMessage& GetSpecialConf() const override; const PbMessage& GetSpecialConf() const override;
...@@ -22,4 +22,4 @@ class DataLoaderOp final : public Operator { ...@@ -22,4 +22,4 @@ class DataLoaderOp final : public Operator {
} // namespace oneflow } // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_DATA_LOADER_OP_H_ #endif // ONEFLOW_CORE_OPERATOR_BASIC_DATA_LOADER_OP_H_
...@@ -57,7 +57,7 @@ message InnerProductOpConf { ...@@ -57,7 +57,7 @@ message InnerProductOpConf {
optional FillConf bias_fill = 6; optional FillConf bias_fill = 6;
} }
message DataLoaderOpConf { message BasicDataLoaderOpConf {
required string out = 1; required string out = 1;
required DataType data_type = 2; required DataType data_type = 2;
required string data_dir = 3; required string data_dir = 3;
...@@ -217,7 +217,7 @@ message OperatorConf { ...@@ -217,7 +217,7 @@ message OperatorConf {
oneof op_type { oneof op_type {
ConvolutionOpConf convolution_conf = 100; ConvolutionOpConf convolution_conf = 100;
InnerProductOpConf innerproduct_conf = 101; InnerProductOpConf innerproduct_conf = 101;
DataLoaderOpConf data_loader_conf = 102; BasicDataLoaderOpConf basic_data_loader_conf = 102;
PoolingOpConf pooling_conf = 103; PoolingOpConf pooling_conf = 103;
ReluOpConf relu_conf = 104; ReluOpConf relu_conf = 104;
SoftmaxOpConf softmax_conf = 105; SoftmaxOpConf softmax_conf = 105;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册