浏览代码

Add context support

* add context support

* improve pingcontext tests
xormplus 6 年之前
父节点
当前提交
c4443b8168
共有 11 个文件被更改,包括 124 次插入25 次删除
  1. 3 0
      engine.go
  2. 28 0
      engine_context.go
  3. 26 0
      engine_context_test.go
  4. 2 0
      interface.go
  5. 6 1
      session.go
  6. 5 8
      session_context.go
  7. 36 0
      session_context_test.go
  8. 6 6
      session_raw.go
  9. 1 1
      session_schema.go
  10. 1 1
      session_tx.go
  11. 10 8
      xorm.go

+ 3 - 0
engine.go

@@ -7,6 +7,7 @@ package xorm
 import (
 	"bufio"
 	"bytes"
+	"context"
 	"database/sql"
 	"encoding/gob"
 	"errors"
@@ -55,6 +56,8 @@ type Engine struct {
 
 	cachers    map[string]core.Cacher
 	cacherLock sync.RWMutex
+
+	defaultContext context.Context
 }
 
 func (engine *Engine) setCacher(tableName string, cacher core.Cacher) {

+ 28 - 0
engine_context.go

@@ -0,0 +1,28 @@
+// Copyright 2019 The Xorm Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.8
+
+package xorm
+
+import "context"
+
+// Context creates a session with the context
+func (engine *Engine) Context(ctx context.Context) *Session {
+	session := engine.NewSession()
+	session.isAutoClose = true
+	return session.Context(ctx)
+}
+
+// SetDefaultContext set the default context
+func (engine *Engine) SetDefaultContext(ctx context.Context) {
+	engine.defaultContext = ctx
+}
+
+// PingContext tests if database is alive
+func (engine *Engine) PingContext(ctx context.Context) error {
+	session := engine.NewSession()
+	defer session.Close()
+	return session.PingContext(ctx)
+}

+ 26 - 0
engine_context_test.go

@@ -0,0 +1,26 @@
+// Copyright 2017 The Xorm Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.8
+
+package xorm
+
+import (
+	"context"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestPingContext(t *testing.T) {
+	assert.NoError(t, prepareEngine())
+
+	ctx, canceled := context.WithTimeout(context.Background(), time.Nanosecond)
+	defer canceled()
+
+	err := testEngine.(*Engine).PingContext(ctx)
+	assert.Error(t, err)
+	assert.Contains(t, err.Error(), "context deadline exceeded")
+}

+ 2 - 0
interface.go

@@ -5,6 +5,7 @@
 package xorm
 
 import (
+	"context"
 	"database/sql"
 	"reflect"
 	"time"
@@ -75,6 +76,7 @@ type EngineInterface interface {
 	Before(func(interface{})) *Session
 	Charset(charset string) *Session
 	ClearCache(...interface{}) error
+	Context(context.Context) *Session
 	CreateTables(...interface{}) error
 	DBMetas() ([]*core.Table, error)
 	Dialect() core.Dialect

+ 6 - 1
session.go

@@ -5,6 +5,7 @@
 package xorm
 
 import (
+	"context"
 	"database/sql"
 	"encoding/json"
 	"errors"
@@ -55,6 +56,8 @@ type Session struct {
 
 	rollbackSavePointID string
 
+	ctx context.Context
+
 	err error
 }
 
@@ -87,6 +90,8 @@ func (session *Session) Init() {
 
 	session.lastSQL = ""
 	session.lastSQLArgs = []interface{}{}
+
+	session.ctx = session.engine.defaultContext
 }
 
 // Close release the connection from pool
@@ -281,7 +286,7 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt,
 	var has bool
 	stmt, has = session.stmtCache[crc]
 	if !has {
-		stmt, err = db.Prepare(sqlStr)
+		stmt, err = db.PrepareContext(session.ctx, sqlStr)
 		if err != nil {
 			return nil, err
 		}

+ 5 - 8
context.go → session_context.go

@@ -1,18 +1,15 @@
-// Copyright 2017 The Xorm Authors. All rights reserved.
+// Copyright 2019 The Xorm Authors. All rights reserved.
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// +build go1.8
-
 package xorm
 
 import "context"
 
-// PingContext tests if database is alive
-func (engine *Engine) PingContext(ctx context.Context) error {
-	session := engine.NewSession()
-	defer session.Close()
-	return session.PingContext(ctx)
+// Context sets the context on this session
+func (session *Session) Context(ctx context.Context) *Session {
+	session.ctx = ctx
+	return session
 }
 
 // PingContext test if database is ok

+ 36 - 0
session_context_test.go

@@ -0,0 +1,36 @@
+// Copyright 2019 The Xorm Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package xorm
+
+import (
+	"context"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestQueryContext(t *testing.T) {
+	type ContextQueryStruct struct {
+		Id   int64
+		Name string
+	}
+
+	assert.NoError(t, prepareEngine())
+	assertSync(t, new(ContextQueryStruct))
+
+	_, err := testEngine.Insert(&ContextQueryStruct{Name: "1"})
+	assert.NoError(t, err)
+
+	sess := testEngine.NewSession()
+	defer sess.Close()
+
+	ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
+	defer cancel()
+	has, err := testEngine.Context(ctx).Exist(&ContextQueryStruct{Name: "1"})
+	assert.Error(t, err)
+	assert.Contains(t, err.Error(), "context deadline exceeded")
+	assert.False(t, has)
+}

+ 6 - 6
session_raw.go

@@ -62,21 +62,21 @@ func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Row
 				return nil, err
 			}
 
-			rows, err := stmt.Query(args...)
+			rows, err := stmt.QueryContext(session.ctx, args...)
 			if err != nil {
 				return nil, err
 			}
 			return rows, nil
 		}
 
-		rows, err := db.Query(sqlStr, args...)
+		rows, err := db.QueryContext(session.ctx, sqlStr, args...)
 		if err != nil {
 			return nil, err
 		}
 		return rows, nil
 	}
 
-	rows, err := session.tx.Query(sqlStr, args...)
+	rows, err := session.tx.QueryContext(session.ctx, sqlStr, args...)
 	if err != nil {
 		return nil, err
 	}
@@ -290,7 +290,7 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er
 	}
 
 	if !session.isAutoCommit {
-		return session.tx.Exec(sqlStr, args...)
+		return session.tx.ExecContext(session.ctx, sqlStr, args...)
 	}
 
 	if session.prepareStmt {
@@ -299,14 +299,14 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er
 			return nil, err
 		}
 
-		res, err := stmt.Exec(args...)
+		res, err := stmt.ExecContext(session.ctx, args...)
 		if err != nil {
 			return nil, err
 		}
 		return res, nil
 	}
 
-	return session.DB().Exec(sqlStr, args...)
+	return session.DB().ExecContext(session.ctx, sqlStr, args...)
 }
 
 func convertSQLOrArgs(sqlorArgs ...interface{}) (string, []interface{}, error) {

+ 1 - 1
session_schema.go

@@ -19,7 +19,7 @@ func (session *Session) Ping() error {
 	}
 
 	session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName())
-	return session.DB().Ping()
+	return session.DB().PingContext(session.ctx)
 }
 
 // CreateTable create a table according a bean

+ 1 - 1
session_tx.go

@@ -7,7 +7,7 @@ package xorm
 // Begin a transaction
 func (session *Session) Begin() error {
 	if session.isAutoCommit {
-		tx, err := session.DB().Begin()
+		tx, err := session.DB().BeginTx(session.ctx, nil)
 		if err != nil {
 			return err
 		}

+ 10 - 8
xorm.go

@@ -7,6 +7,7 @@
 package xorm
 
 import (
+	"context"
 	"fmt"
 	"os"
 	"reflect"
@@ -85,14 +86,15 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
 	}
 
 	engine := &Engine{
-		db:            db,
-		dialect:       dialect,
-		Tables:        make(map[reflect.Type]*core.Table),
-		mutex:         &sync.RWMutex{},
-		TagIdentifier: "xorm",
-		TZLocation:    time.Local,
-		tagHandlers:   defaultTagHandlers,
-		cachers:       make(map[string]core.Cacher),
+		db:             db,
+		dialect:        dialect,
+		Tables:         make(map[reflect.Type]*core.Table),
+		mutex:          &sync.RWMutex{},
+		TagIdentifier:  "xorm",
+		TZLocation:     time.Local,
+		tagHandlers:    defaultTagHandlers,
+		cachers:        make(map[string]core.Cacher),
+		defaultContext: context.Background(),
 	}
 
 	if uri.DbType == core.SQLITE {