提交 e05c2bc2 编写于 作者: fantasy_cs's avatar fantasy_cs

feat: The function exporting SQL supports specifies a specific database type,

eg, mysql\oracle\sqlserver.
上级 b3b9ec8c
......@@ -6,6 +6,7 @@ Parameters:
-c --config The current config file for data format, and it can override the config in the default file.
-o --output The file name of the data generated. You can specify the output format by the extension name.
For example json, xml, sql, csv and xlsx. The text data in the original format is output by default.
Note: For SQL files, you can use --server to specify a specific database type in Mysql\Oracle\SqlServer.
-n --lines The number of lines of data to be generated. The default is 10.
-F --field This parameter can be used to specify the fields, separated by commas. The default is all fields.
......
......@@ -5,6 +5,7 @@ ZenData是一款通用的数据生成工具,您可以使用yaml文件来定义
-d --default 默认的数据格式配置文件。
-c --config 当前场景的数据格式配置文件,可以覆盖默认文件里面的设置。
-o --output 生成的数据的文件名。可通过扩展名指定输出json|xml|sql|csv|xlsx格式的数据。默认输出原始格式的文本数据。
注意:对于 SQL 文件,您可以使用 --server 指定Mysql\Oracle\SqlServer中具体的数据库类型。
-n --lines 要生成的记录条数,默认为10条。
-F --field 可通过该参数指定要输出的字段列表,用逗号分隔。 默认是所有的字段。
......
......@@ -16,15 +16,14 @@ import (
func Print(rows [][]string, format string, table string, colIsNumArr []bool,
fields []string) (lines []interface{}) {
var sqlHeader string
if format == constant.FormatText {
printTextHeader(fields)
} else if format == constant.FormatSql {
sqlHeader := getInsertSqlHeader(fields, table)
sqlHeader = getInsertSqlHeader(fields, table)
if vari.DBDsn != "" {
lines = append(lines, sqlHeader)
} else {
sqlHeader = "INSERT INTO " + sqlHeader
logUtils.PrintLine(sqlHeader)
}
} else if format == constant.FormatJson {
printJsonHeader()
......@@ -58,12 +57,20 @@ func Print(rows [][]string, format string, table string, colIsNumArr []bool,
row = append(row, col)
rowMap[field.Field] = col
colVal := stringUtils.ConvertForSql(col)
colVal := col
if !colIsNumArr[j] {
colVal = "'" + colVal + "'"
switch vari.Server {
case constant.DBTypeSqlServer:
colVal = "'" + stringUtils.EscapeValueOfSqlServer(colVal) + "'"
case constant.DBTypeOracle:
colVal = "'" + stringUtils.EscapeValueOfOracle(colVal) + "'"
// case constant.DBTypeMysql:
default:
colVal = "'" + stringUtils.EscapeValueOfMysql(colVal) + "'"
}
}
valuesForSql = append(valuesForSql, colVal)
}
} // for cols
if format == constant.FormatText && vari.Def.Type == constant.ConfigTypeArticle { // article need to write to more than one files
lines = append(lines, lineForText)
......@@ -72,15 +79,14 @@ func Print(rows [][]string, format string, table string, colIsNumArr []bool,
logUtils.PrintLine(lineForText)
} else if format == constant.FormatSql {
if vari.DBDsn != "" { // add to return array for exec sql
sql := strings.Join(valuesForSql, ", ")
lines = append(lines, sql)
} else {
sql := genSqlLine(strings.Join(valuesForSql, ", "), i, len(rows))
sql := genSqlLine(sqlHeader, valuesForSql, vari.Server)
logUtils.PrintLine(sql)
}
} else if format == constant.FormatJson {
logUtils.PrintLine(genJsonLine(i, row, len(rows), fields))
} else if format == constant.FormatXml {
......@@ -108,15 +114,17 @@ func printTextHeader(fields []string) {
logUtils.PrintLine(headerLine)
}
// return "Table> (<column1, column2,...)"
func getInsertSqlHeader(fields []string, table string) string {
fieldNames := make([]string, 0)
for _, f := range fields {
if vari.Server == "mysql" {
f = "`" + f + "`"
} else if vari.Server == "sqlserver" {
f = "[" + f + "]"
} else if vari.Server == "oracle" {
if vari.Server == constant.DBTypeSqlServer {
f = "[" + stringUtils.EscapeColumnOfSqlServer(f) + "]"
} else if vari.Server == constant.DBTypeOracle {
f = `"` + f + `"`
} else {
f = "`" + stringUtils.EscapeColumnOfMysql(f) + "`"
//vari.Server == constant.DBTypeMysql {
}
fieldNames = append(fieldNames, f)
......@@ -124,12 +132,13 @@ func getInsertSqlHeader(fields []string, table string) string {
var ret string
switch vari.Server {
case "mysql":
ret = fmt.Sprintf("`%s` (%s)", table, strings.Join(fieldNames, ", "))
case "sqlserver":
case constant.DBTypeSqlServer:
ret = fmt.Sprintf("[%s] (%s)", table, strings.Join(fieldNames, ", "))
case "oracle":
case constant.DBTypeOracle:
ret = fmt.Sprintf(`"%s" (%s)`, table, strings.Join(fieldNames, ", "))
// case constant.DBTypeMysql:
default:
ret = fmt.Sprintf("`%s` (%s)", table, strings.Join(fieldNames, ", "))
}
return ret
......@@ -155,22 +164,19 @@ func RowToJson(cols []string, fieldsToExport []string) string {
return respJson
}
func genSqlLine(valuesForSql string, i int, length int) string {
temp := ""
if i == 0 {
temp = fmt.Sprintf(" VALUES (%s)", valuesForSql)
} else {
temp = fmt.Sprintf(" (%s)", valuesForSql)
}
if i < length-1 {
temp = temp + ", "
} else {
temp = temp + "; "
// @return ""
func genSqlLine(sqlheader string, values []string, dbtype string) string {
var tmp string
switch dbtype {
case constant.DBTypeSqlServer:
tmp = "INSERT INTO " + sqlheader + " VALUES (" + strings.Join(values, ",") + "); GO"
default:
// constant.DBTypeMysql
// constant.DBTypeOracle:
tmp = "INSERT INTO " + sqlheader + " VALUES (" + strings.Join(values, ",") + ");"
}
return temp
return tmp
}
func genJsonLine(i int, row []string, length int, fields []string) string {
......
......@@ -95,4 +95,9 @@ var (
TablePrefix = "zd_"
PageSize = 15
// database type [added by leo 2022/05/10]
DBTypeMysql = "mysql"
DBTypeSqlServer = "sqlserver"
DBTypeOracle = "oracle"
)
......@@ -8,17 +8,18 @@ import (
"encoding/binary"
"encoding/hex"
"fmt"
"net/url"
"regexp"
"strconv"
"strings"
"unicode"
"github.com/Chain-Zhang/pinyin"
"github.com/easysoft/zendata/src/model"
constant "github.com/easysoft/zendata/src/utils/const"
"github.com/mattn/go-runewidth"
uuid "github.com/satori/go.uuid"
"gopkg.in/yaml.v2"
"net/url"
"regexp"
"strconv"
"strings"
"unicode"
)
func TrimAll(str string) string {
......@@ -205,6 +206,50 @@ func ConvertForSql(str string) (ret string) {
return
}
func Escape(in string, shouldEscape []rune) (out string) {
out = ""
escapeChar := shouldEscape[0]
for _, v := range in {
for _, se := range shouldEscape {
if v == se {
out += string(escapeChar)
break
}
}
out += string(v)
}
return
}
func EscapeValueOfMysql(in string) string {
return Escape(in, []rune{'\\', '\'', '"'})
}
func EscapeValueOfSqlServer(in string) string {
return Escape(in, []rune{'\''})
}
func EscapeValueOfOracle(in string) string {
return Escape(in, []rune{'\''})
}
func EscapeColumnOfMysql(in string) string {
return Escape(in, []rune{'`'})
}
func EscapeColumnOfSqlServer(in string) string {
return Escape(in, []rune{']'})
}
// oracle limit
//func EscapeColumnOfOracle(in string) string
func GetPinyin(word string) string {
p, _ := pinyin.New(word).Split("").Mode(pinyin.WithoutTone).Convert()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册