未验证 提交 7ca836d3 编写于 作者: Z zhongpu 提交者: GitHub

support if logic for Variable in dygraph (#22892)

* support if logic for Variable in dygraph, test=develop

* fix test_learning_rate_scheduler.py, test=develop

* fix optest, test=develop

* fix error message, test=develop

* fix optest,test=develop

* fix comment, test=develop
上级 166a1ae9
......@@ -1661,7 +1661,7 @@ class GRUUnit(layers.Layer):
'HiddenPrev': [hidden],
'Weight': [self.weight]
}
if self.bias:
if self.bias is not None:
inputs['Bias'] = [self.bias]
attrs = {
'activation': self.activation,
......@@ -2114,7 +2114,7 @@ class BilinearTensorProduct(layers.Layer):
def forward(self, x, y):
self._inputs = {"X": x, "Y": y, "Weight": self.weight}
if self.bias:
if self.bias is not None:
self._inputs["Bias"] = self.bias
if self._name is not None:
out = self._helper.create_variable(
......
......@@ -420,7 +420,7 @@ class DataParallel(layers.Layer):
grad_vars = []
for param in self._layers.parameters():
# NOTE(zcd): The grad_ivar maybe no generated.
if param.trainable and param._grad_ivar():
if param.trainable and (param._grad_ivar() is not None):
g_var = param._grad_ivar()
grad_vars.append(g_var)
assert g_var not in grad_var_set
......
......@@ -204,9 +204,20 @@ def monkey_patch_varbase():
return 'name %s, shape: %s, not inited' % (self.name,
self.shape)
for method_name, method in (("set_value", set_value), ("block", block),
("backward", backward), ("gradient", gradient),
("__str__", __str__), ("to_string", to_string)):
def __nonzero__(self):
numel = np.prod(self.shape)
assert numel == 1, "When Variable is used as the condition of if/while , Variable can only contain one element."
tensor = self.value().get_tensor()
assert tensor._is_initialized(), "tensor not initialized"
return bool(np.all(tensor.__array__() > 0))
def __bool__(self):
return self.__nonzero__()
for method_name, method in (
("__bool__", __bool__), ("__nonzero__", __nonzero__),
("set_value", set_value), ("block", block), ("backward", backward),
("gradient", gradient), ("__str__", __str__), ("to_string", to_string)):
setattr(core.VarBase, method_name, method)
# patch math methods for varbase
......
......@@ -55,7 +55,7 @@ def _append_bias_in_dygraph(input, bias=None, axis=1):
Return the Variable after bias operation
"""
if not bias:
if bias is None:
return input
return core.ops.elementwise_add(input, bias, 'axis', axis)
......@@ -182,8 +182,8 @@ class NetWithControlFlowIf(fluid.dygraph.Layer):
def if_with_and_or(x_v, label=None):
batch_size = fluid.layers.shape(x_v)
if x_v and (fluid.layers.mean(x_v).numpy()[0] > 0 or
label is not None) and batch_size[0] > 1 and True:
if x_v is not None and (fluid.layers.mean(x_v).numpy()[0] > 0 or
label is not None) and batch_size[0] > 1 and True:
x_v = x_v - 1
else:
x_v = x_v + 1
......@@ -198,16 +198,16 @@ def if_with_and_or_1(x, y=None):
batch_size = fluid.layers.shape(x)
if batch_size[0] > 1 and y is not None:
x = x + 1
if y or batch_size[0] > 1:
if y is not None or batch_size[0] > 1:
x = x - 1
return x
def if_with_and_or_2(x, y=None):
batch_size = fluid.layers.shape(x)
if x and batch_size[0] > 1 and y is not None:
if x is not None and batch_size[0] > 1 and y is not None:
x = x + 1
if batch_size[0] > 1 or y or x is not None:
if batch_size[0] > 1 or y is not None or x is not None:
x = x - 1
return x
......@@ -215,9 +215,10 @@ def if_with_and_or_2(x, y=None):
def if_with_and_or_3(x, y=None):
batch_size = fluid.layers.shape(x)
mean_res = fluid.layers.mean(x)
if x and batch_size[0] > 1 and y is not None and mean_res.numpy()[0] > 0:
if x is not None and batch_size[0] > 1 and y is not None and mean_res.numpy(
)[0] > 0:
x = x + 1
if mean_res.numpy()[0] > 0 and (x and batch_size[0] > 1) and y:
if mean_res.numpy()[0] > 0 and (x is not None and batch_size[0] > 1) and y:
x = x - 1
return x
......@@ -225,8 +226,10 @@ def if_with_and_or_3(x, y=None):
def if_with_and_or_4(x, y=None):
batch_size = fluid.layers.shape(x)
mean_res = fluid.layers.mean(x)
if (x and batch_size[0] > 1) or (y is not None and mean_res.numpy()[0] > 0):
if (x is not None and batch_size[0] > 1) or (y is not None and
mean_res.numpy()[0] > 0):
x = x + 1
if (x or batch_size[0] > 1) and (y is not None or mean_res.numpy()[0] > 0):
if (x is not None or batch_size[0] > 1) and (y is not None or
mean_res.numpy()[0] > 0):
x = x - 1
return x
......@@ -100,7 +100,7 @@ def dyfunc_ifExp_with_while(x):
def body(i, ten, y):
# It will be converted into `layers.cond` as followed.
# map_func(lambda x: fluid.layers.cond(i==0, lambda: x, lambda: add_fn(x), y)
y = map_func(lambda x: x if i == 0 else add_fn(x), y)
y = map_func(lambda x: x if (i == 0) is not None else add_fn(x), y)
i += 1
return [i, ten, y]
......
......@@ -50,7 +50,7 @@ class PrePostProcessLayer(Layer):
self.functors = []
for cmd in self.process_cmd:
if cmd == "a": # add residual connection
self.functors.append(lambda x, y: x + y if y else x)
self.functors.append(lambda x, y: x + y if y is not None else x)
elif cmd == "n": # add layer normalization
self.functors.append(
self.add_sublayer(
......@@ -118,7 +118,7 @@ class MultiHeadAttention(Layer):
# scale dot product attention
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.d_model**-0.5)
if attn_bias:
if attn_bias is not None:
product += attn_bias
weights = layers.softmax(product)
if self.dropout_rate:
......
......@@ -364,7 +364,7 @@ class PrePostProcessLayer(Layer):
def forward(self, prev_out, out, process_cmd, dropout_rate=0.):
for cmd in process_cmd:
if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out
out = out + prev_out if prev_out is not None else out
elif cmd == "n": # add layer normalization
out = self._layer_norm(out)
elif cmd == "d": # add dropout
......@@ -443,7 +443,7 @@ class MultiHeadAttentionLayer(Layer):
y=transpose_k,
transpose_y=True,
alpha=self._d_model**-0.5)
if attn_bias:
if attn_bias is not None:
product += attn_bias
weights = fluid.layers.softmax(product)
if self._dropout_rate:
......
......@@ -16,6 +16,7 @@ from __future__ import print_function
import copy
import math
import numpy as np
import unittest
import paddle.fluid as fluid
......@@ -319,7 +320,8 @@ class TestLinearWamrupLearningRateDecayDygraphMode(unittest.TestCase):
t = lr()
self.assertEqual(t[0], right_result[i])
self.assertTrue(
np.allclose((t.numpy())[0].item(), right_result[i]))
class TestLinearWamrupLearningRateDecayDygraphModeTypeCheck(unittest.TestCase):
......
......@@ -189,6 +189,25 @@ class TestVarBase(unittest.TestCase):
np.array_equal(var.numpy(),
fluid.framework._var_base_to_np(var)))
def test_if(self):
with fluid.dygraph.guard():
var1 = fluid.dygraph.to_variable(np.array([[[0]]]))
var2 = fluid.dygraph.to_variable(np.array([[[1]]]))
var1_bool = False
var2_bool = False
if var1:
var1_bool = True
if var2:
var2_bool = True
assert var1_bool == False, "if var1 should be false"
assert var2_bool == True, "if var2 should be true"
assert bool(var1) == False, "bool(var1) is False"
assert bool(var2) == True, "bool(var2) is True"
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册