test_cond_op.py 3.8 KB
Newer Older
D
dzhwinter 已提交
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Z
cond op  
zchen0211 已提交
15
import logging
Q
Qiao Longfei 已提交
16
import paddle.v2.fluid.core as core
Z
cond op  
zchen0211 已提交
17 18
import unittest
import numpy as np
Q
Qiao Longfei 已提交
19
from paddle.v2.fluid.op import Operator, CondOp
Z
cond op  
zchen0211 已提交
20 21 22 23 24 25 26 27


class PySimpleCond(object):
    '''
    A simple implementation of dynamic if-else based on numpy
    '''

    def __init__(self):
Z
zchen0211 已提交
28
        array = [1] * 10
Z
cond op  
zchen0211 已提交
29
        for i in range(1, 10, 2):
Z
zchen0211 已提交
30
            array[i] = 0
Z
cond op  
zchen0211 已提交
31
        self.cond = np.array(array)
Q
qiaolongfei 已提交
32
        self.x = np.ones(shape=(10, 1)).astype("float32")
Z
cond op  
zchen0211 已提交
33 34

    def forward(self):
Z
zchen0211 已提交
35 36
        self.index_t = np.where(self.cond == 1)
        self.index_f = np.where(self.cond == 0)
Z
cond op  
zchen0211 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
        y_t = self.x[self.index_t]
        y_f = self.x[self.index_f]
        y_t = y_t * 2.
        y_f = y_f * (-2.)
        output = np.zeros(shape=(10, 1))
        output[self.index_t] = y_t
        output[self.index_f] = y_f
        return output


class PySimpleCondTest(unittest.TestCase):
    def setUp(self):
        self.condnn = PySimpleCond()

    def test_forward(self):
        output = self.condnn.forward()


def create_tensor(scope, name, shape, np_data):
D
dongzhihong 已提交
56
    tensor = scope.var(name).get_tensor()
Z
cond op  
zchen0211 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
    tensor.set_dims(shape)
    tensor.set(np_data, core.CPUPlace())
    return tensor


class TestCondOp(unittest.TestCase):
    '''
    Test CondOp

    equation:
        cond = [True, False, True, False, ...]
        y[index_t] = x[index_t] * 2.
        y[index_f] = x[index_f] * -2.
    outputs:
        y
    '''

    def setUp(self):
        self.py_cond = PySimpleCond()

    def forward(self):
        self.scope = core.Scope()
        self.create_global_variables()
        self.create_cond_op()
        self.create_sub_net()
D
dzhwinter 已提交
82
        self.condop.run(self.scope, core.CPUPlace())
Z
zchen0211 已提交
83
        return np.array(self.scope.find_var("Out").get_tensor())
Z
cond op  
zchen0211 已提交
84 85 86

    def create_global_variables(self):
        x_np_data = self.py_cond.x
Z
zchen0211 已提交
87 88 89
        create_tensor(self.scope, "X", [10, 1], x_np_data)
        cond_np_data = self.py_cond.cond.astype("int32")
        create_tensor(self.scope, "cond", [10, 1], cond_np_data)
D
dongzhihong 已提交
90 91 92
        self.scope.var("SubScopes")
        self.scope.var("IndexTensors")
        self.scope.var("Out")
Z
cond op  
zchen0211 已提交
93 94 95 96

    def create_cond_op(self):
        self.condop = CondOp(
            Cond="cond",
Z
zchen0211 已提交
97 98
            Xs=["X"],
            Outs=["Out"],
Z
cond op  
zchen0211 已提交
99 100 101 102 103
            SubScopes="SubScopes",
            IndexTensors="IndexTensors")

    def create_sub_net(self):
        truenet = core.Net.create()
Z
zchen0211 已提交
104
        scale_op_t = Operator("scale", X='X', Out='Out', scale=2.)
Z
cond op  
zchen0211 已提交
105 106 107 108 109
        truenet.append_op(scale_op_t)
        truenet.complete_add_op(True)
        self.condop.set_truenet(truenet)

        falsenet = core.Net.create()
Z
zchen0211 已提交
110
        scale_op_t = Operator("scale", X='X', Out='Out', scale=-2.)
Z
cond op  
zchen0211 已提交
111 112 113 114 115 116
        falsenet.append_op(scale_op_t)
        falsenet.complete_add_op(True)
        self.condop.set_falsenet(falsenet)

    def test_forward(self):
        print 'test cond op forward'
Z
zchen0211 已提交
117 118 119 120 121 122 123 124
        pd_output = self.forward()
        py_output = self.py_cond.forward()
        print 'pd_output', pd_output
        print
        print 'py_output', py_output
        self.assertEqual(pd_output.shape, py_output.shape)
        print 'test passed'
        return 0
Z
cond op  
zchen0211 已提交
125 126 127


if __name__ == "__main__":
Q
QI JUN 已提交
128 129 130
    exit(
        0
    )  # FIXME(qijun): https://github.com/PaddlePaddle/Paddle/issues/5101#issuecomment-339814957
Z
cond op  
zchen0211 已提交
131
    unittest.main()