提交 3991dde4 编写于 作者: martianzhang's avatar martianzhang

change mymysql into go-sql-driver/mysql

上级 a60a145f
......@@ -181,7 +181,7 @@ docker:
.PHONY: connect
connect:
mysql -h 127.0.0.1 -u root -p1tIsB1g3rt -c
mysql -h 127.0.0.1 -u root -p1tIsB1g3rt sakila -c
.PHONY: main_test
main_test: install
......
......@@ -1990,7 +1990,7 @@ func (idxAdv *IndexAdvisor) RuleUpdatePrimaryKey() Rule {
if idxMeta == nil {
return rule
}
for _, idx := range idxMeta.IdxRows {
for _, idx := range idxMeta.Rows {
if idx.KeyName == "PRIMARY" {
if col.Name == idx.ColumnName {
rule = HeuristicRules["CLA.016"]
......
......@@ -914,7 +914,7 @@ func (idxAdv *IndexAdvisor) calcCardinality(cols []*common.Column) []*common.Col
// 检查对应列是否为主键或单列唯一索引,如果满足直接返回1,不再重复计算,提高效率
// 多列复合唯一索引不能跳过计算,单列普通索引不能跳过计算
for _, index := range idxAdv.IndexMeta[realDB][col.Table].IdxRows {
for _, index := range idxAdv.IndexMeta[realDB][col.Table].Rows {
// 根据索引的名称判断该索引包含的列数,列数大于1即为复合索引
columnCount := len(idxAdv.IndexMeta[realDB][col.Table].FindIndex(database.IndexKeyName, index.KeyName))
if col.Name == index.ColumnName {
......@@ -1079,7 +1079,7 @@ func DuplicateKeyChecker(conn *database.Connector, databases ...string) map[stri
}
// 枚举所有的索引信息,提取用到的列
for _, idx := range idxInfo.IdxRows {
for _, idx := range idxInfo.Rows {
if _, ok := idxMap[idx.KeyName]; !ok {
idxMap[idx.KeyName] = make([]*common.Column, 0)
for _, col := range idxInfo.FindIndex(database.IndexKeyName, idx.KeyName) {
......
......@@ -102,7 +102,7 @@ type Rule struct {
* SEC Security
* STA Standard
* SUB Subquery
* TBL Table
* TBL TableName
* TRA Trace, 由trace模块给
*/
......
......@@ -67,7 +67,7 @@ func init() {
"SELECT * FROM customer WHERE address_id in (224,510) ORDER BY last_name;", // INDEX(address_id)
"SELECT * FROM film WHERE release_year = 2016 AND length != 1 ORDER BY title;", // INDEX(`release_year`, `length`, `title`)
// "Covering" IdxRows
// "Covering" Rows
"SELECT title FROM film WHERE release_year = 1995;", // INDEX(release_year, title)",
"SELECT title, replacement_cost FROM film WHERE language_id = 5 AND length = 70;", // INDEX(language_id, length, title, replacement_cos film ), title, replacement_cost顺序无关,language_id, length顺序视散粒度情况.
"SELECT title FROM film WHERE language_id > 5 AND length > 70;", // INDEX(language_id, length, title) language_id or length first (that's as far as the Algorithm goes), then the other two fields afterwards.
......
......@@ -59,8 +59,6 @@ type Configuration struct {
Profiling bool `yaml:"profiling"` // 在开启数据采样的情况下,在测试环境执行进行profile
Trace bool `yaml:"trace"` // 在开启数据采样的情况下,在测试环境执行进行Trace
Explain bool `yaml:"explain"` // Explain开关
ConnTimeOut int `yaml:"conn-time-out"` // 数据库连接超时时间,单位秒
QueryTimeOut int `yaml:"query-time-out"` // 数据库SQL执行超时时间,单位秒
Delimiter string `yaml:"delimiter"` // SQL分隔符
// +++++++++++++++日志相关+++++++++++++++++
......@@ -97,8 +95,8 @@ type Configuration struct {
MaxInCount int `yaml:"max-in-count"` // IN()最大数量
MaxIdxBytesPerColumn int `yaml:"max-index-bytes-percolumn"` // 索引中单列最大字节数,默认767
MaxIdxBytes int `yaml:"max-index-bytes"` // 索引总长度限制,默认3072
TableAllowCharsets []string `yaml:"table-allow-charsets"` // Table 允许使用的 DEFAULT CHARSET
TableAllowEngines []string `yaml:"table-allow-engines"` // Table 允许使用的 Engine
TableAllowCharsets []string `yaml:"table-allow-charsets"` // TableName 允许使用的 DEFAULT CHARSET
TableAllowEngines []string `yaml:"table-allow-engines"` // TableName 允许使用的 Engine
MaxIdxCount int `yaml:"max-index-count"` // 单张表允许最多索引数
MaxColCount int `yaml:"max-column-count"` // 单张表允许最大列数
MaxValueCount int `yaml:"max-value-count"` // INSERT/REPLACE 单次允许批量写入的行数
......@@ -135,12 +133,14 @@ type Configuration struct {
// Config 默认设置
var Config = &Configuration{
OnlineDSN: &dsn{
Net: "tcp",
Schema: "information_schema",
Charset: "utf8mb4",
Disable: true,
Version: 99999,
},
TestDSN: &dsn{
Net: "tcp",
Schema: "information_schema",
Charset: "utf8mb4",
Disable: true,
......@@ -156,8 +156,6 @@ var Config = &Configuration{
Profiling: false,
Trace: false,
Explain: true,
ConnTimeOut: 3,
QueryTimeOut: 30,
Delimiter: ";",
MaxJoinTableCount: 5,
......@@ -227,6 +225,7 @@ var Config = &Configuration{
}
type dsn struct {
Net string `yaml:"net"`
Addr string `yaml:"addr"`
Schema string `yaml:"schema"`
......@@ -236,6 +235,10 @@ type dsn struct {
Charset string `yaml:"charset"`
Disable bool `yaml:"disable"`
Timeout int `yaml:"timeout"`
ReadTimeout int `yaml:"read-timeout"`
WriteTimeout int `yaml:"write-timeout"`
Version int `yaml:"-"` // 版本自动检查,不可配置
}
......@@ -502,8 +505,6 @@ func readCmdFlags() error {
explain := flag.Bool("explain", Config.Explain, "Explain, 是否开启Explain执行计划分析")
sampling := flag.Bool("sampling", Config.Sampling, "Sampling, 数据采样开关")
samplingStatisticTarget := flag.Int("sampling-statistic-target", Config.SamplingStatisticTarget, "SamplingStatisticTarget, 数据采样因子,对应 PostgreSQL 的 default_statistics_target")
connTimeOut := flag.Int("conn-time-out", Config.ConnTimeOut, "ConnTimeOut, 数据库连接超时时间,单位秒")
queryTimeOut := flag.Int("query-time-out", Config.QueryTimeOut, "QueryTimeOut, 数据库SQL执行超时时间,单位秒")
delimiter := flag.String("delimiter", Config.Delimiter, "Delimiter, SQL分隔符")
// +++++++++++++++日志相关+++++++++++++++++
logLevel := flag.Int("log-level", Config.LogLevel, "LogLevel, 日志级别, [0:Emergency, 1:Alert, 2:Critical, 3:Error, 4:Warning, 5:Notice, 6:Informational, 7:Debug]")
......@@ -583,8 +584,6 @@ func readCmdFlags() error {
Config.Explain = *explain
Config.Sampling = *sampling
Config.SamplingStatisticTarget = *samplingStatisticTarget
Config.ConnTimeOut = *connTimeOut
Config.QueryTimeOut = *queryTimeOut
Config.LogLevel = *logLevel
if strings.HasPrefix(*logOutput, "/") {
......
......@@ -27,7 +27,7 @@ type Meta map[string]*DB
// DB 数据库相关的结构体
type DB struct {
Name string
Table map[string]*Table // ['table_name']*Table
Table map[string]*Table // ['table_name']*TableName
}
// NewDB 用于初始化*DB
......@@ -38,14 +38,14 @@ func NewDB(db string) *DB {
}
}
// Table 含有表的属性
// TableName 含有表的属性
type Table struct {
TableName string
TableAliases []string
Column map[string]*Column
}
// NewTable 初始化*Table
// NewTable 初始化*TableName
func NewTable(tb string) *Table {
return &Table{
TableName: tb,
......
......@@ -176,9 +176,9 @@ $$
<p>Typora support <a href="http://jekyllrb.com/docs/frontmatter/">YAML Front Matter</a> now. Input <code>---</code> at the top of the article and then press <code>Enter</code> will introduce one. Or insert one metadata block from the menu.</p>
<h3>Table of Contents (TOC)</h3>
<h3>TableName of Contents (TOC)</h3>
<p>Input <code>[toc]</code> then press <code>Return</code> key will create a section for “Table of Contents” extracting all headers from one’s writing, its contents will be updated automatically.</p>
<p>Input <code>[toc]</code> then press <code>Return</code> key will create a section for “TableName of Contents” extracting all headers from one’s writing, its contents will be updated automatically.</p>
<h3>Diagrams (Sequence, Flowchart and Mermaid)</h3>
......
......@@ -186,9 +186,9 @@ Input `***` or `---` on a blank line and press `return` will draw a horizontal l
Typora support [YAML Front Matter](http://jekyllrb.com/docs/frontmatter/) now. Input `---` at the top of the article and then press `Enter` will introduce one. Or insert one metadata block from the menu.
### Table of Contents (TOC)
### TableName of Contents (TOC)
Input `[toc]` then press `Return` key will create a section for “Table of Contents” extracting all headers from one’s writing, its contents will be updated automatically.
Input `[toc]` then press `Return` key will create a section for “TableName of Contents” extracting all headers from one’s writing, its contents will be updated automatically.
### Diagrams (Sequence, Flowchart and Mermaid)
......
......@@ -267,6 +267,7 @@ var ExplainKeyWords = []string{
"using_temporary_table",
}
/*
// ExplainColumnIndent EXPLAIN表头
var ExplainColumnIndent = map[string]string{
"id": "id为SELECT的标识符. 它是在SELECT查询中的顺序编号. 如果这一行表示其他行的union结果, 这个值可以为空. 在这种情况下, table列会显示为形如<union M,N>, 表示它是id为M和N的查询行的联合结果.",
......@@ -281,6 +282,7 @@ var ExplainColumnIndent = map[string]string{
"filtered": "表示返回结果的行占需要读到的行(rows列的值)的百分比.",
"Extra": "该列显示MySQL在查询过程中的一些详细信息, MySQL查询优化器执行查询的过程中对查询计划的重要补充信息.",
}
*/
// ExplainSelectType EXPLAIN中SELECT TYPE会出现的类型
var ExplainSelectType = map[string]string{
......@@ -555,14 +557,16 @@ func (db *Connector) explainAbleSQL(sql string) (string, error) {
}
// 执行explain请求,返回mysql.Result执行结果
func (db *Connector) executeExplain(sql string, explainType int, formatType int) (*QueryResult, error) {
func (db *Connector) executeExplain(sql string, explainType int, formatType int) (QueryResult, error) {
var res QueryResult
var err error
var explainQuery string
sql, err = db.explainAbleSQL(sql)
if sql == "" {
return nil, err
return res, err
}
// 5.6以上支持FORMAT=JSON
// 5.6以上支持 FORMAT=JSON
explainFormat := ""
switch formatType {
case JSONFormatExplain:
......@@ -570,22 +574,23 @@ func (db *Connector) executeExplain(sql string, explainType int, formatType int)
explainFormat = "FORMAT=JSON"
}
}
// 执行explain
var res *QueryResult
// 执行 explain
switch explainType {
case ExtendedExplainType:
// 5.6以上extended关键字已经不推荐使用,8.0废弃了这个关键字
if common.Config.TestDSN.Version >= 50600 {
res, err = db.Query("explain %s", sql)
explainQuery = fmt.Sprintf("explain %s", sql)
} else {
res, err = db.Query("explain extended %s", sql)
explainQuery = fmt.Sprintf("explain extended %s", sql)
}
case PartitionsExplainType:
res, err = db.Query("explain partitions %s", sql)
explainQuery = fmt.Sprintf("explain partitions %s", sql)
default:
res, err = db.Query("explain %s %s", explainFormat, sql)
explainQuery = fmt.Sprintf("explain %s %s", explainFormat, sql)
}
res, err = db.Query(explainQuery)
return res, err
}
......@@ -928,76 +933,75 @@ func parseJSONExplainText(content string) (*ExplainJSON, error) {
}
// ParseExplainResult 分析 mysql 执行 explain 的结果,返回 ExplainInfo 结构化数据
func ParseExplainResult(res *QueryResult, formatType int) (exp *ExplainInfo, err error) {
func ParseExplainResult(res QueryResult, formatType int) (exp *ExplainInfo, err error) {
exp = &ExplainInfo{
ExplainFormat: formatType,
}
// JSON 格式直接调用文本方式解析
if formatType == JSONFormatExplain {
exp.ExplainJSON, err = parseJSONExplainText(res.Rows[0].Str(0))
if res.Rows.Next() {
var explainString string
err = res.Rows.Scan(&explainString)
exp.ExplainJSON, err = parseJSONExplainText(explainString)
}
return exp, err
}
// 生成表头
colIdx := make(map[int]string)
for i, f := range res.Result.Fields() {
colIdx[i] = strings.ToLower(f.Name)
// Different MySQL version has different columns define
var possibleKeys string
expRow := &ExplainRow{}
explainFields := make([]interface{}, 0)
fields := map[string]interface{}{
"id": expRow.ID,
"select_type": expRow.SelectType,
"table": expRow.TableName,
"partitions": expRow.Partitions,
"type": expRow.AccessType,
"possible_keys": &possibleKeys,
"key": expRow.Key,
"key_len": expRow.KeyLen,
"ref": expRow.Ref,
"rows": expRow.Rows,
"filtered": expRow.Filtered,
"extra": expRow.Extra,
}
cols, err := res.Rows.Columns()
common.LogIfError(err, "")
for _, col := range cols {
explainFields = append(explainFields, fields[col])
}
// 补全 ExplainRows
var explainrows []*ExplainRow
for _, row := range res.Rows {
expRow := &ExplainRow{Partitions: "NULL", Filtered: 0.00}
// list 到 map 的转换
for i := range row {
switch colIdx[i] {
case "id":
expRow.ID = row.ForceInt(i)
case "select_type":
expRow.SelectType = row.Str(i)
case "table":
expRow.TableName = row.Str(i)
if expRow.TableName == "" {
expRow.TableName = "NULL"
}
case "type":
expRow.AccessType = row.Str(i)
if expRow.AccessType == "" {
expRow.AccessType = "NULL"
}
expRow.Scalability = ExplainScalability[expRow.AccessType]
case "possible_keys":
expRow.PossibleKeys = strings.Split(row.Str(i), ",")
case "key":
expRow.Key = row.Str(i)
if expRow.Key == "" {
expRow.Key = "NULL"
}
case "key_len":
expRow.KeyLen = row.Str(i)
case "ref":
expRow.Ref = strings.Split(row.Str(i), ",")
case "rows":
expRow.Rows = row.ForceInt(i)
case "extra":
expRow.Extra = row.Str(i)
if expRow.Extra == "" {
expRow.Extra = "NULL"
}
case "filtered":
expRow.Filtered = row.ForceFloat(i)
var explainRows []*ExplainRow
for res.Rows.Next() {
res.Rows.Scan(explainFields...)
expRow.PossibleKeys = strings.Split(possibleKeys, ",")
// MySQL bug: https://bugs.mysql.com/bug.php?id=34124
if expRow.Filtered > 100.00 {
expRow.Filtered = 100.00
}
expRow.Scalability = ExplainScalability[expRow.AccessType]
explainRows = append(explainRows, expRow)
}
}
explainrows = append(explainrows, expRow)
}
exp.ExplainRows = explainrows
for _, w := range res.Warning {
exp.ExplainRows = explainRows
// check explain warning info
if common.Config.ShowWarnings {
for res.Warning.Next() {
var expWarning *ExplainWarning
res.Warning.Scan(
expWarning.Level,
expWarning.Code,
expWarning.Message,
)
// 'EXTENDED' is deprecated and will be removed in a future release.
if w.Int(1) != 1681 {
exp.Warnings = append(exp.Warnings, &ExplainWarning{Level: w.Str(0), Code: w.Int(1), Message: w.Str(2)})
if expWarning.Code != 1681 {
exp.Warnings = append(exp.Warnings, expWarning)
}
}
}
......@@ -1009,7 +1013,6 @@ func ParseExplainResult(res *QueryResult, formatType int) (exp *ExplainInfo, err
// Explain 获取 SQL 的 explain 信息
func (db *Connector) Explain(sql string, explainType int, formatType int) (exp *ExplainInfo, err error) {
exp = &ExplainInfo{SQL: sql}
if explainType != TraditionalExplainType {
formatType = TraditionalFormatExplain
}
......@@ -1025,12 +1028,16 @@ func (db *Connector) Explain(sql string, explainType int, formatType int) (exp *
// 执行EXPLAIN请求
res, err := db.executeExplain(sql, explainType, formatType)
if err != nil || res == nil {
if err != nil {
return exp, err
}
if res.Error != nil {
return exp, res.Error
}
// 解析mysql结果,输出ExplainInfo
exp, err = ParseExplainResult(res, formatType)
exp.SQL = sql
return exp, err
}
......
......@@ -17,8 +17,6 @@
package database
import (
"fmt"
"os"
"testing"
"github.com/XiaoMi/soar/common"
......@@ -26,23 +24,6 @@ import (
"github.com/kr/pretty"
)
var connTest *Connector
func init() {
common.BaseDir = common.DevPath
common.ParseConfig("")
connTest = &Connector{
Addr: common.Config.OnlineDSN.Addr,
User: common.Config.OnlineDSN.User,
Pass: common.Config.OnlineDSN.Password,
Database: common.Config.OnlineDSN.Schema,
}
if _, err := connTest.Version(); err != nil {
common.Log.Critical("Test env Error: %v", err)
os.Exit(0)
}
}
var sqls = []string{
`select * from city where country_id = 44;`,
`select * from address where address2 is not null;`,
......@@ -54,10 +35,10 @@ var sqls = []string{
`select * from city where country_id > 31 and city = 'Aden';`,
`select * from address where address_id > 8 and city_id < 400 and district = 'Nantou';`,
`select * from address where address_id > 8 and city_id < 400;`,
`select * from actor where last_update='2006-02-15 04:34:33' and last_name='CHASE' group by first_name;`,
`select * from address where last_update >='2014-09-25 22:33:47' group by district;`,
`select * from address group by address,district;`,
`select * from address where last_update='2014-09-25 22:30:27' group by district,(address_id+city_id);`,
`select first_name from actor where last_update='2006-02-15 04:34:33' and last_name='CHASE' group by first_name;`,
`select district from address where last_update >='2014-09-25 22:33:47' group by district;`,
`select address from address group by address,district;`,
`select district from address where last_update='2014-09-25 22:30:27' group by district,(address_id+city_id);`,
`select * from customer where active=1 order by last_name limit 10;`,
`select * from customer order by last_name limit 10;`,
`select * from customer where address_id > 224 order by address_id limit 10;`,
......@@ -2351,16 +2332,23 @@ possible_keys: idx_fk_country_id,idx_country_id_city,idx_all,idx_other
}
func TestExplain(t *testing.T) {
for _, sql := range sqls {
// TraditionalFormatExplain
for idx, sql := range sqls {
exp, err := connTest.Explain(sql, TraditionalExplainType, TraditionalFormatExplain)
//exp, err := conn.Explain(sql, TraditionalExplainType, JSONFormatExplain)
fmt.Println("Old: ", sql)
fmt.Println("New: ", exp.SQL)
if err != nil {
fmt.Println(err)
t.Error(err)
}
pretty.Println("No.:", idx, "\nOld: ", sql, "\nNew: ", exp.SQL)
pretty.Println(exp)
}
// JSONFormatExplain
for idx, sql := range sqls {
exp, err := connTest.Explain(sql, TraditionalExplainType, JSONFormatExplain)
if err != nil {
t.Error(err)
}
pretty.Println("No.:", idx, "\nOld: ", sql, "\nNew: ", exp.SQL)
pretty.Println(exp)
fmt.Println()
}
}
......@@ -2400,6 +2388,7 @@ func TestPrintMarkdownExplainTable(t *testing.T) {
if err != nil {
t.Error(err)
}
err = common.GoldenDiff(func() {
PrintMarkdownExplainTable(expInfo)
}, t.Name(), update)
......
......@@ -17,21 +17,17 @@
package database
import (
"database/sql"
"errors"
"fmt"
"io/ioutil"
"os"
"regexp"
"strconv"
"strings"
"time"
"github.com/XiaoMi/soar/ast"
"github.com/XiaoMi/soar/common"
"github.com/ziutek/mymysql/mysql"
// mymysql driver
_ "github.com/ziutek/mymysql/native"
// for database/sql
_ "github.com/go-sql-driver/mysql"
"vitess.io/vitess/go/vt/sqlparser"
)
......@@ -42,81 +38,62 @@ type Connector struct {
Pass string
Database string
Charset string
Net string
}
// QueryResult 数据库查询返回值
type QueryResult struct {
Rows []mysql.Row
Result mysql.Result
Rows *sql.Rows
Error error
Warning []mysql.Row
Warning *sql.Rows
QueryCost float64
}
// NewConnection 创建新连接
func (db *Connector) NewConnection() mysql.Conn {
return mysql.New("tcp", "", db.Addr, db.User, db.Pass, db.Database)
func (db *Connector) NewConnection() (*sql.DB, error) {
dsn := fmt.Sprintf("%s:%s@%s(%s)/%s?parseTime=true&charset=%s", db.User, db.Pass, db.Net, db.Addr, db.Database, db.Charset)
return sql.Open("mysql", dsn)
}
// Query 执行SQL
func (db *Connector) Query(sql string, params ...interface{}) (*QueryResult, error) {
func (db *Connector) Query(sql string, params ...interface{}) (QueryResult, error) {
var res QueryResult
// 测试环境如果检查是关闭的,则SQL不会被执行
if common.Config.TestDSN.Disable {
return nil, errors.New("Dsn Disable")
return res, errors.New("dsn is disable")
}
// 数据库安全性检查:如果 Connector 的 IP 端口与 TEST 环境不一致,则启用SQL白名单
// 不在白名单中的SQL不允许执行
// 执行环境与test环境不相同
if db.Addr != common.Config.TestDSN.Addr && db.dangerousQuery(sql) {
return nil, fmt.Errorf("query execution deny: execute SQL with DSN(%s/%s) '%s'",
return res, fmt.Errorf("query execution deny: execute SQL with DSN(%s/%s) '%s'",
db.Addr, db.Database, fmt.Sprintf(sql, params...))
}
common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, fmt.Sprintf(sql, params...))
conn := db.NewConnection()
// 设置SQL连接超时时间
conn.SetTimeout(time.Duration(common.Config.ConnTimeOut) * time.Second)
conn, err := db.NewConnection()
defer conn.Close()
err := conn.Connect()
if err != nil {
return nil, err
return res, err
}
// 添加SQL执行超时限制
ch := make(chan QueryResult, 1)
go func() {
res := QueryResult{}
res.Rows, res.Result, res.Error = conn.Query(sql, params...)
res.Rows, res.Error = conn.Query(sql, params...)
if common.Config.ShowWarnings {
warning, _, err := conn.Query("SHOW WARNINGS")
if err == nil {
res.Warning = warning
}
res.Warning, err = conn.Query("SHOW WARNINGS")
}
// SHOW WARNINGS 并不会影响 last_query_cost
if common.Config.ShowLastQueryCost {
cost, _, err := conn.Query("SHOW SESSION STATUS LIKE 'last_query_cost'")
cost, err := conn.Query("SHOW SESSION STATUS LIKE 'last_query_cost'")
if err == nil {
if len(cost) > 0 {
res.QueryCost = cost[0].Float(1)
if cost.Next() {
err = cost.Scan(res.QueryCost)
}
}
}
ch <- res
}()
select {
case res := <-ch:
return &res, res.Error
case <-time.After(time.Duration(common.Config.QueryTimeOut) * time.Second):
return nil, errors.New("query execution timeout")
}
return res, err
}
// Version 获取MySQL数据库版本
......@@ -124,77 +101,43 @@ func (db *Connector) Version() (int, error) {
version := 99999
// 从数据库中获取版本信息
res, err := db.Query("select @@version")
if err != nil {
common.Log.Warn("(db *Connector) Version() Error: %v", err)
if err != nil || res.Error != nil {
common.Log.Warn("(db *Connector) Version() Error: %v, MySQL Error: %v", err, res.Error)
return version, err
}
// MariaDB https://mariadb.com/kb/en/library/comment-syntax/
// MySQL https://dev.mysql.com/doc/refman/8.0/en/comments.html
versionStr := strings.Split(res.Rows[0].Str(0), "-")[0]
versionSeg := strings.Split(versionStr, ".")
var versionStr string
var versionSeg []string
for res.Rows.Next() {
err = res.Rows.Scan(&versionStr)
versionStr = strings.Split(versionStr, "-")[0]
versionSeg = strings.Split(versionStr, ".")
if len(versionSeg) == 3 {
versionStr = fmt.Sprintf("%s%02s%02s", versionSeg[0], versionSeg[1], versionSeg[2])
version, err = strconv.Atoi(versionStr)
}
return version, err
}
// Source execute sql from file
func (db *Connector) Source(file string) ([]*QueryResult, error) {
var sqlCounter int // SQL 计数器
var result []*QueryResult
fd, err := os.Open(file)
defer func() {
err = fd.Close()
if err != nil {
common.Log.Error("(db *Connector) Source(%s) fd.Close failed: %s", file, err.Error())
}
}()
if err != nil {
common.Log.Warning("(db *Connector) Source(%s) os.Open failed: %s", file, err.Error())
return nil, err
}
data, err := ioutil.ReadAll(fd)
if err != nil {
common.Log.Critical("ioutil.ReadAll Error: %s", err.Error())
return nil, err
}
sql := strings.TrimSpace(string(data))
buf := strings.TrimSpace(sql)
for ; ; sqlCounter++ {
if buf == "" {
break
}
// 查询请求切分
_, sql, bufBytes := ast.SplitStatement([]byte(buf), []byte(common.Config.Delimiter))
buf = string(bufBytes)
sql = strings.TrimSpace(sql)
common.Log.Debug("Source Query SQL: %s", sql)
res, e := db.Query(sql)
if e != nil {
common.Log.Error("(db *Connector) Source Filename: %s, SQLCounter.: %d", file, sqlCounter)
return result, e
}
result = append(result, res)
}
return result, nil
return version, err
}
// SingleIntValue 获取某个int型变量的值
func (db *Connector) SingleIntValue(option string) (int, error) {
// 从数据库中获取信息
res, err := db.Query("select @@%s", option)
res, err := db.Query("select @@" + option)
if err != nil {
common.Log.Warn("(db *Connector) SingleIntValue() Error: %v", err)
return -1, err
}
return res.Rows[0].Int(0), err
var intVal int
if res.Rows.Next() {
err = res.Rows.Scan(&intVal)
}
return intVal, err
}
// ColumnCardinality 粒度计算
......@@ -228,13 +171,20 @@ func (db *Connector) ColumnCardinality(tb, col string) float64 {
}
// 计算该列散粒度
res, err := db.Query("select count(distinct `%s`) from `%s`.`%s`", col, db.Database, tb)
res, err := db.Query(fmt.Sprintf("select count(distinct `%s`) from `%s`.`%s`", col, db.Database, tb))
if err != nil {
common.Log.Warn("(db *Connector) ColumnCardinality() Query Error: %v", err)
return 0
}
colNum := res.Rows[0].Float(0)
var colNum float64
if res.Rows.Next() {
err = res.Rows.Scan(&colNum)
if err != nil {
common.Log.Warn("(db *Connector) ColumnCardinality() Query Error: %v", err)
return 0
}
}
// 散粒度区间:[0,1]
return colNum / float64(rowTotal)
......@@ -249,13 +199,12 @@ func (db *Connector) IsView(tbName string) bool {
}
if len(tbStatus.Rows) > 0 {
if tbStatus.Rows[0].Comment == "VIEW" {
if string(tbStatus.Rows[0].Comment) == "VIEW" {
return true
}
}
return false
}
// RemoveSQLComments 去除SQL中的注释
......@@ -281,7 +230,7 @@ func (db *Connector) dangerousQuery(query string) bool {
return true
}
for _, sql := range queries {
for _, query := range queries {
dangerous := true
whiteList := []string{
"select",
......@@ -291,7 +240,7 @@ func (db *Connector) dangerousQuery(query string) bool {
}
for _, prefix := range whiteList {
if strings.HasPrefix(sql, prefix) {
if strings.HasPrefix(query, prefix) {
dangerous = false
break
}
......
......@@ -17,26 +17,70 @@
package database
import (
"flag"
"fmt"
"os"
"testing"
"github.com/XiaoMi/soar/common"
"github.com/kr/pretty"
)
var connTest *Connector
var update = flag.Bool("update", false, "update .golden files")
func init() {
common.BaseDir = common.DevPath
common.ParseConfig("")
connTest = &Connector{
Addr: common.Config.OnlineDSN.Addr,
User: common.Config.OnlineDSN.User,
Pass: common.Config.OnlineDSN.Password,
Database: common.Config.OnlineDSN.Schema,
Charset: common.Config.OnlineDSN.Charset,
}
if _, err := connTest.Version(); err != nil {
common.Log.Critical("Test env Error: %v", err)
os.Exit(0)
}
}
func TestNewConnection(t *testing.T) {
_, err := connTest.NewConnection()
if err != nil {
t.Errorf("TestNewConnection, Error: %s", err.Error())
}
}
// TODO: go test -race不通过待解决
func TestQuery(t *testing.T) {
common.Config.QueryTimeOut = 1
_, err := connTest.Query("select sleep(2)")
if err == nil {
t.Error("connTest.Query not timeout")
res, err := connTest.Query("select 0")
if err != nil {
t.Error(err.Error())
}
for res.Rows.Next() {
var val int
err = res.Rows.Scan(&val)
if err != nil {
t.Error(err.Error())
}
if val != 0 {
t.Error("should return 0")
}
}
// TODO: timeout test
}
func TestColumnCardinality(_ *testing.T) {
connTest.Database = "information_schema"
a := connTest.ColumnCardinality("TABLES", "TABLE_SCHEMA")
fmt.Println("TABLES.TABLE_SCHEMA:", a)
func TestColumnCardinality(t *testing.T) {
orgDatabase := connTest.Database
connTest.Database = "sakila"
a := connTest.ColumnCardinality("actor", "first_name")
if a >= 1 || a <= 0 {
t.Error("sakila.actor.first_name cardinality should in (0, 1), now it's", a)
}
connTest.Database = orgDatabase
}
func TestDangerousSQL(t *testing.T) {
......@@ -63,11 +107,15 @@ func TestWarningsAndQueryCost(t *testing.T) {
if err != nil {
t.Error("Query Error: ", err)
} else {
for _, w := range res.Warning {
pretty.Println(w.Str(2))
for res.Warning.Next() {
var str string
err = res.Warning.Scan(str)
if err != nil {
t.Error(err.Error())
}
pretty.Println(str)
}
fmt.Println(res.QueryCost)
pretty.Println(err)
fmt.Println(res.QueryCost, err)
}
}
......@@ -79,16 +127,6 @@ func TestVersion(t *testing.T) {
fmt.Println(version)
}
func TestSource(t *testing.T) {
res, err := connTest.Source("testdata/" + t.Name() + ".sql")
if err != nil {
t.Error("Query Error: ", err)
}
if res[0].Rows[0].Int(0) != 1 || res[1].Rows[0].Int(0) != 1 {
t.Error("Source result not match, expect 1, 1")
}
}
func TestRemoveSQLComments(t *testing.T) {
SQLs := []string{
`-- comment`,
......@@ -109,3 +147,22 @@ comment*/`,
t.Error(err)
}
}
func TestSingleIntValue(t *testing.T) {
val, err := connTest.SingleIntValue("read_only")
if err != nil {
t.Error(err)
}
if val < 0 {
t.Error("SingleIntValue, return should large than zero")
}
}
func TestIsView(t *testing.T) {
originalDatabase := connTest.Database
connTest.Database = "sakila"
if !connTest.IsView("actor_info") {
t.Error("actor_info should be a VIEW")
}
connTest.Database = originalDatabase
}
......@@ -18,6 +18,7 @@ package database
import (
"errors"
"fmt"
"strings"
"github.com/XiaoMi/soar/common"
......@@ -29,8 +30,13 @@ func (db *Connector) CurrentUser() (string, string, error) {
if err != nil {
return "", "", err
}
if len(res.Rows) > 0 {
cols := strings.Split(res.Rows[0].Str(0), "@")
if res.Rows.Next() {
var currentUser string
err = res.Rows.Scan(&currentUser)
if err != nil {
return "", "", err
}
cols := strings.Split(currentUser, "@")
if len(cols) == 2 {
user := strings.Trim(cols[0], "'")
host := strings.Trim(cols[1], "'")
......@@ -51,14 +57,20 @@ func (db *Connector) HasSelectPrivilege() bool {
common.Log.Error("User: %s, HasSelectPrivilege: %s", db.User, err.Error())
return false
}
res, err := db.Query("select Select_priv from mysql.user where user='%s' and host='%s'", user, host)
res, err := db.Query(fmt.Sprintf("select Select_priv from mysql.user where user='%s' and host='%s'", user, host))
if err != nil {
common.Log.Error("HasSelectPrivilege, DSN: %s, Error: %s", db.Addr, err.Error())
return false
}
// Select_priv
if len(res.Rows) > 0 {
if res.Rows[0].Str(0) == "Y" {
if res.Rows.Next() {
var selectPrivilege string
err = res.Rows.Scan(&selectPrivilege)
if err != nil {
common.Log.Error("HasSelectPrivilege, Scan Error: %s", err.Error())
return false
}
if selectPrivilege == "Y" {
return true
}
}
......@@ -79,24 +91,31 @@ func (db *Connector) HasAllPrivilege() bool {
common.Log.Error("HasAllPrivilege, DSN: %s, Error: %s", db.Addr, err.Error())
return false
}
var priv string
if len(res.Rows) > 0 {
priv = res.Rows[0].Str(0)
} else {
common.Log.Error("HasAllPrivilege, DSN: %s, get privilege string error", db.Addr)
if res.Rows.Next() {
err = res.Rows.Scan(&priv)
if err != nil {
common.Log.Error("HasAllPrivilege, DSN: %s, Scan error", db.Addr)
return false
}
}
// get all privilege status
res, err = db.Query("select concat("+priv+") from mysql.user where user='%s' and host='%s'", user, host)
res, err = db.Query(fmt.Sprintf("select concat("+priv+") from mysql.user where user='%s' and host='%s'", user, host))
if err != nil {
common.Log.Error("HasAllPrivilege, DSN: %s, Error: %s", db.Addr, err.Error())
return false
}
// %_priv
if len(res.Rows) > 0 {
if strings.Replace(res.Rows[0].Str(0), "Y", "", -1) == "" {
if res.Rows.Next() {
err = res.Rows.Scan(&priv)
if err != nil {
common.Log.Error("HasAllPrivilege, DSN: %s, Scan error", db.Addr)
return false
}
if strings.Replace(priv, "Y", "", -1) == "" {
return true
}
}
......
......@@ -18,6 +18,16 @@ package database
import "testing"
func TestCurrentUser(t *testing.T) {
user, host, err := connTest.CurrentUser()
if err != nil {
t.Error(err.Error())
}
if user != "root" || host != "%" {
t.Errorf("Want user: root, host: %%. Get user: %s, host: %s", user, host)
}
}
func TestHasSelectPrivilege(t *testing.T) {
if !connTest.HasSelectPrivilege() {
t.Errorf("DSN: %s, User: %s, should has select privilege", connTest.Addr, connTest.User)
......
......@@ -19,9 +19,7 @@ package database
import (
"errors"
"fmt"
"io"
"strings"
"time"
"github.com/XiaoMi/soar/common"
......@@ -40,95 +38,79 @@ type ProfilingRow struct {
// TODO: 支持show profile all,不过目前看all的信息过多有点眼花缭乱
}
// Profiling 执行SQL,并对其Profiling
func (db *Connector) Profiling(sql string, params ...interface{}) (*QueryResult, error) {
// Profiling 执行SQL,并对其 Profile
func (db *Connector) Profiling(sql string, params ...interface{}) ([]ProfilingRow, error) {
var rows []ProfilingRow
// 过滤不需要 profiling 的 SQL
switch sqlparser.Preview(sql) {
case sqlparser.StmtSelect, sqlparser.StmtUpdate, sqlparser.StmtDelete:
default:
return nil, errors.New("no need profiling")
return rows, errors.New("no need profiling")
}
// 测试环境如果检查是关闭的,则SQL不会被执行
if common.Config.TestDSN.Disable {
return nil, errors.New("Dsn Disable")
return rows, errors.New("dsn is disable")
}
// 数据库安全性检查:如果 Connector 的 IP 端口与 TEST 环境不一致,则启用 SQL 白名单
// 不在白名单中的SQL不允许执行
// 不在白名单中的 SQL 不允许执行
// 执行环境与test环境不相同
if db.Addr != common.Config.TestDSN.Addr && db.dangerousQuery(sql) {
return nil, fmt.Errorf("query execution deny: Execute SQL with DSN(%s/%s) '%s'",
return rows, fmt.Errorf("query execution deny: Execute SQL with DSN(%s/%s) '%s'",
db.Addr, db.Database, fmt.Sprintf(sql, params...))
}
common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, sql)
conn := db.NewConnection()
// 设置SQL连接超时时间
conn.SetTimeout(time.Duration(common.Config.ConnTimeOut) * time.Second)
conn, err := db.NewConnection()
if err != nil {
return rows, err
}
defer conn.Close()
err := conn.Connect()
// Keep connection
// https://github.com/go-sql-driver/mysql/issues/208
trx, err := conn.Begin()
if err != nil {
return nil, err
return rows, err
}
defer trx.Rollback()
// 添加SQL执行超时限制
ch := make(chan QueryResult, 1)
go func() {
// 开启Profiling
_, _, err = conn.Query("set @@profiling=1")
// 开启 Profiling
_, err = trx.Query("set @@profiling=1")
common.LogIfError(err, "")
// 执行SQL,抛弃返回结果
result, err := conn.Start(sql, params...)
// 执行 SQL,抛弃返回结果
tmpRes, err := trx.Query(sql, params...)
if err != nil {
ch <- QueryResult{
Error: err,
}
return
}
row := result.MakeRow()
for {
err = result.ScanRow(row)
if err == io.EOF {
break
return rows, err
}
for tmpRes.Next() {
continue
}
// 返回Profiling结果
res := QueryResult{}
res.Rows, res.Result, res.Error = conn.Query("show profile")
_, _, err = conn.Query("set @@profiling=0")
// 返回 Profiling 结果
res, err := trx.Query("show profile")
for res.Next() {
var profileRow ProfilingRow
err := res.Scan(&profileRow.Status, &profileRow.Duration)
if err != nil {
common.LogIfError(err, "")
ch <- res
}()
select {
case res := <-ch:
return &res, res.Error
case <-time.After(time.Duration(common.Config.QueryTimeOut) * time.Second):
return nil, errors.New("query execution timeout")
}
}
func getProfiling(res *QueryResult) Profiling {
var rows []ProfilingRow
for _, row := range res.Rows {
rows = append(rows, ProfilingRow{
Status: row.Str(0),
Duration: row.Float(1),
})
rows = append(rows, profileRow)
}
return Profiling{Rows: rows}
// 关闭 Profiling
_, err = trx.Query("set @@profiling=0")
common.LogIfError(err, "")
return rows, err
}
// FormatProfiling 格式化输出Profiling信息
func FormatProfiling(res *QueryResult) string {
profiling := getProfiling(res)
func FormatProfiling(rows []ProfilingRow) string {
str := []string{"| Status | Duration |"}
str = append(str, "| --- | --- |")
for _, row := range profiling.Rows {
for _, row := range rows {
str = append(str, fmt.Sprintf("| %s | %f |", row.Status, row.Duration))
}
return strings.Join(str, "\n")
......
......@@ -19,35 +19,21 @@ package database
import (
"testing"
"github.com/XiaoMi/soar/common"
"github.com/kr/pretty"
)
func TestProfiling(t *testing.T) {
common.Config.QueryTimeOut = 1
res, err := connTest.Profiling("select 1")
if err == nil {
pretty.Println(res)
} else {
rows, err := connTest.Profiling("select 1")
if err != nil {
t.Error(err)
}
pretty.Println(rows)
}
func TestFormatProfiling(t *testing.T) {
res, err := connTest.Profiling("select 1")
if err == nil {
pretty.Println(FormatProfiling(res))
} else {
t.Error(err)
}
}
func TestGetProfiling(t *testing.T) {
res, err := connTest.Profiling("select 1")
if err == nil {
pretty.Println(getProfiling(res))
} else {
if err != nil {
t.Error(err)
}
pretty.Println(FormatProfiling(res))
}
......@@ -17,14 +17,9 @@
package database
import (
"database/sql"
"fmt"
"io"
"strconv"
"strings"
"time"
"github.com/XiaoMi/soar/common"
"github.com/ziutek/mymysql/mysql"
)
/*--------------------
......@@ -57,21 +52,17 @@ func (db *Connector) SamplingData(remote Connector, tables ...string) error {
maxValCount := 200
// 获取数据库连接对象
conn := remote.NewConnection()
localConn := db.NewConnection()
// 连接数据库
err := conn.Connect()
defer conn.Close()
conn, err := remote.NewConnection()
if err != nil {
return err
}
defer conn.Close()
err = localConn.Connect()
defer localConn.Close()
localConn, err := db.NewConnection()
if err != nil {
return err
}
defer localConn.Close()
for _, table := range tables {
// 表类型检查
......@@ -109,19 +100,27 @@ func (db *Connector) SamplingData(remote Connector, tables ...string) error {
// 开始从环境中泵取数据
// 因为涉及到的数据量问题,所以泵取与插入时同时进行的
// TODO 加 ref link
func startSampling(conn, localConn mysql.Conn, database, table string, factor float64, wants, maxValCount int) error {
func startSampling(conn, localConn *sql.DB, database, table string, factor float64, wants, maxValCount int) error {
return nil
// TODO:
/*
// 从线上数据库获取所需dump的表中所有列的数据类型,备用
// 由于测试库中的库表为刚建立的,所以在information_schema中很可能没有这个表的信息
var dataTypes []string
q := fmt.Sprintf("select DATA_TYPE from information_schema.COLUMNS where TABLE_SCHEMA='%s' and TABLE_NAME = '%s'",
database, table)
common.Log.Debug("Sampling data execute: %s", q)
rs, _, err := localConn.Query(q)
rs, err := localConn.Query(q)
if err != nil {
common.Log.Debug("Sampling data got data type Err: %v", err)
} else {
for _, r := range rs {
dataTypes = append(dataTypes, r.Str(0))
for rs.Next() {
var dataType string
err = rs.Scan(&dataType)
if err != nil {
return err
}
dataTypes = append(dataTypes, dataType)
}
}
......@@ -132,7 +131,7 @@ func startSampling(conn, localConn mysql.Conn, database, table string, factor fl
}
sql := fmt.Sprintf("select * from `%s` %s limit %d;", table, where, wants)
res, err := conn.Start(sql)
res, err := conn.Query(sql)
if err != nil {
return err
}
......@@ -214,14 +213,15 @@ func startSampling(conn, localConn mysql.Conn, database, table string, factor fl
common.Log.Debug("%d rows sampling out", rowCount)
return nil
*/
}
// 将泵取的数据转换成Insert语句并在数据库中执行
func doSampling(conn mysql.Conn, dbName, table, colDef, values string) {
sql := fmt.Sprintf("Insert into `%s`.`%s`(%s) values%s;", dbName, table,
func doSampling(conn *sql.DB, dbName, table, colDef, values string) {
query := fmt.Sprintf("Insert into `%s`.`%s`(%s) values%s;", dbName, table,
colDef, values)
_, _, err := conn.Query(sql)
_, err := conn.Query(query)
if err != nil {
common.Log.Error("doSampling Error from %s.%s: %v", dbName, table, err)
......
......@@ -32,6 +32,8 @@ func TestSamplingData(t *testing.T) {
User: common.Config.OnlineDSN.User,
Pass: common.Config.OnlineDSN.Password,
Database: common.Config.OnlineDSN.Schema,
Charset: common.Config.OnlineDSN.Charset,
Net: common.Config.OnlineDSN.Net,
}
offline := &Connector{
......@@ -39,6 +41,8 @@ func TestSamplingData(t *testing.T) {
User: common.Config.TestDSN.User,
Pass: common.Config.TestDSN.Password,
Database: common.Config.TestDSN.Schema,
Charset: common.Config.TestDSN.Charset,
Net: common.Config.TestDSN.Net,
}
offline.Database = "test"
......
......@@ -18,12 +18,10 @@ package database
import (
"fmt"
"github.com/XiaoMi/soar/common"
"regexp"
"strconv"
"strings"
"time"
"github.com/XiaoMi/soar/common"
)
// SHOW TABLE STATUS Syntax
......@@ -36,11 +34,12 @@ type TableStatInfo struct {
}
// tableStatusRow 用于 show table status value
// use []byte instead of string, because []byte allow to be null, string not
type tableStatusRow struct {
Name string // 表名
Engine string // 该表使用的存储引擎
Version int // 该表的 .frm 文件版本号
RowFormat string // 该表使用的行存储格式
Engine []byte // 该表使用的存储引擎
Version []byte // 该表的 .frm 文件版本号
RowFormat []byte // 该表使用的行存储格式
Rows int64 // 表行数, InnoDB 引擎中为预估值,甚至可能会有40%~50%的数值偏差
AvgRowLength int // 平均行长度
......@@ -60,14 +59,14 @@ type tableStatusRow struct {
IndexLength int
DataFree int // 已分配但未使用的字节数
AutoIncrement int // 下一个自增值
CreateTime time.Time // 创建时间
UpdateTime time.Time // 最近一次更新时间,该值不准确
CheckTime time.Time // 上次检查时间
Collation string // 字符集及排序规则信息
Checksum string // 校验和
CreateOptions string // 创建表的时候的时候一切其他属性
Comment string // 注释
AutoIncrement []byte // 下一个自增值
CreateTime []byte // 创建时间
UpdateTime []byte // 最近一次更新时间,该值不准确
CheckTime []byte // 上次检查时间
Collation []byte // 字符集及排序规则信息
Checksum []byte // 校验和
CreateOptions []byte // 创建表的时候的时候一切其他属性
Comment []byte // 注释
}
// newTableStat 构造 table Stat 对象
......@@ -83,7 +82,7 @@ func (db *Connector) ShowTables() ([]string, error) {
defer func() {
err := recover()
if err != nil {
common.Log.Error("recover ShowTableStatus()", err)
common.Log.Error("recover ShowTables()", err)
}
}()
......@@ -92,79 +91,70 @@ func (db *Connector) ShowTables() ([]string, error) {
if err != nil {
return []string{}, err
}
if res.Error != nil {
return []string{}, res.Error
}
// 获取值
var tables []string
for _, row := range res.Rows {
tables = append(tables, row.Str(0))
for res.Rows.Next() {
var table string
err = res.Rows.Scan(&table)
if err != nil {
return []string{}, err
}
tables = append(tables, table)
}
return tables, err
}
// ShowTableStatus 执行 show table status
func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) {
defer func() {
err := recover()
if err != nil {
common.Log.Error("recover ShowTableStatus()", err)
}
}()
// 初始化struct
ts := newTableStat(tableName)
tbStatus := newTableStat(tableName)
// 执行 show table status
res, err := db.Query("show table status where name = '%s'", ts.Name)
res, err := db.Query(fmt.Sprintf("show table status where name = '%s'", tbStatus.Name))
if err != nil {
return ts, err
}
rs := res.Result.Map("Rows")
name := res.Result.Map("Name")
df := res.Result.Map("Data_free")
sum := res.Result.Map("Checksum")
engine := res.Result.Map("Engine")
version := res.Result.Map("Version")
comment := res.Result.Map("Comment")
ai := res.Result.Map("Auto_increment")
collation := res.Result.Map("Collation")
rowFormat := res.Result.Map("Row_format")
checkTime := res.Result.Map("Check_time")
dataLength := res.Result.Map("Data_length")
idxLength := res.Result.Map("Index_length")
createTime := res.Result.Map("Create_time")
updateTime := res.Result.Map("Update_time")
options := res.Result.Map("Create_options")
avgRowLength := res.Result.Map("Avg_row_length")
maxDataLength := res.Result.Map("Max_data_length")
return tbStatus, err
}
if res.Error != nil {
return tbStatus, res.Error
}
ts := tableStatusRow{}
statusFields := make([]interface{}, 0)
fields := map[string]interface{}{
"Name": &ts.Name,
"Engine": &ts.Engine,
"Version": &ts.Version,
"Row_format": &ts.RowFormat,
"Rows": &ts.Rows,
"Avg_row_length": &ts.AvgRowLength,
"Data_length": &ts.DataLength,
"Max_data_length": &ts.MaxDataLength,
"Index_length": &ts.IndexLength,
"Data_free": &ts.DataFree,
"Auto_increment": &ts.AutoIncrement,
"Create_time": &ts.CreateTime,
"Update_time": &ts.UpdateTime,
"Check_time": &ts.CheckTime,
"Collation": &ts.Collation,
"Checksum": &ts.Checksum,
"Create_options": &ts.CreateOptions,
"Comment": &ts.Comment,
}
cols, err := res.Rows.Columns()
common.LogIfError(err, "")
for _, col := range cols {
statusFields = append(statusFields, fields[col])
}
// 获取值
for _, row := range res.Rows {
value := tableStatusRow{
Name: row.Str(name),
Engine: row.Str(engine),
Version: row.Int(version),
Rows: row.Int64(rs),
RowFormat: row.Str(rowFormat),
AvgRowLength: row.Int(avgRowLength),
DataLength: row.Int(dataLength),
MaxDataLength: row.Int(maxDataLength),
IndexLength: row.Int(idxLength),
DataFree: row.Int(df),
AutoIncrement: row.Int(ai),
CreateTime: row.Time(createTime, time.Local),
UpdateTime: row.Time(updateTime, time.Local),
CheckTime: row.Time(checkTime, time.Local),
Collation: row.Str(collation),
Checksum: row.Str(sum),
CreateOptions: row.Str(options),
Comment: row.Str(comment),
}
ts.Rows = append(ts.Rows, value)
}
return ts, err
for res.Rows.Next() {
res.Rows.Scan(statusFields...)
tbStatus.Rows = append(tbStatus.Rows, ts)
}
return tbStatus, err
}
// https://dev.mysql.com/doc/refman/5.7/en/show-index.html
......@@ -172,7 +162,7 @@ func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) {
// TableIndexInfo 用以保存 show index 之后获取的 index 信息
type TableIndexInfo struct {
TableName string
IdxRows []TableIndexRow
Rows []TableIndexRow
}
// TableIndexRow 用以存放show index之后获取的每一条index信息
......@@ -190,13 +180,14 @@ type TableIndexRow struct {
IndexType string // BTREE, FULLTEXT, HASH, RTREE
Comment string
IndexComment string
Visible string
}
// NewTableIndexInfo 构造 TableIndexInfo
func NewTableIndexInfo(tableName string) *TableIndexInfo {
return &TableIndexInfo{
TableName: tableName,
IdxRows: make([]TableIndexRow, 0),
Rows: make([]TableIndexRow, 0),
}
}
......@@ -205,43 +196,32 @@ func (db *Connector) ShowIndex(tableName string) (*TableIndexInfo, error) {
tbIndex := NewTableIndexInfo(tableName)
// 执行 show create table
res, err := db.Query("show index from `%s`.`%s`", db.Database, tableName)
res, err := db.Query(fmt.Sprintf("show index from `%s`.`%s`", db.Database, tableName))
if err != nil {
return nil, err
}
table := res.Result.Map("Table")
unique := res.Result.Map("Non_unique")
keyName := res.Result.Map("Key_name")
seq := res.Result.Map("Seq_in_index")
cName := res.Result.Map("Column_name")
collation := res.Result.Map("Collation")
cardinality := res.Result.Map("Cardinality")
subPart := res.Result.Map("Sub_part")
packed := res.Result.Map("Packed")
null := res.Result.Map("Null")
idxType := res.Result.Map("Index_type")
comment := res.Result.Map("Comment")
idxComment := res.Result.Map("Index_comment")
if res.Error != nil {
return nil, res.Error
}
// 获取值
for _, row := range res.Rows {
value := TableIndexRow{
Table: row.Str(table),
NonUnique: row.Int(unique),
KeyName: row.Str(keyName),
SeqInIndex: row.Int(seq),
ColumnName: row.Str(cName),
Collation: row.Str(collation),
Cardinality: row.Int(cardinality),
SubPart: row.Int(subPart),
Packed: row.Int(packed),
Null: row.Str(null),
IndexType: row.Str(idxType),
Comment: row.Str(comment),
IndexComment: row.Str(idxComment),
}
tbIndex.IdxRows = append(tbIndex.IdxRows, value)
for res.Rows.Next() {
var ti TableIndexRow
res.Rows.Scan(&ti.Table,
&ti.NonUnique,
&ti.KeyName,
&ti.SeqInIndex,
&ti.ColumnName,
&ti.Collation,
&ti.Cardinality,
&ti.SubPart,
&ti.Packed,
&ti.Null,
&ti.IndexType,
&ti.Comment,
&ti.IndexComment,
&ti.Visible)
tbIndex.Rows = append(tbIndex.Rows, ti)
}
return tbIndex, err
}
......@@ -257,7 +237,7 @@ const (
IndexNonUnique = IndexSelectKey("NonUnique") // 唯一索引
)
// FindIndex 获取TableIndexInfo中需要的索引
// FindIndex 获取 TableIndexInfo 中需要的索引
func (tbIndex *TableIndexInfo) FindIndex(arg IndexSelectKey, value string) []TableIndexRow {
var result []TableIndexRow
if tbIndex == nil {
......@@ -268,28 +248,28 @@ func (tbIndex *TableIndexInfo) FindIndex(arg IndexSelectKey, value string) []Tab
switch arg {
case IndexKeyName:
for _, index := range tbIndex.IdxRows {
for _, index := range tbIndex.Rows {
if strings.ToLower(index.KeyName) == value {
result = append(result, index)
}
}
case IndexColumnName:
for _, index := range tbIndex.IdxRows {
for _, index := range tbIndex.Rows {
if strings.ToLower(index.ColumnName) == value {
result = append(result, index)
}
}
case IndexIndexType:
for _, index := range tbIndex.IdxRows {
for _, index := range tbIndex.Rows {
if strings.ToLower(index.IndexType) == value {
result = append(result, index)
}
}
case IndexNonUnique:
for _, index := range tbIndex.IdxRows {
for _, index := range tbIndex.Rows {
unique := strconv.Itoa(index.NonUnique)
if unique == value {
result = append(result, index)
......@@ -316,12 +296,12 @@ type TableDesc struct {
type TableDescValue struct {
Field string // 列名
Type string // 数据类型
Collation []byte // 字符集
Null string // 是否有NULL(NO、YES)
Collation string // 字符集
Privileges string // 权限s
Key string // 键类型
Default string // 默认值
Default []byte // 默认值
Extra string // 其他
Privileges string // 权限
Comment string // 备注
}
......@@ -338,35 +318,27 @@ func (db *Connector) ShowColumns(tableName string) (*TableDesc, error) {
tbDesc := NewTableDesc(tableName)
// 执行 show create table
res, err := db.Query("show full columns from `%s`.`%s`", db.Database, tableName)
res, err := db.Query(fmt.Sprintf("show full columns from `%s`.`%s`", db.Database, tableName))
if err != nil {
return nil, err
}
field := res.Result.Map("Field")
tp := res.Result.Map("Type")
null := res.Result.Map("Null")
key := res.Result.Map("Key")
def := res.Result.Map("Default")
extra := res.Result.Map("Extra")
collation := res.Result.Map("Collation")
privileges := res.Result.Map("Privileges")
comm := res.Result.Map("Comment")
if res.Error != nil {
return nil, res.Error
}
// 获取值
for _, row := range res.Rows {
value := TableDescValue{
Field: row.Str(field),
Type: row.Str(tp),
Null: row.Str(null),
Key: row.Str(key),
Default: row.Str(def),
Extra: row.Str(extra),
Privileges: row.Str(privileges),
Collation: row.Str(collation),
Comment: row.Str(comm),
}
tbDesc.DescValues = append(tbDesc.DescValues, value)
for res.Rows.Next() {
var tc TableDescValue
res.Rows.Scan(&tc.Field,
&tc.Type,
&tc.Collation,
&tc.Null,
&tc.Key,
&tc.Default,
&tc.Extra,
&tc.Privileges,
&tc.Comment)
tbDesc.DescValues = append(tbDesc.DescValues, tc)
}
return tbDesc, err
}
......@@ -383,18 +355,21 @@ func (td TableDesc) Columns() []string {
// showCreate show create
func (db *Connector) showCreate(createType, name string) (string, error) {
// 执行 show create table
res, err := db.Query("show create %s `%s`", createType, name)
res, err := db.Query(fmt.Sprintf("show create %s `%s`", createType, name))
if err != nil {
return "", err
}
if res.Error != nil {
return "", res.Error
}
// 获取ddl
var ddl string
for _, row := range res.Rows {
ddl = row.Str(1)
// 获取 CREATE TABLE 语句
var tableName, createTable string
for res.Rows.Next() {
res.Rows.Scan(&tableName, &createTable)
}
return ddl, err
return createTable, err
}
// ShowCreateDatabase show create database
......@@ -451,6 +426,10 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
"c.TABLE_NAME,c.TABLE_SCHEMA,c.COLUMN_TYPE,c.CHARACTER_SET_NAME, c.COLLATION_NAME "+
"FROM `INFORMATION_SCHEMA`.`COLUMNS` as c where c.COLUMN_NAME = '%s' ", name)
if dbName != "" {
sql += fmt.Sprintf(" and c.table_schema = '%s'", dbName)
}
if len(tables) > 0 {
var tmp []string
for _, table := range tables {
......@@ -459,33 +438,25 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
sql += fmt.Sprintf(" and c.table_name in (%s)", strings.Join(tmp, ","))
}
if dbName != "" {
sql += fmt.Sprintf(" and c.table_schema = '%s'", dbName)
}
common.Log.Debug("FindColumn, execute SQL: %s", sql)
res, err := db.Query(sql)
if err != nil {
common.Log.Error("(db *Connector) FindColumn Error : ", err)
return columns, err
}
tbName := res.Result.Map("TABLE_NAME")
schema := res.Result.Map("TABLE_SCHEMA")
colTyp := res.Result.Map("COLUMN_TYPE")
colCharset := res.Result.Map("CHARACTER_SET_NAME")
collation := res.Result.Map("COLLATION_NAME")
// 获取ddl
for _, row := range res.Rows {
col := &common.Column{
Name: name,
Table: row.Str(tbName),
DB: row.Str(schema),
DataType: row.Str(colTyp),
Character: row.Str(colCharset),
Collation: row.Str(collation),
if res.Error != nil {
common.Log.Error("(db *Connector) FindColumn Error : ", res.Error)
return columns, res.Error
}
var col common.Column
for res.Rows.Next() {
res.Rows.Scan(&col.Table,
&col.DB,
&col.DataType,
&col.Character,
&col.Collation)
// 填充字符集和排序规则
if col.Character == "" {
// 当从`INFORMATION_SCHEMA`.`COLUMNS`表中查询不到相关列的character和collation的信息时
......@@ -494,40 +465,56 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
sql = fmt.Sprintf("SELECT `t`.`TABLE_COLLATION` FROM `INFORMATION_SCHEMA`.`TABLES` AS `t` "+
"WHERE `t`.`TABLE_NAME`='%s' AND `t`.`TABLE_SCHEMA` = '%s'", col.Table, col.DB)
var newRes *QueryResult
common.Log.Debug("FindColumn, execute SQL: %s", sql)
var newRes QueryResult
newRes, err = db.Query(sql)
if err != nil {
common.Log.Error("(db *Connector) FindColumn Error : ", err)
return columns, err
}
if res.Error != nil {
common.Log.Error("(db *Connector) FindColumn Error : ", res.Error)
return columns, res.Error
}
tbCollation := newRes.Rows[0].Str(0)
var tbCollation string
if newRes.Rows.Next() {
newRes.Rows.Scan(&tbCollation)
}
if tbCollation != "" {
col.Character = strings.Split(tbCollation, "_")[0]
col.Collation = tbCollation
}
}
columns = append(columns, col)
columns = append(columns, &col)
}
return columns, err
}
// IsFKey 判断列是否是外键
func (db *Connector) IsFKey(dbName, tbName, column string) bool {
// IsForeignKey 判断列是否是外键
func (db *Connector) IsForeignKey(dbName, tbName, column string) bool {
sql := fmt.Sprintf("SELECT REFERENCED_COLUMN_NAME FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE C "+
"WHERE REFERENCED_TABLE_SCHEMA <> 'NULL' AND"+
" TABLE_NAME='%s' AND"+
" TABLE_SCHEMA='%s' AND"+
" COLUMN_NAME='%s'", tbName, dbName, column)
common.Log.Debug("IsForeignKey, execute SQL: %s", sql)
res, err := db.Query(sql)
if err == nil && len(res.Rows) == 0 {
if err != nil {
common.Log.Error("IsForeignKey, Error: %s", err.Error())
return false
}
if res.Error != nil {
common.Log.Error("IsForeignKey, Error: %s", res.Error.Error())
return false
}
if res.Rows.Next() {
return true
}
return false
}
// Reference 用于存储关系
......@@ -535,10 +522,10 @@ type Reference map[string][]ReferenceValue
// ReferenceValue 用于处理表之间的关系
type ReferenceValue struct {
RefDBName string // 夫表所属数据库
RefTable string // 父表
DBName string // 子表所属数据库
Table string // 子表
ReferencedTableSchema string // 夫表所属数据库
ReferencedTableName string // 父表
TableSchema string // 子表所属数据库
TableName string // 子表
ConstraintName string // 关系名称
}
......@@ -555,30 +542,26 @@ WHERE C.REFERENCED_TABLE_NAME IS NOT NULL`
sql = sql + extra
}
common.Log.Debug("ShowReference, execute SQL: %s", sql)
// 执行SQL查找外键关联关系
res, err := db.Query(sql)
if err != nil {
return referenceValues, err
}
refDb := res.Result.Map("REFERENCED_TABLE_SCHEMA")
refTb := res.Result.Map("REFERENCED_TABLE_NAME")
schema := res.Result.Map("TABLE_SCHEMA")
tb := res.Result.Map("TABLE_NAME")
cName := res.Result.Map("CONSTRAINT_NAME")
if res.Error != nil {
return referenceValues, res.Error
}
// 获取值
for _, row := range res.Rows {
value := ReferenceValue{
RefDBName: row.Str(refDb),
RefTable: row.Str(refTb),
DBName: row.Str(schema),
Table: row.Str(tb),
ConstraintName: row.Str(cName),
}
referenceValues = append(referenceValues, value)
for res.Rows.Next() {
var rv ReferenceValue
res.Rows.Scan(&rv.ReferencedTableSchema,
&rv.ReferencedTableName,
&rv.TableSchema,
&rv.TableName,
&rv.ConstraintName)
referenceValues = append(referenceValues, rv)
}
return referenceValues, err
}
......@@ -20,75 +20,144 @@ import (
"fmt"
"testing"
"github.com/XiaoMi/soar/common"
"github.com/kr/pretty"
"vitess.io/vitess/go/vt/sqlparser"
)
func TestShowTableStatus(t *testing.T) {
connTest.Database = "information_schema"
ts, err := connTest.ShowTableStatus("TABLES")
orgDatabase := connTest.Database
connTest.Database = "sakila"
ts, err := connTest.ShowTableStatus("film")
if err != nil {
t.Error("ShowTableStatus Error: ", err)
}
if string(ts.Rows[0].Engine) != "InnoDB" {
t.Error("film table should be InnoDB engine")
}
pretty.Println(ts)
connTest.Database = "sakila"
ts, err = connTest.ShowTableStatus("actor_info")
if err != nil {
t.Error("ShowTableStatus Error: ", err)
}
if string(ts.Rows[0].Comment) != "VIEW" {
t.Error("actor_info should be VIEW")
}
pretty.Println(ts)
connTest.Database = orgDatabase
}
func TestShowTables(t *testing.T) {
connTest.Database = "information_schema"
orgDatabase := connTest.Database
connTest.Database = "sakila"
ts, err := connTest.ShowTables()
if err != nil {
t.Error("ShowTableStatus Error: ", err)
}
pretty.Println(ts)
err = common.GoldenDiff(func() {
for _, table := range ts {
fmt.Println(table)
}
}, t.Name(), update)
if err != nil {
t.Error(err)
}
connTest.Database = orgDatabase
}
func TestShowCreateTable(t *testing.T) {
connTest.Database = "information_schema"
ts, err := connTest.ShowCreateTable("TABLES")
orgDatabase := connTest.Database
connTest.Database = "sakila"
ts, err := connTest.ShowCreateTable("film")
if err != nil {
t.Error("ShowCreateTable Error: ", err)
}
err = common.GoldenDiff(func() {
fmt.Println(ts)
stmt, err := sqlparser.Parse(ts)
if err != nil {
t.Error(err.Error())
}
pretty.Println(stmt, err)
}, t.Name(), update)
if err != nil {
t.Error(err)
}
connTest.Database = orgDatabase
}
func TestShowIndex(t *testing.T) {
connTest.Database = "information_schema"
ti, err := connTest.ShowIndex("TABLES")
orgDatabase := connTest.Database
connTest.Database = "sakila"
ti, err := connTest.ShowIndex("film")
if err != nil {
t.Error("ShowIndex Error: ", err)
}
pretty.Println(ti.FindIndex(IndexKeyName, "idx_store_id_film_id"))
err = common.GoldenDiff(func() {
pretty.Println(ti)
pretty.Println(ti.FindIndex(IndexKeyName, "idx_title"))
}, t.Name(), update)
if err != nil {
t.Error(err)
}
connTest.Database = orgDatabase
}
func TestShowColumns(t *testing.T) {
connTest.Database = "information_schema"
ti, err := connTest.ShowColumns("TABLES")
orgDatabase := connTest.Database
connTest.Database = "sakila"
ti, err := connTest.ShowColumns("film")
if err != nil {
t.Error("ShowColumns Error: ", err)
}
err = common.GoldenDiff(func() {
pretty.Println(ti)
}, t.Name(), update)
if err != nil {
t.Error(err)
}
connTest.Database = orgDatabase
}
func TestFindColumn(t *testing.T) {
ti, err := connTest.FindColumn("id", "")
ti, err := connTest.FindColumn("film_id", "sakila", "film")
if err != nil {
t.Error("FindColumn Error: ", err)
}
err = common.GoldenDiff(func() {
pretty.Println(ti)
}, t.Name(), update)
if err != nil {
t.Error(err)
}
}
func TestIsFKey(t *testing.T) {
if !connTest.IsForeignKey("sakila", "film", "language_id") {
t.Error("want True. got false")
}
}
func TestShowReference(t *testing.T) {
rv, err := connTest.ShowReference("test2", "homeImg")
rv, err := connTest.ShowReference("sakila", "film")
if err != nil {
t.Error("ShowReference Error: ", err)
}
pretty.Println(rv)
}
func TestIsFKey(t *testing.T) {
if !connTest.IsFKey("sakila", "film", "language_id") {
t.Error("want True. got false")
err = common.GoldenDiff(func() {
pretty.Println(rv)
}, t.Name(), update)
if err != nil {
t.Error(err)
}
}
[]*common.Column{
&common.Column{
Name: "",
Alias: nil,
Table: "film",
DB: "sakila",
DataType: "smallint(5) unsigned",
Character: "utf8",
Collation: "utf8_general_ci",
Cardinality: 0,
Null: "",
Key: "",
Default: "",
Extra: "",
Comment: "",
Privileges: "",
},
}
```sql
select 1
```
```json
{
"steps": [
{
"join_preparation": {
"select#": 1,
"steps": [
{
"expanded_query": "/* select#1 */ select 1 AS `1`"
}
]
}
},
{
"join_optimization": {
"select#": 1,
"steps": [
]
}
},
{
"join_explain": {
"select#": 1,
"steps": [
]
}
}
]
}
```
&database.TableDesc{
Name: "film",
DescValues: {
{
Field: "film_id",
Type: "smallint(5) unsigned",
Collation: nil,
Null: "NO",
Key: "PRI",
Default: nil,
Extra: "auto_increment",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "title",
Type: "varchar(255)",
Collation: {0x75, 0x74, 0x66, 0x38, 0x5f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, 0x5f, 0x63, 0x69},
Null: "NO",
Key: "MUL",
Default: nil,
Extra: "",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "description",
Type: "text",
Collation: {0x75, 0x74, 0x66, 0x38, 0x5f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, 0x5f, 0x63, 0x69},
Null: "YES",
Key: "",
Default: nil,
Extra: "",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "release_year",
Type: "year(4)",
Collation: nil,
Null: "YES",
Key: "",
Default: nil,
Extra: "",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "language_id",
Type: "tinyint(3) unsigned",
Collation: nil,
Null: "NO",
Key: "MUL",
Default: nil,
Extra: "",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "original_language_id",
Type: "tinyint(3) unsigned",
Collation: nil,
Null: "YES",
Key: "MUL",
Default: nil,
Extra: "",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "rental_duration",
Type: "tinyint(3) unsigned",
Collation: nil,
Null: "NO",
Key: "",
Default: {0x33},
Extra: "",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "rental_rate",
Type: "decimal(4,2)",
Collation: nil,
Null: "NO",
Key: "",
Default: {0x34, 0x2e, 0x39, 0x39},
Extra: "",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "length",
Type: "smallint(5) unsigned",
Collation: nil,
Null: "YES",
Key: "",
Default: nil,
Extra: "",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "replacement_cost",
Type: "decimal(5,2)",
Collation: nil,
Null: "NO",
Key: "",
Default: {0x31, 0x39, 0x2e, 0x39, 0x39},
Extra: "",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "rating",
Type: "enum('G','PG','PG-13','R','NC-17')",
Collation: {0x75, 0x74, 0x66, 0x38, 0x5f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, 0x5f, 0x63, 0x69},
Null: "YES",
Key: "",
Default: {0x47},
Extra: "",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "special_features",
Type: "set('Trailers','Commentaries','Deleted Scenes','Behind the Scenes')",
Collation: {0x75, 0x74, 0x66, 0x38, 0x5f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, 0x5f, 0x63, 0x69},
Null: "YES",
Key: "",
Default: nil,
Extra: "",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "last_update",
Type: "timestamp",
Collation: nil,
Null: "NO",
Key: "",
Default: {0x43, 0x55, 0x52, 0x52, 0x45, 0x4e, 0x54, 0x5f, 0x54, 0x49, 0x4d, 0x45, 0x53, 0x54, 0x41, 0x4d, 0x50},
Extra: "on update CURRENT_TIMESTAMP",
Privileges: "select,insert,update,references",
Comment: "",
},
},
}
CREATE TABLE `film` (
`film_id` smallint(5) unsigned NOT NULL AUTO_INCREMENT,
`title` varchar(255) NOT NULL,
`description` text,
`release_year` year(4) DEFAULT NULL,
`language_id` tinyint(3) unsigned NOT NULL,
`original_language_id` tinyint(3) unsigned DEFAULT NULL,
`rental_duration` tinyint(3) unsigned NOT NULL DEFAULT '3',
`rental_rate` decimal(4,2) NOT NULL DEFAULT '4.99',
`length` smallint(5) unsigned DEFAULT NULL,
`replacement_cost` decimal(5,2) NOT NULL DEFAULT '19.99',
`rating` enum('G','PG','PG-13','R','NC-17') DEFAULT 'G',
`special_features` set('Trailers','Commentaries','Deleted Scenes','Behind the Scenes') DEFAULT NULL,
`last_update` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`film_id`),
KEY `idx_title` (`title`),
KEY `idx_fk_language_id` (`language_id`),
KEY `idx_fk_original_language_id` (`original_language_id`)
) ENGINE=InnoDB AUTO_INCREMENT=1001 DEFAULT CHARSET=utf8
&sqlparser.DDL{
Action: "create",
FromTables: nil,
ToTables: nil,
Table: sqlparser.TableName{
Name: sqlparser.TableIdent{v:"film"},
Qualifier: sqlparser.TableIdent{},
},
IfExists: false,
TableSpec: (*sqlparser.TableSpec)(nil),
OptLike: (*sqlparser.OptLike)(nil),
PartitionSpec: (*sqlparser.PartitionSpec)(nil),
VindexSpec: (*sqlparser.VindexSpec)(nil),
VindexCols: nil,
} nil
&database.TableIndexInfo{
TableName: "film",
Rows: {
{Table:"film", NonUnique:0, KeyName:"PRIMARY", SeqInIndex:1, ColumnName:"film_id", Collation:"A", Cardinality:1000, SubPart:0, Packed:0, Null:"", IndexType:"", Comment:"", IndexComment:"", Visible:""},
{Table:"film", NonUnique:1, KeyName:"idx_title", SeqInIndex:1, ColumnName:"title", Collation:"A", Cardinality:1000, SubPart:0, Packed:0, Null:"", IndexType:"", Comment:"", IndexComment:"", Visible:""},
{Table:"film", NonUnique:1, KeyName:"idx_fk_language_id", SeqInIndex:1, ColumnName:"language_id", Collation:"A", Cardinality:1, SubPart:0, Packed:0, Null:"", IndexType:"", Comment:"", IndexComment:"", Visible:""},
{Table:"film", NonUnique:1, KeyName:"idx_fk_original_language_id", SeqInIndex:1, ColumnName:"original_language_id", Collation:"A", Cardinality:1, SubPart:0, Packed:0, Null:"", IndexType:"", Comment:"", IndexComment:"", Visible:""},
},
}
[]database.TableIndexRow{
{Table:"film", NonUnique:1, KeyName:"idx_title", SeqInIndex:1, ColumnName:"title", Collation:"A", Cardinality:1000, SubPart:0, Packed:0, Null:"", IndexType:"", Comment:"", IndexComment:"", Visible:""},
}
[]database.ReferenceValue{
{ReferencedTableSchema:"sakila", ReferencedTableName:"language", TableSchema:"sakila", TableName:"film", ConstraintName:"fk_film_language"},
{ReferencedTableSchema:"sakila", ReferencedTableName:"language", TableSchema:"sakila", TableName:"film", ConstraintName:"fk_film_language_original"},
}
actor
actor_info
address
category
city
country
customer
customer_list
film
film_actor
film_category
film_list
film_text
inventory
language
nicer_but_slower_film_list
payment
rental
sales_by_film_category
sales_by_store
staff
staff_list
store
[]database.TraceRow{
{Query:"explain select 1", Trace:"{\n \"steps\": [\n {\n \"join_preparation\": {\n \"select#\": 1,\n \"steps\": [\n {\n \"expanded_query\": \"/* select#1 */ select 1 AS `1`\"\n }\n ]\n }\n },\n {\n \"join_optimization\": {\n \"select#\": 1,\n \"steps\": [\n ]\n }\n },\n {\n \"join_explain\": {\n \"select#\": 1,\n \"steps\": [\n ]\n }\n }\n ]\n}", MissingBytesBeyondMaxMemSize:0, InsufficientPrivileges:0},
}
......@@ -19,12 +19,9 @@ package database
import (
"errors"
"fmt"
"io"
"github.com/XiaoMi/soar/common"
"regexp"
"strings"
"time"
"github.com/XiaoMi/soar/common"
"vitess.io/vitess/go/vt/sqlparser"
)
......@@ -43,10 +40,11 @@ type TraceRow struct {
}
// Trace 执行SQL,并对其Trace
func (db *Connector) Trace(sql string, params ...interface{}) (*QueryResult, error) {
func (db *Connector) Trace(sql string, params ...interface{}) ([]TraceRow, error) {
common.Log.Debug("Trace SQL: %s", sql)
var rows []TraceRow
if common.Config.TestDSN.Version < 50600 {
return nil, errors.New("version < 5.6, not support trace")
return rows, errors.New("version < 5.6, not support trace")
}
// 过滤不需要 Trace 的 SQL
......@@ -55,98 +53,71 @@ func (db *Connector) Trace(sql string, params ...interface{}) (*QueryResult, err
sql = "explain " + sql
case sqlparser.EXPLAIN:
default:
return nil, errors.New("no need trace")
return rows, errors.New("no need trace")
}
// 测试环境如果检查是关闭的,则SQL不会被执行
if common.Config.TestDSN.Disable {
return nil, errors.New("Dsn Disable")
return rows, errors.New("dsn is disable")
}
// 数据库安全性检查:如果 Connector 的 IP 端口与 TEST 环境不一致,则启用SQL白名单
// 不在白名单中的SQL不允许执行
// 执行环境与test环境不相同
if db.Addr != common.Config.TestDSN.Addr && db.dangerousQuery(sql) {
return nil, fmt.Errorf("query Execution Deny: Execute SQL with DSN(%s/%s) '%s'",
return rows, fmt.Errorf("query Execution Deny: Execute SQL with DSN(%s/%s) '%s'",
db.Addr, db.Database, fmt.Sprintf(sql, params...))
}
common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, sql)
conn := db.NewConnection()
// 设置SQL连接超时时间
conn.SetTimeout(time.Duration(common.Config.ConnTimeOut) * time.Second)
defer conn.Close()
err := conn.Connect()
conn, err := db.NewConnection()
if err != nil {
return nil, err
return rows, err
}
defer conn.Close()
// 添加SQL执行超时限制
ch := make(chan QueryResult, 1)
go func() {
// 开启Trace
common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=on'")
_, _, err = conn.Query("SET SESSION OPTIMIZER_TRACE='enabled=on'")
trx, err := conn.Begin()
if err != nil {
return rows, err
}
defer trx.Rollback()
_, err = trx.Query("SET SESSION OPTIMIZER_TRACE='enabled=on'")
common.LogIfError(err, "")
// 执行SQL,抛弃返回结果
result, err := conn.Start(sql, params...)
tmpRes, err := trx.Query(sql, params...)
if err != nil {
ch <- QueryResult{
Error: err,
}
return
}
row := result.MakeRow()
for {
err = result.ScanRow(row)
if err == io.EOF {
break
return rows, err
}
for tmpRes.Next() {
continue
}
// 返回Trace结果
res := QueryResult{}
res.Rows, res.Result, res.Error = conn.Query("SELECT * FROM information_schema.OPTIMIZER_TRACE")
// 关闭Trace
common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=off'")
_, _, err = conn.Query("SET SESSION OPTIMIZER_TRACE='enabled=off'")
res, err := trx.Query("SELECT * FROM information_schema.OPTIMIZER_TRACE")
for res.Next() {
var traceRow TraceRow
err = res.Scan(&traceRow.Query, &traceRow.Trace, &traceRow.MissingBytesBeyondMaxMemSize, &traceRow.InsufficientPrivileges)
if err != nil {
fmt.Println(err.Error())
common.LogIfError(err, "")
}
ch <- res
}()
select {
case res := <-ch:
return &res, res.Error
case <-time.After(time.Duration(common.Config.QueryTimeOut) * time.Second):
return nil, errors.New("query execution timeout")
rows = append(rows, traceRow)
}
}
// getTrace 获取trace信息
func getTrace(res *QueryResult) Trace {
var rows []TraceRow
for _, row := range res.Rows {
rows = append(rows, TraceRow{
Query: row.Str(0),
Trace: row.Str(1),
MissingBytesBeyondMaxMemSize: row.Int(2),
InsufficientPrivileges: row.Int(3),
})
}
return Trace{Rows: rows}
// 关闭Trace
common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=off'")
_, err = trx.Query("SET SESSION OPTIMIZER_TRACE='enabled=off'")
common.LogIfError(err, "")
return rows, err
}
// FormatTrace 格式化输出Trace信息
func FormatTrace(res *QueryResult) string {
func FormatTrace(rows []TraceRow) string {
explainReg := regexp.MustCompile(`(?i)^explain\s+`)
trace := getTrace(res)
str := []string{""}
for _, row := range trace.Rows {
for _, row := range rows {
str = append(str, "```sql")
sql := explainReg.ReplaceAllString(row.Query, "")
str = append(str, sql)
......
......@@ -17,7 +17,6 @@
package database
import (
"flag"
"testing"
"github.com/XiaoMi/soar/common"
......@@ -25,34 +24,30 @@ import (
"github.com/kr/pretty"
)
var update = flag.Bool("update", false, "update .golden files")
func TestTrace(t *testing.T) {
common.Config.QueryTimeOut = 1
res, err := connTest.Trace("select 1")
if err == nil {
common.GoldenDiff(func() {
if err != nil {
t.Error(err)
}
err = common.GoldenDiff(func() {
pretty.Println(res)
}, t.Name(), update)
} else {
if err != nil {
t.Error(err)
}
}
func TestFormatTrace(t *testing.T) {
res, err := connTest.Trace("select 1")
if err == nil {
pretty.Println(FormatTrace(res))
} else {
if err != nil {
t.Error(err)
}
}
func TestGetTrace(t *testing.T) {
res, err := connTest.Trace("select 1")
if err == nil {
pretty.Println(getTrace(res))
} else {
err = common.GoldenDiff(func() {
pretty.Println(FormatTrace(res))
}, t.Name(), update)
if err != nil {
t.Error(err)
}
}
......@@ -161,8 +161,9 @@ func (ve *VirtualEnv) CleanupTestDatabase() {
// TODO: 1 hour should be config-able
minHour := 1
for _, row := range dbs.Rows {
testDatabase := row.Str(0)
for dbs.Rows.Next() {
var testDatabase string
dbs.Rows.Scan(&testDatabase)
// test temporary database format `optimizer_YYMMDDHHmmss_randomString(16)`
if len(testDatabase) != 39 {
common.Log.Debug("CleanupTestDatabase by pass %s", testDatabase)
......@@ -218,7 +219,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string)
meta := make(map[string]*common.DB)
for _, sql := range SQLs {
common.Log.Debug("BuildVirtualEnv Database&Table Mapping, SQL: %s", sql)
common.Log.Debug("BuildVirtualEnv Database&TableName Mapping, SQL: %s", sql)
stmt, err = sqlparser.Parse(sql)
if err != nil {
......@@ -320,7 +321,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string)
}
// 如果是视图,解析语句
if len(tbStatus.Rows) > 0 && tbStatus.Rows[0].Comment == "VIEW" {
if len(tbStatus.Rows) > 0 && string(tbStatus.Rows[0].Comment) == "VIEW" {
tmpEnv.Database = db
var viewDDL string
viewDDL, err = tmpEnv.ShowCreateTable(tb.TableName)
......@@ -418,7 +419,7 @@ func (ve VirtualEnv) createTable(rEnv database.Connector, dbName, tbName string)
return nil
}
common.Log.Debug("createTable, Database: %s, Table: %s", dbName, tbName)
common.Log.Debug("createTable, Database: %s, TableName: %s", dbName, tbName)
// TODO:查看是否有外键关联(done),对外键的支持 (未解决循环依赖的问题)
......@@ -506,9 +507,9 @@ func (ve *VirtualEnv) GenTableColumns(meta common.Meta) common.TableColumns {
DB: dbName,
Table: tb.TableName,
DataType: colInfo.Type,
Character: colInfo.Collation,
Character: string(colInfo.Collation),
Key: colInfo.Key,
Default: colInfo.Default,
Default: string(colInfo.Default),
Extra: colInfo.Extra,
Comment: colInfo.Comment,
Privileges: colInfo.Privileges,
......@@ -525,9 +526,9 @@ func (ve *VirtualEnv) GenTableColumns(meta common.Meta) common.TableColumns {
col.DB = dbName
col.Table = tb.TableName
col.DataType = colInfo.Type
col.Character = colInfo.Collation
col.Character = string(colInfo.Collation)
col.Key = colInfo.Key
col.Default = colInfo.Default
col.Default = string(colInfo.Default)
col.Extra = colInfo.Extra
col.Comment = colInfo.Comment
col.Privileges = colInfo.Privileges
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册