未验证 提交 9cdb6039 编写于 作者: W wangna11BD 提交者: GitHub

[to_static train]fix to static train bug for cyclegan model (#741)

* fix to static train bug for cyclegan model

* fix _ in log name
上级 ab31b18e
......@@ -122,6 +122,7 @@ class CycleGANModel(BaseModel):
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
if hasattr(self, 'real_A'):
self.real_A.stop_gradient = False
self.fake_B = self.nets['netG_A'](self.real_A) # G_A(A)
self.rec_A = self.nets['netG_B'](self.fake_B) # G_B(G_A(A))
......@@ -229,14 +230,13 @@ class CycleGANModel(BaseModel):
# forward
# compute fake images and reconstruction images.
self.forward()
# G_A and G_B
# Ds require no gradients when optimizing Gs
self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']],
False)
# set G_A and G_B's gradients to zero
optimizers['optimG'].clear_grad()
# calculate gradients for G_A and G_B
self.backward_G()
# G_A and G_B
# Ds require no gradients when optimizing Gs
self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']], False)
# update G_A and G_B's weights
self.optimizers['optimG'].step()
# D_A and D_B
......
......@@ -73,7 +73,7 @@ REST_ARGS=$4
to_static=""
# parse "to_static" options and modify trainer into "to_static_trainer"
if [ $REST_ARGS = "to_static" ] || [ $PARAMS = "to_static" ] ;then
to_static="d2sT"
to_static="d2sT_"
sed -i 's/trainer:norm_train/trainer:to_static_train/g' $FILENAME
# clear PARAM contents
if [ $PARAMS = "to_static" ] ;then
......@@ -175,7 +175,7 @@ for batch_size in ${batch_size_list[*]}; do
if [ ${#gpu_id} -le 1 ];then
log_path="$SAVE_LOG/profiling_log"
mkdir -p $log_path
log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_${to_static}_profiling"
log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_${to_static}profiling"
func_sed_params "$FILENAME" "${line_gpuid}" "0" # sed used gpu_id
# set profile_option params
tmp=`sed -i "${line_profile}s/.*/${profile_option}/" "${FILENAME}"`
......@@ -191,8 +191,8 @@ for batch_size in ${batch_size_list[*]}; do
speed_log_path="$SAVE_LOG/index"
mkdir -p $log_path
mkdir -p $speed_log_path
log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_${to_static}_log"
speed_log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_${to_static}_speed"
log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_${to_static}log"
speed_log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_${to_static}speed"
func_sed_params "$FILENAME" "${line_profile}" "null" # sed profile_id as null
cmd="bash test_tipc/test_train_inference_python.sh ${FILENAME} benchmark_train > ${log_path}/${log_name} 2>&1 "
echo $cmd
......@@ -226,8 +226,8 @@ for batch_size in ${batch_size_list[*]}; do
speed_log_path="$SAVE_LOG/index"
mkdir -p $log_path
mkdir -p $speed_log_path
log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_${to_static}_log"
speed_log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_${to_static}_speed"
log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_${to_static}log"
speed_log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_${to_static}speed"
func_sed_params "$FILENAME" "${line_gpuid}" "$gpu_id" # sed used gpu_id
func_sed_params "$FILENAME" "${line_profile}" "null" # sed --profile_option as null
cmd="bash test_tipc/test_train_inference_python.sh ${FILENAME} benchmark_train > ${log_path}/${log_name} 2>&1 "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册