udt_test.go 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. // +build all integration
  2. package gocql
  3. import (
  4. "fmt"
  5. "testing"
  6. )
  7. type position struct {
  8. Lat int
  9. Lon int
  10. }
  11. // NOTE: due to current implementation details it is not currently possible to use
  12. // a pointer receiver type for the UDTMarshaler interface to handle UDT's
  13. func (p position) MarshalUDT(name string, info TypeInfo) ([]byte, error) {
  14. switch name {
  15. case "lat":
  16. return Marshal(info, p.Lat)
  17. case "lon":
  18. return Marshal(info, p.Lon)
  19. default:
  20. return nil, fmt.Errorf("unknown column for position: %q", name)
  21. }
  22. }
  23. func (p *position) UnmarshalUDT(name string, info TypeInfo, data []byte) error {
  24. switch name {
  25. case "lat":
  26. return Unmarshal(info, data, &p.Lat)
  27. case "lon":
  28. return Unmarshal(info, data, &p.Lon)
  29. default:
  30. return fmt.Errorf("unknown column for position: %q", name)
  31. }
  32. }
  33. func TestUDT_Marshaler(t *testing.T) {
  34. if *flagProto < protoVersion3 {
  35. t.Skip("UDT are only available on protocol >= 3")
  36. }
  37. session := createSession(t)
  38. defer session.Close()
  39. err := createTable(session, `CREATE TYPE position(
  40. lat int,
  41. lon int);`)
  42. if err != nil {
  43. t.Fatal(err)
  44. }
  45. err = createTable(session, `CREATE TABLE houses(
  46. id int,
  47. name text,
  48. loc frozen<position>,
  49. primary key(id)
  50. );`)
  51. if err != nil {
  52. t.Fatal(err)
  53. }
  54. const (
  55. expLat = -1
  56. expLon = 2
  57. )
  58. err = session.Query("INSERT INTO houses(id, name, loc) VALUES(?, ?, ?)", 1, "test", &position{expLat, expLon}).Exec()
  59. if err != nil {
  60. t.Fatal(err)
  61. }
  62. pos := &position{}
  63. err = session.Query("SELECT loc FROM houses WHERE id = ?", 1).Scan(pos)
  64. if err != nil {
  65. t.Fatal(err)
  66. }
  67. if pos.Lat != expLat {
  68. t.Errorf("expeceted lat to be be %d got %d", expLat, pos.Lat)
  69. }
  70. if pos.Lon != expLon {
  71. t.Errorf("expeceted lon to be be %d got %d", expLon, pos.Lon)
  72. }
  73. }