# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. class TrainTaskConfig(object): # support both CPU and GPU now. use_gpu = True # the epoch number to train. pass_num = 30 # the number of sequences contained in a mini-batch. # deprecated, set batch_size in args. batch_size = 32 # the hyper parameters for Adam optimizer. # This static learning_rate will be multiplied to the LearningRateScheduler # derived learning rate the to get the final learning rate. learning_rate = 2.0 beta1 = 0.9 beta2 = 0.997 eps = 1e-9 # the parameters for learning rate scheduling. warmup_steps = 8000 # the weight used to mix up the ground-truth distribution and the fixed # uniform distribution in label smoothing when training. # Set this as zero if label smoothing is not wanted. label_smooth_eps = 0.1 # the directory for saving trained models. model_dir = "trained_models" # the directory for saving checkpoints. ckpt_dir = "trained_ckpts" # the directory for loading checkpoint. # If provided, continue training from the checkpoint. ckpt_path = None # the parameter to initialize the learning rate scheduler. # It should be provided if use checkpoints, since the checkpoint doesn't # include the training step counter currently. start_step = 0 # the frequency to save trained models. save_freq = 10000 class InferTaskConfig(object): use_gpu = True # the number of examples in one run for sequence generation. batch_size = 10 # the parameters for beam search. beam_size = 5 max_out_len = 256 # the number of decoded sentences to output. n_best = 1 # the flags indicating whether to output the special tokens. output_bos = False output_eos = False output_unk = True # the directory for loading the trained model. model_path = "trained_models/pass_1.infer.model" class ModelHyperParams(object): # These following five vocabularies related configurations will be set # automatically according to the passed vocabulary path and special tokens. # size of source word dictionary. src_vocab_size = 10000 # size of target word dictionay trg_vocab_size = 10000 # size of phone dictionary phone_vocab_size = 1000 # ratio of phoneme embeddings beta = 0.0 # index for token bos_idx = 0 # index for token eos_idx = 1 # index for token unk_idx = 2 # index for in phonemes phone_pad_idx = 0 # max length of sequences deciding the size of position encoding table. max_length = 256 # the dimension for word embeddings, which is also the last dimension of # the input and output of multi-head attention, position-wise feed-forward # networks, encoder and decoder. d_model = 512 # size of the hidden layer in position-wise feed-forward networks. d_inner_hid = 2048 # the dimension that keys are projected to for dot-product attention. d_key = 64 # the dimension that values are projected to for dot-product attention. d_value = 64 # number of head used in multi-head attention. n_head = 8 # number of sub-layers to be stacked in the encoder and decoder. n_layer = 6 # dropout rates of different modules. prepostprocess_dropout = 0.1 attention_dropout = 0.1 relu_dropout = 0.1 # to process before each sub-layer preprocess_cmd = "n" # layer normalization # to process after each sub-layer postprocess_cmd = "da" # dropout + residual connection # random seed used in dropout for CE. dropout_seed = None # the flag indicating whether to share embedding and softmax weights. # vocabularies in source and target should be same for weight sharing. weight_sharing = True def merge_cfg_from_list(cfg_list, g_cfgs): """ Set the above global configurations using the cfg_list. """ assert len(cfg_list) % 2 == 0 for key, value in zip(cfg_list[0::2], cfg_list[1::2]): for g_cfg in g_cfgs: if hasattr(g_cfg, key): try: value = eval(value) except Exception: # for file path pass setattr(g_cfg, key, value) break