提交 12aa35f4 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix ones

上级 c7c324d8
......@@ -48,13 +48,17 @@ class GANLoss(nn.Layer):
"""
if target_is_real:
if not hasattr(self, 'target_real_tensor'):
self.target_real_tensor = self.target_real_label * paddle.ones(
paddle.shape(prediction), dtype='float32')
self.target_real_tensor = paddle.fill_constant(
shape=paddle.shape(prediction),
value=self.target_real_label,
dtype='float32')
target_tensor = self.target_real_tensor
else:
if not hasattr(self, 'target_fake_tensor'):
self.target_fake_tensor = self.target_fake_label * paddle.ones(
paddle.shape(prediction), dtype='float32')
self.target_fake_tensor = paddle.fill_constant(
shape=paddle.shape(prediction),
value=self.target_fake_label,
dtype='float32')
target_tensor = self.target_fake_tensor
# target_tensor.stop_gradient = True
......
......@@ -80,7 +80,7 @@ def calculate_gain(nonlinearity, param=None):
@paddle.no_grad()
def constant_(x, value):
temp_value = value * paddle.ones(x.shape, x.dtype)
temp_value = paddle.full(shape=x.shape, fill_value=value, dtype=x.dtype)
x.set_value(temp_value)
return x
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册