reload.py 12.9 KB
Newer Older
1 2 3 4 5
#!/usr/bin/env python

import os
import time

T
Tyler Ramer 已提交
6 7
from contextlib import closing

8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
from gppylib import gplog
from gppylib.db import dbconn
from gppylib.mainUtils import ExceptionNoStackTraceNeeded
from gppylib.userinput import ask_yesno


logger = gplog.get_default_logger()

class GpReload:
    def __init__(self, options, args):
        self.table_file = options.table_file
        self.port = options.port
        self.database = options.database
        self.interactive = options.interactive
        self.table_list = []
        self.parent_partition_map = {}

    def validate_table(self, schema_name, table_name):
T
Tyler Ramer 已提交
26 27
        conn = dbconn.connect(dbconn.DbURL(dbname=self.database, port=self.port))
        try:
T
Tyler Ramer 已提交
28
            c = dbconn.querySingleton(conn,
29 30 31
                                    """SELECT count(*)
                                       FROM pg_class, pg_namespace
                                       WHERE pg_namespace.nspname = '{schema}'
32 33 34
                                       AND pg_class.relname = '{table}'
                                       AND pg_class.relnamespace = pg_namespace.oid
                                       AND pg_class.relkind != 'v'""".format(schema=schema_name, table=table_name))
35 36 37
            if not c:
                raise ExceptionNoStackTraceNeeded('Table {schema}.{table} does not exist'
                                                  .format(schema=schema_name, table=table_name))
T
Tyler Ramer 已提交
38 39
        finally:
            conn.close()
40 41 42

    def validate_columns(self, schema_name, table_name, sort_column_list):
        columns = []
T
Tyler Ramer 已提交
43 44
        conn = dbconn.connect(dbconn.DbURL(dbname=self.database, port=self.port))
        try:
T
Tyler Ramer 已提交
45
            res = dbconn.query(conn,
46 47 48 49
                          """SELECT attname
                             FROM pg_attribute
                             WHERE attrelid = (SELECT pg_class.oid
                                               FROM pg_class, pg_namespace
50 51 52
                                               WHERE pg_class.relname = '{table}' AND pg_namespace.nspname = '{schema}'
                                               AND pg_class.relnamespace = pg_namespace.oid
                                               AND pg_class.relkind != 'v')"""
53 54 55 56 57 58 59
                                 .format(table=table_name, schema=schema_name))
            for cols in res.fetchall():
                columns.append(cols[0].strip())
            for c in sort_column_list:
                if c[0] not in columns:
                    raise ExceptionNoStackTraceNeeded('Table {schema}.{table} does not have column {col}'
                                                       .format(schema=schema_name, table=table_name, col=c[0]))
T
Tyler Ramer 已提交
60 61
        finally:
            conn.close()
62 63 64

    def validate_mid_level_partitions(self, schema_name, table_name):
        partition_level, max_level = None, None
T
Tyler Ramer 已提交
65 66
        conn = dbconn.connect(dbconn.DbURL(dbname=self.database, port=self.port))
        try:
67 68 69 70
            parent_schema, parent_table = self.parent_partition_map[(schema_name, table_name)]
            if (parent_schema, parent_table) == (schema_name, table_name):
                return
            try:
T
Tyler Ramer 已提交
71
                max_level = dbconn.querySingleton(conn,
72 73 74 75 76 77 78 79 80
                                                   """SELECT max(partitionlevel)
                                                      FROM pg_partitions
                                                      WHERE tablename='%s'
                                                      AND schemaname='%s'
                                                   """ % (parent_table, parent_schema))
            except Exception as e:
                logger.debug('Unable to get the maximum partition level for table %s: (%s)' % (table_name, str(e)))

            try:
T
Tyler Ramer 已提交
81
                partition_level = dbconn.querySingleton(conn,
82 83 84 85 86 87 88 89 90 91 92
                                                         """SELECT partitionlevel
                                                            FROM pg_partitions
                                                            WHERE partitiontablename='%s'
                                                            AND partitionschemaname='%s'
                                                         """ % (table_name, schema_name))
            except Exception as e:
                logger.debug('Unable to get the partition level for table %s: (%s)' % (table_name, str(e)))

            if partition_level != max_level:
                logger.error('Partition level of the table = %s, Max partition level = %s' % (partition_level, max_level))
                raise Exception('Mid Level partition %s.%s is not supported by gpreload. Please specify only leaf partitions or parent table name' % (schema_name, table_name))
T
Tyler Ramer 已提交
93 94
        finally:
            conn.close()
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158

    def validate_options(self):
        if self.table_file is None:
            raise ExceptionNoStackTraceNeeded('Please specify table file')

        if not os.path.exists(self.table_file):
            raise ExceptionNoStackTraceNeeded('Unable to find table file "{table_file}"'.format(table_file=self.table_file))

        if self.database is None:
            raise ExceptionNoStackTraceNeeded('Please specify the correct database')

        if self.port is None:
            if 'PGPORT' not in os.environ:
                raise ExceptionNoStackTraceNeeded('Please specify PGPORT using -p option or set PGPORT in the environment')
            self.port = os.environ['PGPORT']

    def parse_columns(self, columns):
        sort_column_list = []
        for c in columns.split(','):
            toks = c.strip().split()
            if not toks:
                raise Exception('Empty column')
            col = toks[0].strip()
            if len(toks) == 1:
                sort_order = 'asc'
            elif len(toks) == 2:
                sort_order = toks[1].strip()
            else:
                raise Exception('Invalid sort order specified')

            if sort_order and (sort_order != 'asc' and sort_order != 'desc'):
                raise Exception('Invalid sort order {so}'.format(so=sort_order))
            sort_column_list.append((col, sort_order))
        return sort_column_list

    def parse_line(self, line):
        table, sort_columns = line.split(':')
        schema_name, table_name = [t.strip() for t in table.split('.')]
        sort_column_list = self.parse_columns(sort_columns)

        if not schema_name or not table_name:
            raise Exception()

        return schema_name, table_name, sort_column_list

    def validate_table_file(self):
        table_list = []
        with open(self.table_file) as fp:
            for line in fp:
                line = line.strip()
                try:
                    schema_name, table_name, sort_column_list = self.parse_line(line)
                except Exception as e:
                    raise ExceptionNoStackTraceNeeded("Line '{line}' is not formatted correctly: {ex}".format(line=line, ex=e))
                table_list.append((schema_name, table_name, sort_column_list))
        return table_list

    def validate_tables(self):
        for schema_name, table_name, sort_column_list in self.table_list:
            self.validate_mid_level_partitions(schema_name, table_name)
            self.validate_table(schema_name, table_name)
            self.validate_columns(schema_name, table_name, sort_column_list)

    def get_row_count(self, table_name):
T
Tyler Ramer 已提交
159
        with closing(dbconn.connect(dbconn.DbURL(dbname=self.database, port=self.port))) as conn:
T
Tyler Ramer 已提交
160
            c = dbconn.querySingleton(conn, 'SELECT count(*) FROM {table}'.format(table=table_name))
161 162 163
        return c

    def check_indexes(self, schema_name, table_name):
T
Tyler Ramer 已提交
164
        with closing(dbconn.connect(dbconn.DbURL(dbname=self.database, port=self.port))) as conn:
T
Tyler Ramer 已提交
165
            c = dbconn.querySingleton(conn, """SELECT count(*)
166 167 168
                                             FROM pg_index
                                             WHERE indrelid = (SELECT pg_class.oid
                                                               FROM pg_class, pg_namespace
169
                                                               WHERE pg_class.relname='{table}' AND pg_namespace.nspname='{schema}' AND pg_class.relnamespace = pg_namespace.oid)""".format(table=table_name, schema=schema_name))
170 171 172 173 174 175 176 177 178
            if c != 0:
                if self.interactive:
                    return ask_yesno(None,
                                    'Table {schema}.{table} has indexes. This might slow down table reload. Do you still want to continue ?'
                                    .format(schema=schema_name, table=table_name),
                                    'N')
        return True

    def get_table_size(self, schema_name, table_name):
T
Tyler Ramer 已提交
179
        with closing(dbconn.connect(dbconn.DbURL(dbname=self.database, port=self.port))) as conn:
T
Tyler Ramer 已提交
180
            size = dbconn.querySingleton(conn,
181 182 183 184 185
                                       """SELECT pg_size_pretty(pg_relation_size('{schema}.{table}'))"""
                                       .format(schema=schema_name, table=table_name))
        return size

    def get_parent_partitions(self):
T
Tyler Ramer 已提交
186
        with closing(dbconn.connect(dbconn.DbURL(dbname=self.database, port=self.port))) as conn:
187 188 189 190 191
            for schema, table, col_list in self.table_list:
                PARENT_PARTITION_TABLENAME = """SELECT schemaname, tablename
                                                FROM pg_partitions
                                                WHERE partitiontablename='%s' 
                                                AND partitionschemaname='%s'""" % (table, schema)
T
Tyler Ramer 已提交
192
                res = dbconn.query(conn, PARENT_PARTITION_TABLENAME)
193 194 195 196 197 198
                for r in res:
                    self.parent_partition_map[(schema, table)] = (r[0], r[1]) 

                if (schema, table) not in self.parent_partition_map:
                    self.parent_partition_map[(schema, table)] = (schema, table)

T
Tyler Ramer 已提交
199
        return self.parent_partition_map
200 201

    def reload_tables(self):
T
Tyler Ramer 已提交
202 203
        conn =  dbconn.connect(dbconn.DbURL(dbname=self.database, port=self.port))
        try:
204 205 206 207 208 209 210 211 212 213 214 215 216 217
            conn.commit()   #Commit the implicit transaction started by connect
            for schema_name, table_name, sort_column_list in self.table_list:
                logger.info('Starting reload for table {schema}.{table}'.format(schema=schema_name, table=table_name))
                logger.info('Table {schema}.{table} has {rows} rows and {size} size'
                        .format(schema=schema_name, table=table_name,
                         rows=self.get_row_count('%s.%s' % (schema_name, table_name)),
                         size=self.get_table_size(schema_name, table_name)))
                if not self.check_indexes(schema_name, table_name):
                    logger.info('Skipping reload for {schema}.{table}'.format(schema=schema_name, table=table_name))
                    continue
                start = time.time()
                dbconn.execSQL(conn, 'BEGIN')
                dbconn.execSQL(conn, """CREATE TEMP TABLE temp_{table} AS SELECT * FROM {schema}.{table}"""
                                     .format(schema=schema_name, table=table_name))
T
Tyler Ramer 已提交
218 219
                temp_row_count = dbconn.querySingleton(conn, """SELECT count(*) FROM temp_{table}""".format(table=table_name))
                table_row_count = dbconn.querySingleton(conn, """SELECT count(*) from {schema}.{table}"""
220 221 222 223 224 225 226 227 228 229 230 231 232
                                                                    .format(table=table_name, schema=schema_name))
                if temp_row_count != table_row_count:
                    raise Exception('Row count for temp table(%s) does not match(%s)' % (temp_row_count, table_row_count))
                dbconn.execSQL(conn, 'TRUNCATE TABLE {schema}.{table}'.format(schema=schema_name, table=table_name))
                sort_order = ['%s %s' % (c[0], c[1]) for c in sort_column_list]
                parent_schema_name, parent_table_name = self.parent_partition_map[(schema_name, table_name)]
                dbconn.execSQL(conn, """INSERT INTO {parent_schema}.{parent_table} SELECT * FROM temp_{table} ORDER BY {column_list}"""
                                     .format(parent_schema=parent_schema_name, parent_table=parent_table_name, 
                                             table=table_name, column_list=','.join(sort_order)))
                conn.commit()
                end = time.time()
                logger.info('Finished reload for table {schema}.{table} in time {sec} seconds'
                            .format(schema=schema_name, table=table_name, sec=(end-start)))
T
Tyler Ramer 已提交
233 234
        finally:
            conn.close()
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249

    def run(self):
        self.validate_options()
        logger.info('Validating table file {table_file}'.format(table_file=self.table_file))
        self.table_list = self.validate_table_file()
        logger.info('Obtaining parent partitions')
        self.get_parent_partitions()
        logger.info('Validating tables')
        self.validate_tables()
        logger.info('Reloading tables')
        self.reload_tables()

    def cleanup(self):
        pass