Ver código fonte

Avoid phantom nil values

Julien Schmidt 12 anos atrás
pai
commit
ad44b8c0b9
2 arquivos alterados com 20 adições e 21 exclusões
  1. 10 2
      driver_test.go
  2. 10 19
      utils.go

+ 10 - 2
driver_test.go

@@ -526,14 +526,22 @@ func TestNULL(t *testing.T) {
 		if !success {
 			dbt.Error("Inserting []byte(nil) as NULL failed")
 		}
-		// Check input==output (==nil)
+		// Check input==output with input==nil
 		b = nil
-		if err = dbt.db.QueryRow("SELECT ?", nil).Scan(&b); err != nil {
+		if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil {
 			dbt.Fatal(err)
 		}
 		if b != nil {
 			dbt.Error("Non-nil echo from nil input")
 		}
+		// Check input==output with input!=nil
+		b = []byte("")
+		if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil {
+			dbt.Fatal(err)
+		}
+		if b == nil {
+			dbt.Error("nil echo from non-nil input")
+		}
 
 		// Insert NULL
 		dbt.mustExec("CREATE TABLE test (dummmy1 int, value int, dummy2 int)")

+ 10 - 19
utils.go

@@ -604,7 +604,7 @@ func readLengthEnodedString(b []byte) ([]byte, bool, int, error) {
 	// Get length
 	num, isNull, n := readLengthEncodedInteger(b)
 	if num < 1 {
-		return nil, isNull, n, nil
+		return b[n:n], isNull, n, nil
 	}
 
 	n += int(num)
@@ -632,40 +632,31 @@ func skipLengthEnodedString(b []byte) (int, error) {
 	return n, io.EOF
 }
 
-func readLengthEncodedInteger(b []byte) (num uint64, isNull bool, n int) {
+func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
 	switch b[0] {
 
 	// 251: NULL
 	case 0xfb:
-		n = 1
-		isNull = true
-		return
+		return 0, true, 1
 
 	// 252: value of following 2
 	case 0xfc:
-		num = uint64(b[1]) | uint64(b[2])<<8
-		n = 3
-		return
+		return uint64(b[1]) | uint64(b[2])<<8, false, 3
 
 	// 253: value of following 3
 	case 0xfd:
-		num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16
-		n = 4
-		return
+		return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4
 
 	// 254: value of following 8
 	case 0xfe:
-		num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
-			uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
-			uint64(b[7])<<48 | uint64(b[8])<<54
-		n = 9
-		return
+		return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
+				uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
+				uint64(b[7])<<48 | uint64(b[8])<<54,
+			false, 9
 	}
 
 	// 0-250: value of first byte
-	num = uint64(b[0])
-	n = 1
-	return
+	return uint64(b[0]), false, 1
 }
 
 func appendLengthEncodedInteger(b []byte, n uint64) []byte {