未验证 提交 38d2d13d 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat] Add CycleGAN model for unitttest (#25072)

* add cycle_gan_model

* align train test=develop

* modify image_size into 64 to avoid TimeOut test=develop

* TODO in GPU test=develop
上级 ff5be2fb
...@@ -110,11 +110,13 @@ class PartialProgramLayer(layers.Layer): ...@@ -110,11 +110,13 @@ class PartialProgramLayer(layers.Layer):
self._inputs = NestSequence(inputs) self._inputs = NestSequence(inputs)
self._outputs = NestSequence(outputs, need_check=True) self._outputs = NestSequence(outputs, need_check=True)
self._params = parameters if parameters is not None else [] self._params = parameters if parameters is not None else []
# Check all params from main program can be found in self._params: # Check all params from main program can be found in self._params:
# 1. parameter in self._params should be type `framework.ParamBase` which are created in dygraph. # 1. parameter in self._params should be type `framework.ParamBase` which are created in dygraph.
# 2. parameter from transformed program shall be found in self._params. # 2. parameter from transformed program shall be found in self._params.
# Because they share same data with ParamBase of original dygraph. # Because they share same data with ParamBase of original dygraph.
self._check_params_all_inited(main_program) self._check_params_all_inited(main_program)
self._prune_unused_params(main_program)
self._infer_program = main_program self._infer_program = main_program
self._train_program = self._append_backward_desc() self._train_program = self._append_backward_desc()
...@@ -138,6 +140,23 @@ class PartialProgramLayer(layers.Layer): ...@@ -138,6 +140,23 @@ class PartialProgramLayer(layers.Layer):
return program return program
def _prune_unused_params(self, program):
"""
Prune the parameters not used anywhere in the program.
The `@declarative` may only decorated a sub function which
contains some unused parameters created in `__init__`.
So prune these parameters to avoid unnecessary operations in
`run_program_op`.
"""
required_params = []
for param in self._params:
for block in program.blocks:
if param.name in block.vars:
required_params.append(param)
break
self._params = required_params
def train(self): def train(self):
# self.training is inherited from layers.Layer # self.training is inherited from layers.Layer
self.training = True self.training = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册