Przeglądaj źródła

murmur: fix hash to match cassandra sign bug (#1035)

Cassandra has a bug in the murmur hash which causes it to use signed
shifts. Fix this and refactor the code to make it easier to reason
about. All the extracted functions are inlined so there is no
performance change.
Chris Bannister 8 lat temu
rodzic
commit
56a164ee9f

+ 53 - 57
internal/murmur/murmur.go

@@ -1,44 +1,57 @@
-// +build !appengine
-
 package murmur
 
-import (
-	"unsafe"
+const (
+	c1    int64 = -8663945395140668459 // 0x87c37b91114253d5
+	c2    int64 = 5545529020109919103  // 0x4cf5ad432745937f
+	fmix1 int64 = -49064778989728563   // 0xff51afd7ed558ccd
+	fmix2 int64 = -4265267296055464877 // 0xc4ceb9fe1a85ec53
 )
 
-func Murmur3H1(data []byte) uint64 {
-	length := len(data)
+func fmix(n int64) int64 {
+	// cast to unsigned for logical right bitshift (to match C* MM3 implementation)
+	n ^= int64(uint64(n) >> 33)
+	n *= fmix1
+	n ^= int64(uint64(n) >> 33)
+	n *= fmix2
+	n ^= int64(uint64(n) >> 33)
+
+	return n
+}
+
+func block(p byte) int64 {
+	return int64(int8(p))
+}
 
-	var h1, h2, k1, k2 uint64
+func rotl(x int64, r uint8) int64 {
+	// cast to unsigned for logical right bitshift (to match C* MM3 implementation)
+	return (x << r) | (int64)((uint64(x) >> (64 - r)))
+}
 
-	const (
-		c1 = 0x87c37b91114253d5
-		c2 = 0x4cf5ad432745937f
-	)
+func Murmur3H1(data []byte) int64 {
+	length := len(data)
+
+	var h1, h2, k1, k2 int64
 
 	// body
 	nBlocks := length / 16
 	for i := 0; i < nBlocks; i++ {
-		block := (*[2]uint64)(unsafe.Pointer(&data[i*16]))
-
-		k1 = block[0]
-		k2 = block[1]
+		k1, k2 = getBlock(data, i)
 
 		k1 *= c1
-		k1 = (k1 << 31) | (k1 >> 33) // ROTL64(k1, 31)
+		k1 = rotl(k1, 31)
 		k1 *= c2
 		h1 ^= k1
 
-		h1 = (h1 << 27) | (h1 >> 37) // ROTL64(h1, 27)
+		h1 = rotl(h1, 27)
 		h1 += h2
 		h1 = h1*5 + 0x52dce729
 
 		k2 *= c2
-		k2 = (k2 << 33) | (k2 >> 31) // ROTL64(k2, 33)
+		k2 = rotl(k2, 33)
 		k2 *= c1
 		h2 ^= k2
 
-		h2 = (h2 << 31) | (h2 >> 33) // ROTL64(h2, 31)
+		h2 = rotl(h2, 31)
 		h2 += h1
 		h2 = h2*5 + 0x38495ab5
 	}
@@ -49,87 +62,70 @@ func Murmur3H1(data []byte) uint64 {
 	k2 = 0
 	switch length & 15 {
 	case 15:
-		k2 ^= uint64(tail[14]) << 48
+		k2 ^= block(tail[14]) << 48
 		fallthrough
 	case 14:
-		k2 ^= uint64(tail[13]) << 40
+		k2 ^= block(tail[13]) << 40
 		fallthrough
 	case 13:
-		k2 ^= uint64(tail[12]) << 32
+		k2 ^= block(tail[12]) << 32
 		fallthrough
 	case 12:
-		k2 ^= uint64(tail[11]) << 24
+		k2 ^= block(tail[11]) << 24
 		fallthrough
 	case 11:
-		k2 ^= uint64(tail[10]) << 16
+		k2 ^= block(tail[10]) << 16
 		fallthrough
 	case 10:
-		k2 ^= uint64(tail[9]) << 8
+		k2 ^= block(tail[9]) << 8
 		fallthrough
 	case 9:
-		k2 ^= uint64(tail[8])
+		k2 ^= block(tail[8])
 
 		k2 *= c2
-		k2 = (k2 << 33) | (k2 >> 31) // ROTL64(k2, 33)
+		k2 = rotl(k2, 33)
 		k2 *= c1
 		h2 ^= k2
 
 		fallthrough
 	case 8:
-		k1 ^= uint64(tail[7]) << 56
+		k1 ^= block(tail[7]) << 56
 		fallthrough
 	case 7:
-		k1 ^= uint64(tail[6]) << 48
+		k1 ^= block(tail[6]) << 48
 		fallthrough
 	case 6:
-		k1 ^= uint64(tail[5]) << 40
+		k1 ^= block(tail[5]) << 40
 		fallthrough
 	case 5:
-		k1 ^= uint64(tail[4]) << 32
+		k1 ^= block(tail[4]) << 32
 		fallthrough
 	case 4:
-		k1 ^= uint64(tail[3]) << 24
+		k1 ^= block(tail[3]) << 24
 		fallthrough
 	case 3:
-		k1 ^= uint64(tail[2]) << 16
+		k1 ^= block(tail[2]) << 16
 		fallthrough
 	case 2:
-		k1 ^= uint64(tail[1]) << 8
+		k1 ^= block(tail[1]) << 8
 		fallthrough
 	case 1:
-		k1 ^= uint64(tail[0])
+		k1 ^= block(tail[0])
 
 		k1 *= c1
-		k1 = (k1 << 31) | (k1 >> 33) // ROTL64(k1, 31)
+		k1 = rotl(k1, 31)
 		k1 *= c2
 		h1 ^= k1
 	}
 
-	h1 ^= uint64(length)
-	h2 ^= uint64(length)
+	h1 ^= int64(length)
+	h2 ^= int64(length)
 
 	h1 += h2
 	h2 += h1
 
-	// finalizer
-	const (
-		fmix1 = 0xff51afd7ed558ccd
-		fmix2 = 0xc4ceb9fe1a85ec53
-	)
-
-	// fmix64(h1)
-	h1 ^= h1 >> 33
-	h1 *= fmix1
-	h1 ^= h1 >> 33
-	h1 *= fmix2
-	h1 ^= h1 >> 33
-
-	// fmix64(h2)
-	h2 ^= h2 >> 33
-	h2 *= fmix1
-	h2 ^= h2 >> 33
-	h2 *= fmix2
-	h2 ^= h2 >> 33
+	h1 = fmix(h1)
+	h2 = fmix(h2)
 
 	h1 += h2
 	// the following is extraneous since h2 is discarded

+ 4 - 130
internal/murmur/murmur_appengine.go

@@ -4,134 +4,8 @@ package murmur
 
 import "encoding/binary"
 
-func Murmur3H1(data []byte) uint64 {
-	length := len(data)
-
-	var h1, h2, k1, k2 uint64
-
-	const (
-		c1 = 0x87c37b91114253d5
-		c2 = 0x4cf5ad432745937f
-	)
-
-	// body
-	nBlocks := length / 16
-	for i := 0; i < nBlocks; i++ {
-		// block := (*[2]uint64)(unsafe.Pointer(&data[i*16]))
-
-		k1 = binary.LittleEndian.Uint64(data[i*16:])
-		k2 = binary.LittleEndian.Uint64(data[(i*16)+8:])
-
-		k1 *= c1
-		k1 = (k1 << 31) | (k1 >> 33) // ROTL64(k1, 31)
-		k1 *= c2
-		h1 ^= k1
-
-		h1 = (h1 << 27) | (h1 >> 37) // ROTL64(h1, 27)
-		h1 += h2
-		h1 = h1*5 + 0x52dce729
-
-		k2 *= c2
-		k2 = (k2 << 33) | (k2 >> 31) // ROTL64(k2, 33)
-		k2 *= c1
-		h2 ^= k2
-
-		h2 = (h2 << 31) | (h2 >> 33) // ROTL64(h2, 31)
-		h2 += h1
-		h2 = h2*5 + 0x38495ab5
-	}
-
-	// tail
-	tail := data[nBlocks*16:]
-	k1 = 0
-	k2 = 0
-	switch length & 15 {
-	case 15:
-		k2 ^= uint64(tail[14]) << 48
-		fallthrough
-	case 14:
-		k2 ^= uint64(tail[13]) << 40
-		fallthrough
-	case 13:
-		k2 ^= uint64(tail[12]) << 32
-		fallthrough
-	case 12:
-		k2 ^= uint64(tail[11]) << 24
-		fallthrough
-	case 11:
-		k2 ^= uint64(tail[10]) << 16
-		fallthrough
-	case 10:
-		k2 ^= uint64(tail[9]) << 8
-		fallthrough
-	case 9:
-		k2 ^= uint64(tail[8])
-
-		k2 *= c2
-		k2 = (k2 << 33) | (k2 >> 31) // ROTL64(k2, 33)
-		k2 *= c1
-		h2 ^= k2
-
-		fallthrough
-	case 8:
-		k1 ^= uint64(tail[7]) << 56
-		fallthrough
-	case 7:
-		k1 ^= uint64(tail[6]) << 48
-		fallthrough
-	case 6:
-		k1 ^= uint64(tail[5]) << 40
-		fallthrough
-	case 5:
-		k1 ^= uint64(tail[4]) << 32
-		fallthrough
-	case 4:
-		k1 ^= uint64(tail[3]) << 24
-		fallthrough
-	case 3:
-		k1 ^= uint64(tail[2]) << 16
-		fallthrough
-	case 2:
-		k1 ^= uint64(tail[1]) << 8
-		fallthrough
-	case 1:
-		k1 ^= uint64(tail[0])
-
-		k1 *= c1
-		k1 = (k1 << 31) | (k1 >> 33) // ROTL64(k1, 31)
-		k1 *= c2
-		h1 ^= k1
-	}
-
-	h1 ^= uint64(length)
-	h2 ^= uint64(length)
-
-	h1 += h2
-	h2 += h1
-
-	// finalizer
-	const (
-		fmix1 = 0xff51afd7ed558ccd
-		fmix2 = 0xc4ceb9fe1a85ec53
-	)
-
-	// fmix64(h1)
-	h1 ^= h1 >> 33
-	h1 *= fmix1
-	h1 ^= h1 >> 33
-	h1 *= fmix2
-	h1 ^= h1 >> 33
-
-	// fmix64(h2)
-	h2 ^= h2 >> 33
-	h2 *= fmix1
-	h2 ^= h2 >> 33
-	h2 *= fmix2
-	h2 ^= h2 >> 33
-
-	h1 += h2
-	// the following is extraneous since h2 is discarded
-	// h2 += h1
-
-	return h1
+func getBlock(data []byte, n int) (int64, int64) {
+	k1 := binary.LittleEndian.Int64(data[n*16:])
+	k2 := binary.LittleEndian.Int64(data[(n*16)+8:])
+	return k1, k2
 }

+ 65 - 4
internal/murmur/murmur_test.go

@@ -1,10 +1,71 @@
 package murmur
 
 import (
+	"encoding/hex"
+	"fmt"
 	"strconv"
 	"testing"
 )
 
+func TestRotl(t *testing.T) {
+	tests := []struct {
+		in, rotate, exp int64
+	}{
+		{123456789, 33, 1060485742448345088},
+		{-123456789, 33, -1060485733858410497},
+		{-12345678987654, 33, 1756681988166642059},
+
+		{7210216203459776512, 31, -4287945813905642825},
+		{2453826951392495049, 27, -2013042863942636044},
+		{270400184080946339, 33, -3553153987756601583},
+		{2060965185473694757, 31, 6290866853133484661},
+		{3075794793055692309, 33, -3158909918919076318},
+		{-6486402271863858009, 31, 405973038345868736},
+	}
+
+	for _, test := range tests {
+		t.Run(fmt.Sprintf("%d >> %d", test.in, test.rotate), func(t *testing.T) {
+			if v := rotl(test.in, uint8(test.rotate)); v != test.exp {
+				t.Fatalf("expected %d got %d", test.exp, v)
+			}
+		})
+	}
+}
+
+func TestFmix(t *testing.T) {
+	tests := []struct {
+		in, exp int64
+	}{
+		{123456789, -8107560010088384378},
+		{-123456789, -5252787026298255965},
+		{-12345678987654, -1122383578793231303},
+		{-1241537367799374202, 3388197556095096266},
+		{-7566534940689533355, 4729783097411765989},
+	}
+
+	for _, test := range tests {
+		t.Run(strconv.Itoa(int(test.in)), func(t *testing.T) {
+			if v := fmix(test.in); v != test.exp {
+				t.Fatalf("expected %d got %d", test.exp, v)
+			}
+		})
+	}
+
+}
+
+func TestMurmur3H1_CassandraSign(t *testing.T) {
+	key, err := hex.DecodeString("00104327529fb645dd00b883ec39ae448bb800000400066a6b00")
+	if err != nil {
+		t.Fatal(err)
+	}
+	h := Murmur3H1(key)
+	const exp int64 = -9223371632693506265
+
+	if h != exp {
+		t.Fatalf("expected %d got %d", exp, h)
+	}
+}
+
 // Test the implementation of murmur3
 func TestMurmur3H1(t *testing.T) {
 	// these examples are based on adding a index number to a sample string in
@@ -50,8 +111,8 @@ func TestMurmur3H1(t *testing.T) {
 // helper function for testing the murmur3 implementation
 func assertMurmur3H1(t *testing.T, data []byte, expected uint64) {
 	actual := Murmur3H1(data)
-	if actual != expected {
-		t.Errorf("Expected h1 = %x for data = %x, but was %x", expected, data, actual)
+	if actual != int64(expected) {
+		t.Errorf("Expected h1 = %x for data = %x, but was %x", int64(expected), data, actual)
 	}
 }
 
@@ -66,8 +127,8 @@ func BenchmarkMurmur3H1(b *testing.B) {
 	b.RunParallel(func(pb *testing.PB) {
 		for pb.Next() {
 			h1 := Murmur3H1(data)
-			if h1 != uint64(7627370222079200297) {
-				b.Fatalf("expected %d got %d", uint64(7627370222079200297), h1)
+			if h1 != 7627370222079200297 {
+				b.Fatalf("expected %d got %d", 7627370222079200297, h1)
 			}
 		}
 	})

+ 15 - 0
internal/murmur/murmur_unsafe.go

@@ -0,0 +1,15 @@
+// +build !appengine
+
+package murmur
+
+import (
+	"unsafe"
+)
+
+func getBlock(data []byte, n int) (int64, int64) {
+	block := (*[2]int64)(unsafe.Pointer(&data[n*16]))
+
+	k1 := block[0]
+	k2 := block[1]
+	return k1, k2
+}

+ 1 - 1
token.go

@@ -39,7 +39,7 @@ func (p murmur3Partitioner) Name() string {
 
 func (p murmur3Partitioner) Hash(partitionKey []byte) token {
 	h1 := murmur.Murmur3H1(partitionKey)
-	return murmur3Token(int64(h1))
+	return murmur3Token(h1)
 }
 
 // murmur3 little-endian, 128-bit hash, but returns only h1