未验证 提交 a82c9df4 编写于 作者: B Bai Yifan 提交者: GitHub

fix assign (#265)

上级 db7329bc
...@@ -21,8 +21,7 @@ from paddle.fluid.dygraph.base import to_variable ...@@ -21,8 +21,7 @@ from paddle.fluid.dygraph.base import to_variable
class Architect(object): class Architect(object):
def __init__(self, model, eta, arch_learning_rate, place, unrolled, def __init__(self, model, eta, arch_learning_rate, unrolled, parallel):
parallel):
self.network_momentum = 0.9 self.network_momentum = 0.9
self.network_weight_decay = 3e-4 self.network_weight_decay = 3e-4
self.eta = eta self.eta = eta
...@@ -33,7 +32,6 @@ class Architect(object): ...@@ -33,7 +32,6 @@ class Architect(object):
0.999, 0.999,
regularization=fluid.regularizer.L2Decay(1e-3), regularization=fluid.regularizer.L2Decay(1e-3),
parameter_list=self.model.arch_parameters()) parameter_list=self.model.arch_parameters())
self.place = place
self.unrolled = unrolled self.unrolled = unrolled
self.parallel = parallel self.parallel = parallel
if self.unrolled: if self.unrolled:
...@@ -110,13 +108,14 @@ class Architect(object): ...@@ -110,13 +108,14 @@ class Architect(object):
target_train) target_train)
for (p, g), ig in zip(arch_params_grads, implicit_grads): for (p, g), ig in zip(arch_params_grads, implicit_grads):
new_g = g - (ig * self.unrolled_optimizer.current_step_lr()) new_g = g - (ig * self.unrolled_optimizer.current_step_lr())
g.value().get_tensor().set(new_g.numpy(), self.place) fluid.layers.assign(new_g.detach(), g)
return arch_params_grads return arch_params_grads
def _compute_unrolled_model(self, input, target): def _compute_unrolled_model(self, input, target):
for x, y in zip(self.unrolled_model.parameters(), for x, y in zip(self.unrolled_model.parameters(),
self.model.parameters()): self.model.parameters()):
x.value().get_tensor().set(y.numpy(), self.place) fluid.layers.assign(y.detach(), x)
loss = self.unrolled_model._loss(input, target) loss = self.unrolled_model._loss(input, target)
if self.parallel: if self.parallel:
loss = self.parallel_unrolled_model.scale_loss(loss) loss = self.parallel_unrolled_model.scale_loss(loss)
...@@ -141,7 +140,7 @@ class Architect(object): ...@@ -141,7 +140,7 @@ class Architect(object):
] ]
for param, grad in zip(model_params, vector): for param, grad in zip(model_params, vector):
param_p = param + grad * R param_p = param + grad * R
param.value().get_tensor().set(param_p.numpy(), self.place) fluid.layers.assign(param_p.detach(), param)
loss = self.model._loss(input, target) loss = self.model._loss(input, target)
if self.parallel: if self.parallel:
loss = self.parallel_model.scale_loss(loss) loss = self.parallel_model.scale_loss(loss)
...@@ -157,7 +156,7 @@ class Architect(object): ...@@ -157,7 +156,7 @@ class Architect(object):
for param, grad in zip(model_params, vector): for param, grad in zip(model_params, vector):
param_n = param - grad * R * 2 param_n = param - grad * R * 2
param.value().get_tensor().set(param_n.numpy(), self.place) fluid.layers.assign(param_n.detach(), param)
self.model.clear_gradients() self.model.clear_gradients()
loss = self.model._loss(input, target) loss = self.model._loss(input, target)
...@@ -174,7 +173,7 @@ class Architect(object): ...@@ -174,7 +173,7 @@ class Architect(object):
] ]
for param, grad in zip(model_params, vector): for param, grad in zip(model_params, vector):
param_o = param + grad * R param_o = param + grad * R
param.value().get_tensor().set(param_o.numpy(), self.place) fluid.layers.assign(param_o.detach(), param)
self.model.clear_gradients() self.model.clear_gradients()
arch_grad = [(p - n) / (2 * R) for p, n in zip(grads_p, grads_n)] arch_grad = [(p - n) / (2 * R) for p, n in zip(grads_p, grads_n)]
return arch_grad return arch_grad
...@@ -227,7 +227,6 @@ class DARTSearch(object): ...@@ -227,7 +227,6 @@ class DARTSearch(object):
model=self.model, model=self.model,
eta=learning_rate, eta=learning_rate,
arch_learning_rate=self.arch_learning_rate, arch_learning_rate=self.arch_learning_rate,
place=self.place,
unrolled=self.unrolled, unrolled=self.unrolled,
parallel=self.use_data_parallel) parallel=self.use_data_parallel)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册