query_context_tests.py 13.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
K
Karol Kostrzewa 已提交
17 18
import pytest

19
from superset import db
20
from superset.charts.schemas import ChartDataQueryContextSchema
21
from superset.connectors.connector_registry import ConnectorRegistry
22 23
from superset.extensions import cache_manager
from superset.models.cache import CacheKey
24
from superset.utils.core import (
25
    AdhocMetricExpressionType,
26 27
    ChartDataResultFormat,
    ChartDataResultType,
28
    FilterOperator,
29 30
    TimeRangeEndpoint,
)
31
from tests.base_tests import SupersetTestCase
K
Karol Kostrzewa 已提交
32
from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices
33 34 35
from tests.fixtures.query_context import get_query_context


36
class TestQueryContext(SupersetTestCase):
37 38 39 40 41
    def test_schema_deserialization(self):
        """
        Ensure that the deserialized QueryContext contains all required fields.
        """

42
        payload = get_query_context("birth_names", add_postprocessing_operations=True)
43
        query_context = ChartDataQueryContextSchema().load(payload)
44
        self.assertEqual(len(query_context.queries), len(payload["queries"]))
45

46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
        for query_idx, query in enumerate(query_context.queries):
            payload_query = payload["queries"][query_idx]

            # check basic properies
            self.assertEqual(query.extras, payload_query["extras"])
            self.assertEqual(query.filter, payload_query["filters"])
            self.assertEqual(query.groupby, payload_query["groupby"])

            # metrics are mutated during creation
            for metric_idx, metric in enumerate(query.metrics):
                payload_metric = payload_query["metrics"][metric_idx]
                payload_metric = (
                    payload_metric
                    if "expressionType" in payload_metric
                    else payload_metric["label"]
                )
                self.assertEqual(metric, payload_metric)

            self.assertEqual(query.orderby, payload_query["orderby"])
            self.assertEqual(query.time_range, payload_query["time_range"])

            # check post processing operation properties
            for post_proc_idx, post_proc in enumerate(query.post_processing):
                payload_post_proc = payload_query["post_processing"][post_proc_idx]
                self.assertEqual(post_proc["operation"], payload_post_proc["operation"])
                self.assertEqual(post_proc["options"], payload_post_proc["options"])

73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
    def test_cache(self):
        table_name = "birth_names"
        table = self.get_table_by_name(table_name)
        payload = get_query_context(table.name, table.id)
        payload["force"] = True

        query_context = ChartDataQueryContextSchema().load(payload)
        query_object = query_context.queries[0]
        query_cache_key = query_context.query_cache_key(query_object)

        response = query_context.get_payload(cache_query_context=True)
        cache_key = response["cache_key"]
        assert cache_key is not None

        cached = cache_manager.cache.get(cache_key)
        assert cached is not None

        rehydrated_qc = ChartDataQueryContextSchema().load(cached["data"])
        rehydrated_qo = rehydrated_qc.queries[0]
        rehydrated_query_cache_key = rehydrated_qc.query_cache_key(rehydrated_qo)

        self.assertEqual(rehydrated_qc.datasource, query_context.datasource)
        self.assertEqual(len(rehydrated_qc.queries), 1)
        self.assertEqual(query_cache_key, rehydrated_query_cache_key)
        self.assertEqual(rehydrated_qc.result_type, query_context.result_type)
        self.assertEqual(rehydrated_qc.result_format, query_context.result_format)
        self.assertFalse(rehydrated_qc.force)

    def test_query_cache_key_changes_when_datasource_is_updated(self):
102
        self.login(username="admin")
103
        payload = get_query_context("birth_names")
104

105
        # construct baseline query_cache_key
106
        query_context = ChartDataQueryContextSchema().load(payload)
107
        query_object = query_context.queries[0]
108
        cache_key_original = query_context.query_cache_key(query_object)
109 110 111 112 113

        # make temporary change and revert it to refresh the changed_on property
        datasource = ConnectorRegistry.get_datasource(
            datasource_type=payload["datasource"]["type"],
            datasource_id=payload["datasource"]["id"],
114
            session=db.session,
115 116 117 118 119 120 121
        )
        description_original = datasource.description
        datasource.description = "temporary description"
        db.session.commit()
        datasource.description = description_original
        db.session.commit()

122
        # create new QueryContext with unchanged attributes, extract new query_cache_key
123
        query_context = ChartDataQueryContextSchema().load(payload)
124
        query_object = query_context.queries[0]
125
        cache_key_new = query_context.query_cache_key(query_object)
126 127 128 129

        # the new cache_key should be different due to updated datasource
        self.assertNotEqual(cache_key_original, cache_key_new)

130
    def test_query_cache_key_changes_when_post_processing_is_updated(self):
131
        self.login(username="admin")
132
        payload = get_query_context("birth_names", add_postprocessing_operations=True)
133

134
        # construct baseline query_cache_key from query_context with post processing operation
135
        query_context = ChartDataQueryContextSchema().load(payload)
136
        query_object = query_context.queries[0]
137
        cache_key_original = query_context.query_cache_key(query_object)
138

139
        # ensure added None post_processing operation doesn't change query_cache_key
140
        payload["queries"][0]["post_processing"].append(None)
141
        query_context = ChartDataQueryContextSchema().load(payload)
142
        query_object = query_context.queries[0]
143 144
        cache_key = query_context.query_cache_key(query_object)
        self.assertEqual(cache_key_original, cache_key)
145 146 147

        # ensure query without post processing operation is different
        payload["queries"][0].pop("post_processing")
148
        query_context = ChartDataQueryContextSchema().load(payload)
149
        query_object = query_context.queries[0]
150 151
        cache_key = query_context.query_cache_key(query_object)
        self.assertNotEqual(cache_key_original, cache_key)
152

153 154 155
    def test_query_context_time_range_endpoints(self):
        """
        Ensure that time_range_endpoints are populated automatically when missing
156
        from the payload.
157 158
        """
        self.login(username="admin")
159
        payload = get_query_context("birth_names")
160
        del payload["queries"][0]["extras"]["time_range_endpoints"]
161
        query_context = ChartDataQueryContextSchema().load(payload)
162 163
        query_object = query_context.queries[0]
        extras = query_object.to_dict()["extras"]
164
        assert "time_range_endpoints" in extras
B
Bogdan 已提交
165
        self.assertEqual(
166 167 168 169
            extras["time_range_endpoints"],
            (TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE),
        )

170 171 172 173 174 175 176
    def test_handle_metrics_field(self):
        """
        Should support both predefined and adhoc metrics.
        """
        self.login(username="admin")
        adhoc_metric = {
            "expressionType": "SIMPLE",
K
Karol Kostrzewa 已提交
177
            "column": {"column_name": "num_boys", "type": "BIGINT(20)"},
178 179 180 181 182 183 184 185 186 187
            "aggregate": "SUM",
            "label": "Boys",
            "optionName": "metric_11",
        }
        payload = get_query_context("birth_names")
        payload["queries"][0]["metrics"] = ["sum__num", {"label": "abc"}, adhoc_metric]
        query_context = ChartDataQueryContextSchema().load(payload)
        query_object = query_context.queries[0]
        self.assertEqual(query_object.metrics, ["sum__num", "abc", adhoc_metric])

188 189 190 191 192
    def test_convert_deprecated_fields(self):
        """
        Ensure that deprecated fields are converted correctly
        """
        self.login(username="admin")
193
        payload = get_query_context("birth_names")
194
        payload["queries"][0]["granularity_sqla"] = "timecol"
195 196
        payload["queries"][0]["having_filters"] = [{"col": "a", "op": "==", "val": "b"}]
        query_context = ChartDataQueryContextSchema().load(payload)
197 198 199 200
        self.assertEqual(len(query_context.queries), 1)
        query_object = query_context.queries[0]
        self.assertEqual(query_object.granularity, "timecol")
        self.assertIn("having_druid", query_object.extras)
201

K
Karol Kostrzewa 已提交
202
    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
203 204 205 206 207
    def test_csv_response_format(self):
        """
        Ensure that CSV result format works
        """
        self.login(username="admin")
208
        payload = get_query_context("birth_names")
209
        payload["result_format"] = ChartDataResultFormat.CSV.value
210
        payload["queries"][0]["row_limit"] = 10
211
        query_context = ChartDataQueryContextSchema().load(payload)
212 213
        responses = query_context.get_payload()
        self.assertEqual(len(responses), 1)
214
        data = responses["queries"][0]["data"]
215 216 217
        self.assertIn("name,sum__num\n", data)
        self.assertEqual(len(data.split("\n")), 12)

218 219 220 221 222
    def test_sql_injection_via_groupby(self):
        """
        Ensure that calling invalid columns names in groupby are caught
        """
        self.login(username="admin")
223
        payload = get_query_context("birth_names")
224 225 226
        payload["queries"][0]["groupby"] = ["currentDatabase()"]
        query_context = ChartDataQueryContextSchema().load(payload)
        query_payload = query_context.get_payload()
227
        assert query_payload["queries"][0].get("error") is not None
228 229 230

    def test_sql_injection_via_columns(self):
        """
231
        Ensure that calling invalid column names in columns are caught
232 233
        """
        self.login(username="admin")
234
        payload = get_query_context("birth_names")
235 236 237 238 239
        payload["queries"][0]["groupby"] = []
        payload["queries"][0]["metrics"] = []
        payload["queries"][0]["columns"] = ["*, 'extra'"]
        query_context = ChartDataQueryContextSchema().load(payload)
        query_payload = query_context.get_payload()
240
        assert query_payload["queries"][0].get("error") is not None
241 242 243

    def test_sql_injection_via_metrics(self):
        """
244
        Ensure that calling invalid column names in filters are caught
245 246
        """
        self.login(username="admin")
247
        payload = get_query_context("birth_names")
248 249 250 251 252 253 254 255 256 257 258
        payload["queries"][0]["groupby"] = ["name"]
        payload["queries"][0]["metrics"] = [
            {
                "expressionType": AdhocMetricExpressionType.SIMPLE.value,
                "column": {"column_name": "invalid_col"},
                "aggregate": "SUM",
                "label": "My Simple Label",
            }
        ]
        query_context = ChartDataQueryContextSchema().load(payload)
        query_payload = query_context.get_payload()
259
        assert query_payload["queries"][0].get("error") is not None
260

K
Karol Kostrzewa 已提交
261
    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
262 263 264 265 266
    def test_samples_response_type(self):
        """
        Ensure that samples result type works
        """
        self.login(username="admin")
267
        payload = get_query_context("birth_names")
268
        payload["result_type"] = ChartDataResultType.SAMPLES.value
269
        payload["queries"][0]["row_limit"] = 5
270
        query_context = ChartDataQueryContextSchema().load(payload)
271 272
        responses = query_context.get_payload()
        self.assertEqual(len(responses), 1)
273
        data = responses["queries"][0]["data"]
274 275 276 277
        self.assertIsInstance(data, list)
        self.assertEqual(len(data), 5)
        self.assertNotIn("sum__num", data[0])

K
Karol Kostrzewa 已提交
278
    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
279 280 281 282 283
    def test_query_response_type(self):
        """
        Ensure that query result type works
        """
        self.login(username="admin")
284
        payload = get_query_context("birth_names")
285
        payload["result_type"] = ChartDataResultType.QUERY.value
286
        query_context = ChartDataQueryContextSchema().load(payload)
287 288
        responses = query_context.get_payload()
        self.assertEqual(len(responses), 1)
289
        response = responses["queries"][0]
290 291 292
        self.assertEqual(len(response), 2)
        self.assertEqual(response["language"], "sql")
        self.assertIn("SELECT", response["query"])
293 294 295 296 297 298 299 300

    def test_query_object_unknown_fields(self):
        """
        Ensure that query objects with unknown fields don't raise an Exception and
        have an identical cache key as one without the unknown field
        """
        self.maxDiff = None
        self.login(username="admin")
301
        payload = get_query_context("birth_names")
302 303 304 305 306 307 308 309
        query_context = ChartDataQueryContextSchema().load(payload)
        responses = query_context.get_payload()
        orig_cache_key = responses["queries"][0]["cache_key"]
        payload["queries"][0]["foo"] = "bar"
        query_context = ChartDataQueryContextSchema().load(payload)
        responses = query_context.get_payload()
        new_cache_key = responses["queries"][0]["cache_key"]
        self.assertEqual(orig_cache_key, new_cache_key)