op_test.py 16.3 KB
Newer Older
1 2
import unittest
import numpy as np
3
import random
4 5 6
import itertools
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
Y
Yang Yang(Tony) 已提交
7 8
from paddle.v2.framework.executor import Executor
from paddle.v2.framework.framework import Program, OpProtoHolder
9 10


11 12 13 14 15 16 17 18 19
def randomize_probability(batch_size, class_num, dtype='float32'):
    prob = np.random.uniform(
        0.1, 1.0, size=(batch_size, class_num)).astype(dtype)
    prob_sum = prob.sum(axis=1)
    for i in xrange(len(prob)):
        prob[i] /= prob_sum[i]
    return prob


20 21 22 23
def grad_var_name(var_name):
    return var_name + "@GRAD"


Q
qijun 已提交
24
def create_op(scope, op_type, inputs, outputs, attrs):
25 26
    kwargs = dict()

Y
Yu Yang 已提交
27
    def __create_var__(name, var_name):
D
dongzhihong 已提交
28
        scope.var(var_name)
Y
Yu Yang 已提交
29 30
        kwargs[name].append(var_name)

Q
qijun 已提交
31
    for in_name, in_dup in Operator.get_op_inputs(op_type):
32 33 34 35
        if in_name in inputs:
            kwargs[in_name] = []
            if in_dup:
                sub_in = inputs[in_name]
Q
qijun 已提交
36
                for sub_in_name, _ in sub_in:
Y
Yu Yang 已提交
37
                    __create_var__(in_name, sub_in_name)
38
            else:
Y
Yu Yang 已提交
39
                __create_var__(in_name, in_name)
40

Q
qijun 已提交
41
    for out_name, out_dup in Operator.get_op_outputs(op_type):
42 43 44
        if out_name in outputs:
            kwargs[out_name] = []
            if out_dup:
45 46
                sub_out = outputs[out_name]
                for sub_out_name, _ in sub_out:
Y
Yu Yang 已提交
47
                    __create_var__(out_name, sub_out_name)
48
            else:
Y
Yu Yang 已提交
49
                __create_var__(out_name, out_name)
50

Q
qijun 已提交
51
    for attr_name in Operator.get_op_attr_names(op_type):
Q
qijun 已提交
52 53
        if attr_name in attrs:
            kwargs[attr_name] = attrs[attr_name]
54

55 56 57 58
    return Operator(op_type, **kwargs)


def set_input(scope, op, inputs, place):
Y
Yu Yang 已提交
59
    def __set_input__(var_name, var):
60 61 62 63 64 65 66 67 68 69 70
        if isinstance(var, tuple) or isinstance(var, np.ndarray):
            tensor = scope.find_var(var_name).get_tensor()
            if isinstance(var, tuple):
                tensor.set_lod(var[1])
                var = var[0]
            tensor.set_dims(var.shape)
            tensor.set(var, place)
        elif isinstance(var, float):
            scope.find_var(var_name).set_float(var)
        elif isinstance(var, int):
            scope.find_var(var_name).set_int(var)
Y
Yu Yang 已提交
71

Q
qijun 已提交
72
    for in_name, in_dup in Operator.get_op_inputs(op.type()):
73 74 75
        if in_name in inputs:
            if in_dup:
                sub_in = inputs[in_name]
76
                for sub_in_name, sub_in_val in sub_in:
Y
Yu Yang 已提交
77
                    __set_input__(sub_in_name, sub_in_val)
78
            else:
Y
Yu Yang 已提交
79
                __set_input__(in_name, inputs[in_name])
80 81 82


def set_output_grad(scope, op, outputs, place):
83 84
    def __set_tensor__(name):
        out_tensor = scope.find_var(name).get_tensor()
D
dongzhihong 已提交
85
        grad_tensor = scope.var(grad_var_name(name)).get_tensor()
86 87 88 89 90 91 92 93 94 95
        out_dtype = out_tensor.dtype()
        if out_dtype == core.DataType.FP64:
            data = np.ones(out_tensor.shape(), dtype=np.float64)
        elif out_dtype == core.DataType.FP32:
            data = np.ones(out_tensor.shape(), dtype=np.float32)
        else:
            raise ValueError("Not supported data type " + str(out_dtype))

        grad_tensor.set(data, place)

Q
qijun 已提交
96
    for out_name, out_dup in Operator.get_op_outputs(op.type()):
97 98 99
        if out_name in outputs:
            if out_dup:
                sub_out = outputs[out_name]
100
                for sub_out_name, _ in sub_out:
101
                    __set_tensor__(sub_out_name)
102
            else:
103
                __set_tensor__(out_name)
104 105 106 107 108 109


def get_numeric_gradient(scope,
                         op,
                         inputs,
                         input_to_check,
Y
Yancey 已提交
110
                         output_names,
111 112 113 114 115 116 117 118 119 120 121 122
                         delta=0.005,
                         in_place=False):
    set_input(scope, op, inputs, core.CPUPlace())

    tensor_to_check = scope.find_var(input_to_check).get_tensor()

    def product(dim):
        return reduce(lambda a, b: a * b, dim, 1)

    ctx = core.DeviceContext.create(core.CPUPlace())

    def get_output():
Y
Yancey 已提交
123 124 125 126 127
        sum = 0.0
        for output_name in output_names:
            op.run(scope, ctx)
            sum += np.array(scope.find_var(output_name).get_tensor()).sum()
        return sum
128 129 130

    tensor_to_check = scope.find_var(input_to_check).get_tensor()
    tensor_size = product(tensor_to_check.get_dims())
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
    tensor_to_check_dtype = tensor_to_check.dtype()
    if tensor_to_check_dtype == core.DataType.FP32:
        tensor_to_check_dtype = np.float32
    elif tensor_to_check_dtype == core.DataType.FP64:
        tensor_to_check_dtype = np.float64
    else:
        raise ValueError("Not supported data type " + str(
            tensor_to_check_dtype))

    gradient_flat = np.zeros(shape=(tensor_size, ), dtype=tensor_to_check_dtype)

    def __get_elem__(tensor, i):
        if tensor_to_check_dtype == np.float32:
            return tensor.get_float_element(i)
        else:
            return tensor.get_double_element(i)

    def __set_elem__(tensor, i, e):
        if tensor_to_check_dtype == np.float32:
            tensor.set_float_element(i, e)
        else:
            tensor.set_double_element(i, e)

154 155 156 157
    # we only compute gradient of one element each time.
    # we use a for loop to compute the gradient of every element.
    for i in xrange(tensor_size):
        if in_place:
Q
qijun 已提交
158
            set_input(scope, op, inputs, core.CPUPlace())
159 160

        # get one input element throw it's index i.
161
        origin = __get_elem__(tensor_to_check, i)
162 163
        # add delta to it, run op and then get the sum of the result tensor.
        x_pos = origin + delta
164
        __set_elem__(tensor_to_check, i, x_pos)
165 166 167
        y_pos = get_output()

        if in_place:
Q
qijun 已提交
168
            set_input(scope, op, inputs, core.CPUPlace())
169 170

        x_neg = origin - delta
171
        __set_elem__(tensor_to_check, i, x_neg)
172 173
        y_neg = get_output()

174
        __set_elem__(tensor_to_check, i, origin)
175 176 177 178 179 180 181
        gradient_flat[i] = (y_pos - y_neg) / delta / 2

    return gradient_flat.reshape(tensor_to_check.get_dims())


def get_backward_op(scope, op, no_grad_set):
    backward_op = core.Operator.backward(op, no_grad_set)
Q
qijun 已提交
182
    for input in backward_op.input_vars():
D
dongzhihong 已提交
183
        var = scope.var(input)
184
        var.get_tensor()
Q
qijun 已提交
185
    for output in backward_op.output_vars():
D
dongzhihong 已提交
186
        var = scope.var(output)
187 188 189 190
        var.get_tensor()
    return backward_op


Y
Yu Yang 已提交
191 192 193 194 195 196
def get_gradient(scope,
                 op,
                 inputs,
                 outputs,
                 grad_names,
                 place,
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
                 no_grad_set=None):
    ctx = core.DeviceContext.create(place)

    set_input(scope, op, inputs, place)

    op.run(scope, ctx)

    if no_grad_set is None:
        no_grad_set = set()

    backward_op = get_backward_op(scope, op, no_grad_set)
    set_output_grad(scope, op, outputs, place)

    backward_op.run(scope, ctx)

Y
Yu Yang 已提交
212 213 214 215
    return [
        np.array(scope.find_var(grad_name).get_tensor())
        for grad_name in grad_names
    ]
216 217


Y
Yang Yang(Tony) 已提交
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
def append_input_output(block, op_proto, np_list, is_input):
    '''Insert VarDesc and generate Python variable instance'''
    proto_list = op_proto.inputs if is_input else op_proto.outputs

    def create_var(block, name, np_list, var_proto):
        if name not in np_list:
            assert var_proto.intermediate, "{} not found".format(name)
            shape = None
            lod_level = None
        else:
            np_value = np_list[name]
            if isinstance(np_value, tuple):
                shape = list(np_value[0].shape)
                lod_level = len(np_value[1])
            else:
                shape = list(np_value.shape)
                lod_level = 0
        return block.create_var(
            dtype="float32", shape=shape, lod_level=lod_level, name=name)

    var_dict = {}
    for var_proto in proto_list:
        var_name = str(var_proto.name)
        if is_input:
            if (var_name not in np_list) and var_proto.dispensable:
                continue
            assert (var_name in np_list) or (var_proto.dispensable), \
245
                "Missing {} as input".format(var_name)
Y
Yang Yang(Tony) 已提交
246 247 248 249 250 251 252 253 254 255 256 257 258 259
        if var_proto.duplicable:
            assert isinstance(np_list[var_name], list), \
                "Duplicable {} should be set as list".format(var_name)
            var_list = []
            for (name, np_value) in np_list[var_name]:
                var_list.append(
                    create_var(block, name, {name: np_value}, var_proto))
            var_dict[var_name] = var_list
        else:
            var_dict[var_name] = create_var(block, var_name, np_list, var_proto)

    return var_dict


260
class OpTest(unittest.TestCase):
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
    @classmethod
    def setUpClass(cls):
        '''Fix random seeds to remove randomness from tests'''
        cls._np_rand_state = np.random.get_state()
        cls._py_rand_state = random.getstate()

        np.random.seed(123)
        random.seed(124)

    @classmethod
    def tearDownClass(cls):
        '''Restore random seeds'''
        np.random.set_state(cls._np_rand_state)
        random.setstate(cls._py_rand_state)

Y
Yang Yang(Tony) 已提交
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
    def feed_var(self, input_vars, place):
        feed_map = {}
        for var_name in input_vars:
            if isinstance(input_vars[var_name], list):
                for name, np_value in self.inputs[var_name]:
                    tensor = core.LoDTensor()
                    tensor.set(np_value, place)
                    feed_map[name] = tensor
            else:
                tensor = core.LoDTensor()
                if isinstance(self.inputs[var_name], tuple):
                    tensor.set(self.inputs[var_name][0], place)
                    tensor.set_lod(self.inputs[var_name][1])
                else:
                    tensor.set(self.inputs[var_name], place)
                feed_map[var_name] = tensor

        return feed_map

295
    def check_output_with_place(self, place, atol):
Y
Yang Yang(Tony) 已提交
296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324
        op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)

        program = Program()
        block = program.global_block()

        inputs = append_input_output(block, op_proto, self.inputs, True)
        outputs = append_input_output(block, op_proto, self.outputs, False)

        op = block.append_op(
            type=self.op_type,
            inputs=inputs,
            outputs=outputs,
            attrs=self.attrs if hasattr(self, "attrs") else dict())

        fetch_list = []
        for var_name, var in outputs.iteritems():
            if var_name in self.outputs:
                if isinstance(var, list):
                    for v in var:
                        fetch_list.append(v)
                else:
                    fetch_list.append(var)

        feed_map = self.feed_var(inputs, place)

        exe = Executor(place)
        outs = exe.run(program, feed=feed_map, fetch_list=fetch_list)

        for out_name, out_dup in Operator.get_op_outputs(self.op_type):
325 326 327
            if out_name not in self.outputs:
                continue

Y
Yang Yang(Tony) 已提交
328 329 330 331 332 333 334 335 336 337
            def find_actual(target_name, fetch_list):
                found = [
                    i for i, var in enumerate(fetch_list)
                    if var.name == target_name
                ]
                self.assertTrue(
                    len(found) == 1, "Found {} {}".format(
                        len(found), target_name))
                return found[0]

338 339
            if out_dup:
                sub_out = self.outputs[out_name]
Y
Yancey 已提交
340 341 342 343
                if not isinstance(sub_out, list):
                    raise AssertionError("sub_out type %s is not list",
                                         type(sub_out))
                for sub_out_name, expect in sub_out:
Y
Yang Yang(Tony) 已提交
344
                    idx = find_actual(sub_out_name, fetch_list)
345 346 347
                    actual_t = np.array(outs[idx])
                    expect_t = expect[0] \
                        if isinstance(expect, tuple) else expect
348 349
                    self.assertTrue(
                        np.allclose(
350
                            actual_t, expect_t, atol=atol),
Y
Yang Yang(Tony) 已提交
351 352
                        "Output (" + sub_out_name + ") has diff at " +
                        str(place))
353 354 355 356
                    if isinstance(expect, tuple):
                        self.assertListEqual(
                            actual_t.lod(), expect[1], "Output (" + sub_out_name
                            + ") has different lod at " + str(place))
357
            else:
Y
Yang Yang(Tony) 已提交
358
                idx = find_actual(out_name, fetch_list)
359
                actual_t = outs[idx]
360
                expect = self.outputs[out_name]
361
                expect_t = expect[0] if isinstance(expect, tuple) else expect
362 363
                self.assertTrue(
                    np.allclose(
364
                        actual_t, expect_t, atol=atol),
D
dangqingqing 已提交
365
                    "Output (" + out_name + ") has diff at " + str(place))
366 367 368 369
                if isinstance(expect, tuple):
                    self.assertListEqual(actual_t.lod(), expect[1],
                                         "Output (" + out_name +
                                         ") has different lod at " + str(place))
370

371
    def check_output(self, atol=1e-5):
Q
qijun 已提交
372
        places = [core.CPUPlace()]
Y
Yang Yang(Tony) 已提交
373
        if core.is_compile_gpu() and core.op_support_gpu(self.op_type):
Q
qijun 已提交
374 375
            places.append(core.GPUPlace(0))
        for place in places:
376
            self.check_output_with_place(place, atol)
Q
qijun 已提交
377

378 379 380 381 382 383 384 385 386 387 388 389
    def __assert_is_close(self, numeric_grads, analytic_grads, names,
                          max_relative_error, msg_prefix):

        for a, b, name in itertools.izip(numeric_grads, analytic_grads, names):
            abs_a = np.abs(a)
            abs_a[abs_a < 1e-3] = 1

            diff_mat = np.abs(a - b) / abs_a
            max_diff = np.max(diff_mat)

            def err_msg():
                offset = np.argmax(diff_mat > max_relative_error)
390
                return ("%s Variable %s max gradient diff %f over limit %f, "
391
                        "the first error element is %d, %f, %f") % (
392
                            msg_prefix, name, max_diff, max_relative_error,
393
                            offset, a.flatten()[offset], b.flatten()[offset])
394 395 396 397 398

            self.assertLessEqual(max_diff, max_relative_error, err_msg())

    def check_grad(self,
                   inputs_to_check,
Y
Yancey 已提交
399
                   output_names,
400
                   no_grad_set=None,
401
                   numeric_grad_delta=0.005,
402
                   in_place=False,
Q
Qiao Longfei 已提交
403 404
                   max_relative_error=0.005,
                   user_defined_grads=None):
405
        self.scope = core.Scope()
Q
qijun 已提交
406
        op_inputs = self.inputs if hasattr(self, "inputs") else dict()
407
        op_outputs = self.outputs if hasattr(self, "outputs") else dict()
Q
qijun 已提交
408
        op_attrs = self.attrs if hasattr(self, "attrs") else dict()
409
        self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs,
Q
qijun 已提交
410
                            op_attrs)
411 412 413
        if no_grad_set is None:
            no_grad_set = set()

Y
Yancey 已提交
414 415 416
        if not type(output_names) is list:
            output_names = [output_names]

Q
Qiao Longfei 已提交
417
        numeric_grads = user_defined_grads or [
418 419 420 421 422
            get_numeric_gradient(
                self.scope,
                self.op,
                self.inputs,
                input_to_check,
Y
Yancey 已提交
423
                output_names,
424
                delta=numeric_grad_delta,
425 426 427 428 429 430
                in_place=in_place) for input_to_check in inputs_to_check
        ]
        grad_names = [
            grad_var_name(input_to_check) for input_to_check in inputs_to_check
        ]

Q
qijun 已提交
431
        cpu_place = core.CPUPlace()
Y
Yu Yang 已提交
432 433 434
        cpu_analytic_grads = get_gradient(self.scope, self.op, self.inputs,
                                          self.outputs, grad_names, cpu_place,
                                          no_grad_set)
435

Q
qijun 已提交
436 437 438 439 440 441
        self.__assert_is_close(numeric_grads, cpu_analytic_grads, grad_names,
                               max_relative_error,
                               "Gradient Check On %s" % str(cpu_place))

        if core.is_compile_gpu() and self.op.support_gpu():
            gpu_place = core.GPUPlace(0)
Y
Yu Yang 已提交
442 443 444
            gpu_analytic_grads = get_gradient(self.scope, self.op, self.inputs,
                                              self.outputs, grad_names,
                                              gpu_place, no_grad_set)
445

Q
qijun 已提交
446 447 448 449 450 451 452
            self.__assert_is_close(numeric_grads, gpu_analytic_grads,
                                   grad_names, max_relative_error,
                                   "Gradient Check On %s" % str(gpu_place))

            for c_grad, g_grad, name in itertools.izip(
                    cpu_analytic_grads, gpu_analytic_grads, grad_names):
                self.assertTrue(
Q
qijun 已提交
453
                    np.allclose(
Q
qijun 已提交
454 455
                        c_grad, g_grad, atol=1e-4),
                    "output name: " + name + " has diff")