diff --git a/database/sampling.go b/database/sampling.go index 3e280db78a44eeb7e484d35338fc54ca8b6ba4d0..fab82402c78742e241f2d011198a74f0cdd605b8 100644 --- a/database/sampling.go +++ b/database/sampling.go @@ -21,6 +21,7 @@ import ( "fmt" "github.com/XiaoMi/soar/common" + "strings" ) /*-------------------- @@ -85,134 +86,57 @@ func (db *Connector) SamplingData(remote *Connector, tables ...string) error { return nil } -// 开始从环境中泵取数据 +// startSampling sampling data from OnlineDSN to TestDSN // 因为涉及到的数据量问题,所以泵取与插入时同时进行的 -// TODO 加 ref link +// TODO: 加 ref link func startSampling(conn, localConn *sql.DB, database, table string, factor float64, wants, maxValCount int) error { - return nil - // TODO: - /* - // 从线上数据库获取所需dump的表中所有列的数据类型,备用 - // 由于测试库中的库表为刚建立的,所以在information_schema中很可能没有这个表的信息 - var dataTypes []string - q := fmt.Sprintf("select DATA_TYPE from information_schema.COLUMNS where TABLE_SCHEMA='%s' and TABLE_NAME = '%s'", - database, table) - common.Log.Debug("Sampling data execute: %s", q) - rs, err := localConn.Query(q) - if err != nil { - common.Log.Debug("Sampling data got data type Err: %v", err) - } else { - for rs.Next() { - var dataType string - err = rs.Scan(&dataType) - if err != nil { - return err - } - dataTypes = append(dataTypes, dataType) - } - } + // generate where condition + where := fmt.Sprintf("WHERE RAND() <= %f", factor) + if factor >= 1 { + where = "" + } - // 生成where条件 - where := fmt.Sprintf("where RAND()<=%f", factor) - if factor >= 1 { - where = "" - } + res, err := conn.Query(fmt.Sprintf("SELECT * FROM `%s`.`%s` %s LIMIT %d;", database, table, where, wants)) + if err != nil { + return err + } - sql := fmt.Sprintf("select * from `%s` %s limit %d;", table, where, wants) - res, err := conn.Query(sql) - if err != nil { - return err + // column info + columns, err := res.Columns() + if err != nil { + return err + } + row := make(map[string][]byte, len(columns)) + tableFields := make([]interface{}, 0) + for _, col := range columns { + if _, ok := row[col]; ok { + tableFields = append(tableFields, row[col]) } + } - // GetRow method allocates a new chunk of memory for every received row. - row := res.MakeRow() - rowCount := 0 - valCount := 0 - - // 获取所有的列名 - columns := make([]string, len(res.Fields())) - for i, filed := range res.Fields() { - columns[i] = filed.Name + // sampling data + var valuesStr string + var values []string + columnsStr := "`" + strings.Join(columns, "`,`") + "`" + for res.Next() { + res.Scan(tableFields...) + for _, val := range row { + values = append(values, fmt.Sprintf(`unhex("%s")`, fmt.Sprintf("%x", val))) } - colDef := strings.Join(columns, ",") - - // 开始填充数据 - var valList []string - for { - err := res.ScanRow(row) - if err == io.EOF { - // 扫描结束 - if len(valList) > 0 { - // 如果缓存中还存在未插入的数据,则把缓存中的数据刷新到DB中 - doSampling(localConn, database, table, colDef, strings.Join(valList, ",")) - } - break - } - - if err != nil { - return err - } - - values := make([]string, len(columns)) - for i := range row { - // TODO 不支持坐标类型的导出 - switch data := row[i].(type) { - case nil: - // str = "" - case []byte: - // 先尝试转成数字,如果报错则转换成string - if v, err := row.Int64Err(i); err != nil { - values[i] = string(data) - } else { - values[i] = strconv.FormatInt(v, 10) - } - case time.Time: - values[i] = mysql.TimeString(data) - case time.Duration: - values[i] = mysql.DurationString(data) - default: - values[i] = fmt.Sprint(data) - } - - // 非text/varchar类的数据类型,如果dump出的数据为空,则说明该值为null值 - // 应转换其 value 为 null,如果用空('')进行替代,会导致出现语法错误。 - if len(dataTypes) == len(res.Fields()) && values[i] == "" && - (!strings.Contains(dataTypes[i], "char") || - !strings.Contains(dataTypes[i], "text")) { - values[i] = "null" - } else { - values[i] = "'" + values[i] + "'" - } - } - - valuesStr := fmt.Sprintf(`(%s)`, strings.Join(values, `,`)) - valList = append(valList, valuesStr) - - rowCount++ - valCount++ - - if rowCount%maxValCount == 0 { - doSampling(localConn, database, table, colDef, strings.Join(valList, ",")) - valCount = 0 - valList = make([]string, 0) - - } - } - - common.Log.Debug("%d rows sampling out", rowCount) - return nil - */ + valuesStr = fmt.Sprintf(`(%s)`, strings.Join(values, `,`)) + doSampling(localConn, database, table, columnsStr, valuesStr) + } + res.Close() + return nil } // 将泵取的数据转换成Insert语句并在数据库中执行 func doSampling(conn *sql.DB, dbName, table, colDef, values string) { - query := fmt.Sprintf("Insert into `%s`.`%s`(%s) values%s;", dbName, table, + query := fmt.Sprintf("INSERT INTO `%s`.`%s` (%s) VALUES %s;", dbName, table, colDef, values) - _, err := conn.Query(query) - + _, err := conn.Exec(query) if err != nil { common.Log.Error("doSampling Error from %s.%s: %v", dbName, table, err) } - }