mysql_test.go 5.9 KB
Newer Older
martianzhang's avatar
martianzhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Copyright 2018 Xiaomi, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package database

import (
20
	"flag"
martianzhang's avatar
martianzhang 已提交
21
	"fmt"
22
	"os"
martianzhang's avatar
martianzhang 已提交
23 24
	"path/filepath"
	"runtime"
martianzhang's avatar
martianzhang 已提交
25 26 27
	"testing"

	"github.com/XiaoMi/soar/common"
28

martianzhang's avatar
martianzhang 已提交
29 30 31
	"github.com/kr/pretty"
)

32 33 34
var connTest *Connector
var update = flag.Bool("update", false, "update .golden files")

35 36
func TestMain(m *testing.M) {
	// 初始化 init
martianzhang's avatar
martianzhang 已提交
37 38 39 40
	if common.DevPath == "" {
		_, file, _, _ := runtime.Caller(0)
		common.DevPath, _ = filepath.Abs(filepath.Dir(filepath.Join(file, ".."+string(filepath.Separator))))
	}
41
	common.BaseDir = common.DevPath
martianzhang's avatar
martianzhang 已提交
42 43 44 45 46
	err := common.ParseConfig("")
	common.LogIfError(err, "init ParseConfig")
	common.Log.Debug("mysql_test init")
	connTest, err = NewConnector(common.Config.TestDSN)
	if err != nil {
47 48 49 50
		common.Log.Critical("Test env Error: %v", err)
		os.Exit(0)
	}

martianzhang's avatar
martianzhang 已提交
51 52 53
	if _, err := connTest.Version(); err != nil {
		common.Log.Critical("Test env Error: %v", err)
		os.Exit(0)
54
	}
55 56 57 58 59 60 61

	// 分割线
	flag.Parse()
	m.Run()

	// 环境清理
	//
62 63
}

martianzhang's avatar
martianzhang 已提交
64
func TestQuery(t *testing.T) {
martianzhang's avatar
martianzhang 已提交
65
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
66 67 68 69 70 71 72 73 74 75 76 77 78
	res, err := connTest.Query("select 0")
	if err != nil {
		t.Error(err.Error())
	}
	for res.Rows.Next() {
		var val int
		err = res.Rows.Scan(&val)
		if err != nil {
			t.Error(err.Error())
		}
		if val != 0 {
			t.Error("should return 0")
		}
martianzhang's avatar
martianzhang 已提交
79
	}
martianzhang's avatar
martianzhang 已提交
80
	res.Rows.Close()
81
	// TODO: timeout test
martianzhang's avatar
martianzhang 已提交
82
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
83 84
}

85
func TestColumnCardinality(t *testing.T) {
martianzhang's avatar
martianzhang 已提交
86
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
87 88 89
	orgDatabase := connTest.Database
	connTest.Database = "sakila"
	a := connTest.ColumnCardinality("actor", "first_name")
90 91
	if a > 1 || a <= 0 {
		t.Error("sakila.actor.first_name cardinality should in [0, 1], now it's", a)
92 93
	}
	connTest.Database = orgDatabase
martianzhang's avatar
martianzhang 已提交
94
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
95 96 97
}

func TestDangerousSQL(t *testing.T) {
martianzhang's avatar
martianzhang 已提交
98
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112
	testCase := map[string]bool{
		"select * from tb;delete from tb;": true,
		"show database;":                   false,
		"select * from t;":                 false,
		"explain delete from t;":           false,
	}

	db := Connector{}
	for sql, want := range testCase {
		got := db.dangerousQuery(sql)
		if got != want {
			t.Errorf("SQL:%s got:%v want:%v", sql, got, want)
		}
	}
martianzhang's avatar
martianzhang 已提交
113
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
114 115 116
}

func TestWarningsAndQueryCost(t *testing.T) {
martianzhang's avatar
martianzhang 已提交
117
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
118 119 120 121 122 123
	common.Config.ShowWarnings = true
	common.Config.ShowLastQueryCost = true
	res, err := connTest.Query("explain select * from sakila.film")
	if err != nil {
		t.Error("Query Error: ", err)
	} else {
124 125 126 127 128 129 130
		for res.Warning.Next() {
			var str string
			err = res.Warning.Scan(str)
			if err != nil {
				t.Error(err.Error())
			}
			pretty.Println(str)
martianzhang's avatar
martianzhang 已提交
131
		}
martianzhang's avatar
martianzhang 已提交
132
		res.Warning.Close()
133
		fmt.Println(res.QueryCost, err)
martianzhang's avatar
martianzhang 已提交
134
	}
martianzhang's avatar
martianzhang 已提交
135
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
136 137 138
}

func TestVersion(t *testing.T) {
martianzhang's avatar
martianzhang 已提交
139
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
140 141 142 143 144
	version, err := connTest.Version()
	if err != nil {
		t.Error(err.Error())
	}
	fmt.Println(version)
martianzhang's avatar
martianzhang 已提交
145
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
146 147
}

martianzhang's avatar
martianzhang 已提交
148
func TestRemoveSQLComments(t *testing.T) {
149
	// Notice: double dash without space not comment, eg. `--not comment`
martianzhang's avatar
martianzhang 已提交
150
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
151
	SQLs := []string{
martianzhang's avatar
martianzhang 已提交
152 153
		`select 'c#\'#not comment'`,
		`select "c#\"#not comment"`,
martianzhang's avatar
martianzhang 已提交
154 155 156
		`-- comment`,
		`--`,
		`# comment`,
157
		`#comment`,
martianzhang's avatar
martianzhang 已提交
158 159 160 161 162 163 164 165 166 167 168 169 170
		`/* multi-line
comment*/`,
		`--
-- comment`,
	}
	err := common.GoldenDiff(func() {
		for _, sql := range SQLs {
			fmt.Println(RemoveSQLComments(sql))
		}
	}, t.Name(), update)
	if err != nil {
		t.Error(err)
	}
martianzhang's avatar
martianzhang 已提交
171
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
172
}
173 174

func TestSingleIntValue(t *testing.T) {
martianzhang's avatar
martianzhang 已提交
175
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
176 177 178 179 180 181 182
	val, err := connTest.SingleIntValue("read_only")
	if err != nil {
		t.Error(err)
	}
	if val < 0 {
		t.Error("SingleIntValue, return should large than zero")
	}
martianzhang's avatar
martianzhang 已提交
183
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
184 185 186
}

func TestIsView(t *testing.T) {
martianzhang's avatar
martianzhang 已提交
187
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
188 189 190 191 192 193
	originalDatabase := connTest.Database
	connTest.Database = "sakila"
	if !connTest.IsView("actor_info") {
		t.Error("actor_info should be a VIEW")
	}
	connTest.Database = originalDatabase
martianzhang's avatar
martianzhang 已提交
194
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
195
}
196 197 198 199 200 201 202 203 204 205 206 207 208 209

func TestNullString(t *testing.T) {
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
	cases := [][]byte{
		nil,
		[]byte("NULL"),
	}
	for _, buf := range cases {
		if NullString(buf) != "NULL" {
			t.Errorf("%s want NULL", string(buf))
		}
	}
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
210

martianzhang's avatar
martianzhang 已提交
211
func TestEscape(t *testing.T) {
212 213 214 215 216 217 218 219 220 221 222 223 224 225
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
	cases := []string{
		"",
		"hello world",
		"hello' world",
		`hello" world`,
		"hello\000world",
		`hello\ world`,
		"hello\032world",
		"hello\rworld",
		"hello\nworld",
	}
	err := common.GoldenDiff(func() {
		for _, str := range cases {
martianzhang's avatar
martianzhang 已提交
226 227
			fmt.Println(Escape(str, false))
			fmt.Println(Escape(str, true))
228 229 230 231 232 233 234
		}
	}, t.Name(), update)
	if err != nil {
		t.Error(err)
	}
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}