Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
01cb542f
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
1 年多 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
01cb542f
编写于
2月 22, 2023
作者:
W
wangna11BD
提交者:
GitHub
2月 22, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support @to_static traing for edvr pix2pix and esrgan (#750)
上级
461bc8cd
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
28 addition
and
9 deletion
+28
-9
configs/edvr_m_wo_tsa.yaml
configs/edvr_m_wo_tsa.yaml
+2
-0
configs/esrgan_psnr_x4_div2k.yaml
configs/esrgan_psnr_x4_div2k.yaml
+2
-0
configs/pix2pix_facades.yaml
configs/pix2pix_facades.yaml
+2
-0
ppgan/models/edvr_model.py
ppgan/models/edvr_model.py
+6
-2
ppgan/models/pix2pix_model.py
ppgan/models/pix2pix_model.py
+8
-2
ppgan/models/sr_model.py
ppgan/models/sr_model.py
+5
-2
test_tipc/configs/Pix2pix/train_infer_python.txt
test_tipc/configs/Pix2pix/train_infer_python.txt
+1
-1
test_tipc/configs/edvr/train_infer_python.txt
test_tipc/configs/edvr/train_infer_python.txt
+1
-1
test_tipc/configs/esrgan/train_infer_python.txt
test_tipc/configs/esrgan/train_infer_python.txt
+1
-1
未找到文件。
configs/edvr_m_wo_tsa.yaml
浏览文件 @
01cb542f
...
...
@@ -24,6 +24,8 @@ model:
w_TSA
:
False
pixel_criterion
:
name
:
CharbonnierLoss
# training model under @to_static
to_static
:
False
export_model
:
-
{
name
:
'
generator'
,
inputs_num
:
1
}
...
...
configs/esrgan_psnr_x4_div2k.yaml
浏览文件 @
01cb542f
...
...
@@ -14,6 +14,8 @@ model:
nb
:
23
pixel_criterion
:
name
:
L1Loss
# training model under @to_static
to_static
:
False
export_model
:
-
{
name
:
'
generator'
,
inputs_num
:
1
}
...
...
configs/pix2pix_facades.yaml
浏览文件 @
01cb542f
...
...
@@ -24,6 +24,8 @@ model:
gan_criterion
:
name
:
GANLoss
gan_mode
:
vanilla
# training model under @to_static
to_static
:
False
dataset
:
train
:
...
...
ppgan/models/edvr_model.py
浏览文件 @
01cb542f
...
...
@@ -15,6 +15,7 @@
import
paddle
import
paddle.nn
as
nn
from
.base_model
import
apply_to_static
from
.builder
import
MODELS
from
.sr_model
import
BaseSRModel
from
.generators.edvr
import
ResidualBlockNoBN
,
DCNPack
...
...
@@ -28,7 +29,8 @@ class EDVRModel(BaseSRModel):
Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks.
"""
def
__init__
(
self
,
generator
,
tsa_iter
,
pixel_criterion
=
None
):
def
__init__
(
self
,
generator
,
tsa_iter
,
pixel_criterion
=
None
,
to_static
=
False
,
image_shape
=
None
):
"""Initialize the EDVR class.
Args:
...
...
@@ -36,7 +38,9 @@ class EDVRModel(BaseSRModel):
tsa_iter (dict): config of tsa_iter.
pixel_criterion (dict): config of pixel criterion.
"""
super
(
EDVRModel
,
self
).
__init__
(
generator
,
pixel_criterion
)
super
(
EDVRModel
,
self
).
__init__
(
generator
,
pixel_criterion
,
to_static
=
to_static
,
image_shape
=
image_shape
)
self
.
tsa_iter
=
tsa_iter
self
.
current_iter
=
1
init_edvr_weight
(
self
.
nets
[
'generator'
])
...
...
ppgan/models/pix2pix_model.py
浏览文件 @
01cb542f
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
import
paddle
from
.base_model
import
BaseModel
from
.base_model
import
BaseModel
,
apply_to_static
from
.builder
import
MODELS
from
.generators.builder
import
build_generator
...
...
@@ -36,7 +36,9 @@ class Pix2PixModel(BaseModel):
discriminator
=
None
,
pixel_criterion
=
None
,
gan_criterion
=
None
,
direction
=
'a2b'
):
direction
=
'a2b'
,
to_static
=
False
,
image_shape
=
None
):
"""Initialize the pix2pix class.
Args:
...
...
@@ -51,11 +53,15 @@ class Pix2PixModel(BaseModel):
# define networks (both generator and discriminator)
self
.
nets
[
'netG'
]
=
build_generator
(
generator
)
init_weights
(
self
.
nets
[
'netG'
])
# set @to_static for benchmark, skip this by default.
apply_to_static
(
to_static
,
image_shape
,
self
.
nets
[
'netG'
])
# define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
if
discriminator
:
self
.
nets
[
'netD'
]
=
build_discriminator
(
discriminator
)
init_weights
(
self
.
nets
[
'netD'
])
# set @to_static for benchmark, skip this by default.
apply_to_static
(
to_static
,
image_shape
,
self
.
nets
[
'netD'
])
if
pixel_criterion
:
self
.
pixel_criterion
=
build_criterion
(
pixel_criterion
)
...
...
ppgan/models/sr_model.py
浏览文件 @
01cb542f
...
...
@@ -17,7 +17,7 @@ import paddle.nn as nn
from
.generators.builder
import
build_generator
from
.criterions.builder
import
build_criterion
from
.base_model
import
BaseModel
from
.base_model
import
BaseModel
,
apply_to_static
from
.builder
import
MODELS
from
..utils.visual
import
tensor2img
from
..modules.init
import
reset_parameters
...
...
@@ -28,7 +28,8 @@ class BaseSRModel(BaseModel):
"""Base SR model for single image super-resolution.
"""
def
__init__
(
self
,
generator
,
pixel_criterion
=
None
,
use_init_weight
=
False
):
def
__init__
(
self
,
generator
,
pixel_criterion
=
None
,
use_init_weight
=
False
,
to_static
=
False
,
image_shape
=
None
):
"""
Args:
generator (dict): config of generator.
...
...
@@ -37,6 +38,8 @@ class BaseSRModel(BaseModel):
super
(
BaseSRModel
,
self
).
__init__
()
self
.
nets
[
'generator'
]
=
build_generator
(
generator
)
# set @to_static for benchmark, skip this by default.
apply_to_static
(
to_static
,
image_shape
,
self
.
nets
[
'generator'
])
if
pixel_criterion
:
self
.
pixel_criterion
=
build_criterion
(
pixel_criterion
)
...
...
test_tipc/configs/Pix2pix/train_infer_python.txt
浏览文件 @
01cb542f
...
...
@@ -17,7 +17,7 @@ norm_train:tools/main.py -c configs/pix2pix_facades.yaml --seed 123 -o log_confi
pact_train:null
fpgm_train:null
distill_train:null
null:null
to_static_train:model.to_static=True
null:null
##
===========================eval_params===========================
...
...
test_tipc/configs/edvr/train_infer_python.txt
浏览文件 @
01cb542f
...
...
@@ -17,7 +17,7 @@ norm_train:tools/main.py -c configs/edvr_m_wo_tsa.yaml --seed 123 -o log_config.
pact_train:null
fpgm_train:null
distill_train:null
null:null
to_static_train:model.to_static=True
null:null
##
===========================eval_params===========================
...
...
test_tipc/configs/esrgan/train_infer_python.txt
浏览文件 @
01cb542f
...
...
@@ -17,7 +17,7 @@ norm_train:tools/main.py -c configs/esrgan_psnr_x4_div2k.yaml --seed 123 -o log_
pact_train:null
fpgm_train:null
distill_train:null
null:null
to_static_train:model.to_static=True
null:null
##
===========================eval_params===========================
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录