diff --git a/configs/animeganv2.yaml b/configs/animeganv2.yaml index af082ed6bfe8db0fdb4353228fde0bf655cca1da..5e4b650d2a7ecd575b82766d5abb0ff1ed9ebf67 100644 --- a/configs/animeganv2.yaml +++ b/configs/animeganv2.yaml @@ -10,6 +10,7 @@ model: gan_criterion: name: GANLoss gan_mode: lsgan + # use your trained path pretrain_ckpt: output_dir/AnimeGANV2PreTrainModel-2020-11-29-17-02/epoch_2_checkpoint.pdparams g_adv_weight: 300. d_adv_weight: 300. @@ -47,21 +48,21 @@ dataset: test: name: SingleDataset dataroot: data/animedataset/test/HR_photo - max_dataset_size: inf - direction: BtoA - input_nc: 3 - output_nc: 3 - serial_batches: False - pool_size: 50 - transforms: - - name: ResizeToScale - size: [256, 256] - scale: 32 - interpolation: bilinear - - name: Transpose - - name: Normalize - mean: [127.5, 127.5, 127.5] - std: [127.5, 127.5, 127.5] + preprocess: + - name: LoadImageFromFile + key: A + - name: Transforms + input_keys: [A] + pipeline: + - name: ResizeToScale + size: [256, 256] + scale: 32 + interpolation: bilinear + - name: Transpose + - name: Normalize + mean: [127.5, 127.5, 127.5] + std: [127.5, 127.5, 127.5] + keys: [image, image] lr_scheduler: name: LinearDecay diff --git a/configs/dcgan_mnist.yaml b/configs/dcgan_mnist.yaml index a423dd03657b6fb157cbc9a1d6d11edee06aaa6a..931c679907eb0e01bbad5ae17dcef53cf124c3df 100644 --- a/configs/dcgan_mnist.yaml +++ b/configs/dcgan_mnist.yaml @@ -21,44 +21,33 @@ model: dataset: train: - name: SingleDataset - dataroot: data/mnist/train + name: CommonVisionDataset + dataset_name: MNIST + num_workers: 0 batch_size: 128 - preprocess: - - name: LoadImageFromFile - key: A - - name: Transfroms - input_keys: [A] - pipeline: - - name: Resize - size: [64, 64] - interpolation: 'bicubic' #cv2.INTER_CUBIC - keys: [image, image] - - name: Transpose - keys: [image, image] - - name: Normalize - mean: [127.5, 127.5, 127.5] - std: [127.5, 127.5, 127.5] - keys: [image, image] + return_label: False + transforms: + - name: Resize + size: [64, 64] + interpolation: 'bicubic' #cv2.INTER_CUBIC + - name: Normalize + mean: [127.5] + std: [127.5] + keys: [image] test: - name: SingleDataset - dataroot: data/mnist/test - preprocess: - - name: LoadImageFromFile - key: A - - name: Transforms - input_keys: [A] - pipeline: - - name: Resize - size: [64, 64] - interpolation: 'bicubic' #cv2.INTER_CUBIC - keys: [image, image] - - name: Transpose - keys: [image, image] - - name: Normalize - mean: [127.5, 127.5, 127.5] - std: [127.5, 127.5, 127.5] - keys: [image, image] + name: CommonVisionDataset + dataset_name: MNIST + num_workers: 0 + batch_size: 128 + return_label: False + transforms: + - name: Resize + size: [64, 64] + interpolation: 'bicubic' #cv2.INTER_CUBIC + - name: Normalize + mean: [127.5] + std: [127.5] + keys: [image] lr_scheduler: name: LinearDecay diff --git a/configs/pix2pix_facades.yaml b/configs/pix2pix_facades.yaml index c7ed53754992052ecf708be7ed2346d6ba070e3a..0123bf9afc4e17108055c9c6c7f23a9e5acc6ddf 100644 --- a/configs/pix2pix_facades.yaml +++ b/configs/pix2pix_facades.yaml @@ -64,7 +64,7 @@ dataset: preprocess: - name: LoadImageFromFile key: pair - - name: Transforms + - name: Transforms input_keys: [A, B] pipeline: - name: Resize diff --git a/docs/en_US/tutorials/styleganv2.md b/docs/en_US/tutorials/styleganv2.md index 4b151261ed9b3540d6b1a4ccca4d5a227ac7744a..ed54e77c3297b5322703dc0d4a7df0414749f3c2 100644 --- a/docs/en_US/tutorials/styleganv2.md +++ b/docs/en_US/tutorials/styleganv2.md @@ -92,6 +92,21 @@ train model python tools/main.py -c configs/stylegan_v2_256_ffhq.yaml ``` +### Inference + +When you finish training, you need to use ``tools/extract_weight.py`` to extract the corresponding weights. +``` +python tools/extract_weight.py output_dir/YOUR_TRAINED_WEIGHT.pdparams --net-name gen_ema --output YOUR_WEIGHT_PATH.pdparams +``` + +Then use ``applications/tools/styleganv2.py`` to get results +``` +python tools/styleganv2.py --output_path stylegan01 --weight_path YOUR_WEIGHT_PATH.pdparams --size 256 +``` + +Note: ``--size`` should be same with your config file. + + ## Results Random Samples: diff --git a/docs/zh_CN/tutorials/styleganv2.md b/docs/zh_CN/tutorials/styleganv2.md index 7ebab5e1ff14af2fdca8769515b40736491a6029..91f94593e0944e640816bbf7aeb51024b2c2f421 100644 --- a/docs/zh_CN/tutorials/styleganv2.md +++ b/docs/zh_CN/tutorials/styleganv2.md @@ -54,9 +54,56 @@ python -u tools/styleganv2.py \ - n_col: 采样的图片的列数 - cpu: 是否使用cpu推理,若不使用,请在命令中去除 -### 训练(TODO) +### 训练 + +#### 准备数据集 +你可以从[这里](https://drive.google.com/drive/folders/1u2xu7bSrWxrbUxk-dT-UvEJq8IjdmNTP)下载对应的数据集 + +为了方便,我们提供了[images256x256.tar](https://paddlegan.bj.bcebos.com/datasets/images256x256.tar) + +目前的配置文件默认数据集的结构如下: + ``` + PaddleGAN + ├── data + ├── ffhq + ├──images1024x1024 + ├── 00000.png + ├── 00001.png + ├── 00002.png + ├── 00003.png + ├── 00004.png + ├──images256x256 + ├── 00000.png + ├── 00001.png + ├── 00002.png + ├── 00003.png + ├── 00004.png + ├──custom_data + ├── img0.png + ├── img1.png + ├── img2.png + ├── img3.png + ├── img4.png + ... + ``` + +启动训练 +``` +python tools/main.py -c configs/stylegan_v2_256_ffhq.yaml +``` + +### 推理 + +训练结束后,需要使用 ``tools/extract_weight.py`` 来提取对应的权重给``applications/tools/styleganv2.py``来进行推理. +``` +python tools/extract_weight.py output_dir/YOUR_TRAINED_WEIGHT.pdparams --net-name gen_ema --output stylegan_config_f.pdparams +``` + +``` +python tools/styleganv2.py --output_path stylegan01 --weight_path YOUR_WEIGHT_PATH.pdparams --size 256 +``` -未来还将添加训练脚本方便用户训练出更多类型的 StyleGAN V2 图像生成器。 +注意: ``--size`` 这个参数要和配置文件中的参数保持一致. ## 生成结果展示 diff --git a/ppgan/datasets/animeganv2_dataset.py b/ppgan/datasets/animeganv2_dataset.py index d30e8dfd5ab55dbcdb506d51940920689a969ec4..66a58ff7778c5a3e627297e8af88dac83cb0b4fd 100644 --- a/ppgan/datasets/animeganv2_dataset.py +++ b/ppgan/datasets/animeganv2_dataset.py @@ -20,7 +20,7 @@ from .base_dataset import BaseDataset from .image_folder import ImageFolder from .builder import DATASETS -from .transforms.builder import build_transforms +from .preprocess.builder import build_transforms @DATASETS.register() diff --git a/ppgan/datasets/common_vision_dataset.py b/ppgan/datasets/common_vision_dataset.py index 2e69104603defab1c03705b02043bcb535f18079..8b03926594eae35242b1fad31984f413d260eede 100644 --- a/ppgan/datasets/common_vision_dataset.py +++ b/ppgan/datasets/common_vision_dataset.py @@ -17,7 +17,7 @@ import paddle from .builder import DATASETS from .base_dataset import BaseDataset -from .transforms.builder import build_transforms +from .preprocess.builder import build_transforms @DATASETS.register() diff --git a/ppgan/datasets/preprocess/builder.py b/ppgan/datasets/preprocess/builder.py index e25147c8c6bfb0d6206aa93a2a905ee411183040..bb6c7dec4958194b1984b8930bce15b7535facb8 100644 --- a/ppgan/datasets/preprocess/builder.py +++ b/ppgan/datasets/preprocess/builder.py @@ -62,3 +62,15 @@ def build_preprocess(cfg): preproccess = Compose(preproccess) return preproccess + + +def build_transforms(cfg): + transforms = [] + + for trans_cfg in cfg: + temp_trans_cfg = copy.deepcopy(trans_cfg) + name = temp_trans_cfg.pop('name') + transforms.append(TRANSFORMS.get(name)(**temp_trans_cfg)) + + transforms = Compose(transforms) + return transforms diff --git a/ppgan/datasets/preprocess/transforms.py b/ppgan/datasets/preprocess/transforms.py index 45901932d96499cb5520f5bc3b5f4cb705162a0c..73481e84ae7b220b28c244ba738ff21c96a40f17 100644 --- a/ppgan/datasets/preprocess/transforms.py +++ b/ppgan/datasets/preprocess/transforms.py @@ -264,3 +264,74 @@ class SRNoise(T.BaseTransform): image = image + normed_noise image = np.clip(image, 0., 1.) return image + + +@TRANSFORMS.register() +class Add(T.BaseTransform): + def __init__(self, value, keys=None): + """Initialize Add Transform + + Parameters: + value (List[int]) -- the [r,g,b] value will add to image by pixel wise. + """ + super().__init__(keys=keys) + self.value = value + + def _get_params(self, inputs): + params = {} + params['value'] = self.value + return params + + def _apply_image(self, image): + return np.clip(image + self.params['value'], 0, 255).astype('uint8') + # return custom_F.add(image, self.params['value']) + + +@TRANSFORMS.register() +class ResizeToScale(T.BaseTransform): + def __init__(self, + size: int, + scale: int, + interpolation='bilinear', + keys=None): + """Initialize ResizeToScale Transform + + Parameters: + size (List[int]) -- the minimum target size + scale (List[int]) -- the stride scale + interpolation (Optional[str]) -- interpolation method + """ + super().__init__(keys=keys) + if isinstance(size, int): + self.size = (size, size) + else: + self.size = size + self.scale = scale + self.interpolation = interpolation + + def _get_params(self, inputs): + image = inputs[self.keys.index('image')] + hw = image.shape[:2] + params = {} + params['taget_size'] = self.reduce_to_scale(hw, self.size[::-1], + self.scale) + return params + + @staticmethod + def reduce_to_scale(img_hw, min_hw, scale): + im_h, im_w = img_hw + if im_h <= min_hw[0]: + im_h = min_hw[0] + else: + x = im_h % scale + im_h = im_h - x + + if im_w < min_hw[1]: + im_w = min_hw[1] + else: + y = im_w % scale + im_w = im_w - y + return (im_h, im_w) + + def _apply_image(self, image): + return F.resize(image, self.params['taget_size'], self.interpolation) diff --git a/ppgan/datasets/transforms/__init__.py b/ppgan/datasets/transforms/__init__.py deleted file mode 100644 index acb1b770db0c05f74cce8e0350be8d0ef4e96b89..0000000000000000000000000000000000000000 --- a/ppgan/datasets/transforms/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .transforms import ResizeToScale, PairedRandomCrop, PairedRandomHorizontalFlip, Add diff --git a/ppgan/datasets/transforms/builder.py b/ppgan/datasets/transforms/builder.py deleted file mode 100644 index 12b05a6c0524274e0711938c51e77ed855a056b2..0000000000000000000000000000000000000000 --- a/ppgan/datasets/transforms/builder.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import traceback -import paddle -from ...utils.registry import Registry - -TRANSFORMS = Registry("TRANSFORMS") - - -class Compose(object): - """ - Composes several transforms together use for composing list of transforms - together for a dataset transform. - Args: - transforms (list): List of transforms to compose. - Returns: - A compose object which is callable, __call__ for this Compose - object will call each given :attr:`transforms` sequencely. - """ - def __init__(self, transforms): - self.transforms = transforms - - def __call__(self, data): - for f in self.transforms: - try: - data = f(data) - except Exception as e: - print(f) - stack_info = traceback.format_exc() - print("fail to perform transform [{}] with error: " - "{} and stack:\n{}".format(f, e, str(stack_info))) - raise e - return data - - -def build_transforms(cfg): - transforms = [] - - for trans_cfg in cfg: - temp_trans_cfg = copy.deepcopy(trans_cfg) - name = temp_trans_cfg.pop('name') - transforms.append(TRANSFORMS.get(name)(**temp_trans_cfg)) - - transforms = Compose(transforms) - return transforms diff --git a/ppgan/datasets/transforms/functional.py b/ppgan/datasets/transforms/functional.py deleted file mode 100644 index 83350f58618e1fa60e613098900774acb9ae285e..0000000000000000000000000000000000000000 --- a/ppgan/datasets/transforms/functional.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. -from __future__ import division - -from . import functional_cv2 as F_cv2 -from paddle.vision.transforms.functional import _is_numpy_image, _is_pil_image - -__all__ = ['add'] - - -def add(pic, value): - if not (_is_pil_image(pic) or _is_numpy_image(pic)): - raise TypeError('pic should be PIL Image or ndarray. Got {}'.format( - type(pic))) - - if _is_pil_image(pic): - raise NotImplementedError('add not support pil image') - else: - return F_cv2.add(pic, value) diff --git a/ppgan/datasets/transforms/functional_cv2.py b/ppgan/datasets/transforms/functional_cv2.py deleted file mode 100644 index e688a974ae00f49e4be1099aa01a43326d347156..0000000000000000000000000000000000000000 --- a/ppgan/datasets/transforms/functional_cv2.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. -from __future__ import division -import numpy as np - - -def add(image, value): - return np.clip(image + value, 0, 255).astype('uint8') diff --git a/ppgan/datasets/transforms/transforms.py b/ppgan/datasets/transforms/transforms.py deleted file mode 100644 index 540644acce336df4a77cdf1c207b20ad4650d1df..0000000000000000000000000000000000000000 --- a/ppgan/datasets/transforms/transforms.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -import random -import numbers -import collections -import numpy as np - -import paddle.vision.transforms as T -import paddle.vision.transforms.functional as F - -from . import functional as custom_F -from .builder import TRANSFORMS - -if sys.version_info < (3, 3): - Sequence = collections.Sequence - Iterable = collections.Iterable -else: - Sequence = collections.abc.Sequence - Iterable = collections.abc.Iterable - -TRANSFORMS.register(T.Resize) -TRANSFORMS.register(T.RandomCrop) -TRANSFORMS.register(T.RandomHorizontalFlip) -TRANSFORMS.register(T.Normalize) -TRANSFORMS.register(T.Transpose) -TRANSFORMS.register(T.Grayscale) - - -@TRANSFORMS.register() -class PairedRandomCrop(T.RandomCrop): - def __init__(self, size, keys=None): - super().__init__(size, keys=keys) - - if isinstance(size, int): - self.size = (size, size) - else: - self.size = size - - def _get_params(self, inputs): - image = inputs[self.keys.index('image')] - params = {} - params['crop_prams'] = self._get_param(image, self.size) - return params - - def _apply_image(self, img): - i, j, h, w = self.params['crop_prams'] - return F.crop(img, i, j, h, w) - - -@TRANSFORMS.register() -class PairedRandomHorizontalFlip(T.RandomHorizontalFlip): - def __init__(self, prob=0.5, keys=None): - super().__init__(prob, keys=keys) - - def _get_params(self, inputs): - params = {} - params['flip'] = random.random() < self.prob - return params - - def _apply_image(self, image): - if self.params['flip']: - return F.hflip(image) - return image - - -@TRANSFORMS.register() -class Add(T.BaseTransform): - def __init__(self, value, keys=None): - """Initialize Add Transform - - Parameters: - value (List[int]) -- the [r,g,b] value will add to image by pixel wise. - """ - super().__init__(keys=keys) - self.value = value - - def _get_params(self, inputs): - params = {} - params['value'] = self.value - return params - - def _apply_image(self, image): - return custom_F.add(image, self.params['value']) - - -@TRANSFORMS.register() -class ResizeToScale(T.BaseTransform): - def __init__(self, - size: int, - scale: int, - interpolation='bilinear', - keys=None): - """Initialize ResizeToScale Transform - - Parameters: - size (List[int]) -- the minimum target size - scale (List[int]) -- the stride scale - interpolation (Optional[str]) -- interpolation method - """ - super().__init__(keys=keys) - if isinstance(size, int): - self.size = (size, size) - else: - self.size = size - self.scale = scale - self.interpolation = interpolation - - def _get_params(self, inputs): - image = inputs[self.keys.index('image')] - hw = image.shape[:2] - params = {} - params['taget_size'] = self.reduce_to_scale(hw, self.size[::-1], - self.scale) - return params - - @staticmethod - def reduce_to_scale(img_hw, min_hw, scale): - im_h, im_w = img_hw - if im_h <= min_hw[0]: - im_h = min_hw[0] - else: - x = im_h % scale - im_h = im_h - x - - if im_w < min_hw[1]: - im_w = min_hw[1] - else: - y = im_w % scale - im_w = im_w - y - return (im_h, im_w) - - def _apply_image(self, image): - return F.resize(image, self.params['taget_size'], self.interpolation) diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index d171b80c7239e991941fdb189960f512d5ac1d45..4b0b5ea9de65a96852e1b4d684a9824d55f0e454 100644 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -165,6 +165,8 @@ class Trainer: iter_loader = IterLoader(self.train_dataloader) + # set model.is_train = True + self.model.setup_train_mode(is_train=True) while self.current_iter < (self.total_iters + 1): self.current_epoch = iter_loader.epoch self.inner_iter = self.current_iter % self.iters_per_epoch @@ -219,6 +221,9 @@ class Trainer: for metric in self.metrics.values(): metric.reset() + # set model.is_train = False + self.model.setup_train_mode(is_train=False) + for i in range(self.max_eval_steps): data = next(iter_loader) self.model.setup_input(data) @@ -289,7 +294,9 @@ class Trainer: message += 'ips: %.5f images/s ' % self.ips if hasattr(self, 'step_time'): - eta = self.step_time * (self.total_iters - self.current_iter - 1) + eta = self.step_time * (self.total_iters - self.current_iter) + eta = eta if eta > 0 else 0 + eta_str = str(datetime.timedelta(seconds=int(eta))) message += f'eta: {eta_str}' diff --git a/ppgan/models/animeganv2_model.py b/ppgan/models/animeganv2_model.py index 79914328d8c623a040333fdb0dbfd84638a0025b..7bceb36c70d0eff6daad7690b8418d94514e6b58 100644 --- a/ppgan/models/animeganv2_model.py +++ b/ppgan/models/animeganv2_model.py @@ -83,7 +83,7 @@ class AnimeGANV2Model(BaseModel): self.smooth_gray = paddle.to_tensor(input['smooth_gray']) else: self.real = paddle.to_tensor(input['A']) - self.image_paths = input['A_paths'] + self.image_paths = input['A_path'] def forward(self): """Run forward pass; called by both functions and .""" diff --git a/ppgan/models/dc_gan_model.py b/ppgan/models/dc_gan_model.py index b13e494af2d83d34aecc878ccaa8e505d7327796..220e05c0d0cf4ec690f4fabbbb3dd305c2b2f9ff 100644 --- a/ppgan/models/dc_gan_model.py +++ b/ppgan/models/dc_gan_model.py @@ -56,8 +56,9 @@ class DCGANModel(BaseModel): input (dict): include the data itself and its metadata information. """ # get 1-channel gray image, or 3-channel color image - self.real = paddle.to_tensor(input['A']) - self.image_paths = input['A_path'] + self.real = paddle.to_tensor(input['img']) + if 'img_path' in input: + self.image_paths = input['A_path'] def forward(self): """Run forward pass; called by both functions and .""" diff --git a/ppgan/models/pix2pix_model.py b/ppgan/models/pix2pix_model.py index 6aec7e30a512484cc86ba9b6896bd8d542c6848c..d2ec0de4e1437ecb4211f257f15997ff44cf3bf8 100644 --- a/ppgan/models/pix2pix_model.py +++ b/ppgan/models/pix2pix_model.py @@ -74,10 +74,8 @@ class Pix2PixModel(BaseModel): AtoB = self.direction == 'AtoB' - self.real_A = paddle.to_tensor( - input['A' if AtoB else 'B']) - self.real_B = paddle.to_tensor( - input['B' if AtoB else 'A']) + self.real_A = paddle.to_tensor(input['A' if AtoB else 'B']) + self.real_B = paddle.to_tensor(input['B' if AtoB else 'A']) self.image_paths = input['A_path' if AtoB else 'B_path'] @@ -141,3 +139,7 @@ class Pix2PixModel(BaseModel): optimizers['optimG'].clear_grad() self.backward_G() optimizers['optimG'].step() + + def test_iter(self, metrics=None): + with paddle.no_grad(): + self.forward() diff --git a/tools/extract_weight.py b/tools/extract_weight.py new file mode 100644 index 0000000000000000000000000000000000000000..56dffa9f1570fe8bed6bf9e510d9fec66a5de60c --- /dev/null +++ b/tools/extract_weight.py @@ -0,0 +1,40 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import argparse + + +def parse_args(): + parser = argparse.ArgumentParser( + description='This script extracts weights from a checkpoint') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument('--net-name', + type=str, + help='net name in checkpoint dict') + parser.add_argument('--output', type=str, help='destination file name') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + assert args.output.endswith(".pdparams") + ckpt = paddle.load(args.checkpoint) + state_dict = ckpt[args.net_name] + paddle.save(state_dict, args.output) + + +if __name__ == '__main__': + main()