提交 96f5e6e4 编写于 作者: S Shoaib Lari

unset_search_path: backend fix

For commands called directly by the user, we provide the fix.

Since Behave and unit tests are supposed to behave as a normal user,
we do not provide the fix.  The fix is supposed to be done by the
commands themselves, and we want to test with an unmodified search_path
in the actual tests.
Co-authored-by: NJamie McAtamney <jmcatamney@pivotal.io>
Co-authored-by: NKalen Krempely <kkrempely@pivotal.io>
Co-authored-by: NNikolaos Kalampalikis <nkalampalikis@pivotal.io>
Co-authored-by: NShoaib Lari <slari@pivotal.io>
Co-authored-by: NDavid Krieger <dkrieger@pivotal.io>
上级 452a463f
......@@ -366,7 +366,7 @@ class ClusterConfiguration():
print '%s: fetched cluster configuration' % (datetime.datetime.now())
try:
with dbconn.connect(dburl, utility=True) as conn:
with dbconn.connect(dburl, utility=True, unsetSearchPath=False) as conn:
resultsets = dbconn.execSQL(conn, query).fetchall()
except Exception, e:
print e
......
......@@ -222,7 +222,10 @@ def do_list(skipvalidation):
def get_gucs_from_database(gucname):
try:
dburl = dbconn.DbURL()
conn = dbconn.connect(dburl, False)
# we always want to unset search path except when getting the
# 'search_path' GUC itself
unsetSearchPath = gucname != 'search_path'
conn = dbconn.connect(dburl, False, unsetSearchPath=unsetSearchPath)
query = ToolkitQuery(gucname).query
cursor = dbconn.execSQL(conn, query)
# we assume that all roles are primary due to the query.
......
......@@ -155,7 +155,7 @@ class DbURL:
def connect(dburl, utility=False, verbose=False,
encoding=None, allowSystemTableMods=False, logConn=True):
encoding=None, allowSystemTableMods=False, logConn=True, unsetSearchPath=True):
if utility:
options = '-c gp_session_role=utility'
......@@ -225,8 +225,14 @@ def connect(dburl, utility=False, verbose=False,
if cnx is None:
raise ConnectionError('Failed to connect to %s' % dbbase)
# NOTE: the code to set ALWAYS_SECURE_SEARCH_PATH_SQL below assumes it is not part of an existing transaction
conn = pgdb.pgdbCnx(cnx)
# unset search path due to CVE-2018-1058
if unsetSearchPath:
ALWAYS_SECURE_SEARCH_PATH_SQL = "SELECT pg_catalog.set_config('search_path', '', false)"
execSQL(conn, ALWAYS_SECURE_SEARCH_PATH_SQL).close()
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
......
#!/usr/bin/env python
#
# Copyright (c) Greenplum Inc 2008. All Rights Reserved.
# Copyright (c) Greenplum Inc 2008. All Rights Reserved.
#
# Unit Testing of catalog module.
#
......@@ -20,16 +20,16 @@ logger=gplog.get_default_logger()
@skipIfDatabaseDown()
class catalogTestCase(unittest.TestCase):
def setUp(self):
self.dburl=dbconn.DbURL()
self.conn = dbconn.connect(self.dburl)
self.conn = dbconn.connect(self.dburl, unsetSearchPath=False)
def tearDown(self):
self.conn.close()
pass
#------------------------------- Mainline --------------------------------
if __name__ == '__main__':
unittest.main()
unittest.main()
#!/usr/bin/env python
#
# Copyright (c) Greenplum Inc 2008. All Rights Reserved.
# Copyright (c) Greenplum Inc 2008. All Rights Reserved.
#
# Unit Testing of catalog module.
#
......@@ -20,16 +20,16 @@ logger=gplog.get_default_logger()
@skipIfDatabaseDown()
class catalogTestCase(unittest.TestCase):
def setUp(self):
self.dburl=dbconn.DbURL()
self.conn = dbconn.connect(self.dburl)
self.conn = dbconn.connect(self.dburl, unsetSearchPath=False)
def tearDown(self):
self.conn.close()
pass
#------------------------------- Mainline --------------------------------
if __name__ == '__main__':
unittest.main()
unittest.main()
......@@ -58,6 +58,19 @@ class ConnectTestCase(unittest.TestCase):
self.assertEqual(actual, encoding)
def test_secure_search_path_set(self):
with dbconn.connect(self.url) as conn:
result = dbconn.execSQLForSingleton(conn, "SELECT setting FROM pg_settings WHERE name='search_path'")
self.assertEqual(result, '')
def test_secure_search_path_not_set(self):
with dbconn.connect(self.url, unsetSearchPath=False) as conn:
result = dbconn.execSQLForSingleton(conn, "SELECT setting FROM pg_settings WHERE name='search_path'")
self.assertEqual(result, '"$user",public')
if __name__ == '__main__':
unittest.main()
......@@ -8,8 +8,8 @@ from gppylib.commands.gp import GpStart
from gppylib.db import dbconn
class GpExpandTestCase(unittest.TestCase):
EXPANSION_INPUT_FILE = 'test_expand.input'
EXPANSION_INPUT_FILE = 'test_expand.input'
GP_COMMAND_FAULT_POINT = 'GP_COMMAND_FAULT_POINT'
GPMGMT_FAULT_POINT = 'GPMGMT_FAULT_POINT'
MASTER_DATA_DIRECTORY = os.environ['MASTER_DATA_DIRECTORY']
......@@ -29,13 +29,13 @@ class GpExpandTestCase(unittest.TestCase):
os.remove(self.EXPANSION_INPUT_FILE)
if self.GP_COMMAND_FAULT_POINT in os.environ:
del os.environ[self.GP_COMMAND_FAULT_POINT]
def _create_expansion_input_file(self):
"""This code has been taken from system_management utilities
test suite.
creates a expansion input file"""
with dbconn.connect(dbconn.DbURL()) as conn:
with dbconn.connect(dbconn.DbURL(), unsetSearchPath=False) as conn:
next_dbid = dbconn.execSQLForSingletonRow(conn,
"select max(dbid)+1 \
from pg_catalog.gp_segment_configuration")[0]
......@@ -56,9 +56,9 @@ class GpExpandTestCase(unittest.TestCase):
where role='m'").fetchall()[0][0]
if next_mir_port is None or next_mir_port == ' ' or next_mir_port == 0:
mirroring_on = False
mirroring_on = False
else:
mirroring_on = True
mirroring_on = True
next_pri_replication_port = dbconn.execSQL(conn,
"select max(replication_port)+1 \
from pg_catalog.gp_segment_configuration \
......@@ -81,13 +81,13 @@ class GpExpandTestCase(unittest.TestCase):
with open(self.EXPANSION_INPUT_FILE, 'w') as outfile:
for i in range(self.SEGMENTS):
pri_datadir = os.path.join(os.getcwd(), 'new_pri_seg%d' % i)
mir_datadir = os.path.join(os.getcwd(), 'new_mir_seg%d' % i)
pri_datadir = os.path.join(os.getcwd(), 'new_pri_seg%d' % i)
mir_datadir = os.path.join(os.getcwd(), 'new_mir_seg%d' % i)
temp_str = "%s:%s:%d:%s:%d:%d:%s" % (self.primary_host_name, self.primary_host_address, next_pri_port, pri_datadir, next_dbid, next_content, 'p')
if mirroring_on:
temp_str = temp_str + ":" + str(next_pri_replication_port)
temp_str = temp_str + "\n"
temp_str = temp_str + "\n"
outfile.write(temp_str)
if mirroring_on: # The content number for mirror is same as the primary segment's content number
......@@ -96,26 +96,26 @@ class GpExpandTestCase(unittest.TestCase):
next_mir_port += 1
next_pri_replication_port += 1
next_mir_replication_port += 1
next_pri_port += 1
next_dbid += 1
next_content += 1
def _create_tables(self):
with dbconn.connect(dbconn.DbURL()) as conn:
with dbconn.connect(dbconn.DbURL(), unsetSearchPath=False) as conn:
for i in range(self.NUM_TABLES):
dbconn.execSQL(conn, 'create table tab%d(i integer)' % i)
conn.commit()
def _drop_tables(self):
with dbconn.connect(dbconn.DbURL()) as conn:
with dbconn.connect(dbconn.DbURL(), unsetSearchPath=False) as conn:
for i in range(self.NUM_TABLES):
dbconn.execSQL(conn, 'drop table tab%d' % i)
dbconn.execSQL(conn, 'drop table tab%d' % i)
conn.commit()
def _get_dist_policies(self):
policies = []
with dbconn.connect(dbconn.DbURL()) as conn:
with dbconn.connect(dbconn.DbURL(), unsetSearchPath=False) as conn:
cursor = dbconn.execSQL(conn, 'select * from gp_distribution_policy;').fetchall()
for row in cursor:
policies.append(row)
......@@ -128,16 +128,16 @@ class GpExpandTestCase(unittest.TestCase):
cmd = Command(name='run gpexpand', cmdStr='gpexpand -i %s' % (self.EXPANSION_INPUT_FILE))
with self.assertRaisesRegexp(ExecutionError, 'Fault Injection'):
cmd.run(validateAfter=True)
#Read from the pg_hba.conf file and ensure that
#Read from the pg_hba.conf file and ensure that
#The address of the new hosts is present.
cmd = Command(name='get the temp pg_hba.conf file',
cmd = Command(name='get the temp pg_hba.conf file',
cmdStr="ls %s" % os.path.join(os.path.dirname(self.MASTER_DATA_DIRECTORY),
'gpexpand*',
'pg_hba.conf'))
cmd.run(validateAfter=True)
results = cmd.get_results()
temp_pg_hba_conf = results.stdout.strip()
temp_pg_hba_conf = results.stdout.strip()
actual_values = set()
expected_values = set([self.primary_host_address, self.mirror_host_address])
......@@ -156,9 +156,9 @@ class GpExpandTestCase(unittest.TestCase):
GpStart(name='start the database').run(validateAfter=True)
def test01_distribution_policy(self):
self._create_tables()
try:
os.environ[self.GPMGMT_FAULT_POINT] = 'gpexpand MPP-14620 fault injection'
original_dist_policies = self._get_dist_policies()
......@@ -168,10 +168,9 @@ class GpExpandTestCase(unittest.TestCase):
rollback = Command(name='rollback expansion', cmdStr='gpexpand -r')
rollback.run(validateAfter=True)
dist_policies = self._get_dist_policies()
dist_policies = self._get_dist_policies()
self.assertEqual(original_dist_policies, dist_policies)
finally:
self._drop_tables()
......@@ -32,7 +32,7 @@ def before_feature(context, feature):
create_database(context, 'incr_analyze')
drop_database_if_exists(context, 'incr_analyze_2')
create_database(context, 'incr_analyze_2')
context.conn = dbconn.connect(dbconn.DbURL(dbname='incr_analyze'))
context.conn = dbconn.connect(dbconn.DbURL(dbname='incr_analyze'), unsetSearchPath=False)
context.dbname = 'incr_analyze'
# setting up the tables that will be used
......@@ -48,7 +48,7 @@ def before_feature(context, feature):
minirepro_db = 'minireprodb'
drop_database_if_exists(context, minirepro_db)
create_database(context, minirepro_db)
context.conn = dbconn.connect(dbconn.DbURL(dbname=minirepro_db))
context.conn = dbconn.connect(dbconn.DbURL(dbname=minirepro_db), unsetSearchPath=False)
context.dbname = minirepro_db
dbconn.execSQL(context.conn, 'create table t1(a integer, b integer)')
dbconn.execSQL(context.conn, 'create table t2(c integer, d integer)')
......
......@@ -213,7 +213,7 @@ def impl(context, mod_count, table, schema, dbname):
@then('root stats are populated for partition table "{tablename}" for database "{dbname}"')
def impl(context, tablename, dbname):
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
query = "select count(*) from pg_statistic where starelid='%s'::regclass;" % tablename
num_tuples = dbconn.execSQLForSingleton(conn, query)
if num_tuples == 0:
......@@ -230,7 +230,7 @@ def get_mod_count_in_state_file(dbname, schema, table):
def create_long_lived_conn(context, dbname):
context.long_lived_conn = dbconn.connect(dbconn.DbURL(dbname=dbname))
context.long_lived_conn = dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False)
def table_found_in_state_file(dbname, qualified_table):
......
......@@ -174,12 +174,12 @@ def impl(context, dbname, cname):
if cname in context.named_conns:
context.named_conns[cname].close()
del context.named_conns[cname]
context.named_conns[cname] = dbconn.connect(dbconn.DbURL(dbname=dbname))
context.named_conns[cname] = dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False)
@given('the user create a writable external table with name "{tabname}"')
def impl(conetxt, tabname):
dbname = 'gptest'
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
sql = ("create writable external table {tabname}(a int) location "
"('gpfdist://host.invalid:8000/file') format 'text'").format(tabname=tabname)
dbconn.execSQL(conn, sql)
......@@ -527,7 +527,7 @@ def impl(context, table_type, tablename, dbname):
def impl(context, table_type, tablename, dbname, numrows):
if not check_table_exists(context, dbname=dbname, table_name=tablename, table_type=table_type):
raise Exception("Table '%s' of type '%s' does not exist when expected" % (tablename, table_type))
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
rowcount = dbconn.execSQLForSingleton(conn, "SELECT count(*) FROM %s" % tablename)
if rowcount != numrows:
raise Exception("Expected to find %d rows in table %s, found %d" % (numrows, tablename, rowcount))
......@@ -578,7 +578,7 @@ def impl(context, row_values, table, dbname):
@then('verify that database "{dbname}" does not exist')
def impl(context, dbname):
with dbconn.connect(dbconn.DbURL(dbname='template1')) as conn:
with dbconn.connect(dbconn.DbURL(dbname='template1'), unsetSearchPath=False) as conn:
sql = """SELECT datname FROM pg_database"""
dbs = dbconn.execSQL(conn, sql)
if dbname in dbs:
......@@ -613,7 +613,7 @@ def impl(context, filepath):
def impl(context, sql, dbname):
context.stored_sql_results = []
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
curs = dbconn.execSQL(conn, sql)
context.stored_sql_results = curs.fetchall()
......@@ -697,7 +697,7 @@ def impl(context):
@given('the user runs gpinitstandby with options "{options}"')
def impl(context, options):
dbname = 'postgres'
with dbconn.connect(dbconn.DbURL(port=os.environ.get("PGPORT"), dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(port=os.environ.get("PGPORT"), dbname=dbname), unsetSearchPath=False) as conn:
query = """select distinct content, hostname from gp_segment_configuration order by content limit 2;"""
cursor = dbconn.execSQL(conn, query)
......@@ -884,7 +884,7 @@ def impl(context, sql, dbname):
def impl(context, dbname):
context.stored_rows = []
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
curs = dbconn.execSQL(conn, context.text)
context.stored_rows = curs.fetchall()
......@@ -893,7 +893,7 @@ def impl(context, dbname):
def impl(context, sql, dbname):
context.stored_rows = []
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
curs = dbconn.execSQL(conn, sql)
context.stored_rows = curs.fetchall()
......@@ -1181,7 +1181,7 @@ def impl(context):
def impl(context):
check_segment_config_query = "SELECT * FROM gp_segment_configuration WHERE content = -1 AND role = 'm'"
check_stat_replication_query = "SELECT * FROM pg_stat_replication"
with dbconn.connect(dbconn.DbURL(dbname='postgres')) as conn:
with dbconn.connect(dbconn.DbURL(dbname='postgres'), unsetSearchPath=False) as conn:
segconfig = dbconn.execSQL(conn, check_segment_config_query).fetchall()
statrep = dbconn.execSQL(conn, check_stat_replication_query).fetchall()
......@@ -1196,7 +1196,7 @@ def impl(context):
@then('verify the standby master is now acting as master')
def impl(context):
check_segment_config_query = "SELECT * FROM gp_segment_configuration WHERE content = -1 AND role = 'p' AND preferred_role = 'p' AND dbid = %s" % context.standby_dbid
with dbconn.connect(dbconn.DbURL(hostname=context.standby_hostname, dbname='postgres', port=context.standby_port)) as conn:
with dbconn.connect(dbconn.DbURL(hostname=context.standby_hostname, dbname='postgres', port=context.standby_port), unsetSearchPath=False) as conn:
segconfig = dbconn.execSQL(conn, check_segment_config_query).fetchall()
if len(segconfig) != 1:
......@@ -1374,7 +1374,7 @@ def impl(context, filename, some, output):
@then('verify that the file "{filename}" in each segment data directory has "{some}" line starting with "{output}"')
def impl(context, filename, some, output):
try:
with dbconn.connect(dbconn.DbURL(dbname='template1')) as conn:
with dbconn.connect(dbconn.DbURL(dbname='template1'), unsetSearchPath=False) as conn:
curs = dbconn.execSQL(conn, "SELECT hostname, datadir FROM gp_segment_configuration WHERE role='p' AND content > -1;")
result = curs.fetchall()
segment_info = [(result[s][0], result[s][1]) for s in range(len(result))]
......@@ -1412,7 +1412,7 @@ def impl(context, filename, some, output):
def impl(context, filename, output):
segment_info = []
try:
with dbconn.connect(dbconn.DbURL(dbname='template1')) as conn:
with dbconn.connect(dbconn.DbURL(dbname='template1'), unsetSearchPath=False) as conn:
curs = dbconn.execSQL(conn, "SELECT hostname, datadir FROM gp_segment_configuration WHERE role='p' AND content > -1;")
result = curs.fetchall()
segment_info = [(result[s][0], result[s][1]) for s in range(len(result))]
......@@ -1654,7 +1654,7 @@ def impl(context, dir):
'the entry for the table "{user_table}" is removed from "{catalog_table}" with key "{primary_key}" in the database "{db_name}"')
def impl(context, user_table, catalog_table, primary_key, db_name):
delete_qry = "delete from %s where %s='%s'::regclass::oid;" % (catalog_table, primary_key, user_table)
with dbconn.connect(dbconn.DbURL(dbname=db_name)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=db_name), unsetSearchPath=False) as conn:
for qry in ["set allow_system_table_mods=true;", "set allow_segment_dml=true;", delete_qry]:
dbconn.execSQL(conn, qry)
conn.commit()
......@@ -1667,7 +1667,7 @@ def impl(context, user_table, catalog_table, primary_key, db_name):
delete_qry = "delete from %s where %s='%s'::regclass::oid;" % (catalog_table, primary_key, user_table)
with dbconn.connect(dbconn.DbURL(dbname=db_name, port=port, hostname=host), utility=True,
allowSystemTableMods=True) as conn:
allowSystemTableMods=True, unsetSearchPath=False) as conn:
for qry in [delete_qry]:
dbconn.execSQL(conn, qry)
conn.commit()
......@@ -1761,7 +1761,7 @@ def impl(context, table_name, db_name):
index_qry = "create table {0}(i int primary key, j varchar); create index test_index on index_table using bitmap(j)".format(
table_name)
with dbconn.connect(dbconn.DbURL(dbname=db_name)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=db_name), unsetSearchPath=False) as conn:
dbconn.execSQL(conn, index_qry)
conn.commit()
......@@ -2029,7 +2029,7 @@ def impl(context, command, target):
@then('verify that a role "{role_name}" exists in database "{dbname}"')
def impl(context, role_name, dbname):
query = "select rolname from pg_roles where rolname = '%s'" % role_name
conn = dbconn.connect(dbconn.DbURL(dbname=dbname))
conn = dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False)
try:
result = getRows(dbname, query)[0][0]
if result != role_name:
......@@ -2114,7 +2114,7 @@ def _create_cluster(context, master_host, segment_host_list, with_mirrors=False,
os.environ['MASTER_DATA_DIRECTORY'] = os.path.join(context.working_directory,
'data/master/gpseg-1')
try:
with dbconn.connect(dbconn.DbURL(dbname='template1')) as conn:
with dbconn.connect(dbconn.DbURL(dbname='template1'), unsetSearchPath=False) as conn:
curs = dbconn.execSQL(conn, "select count(*) from gp_segment_configuration where role='m';")
count = curs.fetchall()[0][0]
if not with_mirrors and count == 0:
......@@ -2267,7 +2267,7 @@ def impl(context):
@then('the numsegments of table "{tabname}" is {numsegments}')
def impl(context, tabname, numsegments):
dbname = 'gptest'
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
query = "select numsegments from gp_distribution_policy where localoid = '{tabname}'::regclass".format(tabname=tabname)
ns = dbconn.execSQLForSingleton(conn, query)
......@@ -2282,7 +2282,7 @@ def impl(context, tabname, numsegments):
@then('the number of segments have been saved')
def impl(context):
dbname = 'gptest'
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
query = """SELECT count(*) from gp_segment_configuration where -1 < content"""
context.start_data_segments = dbconn.execSQLForSingleton(conn, query)
......@@ -2292,7 +2292,7 @@ def impl(context):
def impl(context):
dbname = 'gptest'
gp_segment_conf_backup = {}
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
query = """SELECT count(*) from gp_segment_configuration where -1 < content"""
segment_count = int(dbconn.execSQLForSingleton(conn, query))
query = """SELECT * from gp_segment_configuration where -1 < content order by dbid"""
......@@ -2318,7 +2318,7 @@ def impl(context):
def impl(context):
dbname = 'gptest'
gp_segment_conf_backup = {}
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
query = """SELECT count(*) from gp_segment_configuration where -1 < content"""
segment_count = int(dbconn.execSQLForSingleton(conn, query))
query = """SELECT * from gp_segment_configuration where -1 < content order by dbid"""
......@@ -2342,7 +2342,7 @@ def impl(context):
@given('user has created {table_name} table')
def impl(context, table_name):
dbname = 'gptest'
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
query = """CREATE TABLE %s(a INT)""" % table_name
dbconn.execSQL(conn, query)
conn.commit()
......@@ -2350,7 +2350,7 @@ def impl(context, table_name):
@given('a long-run read-only transaction exists on {table_name}')
def impl(context, table_name):
dbname = 'gptest'
conn = dbconn.connect(dbconn.DbURL(dbname=dbname))
conn = dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False)
context.long_run_select_only_conn = conn
query = """SELECT gp_segment_id, * from %s order by 1, 2""" % table_name
......@@ -2381,7 +2381,7 @@ def impl(context, table_name):
@given('a long-run transaction starts')
def impl(context):
dbname = 'gptest'
conn = dbconn.connect(dbconn.DbURL(dbname=dbname))
conn = dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False)
context.long_run_conn = conn
query = """SELECT txid_current()"""
......@@ -2413,7 +2413,7 @@ def impl(context, table_name):
@then('verify that the cluster has {num_of_segments} new segments')
def impl(context, num_of_segments):
dbname = 'gptest'
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
query = """SELECT dbid, content, role, preferred_role, mode, status, port, hostname, address, datadir from gp_segment_configuration;"""
rows = dbconn.execSQL(conn, query).fetchall()
end_data_segments = 0
......@@ -2470,7 +2470,7 @@ def impl(context, hostnames):
@then('user has created expansiontest tables')
def impl(context):
dbname = 'gptest'
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
for i in range(3):
query = """drop table if exists expansiontest%s""" % (i)
dbconn.execSQL(conn, query)
......@@ -2481,7 +2481,7 @@ def impl(context):
@then('the tables have finished expanding')
def impl(context):
dbname = 'postgres'
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
query = """select fq_name from gpexpand.status_detail WHERE expansion_finished IS NULL"""
cursor = dbconn.execSQL(conn, query)
......@@ -2492,7 +2492,7 @@ def impl(context):
@given('an FTS probe is triggered')
@when('an FTS probe is triggered')
def impl(context):
with dbconn.connect(dbconn.DbURL(dbname='postgres')) as conn:
with dbconn.connect(dbconn.DbURL(dbname='postgres'), unsetSearchPath=False) as conn:
dbconn.execSQLForSingleton(conn, "SELECT gp_request_fts_probe_scan()")
@then('verify that gpstart on original master fails due to lower Timeline ID')
......@@ -2530,7 +2530,7 @@ def step_impl(context, options):
break ## down segments comes after up segments, so we can break here
elif '-m' in options:
dbname = 'postgres'
with dbconn.connect(dbconn.DbURL(hostname=context.standby_hostname, port=context.standby_port, dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(hostname=context.standby_hostname, port=context.standby_port, dbname=dbname), unsetSearchPath=False) as conn:
query = """select datadir, port from pg_catalog.gp_segment_configuration where role='m' and content <> -1;"""
cursor = dbconn.execSQL(conn, query)
......@@ -2566,7 +2566,7 @@ def impl(context, config_file):
@then('check segment conf: postgresql.conf')
def step_impl(context):
query = "select dbid, port, hostname, datadir from gp_segment_configuration where content >= 0"
conn = dbconn.connect(dbconn.DbURL(dbname='postgres'))
conn = dbconn.connect(dbconn.DbURL(dbname='postgres'), unsetSearchPath=False)
segments = dbconn.execSQL(conn, query).fetchall()
for segment in segments:
dbid = "'%s'" % segment[0]
......@@ -2607,7 +2607,7 @@ def impl(context):
@then('verify the dml results again in a new transaction')
def impl(context):
dbname = 'gptest'
conn = dbconn.connect(dbconn.DbURL(dbname=dbname))
conn = dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False)
for dml, job in context.dml_jobs:
code, message = job.reverify(conn)
......@@ -2626,7 +2626,7 @@ def impl(context, table, dbname):
raise Exception("Failed to redistribute table. Expected to have more than %d segments, got %d segments" % (len(pre_distribution_row_count), len(post_distribution_row_count)))
post_distribution_num_segments = 0
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
query = "SELECT count(DISTINCT content) FROM gp_segment_configuration WHERE content != -1;"
cursor = dbconn.execSQL(conn, query)
post_distribution_num_segments = cursor.fetchone()[0]
......@@ -2648,7 +2648,7 @@ def impl(context, table, dbname):
(table, relative_std_error, tolerance))
def _get_row_count_per_segment(table, dbname):
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
query = "SELECT gp_segment_id,COUNT(i) FROM %s GROUP BY gp_segment_id;" % table
cursor = dbconn.execSQL(conn, query)
rows = cursor.fetchall()
......@@ -2683,7 +2683,7 @@ def impl(context):
createdb_cmd = "createdb \"%s\"" % escape_dbname
run_command(context, createdb_cmd)
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
#special char table
query = 'create table " a b.""\'\\\\"(c1 int);'
dbconn.execSQL(conn, query)
......@@ -2763,7 +2763,7 @@ CREATE TABLE mismatch_two (a text);
DROP TABLE IF EXISTS mismatch_three;
CREATE TABLE mismatch_three (a text);
-- fixed -> 1 -> 2 -> 3 -> 1
UPDATE pg_class SET reltoastrelid = (SELECT reltoastrelid FROM pg_class WHERE relname = 'mismatch_one') WHERE relname = 'mismatch_fixed'; -- "save" the reltoastrelid
UPDATE pg_class SET reltoastrelid = (SELECT reltoastrelid FROM pg_class WHERE relname = 'mismatch_two') WHERE relname = 'mismatch_one';
......@@ -2836,7 +2836,7 @@ UPDATE pg_class SET reltoastrelid = 0 WHERE relname = 'double_orphan_invalid_par
for dbURL in dbURLs:
utility = True if contentIDs else False
with dbconn.connect(dbURL, allowSystemTableMods=True, utility=utility) as conn:
with dbconn.connect(dbURL, allowSystemTableMods=True, utility=utility, unsetSearchPath=False) as conn:
dbconn.execSQL(conn, sql)
conn.commit()
......@@ -2858,20 +2858,20 @@ def impl(context, dbname):
seg0 = dbconn.DbURL(dbname=dbname, hostname=primary0.hostname, port=primary0.port)
seg1 = dbconn.DbURL(dbname=dbname, hostname=primary1.hostname, port=primary1.port)
with dbconn.connect(master, allowSystemTableMods=True) as conn:
with dbconn.connect(master, allowSystemTableMods=True, unsetSearchPath=False) as conn:
dbconn.execSQL(conn, """
DROP TABLE IF EXISTS borked;
CREATE TABLE borked (a text);
""")
conn.commit()
with dbconn.connect(seg0, utility=True, allowSystemTableMods=True) as conn:
with dbconn.connect(seg0, utility=True, allowSystemTableMods=True, unsetSearchPath=False) as conn:
dbconn.execSQL(conn, """
DELETE FROM pg_depend WHERE refobjid = 'borked'::regclass;
""")
conn.commit()
with dbconn.connect(seg1, utility=True, allowSystemTableMods=True) as conn:
with dbconn.connect(seg1, utility=True, allowSystemTableMods=True, unsetSearchPath=False) as conn:
dbconn.execSQL(conn, """
UPDATE pg_class SET reltoastrelid = 0 WHERE oid = 'borked'::regclass;
""")
......@@ -2883,7 +2883,7 @@ def impl(context):
gp_segment_configuration_backup = 'gpexpand.gp_segment_configuration'
query = "select hostname, datadir from gp_segment_configuration where content = -1 order by dbid"
conn = dbconn.connect(dbconn.DbURL(dbname='postgres'))
conn = dbconn.connect(dbconn.DbURL(dbname='postgres'), unsetSearchPath=False)
res = dbconn.execSQL(conn, query).fetchall()
master = res[0]
standby = res[1]
......
......@@ -86,7 +86,7 @@ def make_data_directory_called(data_directory_name):
def _get_mirror_count():
with dbconn.connect(dbconn.DbURL(dbname='template1')) as conn:
with dbconn.connect(dbconn.DbURL(dbname='template1'), unsetSearchPath=False) as conn:
sql = """SELECT count(*) FROM gp_segment_configuration WHERE role='m'"""
count_row = dbconn.execSQL(conn, sql).fetchone()
return count_row[0]
......
......@@ -20,12 +20,12 @@ class Tablespace:
for host in gparray.getHostList():
run_cmd('ssh %s mkdir -p %s' % (pipes.quote(host), pipes.quote(self.path)))
with dbconn.connect(dbconn.DbURL()) as conn:
with dbconn.connect(dbconn.DbURL(), unsetSearchPath=False) as conn:
db = pg.DB(conn)
db.query("CREATE TABLESPACE %s LOCATION '%s'" % (self.name, self.path))
db.query("CREATE DATABASE %s TABLESPACE %s" % (self.dbname, self.name))
with dbconn.connect(dbconn.DbURL(dbname=self.dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=self.dbname), unsetSearchPath=False) as conn:
db = pg.DB(conn)
db.query("CREATE TABLE tbl (i int) DISTRIBUTED RANDOMLY")
db.query("INSERT INTO tbl VALUES (GENERATE_SERIES(0, 25))")
......@@ -33,7 +33,7 @@ class Tablespace:
self.initial_data = db.query("SELECT gp_segment_id, i FROM tbl").getresult()
def cleanup(self):
with dbconn.connect(dbconn.DbURL(dbname="postgres")) as conn:
with dbconn.connect(dbconn.DbURL(dbname="postgres"), unsetSearchPath=False) as conn:
db = pg.DB(conn)
db.query("DROP DATABASE IF EXISTS %s" % self.dbname)
db.query("DROP TABLESPACE IF EXISTS %s" % self.name)
......@@ -56,7 +56,7 @@ class Tablespace:
distributed.
"""
url = dbconn.DbURL(hostname=hostname, port=port, dbname=self.dbname)
with dbconn.connect(url) as conn:
with dbconn.connect(url, unsetSearchPath=False) as conn:
db = pg.DB(conn)
data = db.query("SELECT gp_segment_id, i FROM tbl").getresult()
......@@ -76,7 +76,7 @@ class Tablespace:
2. the table's numsegments is enlarged to the new cluster size
"""
url = dbconn.DbURL(hostname=hostname, port=port, dbname=self.dbname)
with dbconn.connect(url) as conn:
with dbconn.connect(url, unsetSearchPath=False) as conn:
db = pg.DB(conn)
data = db.query("SELECT gp_segment_id, i FROM tbl").getresult()
tbl_numsegments = dbconn.execSQLForSingleton(conn,
......
......@@ -73,7 +73,7 @@ class Gpexpand:
def get_redistribute_status(self):
sql = 'select status from gpexpand.status order by updated desc limit 1'
dburl = dbconn.DbURL(dbname=self.database)
conn = dbconn.connect(dburl, encoding='UTF8')
conn = dbconn.connect(dburl, encoding='UTF8', unsetSearchPath=False)
status = dbconn.execSQLForSingleton(conn, sql)
if status == 'EXPANSION COMPLETE':
rc = 0
......
......@@ -29,7 +29,7 @@ class TestDML(threading.Thread):
self.prepare()
def run(self):
conn = dbconn.connect(dbconn.DbURL(dbname=self.dbname))
conn = dbconn.connect(dbconn.DbURL(dbname=self.dbname), unsetSearchPath=False)
self.loop(conn)
self.verify(conn)
......@@ -46,7 +46,7 @@ class TestDML(threading.Thread):
) DISTRIBUTED BY (c1);
'''.format(tablename=self.tablename)
conn = dbconn.connect(dbconn.DbURL(dbname=self.dbname))
conn = dbconn.connect(dbconn.DbURL(dbname=self.dbname), unsetSearchPath=False)
dbconn.execSQL(conn, sql)
self.prepare_extra(conn)
......
......@@ -31,7 +31,7 @@ if master_data_dir is None:
def execute_sql(dbname, sql):
result = None
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
result = dbconn.execSQL(conn, sql)
conn.commit()
......@@ -39,7 +39,7 @@ def execute_sql(dbname, sql):
def execute_sql_singleton(dbname, sql):
result = None
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
result = dbconn.execSQLForSingleton(conn, sql)
if result is None:
......@@ -204,7 +204,7 @@ def stop_database(context):
def stop_primary(context, content_id):
get_psegment_sql = 'select datadir, hostname from gp_segment_configuration where content=%i and role=\'p\';' % content_id
with dbconn.connect(dbconn.DbURL(dbname='template1')) as conn:
with dbconn.connect(dbconn.DbURL(dbname='template1'), unsetSearchPath=False) as conn:
cur = dbconn.execSQL(conn, get_psegment_sql)
rows = cur.fetchall()
seg_data_dir = rows[0][0]
......@@ -227,14 +227,14 @@ def run_gprecoverseg():
def getRows(dbname, exec_sql):
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
curs = dbconn.execSQL(conn, exec_sql)
results = curs.fetchall()
return results
def getRow(dbname, exec_sql):
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
curs = dbconn.execSQL(conn, exec_sql)
result = curs.fetchone()
return result
......@@ -244,7 +244,7 @@ def check_db_exists(dbname, host=None, port=0, user=None):
LIST_DATABASE_SQL = 'SELECT datname FROM pg_database'
results = []
with dbconn.connect(dbconn.DbURL(hostname=host, username=user, port=port, dbname='template1')) as conn:
with dbconn.connect(dbconn.DbURL(hostname=host, username=user, port=port, dbname='template1'), unsetSearchPath=False) as conn:
curs = dbconn.execSQL(conn, LIST_DATABASE_SQL)
results = curs.fetchall()
......@@ -259,7 +259,7 @@ def create_database_if_not_exists(context, dbname, host=None, port=0, user=None)
if not check_db_exists(dbname, host, port, user):
create_database(context, dbname, host, port, user)
context.dbname = dbname
context.conn = dbconn.connect(dbconn.DbURL(dbname=context.dbname))
context.conn = dbconn.connect(dbconn.DbURL(dbname=context.dbname), unsetSearchPath=False)
def create_database(context, dbname=None, host=None, port=0, user=None):
LOOPS = 10
......@@ -294,7 +294,7 @@ def get_segment_hostnames(context, dbname):
def check_table_exists(context, dbname, table_name, table_type=None, host=None, port=0, user=None):
with dbconn.connect(dbconn.DbURL(hostname=host, port=port, username=user, dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(hostname=host, port=port, username=user, dbname=dbname), unsetSearchPath=False) as conn:
if '.' in table_name:
schemaname, tablename = table_name.split('.')
SQL_format = """
......@@ -347,14 +347,14 @@ def drop_external_table_if_exists(context, table_name, dbname):
def drop_table_if_exists(context, table_name, dbname, host=None, port=0, user=None):
SQL = 'drop table if exists %s' % table_name
with dbconn.connect(dbconn.DbURL(hostname=host, port=port, username=user, dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(hostname=host, port=port, username=user, dbname=dbname), unsetSearchPath=False) as conn:
dbconn.execSQL(conn, SQL)
conn.commit()
def drop_external_table(context, table_name, dbname, host=None, port=0, user=None):
SQL = 'drop external table %s' % table_name
with dbconn.connect(dbconn.DbURL(hostname=host, port=port, username=user, dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(hostname=host, port=port, username=user, dbname=dbname), unsetSearchPath=False) as conn:
dbconn.execSQL(conn, SQL)
conn.commit()
......@@ -365,7 +365,7 @@ def drop_external_table(context, table_name, dbname, host=None, port=0, user=Non
def drop_table(context, table_name, dbname, host=None, port=0, user=None):
SQL = 'drop table %s' % table_name
with dbconn.connect(dbconn.DbURL(hostname=host, username=user, port=port, dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(hostname=host, username=user, port=port, dbname=dbname), unsetSearchPath=False) as conn:
dbconn.execSQL(conn, SQL)
conn.commit()
......@@ -387,7 +387,7 @@ def drop_schema_if_exists(context, schema_name, dbname):
def drop_schema(context, schema_name, dbname):
SQL = 'drop schema %s cascade' % schema_name
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
dbconn.execSQL(conn, SQL)
conn.commit()
if check_schema_exists(context, schema_name, dbname):
......@@ -445,7 +445,7 @@ def create_external_partition(context, tablename, dbname, port, filename):
drop_table_str = "Drop table %s_ret;" % (tablename)
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
dbconn.execSQL(conn, create_table_str)
dbconn.execSQL(conn, create_ext_table_str)
dbconn.execSQL(conn, alter_table_str)
......@@ -484,7 +484,7 @@ def create_partition(context, tablename, storage_type, dbname, compression_type=
create_table_str = create_table_str + ";"
with dbconn.connect(dbconn.DbURL(hostname=host, port=port, username=user, dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(hostname=host, port=port, username=user, dbname=dbname), unsetSearchPath=False) as conn:
dbconn.execSQL(conn, create_table_str)
conn.commit()
......@@ -499,7 +499,7 @@ def populate_partition(tablename, start_date, dbname, data_offset, rowcount=1094
insert_sql_str += "; insert into %s select i+%d, 'restore', i + date '%s' from generate_series(0,%d) as i" % (
tablename, data_offset, start_date, rowcount)
with dbconn.connect(dbconn.DbURL(hostname=host, port=port, username=user, dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(hostname=host, port=port, username=user, dbname=dbname), unsetSearchPath=False) as conn:
dbconn.execSQL(conn, insert_sql_str)
conn.commit()
......@@ -508,7 +508,7 @@ def create_indexes(context, table_name, indexname, dbname):
btree_index_sql = "create index btree_%s on %s using btree(column1);" % (indexname, table_name)
bitmap_index_sql = "create index bitmap_%s on %s using bitmap(column3);" % (indexname, table_name)
index_sql = btree_index_sql + bitmap_index_sql
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
dbconn.execSQL(conn, index_sql)
conn.commit()
validate_index(context, table_name, dbname)
......@@ -524,7 +524,7 @@ def validate_index(context, table_name, dbname):
def create_schema(context, schema_name, dbname):
if not check_schema_exists(context, schema_name, dbname):
schema_sql = "create schema %s" % schema_name
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
dbconn.execSQL(conn, schema_sql)
conn.commit()
......@@ -547,7 +547,7 @@ def create_int_table(context, table_name, table_type='heap', dbname='testdb'):
raise Exception('Invalid table type specified')
SELECT_TABLE_SQL = 'select count(*) from %s' % table_name
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
dbconn.execSQL(conn, CREATE_TABLE_SQL)
conn.commit()
......@@ -605,7 +605,7 @@ def check_row_count(context, tablename, dbname, nrows):
else:
dburl = dbconn.DbURL(dbname=dbname)
with dbconn.connect(dburl) as conn:
with dbconn.connect(dburl, unsetSearchPath=False) as conn:
result = dbconn.execSQLForSingleton(conn, NUM_ROWS_QUERY)
if result != nrows:
raise Exception('%d rows in table %s.%s, expected row count = %d' % (result, dbname, tablename, nrows))
......@@ -697,7 +697,7 @@ def create_dir(host, directory):
def check_count_for_specific_query(dbname, query, nrows):
NUM_ROWS_QUERY = '%s' % query
# We want to bubble up the exception so that if table does not exist, the test fails
with dbconn.connect(dbconn.DbURL(dbname=dbname)) as conn:
with dbconn.connect(dbconn.DbURL(dbname=dbname), unsetSearchPath=False) as conn:
result = dbconn.execSQLForSingleton(conn, NUM_ROWS_QUERY)
if result != nrows:
raise Exception('%d rows in query: %s. Expected row count = %d' % (result, query, nrows))
......@@ -709,7 +709,7 @@ def get_primary_segment_host_port():
"""
FIRST_PRIMARY_DBID = 2
get_psegment_sql = 'select hostname, port from gp_segment_configuration where dbid=%i;' % FIRST_PRIMARY_DBID
with dbconn.connect(dbconn.DbURL(dbname='template1')) as conn:
with dbconn.connect(dbconn.DbURL(dbname='template1'), unsetSearchPath=False) as conn:
cur = dbconn.execSQL(conn, get_psegment_sql)
rows = cur.fetchall()
primary_seg_host = rows[0][0]
......@@ -779,7 +779,7 @@ def wait_for_unblocked_transactions(context, num_retries=150):
attempt = 0
while attempt < num_retries:
try:
with dbconn.connect(dbconn.DbURL()) as conn:
with dbconn.connect(dbconn.DbURL(), unsetSearchPath=False) as conn:
# Cursor.execute() will issue an implicit BEGIN for us.
# Empty block of 'BEGIN' and 'END' won't start a distributed transaction,
# execute a DDL query to start a distributed transaction.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册