test_cond_op.py 3.2 KB
Newer Older
Z
cond op  
zchen0211 已提交
1
import logging
Q
Qiao Longfei 已提交
2
import paddle.v2.fluid.core as core
Z
cond op  
zchen0211 已提交
3 4
import unittest
import numpy as np
Q
Qiao Longfei 已提交
5
from paddle.v2.fluid.op import Operator, CondOp
Z
cond op  
zchen0211 已提交
6 7 8 9 10 11 12 13


class PySimpleCond(object):
    '''
    A simple implementation of dynamic if-else based on numpy
    '''

    def __init__(self):
Z
zchen0211 已提交
14
        array = [1] * 10
Z
cond op  
zchen0211 已提交
15
        for i in range(1, 10, 2):
Z
zchen0211 已提交
16
            array[i] = 0
Z
cond op  
zchen0211 已提交
17
        self.cond = np.array(array)
Q
qiaolongfei 已提交
18
        self.x = np.ones(shape=(10, 1)).astype("float32")
Z
cond op  
zchen0211 已提交
19 20

    def forward(self):
Z
zchen0211 已提交
21 22
        self.index_t = np.where(self.cond == 1)
        self.index_f = np.where(self.cond == 0)
Z
cond op  
zchen0211 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
        y_t = self.x[self.index_t]
        y_f = self.x[self.index_f]
        y_t = y_t * 2.
        y_f = y_f * (-2.)
        output = np.zeros(shape=(10, 1))
        output[self.index_t] = y_t
        output[self.index_f] = y_f
        return output


class PySimpleCondTest(unittest.TestCase):
    def setUp(self):
        self.condnn = PySimpleCond()

    def test_forward(self):
        output = self.condnn.forward()


def create_tensor(scope, name, shape, np_data):
D
dongzhihong 已提交
42
    tensor = scope.var(name).get_tensor()
Z
cond op  
zchen0211 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
    tensor.set_dims(shape)
    tensor.set(np_data, core.CPUPlace())
    return tensor


class TestCondOp(unittest.TestCase):
    '''
    Test CondOp

    equation:
        cond = [True, False, True, False, ...]
        y[index_t] = x[index_t] * 2.
        y[index_f] = x[index_f] * -2.
    outputs:
        y
    '''

    def setUp(self):
        self.py_cond = PySimpleCond()

    def forward(self):
        self.scope = core.Scope()
        self.create_global_variables()
        self.create_cond_op()
        self.create_sub_net()
D
dzhwinter 已提交
68
        self.condop.run(self.scope, core.CPUPlace())
Z
zchen0211 已提交
69
        return np.array(self.scope.find_var("Out").get_tensor())
Z
cond op  
zchen0211 已提交
70 71 72

    def create_global_variables(self):
        x_np_data = self.py_cond.x
Z
zchen0211 已提交
73 74 75
        create_tensor(self.scope, "X", [10, 1], x_np_data)
        cond_np_data = self.py_cond.cond.astype("int32")
        create_tensor(self.scope, "cond", [10, 1], cond_np_data)
D
dongzhihong 已提交
76 77 78
        self.scope.var("SubScopes")
        self.scope.var("IndexTensors")
        self.scope.var("Out")
Z
cond op  
zchen0211 已提交
79 80 81 82

    def create_cond_op(self):
        self.condop = CondOp(
            Cond="cond",
Z
zchen0211 已提交
83 84
            Xs=["X"],
            Outs=["Out"],
Z
cond op  
zchen0211 已提交
85 86 87 88 89
            SubScopes="SubScopes",
            IndexTensors="IndexTensors")

    def create_sub_net(self):
        truenet = core.Net.create()
Z
zchen0211 已提交
90
        scale_op_t = Operator("scale", X='X', Out='Out', scale=2.)
Z
cond op  
zchen0211 已提交
91 92 93 94 95
        truenet.append_op(scale_op_t)
        truenet.complete_add_op(True)
        self.condop.set_truenet(truenet)

        falsenet = core.Net.create()
Z
zchen0211 已提交
96
        scale_op_t = Operator("scale", X='X', Out='Out', scale=-2.)
Z
cond op  
zchen0211 已提交
97 98 99 100 101 102
        falsenet.append_op(scale_op_t)
        falsenet.complete_add_op(True)
        self.condop.set_falsenet(falsenet)

    def test_forward(self):
        print 'test cond op forward'
Z
zchen0211 已提交
103 104 105 106 107 108 109 110
        pd_output = self.forward()
        py_output = self.py_cond.forward()
        print 'pd_output', pd_output
        print
        print 'py_output', py_output
        self.assertEqual(pd_output.shape, py_output.shape)
        print 'test passed'
        return 0
Z
cond op  
zchen0211 已提交
111 112 113


if __name__ == "__main__":
Q
QI JUN 已提交
114 115 116
    exit(
        0
    )  # FIXME(qijun): https://github.com/PaddlePaddle/Paddle/issues/5101#issuecomment-339814957
Z
cond op  
zchen0211 已提交
117
    unittest.main()