提交 0970de7c 编写于 作者: L LielinJiang

fix save error

上级 abd3250d
...@@ -93,7 +93,7 @@ class BaseModel(ABC): ...@@ -93,7 +93,7 @@ class BaseModel(ABC):
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
It also calls <compute_visuals> to produce additional visualization results It also calls <compute_visuals> to produce additional visualization results
""" """
with paddle.imperative.no_grad(): with paddle.no_grad():
self.forward() self.forward()
self.compute_visuals() self.compute_visuals()
......
...@@ -86,7 +86,7 @@ class Pix2PixModel(BaseModel): ...@@ -86,7 +86,7 @@ class Pix2PixModel(BaseModel):
self.fake_B = self.netG(self.real_A) # G(A) self.fake_B = self.netG(self.real_A) # G(A)
def forward_test(self, input): def forward_test(self, input):
input = paddle.imperative.to_variable(input) input = paddle.to_tensor(input)
return self.netG(input) return self.netG(input)
def backward_D(self): def backward_D(self):
......
...@@ -3,17 +3,19 @@ import six ...@@ -3,17 +3,19 @@ import six
import pickle import pickle
import paddle import paddle
def makedirs(dir): def makedirs(dir):
if not os.path.exists(dir): if not os.path.exists(dir):
os.makedirs(dir) os.makedirs(dir)
def save(state_dicts, file_name):
def save(state_dicts, file_name):
def convert(state_dict): def convert(state_dict):
model_dict = {} model_dict = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
if isinstance(v, (paddle.framework.Variable, paddle.imperative.core.VarBase)): if isinstance(
v, (paddle.framework.Variable, paddle.fluid.core.VarBase)):
model_dict[k] = v.numpy() model_dict[k] = v.numpy()
else: else:
model_dict[k] = v model_dict[k] = v
...@@ -22,14 +24,15 @@ def save(state_dicts, file_name): ...@@ -22,14 +24,15 @@ def save(state_dicts, file_name):
final_dict = {} final_dict = {}
for k, v in state_dicts.items(): for k, v in state_dicts.items():
if isinstance(v, (paddle.framework.Variable, paddle.imperative.core.VarBase)): if isinstance(v,
(paddle.framework.Variable, paddle.fluid.core.VarBase)):
final_dict = convert(state_dicts) final_dict = convert(state_dicts)
break break
elif isinstance(v, dict): elif isinstance(v, dict):
final_dict[k] = convert(v) final_dict[k] = convert(v)
else: else:
final_dict[k] = v final_dict[k] = v
with open(file_name, 'wb') as f: with open(file_name, 'wb') as f:
pickle.dump(final_dict, f, protocol=2) pickle.dump(final_dict, f, protocol=2)
...@@ -39,7 +42,3 @@ def load(file_name): ...@@ -39,7 +42,3 @@ def load(file_name):
state_dicts = pickle.load(f) if six.PY2 else pickle.load( state_dicts = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1') f, encoding='latin1')
return state_dicts return state_dicts
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册