reload.py 12.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
#!/usr/bin/env python

import os
import time

from gppylib import gplog
from gppylib.db import dbconn
from gppylib.db.dbconn import execSQL, execSQLForSingleton
from gppylib.mainUtils import ExceptionNoStackTraceNeeded
from gppylib.userinput import ask_yesno


logger = gplog.get_default_logger()

class GpReload:
    def __init__(self, options, args):
        self.table_file = options.table_file
        self.port = options.port
        self.database = options.database
        self.interactive = options.interactive
        self.table_list = []
        self.parent_partition_map = {}

    def validate_table(self, schema_name, table_name):
        with dbconn.connect(dbconn.DbURL(dbname=self.database, port=self.port)) as conn:
            c = execSQLForSingleton(conn,
                                    """SELECT count(*)
                                       FROM pg_class, pg_namespace
                                       WHERE pg_namespace.nspname = '{schema}'
                                       AND pg_class.relname = '{table}'""".format(schema=schema_name, table=table_name))
            if not c:
                raise ExceptionNoStackTraceNeeded('Table {schema}.{table} does not exist'
                                                  .format(schema=schema_name, table=table_name))

    def validate_columns(self, schema_name, table_name, sort_column_list):
        columns = []
        with dbconn.connect(dbconn.DbURL(dbname=self.database, port=self.port)) as conn:
            res = execSQL(conn,
                          """SELECT attname
                             FROM pg_attribute
                             WHERE attrelid = (SELECT pg_class.oid
                                               FROM pg_class, pg_namespace
                                               WHERE pg_class.relname = '{table}' AND pg_namespace.nspname = '{schema}')"""
                                 .format(table=table_name, schema=schema_name))
            for cols in res.fetchall():
                columns.append(cols[0].strip())
            for c in sort_column_list:
                if c[0] not in columns:
                    raise ExceptionNoStackTraceNeeded('Table {schema}.{table} does not have column {col}'
                                                       .format(schema=schema_name, table=table_name, col=c[0]))

    def validate_mid_level_partitions(self, schema_name, table_name):
        partition_level, max_level = None, None
        with dbconn.connect(dbconn.DbURL(dbname=self.database, port=self.port)) as conn:
            parent_schema, parent_table = self.parent_partition_map[(schema_name, table_name)]
            if (parent_schema, parent_table) == (schema_name, table_name):
                return
            try:
                max_level = dbconn.execSQLForSingleton(conn, 
                                                   """SELECT max(partitionlevel)
                                                      FROM pg_partitions
                                                      WHERE tablename='%s'
                                                      AND schemaname='%s'
                                                   """ % (parent_table, parent_schema))
            except Exception as e:
                logger.debug('Unable to get the maximum partition level for table %s: (%s)' % (table_name, str(e)))

            try:
                partition_level = dbconn.execSQLForSingleton(conn,
                                                         """SELECT partitionlevel
                                                            FROM pg_partitions
                                                            WHERE partitiontablename='%s'
                                                            AND partitionschemaname='%s'
                                                         """ % (table_name, schema_name))
            except Exception as e:
                logger.debug('Unable to get the partition level for table %s: (%s)' % (table_name, str(e)))

            if partition_level != max_level:
                logger.error('Partition level of the table = %s, Max partition level = %s' % (partition_level, max_level))
                raise Exception('Mid Level partition %s.%s is not supported by gpreload. Please specify only leaf partitions or parent table name' % (schema_name, table_name))

    def validate_options(self):
        if self.table_file is None:
            raise ExceptionNoStackTraceNeeded('Please specify table file')

        if not os.path.exists(self.table_file):
            raise ExceptionNoStackTraceNeeded('Unable to find table file "{table_file}"'.format(table_file=self.table_file))

        if self.database is None:
            raise ExceptionNoStackTraceNeeded('Please specify the correct database')

        if self.port is None:
            if 'PGPORT' not in os.environ:
                raise ExceptionNoStackTraceNeeded('Please specify PGPORT using -p option or set PGPORT in the environment')
            self.port = os.environ['PGPORT']

    def parse_columns(self, columns):
        sort_column_list = []
        for c in columns.split(','):
            toks = c.strip().split()
            if not toks:
                raise Exception('Empty column')
            col = toks[0].strip()
            if len(toks) == 1:
                sort_order = 'asc'
            elif len(toks) == 2:
                sort_order = toks[1].strip()
            else:
                raise Exception('Invalid sort order specified')

            if sort_order and (sort_order != 'asc' and sort_order != 'desc'):
                raise Exception('Invalid sort order {so}'.format(so=sort_order))
            sort_column_list.append((col, sort_order))
        return sort_column_list

    def parse_line(self, line):
        table, sort_columns = line.split(':')
        schema_name, table_name = [t.strip() for t in table.split('.')]
        sort_column_list = self.parse_columns(sort_columns)

        if not schema_name or not table_name:
            raise Exception()

        return schema_name, table_name, sort_column_list

    def validate_table_file(self):
        table_list = []
        with open(self.table_file) as fp:
            for line in fp:
                line = line.strip()
                try:
                    schema_name, table_name, sort_column_list = self.parse_line(line)
                except Exception as e:
                    raise ExceptionNoStackTraceNeeded("Line '{line}' is not formatted correctly: {ex}".format(line=line, ex=e))
                table_list.append((schema_name, table_name, sort_column_list))
        return table_list

    def validate_tables(self):
        for schema_name, table_name, sort_column_list in self.table_list:
            self.validate_mid_level_partitions(schema_name, table_name)
            self.validate_table(schema_name, table_name)
            self.validate_columns(schema_name, table_name, sort_column_list)

    def get_row_count(self, table_name):
        with dbconn.connect(dbconn.DbURL(dbname=self.database, port=self.port)) as conn:
            c = execSQLForSingleton(conn, 'SELECT count(*) FROM {table}'.format(table=table_name))
        return c

    def check_indexes(self, schema_name, table_name):
        with dbconn.connect(dbconn.DbURL(dbname=self.database, port=self.port)) as conn:
            c = execSQLForSingleton(conn, """SELECT count(*)
                                             FROM pg_index
                                             WHERE indrelid = (SELECT pg_class.oid
                                                               FROM pg_class, pg_namespace
                                                               WHERE pg_class.relname='{table}' AND pg_namespace.nspname='{schema}')
                                          """.format(table=table_name, schema=schema_name))
            if c != 0:
                if self.interactive:
                    return ask_yesno(None,
                                    'Table {schema}.{table} has indexes. This might slow down table reload. Do you still want to continue ?'
                                    .format(schema=schema_name, table=table_name),
                                    'N')
        return True

    def get_table_size(self, schema_name, table_name):
        with dbconn.connect(dbconn.DbURL(dbname=self.database, port=self.port)) as conn:
            size = execSQLForSingleton(conn,
                                       """SELECT pg_size_pretty(pg_relation_size('{schema}.{table}'))"""
                                       .format(schema=schema_name, table=table_name))
        return size

    def get_parent_partitions(self):
        with dbconn.connect(dbconn.DbURL(dbname=self.database, port=self.port)) as conn:
            for schema, table, col_list in self.table_list:
                PARENT_PARTITION_TABLENAME = """SELECT schemaname, tablename
                                                FROM pg_partitions
                                                WHERE partitiontablename='%s' 
                                                AND partitionschemaname='%s'""" % (table, schema)
                res = execSQL(conn, PARENT_PARTITION_TABLENAME)
                for r in res:
                    self.parent_partition_map[(schema, table)] = (r[0], r[1]) 

                if (schema, table) not in self.parent_partition_map:
                    self.parent_partition_map[(schema, table)] = (schema, table)

            return self.parent_partition_map 

    def reload_tables(self):
        with dbconn.connect(dbconn.DbURL(dbname=self.database, port=self.port)) as conn:
            conn.commit()   #Commit the implicit transaction started by connect
            for schema_name, table_name, sort_column_list in self.table_list:
                logger.info('Starting reload for table {schema}.{table}'.format(schema=schema_name, table=table_name))
                logger.info('Table {schema}.{table} has {rows} rows and {size} size'
                        .format(schema=schema_name, table=table_name,
                         rows=self.get_row_count('%s.%s' % (schema_name, table_name)),
                         size=self.get_table_size(schema_name, table_name)))
                if not self.check_indexes(schema_name, table_name):
                    logger.info('Skipping reload for {schema}.{table}'.format(schema=schema_name, table=table_name))
                    continue
                start = time.time()
                dbconn.execSQL(conn, 'BEGIN')
                dbconn.execSQL(conn, """CREATE TEMP TABLE temp_{table} AS SELECT * FROM {schema}.{table}"""
                                     .format(schema=schema_name, table=table_name))
                temp_row_count = dbconn.execSQLForSingleton(conn, """SELECT count(*) FROM temp_{table}""".format(table=table_name))
                table_row_count = dbconn.execSQLForSingleton(conn, """SELECT count(*) from {schema}.{table}"""
                                                                    .format(table=table_name, schema=schema_name))
                if temp_row_count != table_row_count:
                    raise Exception('Row count for temp table(%s) does not match(%s)' % (temp_row_count, table_row_count))
                dbconn.execSQL(conn, 'TRUNCATE TABLE {schema}.{table}'.format(schema=schema_name, table=table_name))
                sort_order = ['%s %s' % (c[0], c[1]) for c in sort_column_list]
                parent_schema_name, parent_table_name = self.parent_partition_map[(schema_name, table_name)]
                dbconn.execSQL(conn, """INSERT INTO {parent_schema}.{parent_table} SELECT * FROM temp_{table} ORDER BY {column_list}"""
                                     .format(parent_schema=parent_schema_name, parent_table=parent_table_name, 
                                             table=table_name, column_list=','.join(sort_order)))
                conn.commit()
                end = time.time()
                logger.info('Finished reload for table {schema}.{table} in time {sec} seconds'
                            .format(schema=schema_name, table=table_name, sec=(end-start)))

    def run(self):
        self.validate_options()
        logger.info('Validating table file {table_file}'.format(table_file=self.table_file))
        self.table_list = self.validate_table_file()
        logger.info('Obtaining parent partitions')
        self.get_parent_partitions()
        logger.info('Validating tables')
        self.validate_tables()
        logger.info('Reloading tables')
        self.reload_tables()

    def cleanup(self):
        pass