db.go 24.9 KB
Newer Older
S
slene 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
package orm

import (
	"database/sql"
	"errors"
	"fmt"
	"reflect"
	"strings"
	"time"
)

const (
	format_Date     = "2006-01-02"
	format_DateTime = "2006-01-02 15:04:05"
)

var (
	ErrMissPK = errors.New("missed pk value")
)

var (
	operators = map[string]bool{
		"exact":     true,
		"iexact":    true,
		"contains":  true,
		"icontains": true,
		// "regex":       true,
		// "iregex":      true,
		"gt":          true,
		"gte":         true,
		"lt":          true,
		"lte":         true,
		"startswith":  true,
		"endswith":    true,
		"istartswith": true,
		"iendswith":   true,
		"in":          true,
		// "range":       true,
		// "year":        true,
		// "month":       true,
		// "day":         true,
		// "week_day":    true,
		"isnull": true,
		// "search":      true,
	}
)

type dbBase struct {
	ins dbBaser
}

S
slene 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64
func (d *dbBase) existPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {

	fi := mi.fields.pk

	v := ind.Field(fi.fieldIndex)
	if fi.fieldType&IsIntegerField > 0 {
		vu := v.Int()
		exist = vu > 0
		value = vu
	} else {
		vu := v.String()
		exist = vu != ""
		value = vu
S
slene 已提交
65
	}
S
slene 已提交
66 67 68 69

	column = fi.column

	return
S
slene 已提交
70 71 72
}

func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) {
S
slene 已提交
73
	_, pkValue, _ := d.existPk(mi, ind)
S
slene 已提交
74 75 76 77 78 79
	for _, column := range mi.fields.orders {
		fi := mi.fields.columns[column]
		if fi.dbcol == false || fi.auto && skipAuto {
			continue
		}
		var value interface{}
S
slene 已提交
80 81
		if fi.pk {
			value = pkValue
S
slene 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
		} else {
			field := ind.Field(fi.fieldIndex)
			if fi.isFielder {
				f := field.Addr().Interface().(Fielder)
				value = f.RawValue()
			} else {
				switch fi.fieldType {
				case TypeBooleanField:
					value = field.Bool()
				case TypeCharField, TypeTextField:
					value = field.String()
				case TypeFloatField, TypeDecimalField:
					value = field.Float()
				case TypeDateField, TypeDateTimeField:
					value = field.Interface()
				default:
					switch {
					case fi.fieldType&IsPostiveIntegerField > 0:
						value = field.Uint()
					case fi.fieldType&IsIntegerField > 0:
						value = field.Int()
					case fi.fieldType&IsRelField > 0:
						if field.IsNil() {
							value = nil
						} else {
S
slene 已提交
107 108
							if _, vu, ok := d.existPk(fi.relModelInfo, reflect.Indirect(field)); ok {
								value = vu
S
slene 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
							} else {
								value = nil
							}
						}
						if fi.null == false && value == nil {
							return nil, nil, errors.New(fmt.Sprintf("field `%s` cannot be NULL", fi.fullName))
						}
					}
				}
			}
			switch fi.fieldType {
			case TypeDateField, TypeDateTimeField:
				if fi.auto_now || fi.auto_now_add && insert {
					tnow := time.Now()
					if fi.fieldType == TypeDateField {
						value = timeFormat(tnow, format_Date)
					} else {
						value = timeFormat(tnow, format_DateTime)
					}
					if fi.isFielder {
						f := field.Addr().Interface().(Fielder)
						f.SetRaw(tnow)
					} else {
						field.Set(reflect.ValueOf(tnow))
					}
				}
			}
		}
		columns = append(columns, column)
		values = append(values, value)
	}
	return
}

S
slene 已提交
143
func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
144 145
	Q := d.ins.TableQuote()

S
slene 已提交
146 147 148 149 150 151 152 153 154
	dbcols := make([]string, 0, len(mi.fields.dbcols))
	marks := make([]string, 0, len(mi.fields.dbcols))
	for _, fi := range mi.fields.fieldsDB {
		if fi.auto == false {
			dbcols = append(dbcols, fi.column)
			marks = append(marks, "?")
		}
	}
	qmarks := strings.Join(marks, ", ")
155 156 157 158 159 160
	sep := fmt.Sprintf("%s, %s", Q, Q)
	columns := strings.Join(dbcols, sep)

	query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)

	d.ins.ReplaceMarks(&query)
S
slene 已提交
161

S
slene 已提交
162 163
	stmt, err := q.Prepare(query)
	return stmt, query, err
S
slene 已提交
164 165
}

S
slene 已提交
166
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
S
slene 已提交
167 168 169 170 171 172 173 174 175 176 177 178
	_, values, err := d.collectValues(mi, ind, true, true)
	if err != nil {
		return 0, err
	}

	if res, err := stmt.Exec(values...); err == nil {
		return res.LastInsertId()
	} else {
		return 0, err
	}
}

S
slene 已提交
179
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
S
slene 已提交
180
	pkColumn, pkValue, ok := d.existPk(mi, ind)
S
slene 已提交
181 182 183 184
	if ok == false {
		return ErrMissPK
	}

185 186 187 188
	Q := d.ins.TableQuote()

	sep := fmt.Sprintf("%s, %s", Q, Q)
	sels := strings.Join(mi.fields.dbcols, sep)
S
slene 已提交
189 190
	colsNum := len(mi.fields.dbcols)

191
	query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ?", Q, sels, Q, Q, mi.table, Q, Q, pkColumn, Q)
S
slene 已提交
192 193 194 195 196 197 198

	refs := make([]interface{}, colsNum)
	for i, _ := range refs {
		var ref interface{}
		refs[i] = &ref
	}

199 200
	d.ins.ReplaceMarks(&query)

S
slene 已提交
201
	row := q.QueryRow(query, pkValue)
S
slene 已提交
202
	if err := row.Scan(refs...); err != nil {
S
slene 已提交
203 204 205
		if err == sql.ErrNoRows {
			return ErrNoRows
		}
S
slene 已提交
206 207 208 209 210 211 212 213 214 215 216 217 218
		return err
	} else {
		elm := reflect.New(mi.addrField.Elem().Type())
		mind := reflect.Indirect(elm)

		d.setColsValues(mi, &mind, mi.fields.dbcols, refs)

		ind.Set(mind)
	}

	return nil
}

S
slene 已提交
219 220 221 222 223 224
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
	names, values, err := d.collectValues(mi, ind, true, true)
	if err != nil {
		return 0, err
	}

225 226
	Q := d.ins.TableQuote()

S
slene 已提交
227 228 229 230
	marks := make([]string, len(names))
	for i, _ := range marks {
		marks[i] = "?"
	}
231 232

	sep := fmt.Sprintf("%s, %s", Q, Q)
S
slene 已提交
233
	qmarks := strings.Join(marks, ", ")
234 235 236
	columns := strings.Join(names, sep)

	query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
S
slene 已提交
237

238
	d.ins.ReplaceMarks(&query)
S
slene 已提交
239 240 241 242 243 244 245 246 247

	if res, err := q.Exec(query, values...); err == nil {
		return res.LastInsertId()
	} else {
		return 0, err
	}
}

func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
S
slene 已提交
248
	pkName, pkValue, ok := d.existPk(mi, ind)
S
slene 已提交
249 250 251 252 253 254 255 256
	if ok == false {
		return 0, ErrMissPK
	}
	setNames, setValues, err := d.collectValues(mi, ind, true, false)
	if err != nil {
		return 0, err
	}

257 258 259
	setValues = append(setValues, pkValue)

	Q := d.ins.TableQuote()
S
slene 已提交
260

261 262
	sep := fmt.Sprintf("%s = ?, %s", Q, Q)
	setColumns := strings.Join(setNames, sep)
S
slene 已提交
263

264 265 266
	query := fmt.Sprintf("UPDATE %s%s%s SET %s%s%s = ? WHERE %s%s%s = ?", Q, mi.table, Q, Q, setColumns, Q, Q, pkName, Q)

	d.ins.ReplaceMarks(&query)
S
slene 已提交
267 268 269 270 271 272 273 274 275 276

	if res, err := q.Exec(query, setValues...); err == nil {
		return res.RowsAffected()
	} else {
		return 0, err
	}
	return 0, nil
}

func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
S
slene 已提交
277
	pkName, pkValue, ok := d.existPk(mi, ind)
S
slene 已提交
278 279 280 281
	if ok == false {
		return 0, ErrMissPK
	}

282 283 284 285 286
	Q := d.ins.TableQuote()

	query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, pkName, Q)

	d.ins.ReplaceMarks(&query)
S
slene 已提交
287

S
slene 已提交
288
	if res, err := q.Exec(query, pkValue); err == nil {
S
slene 已提交
289 290 291 292 293 294 295

		num, err := res.RowsAffected()
		if err != nil {
			return 0, err
		}

		if num > 0 {
S
slene 已提交
296 297
			if mi.fields.pk.auto {
				ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
S
slene 已提交
298 299
			}

S
slene 已提交
300 301 302
			err := d.deleteRels(q, mi, []interface{}{pkValue})
			if err != nil {
				return num, err
S
slene 已提交
303 304 305 306 307 308 309 310 311 312 313 314 315 316
			}
		}

		return num, err
	} else {
		return 0, err
	}
	return 0, nil
}

func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params) (int64, error) {
	columns := make([]string, 0, len(params))
	values := make([]interface{}, 0, len(params))
	for col, val := range params {
S
slene 已提交
317 318 319 320 321
		if fi, ok := mi.fields.GetByAny(col); ok == false || fi.dbcol == false {
			panic(fmt.Sprintf("wrong field/column name `%s`", col))
		} else {
			columns = append(columns, fi.column)
			values = append(values, val)
S
slene 已提交
322 323 324 325 326 327 328 329 330 331 332 333 334 335
		}
	}

	if len(columns) == 0 {
		panic("update params cannot empty")
	}

	tables := newDbTables(mi, d.ins)
	if qs != nil {
		tables.parseRelated(qs.related, qs.relDepth)
	}

	where, args := tables.getCondSql(cond, false)

336 337
	values = append(values, args...)

S
slene 已提交
338 339
	join := tables.getJoinSql()

340
	var query string
S
slene 已提交
341

342 343 344 345 346 347 348 349 350 351 352 353
	Q := d.ins.TableQuote()

	if d.ins.SupportUpdateJoin() {
		cols := strings.Join(columns, fmt.Sprintf("%s = ?, T0.%s", Q, Q))
		query = fmt.Sprintf("UPDATE %s%s%s T0 %sSET T0.%s%s%s = ? %s", Q, mi.table, Q, join, Q, cols, Q, where)
	} else {
		cols := strings.Join(columns, fmt.Sprintf("%s = ?, %s", Q, Q))
		supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s", Q, mi.fields.pk.column, Q, Q, mi.table, Q, join, where)
		query = fmt.Sprintf("UPDATE %s%s%s SET %s%s%s = ? WHERE %s%s%s IN ( %s )", Q, mi.table, Q, Q, cols, Q, Q, mi.fields.pk.column, Q, supQuery)
	}

	d.ins.ReplaceMarks(&query)
S
slene 已提交
354 355 356 357 358 359 360 361 362 363 364 365 366 367

	if res, err := q.Exec(query, values...); err == nil {
		return res.RowsAffected()
	} else {
		return 0, err
	}
	return 0, nil
}

func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}) error {
	for _, fi := range mi.fields.fieldsReverse {
		fi = fi.reverseFieldInfo
		switch fi.onDelete {
		case od_CASCADE:
S
slene 已提交
368
			cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
S
slene 已提交
369 370 371 372 373
			_, err := d.DeleteBatch(q, nil, fi.mi, cond)
			if err != nil {
				return err
			}
		case od_SET_DEFAULT, od_SET_NULL:
S
slene 已提交
374
			cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
S
slene 已提交
375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398
			params := Params{fi.column: nil}
			if fi.onDelete == od_SET_DEFAULT {
				params[fi.column] = fi.initial.String()
			}
			_, err := d.UpdateBatch(q, nil, fi.mi, cond, params)
			if err != nil {
				return err
			}
		case od_DO_NOTHING:
		}
	}
	return nil
}

func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition) (int64, error) {
	tables := newDbTables(mi, d.ins)
	if qs != nil {
		tables.parseRelated(qs.related, qs.relDepth)
	}

	if cond == nil || cond.IsEmpty() {
		panic("delete operation cannot execute without condition")
	}

399 400
	Q := d.ins.TableQuote()

S
slene 已提交
401 402 403
	where, args := tables.getCondSql(cond, false)
	join := tables.getJoinSql()

404 405 406 407
	cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q)
	query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where)

	d.ins.ReplaceMarks(&query)
S
slene 已提交
408 409 410 411 412 413 414 415

	var rs *sql.Rows
	if r, err := q.Query(query, args...); err != nil {
		return 0, err
	} else {
		rs = r
	}

S
slene 已提交
416
	var ref interface{}
S
slene 已提交
417 418 419 420

	args = make([]interface{}, 0)
	cnt := 0
	for rs.Next() {
S
slene 已提交
421
		if err := rs.Scan(&ref); err != nil {
S
slene 已提交
422 423
			return 0, err
		}
S
slene 已提交
424
		args = append(args, reflect.ValueOf(ref).Interface())
S
slene 已提交
425 426 427 428 429 430 431
		cnt++
	}

	if cnt == 0 {
		return 0, nil
	}

432 433 434 435
	sql, args := d.ins.GenerateOperatorSql(mi, "in", args)
	query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql)

	d.ins.ReplaceMarks(&query)
S
slene 已提交
436 437 438 439 440 441 442

	if res, err := q.Exec(query, args...); err == nil {
		num, err := res.RowsAffected()
		if err != nil {
			return 0, err
		}

S
slene 已提交
443
		if num > 0 {
S
slene 已提交
444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467
			err := d.deleteRels(q, mi, args)
			if err != nil {
				return num, err
			}
		}

		return num, nil
	} else {
		return 0, err
	}

	return 0, nil
}

func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}) (int64, error) {

	val := reflect.ValueOf(container)
	ind := reflect.Indirect(val)

	errTyp := true

	one := true

	if val.Kind() == reflect.Ptr {
468
		fn := ""
S
slene 已提交
469 470 471
		if ind.Kind() == reflect.Slice {
			one = false
			if ind.Type().Elem().Kind() == reflect.Ptr {
472 473
				typ := ind.Type().Elem().Elem()
				fn = getFullName(typ)
S
slene 已提交
474
			}
475 476
		} else {
			fn = getFullName(ind.Type())
S
slene 已提交
477
		}
478
		errTyp = fn != mi.fullName
S
slene 已提交
479 480 481
	}

	if errTyp {
482
		panic(fmt.Sprintf("wrong object type `%s` for rows scan, need *[]*%s or *%s", ind.Type(), mi.fullName, mi.fullName))
S
slene 已提交
483 484 485 486 487 488 489 490 491
	}

	rlimit := qs.limit
	offset := qs.offset
	if one {
		rlimit = 0
		offset = 0
	}

492 493
	Q := d.ins.TableQuote()

S
slene 已提交
494 495 496 497 498
	tables := newDbTables(mi, d.ins)
	tables.parseRelated(qs.related, qs.relDepth)

	where, args := tables.getCondSql(cond, false)
	orderBy := tables.getOrderSql(qs.orders)
499
	limit := tables.getLimitSql(mi, offset, rlimit)
S
slene 已提交
500 501 502
	join := tables.getJoinSql()

	colsNum := len(mi.fields.dbcols)
503 504
	sep := fmt.Sprintf("%s, T0.%s", Q, Q)
	cols := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(mi.fields.dbcols, sep), Q)
S
slene 已提交
505 506 507
	for _, tbl := range tables.tables {
		if tbl.sel {
			colsNum += len(tbl.mi.fields.dbcols)
508 509
			sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q)
			cols += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q)
S
slene 已提交
510 511 512
		}
	}

513 514 515
	query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", cols, Q, mi.table, Q, join, where, orderBy, limit)

	d.ins.ReplaceMarks(&query)
S
slene 已提交
516 517 518 519 520 521 522 523 524 525

	var rs *sql.Rows
	if r, err := q.Query(query, args...); err != nil {
		return 0, err
	} else {
		rs = r
	}

	refs := make([]interface{}, colsNum)
	for i, _ := range refs {
S
slene 已提交
526
		var ref interface{}
S
slene 已提交
527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563
		refs[i] = &ref
	}

	slice := ind

	var cnt int64
	for rs.Next() {
		if one && cnt == 0 || one == false {
			if err := rs.Scan(refs...); err != nil {
				return 0, err
			}

			elm := reflect.New(mi.addrField.Elem().Type())
			mind := reflect.Indirect(elm)

			cacheV := make(map[string]*reflect.Value)
			cacheM := make(map[string]*modelInfo)
			trefs := refs

			d.setColsValues(mi, &mind, mi.fields.dbcols, refs[:len(mi.fields.dbcols)])
			trefs = refs[len(mi.fields.dbcols):]

			for _, tbl := range tables.tables {
				if tbl.sel {
					last := mind
					names := ""
					mmi := mi
					for _, name := range tbl.names {
						names += name
						if val, ok := cacheV[names]; ok {
							last = *val
							mmi = cacheM[names]
						} else {
							fi := mmi.fields.GetByName(name)
							lastm := mmi
							mmi := fi.relModelInfo
							field := reflect.Indirect(last.Field(fi.fieldIndex))
S
slene 已提交
564 565 566 567 568 569 570
							if field.IsValid() {
								d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)])
								for _, fi := range mmi.fields.fieldsReverse {
									if fi.reverseFieldInfo.mi == lastm {
										if fi.reverseFieldInfo != nil {
											field.Field(fi.fieldIndex).Set(last.Addr())
										}
S
slene 已提交
571 572
									}
								}
S
slene 已提交
573 574 575
								cacheV[names] = &field
								cacheM[names] = mmi
								last = field
S
slene 已提交
576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606
							}
							trefs = trefs[len(mmi.fields.dbcols):]
						}
					}
				}
			}

			if one {
				ind.Set(mind)
			} else {
				slice = reflect.Append(slice, mind.Addr())
			}
		}
		cnt++
	}

	if one == false {
		ind.Set(slice)
	}

	return cnt, nil
}

func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition) (cnt int64, err error) {
	tables := newDbTables(mi, d.ins)
	tables.parseRelated(qs.related, qs.relDepth)

	where, args := tables.getCondSql(cond, false)
	tables.getOrderSql(qs.orders)
	join := tables.getJoinSql()

607 608 609 610 611
	Q := d.ins.TableQuote()

	query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s", Q, mi.table, Q, join, where)

	d.ins.ReplaceMarks(&query)
S
slene 已提交
612 613 614 615 616 617 618

	row := q.QueryRow(query, args...)

	err = row.Scan(&cnt)
	return
}

619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657
func (d *dbBase) getOperatorParams(operator string, args []interface{}) (params []interface{}) {
	for _, arg := range args {
		val := reflect.ValueOf(arg)

		if arg == nil {
			params = append(params, arg)
			continue
		}

		kind := val.Kind()

		switch kind {
		case reflect.Slice, reflect.Array:
			var args []interface{}
			for i := 0; i < val.Len(); i++ {
				v := val.Index(i)

				var vu interface{}
				if v.CanInterface() {
					vu = v.Interface()
				}

				if vu == nil {
					continue
				}

				args = append(args, vu)
			}

			if len(args) > 0 {
				p := d.getOperatorParams(operator, args)
				params = append(params, p...)
			}

		case reflect.Ptr, reflect.Struct:
			ind := reflect.Indirect(val)

			if ind.Kind() == reflect.Struct {
				typ := ind.Type()
658
				name := getFullName(typ)
659
				var value interface{}
660
				if mmi, ok := modelCache.getByFN(name); ok {
661 662 663 664 665 666 667
					if _, vu, exist := d.existPk(mmi, ind); exist {
						value = vu
					}
				}
				arg = value

				if arg == nil {
668
					panic(fmt.Sprintf("`%s` operator need a valid args value, unknown table or value `%s`", operator, name))
669
				}
S
slene 已提交
670
			} else {
671
				arg = ind.Interface()
S
slene 已提交
672
			}
673 674 675 676 677

			params = append(params, arg)

		default:
			params = append(params, arg)
S
slene 已提交
678
		}
679

S
slene 已提交
680
	}
681 682 683 684

	return
}

685
func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []interface{}) (string, []interface{}) {
686 687 688
	sql := ""
	params := d.getOperatorParams(operator, args)

S
slene 已提交
689 690 691 692 693 694 695 696 697 698
	if operator == "in" {
		marks := make([]string, len(params))
		for i, _ := range marks {
			marks[i] = "?"
		}
		sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", "))
	} else {
		if len(params) > 1 {
			panic(fmt.Sprintf("operator `%s` need 1 args not %d", operator, len(params)))
		}
699
		sql = d.ins.OperatorSql(operator)
S
slene 已提交
700 701
		arg := params[0]
		switch operator {
S
slene 已提交
702 703 704 705
		case "exact":
			if arg == nil {
				params[0] = "IS NULL"
			}
S
slene 已提交
706 707 708
		case "iexact", "contains", "icontains", "startswith", "endswith", "istartswith", "iendswith":
			param := strings.Replace(ToStr(arg), `%`, `\%`, -1)
			switch operator {
S
slene 已提交
709 710
			case "iexact":
			case "contains", "icontains":
S
slene 已提交
711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743
				param = fmt.Sprintf("%%%s%%", param)
			case "startswith", "istartswith":
				param = fmt.Sprintf("%s%%", param)
			case "endswith", "iendswith":
				param = fmt.Sprintf("%%%s", param)
			}
			params[0] = param
		case "isnull":
			if b, ok := arg.(bool); ok {
				if b {
					sql = "IS NULL"
				} else {
					sql = "IS NOT NULL"
				}
				params = nil
			} else {
				panic(fmt.Sprintf("operator `%s` need a bool value not `%T`", operator, arg))
			}
		}
	}
	return sql, params
}

func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}) {
	for i, column := range cols {
		val := reflect.Indirect(reflect.ValueOf(values[i])).Interface()

		fi := mi.fields.GetByColumn(column)

		field := ind.Field(fi.fieldIndex)

		value, err := d.getValue(fi, val)
		if err != nil {
744
			panic(fmt.Sprintf("Raw value: `%v` %s", val, err.Error()))
S
slene 已提交
745 746 747 748 749
		}

		_, err = d.setValue(fi, value, &field)

		if err != nil {
750
			panic(fmt.Sprintf("Raw value: `%v` %s", val, err.Error()))
S
slene 已提交
751 752 753 754 755 756 757 758 759 760
		}
	}
}

func (d *dbBase) getValue(fi *fieldInfo, val interface{}) (interface{}, error) {
	if val == nil {
		return nil, nil
	}

	var value interface{}
761
	var tErr error
S
slene 已提交
762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790

	var str *StrTo
	switch v := val.(type) {
	case []byte:
		s := StrTo(string(v))
		str = &s
	case string:
		s := StrTo(v)
		str = &s
	}

	fieldType := fi.fieldType

setValue:
	switch {
	case fieldType == TypeBooleanField:
		if str == nil {
			switch v := val.(type) {
			case int64:
				b := v == 1
				value = b
			default:
				s := StrTo(ToStr(v))
				str = &s
			}
		}
		if str != nil {
			b, err := str.Bool()
			if err != nil {
791 792
				tErr = err
				goto end
S
slene 已提交
793 794 795 796 797
			}
			value = b
		}
	case fieldType == TypeCharField || fieldType == TypeTextField:
		if str == nil {
798 799 800
			value = ToStr(val)
		} else {
			value = str.String()
S
slene 已提交
801 802 803 804 805 806 807 808 809 810 811 812
		}
	case fieldType == TypeDateField || fieldType == TypeDateTimeField:
		if str == nil {
			switch v := val.(type) {
			case time.Time:
				value = v
			default:
				s := StrTo(ToStr(v))
				str = &s
			}
		}
		if str != nil {
813 814
			s := str.String()
			var format string
S
slene 已提交
815 816
			if fi.fieldType == TypeDateField {
				format = format_Date
817 818 819 820 821 822 823 824
				if len(s) > 10 {
					s = s[:10]
				}
			} else {
				format = format_DateTime
				if len(s) > 19 {
					s = s[:19]
				}
S
slene 已提交
825 826 827
			}
			t, err := timeParse(s, format)
			if err != nil && s != "0000-00-00" && s != "0000-00-00 00:00:00" {
828 829
				tErr = err
				goto end
S
slene 已提交
830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854
			}
			value = t
		}
	case fieldType&IsIntegerField > 0:
		if str == nil {
			s := StrTo(ToStr(val))
			str = &s
		}
		if str != nil {
			var err error
			switch fieldType {
			case TypeSmallIntegerField:
				_, err = str.Int16()
			case TypeIntegerField:
				_, err = str.Int32()
			case TypeBigIntegerField:
				_, err = str.Int64()
			case TypePositiveSmallIntegerField:
				_, err = str.Uint16()
			case TypePositiveIntegerField:
				_, err = str.Uint32()
			case TypePositiveBigIntegerField:
				_, err = str.Uint64()
			}
			if err != nil {
855 856
				tErr = err
				goto end
S
slene 已提交
857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878
			}
			if fieldType&IsPostiveIntegerField > 0 {
				v, _ := str.Uint64()
				value = v
			} else {
				v, _ := str.Int64()
				value = v
			}
		}
	case fieldType == TypeFloatField || fieldType == TypeDecimalField:
		if str == nil {
			switch v := val.(type) {
			case float64:
				value = v
			default:
				s := StrTo(ToStr(v))
				str = &s
			}
		}
		if str != nil {
			v, err := str.Float64()
			if err != nil {
879 880
				tErr = err
				goto end
S
slene 已提交
881 882 883 884
			}
			value = v
		}
	case fieldType&IsRelField > 0:
885 886
		fi = fi.relModelInfo.fields.pk
		fieldType = fi.fieldType
S
slene 已提交
887 888 889
		goto setValue
	}

890 891 892 893 894 895
end:
	if tErr != nil {
		err := fmt.Errorf("convert to `%s` failed, field: %s err: %s", fi.addrValue.Type(), fi.fullName, tErr)
		return nil, err
	}

S
slene 已提交
896 897 898 899 900 901 902 903 904 905 906 907 908
	return value, nil

}

func (d *dbBase) setValue(fi *fieldInfo, value interface{}, field *reflect.Value) (interface{}, error) {

	fieldType := fi.fieldType
	isNative := fi.isFielder == false

setValue:
	switch {
	case fieldType == TypeBooleanField:
		if isNative {
S
slene 已提交
909 910 911
			if value == nil {
				value = false
			}
S
slene 已提交
912 913 914 915
			field.SetBool(value.(bool))
		}
	case fieldType == TypeCharField || fieldType == TypeTextField:
		if isNative {
S
slene 已提交
916 917 918
			if value == nil {
				value = ""
			}
S
slene 已提交
919 920 921 922
			field.SetString(value.(string))
		}
	case fieldType == TypeDateField || fieldType == TypeDateTimeField:
		if isNative {
S
slene 已提交
923 924 925
			if value == nil {
				value = time.Time{}
			}
S
slene 已提交
926 927 928 929 930
			field.Set(reflect.ValueOf(value))
		}
	case fieldType&IsIntegerField > 0:
		if fieldType&IsPostiveIntegerField > 0 {
			if isNative {
S
slene 已提交
931 932 933
				if value == nil {
					value = uint64(0)
				}
S
slene 已提交
934 935 936 937
				field.SetUint(value.(uint64))
			}
		} else {
			if isNative {
S
slene 已提交
938 939 940
				if value == nil {
					value = int64(0)
				}
S
slene 已提交
941 942 943 944 945
				field.SetInt(value.(int64))
			}
		}
	case fieldType == TypeFloatField || fieldType == TypeDecimalField:
		if isNative {
S
slene 已提交
946 947 948
			if value == nil {
				value = float64(0)
			}
S
slene 已提交
949 950 951
			field.SetFloat(value.(float64))
		}
	case fieldType&IsRelField > 0:
S
slene 已提交
952
		if value != nil {
S
slene 已提交
953
			fieldType = fi.relModelInfo.fields.pk.fieldType
S
slene 已提交
954 955
			mf := reflect.New(fi.relModelInfo.addrField.Elem().Type())
			field.Set(mf)
S
slene 已提交
956
			f := mf.Elem().Field(fi.relModelInfo.fields.pk.fieldIndex)
S
slene 已提交
957 958
			field = &f
			goto setValue
S
slene 已提交
959 960 961 962 963 964 965
		}
	}

	if isNative == false {
		fd := field.Addr().Interface().(Fielder)
		err := fd.SetRaw(value)
		if err != nil {
966
			err = fmt.Errorf("converted value `%v` set to Fielder `%s` failed, err: %s", value, fi.fullName, err)
S
slene 已提交
967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002
			return nil, err
		}
	}

	return value, nil
}

func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}) (int64, error) {

	var (
		maps  []Params
		lists []ParamsList
		list  ParamsList
	)

	typ := 0
	switch container.(type) {
	case *[]Params:
		typ = 1
	case *[]ParamsList:
		typ = 2
	case *ParamsList:
		typ = 3
	default:
		panic(fmt.Sprintf("unsupport read values type `%T`", container))
	}

	tables := newDbTables(mi, d.ins)

	var (
		cols  []string
		infos []*fieldInfo
	)

	hasExprs := len(exprs) > 0

1003 1004
	Q := d.ins.TableQuote()

S
slene 已提交
1005 1006 1007 1008
	if hasExprs {
		cols = make([]string, 0, len(exprs))
		infos = make([]*fieldInfo, 0, len(exprs))
		for _, ex := range exprs {
S
slene 已提交
1009
			index, col, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep))
S
slene 已提交
1010 1011 1012
			if suc == false {
				panic(fmt.Errorf("unknown field/column name `%s`", ex))
			}
1013
			cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, col, Q, Q, name, Q))
S
slene 已提交
1014 1015 1016 1017 1018 1019
			infos = append(infos, fi)
		}
	} else {
		cols = make([]string, 0, len(mi.fields.dbcols))
		infos = make([]*fieldInfo, 0, len(exprs))
		for _, fi := range mi.fields.fieldsDB {
1020
			cols = append(cols, fmt.Sprintf("T0.%s%s%s %s%s%s", Q, fi.column, Q, Q, fi.name, Q))
S
slene 已提交
1021 1022 1023 1024 1025 1026
			infos = append(infos, fi)
		}
	}

	where, args := tables.getCondSql(cond, false)
	orderBy := tables.getOrderSql(qs.orders)
1027
	limit := tables.getLimitSql(mi, qs.offset, qs.limit)
S
slene 已提交
1028 1029 1030 1031
	join := tables.getJoinSql()

	sels := strings.Join(cols, ", ")

1032
	query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit)
S
slene 已提交
1033 1034 1035 1036 1037 1038 1039 1040 1041 1042

	var rs *sql.Rows
	if r, err := q.Query(query, args...); err != nil {
		return 0, err
	} else {
		rs = r
	}

	refs := make([]interface{}, len(cols))
	for i, _ := range refs {
S
slene 已提交
1043
		var ref interface{}
S
slene 已提交
1044 1045 1046
		refs[i] = &ref
	}

S
slene 已提交
1047 1048 1049 1050
	var (
		cnt     int64
		columns []string
	)
S
slene 已提交
1051
	for rs.Next() {
S
slene 已提交
1052 1053 1054 1055 1056 1057 1058 1059
		if cnt == 0 {
			if cols, err := rs.Columns(); err != nil {
				return 0, err
			} else {
				columns = cols
			}
		}

S
slene 已提交
1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076
		if err := rs.Scan(refs...); err != nil {
			return 0, err
		}

		switch typ {
		case 1:
			params := make(Params, len(cols))
			for i, ref := range refs {
				fi := infos[i]

				val := reflect.Indirect(reflect.ValueOf(ref)).Interface()

				value, err := d.getValue(fi, val)
				if err != nil {
					panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
				}

S
slene 已提交
1077
				params[columns[i]] = value
S
slene 已提交
1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123
			}
			maps = append(maps, params)
		case 2:
			params := make(ParamsList, 0, len(cols))
			for i, ref := range refs {
				fi := infos[i]

				val := reflect.Indirect(reflect.ValueOf(ref)).Interface()

				value, err := d.getValue(fi, val)
				if err != nil {
					panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
				}

				params = append(params, value)
			}
			lists = append(lists, params)
		case 3:
			for i, ref := range refs {
				fi := infos[i]

				val := reflect.Indirect(reflect.ValueOf(ref)).Interface()

				value, err := d.getValue(fi, val)
				if err != nil {
					panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
				}

				list = append(list, value)
			}
		}

		cnt++
	}

	switch v := container.(type) {
	case *[]Params:
		*v = maps
	case *[]ParamsList:
		*v = lists
	case *ParamsList:
		*v = list
	}

	return cnt, nil
}
1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139

func (d *dbBase) SupportUpdateJoin() bool {
	return true
}

func (d *dbBase) MaxLimit() uint64 {
	return 18446744073709551615
}

func (d *dbBase) TableQuote() string {
	return "`"
}

func (d *dbBase) ReplaceMarks(query *string) {
	// default use `?` as mark, do nothing
}