conn.go 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854
  1. package pq
  2. import (
  3. "bufio"
  4. "crypto/md5"
  5. "database/sql"
  6. "database/sql/driver"
  7. "encoding/binary"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "net"
  12. "os"
  13. "os/user"
  14. "path"
  15. "path/filepath"
  16. "strconv"
  17. "strings"
  18. "time"
  19. "unicode"
  20. "github.com/lib/pq/oid"
  21. )
  22. // Common error types
  23. var (
  24. ErrNotSupported = errors.New("pq: Unsupported command")
  25. ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction")
  26. ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
  27. ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less")
  28. ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly")
  29. errUnexpectedReady = errors.New("unexpected ReadyForQuery")
  30. errNoRowsAffected = errors.New("no RowsAffected available after the empty statement")
  31. errNoLastInsertID = errors.New("no LastInsertId available after the empty statement")
  32. )
  33. // Driver is the Postgres database driver.
  34. type Driver struct{}
  35. // Open opens a new connection to the database. name is a connection string.
  36. // Most users should only use it through database/sql package from the standard
  37. // library.
  38. func (d *Driver) Open(name string) (driver.Conn, error) {
  39. return Open(name)
  40. }
  41. func init() {
  42. sql.Register("postgres", &Driver{})
  43. }
  44. type parameterStatus struct {
  45. // server version in the same format as server_version_num, or 0 if
  46. // unavailable
  47. serverVersion int
  48. // the current location based on the TimeZone value of the session, if
  49. // available
  50. currentLocation *time.Location
  51. }
  52. type transactionStatus byte
  53. const (
  54. txnStatusIdle transactionStatus = 'I'
  55. txnStatusIdleInTransaction transactionStatus = 'T'
  56. txnStatusInFailedTransaction transactionStatus = 'E'
  57. )
  58. func (s transactionStatus) String() string {
  59. switch s {
  60. case txnStatusIdle:
  61. return "idle"
  62. case txnStatusIdleInTransaction:
  63. return "idle in transaction"
  64. case txnStatusInFailedTransaction:
  65. return "in a failed transaction"
  66. default:
  67. errorf("unknown transactionStatus %d", s)
  68. }
  69. panic("not reached")
  70. }
  71. // Dialer is the dialer interface. It can be used to obtain more control over
  72. // how pq creates network connections.
  73. type Dialer interface {
  74. Dial(network, address string) (net.Conn, error)
  75. DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
  76. }
  77. type defaultDialer struct{}
  78. func (d defaultDialer) Dial(ntw, addr string) (net.Conn, error) {
  79. return net.Dial(ntw, addr)
  80. }
  81. func (d defaultDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) {
  82. return net.DialTimeout(ntw, addr, timeout)
  83. }
  84. type conn struct {
  85. c net.Conn
  86. buf *bufio.Reader
  87. namei int
  88. scratch [512]byte
  89. txnStatus transactionStatus
  90. txnFinish func()
  91. // Save connection arguments to use during CancelRequest.
  92. dialer Dialer
  93. opts values
  94. // Cancellation key data for use with CancelRequest messages.
  95. processID int
  96. secretKey int
  97. parameterStatus parameterStatus
  98. saveMessageType byte
  99. saveMessageBuffer []byte
  100. // If true, this connection is bad and all public-facing functions should
  101. // return ErrBadConn.
  102. bad bool
  103. // If set, this connection should never use the binary format when
  104. // receiving query results from prepared statements. Only provided for
  105. // debugging.
  106. disablePreparedBinaryResult bool
  107. // Whether to always send []byte parameters over as binary. Enables single
  108. // round-trip mode for non-prepared Query calls.
  109. binaryParameters bool
  110. // If true this connection is in the middle of a COPY
  111. inCopy bool
  112. }
  113. // Handle driver-side settings in parsed connection string.
  114. func (cn *conn) handleDriverSettings(o values) (err error) {
  115. boolSetting := func(key string, val *bool) error {
  116. if value, ok := o[key]; ok {
  117. if value == "yes" {
  118. *val = true
  119. } else if value == "no" {
  120. *val = false
  121. } else {
  122. return fmt.Errorf("unrecognized value %q for %s", value, key)
  123. }
  124. }
  125. return nil
  126. }
  127. err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult)
  128. if err != nil {
  129. return err
  130. }
  131. return boolSetting("binary_parameters", &cn.binaryParameters)
  132. }
  133. func (cn *conn) handlePgpass(o values) {
  134. // if a password was supplied, do not process .pgpass
  135. if _, ok := o["password"]; ok {
  136. return
  137. }
  138. filename := os.Getenv("PGPASSFILE")
  139. if filename == "" {
  140. // XXX this code doesn't work on Windows where the default filename is
  141. // XXX %APPDATA%\postgresql\pgpass.conf
  142. // Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470
  143. userHome := os.Getenv("HOME")
  144. if userHome == "" {
  145. user, err := user.Current()
  146. if err != nil {
  147. return
  148. }
  149. userHome = user.HomeDir
  150. }
  151. filename = filepath.Join(userHome, ".pgpass")
  152. }
  153. fileinfo, err := os.Stat(filename)
  154. if err != nil {
  155. return
  156. }
  157. mode := fileinfo.Mode()
  158. if mode&(0x77) != 0 {
  159. // XXX should warn about incorrect .pgpass permissions as psql does
  160. return
  161. }
  162. file, err := os.Open(filename)
  163. if err != nil {
  164. return
  165. }
  166. defer file.Close()
  167. scanner := bufio.NewScanner(io.Reader(file))
  168. hostname := o["host"]
  169. ntw, _ := network(o)
  170. port := o["port"]
  171. db := o["dbname"]
  172. username := o["user"]
  173. // From: https://github.com/tg/pgpass/blob/master/reader.go
  174. getFields := func(s string) []string {
  175. fs := make([]string, 0, 5)
  176. f := make([]rune, 0, len(s))
  177. var esc bool
  178. for _, c := range s {
  179. switch {
  180. case esc:
  181. f = append(f, c)
  182. esc = false
  183. case c == '\\':
  184. esc = true
  185. case c == ':':
  186. fs = append(fs, string(f))
  187. f = f[:0]
  188. default:
  189. f = append(f, c)
  190. }
  191. }
  192. return append(fs, string(f))
  193. }
  194. for scanner.Scan() {
  195. line := scanner.Text()
  196. if len(line) == 0 || line[0] == '#' {
  197. continue
  198. }
  199. split := getFields(line)
  200. if len(split) != 5 {
  201. continue
  202. }
  203. if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) {
  204. o["password"] = split[4]
  205. return
  206. }
  207. }
  208. }
  209. func (cn *conn) writeBuf(b byte) *writeBuf {
  210. cn.scratch[0] = b
  211. return &writeBuf{
  212. buf: cn.scratch[:5],
  213. pos: 1,
  214. }
  215. }
  216. // Open opens a new connection to the database. name is a connection string.
  217. // Most users should only use it through database/sql package from the standard
  218. // library.
  219. func Open(name string) (_ driver.Conn, err error) {
  220. return DialOpen(defaultDialer{}, name)
  221. }
  222. // DialOpen opens a new connection to the database using a dialer.
  223. func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
  224. // Handle any panics during connection initialization. Note that we
  225. // specifically do *not* want to use errRecover(), as that would turn any
  226. // connection errors into ErrBadConns, hiding the real error message from
  227. // the user.
  228. defer errRecoverNoErrBadConn(&err)
  229. o := make(values)
  230. // A number of defaults are applied here, in this order:
  231. //
  232. // * Very low precedence defaults applied in every situation
  233. // * Environment variables
  234. // * Explicitly passed connection information
  235. o["host"] = "localhost"
  236. o["port"] = "5432"
  237. // N.B.: Extra float digits should be set to 3, but that breaks
  238. // Postgres 8.4 and older, where the max is 2.
  239. o["extra_float_digits"] = "2"
  240. for k, v := range parseEnviron(os.Environ()) {
  241. o[k] = v
  242. }
  243. if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") {
  244. name, err = ParseURL(name)
  245. if err != nil {
  246. return nil, err
  247. }
  248. }
  249. if err := parseOpts(name, o); err != nil {
  250. return nil, err
  251. }
  252. // Use the "fallback" application name if necessary
  253. if fallback, ok := o["fallback_application_name"]; ok {
  254. if _, ok := o["application_name"]; !ok {
  255. o["application_name"] = fallback
  256. }
  257. }
  258. // We can't work with any client_encoding other than UTF-8 currently.
  259. // However, we have historically allowed the user to set it to UTF-8
  260. // explicitly, and there's no reason to break such programs, so allow that.
  261. // Note that the "options" setting could also set client_encoding, but
  262. // parsing its value is not worth it. Instead, we always explicitly send
  263. // client_encoding as a separate run-time parameter, which should override
  264. // anything set in options.
  265. if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) {
  266. return nil, errors.New("client_encoding must be absent or 'UTF8'")
  267. }
  268. o["client_encoding"] = "UTF8"
  269. // DateStyle needs a similar treatment.
  270. if datestyle, ok := o["datestyle"]; ok {
  271. if datestyle != "ISO, MDY" {
  272. panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v",
  273. "ISO, MDY", datestyle))
  274. }
  275. } else {
  276. o["datestyle"] = "ISO, MDY"
  277. }
  278. // If a user is not provided by any other means, the last
  279. // resort is to use the current operating system provided user
  280. // name.
  281. if _, ok := o["user"]; !ok {
  282. u, err := userCurrent()
  283. if err != nil {
  284. return nil, err
  285. }
  286. o["user"] = u
  287. }
  288. cn := &conn{
  289. opts: o,
  290. dialer: d,
  291. }
  292. err = cn.handleDriverSettings(o)
  293. if err != nil {
  294. return nil, err
  295. }
  296. cn.handlePgpass(o)
  297. cn.c, err = dial(d, o)
  298. if err != nil {
  299. return nil, err
  300. }
  301. err = cn.ssl(o)
  302. if err != nil {
  303. return nil, err
  304. }
  305. // cn.startup panics on error. Make sure we don't leak cn.c.
  306. panicking := true
  307. defer func() {
  308. if panicking {
  309. cn.c.Close()
  310. }
  311. }()
  312. cn.buf = bufio.NewReader(cn.c)
  313. cn.startup(o)
  314. // reset the deadline, in case one was set (see dial)
  315. if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
  316. err = cn.c.SetDeadline(time.Time{})
  317. }
  318. panicking = false
  319. return cn, err
  320. }
  321. func dial(d Dialer, o values) (net.Conn, error) {
  322. ntw, addr := network(o)
  323. // SSL is not necessary or supported over UNIX domain sockets
  324. if ntw == "unix" {
  325. o["sslmode"] = "disable"
  326. }
  327. // Zero or not specified means wait indefinitely.
  328. if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
  329. seconds, err := strconv.ParseInt(timeout, 10, 0)
  330. if err != nil {
  331. return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
  332. }
  333. duration := time.Duration(seconds) * time.Second
  334. // connect_timeout should apply to the entire connection establishment
  335. // procedure, so we both use a timeout for the TCP connection
  336. // establishment and set a deadline for doing the initial handshake.
  337. // The deadline is then reset after startup() is done.
  338. deadline := time.Now().Add(duration)
  339. conn, err := d.DialTimeout(ntw, addr, duration)
  340. if err != nil {
  341. return nil, err
  342. }
  343. err = conn.SetDeadline(deadline)
  344. return conn, err
  345. }
  346. return d.Dial(ntw, addr)
  347. }
  348. func network(o values) (string, string) {
  349. host := o["host"]
  350. if strings.HasPrefix(host, "/") {
  351. sockPath := path.Join(host, ".s.PGSQL."+o["port"])
  352. return "unix", sockPath
  353. }
  354. return "tcp", net.JoinHostPort(host, o["port"])
  355. }
  356. type values map[string]string
  357. // scanner implements a tokenizer for libpq-style option strings.
  358. type scanner struct {
  359. s []rune
  360. i int
  361. }
  362. // newScanner returns a new scanner initialized with the option string s.
  363. func newScanner(s string) *scanner {
  364. return &scanner{[]rune(s), 0}
  365. }
  366. // Next returns the next rune.
  367. // It returns 0, false if the end of the text has been reached.
  368. func (s *scanner) Next() (rune, bool) {
  369. if s.i >= len(s.s) {
  370. return 0, false
  371. }
  372. r := s.s[s.i]
  373. s.i++
  374. return r, true
  375. }
  376. // SkipSpaces returns the next non-whitespace rune.
  377. // It returns 0, false if the end of the text has been reached.
  378. func (s *scanner) SkipSpaces() (rune, bool) {
  379. r, ok := s.Next()
  380. for unicode.IsSpace(r) && ok {
  381. r, ok = s.Next()
  382. }
  383. return r, ok
  384. }
  385. // parseOpts parses the options from name and adds them to the values.
  386. //
  387. // The parsing code is based on conninfo_parse from libpq's fe-connect.c
  388. func parseOpts(name string, o values) error {
  389. s := newScanner(name)
  390. for {
  391. var (
  392. keyRunes, valRunes []rune
  393. r rune
  394. ok bool
  395. )
  396. if r, ok = s.SkipSpaces(); !ok {
  397. break
  398. }
  399. // Scan the key
  400. for !unicode.IsSpace(r) && r != '=' {
  401. keyRunes = append(keyRunes, r)
  402. if r, ok = s.Next(); !ok {
  403. break
  404. }
  405. }
  406. // Skip any whitespace if we're not at the = yet
  407. if r != '=' {
  408. r, ok = s.SkipSpaces()
  409. }
  410. // The current character should be =
  411. if r != '=' || !ok {
  412. return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
  413. }
  414. // Skip any whitespace after the =
  415. if r, ok = s.SkipSpaces(); !ok {
  416. // If we reach the end here, the last value is just an empty string as per libpq.
  417. o[string(keyRunes)] = ""
  418. break
  419. }
  420. if r != '\'' {
  421. for !unicode.IsSpace(r) {
  422. if r == '\\' {
  423. if r, ok = s.Next(); !ok {
  424. return fmt.Errorf(`missing character after backslash`)
  425. }
  426. }
  427. valRunes = append(valRunes, r)
  428. if r, ok = s.Next(); !ok {
  429. break
  430. }
  431. }
  432. } else {
  433. quote:
  434. for {
  435. if r, ok = s.Next(); !ok {
  436. return fmt.Errorf(`unterminated quoted string literal in connection string`)
  437. }
  438. switch r {
  439. case '\'':
  440. break quote
  441. case '\\':
  442. r, _ = s.Next()
  443. fallthrough
  444. default:
  445. valRunes = append(valRunes, r)
  446. }
  447. }
  448. }
  449. o[string(keyRunes)] = string(valRunes)
  450. }
  451. return nil
  452. }
  453. func (cn *conn) isInTransaction() bool {
  454. return cn.txnStatus == txnStatusIdleInTransaction ||
  455. cn.txnStatus == txnStatusInFailedTransaction
  456. }
  457. func (cn *conn) checkIsInTransaction(intxn bool) {
  458. if cn.isInTransaction() != intxn {
  459. cn.bad = true
  460. errorf("unexpected transaction status %v", cn.txnStatus)
  461. }
  462. }
  463. func (cn *conn) Begin() (_ driver.Tx, err error) {
  464. return cn.begin("")
  465. }
  466. func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
  467. if cn.bad {
  468. return nil, driver.ErrBadConn
  469. }
  470. defer cn.errRecover(&err)
  471. cn.checkIsInTransaction(false)
  472. _, commandTag, err := cn.simpleExec("BEGIN" + mode)
  473. if err != nil {
  474. return nil, err
  475. }
  476. if commandTag != "BEGIN" {
  477. cn.bad = true
  478. return nil, fmt.Errorf("unexpected command tag %s", commandTag)
  479. }
  480. if cn.txnStatus != txnStatusIdleInTransaction {
  481. cn.bad = true
  482. return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
  483. }
  484. return cn, nil
  485. }
  486. func (cn *conn) closeTxn() {
  487. if finish := cn.txnFinish; finish != nil {
  488. finish()
  489. }
  490. }
  491. func (cn *conn) Commit() (err error) {
  492. defer cn.closeTxn()
  493. if cn.bad {
  494. return driver.ErrBadConn
  495. }
  496. defer cn.errRecover(&err)
  497. cn.checkIsInTransaction(true)
  498. // We don't want the client to think that everything is okay if it tries
  499. // to commit a failed transaction. However, no matter what we return,
  500. // database/sql will release this connection back into the free connection
  501. // pool so we have to abort the current transaction here. Note that you
  502. // would get the same behaviour if you issued a COMMIT in a failed
  503. // transaction, so it's also the least surprising thing to do here.
  504. if cn.txnStatus == txnStatusInFailedTransaction {
  505. if err := cn.Rollback(); err != nil {
  506. return err
  507. }
  508. return ErrInFailedTransaction
  509. }
  510. _, commandTag, err := cn.simpleExec("COMMIT")
  511. if err != nil {
  512. if cn.isInTransaction() {
  513. cn.bad = true
  514. }
  515. return err
  516. }
  517. if commandTag != "COMMIT" {
  518. cn.bad = true
  519. return fmt.Errorf("unexpected command tag %s", commandTag)
  520. }
  521. cn.checkIsInTransaction(false)
  522. return nil
  523. }
  524. func (cn *conn) Rollback() (err error) {
  525. defer cn.closeTxn()
  526. if cn.bad {
  527. return driver.ErrBadConn
  528. }
  529. defer cn.errRecover(&err)
  530. cn.checkIsInTransaction(true)
  531. _, commandTag, err := cn.simpleExec("ROLLBACK")
  532. if err != nil {
  533. if cn.isInTransaction() {
  534. cn.bad = true
  535. }
  536. return err
  537. }
  538. if commandTag != "ROLLBACK" {
  539. return fmt.Errorf("unexpected command tag %s", commandTag)
  540. }
  541. cn.checkIsInTransaction(false)
  542. return nil
  543. }
  544. func (cn *conn) gname() string {
  545. cn.namei++
  546. return strconv.FormatInt(int64(cn.namei), 10)
  547. }
  548. func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
  549. b := cn.writeBuf('Q')
  550. b.string(q)
  551. cn.send(b)
  552. for {
  553. t, r := cn.recv1()
  554. switch t {
  555. case 'C':
  556. res, commandTag = cn.parseComplete(r.string())
  557. case 'Z':
  558. cn.processReadyForQuery(r)
  559. if res == nil && err == nil {
  560. err = errUnexpectedReady
  561. }
  562. // done
  563. return
  564. case 'E':
  565. err = parseError(r)
  566. case 'I':
  567. res = emptyRows
  568. case 'T', 'D':
  569. // ignore any results
  570. default:
  571. cn.bad = true
  572. errorf("unknown response for simple query: %q", t)
  573. }
  574. }
  575. }
  576. func (cn *conn) simpleQuery(q string) (res *rows, err error) {
  577. defer cn.errRecover(&err)
  578. b := cn.writeBuf('Q')
  579. b.string(q)
  580. cn.send(b)
  581. for {
  582. t, r := cn.recv1()
  583. switch t {
  584. case 'C', 'I':
  585. // We allow queries which don't return any results through Query as
  586. // well as Exec. We still have to give database/sql a rows object
  587. // the user can close, though, to avoid connections from being
  588. // leaked. A "rows" with done=true works fine for that purpose.
  589. if err != nil {
  590. cn.bad = true
  591. errorf("unexpected message %q in simple query execution", t)
  592. }
  593. if res == nil {
  594. res = &rows{
  595. cn: cn,
  596. }
  597. }
  598. // Set the result and tag to the last command complete if there wasn't a
  599. // query already run. Although queries usually return from here and cede
  600. // control to Next, a query with zero results does not.
  601. if t == 'C' && res.colNames == nil {
  602. res.result, res.tag = cn.parseComplete(r.string())
  603. }
  604. res.done = true
  605. case 'Z':
  606. cn.processReadyForQuery(r)
  607. // done
  608. return
  609. case 'E':
  610. res = nil
  611. err = parseError(r)
  612. case 'D':
  613. if res == nil {
  614. cn.bad = true
  615. errorf("unexpected DataRow in simple query execution")
  616. }
  617. // the query didn't fail; kick off to Next
  618. cn.saveMessage(t, r)
  619. return
  620. case 'T':
  621. // res might be non-nil here if we received a previous
  622. // CommandComplete, but that's fine; just overwrite it
  623. res = &rows{cn: cn}
  624. res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r)
  625. // To work around a bug in QueryRow in Go 1.2 and earlier, wait
  626. // until the first DataRow has been received.
  627. default:
  628. cn.bad = true
  629. errorf("unknown response for simple query: %q", t)
  630. }
  631. }
  632. }
  633. type noRows struct{}
  634. var emptyRows noRows
  635. var _ driver.Result = noRows{}
  636. func (noRows) LastInsertId() (int64, error) {
  637. return 0, errNoLastInsertID
  638. }
  639. func (noRows) RowsAffected() (int64, error) {
  640. return 0, errNoRowsAffected
  641. }
  642. // Decides which column formats to use for a prepared statement. The input is
  643. // an array of type oids, one element per result column.
  644. func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) {
  645. if len(colTyps) == 0 {
  646. return nil, colFmtDataAllText
  647. }
  648. colFmts = make([]format, len(colTyps))
  649. if forceText {
  650. return colFmts, colFmtDataAllText
  651. }
  652. allBinary := true
  653. allText := true
  654. for i, t := range colTyps {
  655. switch t.OID {
  656. // This is the list of types to use binary mode for when receiving them
  657. // through a prepared statement. If a type appears in this list, it
  658. // must also be implemented in binaryDecode in encode.go.
  659. case oid.T_bytea:
  660. fallthrough
  661. case oid.T_int8:
  662. fallthrough
  663. case oid.T_int4:
  664. fallthrough
  665. case oid.T_int2:
  666. fallthrough
  667. case oid.T_uuid:
  668. colFmts[i] = formatBinary
  669. allText = false
  670. default:
  671. allBinary = false
  672. }
  673. }
  674. if allBinary {
  675. return colFmts, colFmtDataAllBinary
  676. } else if allText {
  677. return colFmts, colFmtDataAllText
  678. } else {
  679. colFmtData = make([]byte, 2+len(colFmts)*2)
  680. binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts)))
  681. for i, v := range colFmts {
  682. binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v))
  683. }
  684. return colFmts, colFmtData
  685. }
  686. }
  687. func (cn *conn) prepareTo(q, stmtName string) *stmt {
  688. st := &stmt{cn: cn, name: stmtName}
  689. b := cn.writeBuf('P')
  690. b.string(st.name)
  691. b.string(q)
  692. b.int16(0)
  693. b.next('D')
  694. b.byte('S')
  695. b.string(st.name)
  696. b.next('S')
  697. cn.send(b)
  698. cn.readParseResponse()
  699. st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
  700. st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult)
  701. cn.readReadyForQuery()
  702. return st
  703. }
  704. func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
  705. if cn.bad {
  706. return nil, driver.ErrBadConn
  707. }
  708. defer cn.errRecover(&err)
  709. if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
  710. s, err := cn.prepareCopyIn(q)
  711. if err == nil {
  712. cn.inCopy = true
  713. }
  714. return s, err
  715. }
  716. return cn.prepareTo(q, cn.gname()), nil
  717. }
  718. func (cn *conn) Close() (err error) {
  719. // Skip cn.bad return here because we always want to close a connection.
  720. defer cn.errRecover(&err)
  721. // Ensure that cn.c.Close is always run. Since error handling is done with
  722. // panics and cn.errRecover, the Close must be in a defer.
  723. defer func() {
  724. cerr := cn.c.Close()
  725. if err == nil {
  726. err = cerr
  727. }
  728. }()
  729. // Don't go through send(); ListenerConn relies on us not scribbling on the
  730. // scratch buffer of this connection.
  731. return cn.sendSimpleMessage('X')
  732. }
  733. // Implement the "Queryer" interface
  734. func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
  735. return cn.query(query, args)
  736. }
  737. func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
  738. if cn.bad {
  739. return nil, driver.ErrBadConn
  740. }
  741. if cn.inCopy {
  742. return nil, errCopyInProgress
  743. }
  744. defer cn.errRecover(&err)
  745. // Check to see if we can use the "simpleQuery" interface, which is
  746. // *much* faster than going through prepare/exec
  747. if len(args) == 0 {
  748. return cn.simpleQuery(query)
  749. }
  750. if cn.binaryParameters {
  751. cn.sendBinaryModeQuery(query, args)
  752. cn.readParseResponse()
  753. cn.readBindResponse()
  754. rows := &rows{cn: cn}
  755. rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse()
  756. cn.postExecuteWorkaround()
  757. return rows, nil
  758. }
  759. st := cn.prepareTo(query, "")
  760. st.exec(args)
  761. return &rows{
  762. cn: cn,
  763. colNames: st.colNames,
  764. colTyps: st.colTyps,
  765. colFmts: st.colFmts,
  766. }, nil
  767. }
  768. // Implement the optional "Execer" interface for one-shot queries
  769. func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
  770. if cn.bad {
  771. return nil, driver.ErrBadConn
  772. }
  773. defer cn.errRecover(&err)
  774. // Check to see if we can use the "simpleExec" interface, which is
  775. // *much* faster than going through prepare/exec
  776. if len(args) == 0 {
  777. // ignore commandTag, our caller doesn't care
  778. r, _, err := cn.simpleExec(query)
  779. return r, err
  780. }
  781. if cn.binaryParameters {
  782. cn.sendBinaryModeQuery(query, args)
  783. cn.readParseResponse()
  784. cn.readBindResponse()
  785. cn.readPortalDescribeResponse()
  786. cn.postExecuteWorkaround()
  787. res, _, err = cn.readExecuteResponse("Execute")
  788. return res, err
  789. }
  790. // Use the unnamed statement to defer planning until bind
  791. // time, or else value-based selectivity estimates cannot be
  792. // used.
  793. st := cn.prepareTo(query, "")
  794. r, err := st.Exec(args)
  795. if err != nil {
  796. panic(err)
  797. }
  798. return r, err
  799. }
  800. func (cn *conn) send(m *writeBuf) {
  801. _, err := cn.c.Write(m.wrap())
  802. if err != nil {
  803. panic(err)
  804. }
  805. }
  806. func (cn *conn) sendStartupPacket(m *writeBuf) error {
  807. _, err := cn.c.Write((m.wrap())[1:])
  808. return err
  809. }
  810. // Send a message of type typ to the server on the other end of cn. The
  811. // message should have no payload. This method does not use the scratch
  812. // buffer.
  813. func (cn *conn) sendSimpleMessage(typ byte) (err error) {
  814. _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'})
  815. return err
  816. }
  817. // saveMessage memorizes a message and its buffer in the conn struct.
  818. // recvMessage will then return these values on the next call to it. This
  819. // method is useful in cases where you have to see what the next message is
  820. // going to be (e.g. to see whether it's an error or not) but you can't handle
  821. // the message yourself.
  822. func (cn *conn) saveMessage(typ byte, buf *readBuf) {
  823. if cn.saveMessageType != 0 {
  824. cn.bad = true
  825. errorf("unexpected saveMessageType %d", cn.saveMessageType)
  826. }
  827. cn.saveMessageType = typ
  828. cn.saveMessageBuffer = *buf
  829. }
  830. // recvMessage receives any message from the backend, or returns an error if
  831. // a problem occurred while reading the message.
  832. func (cn *conn) recvMessage(r *readBuf) (byte, error) {
  833. // workaround for a QueryRow bug, see exec
  834. if cn.saveMessageType != 0 {
  835. t := cn.saveMessageType
  836. *r = cn.saveMessageBuffer
  837. cn.saveMessageType = 0
  838. cn.saveMessageBuffer = nil
  839. return t, nil
  840. }
  841. x := cn.scratch[:5]
  842. _, err := io.ReadFull(cn.buf, x)
  843. if err != nil {
  844. return 0, err
  845. }
  846. // read the type and length of the message that follows
  847. t := x[0]
  848. n := int(binary.BigEndian.Uint32(x[1:])) - 4
  849. var y []byte
  850. if n <= len(cn.scratch) {
  851. y = cn.scratch[:n]
  852. } else {
  853. y = make([]byte, n)
  854. }
  855. _, err = io.ReadFull(cn.buf, y)
  856. if err != nil {
  857. return 0, err
  858. }
  859. *r = y
  860. return t, nil
  861. }
  862. // recv receives a message from the backend, but if an error happened while
  863. // reading the message or the received message was an ErrorResponse, it panics.
  864. // NoticeResponses are ignored. This function should generally be used only
  865. // during the startup sequence.
  866. func (cn *conn) recv() (t byte, r *readBuf) {
  867. for {
  868. var err error
  869. r = &readBuf{}
  870. t, err = cn.recvMessage(r)
  871. if err != nil {
  872. panic(err)
  873. }
  874. switch t {
  875. case 'E':
  876. panic(parseError(r))
  877. case 'N':
  878. // ignore
  879. default:
  880. return
  881. }
  882. }
  883. }
  884. // recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by
  885. // the caller to avoid an allocation.
  886. func (cn *conn) recv1Buf(r *readBuf) byte {
  887. for {
  888. t, err := cn.recvMessage(r)
  889. if err != nil {
  890. panic(err)
  891. }
  892. switch t {
  893. case 'A', 'N':
  894. // ignore
  895. case 'S':
  896. cn.processParameterStatus(r)
  897. default:
  898. return t
  899. }
  900. }
  901. }
  902. // recv1 receives a message from the backend, panicking if an error occurs
  903. // while attempting to read it. All asynchronous messages are ignored, with
  904. // the exception of ErrorResponse.
  905. func (cn *conn) recv1() (t byte, r *readBuf) {
  906. r = &readBuf{}
  907. t = cn.recv1Buf(r)
  908. return t, r
  909. }
  910. func (cn *conn) ssl(o values) error {
  911. upgrade, err := ssl(o)
  912. if err != nil {
  913. return err
  914. }
  915. if upgrade == nil {
  916. // Nothing to do
  917. return nil
  918. }
  919. w := cn.writeBuf(0)
  920. w.int32(80877103)
  921. if err = cn.sendStartupPacket(w); err != nil {
  922. return err
  923. }
  924. b := cn.scratch[:1]
  925. _, err = io.ReadFull(cn.c, b)
  926. if err != nil {
  927. return err
  928. }
  929. if b[0] != 'S' {
  930. return ErrSSLNotSupported
  931. }
  932. cn.c, err = upgrade(cn.c)
  933. return err
  934. }
  935. // isDriverSetting returns true iff a setting is purely for configuring the
  936. // driver's options and should not be sent to the server in the connection
  937. // startup packet.
  938. func isDriverSetting(key string) bool {
  939. switch key {
  940. case "host", "port":
  941. return true
  942. case "password":
  943. return true
  944. case "sslmode", "sslcert", "sslkey", "sslrootcert":
  945. return true
  946. case "fallback_application_name":
  947. return true
  948. case "connect_timeout":
  949. return true
  950. case "disable_prepared_binary_result":
  951. return true
  952. case "binary_parameters":
  953. return true
  954. default:
  955. return false
  956. }
  957. }
  958. func (cn *conn) startup(o values) {
  959. w := cn.writeBuf(0)
  960. w.int32(196608)
  961. // Send the backend the name of the database we want to connect to, and the
  962. // user we want to connect as. Additionally, we send over any run-time
  963. // parameters potentially included in the connection string. If the server
  964. // doesn't recognize any of them, it will reply with an error.
  965. for k, v := range o {
  966. if isDriverSetting(k) {
  967. // skip options which can't be run-time parameters
  968. continue
  969. }
  970. // The protocol requires us to supply the database name as "database"
  971. // instead of "dbname".
  972. if k == "dbname" {
  973. k = "database"
  974. }
  975. w.string(k)
  976. w.string(v)
  977. }
  978. w.string("")
  979. if err := cn.sendStartupPacket(w); err != nil {
  980. panic(err)
  981. }
  982. for {
  983. t, r := cn.recv()
  984. switch t {
  985. case 'K':
  986. cn.processBackendKeyData(r)
  987. case 'S':
  988. cn.processParameterStatus(r)
  989. case 'R':
  990. cn.auth(r, o)
  991. case 'Z':
  992. cn.processReadyForQuery(r)
  993. return
  994. default:
  995. errorf("unknown response for startup: %q", t)
  996. }
  997. }
  998. }
  999. func (cn *conn) auth(r *readBuf, o values) {
  1000. switch code := r.int32(); code {
  1001. case 0:
  1002. // OK
  1003. case 3:
  1004. w := cn.writeBuf('p')
  1005. w.string(o["password"])
  1006. cn.send(w)
  1007. t, r := cn.recv()
  1008. if t != 'R' {
  1009. errorf("unexpected password response: %q", t)
  1010. }
  1011. if r.int32() != 0 {
  1012. errorf("unexpected authentication response: %q", t)
  1013. }
  1014. case 5:
  1015. s := string(r.next(4))
  1016. w := cn.writeBuf('p')
  1017. w.string("md5" + md5s(md5s(o["password"]+o["user"])+s))
  1018. cn.send(w)
  1019. t, r := cn.recv()
  1020. if t != 'R' {
  1021. errorf("unexpected password response: %q", t)
  1022. }
  1023. if r.int32() != 0 {
  1024. errorf("unexpected authentication response: %q", t)
  1025. }
  1026. default:
  1027. errorf("unknown authentication response: %d", code)
  1028. }
  1029. }
  1030. type format int
  1031. const formatText format = 0
  1032. const formatBinary format = 1
  1033. // One result-column format code with the value 1 (i.e. all binary).
  1034. var colFmtDataAllBinary = []byte{0, 1, 0, 1}
  1035. // No result-column format codes (i.e. all text).
  1036. var colFmtDataAllText = []byte{0, 0}
  1037. type stmt struct {
  1038. cn *conn
  1039. name string
  1040. colNames []string
  1041. colFmts []format
  1042. colFmtData []byte
  1043. colTyps []fieldDesc
  1044. paramTyps []oid.Oid
  1045. closed bool
  1046. }
  1047. func (st *stmt) Close() (err error) {
  1048. if st.closed {
  1049. return nil
  1050. }
  1051. if st.cn.bad {
  1052. return driver.ErrBadConn
  1053. }
  1054. defer st.cn.errRecover(&err)
  1055. w := st.cn.writeBuf('C')
  1056. w.byte('S')
  1057. w.string(st.name)
  1058. st.cn.send(w)
  1059. st.cn.send(st.cn.writeBuf('S'))
  1060. t, _ := st.cn.recv1()
  1061. if t != '3' {
  1062. st.cn.bad = true
  1063. errorf("unexpected close response: %q", t)
  1064. }
  1065. st.closed = true
  1066. t, r := st.cn.recv1()
  1067. if t != 'Z' {
  1068. st.cn.bad = true
  1069. errorf("expected ready for query, but got: %q", t)
  1070. }
  1071. st.cn.processReadyForQuery(r)
  1072. return nil
  1073. }
  1074. func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
  1075. if st.cn.bad {
  1076. return nil, driver.ErrBadConn
  1077. }
  1078. defer st.cn.errRecover(&err)
  1079. st.exec(v)
  1080. return &rows{
  1081. cn: st.cn,
  1082. colNames: st.colNames,
  1083. colTyps: st.colTyps,
  1084. colFmts: st.colFmts,
  1085. }, nil
  1086. }
  1087. func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
  1088. if st.cn.bad {
  1089. return nil, driver.ErrBadConn
  1090. }
  1091. defer st.cn.errRecover(&err)
  1092. st.exec(v)
  1093. res, _, err = st.cn.readExecuteResponse("simple query")
  1094. return res, err
  1095. }
  1096. func (st *stmt) exec(v []driver.Value) {
  1097. if len(v) >= 65536 {
  1098. errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v))
  1099. }
  1100. if len(v) != len(st.paramTyps) {
  1101. errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps))
  1102. }
  1103. cn := st.cn
  1104. w := cn.writeBuf('B')
  1105. w.byte(0) // unnamed portal
  1106. w.string(st.name)
  1107. if cn.binaryParameters {
  1108. cn.sendBinaryParameters(w, v)
  1109. } else {
  1110. w.int16(0)
  1111. w.int16(len(v))
  1112. for i, x := range v {
  1113. if x == nil {
  1114. w.int32(-1)
  1115. } else {
  1116. b := encode(&cn.parameterStatus, x, st.paramTyps[i])
  1117. w.int32(len(b))
  1118. w.bytes(b)
  1119. }
  1120. }
  1121. }
  1122. w.bytes(st.colFmtData)
  1123. w.next('E')
  1124. w.byte(0)
  1125. w.int32(0)
  1126. w.next('S')
  1127. cn.send(w)
  1128. cn.readBindResponse()
  1129. cn.postExecuteWorkaround()
  1130. }
  1131. func (st *stmt) NumInput() int {
  1132. return len(st.paramTyps)
  1133. }
  1134. // parseComplete parses the "command tag" from a CommandComplete message, and
  1135. // returns the number of rows affected (if applicable) and a string
  1136. // identifying only the command that was executed, e.g. "ALTER TABLE". If the
  1137. // command tag could not be parsed, parseComplete panics.
  1138. func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
  1139. commandsWithAffectedRows := []string{
  1140. "SELECT ",
  1141. // INSERT is handled below
  1142. "UPDATE ",
  1143. "DELETE ",
  1144. "FETCH ",
  1145. "MOVE ",
  1146. "COPY ",
  1147. }
  1148. var affectedRows *string
  1149. for _, tag := range commandsWithAffectedRows {
  1150. if strings.HasPrefix(commandTag, tag) {
  1151. t := commandTag[len(tag):]
  1152. affectedRows = &t
  1153. commandTag = tag[:len(tag)-1]
  1154. break
  1155. }
  1156. }
  1157. // INSERT also includes the oid of the inserted row in its command tag.
  1158. // Oids in user tables are deprecated, and the oid is only returned when
  1159. // exactly one row is inserted, so it's unlikely to be of value to any
  1160. // real-world application and we can ignore it.
  1161. if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
  1162. parts := strings.Split(commandTag, " ")
  1163. if len(parts) != 3 {
  1164. cn.bad = true
  1165. errorf("unexpected INSERT command tag %s", commandTag)
  1166. }
  1167. affectedRows = &parts[len(parts)-1]
  1168. commandTag = "INSERT"
  1169. }
  1170. // There should be no affected rows attached to the tag, just return it
  1171. if affectedRows == nil {
  1172. return driver.RowsAffected(0), commandTag
  1173. }
  1174. n, err := strconv.ParseInt(*affectedRows, 10, 64)
  1175. if err != nil {
  1176. cn.bad = true
  1177. errorf("could not parse commandTag: %s", err)
  1178. }
  1179. return driver.RowsAffected(n), commandTag
  1180. }
  1181. type rows struct {
  1182. cn *conn
  1183. finish func()
  1184. colNames []string
  1185. colTyps []fieldDesc
  1186. colFmts []format
  1187. done bool
  1188. rb readBuf
  1189. result driver.Result
  1190. tag string
  1191. }
  1192. func (rs *rows) Close() error {
  1193. if finish := rs.finish; finish != nil {
  1194. defer finish()
  1195. }
  1196. // no need to look at cn.bad as Next() will
  1197. for {
  1198. err := rs.Next(nil)
  1199. switch err {
  1200. case nil:
  1201. case io.EOF:
  1202. // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row
  1203. // description, used with HasNextResultSet). We need to fetch messages until
  1204. // we hit a 'Z', which is done by waiting for done to be set.
  1205. if rs.done {
  1206. return nil
  1207. }
  1208. default:
  1209. return err
  1210. }
  1211. }
  1212. }
  1213. func (rs *rows) Columns() []string {
  1214. return rs.colNames
  1215. }
  1216. func (rs *rows) Result() driver.Result {
  1217. if rs.result == nil {
  1218. return emptyRows
  1219. }
  1220. return rs.result
  1221. }
  1222. func (rs *rows) Tag() string {
  1223. return rs.tag
  1224. }
  1225. func (rs *rows) Next(dest []driver.Value) (err error) {
  1226. if rs.done {
  1227. return io.EOF
  1228. }
  1229. conn := rs.cn
  1230. if conn.bad {
  1231. return driver.ErrBadConn
  1232. }
  1233. defer conn.errRecover(&err)
  1234. for {
  1235. t := conn.recv1Buf(&rs.rb)
  1236. switch t {
  1237. case 'E':
  1238. err = parseError(&rs.rb)
  1239. case 'C', 'I':
  1240. if t == 'C' {
  1241. rs.result, rs.tag = conn.parseComplete(rs.rb.string())
  1242. }
  1243. continue
  1244. case 'Z':
  1245. conn.processReadyForQuery(&rs.rb)
  1246. rs.done = true
  1247. if err != nil {
  1248. return err
  1249. }
  1250. return io.EOF
  1251. case 'D':
  1252. n := rs.rb.int16()
  1253. if err != nil {
  1254. conn.bad = true
  1255. errorf("unexpected DataRow after error %s", err)
  1256. }
  1257. if n < len(dest) {
  1258. dest = dest[:n]
  1259. }
  1260. for i := range dest {
  1261. l := rs.rb.int32()
  1262. if l == -1 {
  1263. dest[i] = nil
  1264. continue
  1265. }
  1266. dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i])
  1267. }
  1268. return
  1269. case 'T':
  1270. rs.colNames, rs.colFmts, rs.colTyps = parsePortalRowDescribe(&rs.rb)
  1271. return io.EOF
  1272. default:
  1273. errorf("unexpected message after execute: %q", t)
  1274. }
  1275. }
  1276. }
  1277. func (rs *rows) HasNextResultSet() bool {
  1278. return !rs.done
  1279. }
  1280. func (rs *rows) NextResultSet() error {
  1281. return nil
  1282. }
  1283. // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
  1284. // used as part of an SQL statement. For example:
  1285. //
  1286. // tblname := "my_table"
  1287. // data := "my_data"
  1288. // quoted := pq.QuoteIdentifier(tblname)
  1289. // err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
  1290. //
  1291. // Any double quotes in name will be escaped. The quoted identifier will be
  1292. // case sensitive when used in a query. If the input string contains a zero
  1293. // byte, the result will be truncated immediately before it.
  1294. func QuoteIdentifier(name string) string {
  1295. end := strings.IndexRune(name, 0)
  1296. if end > -1 {
  1297. name = name[:end]
  1298. }
  1299. return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
  1300. }
  1301. func md5s(s string) string {
  1302. h := md5.New()
  1303. h.Write([]byte(s))
  1304. return fmt.Sprintf("%x", h.Sum(nil))
  1305. }
  1306. func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {
  1307. // Do one pass over the parameters to see if we're going to send any of
  1308. // them over in binary. If we are, create a paramFormats array at the
  1309. // same time.
  1310. var paramFormats []int
  1311. for i, x := range args {
  1312. _, ok := x.([]byte)
  1313. if ok {
  1314. if paramFormats == nil {
  1315. paramFormats = make([]int, len(args))
  1316. }
  1317. paramFormats[i] = 1
  1318. }
  1319. }
  1320. if paramFormats == nil {
  1321. b.int16(0)
  1322. } else {
  1323. b.int16(len(paramFormats))
  1324. for _, x := range paramFormats {
  1325. b.int16(x)
  1326. }
  1327. }
  1328. b.int16(len(args))
  1329. for _, x := range args {
  1330. if x == nil {
  1331. b.int32(-1)
  1332. } else {
  1333. datum := binaryEncode(&cn.parameterStatus, x)
  1334. b.int32(len(datum))
  1335. b.bytes(datum)
  1336. }
  1337. }
  1338. }
  1339. func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
  1340. if len(args) >= 65536 {
  1341. errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
  1342. }
  1343. b := cn.writeBuf('P')
  1344. b.byte(0) // unnamed statement
  1345. b.string(query)
  1346. b.int16(0)
  1347. b.next('B')
  1348. b.int16(0) // unnamed portal and statement
  1349. cn.sendBinaryParameters(b, args)
  1350. b.bytes(colFmtDataAllText)
  1351. b.next('D')
  1352. b.byte('P')
  1353. b.byte(0) // unnamed portal
  1354. b.next('E')
  1355. b.byte(0)
  1356. b.int32(0)
  1357. b.next('S')
  1358. cn.send(b)
  1359. }
  1360. func (cn *conn) processParameterStatus(r *readBuf) {
  1361. var err error
  1362. param := r.string()
  1363. switch param {
  1364. case "server_version":
  1365. var major1 int
  1366. var major2 int
  1367. var minor int
  1368. _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor)
  1369. if err == nil {
  1370. cn.parameterStatus.serverVersion = major1*10000 + major2*100 + minor
  1371. }
  1372. case "TimeZone":
  1373. cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
  1374. if err != nil {
  1375. cn.parameterStatus.currentLocation = nil
  1376. }
  1377. default:
  1378. // ignore
  1379. }
  1380. }
  1381. func (cn *conn) processReadyForQuery(r *readBuf) {
  1382. cn.txnStatus = transactionStatus(r.byte())
  1383. }
  1384. func (cn *conn) readReadyForQuery() {
  1385. t, r := cn.recv1()
  1386. switch t {
  1387. case 'Z':
  1388. cn.processReadyForQuery(r)
  1389. return
  1390. default:
  1391. cn.bad = true
  1392. errorf("unexpected message %q; expected ReadyForQuery", t)
  1393. }
  1394. }
  1395. func (cn *conn) processBackendKeyData(r *readBuf) {
  1396. cn.processID = r.int32()
  1397. cn.secretKey = r.int32()
  1398. }
  1399. func (cn *conn) readParseResponse() {
  1400. t, r := cn.recv1()
  1401. switch t {
  1402. case '1':
  1403. return
  1404. case 'E':
  1405. err := parseError(r)
  1406. cn.readReadyForQuery()
  1407. panic(err)
  1408. default:
  1409. cn.bad = true
  1410. errorf("unexpected Parse response %q", t)
  1411. }
  1412. }
  1413. func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) {
  1414. for {
  1415. t, r := cn.recv1()
  1416. switch t {
  1417. case 't':
  1418. nparams := r.int16()
  1419. paramTyps = make([]oid.Oid, nparams)
  1420. for i := range paramTyps {
  1421. paramTyps[i] = r.oid()
  1422. }
  1423. case 'n':
  1424. return paramTyps, nil, nil
  1425. case 'T':
  1426. colNames, colTyps = parseStatementRowDescribe(r)
  1427. return paramTyps, colNames, colTyps
  1428. case 'E':
  1429. err := parseError(r)
  1430. cn.readReadyForQuery()
  1431. panic(err)
  1432. default:
  1433. cn.bad = true
  1434. errorf("unexpected Describe statement response %q", t)
  1435. }
  1436. }
  1437. }
  1438. func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []fieldDesc) {
  1439. t, r := cn.recv1()
  1440. switch t {
  1441. case 'T':
  1442. return parsePortalRowDescribe(r)
  1443. case 'n':
  1444. return nil, nil, nil
  1445. case 'E':
  1446. err := parseError(r)
  1447. cn.readReadyForQuery()
  1448. panic(err)
  1449. default:
  1450. cn.bad = true
  1451. errorf("unexpected Describe response %q", t)
  1452. }
  1453. panic("not reached")
  1454. }
  1455. func (cn *conn) readBindResponse() {
  1456. t, r := cn.recv1()
  1457. switch t {
  1458. case '2':
  1459. return
  1460. case 'E':
  1461. err := parseError(r)
  1462. cn.readReadyForQuery()
  1463. panic(err)
  1464. default:
  1465. cn.bad = true
  1466. errorf("unexpected Bind response %q", t)
  1467. }
  1468. }
  1469. func (cn *conn) postExecuteWorkaround() {
  1470. // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores
  1471. // any errors from rows.Next, which masks errors that happened during the
  1472. // execution of the query. To avoid the problem in common cases, we wait
  1473. // here for one more message from the database. If it's not an error the
  1474. // query will likely succeed (or perhaps has already, if it's a
  1475. // CommandComplete), so we push the message into the conn struct; recv1
  1476. // will return it as the next message for rows.Next or rows.Close.
  1477. // However, if it's an error, we wait until ReadyForQuery and then return
  1478. // the error to our caller.
  1479. for {
  1480. t, r := cn.recv1()
  1481. switch t {
  1482. case 'E':
  1483. err := parseError(r)
  1484. cn.readReadyForQuery()
  1485. panic(err)
  1486. case 'C', 'D', 'I':
  1487. // the query didn't fail, but we can't process this message
  1488. cn.saveMessage(t, r)
  1489. return
  1490. default:
  1491. cn.bad = true
  1492. errorf("unexpected message during extended query execution: %q", t)
  1493. }
  1494. }
  1495. }
  1496. // Only for Exec(), since we ignore the returned data
  1497. func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) {
  1498. for {
  1499. t, r := cn.recv1()
  1500. switch t {
  1501. case 'C':
  1502. if err != nil {
  1503. cn.bad = true
  1504. errorf("unexpected CommandComplete after error %s", err)
  1505. }
  1506. res, commandTag = cn.parseComplete(r.string())
  1507. case 'Z':
  1508. cn.processReadyForQuery(r)
  1509. if res == nil && err == nil {
  1510. err = errUnexpectedReady
  1511. }
  1512. return res, commandTag, err
  1513. case 'E':
  1514. err = parseError(r)
  1515. case 'T', 'D', 'I':
  1516. if err != nil {
  1517. cn.bad = true
  1518. errorf("unexpected %q after error %s", t, err)
  1519. }
  1520. if t == 'I' {
  1521. res = emptyRows
  1522. }
  1523. // ignore any results
  1524. default:
  1525. cn.bad = true
  1526. errorf("unknown %s response: %q", protocolState, t)
  1527. }
  1528. }
  1529. }
  1530. func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) {
  1531. n := r.int16()
  1532. colNames = make([]string, n)
  1533. colTyps = make([]fieldDesc, n)
  1534. for i := range colNames {
  1535. colNames[i] = r.string()
  1536. r.next(6)
  1537. colTyps[i].OID = r.oid()
  1538. colTyps[i].Len = r.int16()
  1539. colTyps[i].Mod = r.int32()
  1540. // format code not known when describing a statement; always 0
  1541. r.next(2)
  1542. }
  1543. return
  1544. }
  1545. func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []fieldDesc) {
  1546. n := r.int16()
  1547. colNames = make([]string, n)
  1548. colFmts = make([]format, n)
  1549. colTyps = make([]fieldDesc, n)
  1550. for i := range colNames {
  1551. colNames[i] = r.string()
  1552. r.next(6)
  1553. colTyps[i].OID = r.oid()
  1554. colTyps[i].Len = r.int16()
  1555. colTyps[i].Mod = r.int32()
  1556. colFmts[i] = format(r.int16())
  1557. }
  1558. return
  1559. }
  1560. // parseEnviron tries to mimic some of libpq's environment handling
  1561. //
  1562. // To ease testing, it does not directly reference os.Environ, but is
  1563. // designed to accept its output.
  1564. //
  1565. // Environment-set connection information is intended to have a higher
  1566. // precedence than a library default but lower than any explicitly
  1567. // passed information (such as in the URL or connection string).
  1568. func parseEnviron(env []string) (out map[string]string) {
  1569. out = make(map[string]string)
  1570. for _, v := range env {
  1571. parts := strings.SplitN(v, "=", 2)
  1572. accrue := func(keyname string) {
  1573. out[keyname] = parts[1]
  1574. }
  1575. unsupported := func() {
  1576. panic(fmt.Sprintf("setting %v not supported", parts[0]))
  1577. }
  1578. // The order of these is the same as is seen in the
  1579. // PostgreSQL 9.1 manual. Unsupported but well-defined
  1580. // keys cause a panic; these should be unset prior to
  1581. // execution. Options which pq expects to be set to a
  1582. // certain value are allowed, but must be set to that
  1583. // value if present (they can, of course, be absent).
  1584. switch parts[0] {
  1585. case "PGHOST":
  1586. accrue("host")
  1587. case "PGHOSTADDR":
  1588. unsupported()
  1589. case "PGPORT":
  1590. accrue("port")
  1591. case "PGDATABASE":
  1592. accrue("dbname")
  1593. case "PGUSER":
  1594. accrue("user")
  1595. case "PGPASSWORD":
  1596. accrue("password")
  1597. case "PGSERVICE", "PGSERVICEFILE", "PGREALM":
  1598. unsupported()
  1599. case "PGOPTIONS":
  1600. accrue("options")
  1601. case "PGAPPNAME":
  1602. accrue("application_name")
  1603. case "PGSSLMODE":
  1604. accrue("sslmode")
  1605. case "PGSSLCERT":
  1606. accrue("sslcert")
  1607. case "PGSSLKEY":
  1608. accrue("sslkey")
  1609. case "PGSSLROOTCERT":
  1610. accrue("sslrootcert")
  1611. case "PGREQUIRESSL", "PGSSLCRL":
  1612. unsupported()
  1613. case "PGREQUIREPEER":
  1614. unsupported()
  1615. case "PGKRBSRVNAME", "PGGSSLIB":
  1616. unsupported()
  1617. case "PGCONNECT_TIMEOUT":
  1618. accrue("connect_timeout")
  1619. case "PGCLIENTENCODING":
  1620. accrue("client_encoding")
  1621. case "PGDATESTYLE":
  1622. accrue("datestyle")
  1623. case "PGTZ":
  1624. accrue("timezone")
  1625. case "PGGEQO":
  1626. accrue("geqo")
  1627. case "PGSYSCONFDIR", "PGLOCALEDIR":
  1628. unsupported()
  1629. }
  1630. }
  1631. return out
  1632. }
  1633. // isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
  1634. func isUTF8(name string) bool {
  1635. // Recognize all sorts of silly things as "UTF-8", like Postgres does
  1636. s := strings.Map(alnumLowerASCII, name)
  1637. return s == "utf8" || s == "unicode"
  1638. }
  1639. func alnumLowerASCII(ch rune) rune {
  1640. if 'A' <= ch && ch <= 'Z' {
  1641. return ch + ('a' - 'A')
  1642. }
  1643. if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
  1644. return ch
  1645. }
  1646. return -1 // discard
  1647. }