提交 56e032bf 编写于 作者: S Shiyuan Shang-Guan 提交者: Jinhui Yuan

fix decode_random and refine synthetic_data (#1278)

* fix decode_random and refine synthetic_data

* add example

* initialize only once


Former-commit-id: a1b44c05
上级 ff1de4cd
net: "./net.prototxt"
resource: "./resource.prototxt"
placement: "./placement.prototxt"
other: "./other.prototxt"
op {
name: "decode_random1"
decode_random_conf {
out: "data"
data_type: kFloat
shape {
dim: 1
dim: 28
dim: 28
}
data_initializer {
constant_conf {
value: 0
}
}
}
}
op {
name: "decode_random2"
decode_random_conf {
out: "label"
data_type: kInt32
shape {
}
data_initializer {
constant_int_conf {
value: 0
}
}
}
}
op {
name: "conv1"
conv_2d_conf {
in: "decode_random1/data"
out: "out"
filters: 32
padding: "SAME"
data_format: "channels_first"
kernel_size: 5
kernel_size: 5
strides: 1
strides: 1
dilation_rate: 1
dilation_rate: 1
use_bias: true
}
}
op {
name: "relu1"
relu_conf {
in: "conv1/out"
out: "out"
}
}
op {
name: "pool1"
max_pooling_2d_conf {
in: "relu1/out"
out: "out"
padding: "SAME"
data_format: "channels_first"
pool_size: 2
pool_size: 2
strides: 2
strides: 2
}
}
op {
name: "conv2"
conv_2d_conf {
in: "pool1/out"
out: "out"
filters: 64
padding: "SAME"
data_format: "channels_first"
kernel_size: 5
kernel_size: 5
strides: 1
strides: 1
dilation_rate: 1
dilation_rate: 1
use_bias: true
}
}
op {
name: "relu2"
relu_conf {
in: "conv2/out"
out: "out"
}
}
op {
name: "pool2"
max_pooling_2d_conf {
in: "relu2/out"
out: "out"
padding: "SAME"
data_format: "channels_first"
pool_size: 2
pool_size: 2
strides: 2
strides: 2
}
}
op {
name: "ip1024"
fully_connected_conf {
in: "pool2/out"
out: "out"
units: 1024
}
}
op {
name: "relu3"
relu_conf {
in: "ip1024/out"
out: "out"
}
}
op {
name: "dropout"
dropout_conf {
in: "relu3/out"
out: "out"
rate: 0.5
}
}
op {
name: "ip10"
fully_connected_conf {
in: "dropout/out"
out: "out"
units: 10
}
}
op {
name: "softmax_loss"
sparse_softmax_cross_entropy_loss_conf {
prediction: "ip10/out"
label: "decode_random2/label"
loss: "loss"
}
}
data_fs_conf {
localfs_conf {
}
}
snapshot_fs_conf {
localfs_conf {
}
}
piece_size: 600
data_part_num: 6
train_conf {
batch_size: 6000
total_batch_num: 100
primary_lr: 0.01
model_update_conf {
naive_conf {
}
}
model_save_snapshots_path: "/tmp/snapshot"
num_of_batches_in_snapshot: 10
default_initializer_conf {
random_normal_conf {
mean: 0.0
std: 0.1
}
}
}
placement_group {
op_set {
op_name: "decode_random1"
op_name: "decode_random2"
op_name: "conv1"
op_name: "relu1"
op_name: "pool1"
op_name: "conv2"
op_name: "relu2"
op_name: "pool2"
op_name: "ip1024"
op_name: "relu3"
op_name: "dropout"
op_name: "ip10"
op_name: "softmax_loss"
}
parallel_conf {
policy: kDataParallel
device_name: "0:gpu:0-1"
}
}
machine {
addr: "127.0.0.1"
port: 7099
id: 0
}
gpu_device_num: 2
......@@ -5,7 +5,7 @@
namespace oneflow {
void DecodeRandomCompTaskNode::ProduceAllRegstsAndBindEdges() {
ProduceRegst("out", true);
ProduceRegst("out", false);
for (TaskEdge* edge : out_edges()) { BindEdgeWithProducedRegst(edge, "out"); }
}
......
......@@ -66,7 +66,6 @@ message OtherConf {
optional int64 reduce_group_size = 112 [default = 20];
optional bool collect_act_event = 113 [default = false];
optional bool enable_mem_sharing = 114 [default = true];
optional bool use_synthetic_data = 116 [default = false];
optional bool enable_write_snapshot = 130 [default = true];
optional bool enable_blob_mem_sharing = 140 [default = true];
optional bool enable_nccl = 142 [default = true];
......
......@@ -25,7 +25,6 @@ class JobDesc final {
DataType DefaultDataType() const { return job_conf_.other().default_data_type(); }
size_t SizeOfOneDataId() const { return job_conf_.other().max_data_id_length() * sizeof(char); }
bool use_rdma() const { return job_conf_.other().use_rdma(); }
bool use_synthetic_data() const { return job_conf_.other().use_synthetic_data(); }
bool EnableCudnn() const { return job_conf_.other().enable_cudnn(); }
int64_t TotalMachineNum() const { return job_conf_.resource().machine().size(); }
int32_t CpuDeviceNum() const { return job_conf_.resource().cpu_device_num(); }
......
......@@ -4,38 +4,46 @@ namespace oneflow {
namespace {
void RandomFillBlob(DeviceCtx* ctx, const InitializerConf& initializer_conf, uint32_t random_seed,
Blob* blob) {
static const HashMap<int, void (*)(DeviceCtx * ctx, const InitializerConf& initializer_conf,
uint32_t random_seed, Blob* blob)>
void RandomFillBlob(DeviceCtx* ctx, DeviceType device_type, const InitializerConf& initializer_conf,
uint32_t random_seed, Blob* blob) {
static const HashMap<std::string,
void (*)(DeviceCtx * ctx, const InitializerConf& initializer_conf,
uint32_t random_seed, Blob* blob)>
fill_funcs = {
#define RANDOM_FILL_ENTRY(type_cpp, type_proto) \
{type_proto, &KernelUtil<DeviceType::kCPU, type_cpp>::InitializeWithConf},
OF_PP_FOR_EACH_TUPLE(RANDOM_FILL_ENTRY, ARITHMETIC_DATA_TYPE_SEQ)};
fill_funcs.at(blob->data_type())(ctx, initializer_conf, random_seed, blob);
#define RANDOM_FILL_ENTRY(type_dev, data_type_pair) \
{GetHashKey(type_dev, OF_PP_PAIR_SECOND(data_type_pair)), \
&KernelUtil<type_dev, OF_PP_PAIR_FIRST(data_type_pair)>::InitializeWithConf},
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(RANDOM_FILL_ENTRY, DEVICE_TYPE_SEQ,
ARITHMETIC_DATA_TYPE_SEQ)};
fill_funcs.at(GetHashKey(device_type, blob->data_type()))(ctx, initializer_conf, random_seed,
blob);
}
} // namespace
void DecodeRandomKernel::VirtualKernelInit(const ParallelContext*) {
gen_.reset(new std::mt19937(kernel_conf().decode_random_conf().random_seed()));
template<DeviceType device_type>
void DecodeRandomKernel<device_type>::VirtualKernelInit(const ParallelContext*) {
gen_.reset(new std::mt19937(this->kernel_conf().decode_random_conf().random_seed()));
dis_.reset(new std::uniform_int_distribution<uint32_t>());
is_init_ = false;
}
uint32_t DecodeRandomKernel::GenNextRandomSeed() const { return (*dis_)(*gen_); }
template<DeviceType device_type>
uint32_t DecodeRandomKernel<device_type>::GenNextRandomSeed() const {
return (*dis_)(*gen_);
}
void DecodeRandomKernel::Forward(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const DecodeRandomOpConf& conf = op_conf().decode_random_conf();
CHECK(ctx.other);
auto status = static_cast<DecodeStatus*>(ctx.other);
if (conf.max_sequence_size() > 1 && status->max_col_id_ == 0 && status->cur_col_id_ == 0) {
status->max_col_id_ = GenNextRandomSeed() % conf.max_sequence_size();
template<DeviceType device_type>
void DecodeRandomKernel<device_type>::Forward(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const DecodeRandomOpConf& conf = this->op_conf().decode_random_conf();
if (is_init_ == false) {
RandomFillBlob(ctx.device_ctx, device_type, conf.data_initializer(), this->GenNextRandomSeed(),
BnInOp2Blob("out"));
is_init_ = true;
}
RandomFillBlob(ctx.device_ctx, conf.distribution(), GenNextRandomSeed(),
BnInOp2Blob(op_attribute().output_bns(0)));
}
REGISTER_KERNEL(OperatorConf::kDecodeRandomConf, DecodeRandomKernel);
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kDecodeRandomConf, DecodeRandomKernel);
} // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_DECODE_RANDOM_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_DECODE_RANDOM_KERNEL_H_
#include "oneflow/core/kernel/decode_ofrecord_kernel.h"
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
class DecodeRandomKernel final : public KernelIf<DeviceType::kCPU> {
template<DeviceType device_type>
class DecodeRandomKernel final : public KernelIf<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(DecodeRandomKernel);
DecodeRandomKernel() = default;
......@@ -19,6 +20,8 @@ class DecodeRandomKernel final : public KernelIf<DeviceType::kCPU> {
std::unique_ptr<std::mt19937> gen_;
std::unique_ptr<std::uniform_int_distribution<uint32_t>> dis_;
mutable bool is_init_;
};
} // namespace oneflow
......
......@@ -30,7 +30,6 @@ void RecordLoadKernel::VirtualKernelInit(const ParallelContext* parallel_ctx) {
int64_t global_piece_size = Global<JobDesc>::Get()->PieceSize();
CHECK_EQ(global_piece_size % parallel_ctx->parallel_num(), 0);
piece_size_in_one_loader_ = global_piece_size / parallel_ctx->parallel_num();
loaded_cnt_ = 0;
}
void RecordLoadKernel::Forward(const KernelCtx& ctx,
......@@ -38,10 +37,7 @@ void RecordLoadKernel::Forward(const KernelCtx& ctx,
auto status = static_cast<RecordLoadStatus*>(ctx.other);
Blob* out_blob = BnInOp2Blob("out");
RecordBlob<OFRecord> record_blob(out_blob);
if (!Global<JobDesc>::Get()->use_synthetic_data() || loaded_cnt_ < 2) {
record_blob.ReadFrom(in_stream_.get());
++loaded_cnt_;
}
record_blob.ReadFrom(in_stream_.get());
status->record_num = record_blob.record_num();
if (status->record_num < piece_size_in_one_loader_) { status->is_eof = true; }
}
......
......@@ -24,7 +24,6 @@ class RecordLoadKernel final : public KernelIf<DeviceType::kCPU> {
std::unique_ptr<PersistentInStream> in_stream_;
int64_t piece_size_in_one_loader_;
mutable int64_t loaded_cnt_;
};
} // namespace oneflow
......
......@@ -29,8 +29,6 @@ void DecodeRandomOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)>
out_blob_desc->mut_shape() = Shape(dim_vec);
out_blob_desc->set_data_type(conf.data_type());
out_blob_desc->set_has_data_id_field(Global<JobDesc>::Get()->SizeOfOneDataId() > 0);
out_blob_desc->set_has_col_num_field(conf.max_sequence_size() > 1);
out_blob_desc->set_max_col_num(conf.max_sequence_size());
}
REGISTER_OP(OperatorConf::kDecodeRandomConf, DecodeRandomOp);
......
......@@ -14,7 +14,7 @@ class DecodeRandomOp final : public Operator {
void InitFromOpConf() override;
const PbMessage& GetCustomizedConf() const override;
LogicalNode* NewProperLogicalNode() override { return new DecodeLogicalNode; }
LogicalNode* NewProperLogicalNode() override { return new DecodeRandomLogicalNode; }
bool IsDecodeOp() const override { return true; }
......
......@@ -555,8 +555,7 @@ message DecodeRandomOpConf {
required string out = 1;
required ShapeProto shape = 2;
required DataType data_type = 3;
optional int32 max_sequence_size = 4 [default = 1];
required InitializerConf distribution = 7;
required InitializerConf data_initializer = 4;
}
message DefineTestBlobConf {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册