提交 7ee8afb6 编写于 作者: V Ville Brofeldt 提交者: Maxime Beauchemin

Improve support for BigQuery, Redshift, Oracle, Db2, Snowflake (#5827)

* Conditionally mutate and quote sqla labels decouple sqla logic from viz.py

* Prefix hashed label with underscore if bigquery label exceeds 128 chars

* Add comments for label cache

* Rename to mutated_labels and simply

* Rename mutated_label to get_label and simplify make_label_compatible in db_engine_specs

* Add note about deterministic and unique mutated labels

* add hash to label that has been prefixed with underscore

* Fix PEP8 escape warning

* Fix DeckPathViz get_metric_label call
上级 055467de
......@@ -112,8 +112,8 @@ class TableColumn(Model, BaseColumn):
export_parent = 'table'
def get_sqla_col(self, label=None):
db_engine_spec = self.table.database.db_engine_spec
label = db_engine_spec.make_label_compatible(label if label else self.column_name)
label = label if label else self.column_name
label = self.table.get_label(label)
if not self.expression:
col = column(self.column_name).label(label)
else:
......@@ -135,10 +135,12 @@ class TableColumn(Model, BaseColumn):
def get_timestamp_expression(self, time_grain):
"""Getting the time component of the query"""
label = self.table.get_label(utils.DTTM_ALIAS)
pdf = self.python_date_format
is_epoch = pdf in ('epoch_s', 'epoch_ms')
if not self.expression and not time_grain and not is_epoch:
return column(self.column_name, type_=DateTime).label(utils.DTTM_ALIAS)
return column(self.column_name, type_=DateTime).label(label)
expr = self.expression or self.column_name
if is_epoch:
......@@ -152,7 +154,7 @@ class TableColumn(Model, BaseColumn):
grain = self.table.database.grains_dict().get(time_grain)
if grain:
expr = grain.function.format(col=expr)
return literal_column(expr, type_=DateTime).label(utils.DTTM_ALIAS)
return literal_column(expr, type_=DateTime).label(label)
@classmethod
def import_obj(cls, i_column):
......@@ -207,8 +209,8 @@ class SqlMetric(Model, BaseMetric):
export_parent = 'table'
def get_sqla_col(self, label=None):
db_engine_spec = self.table.database.db_engine_spec
label = db_engine_spec.make_label_compatible(label if label else self.metric_name)
label = label if label else self.metric_name
label = self.table.get_label(label)
return literal_column(self.expression).label(label)
@property
......@@ -287,6 +289,21 @@ class SqlaTable(Model, BaseDatasource):
'MAX': sa.func.MAX,
}
def get_label(self, label):
"""Conditionally mutate a label to conform to db engine requirements
and store mapping from mutated label to original label
:param label: original label
:return: Either a string or sqlalchemy.sql.elements.quoted_name if required
by db engine
"""
db_engine_spec = self.database.db_engine_spec
sqla_label = db_engine_spec.make_label_compatible(label)
mutated_label = str(sqla_label)
if label != mutated_label:
self.mutated_labels[mutated_label] = label
return sqla_label
def __repr__(self):
return self.name
......@@ -486,8 +503,8 @@ class SqlaTable(Model, BaseDatasource):
:rtype: sqlalchemy.sql.column
"""
expression_type = metric.get('expressionType')
db_engine_spec = self.database.db_engine_spec
label = db_engine_spec.make_label_compatible(metric.get('label'))
label = utils.get_metric_name(metric)
label = self.get_label(label)
if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
column_name = metric.get('column').get('column_name')
......@@ -540,6 +557,9 @@ class SqlaTable(Model, BaseDatasource):
template_processor = self.get_template_processor(**template_kwargs)
db_engine_spec = self.database.db_engine_spec
# Initialize empty cache to store mutated labels
self.mutated_labels = {}
orderby = orderby or []
# For backward compatibility
......@@ -569,8 +589,8 @@ class SqlaTable(Model, BaseDatasource):
if metrics_exprs:
main_metric_expr = metrics_exprs[0]
else:
main_metric_expr = literal_column('COUNT(*)').label(
db_engine_spec.make_label_compatible('count'))
label = self.get_label('ccount')
main_metric_expr = literal_column('COUNT(*)').label(label)
select_exprs = []
groupby_exprs = []
......@@ -695,7 +715,8 @@ class SqlaTable(Model, BaseDatasource):
# some sql dialects require for order by expressions
# to also be in the select clause -- others, e.g. vertica,
# require a unique inner alias
inner_main_metric_expr = main_metric_expr.label('mme_inner__')
label = self.get_label('mme_inner__')
inner_main_metric_expr = main_metric_expr.label(label)
inner_select_exprs += [inner_main_metric_expr]
subq = select(inner_select_exprs)
subq = subq.select_from(tbl)
......@@ -723,8 +744,11 @@ class SqlaTable(Model, BaseDatasource):
on_clause = []
for i, gb in enumerate(groupby):
on_clause.append(
groupby_exprs[i] == column(gb + '__'))
# in this case the column name, not the alias, needs to be
# conditionally mutated, as it refers to the column alias in
# the inner query
col_name = self.get_label(gb + '__')
on_clause.append(groupby_exprs[i] == column(col_name))
tbl = tbl.join(subq.alias(), and_(*on_clause))
else:
......@@ -776,6 +800,8 @@ class SqlaTable(Model, BaseDatasource):
df = None
try:
df = self.database.get_df(sql, self.schema)
if self.mutated_labels:
df = df.rename(index=str, columns=self.mutated_labels)
except Exception as e:
status = utils.QueryStatus.FAILED
logging.exception(e)
......@@ -818,7 +844,6 @@ class SqlaTable(Model, BaseDatasource):
.filter(or_(TableColumn.column_name == col.name
for col in table.columns)))
dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}
db_engine_spec = self.database.db_engine_spec
for col in table.columns:
try:
......@@ -850,9 +875,6 @@ class SqlaTable(Model, BaseDatasource):
))
if not self.main_dttm_col:
self.main_dttm_col = any_date_col
for metric in metrics:
metric.metric_name = db_engine_spec.mutate_expression_label(
metric.metric_name)
self.add_missing_metrics(metrics)
db.session.merge(self)
db.session.commit()
......
......@@ -29,6 +29,7 @@ at all. The classes here will use a common interface to specify all this.
The general idea is to use static classes and an inheritance scheme.
"""
from collections import namedtuple
import hashlib
import inspect
import logging
import os
......@@ -392,16 +393,26 @@ class BaseEngineSpec(object):
@classmethod
def make_label_compatible(cls, label):
"""
Return a sqlalchemy.sql.elements.quoted_name if the engine requires
quoting of aliases to ensure that select query and query results
have same case.
Conditionally mutate and/or quote a sql column/expression label. If
force_column_alias_quotes is set to True, return the label as a
sqlalchemy.sql.elements.quoted_name object to ensure that the select query
and query results have same case. Otherwise return the mutated label as a
regular string.
"""
if cls.force_column_alias_quotes is True:
return quoted_name(label, True)
return label
label = cls.mutate_label(label)
return quoted_name(label, True) if cls.force_column_alias_quotes else label
@staticmethod
def mutate_expression_label(label):
def mutate_label(label):
"""
Most engines support mixed case aliases that can include numbers
and special characters, like commas, parentheses etc. For engines that
have restrictions on what types of aliases are supported, this method
can be overridden to ensure that labels conform to the engine's
limitations. Mutated labels should be deterministic (input label A always
yields output label X) and unique (input labels A and B don't yield the same
output label X).
"""
return label
......@@ -490,7 +501,15 @@ class VerticaEngineSpec(PostgresBaseEngineSpec):
class RedshiftEngineSpec(PostgresBaseEngineSpec):
engine = 'redshift'
force_column_alias_quotes = True
@staticmethod
def mutate_label(label):
"""
Redshift only supports lowercase column names and aliases.
:param str label: Original label which might include uppercase letters
:return: String that is supported by the database
"""
return label.lower()
class OracleEngineSpec(PostgresBaseEngineSpec):
......@@ -516,11 +535,26 @@ class OracleEngineSpec(PostgresBaseEngineSpec):
"""TO_TIMESTAMP('{}', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')"""
).format(dttm.isoformat())
@staticmethod
def mutate_label(label):
"""
Oracle 12.1 and earlier support a maximum of 30 byte length object names, which
usually means 30 characters.
:param str label: Original label which might include unsupported characters
:return: String that is supported by the database
"""
if len(label) > 30:
hashed_label = hashlib.md5(label.encode('utf-8')).hexdigest()
# truncate the hash to first 30 characters
return hashed_label[:30]
return label
class Db2EngineSpec(BaseEngineSpec):
engine = 'ibm_db_sa'
limit_method = LimitMethod.WRAP_SQL
force_column_alias_quotes = True
time_grain_functions = {
None: '{col}',
'PT1S': 'CAST({col} as TIMESTAMP)'
......@@ -554,6 +588,20 @@ class Db2EngineSpec(BaseEngineSpec):
def convert_dttm(cls, target_type, dttm):
return "'{}'".format(dttm.strftime('%Y-%m-%d-%H.%M.%S'))
@staticmethod
def mutate_label(label):
"""
Db2 for z/OS supports a maximum of 30 byte length object names, which usually
means 30 characters.
:param str label: Original label which might include unsupported characters
:return: String that is supported by the database
"""
if len(label) > 30:
hashed_label = hashlib.md5(label.encode('utf-8')).hexdigest()
# truncate the hash to first 30 characters
return hashed_label[:30]
return label
class SqliteEngineSpec(BaseEngineSpec):
engine = 'sqlite'
......@@ -1424,16 +1472,30 @@ class BQEngineSpec(BaseEngineSpec):
return data
@staticmethod
def mutate_expression_label(label):
mutated_label = re.sub('[^\w]+', '_', label)
if not re.match('^[a-zA-Z_]+.*', mutated_label):
raise SupersetTemplateException('BigQuery field_name used is invalid {}, '
'should start with a letter or '
'underscore'.format(mutated_label))
if len(mutated_label) > 128:
raise SupersetTemplateException('BigQuery field_name {}, should be atmost '
'128 characters'.format(mutated_label))
return mutated_label
def mutate_label(label):
"""
BigQuery field_name should start with a letter or underscore, contain only
alphanumeric characters and be at most 128 characters long. Labels that start
with a number are prefixed with an underscore. Any unsupported characters are
replaced with underscores and an md5 hash is added to the end of the label to
avoid possible collisions. If the resulting label exceeds 128 characters, only
the md5 sum is returned.
:param str label: the original label which might include unsupported characters
:return: String that is supported by the database
"""
hashed_label = '_' + hashlib.md5(label.encode('utf-8')).hexdigest()
# if label starts with number, add underscore as first character
mutated_label = '_' + label if re.match(r'^\d', label) else label
# replace non-alphanumeric characters with underscores
mutated_label = re.sub(r'[^\w]+', '_', mutated_label)
if mutated_label != label:
# add md5 hash to label to avoid possible collisions
mutated_label += hashed_label
# return only hash if length of final label exceeds 128 chars
return mutated_label if len(mutated_label) <= 128 else hashed_label
@classmethod
def extra_table_metadata(cls, database, table_name, schema_name):
......
......@@ -121,27 +121,13 @@ class BaseViz(object):
if not isinstance(val, list):
val = [val]
for o in val:
label = self.get_metric_label(o)
if isinstance(o, dict):
o['label'] = label
label = utils.get_metric_name(o)
self.metric_dict[label] = o
# Cast to list needed to return serializable object in py3
self.all_metrics = list(self.metric_dict.values())
self.metric_labels = list(self.metric_dict.keys())
def get_metric_label(self, metric):
if isinstance(metric, str):
return metric
if isinstance(metric, dict):
metric = metric.get('label')
if self.datasource.type == 'table':
db_engine_spec = self.datasource.database.db_engine_spec
metric = db_engine_spec.mutate_expression_label(metric)
return metric
@staticmethod
def handle_js_int_overflow(data):
for d in data.get('records', dict()):
......@@ -577,7 +563,7 @@ class TableViz(BaseViz):
# Sum up and compute percentages for all percent metrics
percent_metrics = fd.get('percent_metrics') or []
percent_metrics = [self.get_metric_label(m) for m in percent_metrics]
percent_metrics = [utils.get_metric_name(m) for m in percent_metrics]
if len(percent_metrics):
percent_metrics = list(filter(lambda m: m in df, percent_metrics))
......@@ -595,7 +581,7 @@ class TableViz(BaseViz):
df[m_name] = pd.Series(metric_percents[m], name=m_name)
# Remove metrics that are not in the main metrics list
metrics = fd.get('metrics') or []
metrics = [self.get_metric_label(m) for m in metrics]
metrics = [utils.get_metric_name(m) for m in metrics]
for m in filter(
lambda m: m not in metrics and m in df.columns,
percent_metrics,
......@@ -695,7 +681,7 @@ class PivotTableViz(BaseViz):
df = df.pivot_table(
index=self.form_data.get('groupby'),
columns=self.form_data.get('columns'),
values=[self.get_metric_label(m) for m in self.form_data.get('metrics')],
values=[utils.get_metric_name(m) for m in self.form_data.get('metrics')],
aggfunc=self.form_data.get('pandas_aggfunc'),
margins=self.form_data.get('pivot_margins'),
)
......@@ -1030,7 +1016,7 @@ class BulletViz(NVD3Viz):
def get_data(self, df):
df = df.fillna(0)
df['metric'] = df[[self.get_metric_label(self.metric)]]
df['metric'] = df[[utils.get_metric_name(self.metric)]]
values = df['metric'].values
return {
'measures': values.tolist(),
......@@ -1150,6 +1136,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
df = df.fillna(0)
if fd.get('granularity') == 'all':
raise Exception(_('Pick a time granularity for your time series'))
if not aggregate:
df = df.pivot_table(
index=DTTM_ALIAS,
......@@ -1365,8 +1352,8 @@ class NVD3DualLineViz(NVD3Viz):
if self.form_data.get('granularity') == 'all':
raise Exception(_('Pick a time granularity for your time series'))
metric = self.get_metric_label(fd.get('metric'))
metric_2 = self.get_metric_label(fd.get('metric_2'))
metric = utils.get_metric_name(fd.get('metric'))
metric_2 = utils.get_metric_name(fd.get('metric_2'))
df = df.pivot_table(
index=DTTM_ALIAS,
values=[metric, metric_2])
......@@ -1417,7 +1404,7 @@ class NVD3TimePivotViz(NVD3TimeSeriesViz):
df = df.pivot_table(
index=DTTM_ALIAS,
columns='series',
values=self.get_metric_label(fd.get('metric')))
values=utils.get_metric_name(fd.get('metric')))
chart_data = self.to_series(df)
for serie in chart_data:
serie['rank'] = rank_lookup[serie['key']]
......@@ -1589,8 +1576,8 @@ class SunburstViz(BaseViz):
def get_data(self, df):
fd = self.form_data
cols = fd.get('groupby')
metric = self.get_metric_label(fd.get('metric'))
secondary_metric = self.get_metric_label(fd.get('secondary_metric'))
metric = utils.get_metric_name(fd.get('metric'))
secondary_metric = utils.get_metric_name(fd.get('secondary_metric'))
if metric == secondary_metric or secondary_metric is None:
df.columns = cols + ['m1']
df['m2'] = df['m1']
......@@ -1691,7 +1678,7 @@ class ChordViz(BaseViz):
qry = super(ChordViz, self).query_obj()
fd = self.form_data
qry['groupby'] = [fd.get('groupby'), fd.get('columns')]
qry['metrics'] = [self.get_metric_label(fd.get('metric'))]
qry['metrics'] = [utils.get_metric_name(fd.get('metric'))]
return qry
def get_data(self, df):
......@@ -1757,8 +1744,8 @@ class WorldMapViz(BaseViz):
from superset.data import countries
fd = self.form_data
cols = [fd.get('entity')]
metric = self.get_metric_label(fd.get('metric'))
secondary_metric = self.get_metric_label(fd.get('secondary_metric'))
metric = utils.get_metric_name(fd.get('metric'))
secondary_metric = utils.get_metric_name(fd.get('secondary_metric'))
columns = ['country', 'm1', 'm2']
if metric == secondary_metric:
ndf = df[cols]
......@@ -2289,7 +2276,7 @@ class DeckScatterViz(BaseDeckGLViz):
def get_data(self, df):
fd = self.form_data
self.metric_label = \
self.get_metric_label(self.metric) if self.metric else None
utils.get_metric_name(self.metric) if self.metric else None
self.point_radius_fixed = fd.get('point_radius_fixed')
self.fixed_value = None
self.dim = self.form_data.get('dimension')
......@@ -2320,7 +2307,7 @@ class DeckScreengrid(BaseDeckGLViz):
}
def get_data(self, df):
self.metric_label = self.get_metric_label(self.metric)
self.metric_label = utils.get_metric_name(self.metric)
return super(DeckScreengrid, self).get_data(df)
......@@ -2339,7 +2326,7 @@ class DeckGrid(BaseDeckGLViz):
}
def get_data(self, df):
self.metric_label = self.get_metric_label(self.metric)
self.metric_label = utils.get_metric_name(self.metric)
return super(DeckGrid, self).get_data(df)
......@@ -2397,7 +2384,7 @@ class DeckPathViz(BaseDeckGLViz):
return d
def get_data(self, df):
self.metric_label = self.get_metric_label(self.metric)
self.metric_label = utils.get_metric_name(self.metric)
return super(DeckPathViz, self).get_data(df)
......@@ -2445,7 +2432,7 @@ class DeckHex(BaseDeckGLViz):
}
def get_data(self, df):
self.metric_label = self.get_metric_label(self.metric)
self.metric_label = utils.get_metric_name(self.metric)
return super(DeckHex, self).get_data(df)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册