test_cond_op.py 3.1 KB
Newer Older
Z
cond op  
zchen0211 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
import logging
import paddle.v2.framework.core as core
import unittest
import numpy as np
from paddle.v2.framework.op import Operator, CondOp


class PySimpleCond(object):
    '''
    A simple implementation of dynamic if-else based on numpy
    '''

    def __init__(self):
Z
zchen0211 已提交
14
        array = [1] * 10
Z
cond op  
zchen0211 已提交
15
        for i in range(1, 10, 2):
Z
zchen0211 已提交
16
            array[i] = 0
Z
cond op  
zchen0211 已提交
17
        self.cond = np.array(array)
Q
qiaolongfei 已提交
18
        self.x = np.ones(shape=(10, 1)).astype("float32")
Z
cond op  
zchen0211 已提交
19 20

    def forward(self):
Z
zchen0211 已提交
21 22
        self.index_t = np.where(self.cond == 1)
        self.index_f = np.where(self.cond == 0)
Z
cond op  
zchen0211 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
        y_t = self.x[self.index_t]
        y_f = self.x[self.index_f]
        y_t = y_t * 2.
        y_f = y_f * (-2.)
        output = np.zeros(shape=(10, 1))
        output[self.index_t] = y_t
        output[self.index_f] = y_f
        return output


class PySimpleCondTest(unittest.TestCase):
    def setUp(self):
        self.condnn = PySimpleCond()

    def test_forward(self):
        output = self.condnn.forward()


def create_tensor(scope, name, shape, np_data):
D
dongzhihong 已提交
42
    tensor = scope.var(name).get_tensor()
Z
cond op  
zchen0211 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
    tensor.set_dims(shape)
    tensor.set(np_data, core.CPUPlace())
    return tensor


class TestCondOp(unittest.TestCase):
    '''
    Test CondOp

    equation:
        cond = [True, False, True, False, ...]
        y[index_t] = x[index_t] * 2.
        y[index_f] = x[index_f] * -2.
    outputs:
        y
    '''

    def setUp(self):
        self.py_cond = PySimpleCond()

    def forward(self):
        self.scope = core.Scope()
        self.create_global_variables()
        self.create_cond_op()
        self.create_sub_net()
        ctx = core.DeviceContext.create(core.CPUPlace())
        self.condop.run(self.scope, ctx)
Z
zchen0211 已提交
70
        return np.array(self.scope.find_var("Out").get_tensor())
Z
cond op  
zchen0211 已提交
71 72 73

    def create_global_variables(self):
        x_np_data = self.py_cond.x
Z
zchen0211 已提交
74 75 76
        create_tensor(self.scope, "X", [10, 1], x_np_data)
        cond_np_data = self.py_cond.cond.astype("int32")
        create_tensor(self.scope, "cond", [10, 1], cond_np_data)
D
dongzhihong 已提交
77 78 79
        self.scope.var("SubScopes")
        self.scope.var("IndexTensors")
        self.scope.var("Out")
Z
cond op  
zchen0211 已提交
80 81 82 83

    def create_cond_op(self):
        self.condop = CondOp(
            Cond="cond",
Z
zchen0211 已提交
84 85
            Xs=["X"],
            Outs=["Out"],
Z
cond op  
zchen0211 已提交
86 87 88 89 90
            SubScopes="SubScopes",
            IndexTensors="IndexTensors")

    def create_sub_net(self):
        truenet = core.Net.create()
Z
zchen0211 已提交
91
        scale_op_t = Operator("scale", X='X', Out='Out', scale=2.)
Z
cond op  
zchen0211 已提交
92 93 94 95 96
        truenet.append_op(scale_op_t)
        truenet.complete_add_op(True)
        self.condop.set_truenet(truenet)

        falsenet = core.Net.create()
Z
zchen0211 已提交
97
        scale_op_t = Operator("scale", X='X', Out='Out', scale=-2.)
Z
cond op  
zchen0211 已提交
98 99 100 101 102 103
        falsenet.append_op(scale_op_t)
        falsenet.complete_add_op(True)
        self.condop.set_falsenet(falsenet)

    def test_forward(self):
        print 'test cond op forward'
Z
zchen0211 已提交
104 105 106 107 108 109 110 111
        pd_output = self.forward()
        py_output = self.py_cond.forward()
        print 'pd_output', pd_output
        print
        print 'py_output', py_output
        self.assertEqual(pd_output.shape, py_output.shape)
        print 'test passed'
        return 0
Z
cond op  
zchen0211 已提交
112 113 114 115


if __name__ == "__main__":
    unittest.main()