prepared.go 3.5 KB
Newer Older
martianzhang's avatar
martianzhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
package native

import (
	"github.com/ziutek/mymysql/mysql"
	"log"
)

type Stmt struct {
	my *Conn

	id  uint32
	sql string // For reprepare during reconnect

	params []paramValue // Parameters binding
	rebind bool
	binded bool

	fields []*mysql.Field

	field_count   int
	param_count   int
	warning_count int
	status        mysql.ConnStatus

	null_bitmap []byte
}

func (stmt *Stmt) Fields() []*mysql.Field {
	return stmt.fields
}

func (stmt *Stmt) NumParam() int {
	return stmt.param_count
}

func (stmt *Stmt) WarnCount() int {
	return stmt.warning_count
}

func (stmt *Stmt) sendCmdExec() {
	// Calculate packet length and NULL bitmap
	pkt_len := 1 + 4 + 1 + 4 + 1 + len(stmt.null_bitmap)
	for ii := range stmt.null_bitmap {
		stmt.null_bitmap[ii] = 0
	}
	for ii, param := range stmt.params {
		par_len := param.Len()
		pkt_len += par_len
		if par_len == 0 {
			null_byte := ii >> 3
			null_mask := byte(1) << uint(ii-(null_byte<<3))
			stmt.null_bitmap[null_byte] |= null_mask
		}
	}
	if stmt.rebind {
		pkt_len += stmt.param_count * 2
	}
	// Reset sequence number
	stmt.my.seq = 0
	// Packet sending
	pw := stmt.my.newPktWriter(pkt_len)
	pw.writeByte(_COM_STMT_EXECUTE)
	pw.writeU32(stmt.id)
	pw.writeByte(0) // flags = CURSOR_TYPE_NO_CURSOR
	pw.writeU32(1)  // iteration_count
	pw.write(stmt.null_bitmap)
	if stmt.rebind {
		pw.writeByte(1)
		// Types
		for _, param := range stmt.params {
			pw.writeU16(param.typ)
		}
	} else {
		pw.writeByte(0)
	}
	// Values
	for i := range stmt.params {
		pw.writeValue(&stmt.params[i])
	}

	if stmt.my.Debug {
		log.Printf("[%2d <-] Exec command packet: len=%d, null_bitmap=%v, rebind=%t",
			stmt.my.seq-1, pkt_len, stmt.null_bitmap, stmt.rebind)
	}

	// Mark that we sended information about binded types
	stmt.rebind = false
}

func (my *Conn) getPrepareResult(stmt *Stmt) interface{} {
loop:
	pr := my.newPktReader() // New reader for next packet
	pkt0 := pr.readByte()

	//log.Println("pkt0:", pkt0, "stmt:", stmt)

	if pkt0 == 255 {
		// Error packet
		my.getErrorPacket(pr)
	}

	if stmt == nil {
		if pkt0 == 0 {
			// OK packet
			return my.getPrepareOkPacket(pr)
		}
	} else {
		unreaded_params := (stmt.param_count < len(stmt.params))
		switch {
		case pkt0 == 254:
			// EOF packet
			stmt.warning_count, stmt.status = my.getEofPacket(pr)
			stmt.my.status = stmt.status
			return stmt

		case pkt0 > 0 && pkt0 < 251 && (stmt.field_count < len(stmt.fields) ||
			unreaded_params):
			// Field packet
			if unreaded_params {
				// Read and ignore parameter field. Sentence from MySQL source:
				/* skip parameters data: we don't support it yet */
				pr.skipAll()
				// Increment param_count count
				stmt.param_count++
			} else {
				field := my.getFieldPacket(pr)
				stmt.fields[stmt.field_count] = field
				// Increment field count
				stmt.field_count++
			}
			// Read next packet
			goto loop
		}
	}
	panic(mysql.ErrUnkResultPkt)
}

func (my *Conn) getPrepareOkPacket(pr *pktReader) (stmt *Stmt) {
	if my.Debug {
		log.Printf("[%2d ->] Perpared OK packet:", my.seq-1)
	}

	stmt = new(Stmt)
	stmt.my = my
	// First byte was readed by getPrepRes
	stmt.id = pr.readU32()
	stmt.fields = make([]*mysql.Field, int(pr.readU16())) // FieldCount
	pl := int(pr.readU16())                               // ParamCount
	if pl > 0 {
		stmt.params = make([]paramValue, pl)
		stmt.null_bitmap = make([]byte, (pl+7)>>3)
	}
	pr.skipN(1)
	stmt.warning_count = int(pr.readU16())
	pr.checkEof()

	if my.Debug {
		log.Printf(tab8s+"ID=0x%x ParamCount=%d FieldsCount=%d WarnCount=%d",
			stmt.id, len(stmt.params), len(stmt.fields), stmt.warning_count,
		)
	}
	return
}