test_registry.py 700 字节
Newer Older
D
dzhwinter 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
import unittest
import warnings

import paddle.v2.fluid as fluid
import paddle.v2.fluid.framework as framework
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.registry as registry


class TestRegistry(unittest.TestCase):
    def test_registry_layer(self):
        self.layer_type = "mean"
        program = framework.Program()

        x = fluid.layers.data(name='X', shape=[10, 10], dtype='float32')
        output = layers.mean(x)
        place = fluid.CPUPlace()
        exe = fluid.Executor(place)

        X = np.random.random((10, 10)).astype("float32")
        mean_out = exe.run(program, feed={"X": X}, fetch_list=[output])
        self.assertAlmostEqual(np.mean(X), mean_out)