udt_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. // +build all integration
  2. package gocql
  3. import (
  4. "fmt"
  5. "strings"
  6. "testing"
  7. "time"
  8. )
  9. type position struct {
  10. Lat int `cql:"lat"`
  11. Lon int `cql:"lon"`
  12. Padding string `json:"padding"`
  13. }
  14. // NOTE: due to current implementation details it is not currently possible to use
  15. // a pointer receiver type for the UDTMarshaler interface to handle UDT's
  16. func (p position) MarshalUDT(name string, info TypeInfo) ([]byte, error) {
  17. switch name {
  18. case "lat":
  19. return Marshal(info, p.Lat)
  20. case "lon":
  21. return Marshal(info, p.Lon)
  22. case "padding":
  23. return Marshal(info, p.Padding)
  24. default:
  25. return nil, fmt.Errorf("unknown column for position: %q", name)
  26. }
  27. }
  28. func (p *position) UnmarshalUDT(name string, info TypeInfo, data []byte) error {
  29. switch name {
  30. case "lat":
  31. return Unmarshal(info, data, &p.Lat)
  32. case "lon":
  33. return Unmarshal(info, data, &p.Lon)
  34. case "padding":
  35. return Unmarshal(info, data, &p.Padding)
  36. default:
  37. return fmt.Errorf("unknown column for position: %q", name)
  38. }
  39. }
  40. func TestUDT_Marshaler(t *testing.T) {
  41. if *flagProto < protoVersion3 {
  42. t.Skip("UDT are only available on protocol >= 3")
  43. }
  44. session := createSession(t)
  45. defer session.Close()
  46. err := createTable(session, `CREATE TYPE gocql_test.position(
  47. lat int,
  48. lon int,
  49. padding text);`)
  50. if err != nil {
  51. t.Fatal(err)
  52. }
  53. err = createTable(session, `CREATE TABLE gocql_test.houses(
  54. id int,
  55. name text,
  56. loc frozen<position>,
  57. primary key(id)
  58. );`)
  59. if err != nil {
  60. t.Fatal(err)
  61. }
  62. const (
  63. expLat = -1
  64. expLon = 2
  65. )
  66. pad := strings.Repeat("X", 1000)
  67. err = session.Query("INSERT INTO houses(id, name, loc) VALUES(?, ?, ?)", 1, "test", &position{expLat, expLon, pad}).Exec()
  68. if err != nil {
  69. t.Fatal(err)
  70. }
  71. pos := &position{}
  72. err = session.Query("SELECT loc FROM houses WHERE id = ?", 1).Scan(pos)
  73. if err != nil {
  74. t.Fatal(err)
  75. }
  76. if pos.Lat != expLat {
  77. t.Errorf("expeceted lat to be be %d got %d", expLat, pos.Lat)
  78. }
  79. if pos.Lon != expLon {
  80. t.Errorf("expeceted lon to be be %d got %d", expLon, pos.Lon)
  81. }
  82. if pos.Padding != pad {
  83. t.Errorf("expected to get padding %q got %q\n", pad, pos.Padding)
  84. }
  85. }
  86. func TestUDT_Reflect(t *testing.T) {
  87. if *flagProto < protoVersion3 {
  88. t.Skip("UDT are only available on protocol >= 3")
  89. }
  90. // Uses reflection instead of implementing the marshaling type
  91. session := createSession(t)
  92. defer session.Close()
  93. err := createTable(session, `CREATE TYPE gocql_test.horse(
  94. name text,
  95. owner text);`)
  96. if err != nil {
  97. t.Fatal(err)
  98. }
  99. err = createTable(session, `CREATE TABLE gocql_test.horse_race(
  100. position int,
  101. horse frozen<horse>,
  102. primary key(position)
  103. );`)
  104. if err != nil {
  105. t.Fatal(err)
  106. }
  107. type horse struct {
  108. Name string `cql:"name"`
  109. Owner string `cql:"owner"`
  110. }
  111. insertedHorse := &horse{
  112. Name: "pony",
  113. Owner: "jim",
  114. }
  115. err = session.Query("INSERT INTO horse_race(position, horse) VALUES(?, ?)", 1, insertedHorse).Exec()
  116. if err != nil {
  117. t.Fatal(err)
  118. }
  119. retrievedHorse := &horse{}
  120. err = session.Query("SELECT horse FROM horse_race WHERE position = ?", 1).Scan(retrievedHorse)
  121. if err != nil {
  122. t.Fatal(err)
  123. }
  124. if *retrievedHorse != *insertedHorse {
  125. t.Fatalf("expected to get %+v got %+v", insertedHorse, retrievedHorse)
  126. }
  127. }
  128. func TestUDT_Proto2error(t *testing.T) {
  129. // TODO(zariel): move this to marshal test?
  130. _, err := Marshal(NativeType{custom: "org.apache.cassandra.db.marshal.UserType.Type", proto: 2}, 1)
  131. if err != ErrorUDTUnavailable {
  132. t.Fatalf("expected %v got %v", ErrUnavailable, err)
  133. }
  134. }
  135. func TestUDT_NullObject(t *testing.T) {
  136. if *flagProto < protoVersion3 {
  137. t.Skip("UDT are only available on protocol >= 3")
  138. }
  139. session := createSession(t)
  140. defer session.Close()
  141. err := createTable(session, `CREATE TYPE gocql_test.udt_null_type(
  142. name text,
  143. owner text);`)
  144. if err != nil {
  145. t.Fatal(err)
  146. }
  147. err = createTable(session, `CREATE TABLE gocql_test.udt_null_table(
  148. id uuid,
  149. udt_col frozen<udt_null_type>,
  150. primary key(id)
  151. );`)
  152. if err != nil {
  153. t.Fatal(err)
  154. }
  155. type col struct {
  156. Name string `cql:"name"`
  157. Owner string `cql:"owner"`
  158. }
  159. id := TimeUUID()
  160. err = session.Query("INSERT INTO udt_null_table(id) VALUES(?)", id).Exec()
  161. if err != nil {
  162. t.Fatal(err)
  163. }
  164. readCol := &col{
  165. Name: "temp",
  166. Owner: "temp",
  167. }
  168. err = session.Query("SELECT udt_col FROM udt_null_table WHERE id = ?", id).Scan(readCol)
  169. if err != nil {
  170. t.Fatal(err)
  171. }
  172. if readCol.Name != "" {
  173. t.Errorf("expected empty string to be returned for null udt: got %q", readCol.Name)
  174. }
  175. if readCol.Owner != "" {
  176. t.Errorf("expected empty string to be returned for null udt: got %q", readCol.Owner)
  177. }
  178. }
  179. func TestMapScanUDT(t *testing.T) {
  180. if *flagProto < protoVersion3 {
  181. t.Skip("UDT are only available on protocol >= 3")
  182. }
  183. session := createSession(t)
  184. defer session.Close()
  185. err := createTable(session, `CREATE TYPE gocql_test.log_entry (
  186. created_timestamp timestamp,
  187. message text
  188. );`)
  189. if err != nil {
  190. t.Fatal(err)
  191. }
  192. err = createTable(session, `CREATE TABLE gocql_test.requests_by_id (
  193. id uuid PRIMARY KEY,
  194. type int,
  195. log_entries list<frozen <log_entry>>
  196. );`)
  197. if err != nil {
  198. t.Fatal(err)
  199. }
  200. entry := []struct {
  201. CreatedTimestamp time.Time `cql:"created_timestamp"`
  202. Message string `cql:"message"`
  203. }{
  204. {
  205. CreatedTimestamp: time.Now().Truncate(time.Millisecond),
  206. Message: "test time now",
  207. },
  208. }
  209. id, _ := RandomUUID()
  210. const typ = 1
  211. err = session.Query("INSERT INTO requests_by_id(id, type, log_entries) VALUES (?, ?, ?)", id, typ, entry).Exec()
  212. if err != nil {
  213. t.Fatal(err)
  214. }
  215. rawResult := map[string]interface{}{}
  216. err = session.Query(`SELECT * FROM requests_by_id WHERE id = ?`, id).MapScan(rawResult)
  217. if err != nil {
  218. t.Fatal(err)
  219. }
  220. logEntries, ok := rawResult["log_entries"].([]map[string]interface{})
  221. if !ok {
  222. t.Fatal("log_entries not in scanned map")
  223. }
  224. if len(logEntries) != 1 {
  225. t.Fatalf("expected to get 1 log_entry got %d", len(logEntries))
  226. }
  227. logEntry := logEntries[0]
  228. timestamp, ok := logEntry["created_timestamp"]
  229. if !ok {
  230. t.Error("created_timestamp not unmarshalled into map")
  231. } else {
  232. if ts, ok := timestamp.(time.Time); ok {
  233. if !ts.In(time.UTC).Equal(entry[0].CreatedTimestamp.In(time.UTC)) {
  234. t.Errorf("created_timestamp not equal to stored: got %v expected %v", ts.In(time.UTC), entry[0].CreatedTimestamp.In(time.UTC))
  235. }
  236. } else {
  237. t.Errorf("created_timestamp was not time.Time got: %T", timestamp)
  238. }
  239. }
  240. message, ok := logEntry["message"]
  241. if !ok {
  242. t.Error("message not unmarshalled into map")
  243. } else {
  244. if ts, ok := message.(string); ok {
  245. if ts != message {
  246. t.Errorf("message not equal to stored: got %v expected %v", ts, entry[0].Message)
  247. }
  248. } else {
  249. t.Errorf("message was not string got: %T", message)
  250. }
  251. }
  252. }
  253. func TestUDT_MissingField(t *testing.T) {
  254. if *flagProto < protoVersion3 {
  255. t.Skip("UDT are only available on protocol >= 3")
  256. }
  257. session := createSession(t)
  258. defer session.Close()
  259. err := createTable(session, `CREATE TYPE gocql_test.missing_field(
  260. name text,
  261. owner text);`)
  262. if err != nil {
  263. t.Fatal(err)
  264. }
  265. err = createTable(session, `CREATE TABLE gocql_test.missing_field(
  266. id uuid,
  267. udt_col frozen<udt_null_type>,
  268. primary key(id)
  269. );`)
  270. if err != nil {
  271. t.Fatal(err)
  272. }
  273. type col struct {
  274. Name string `cql:"name"`
  275. }
  276. writeCol := &col{
  277. Name: "test",
  278. }
  279. id := TimeUUID()
  280. err = session.Query("INSERT INTO missing_field(id, udt_col) VALUES(?, ?)", id, writeCol).Exec()
  281. if err != nil {
  282. t.Fatal(err)
  283. }
  284. readCol := &col{}
  285. err = session.Query("SELECT udt_col FROM missing_field WHERE id = ?", id).Scan(readCol)
  286. if err != nil {
  287. t.Fatal(err)
  288. }
  289. if readCol.Name != writeCol.Name {
  290. t.Errorf("expected %q: got %q", writeCol.Name, readCol.Name)
  291. }
  292. }
  293. func TestUDT_EmptyCollections(t *testing.T) {
  294. if *flagProto < protoVersion3 {
  295. t.Skip("UDT are only available on protocol >= 3")
  296. }
  297. session := createSession(t)
  298. defer session.Close()
  299. err := createTable(session, `CREATE TYPE gocql_test.nil_collections(
  300. a list<text>,
  301. b map<text, text>,
  302. c set<text>
  303. );`)
  304. if err != nil {
  305. t.Fatal(err)
  306. }
  307. err = createTable(session, `CREATE TABLE gocql_test.nil_collections(
  308. id uuid,
  309. udt_col frozen<nil_collections>,
  310. primary key(id)
  311. );`)
  312. if err != nil {
  313. t.Fatal(err)
  314. }
  315. type udt struct {
  316. A []string `cql:"a"`
  317. B map[string]string `cql:"b"`
  318. C []string `cql:"c"`
  319. }
  320. id := TimeUUID()
  321. err = session.Query("INSERT INTO nil_collections(id, udt_col) VALUES(?, ?)", id, &udt{}).Exec()
  322. if err != nil {
  323. t.Fatal(err)
  324. }
  325. var val udt
  326. err = session.Query("SELECT udt_col FROM nil_collections WHERE id=?", id).Scan(&val)
  327. if err != nil {
  328. t.Fatal(err)
  329. }
  330. if val.A != nil {
  331. t.Errorf("expected to get nil got %#+v", val.A)
  332. }
  333. if val.B != nil {
  334. t.Errorf("expected to get nil got %#+v", val.B)
  335. }
  336. if val.C != nil {
  337. t.Errorf("expected to get nil got %#+v", val.C)
  338. }
  339. }
  340. func TestUDT_UpdateField(t *testing.T) {
  341. if *flagProto < protoVersion3 {
  342. t.Skip("UDT are only available on protocol >= 3")
  343. }
  344. session := createSession(t)
  345. defer session.Close()
  346. err := createTable(session, `CREATE TYPE gocql_test.update_field_udt(
  347. name text,
  348. owner text);`)
  349. if err != nil {
  350. t.Fatal(err)
  351. }
  352. err = createTable(session, `CREATE TABLE gocql_test.update_field(
  353. id uuid,
  354. udt_col frozen<update_field_udt>,
  355. primary key(id)
  356. );`)
  357. if err != nil {
  358. t.Fatal(err)
  359. }
  360. type col struct {
  361. Name string `cql:"name"`
  362. Owner string `cql:"owner"`
  363. Data string `cql:"data"`
  364. }
  365. writeCol := &col{
  366. Name: "test-name",
  367. Owner: "test-owner",
  368. }
  369. id := TimeUUID()
  370. err = session.Query("INSERT INTO update_field(id, udt_col) VALUES(?, ?)", id, writeCol).Exec()
  371. if err != nil {
  372. t.Fatal(err)
  373. }
  374. if err := createTable(session, `ALTER TYPE gocql_test.update_field_udt ADD data text;`); err != nil {
  375. t.Fatal(err)
  376. }
  377. readCol := &col{}
  378. err = session.Query("SELECT udt_col FROM update_field WHERE id = ?", id).Scan(readCol)
  379. if err != nil {
  380. t.Fatal(err)
  381. }
  382. if *readCol != *writeCol {
  383. t.Errorf("expected %+v: got %+v", *writeCol, *readCol)
  384. }
  385. }
  386. func TestUDT_ScanNullUDT(t *testing.T) {
  387. if *flagProto < protoVersion3 {
  388. t.Skip("UDT are only available on protocol >= 3")
  389. }
  390. session := createSession(t)
  391. defer session.Close()
  392. err := createTable(session, `CREATE TYPE gocql_test.scan_null_udt_position(
  393. lat int,
  394. lon int,
  395. padding text);`)
  396. if err != nil {
  397. t.Fatal(err)
  398. }
  399. err = createTable(session, `CREATE TABLE gocql_test.scan_null_udt_houses(
  400. id int,
  401. name text,
  402. loc frozen<position>,
  403. primary key(id)
  404. );`)
  405. if err != nil {
  406. t.Fatal(err)
  407. }
  408. err = session.Query("INSERT INTO scan_null_udt_houses(id, name) VALUES(?, ?)", 1, "test").Exec()
  409. if err != nil {
  410. t.Fatal(err)
  411. }
  412. pos := &position{}
  413. err = session.Query("SELECT loc FROM scan_null_udt_houses WHERE id = ?", 1).Scan(pos)
  414. if err != nil {
  415. t.Fatal(err)
  416. }
  417. }