提交 0970de7c 编写于 作者: L LielinJiang

fix save error

上级 abd3250d
......@@ -93,7 +93,7 @@ class BaseModel(ABC):
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
It also calls <compute_visuals> to produce additional visualization results
"""
with paddle.imperative.no_grad():
with paddle.no_grad():
self.forward()
self.compute_visuals()
......
......@@ -86,7 +86,7 @@ class Pix2PixModel(BaseModel):
self.fake_B = self.netG(self.real_A) # G(A)
def forward_test(self, input):
input = paddle.imperative.to_variable(input)
input = paddle.to_tensor(input)
return self.netG(input)
def backward_D(self):
......
......@@ -3,17 +3,19 @@ import six
import pickle
import paddle
def makedirs(dir):
if not os.path.exists(dir):
os.makedirs(dir)
def save(state_dicts, file_name):
def save(state_dicts, file_name):
def convert(state_dict):
model_dict = {}
for k, v in state_dict.items():
if isinstance(v, (paddle.framework.Variable, paddle.imperative.core.VarBase)):
if isinstance(
v, (paddle.framework.Variable, paddle.fluid.core.VarBase)):
model_dict[k] = v.numpy()
else:
model_dict[k] = v
......@@ -22,14 +24,15 @@ def save(state_dicts, file_name):
final_dict = {}
for k, v in state_dicts.items():
if isinstance(v, (paddle.framework.Variable, paddle.imperative.core.VarBase)):
if isinstance(v,
(paddle.framework.Variable, paddle.fluid.core.VarBase)):
final_dict = convert(state_dicts)
break
elif isinstance(v, dict):
final_dict[k] = convert(v)
else:
final_dict[k] = v
with open(file_name, 'wb') as f:
pickle.dump(final_dict, f, protocol=2)
......@@ -39,7 +42,3 @@ def load(file_name):
state_dicts = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
return state_dicts
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册