# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Modified from espnet(https://github.com/espnet/espnet) """Repeat the same layer definition.""" import paddle class MultiSequential(paddle.nn.Sequential): """Multi-input multi-output paddle.nn.Sequential.""" def forward(self, *args): """Repeat.""" for m in self: args = m(*args) return args def repeat(N, fn): """Repeat module N times. Parameters ---------- N : int Number of repeat time. fn : Callable Function to generate module. Returns ---------- MultiSequential Repeated model instance. """ return MultiSequential(* [fn(n) for n in range(N)])