diff --git a/python/paddle/v2/framework/tests/test_cond_op.py b/python/paddle/v2/framework/tests/test_cond_op.py index 3698ce9c8ed5c021826af622a53ee742e9b22552..76323b5e10c59822b4de82a70ebd57b3e57c8392 100644 --- a/python/paddle/v2/framework/tests/test_cond_op.py +++ b/python/paddle/v2/framework/tests/test_cond_op.py @@ -15,7 +15,7 @@ class PySimpleCond(object): for i in range(1, 10, 2): array[i] = 0 self.cond = np.array(array) - self.x = np.ones(shape=(10, 1)) + self.x = np.ones(shape=(10, 1)).astype("float32") def forward(self): self.index_t = np.where(self.cond == 1)