From f550f78c9b83e935dfab044578784f400bf83ff7 Mon Sep 17 00:00:00 2001 From: baiyfbupt Date: Tue, 7 Jul 2020 19:32:54 +0800 Subject: [PATCH] code update --- demo/bert/train_distill.py | 184 +++++++++++++++ .../{train_cell_base.py => train_search.py} | 212 ++++++------------ .../nas/darts/search_space/conv_bert/cls.py | 15 +- .../conv_bert/model/transformer_encoder.py | 58 +++-- paddleslim/teachers/bert/reader/cls.py | 2 +- 5 files changed, 277 insertions(+), 194 deletions(-) create mode 100755 demo/bert/train_distill.py rename demo/bert/{train_cell_base.py => train_search.py} (58%) diff --git a/demo/bert/train_distill.py b/demo/bert/train_distill.py new file mode 100755 index 00000000..c9e042f5 --- /dev/null +++ b/demo/bert/train_distill.py @@ -0,0 +1,184 @@ +import numpy as np +from itertools import izip +import paddle.fluid as fluid +from paddleslim.teachers.bert.reader.cls import * +from paddleslim.nas.darts.search_space import AdaBERTClassifier +from paddle.fluid.dygraph.base import to_variable +from tqdm import tqdm +import os +import pickle + +import logging +from paddleslim.common import AvgrageMeter, get_logger +logger = get_logger(__name__, level=logging.INFO) + + +def valid_one_epoch(model, valid_loader, epoch, log_freq): + accs = AvgrageMeter() + ce_losses = AvgrageMeter() + model.student.eval() + + step_id = 0 + for valid_data in valid_loader(): + try: + loss, acc, ce_loss, _, _ = model._layers.loss(valid_data, epoch) + except: + loss, acc, ce_loss, _, _ = model.loss(valid_data, epoch) + + batch_size = valid_data[0].shape[0] + ce_losses.update(ce_loss.numpy(), batch_size) + accs.update(acc.numpy(), batch_size) + step_id += 1 + return ce_losses.avg[0], accs.avg[0] + + +def train_one_epoch(model, train_loader, optimizer, epoch, use_data_parallel, + log_freq): + total_losses = AvgrageMeter() + accs = AvgrageMeter() + ce_losses = AvgrageMeter() + kd_losses = AvgrageMeter() + model.student.train() + + step_id = 0 + for train_data in train_loader(): + batch_size = train_data[0].shape[0] + + if use_data_parallel: + total_loss, acc, ce_loss, kd_loss, _ = model._layers.loss( + train_data, epoch) + else: + total_loss, acc, ce_loss, kd_loss, _ = model.loss(train_data, + epoch) + + if use_data_parallel: + total_loss = model.scale_loss(total_loss) + total_loss.backward() + model.apply_collective_grads() + else: + total_loss.backward() + optimizer.minimize(total_loss) + model.clear_gradients() + total_losses.update(total_loss.numpy(), batch_size) + accs.update(acc.numpy(), batch_size) + ce_losses.update(ce_loss.numpy(), batch_size) + kd_losses.update(kd_loss.numpy(), batch_size) + + if step_id % log_freq == 0: + logger.info( + "Train Epoch {}, Step {}, Lr {:.6f} total_loss {:.6f}; ce_loss {:.6f}, kd_loss {:.6f}, train_acc {:.6f};". + format(epoch, step_id, + optimizer.current_step_lr(), total_losses.avg[0], + ce_losses.avg[0], kd_losses.avg[0], accs.avg[0])) + step_id += 1 + + +def main(): + # whether use multi-gpus + use_data_parallel = False + place = fluid.CUDAPlace(fluid.dygraph.parallel.Env( + ).dev_id) if use_data_parallel else fluid.CUDAPlace(0) + + BERT_BASE_PATH = "./data/pretrained_models/uncased_L-12_H-768_A-12" + vocab_path = BERT_BASE_PATH + "/vocab.txt" + data_dir = "./data/glue_data/MNLI/" + teacher_model_dir = "./data/teacher_model/steps_23000" + do_lower_case = True + num_samples = 392702 + # augmented dataset nums + # num_samples = 8016987 + max_seq_len = 128 + batch_size = 192 + hidden_size = 768 + emb_size = 768 + max_layer = 8 + epoch = 80 + log_freq = 10 + + device_num = fluid.dygraph.parallel.Env().nranks + use_fixed_gumbel = True + train_phase = "train" + val_phase = "dev" + step_per_epoch = int(num_samples / (batch_size * device_num)) + + with fluid.dygraph.guard(place): + if use_fixed_gumbel: + # make sure gumbel arch is constant + np.random.seed(1) + fluid.default_main_program().random_seed = 1 + model = AdaBERTClassifier( + 3, + n_layer=max_layer, + hidden_size=hidden_size, + emb_size=emb_size, + teacher_model=teacher_model_dir, + data_dir=data_dir, + use_fixed_gumbel=use_fixed_gumbel) + + learning_rate = fluid.dygraph.CosineDecay(2e-2, step_per_epoch, epoch) + + model_parameters = [] + for p in model.parameters(): + if (p.name not in [a.name for a in model.arch_parameters()] and + p.name not in + [a.name for a in model.teacher.parameters()]): + model_parameters.append(p) + + optimizer = fluid.optimizer.MomentumOptimizer( + learning_rate, + 0.9, + regularization=fluid.regularizer.L2DecayRegularizer(3e-4), + parameter_list=model_parameters) + + processor = MnliProcessor( + data_dir=data_dir, + vocab_path=vocab_path, + max_seq_len=max_seq_len, + do_lower_case=do_lower_case, + in_tokens=False) + + train_reader = processor.data_generator( + batch_size=batch_size, + phase=train_phase, + epoch=1, + dev_count=1, + shuffle=True) + dev_reader = processor.data_generator( + batch_size=batch_size, + phase=val_phase, + epoch=1, + dev_count=1, + shuffle=False) + + if use_data_parallel: + train_reader = fluid.contrib.reader.distributed_batch_reader( + train_reader) + + train_loader = fluid.io.DataLoader.from_generator( + capacity=128, + use_double_buffer=True, + iterable=True, + return_list=True) + dev_loader = fluid.io.DataLoader.from_generator( + capacity=128, + use_double_buffer=True, + iterable=True, + return_list=True) + + train_loader.set_batch_generator(train_reader, places=place) + dev_loader.set_batch_generator(dev_reader, places=place) + + if use_data_parallel: + strategy = fluid.dygraph.parallel.prepare_context() + model = fluid.dygraph.parallel.DataParallel(model, strategy) + + for epoch_id in range(epoch): + train_one_epoch(model, train_loader, optimizer, epoch_id, + use_data_parallel, log_freq) + loss, acc = valid_one_epoch(model, dev_loader, epoch_id, log_freq) + logger.info("dev set, ce_loss {:.6f}; acc: {:.6f};".format(loss, + acc)) + + +if __name__ == '__main__': + main() diff --git a/demo/bert/train_cell_base.py b/demo/bert/train_search.py similarity index 58% rename from demo/bert/train_cell_base.py rename to demo/bert/train_search.py index 983662be..00953965 100755 --- a/demo/bert/train_cell_base.py +++ b/demo/bert/train_search.py @@ -13,41 +13,23 @@ from paddleslim.common import AvgrageMeter, get_logger logger = get_logger(__name__, level=logging.INFO) -def count_parameters_in_MB(all_params): - parameters_number = 0 - for param in all_params: - if param.trainable: - parameters_number += np.prod(param.shape) - return parameters_number / 1e6 - - -def preprocess_data(data_generator, data_nums, phase, cached_data): - t = tqdm(total=data_nums) - data_list = [] - for data in tqdm(data_generator()): - # data_var = [] - # for d in data: - # tmp = fluid.core.LoDTensor() - # tmp.set(d, fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)) - # data_var.append(tmp) - data_list.append(data) - t.update(data[0].shape[0]) - t.close() - - logger.info("Saving {} data to {}".format(phase, cached_data + phase)) - f = open(cached_data + phase, 'wb') - pickle.dump(data_list, f) - f.close() - - return data_list - +def valid_one_epoch(model, valid_loader, epoch, log_freq): + accs = AvgrageMeter() + ce_losses = AvgrageMeter() + model.student.eval() -def generator_reader(data_list): - def wrapper(): - for d in data_list: - yield d + step_id = 0 + for valid_data in valid_loader(): + try: + loss, acc, ce_loss, _, _ = model._layers.loss(valid_data, epoch) + except: + loss, acc, ce_loss, _, _ = model.loss(valid_data, epoch) - return wrapper + batch_size = valid_data[0].shape[0] + ce_losses.update(ce_loss.numpy(), batch_size) + accs.update(acc.numpy(), batch_size) + step_id += 1 + return ce_losses.avg[0], accs.avg[0] def train_one_epoch(model, train_loader, valid_loader, optimizer, @@ -57,19 +39,17 @@ def train_one_epoch(model, train_loader, valid_loader, optimizer, ce_losses = AvgrageMeter() kd_losses = AvgrageMeter() val_accs = AvgrageMeter() - model.train() + model.student.train() step_id = 0 for train_data, valid_data in izip(train_loader(), valid_loader()): - #for train_data in train_loader(): batch_size = train_data[0].shape[0] - - # make sure arch on every gpu is same - np.random.seed(step_id * 2) - try: + # make sure arch on every gpu is same, otherwise an error will occurs + np.random.seed(step_id * 2 * (epoch + 1)) + if use_data_parallel: total_loss, acc, ce_loss, kd_loss, _ = model._layers.loss( train_data, epoch) - except: + else: total_loss, acc, ce_loss, kd_loss, _ = model.loss(train_data, epoch) @@ -86,12 +66,12 @@ def train_one_epoch(model, train_loader, valid_loader, optimizer, ce_losses.update(ce_loss.numpy(), batch_size) kd_losses.update(kd_loss.numpy(), batch_size) - # make sure arch on every gpu is same - np.random.seed(step_id * 2 + 1) - try: + # make sure arch on every gpu is same, otherwise an error will occurs + np.random.seed(step_id * 2 * (epoch + 1) + 1) + if use_data_parallel: arch_loss, _, _, _, arch_logits = model._layers.loss(valid_data, epoch) - except: + else: arch_loss, _, _, _, arch_logits = model.loss(valid_data, epoch) if use_data_parallel: @@ -101,39 +81,20 @@ def train_one_epoch(model, train_loader, valid_loader, optimizer, else: arch_loss.backward() arch_optimizer.minimize(arch_loss) - arch_optimizer.clear_gradients() + model.clear_gradients() probs = fluid.layers.softmax(arch_logits[-1]) val_acc = fluid.layers.accuracy(input=probs, label=valid_data[4]) val_accs.update(val_acc.numpy(), batch_size) if step_id % log_freq == 0: logger.info( - "Train Epoch {}, Step {}, Lr {:.6f} total_loss {:.6f}; ce_loss {:.6f}, kd_loss {:.6f}, train_acc {:.6f}, valid_acc {:.6f};". + "Train Epoch {}, Step {}, Lr {:.6f} total_loss {:.6f}; ce_loss {:.6f}, kd_loss {:.6f}, train_acc {:.6f}, search_valid_acc {:.6f};". format(epoch, step_id, optimizer.current_step_lr(), total_losses.avg[ 0], ce_losses.avg[0], kd_losses.avg[0], accs.avg[0], val_accs.avg[0])) - step_id += 1 - - -def valid_one_epoch(model, valid_loader, epoch, log_freq): - accs = AvgrageMeter() - ce_losses = AvgrageMeter() - model.eval() - - step_id = 0 - for valid_data in valid_loader(): - try: - loss, acc, ce_loss, _, _ = model._layers.loss(valid_data, epoch) - except: - loss, acc, ce_loss, _, _ = model.loss(valid_data, epoch) - - batch_size = valid_data[0].shape[0] - ce_losses.update(ce_loss.numpy(), batch_size) - accs.update(acc.numpy(), batch_size) step_id += 1 - return ce_losses.avg[0], accs.avg[0] def main(): @@ -145,33 +106,24 @@ def main(): BERT_BASE_PATH = "./data/pretrained_models/uncased_L-12_H-768_A-12" vocab_path = BERT_BASE_PATH + "/vocab.txt" data_dir = "./data/glue_data/MNLI/" - cached_data = "./data/glue_data/MNLI/cached_data_" teacher_model_dir = "./data/teacher_model/steps_23000" do_lower_case = True - #num_samples = 392702 - num_samples = 8016987 + num_samples = 392702 + # augmented dataset nums + # num_samples = 8016987 max_seq_len = 128 - # any modify of vocab/do_lower_case/max_seq_len requires update cached data batch_size = 128 hidden_size = 768 emb_size = 768 max_layer = 8 epoch = 80 log_freq = 10 - device_num = fluid.dygraph.parallel.Env().nranks - search = True - - if search: - use_fixed_gumbel = False - train_phase = "search_train" - val_phase = "search_valid" - step_per_epoch = int(num_samples / ((batch_size * 0.5) * device_num)) - else: - use_fixed_gumbel = True - train_phase = "train" - val_phase = "dev" - step_per_epoch = int(num_samples / (batch_size * device_num)) + + use_fixed_gumbel = False + train_phase = "search_train" + val_phase = "search_valid" + step_per_epoch = int(num_samples * 0.5 / ((batch_size) * device_num)) with fluid.dygraph.guard(place): model = AdaBERTClassifier( @@ -205,69 +157,31 @@ def main(): regularization=fluid.regularizer.L2Decay(1e-3), parameter_list=model.arch_parameters()) - if os.path.exists(cached_data + "train") and os.path.exists( - cached_data + "valid") + os.path.exists(cached_data + "dev"): - f = open(cached_data + "train", 'rb') - logger.info("loading preprocessed train data from {}".format( - cached_data + "train")) - train_data_list = pickle.load(f) - f.close() - - f = open(cached_data + "valid", 'rb') - logger.info("loading preprocessed valid data from {}".format( - cached_data + "valid")) - valid_data_list = pickle.load(f) - f.close() - - f = open(cached_data + "dev", 'rb') - logger.info("loading preprocessed dev data from {}".format( - cached_data + "dev")) - dev_data_list = pickle.load(f) - f.close() - else: - processor = MnliProcessor( - data_dir=data_dir, - vocab_path=vocab_path, - max_seq_len=max_seq_len, - do_lower_case=do_lower_case, - in_tokens=False) - - train_reader = processor.data_generator( - batch_size=batch_size, - phase=train_phase, - epoch=1, - dev_count=1, - shuffle=True) - valid_reader = processor.data_generator( - batch_size=batch_size, - phase=val_phase, - epoch=1, - dev_count=1, - shuffle=True) - dev_reader = processor.data_generator( - batch_size=batch_size, - phase="dev", - epoch=1, - dev_count=1, - shuffle=False) - - train_data_nums = processor.get_num_examples(train_phase) - valid_data_nums = processor.get_num_examples(val_phase) - dev_data_nums = processor.get_num_examples("dev") - - logger.info("Preprocessing train data") - train_data_list = preprocess_data(train_reader, train_data_nums, - "train", cached_data) - logger.info("Preprocessing valid data") - valid_data_list = preprocess_data(valid_reader, valid_data_nums, - "valid", cached_data) - logger.info("Preprocessing dev data") - dev_data_list = preprocess_data(dev_reader, dev_data_nums, "dev", - cached_data) - - train_reader = generator_reader(train_data_list) - valid_reader = generator_reader(valid_data_list) - dev_reader = generator_reader(dev_data_list) + processor = MnliProcessor( + data_dir=data_dir, + vocab_path=vocab_path, + max_seq_len=max_seq_len, + do_lower_case=do_lower_case, + in_tokens=False) + + train_reader = processor.data_generator( + batch_size=batch_size, + phase=train_phase, + epoch=1, + dev_count=1, + shuffle=True) + valid_reader = processor.data_generator( + batch_size=batch_size, + phase=val_phase, + epoch=1, + dev_count=1, + shuffle=True) + dev_reader = processor.data_generator( + batch_size=batch_size, + phase="dev", + epoch=1, + dev_count=1, + shuffle=False) if use_data_parallel: train_reader = fluid.contrib.reader.distributed_batch_reader( @@ -304,12 +218,12 @@ def main(): arch_optimizer, epoch_id, use_data_parallel, log_freq) loss, acc = valid_one_epoch(model, dev_loader, epoch_id, log_freq) - logger.info("Valid set2, ce_loss {:.6f}; acc: {:.6f};".format(loss, - acc)) + logger.info("dev set, ce_loss {:.6f}; acc: {:.6f};".format(loss, + acc)) - try: + if use_data_parallel: print(model.student._encoder.alphas.numpy()) - except: + else: print(model._layers.student._encoder.alphas.numpy()) print("=" * 100) diff --git a/paddleslim/nas/darts/search_space/conv_bert/cls.py b/paddleslim/nas/darts/search_space/conv_bert/cls.py index 6c8e526b..fa7e1fd2 100644 --- a/paddleslim/nas/darts/search_space/conv_bert/cls.py +++ b/paddleslim/nas/darts/search_space/conv_bert/cls.py @@ -72,6 +72,7 @@ class AdaBERTClassifier(Layer): ) self.teacher = BERTClassifier( num_labels, model_path=self._teacher_model) + # global setting, will be overwritten when training(about 1% acc loss) self.teacher.eval() #self.teacher.test(self._data_dir) print( @@ -100,24 +101,12 @@ class AdaBERTClassifier(Layer): format(t_emb.name, s_emb.name)) def forward(self, data_ids, epoch): - src_ids = data_ids[0] - position_ids = data_ids[1] - sentence_ids = data_ids[2] - return self.student(src_ids, position_ids, sentence_ids, epoch) + return self.student(data_ids, epoch) def arch_parameters(self): return self.student.arch_parameters() - def ce(self, logits): - logits = np.exp(logits - np.max(logits)) - logits = logits / logits.sum(axis=0) - return logits - def loss(self, data_ids, epoch): - # src_ids = data_ids[0] - # position_ids = data_ids[1] - # sentence_ids = data_ids[2] - # input_mask = data_ids[3] labels = data_ids[4] s_logits = self.student(data_ids, epoch) diff --git a/paddleslim/nas/darts/search_space/conv_bert/model/transformer_encoder.py b/paddleslim/nas/darts/search_space/conv_bert/model/transformer_encoder.py index b8d47a28..ccdeb7e7 100755 --- a/paddleslim/nas/darts/search_space/conv_bert/model/transformer_encoder.py +++ b/paddleslim/nas/darts/search_space/conv_bert/model/transformer_encoder.py @@ -22,14 +22,15 @@ import numpy as np import paddle import paddle.fluid as fluid from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, Conv2D, BatchNorm, Pool2D, to_variable +from paddle.fluid.dygraph import to_variable from paddle.fluid.initializer import NormalInitializer from paddle.fluid import ParamAttr from paddle.fluid.initializer import MSRA, ConstantInitializer ConvBN_PRIMITIVES = [ 'std_conv_bn_3', 'std_conv_bn_5', 'std_conv_bn_7', 'dil_conv_bn_3', - 'dil_conv_bn_5', 'dil_conv_bn_7', 'avg_pool_3', 'max_pool_3', 'none', - 'skip_connect' + 'dil_conv_bn_5', 'dil_conv_bn_7', 'avg_pool_3', 'max_pool_3', + 'skip_connect', 'none' ] @@ -69,12 +70,12 @@ class MixedOp(fluid.dygraph.Layer): self._ops = fluid.dygraph.LayerList(ops) - def forward(self, x, weights, index): + def forward(self, x, weights): # out = fluid.layers.sums( # [weights[i] * op(x) for i, op in enumerate(self._ops)]) # return out - for i in range(len(self._ops)): + for i in range(len(weights.numpy())): if weights[i].numpy() != 0: return self._ops[i](x) * weights[i] @@ -90,13 +91,13 @@ def gumbel_softmax(logits, epoch, temperature=1.0, hard=True, eps=1e-10): if hard: maxes = fluid.layers.reduce_max(logits, dim=1, keep_dim=True) hard = fluid.layers.cast((logits == maxes), logits.dtype) - index = np.argmax(hard.numpy(), axis=1) out = hard - logits.detach() + logits # tmp.stop_gradient = True # out = tmp + logits else: out = logits - return out, index + + return out class Zero(fluid.dygraph.Layer): @@ -174,7 +175,7 @@ class Cell(fluid.dygraph.Layer): ops.append(op) self._ops = fluid.dygraph.LayerList(ops) - def forward(self, s0, s1, weights, index): + def forward(self, s0, s1, weights): s0 = self.preprocess0(s0) s1 = self.preprocess1(s1) @@ -182,8 +183,7 @@ class Cell(fluid.dygraph.Layer): offset = 0 for i in range(self._steps): s = fluid.layers.sums([ - self._ops[offset + j](h, weights[offset + j], - index[offset + j]) + self._ops[offset + j](h, weights[offset + j]) for j, h in enumerate(states) ]) offset += len(states) @@ -262,15 +262,6 @@ class EncoderLayer(Layer): self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True) - self.BN = BatchNorm( - num_channels=self._n_channel, - param_attr=fluid.ParamAttr( - initializer=fluid.initializer.Constant(value=1), - trainable=False), - bias_attr=fluid.ParamAttr( - initializer=fluid.initializer.Constant(value=0), - trainable=False)) - self.bns = [] self.outs = [] for i in range(self._n_layer): @@ -292,22 +283,27 @@ class EncoderLayer(Layer): self._bns = fluid.dygraph.LayerList(self.bns) self._outs = fluid.dygraph.LayerList(self.outs) - self.pooled_fc = Linear( - input_dim=self._n_channel, - output_dim=self._hidden_size, - param_attr=fluid.ParamAttr( - name=self.full_name() + "pooled_fc.w_0", - initializer=fluid.initializer.TruncatedNormal(scale=1.0)), - bias_attr=fluid.ParamAttr(name=self.full_name() + "pooled_fc.b_0"), - act="tanh") - self.use_fixed_gumbel = use_fixed_gumbel - self.gumbel_alphas = gumbel_softmax(self.alphas, 0)[0].detach() - #print("gumbel_alphas: \n", self.gumbel_alphas.numpy()) + #self.gumbel_alphas = gumbel_softmax(self.alphas, 0).detach() + + mrpc_arch = [ + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], # std_conv7 0 # node 0 + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], # dil_conv5 1 + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], # std_conv7 0 # node 1 + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], # dil_conv5 1 + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], # zero 2 + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], # zero 0 # node2 + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # std_conv3 1 + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], # zero 2 + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0] # dil_conv3 3 + ] + self.gumbel_alphas = to_variable( + np.array(mrpc_arch).astype(np.float32)) + print("gumbel_alphas: \n", self.gumbel_alphas.numpy()) def forward(self, enc_input_0, enc_input_1, epoch, flops=[], model_size=[]): - alphas, index = self.gumbel_alphas if self.use_fixed_gumbel else gumbel_softmax( + alphas = self.gumbel_alphas if self.use_fixed_gumbel else gumbel_softmax( self.alphas, epoch) s0 = fluid.layers.unsqueeze(enc_input_0, [1]) @@ -317,7 +313,7 @@ class EncoderLayer(Layer): enc_outputs = [] for i in range(self._n_layer): - s0, s1 = s1, self._cells[i](s0, s1, alphas, index) + s0, s1 = s1, self._cells[i](s0, s1, alphas) # (bs, n_channel, seq_len, 1) tmp = self._bns[i](s1) tmp = self.pool2d_avg(tmp) diff --git a/paddleslim/teachers/bert/reader/cls.py b/paddleslim/teachers/bert/reader/cls.py index 540efd46..e384bdce 100644 --- a/paddleslim/teachers/bert/reader/cls.py +++ b/paddleslim/teachers/bert/reader/cls.py @@ -380,7 +380,7 @@ class MnliProcessor(DataProcessor): def get_train_examples(self, data_dir): """See base class.""" return self._create_examples( - self._read_tsv(os.path.join(data_dir, "train_aug.tsv")), "train") + self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") def get_dev_examples(self, data_dir): """See base class.""" -- GitLab