diff --git a/research/object_detection/builders/dataset_builder.py b/research/object_detection/builders/dataset_builder.py index 78874818a8b2608ba67946dbe1d600be5b2ce3fa..f0dd3002b407b6de900a042dc1a81f4b7878e022 100644 --- a/research/object_detection/builders/dataset_builder.py +++ b/research/object_detection/builders/dataset_builder.py @@ -242,7 +242,8 @@ def build(input_reader_config, batch_size=None, transform_input_data_fn=None, dataset = dataset_map_fn(dataset, transform_input_data_fn, batch_size, input_reader_config) if batch_size: - dataset = dataset.batch(batch_size, drop_remainder=True) + dataset = dataset.batch(batch_size, + drop_remainder=input_reader_config.drop_remainder) dataset = dataset.prefetch(input_reader_config.num_prefetch_batches) return dataset diff --git a/research/object_detection/protos/input_reader.proto b/research/object_detection/protos/input_reader.proto index 745453813269adeeece07b1ba2e2313bec87bec9..50c68ddf64d2af549ac92808bd3dbb0014dd0362 100644 --- a/research/object_detection/protos/input_reader.proto +++ b/research/object_detection/protos/input_reader.proto @@ -30,7 +30,7 @@ enum InputType { TF_SEQUENCE_EXAMPLE = 2; // TfSequenceExample Input } -// Next id: 35 +// Next id: 36 message InputReader { // Name of input reader. Typically used to describe the dataset that is read // by this input reader. @@ -93,6 +93,9 @@ message InputReader { // Number of parallel decode ops to apply. optional uint32 num_parallel_map_calls = 14 [default = 64, deprecated = true]; + // Drop remainder when batch size does not divide dataset size. + optional bool drop_remainder = 35 [default = true]; + // If positive, TfExampleDecoder will try to decode rasters of additional // channels from tf.Examples. optional int32 num_additional_channels = 18 [default = 0];