test_paddle_model_convertor.py 9.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
#!/usr/bin/env python3

# Copyright (c) 2022 CINN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os, sys
import numpy as np
import unittest
import logging, argparse

import paddle

from cinn.frontend import PaddleModelConvertor
from cinn.common import is_compiled_with_cuda, DefaultNVGPUTarget
from cinn.runtime import seed as cinn_seed
from ops.op_test import OpTestTool
from op_mappers.op_mapper_test import OpMapperTest

logging.basicConfig(level=os.environ.get('LOG_LEVEL', 'INFO').upper())
logger = logging.getLogger(name="paddle_model_convertor")

parser = argparse.ArgumentParser(
34 35
    description='Load Paddle Model File and Running at CINN'
)
36
parser.add_argument(
37 38
    "--path", help="The path to load the paddle model", type=str, required=True
)
39 40 41 42 43
parser.add_argument(
    "-m",
    "--model_filename",
    help="The filename of model file, default \"__model__\"",
    type=str,
44 45
    default="__model__",
)
46 47 48
parser.add_argument(
    "-p",
    "--params_filename",
49
    help="The filename of model parameter file, default None, in which each parameter will saved in each file",
50
    type=str,
51 52
    default=None,
)
53 54 55 56 57
parser.add_argument(
    "-cuda",
    "--enable_cuda",
    help="Whether enable CUDA, default True",
    type=bool,
58 59
    default=True,
)
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
args = parser.parse_args()

np.random.seed(1234)
paddle.seed(1234)
cinn_seed(1234)

paddle.enable_static()

# first save paddle model like:
# ```
# import paddle
# paddle.enable_static()

# x = paddle.static.data(name='x', shape=[10, 12, 128, 128], dtype='float32')
# y = paddle.static.data(name='y', shape=[10, 12, 128, 128], dtype='float32')
# prediction = paddle.stack([x, y], 1)

# place = paddle.CUDAPlace(0)

# exe = paddle.static.Executor(place)
# exe.run(paddle.static.default_startup_program())
# prog = paddle.static.default_main_program()

# paddle.fluid.io.save_inference_model("./stack", [x.name, y.name], [prediction], exe, prog)
# ```
# Second load and run model like:
# ```
# python test_paddle_model_convertor.py --path build/thirds/resnet_model -m "__model__" -p "params"
# ```


class TestPaddleModel(OpMapperTest):
    def setUp(self):
        if args.enable_cuda:
            self.target = DefaultNVGPUTarget()
            self.place = paddle.CUDAPlace(0)
        else:
            self.target = DefaultHostTarget()
            self.place = paddle.CPUPlace()

        self.model_dir = args.path
        self.model_filename = args.model_filename
        self.params_filename = args.params_filename

        logger.info(
105 106 107 108
            "Run Model From \"{}\", which model filename is \"{}\", and parameter filename is \"{}\"".format(
                self.model_dir, self.model_filename, self.params_filename
            )
        )
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130

        self.load_paddle_program()
        self.init_case()

    @staticmethod
    def eliminate_unkown_shape(shape):
        return [1 if dim == -1 else dim for dim in shape]

    def get_paddle_op_attrs(self, op):
        attr_map = {}
        for n in op.attr_names:
            attr_map[n] = op.attr(n)

        return attr_map

    def init_case(self):
        self.feed_data = dict()
        for i in range(len(self.feed_names)):
            # check no repeat variable
            self.assertNotIn(
                self.feed_names[i],
                self.feed_data,
131 132
                msg="Repeat feed name: " + self.feed_names[i],
            )
133 134 135 136 137 138 139 140 141

            dtype = self.paddleddtype2nptype(self.feed_dtypes[i])
            # random int type data should not limited to [0, 1]
            high = 1 if ("int" not in dtype) else self.feed_shapes[i][0]

            # the paddle's feed list need dict not list
            self.feed_data[self.feed_names[i]] = self.random(
                self.eliminate_unkown_shape(self.feed_shapes[i]),
                dtype,
142 143
                high=high,
            )
144 145 146 147

    def load_paddle_program(self):
        self.exe = paddle.static.Executor(self.place)

148 149 150 151 152 153 154 155 156 157
        [
            self.inference_program,
            self.feed_names,
            self.fetch_targets,
        ] = paddle.fluid.io.load_inference_model(
            dirname=self.model_dir,
            executor=self.exe,
            model_filename=self.model_filename,
            params_filename=self.params_filename,
        )
158 159 160 161 162

        self.param_vars = paddle.load(
            self.model_dir,
            model_filename=self.model_filename,
            params_filename=self.params_filename,
163 164
            return_numpy=True,
        )
165 166 167 168

        logger.debug(msg="Program:\n{}".format(self.inference_program))
        logger.debug(msg="Param List: {}".format(self.param_vars.keys()))
        logger.debug(msg="Feed List: {}".format(self.feed_names))
169 170 171 172 173
        logger.debug(
            msg="Fetch List: {}".format(
                [var.name for var in self.fetch_targets]
            )
        )
174 175 176 177 178 179 180 181 182 183 184 185

        self.feed_shapes = []
        self.feed_dtypes = []

        for var in self.inference_program.list_vars():
            if var.name in self.feed_names:
                self.feed_shapes.append(var.shape)
                self.feed_dtypes.append(var.dtype)

        self.assertEqual(
            len(self.feed_names),
            len(self.feed_shapes),
186 187
            msg="Cannot found some feed var in program!",
        )
188 189 190 191 192 193

    def build_paddle_program(self, target):
        self.paddle_outputs = self.exe.run(
            self.inference_program,
            feed=self.feed_data,
            fetch_list=self.fetch_targets,
194 195
            return_numpy=True,
        )
196 197 198 199 200 201
        logger.debug("Paddle Result:\n{}".format(self.paddle_outputs))

    def build_cinn_program(self, target):
        self.assertEqual(
            1,
            self.inference_program.num_blocks,
202 203
            msg="CINN only support single block now",
        )
204 205 206 207 208 209 210 211

        feed_with_param = list()

        convertor = PaddleModelConvertor(target)
        for i in range(len(self.feed_names)):
            convertor.create_input(
                dtype=self.paddleddtype2nptype(self.feed_dtypes[i]),
                shape=self.feed_data[self.feed_names[i]].shape,
212 213
                name=self.feed_names[i],
            )
214 215 216 217 218 219
            feed_with_param.append(self.feed_names[i])

        for param_name, param_value in self.param_vars.items():
            convertor.create_input(
                dtype=str(param_value.dtype),
                shape=param_value.shape,
220 221
                name=param_name,
            )
222 223 224 225 226
            feed_with_param.append(param_name)

        for op in self.inference_program.global_block().ops:
            if op.desc.type() == "feed" or op.desc.type() == "fetch":
                continue
227 228 229 230 231 232
            convertor.append_op(
                op.desc.type(),
                op.desc.inputs(),
                op.desc.outputs(),
                self.get_paddle_op_attrs(op),
            )
233 234 235 236 237

        prog = convertor()

        # get cinn input list
        inputs = prog.get_inputs()
238 239 240
        logger.debug(
            "CINN Input List: {}".format([var.name() for var in inputs])
        )
241 242 243
        self.assertEqual(
            len(feed_with_param),
            len(inputs),
244 245
            msg="The paddle's input list not equal to cinn's input list!",
        )
246 247 248 249 250 251 252 253 254 255 256 257

        # map the name the variable
        input_dict = {var.name(): var for var in inputs}

        cinn_inputs = []
        cinn_feed_datas = []
        for name in feed_with_param:
            cinn_name = convertor.get_cinn_name(name)

            self.assertIn(
                cinn_name,
                input_dict,
258 259 260 261 262
                msg="Cannot find variable "
                + cinn_name
                + " in cinn program's input, which are "
                + str(input_dict.items()),
            )
263 264 265 266 267 268 269 270
            cinn_inputs.append(input_dict[cinn_name])

            if name in self.feed_data:
                cinn_feed_datas.append(self.feed_data[name])
            else:
                self.assertIn(
                    name,
                    self.param_vars,
271 272
                    msg="The input variable should in feed list or parameter list",
                )
273 274 275 276 277 278 279 280 281
                cinn_feed_datas.append(self.param_vars[name])

        # get cinn output list
        fetch_names = {var.name for var in self.fetch_targets}
        output_dict = convertor.get_fetch_list(fetch_names)
        cinn_output = [output_dict[var.name] for var in self.fetch_targets]

        # run and get result
        self.cinn_outputs = self.get_cinn_output(
282 283
            prog, target, cinn_inputs, cinn_feed_datas, cinn_output, passes=[]
        )
284 285 286 287

        logger.debug("CINN Result:\n{}".format(self.cinn_outputs))

    def test_check_results(self):
6
6clc 已提交
288 289
        # TODO(6clc): There is a random accuracy problem,
        #             temporarily adjust max_absolute_error from 1e-6 to 1e-3
290 291 292
        self.check_outputs_and_grads(
            max_relative_error=1e-2, max_absolute_error=1e-3
        )
293 294 295 296 297 298 299


if __name__ == "__main__":
    tester = unittest.defaultTestLoader.loadTestsFromTestCase(TestPaddleModel)
    test_runer = unittest.TextTestRunner()
    res = test_runer.run(tester)
    sys.exit(not res.wasSuccessful())