test_cond_op.py 3.2 KB
Newer Older
Z
cond op  
zchen0211 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
import logging
import paddle.v2.framework.core as core
import unittest
import numpy as np
from paddle.v2.framework.op import Operator, CondOp


class PySimpleCond(object):
    '''
    A simple implementation of dynamic if-else based on numpy
    '''

    def __init__(self):
Z
zchen0211 已提交
14
        array = [1] * 10
Z
cond op  
zchen0211 已提交
15
        for i in range(1, 10, 2):
Z
zchen0211 已提交
16
            array[i] = 0
Z
cond op  
zchen0211 已提交
17 18 19 20
        self.cond = np.array(array)
        self.x = np.ones(shape=(10, 1))

    def forward(self):
Z
zchen0211 已提交
21 22
        self.index_t = np.where(self.cond == 1)
        self.index_f = np.where(self.cond == 0)
Z
cond op  
zchen0211 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
        y_t = self.x[self.index_t]
        y_f = self.x[self.index_f]
        y_t = y_t * 2.
        y_f = y_f * (-2.)
        output = np.zeros(shape=(10, 1))
        output[self.index_t] = y_t
        output[self.index_f] = y_f
        return output


class PySimpleCondTest(unittest.TestCase):
    def setUp(self):
        self.condnn = PySimpleCond()

    def test_forward(self):
        output = self.condnn.forward()


def create_tensor(scope, name, shape, np_data):
    tensor = scope.new_var(name).get_tensor()
    tensor.set_dims(shape)
    tensor.set(np_data, core.CPUPlace())
    return tensor


class TestCondOp(unittest.TestCase):
    '''
    Test CondOp

    equation:
        cond = [True, False, True, False, ...]
        y[index_t] = x[index_t] * 2.
        y[index_f] = x[index_f] * -2.
    outputs:
        y
    '''

    def setUp(self):
        self.py_cond = PySimpleCond()

    def forward(self):
        self.scope = core.Scope()
        self.create_global_variables()
        self.create_cond_op()
        self.create_sub_net()
        ctx = core.DeviceContext.create(core.CPUPlace())
        self.condop.infer_shape(self.scope)
        self.condop.run(self.scope, ctx)
Z
zchen0211 已提交
71
        return np.array(self.scope.find_var("Out").get_tensor())
Z
cond op  
zchen0211 已提交
72 73 74

    def create_global_variables(self):
        x_np_data = self.py_cond.x
Z
zchen0211 已提交
75 76 77
        create_tensor(self.scope, "X", [10, 1], x_np_data)
        cond_np_data = self.py_cond.cond.astype("int32")
        create_tensor(self.scope, "cond", [10, 1], cond_np_data)
Z
cond op  
zchen0211 已提交
78 79
        self.scope.new_var("SubScopes")
        self.scope.new_var("IndexTensors")
Z
zchen0211 已提交
80
        self.scope.new_var("Out")
Z
cond op  
zchen0211 已提交
81 82 83 84

    def create_cond_op(self):
        self.condop = CondOp(
            Cond="cond",
Z
zchen0211 已提交
85 86
            Xs=["X"],
            Outs=["Out"],
Z
cond op  
zchen0211 已提交
87 88 89 90 91
            SubScopes="SubScopes",
            IndexTensors="IndexTensors")

    def create_sub_net(self):
        truenet = core.Net.create()
Z
zchen0211 已提交
92
        scale_op_t = Operator("scale", X='X', Out='Out', scale=2.)
Z
cond op  
zchen0211 已提交
93 94 95 96 97
        truenet.append_op(scale_op_t)
        truenet.complete_add_op(True)
        self.condop.set_truenet(truenet)

        falsenet = core.Net.create()
Z
zchen0211 已提交
98
        scale_op_t = Operator("scale", X='X', Out='Out', scale=-2.)
Z
cond op  
zchen0211 已提交
99 100 101 102 103 104
        falsenet.append_op(scale_op_t)
        falsenet.complete_add_op(True)
        self.condop.set_falsenet(falsenet)

    def test_forward(self):
        print 'test cond op forward'
Z
zchen0211 已提交
105 106 107 108 109 110 111 112
        pd_output = self.forward()
        py_output = self.py_cond.forward()
        print 'pd_output', pd_output
        print
        print 'py_output', py_output
        self.assertEqual(pd_output.shape, py_output.shape)
        print 'test passed'
        return 0
Z
cond op  
zchen0211 已提交
113 114 115 116


if __name__ == "__main__":
    unittest.main()