소스 검색

some fixed for mssql support

xormplus 9 년 전
부모
커밋
5570203c86
4개의 변경된 파일47개의 추가작업 그리고 14개의 파일을 삭제
  1. 2 2
      mssql_dialect.go
  2. 22 4
      session.go
  3. 11 5
      sessionplus.go
  4. 12 3
      statement.go

+ 2 - 2
mssql_dialect.go

@@ -242,8 +242,8 @@ func (db *mssql) SqlType(c *core.Column) string {
 		c.Length = 7
 		c.Length = 7
 	case core.MediumInt:
 	case core.MediumInt:
 		res = core.Int
 		res = core.Int
-	case core.MediumText, core.TinyText, core.LongText, core.Json:
-		res = core.Text
+	case core.Text, core.MediumText, core.TinyText, core.LongText, core.Json:
+		res = core.Varchar + "(MAX)"
 	case core.Double:
 	case core.Double:
 		res = core.Real
 		res = core.Real
 	case core.Uuid:
 	case core.Uuid:

+ 22 - 4
session.go

@@ -691,6 +691,7 @@ func (session *Session) canCache() bool {
 		session.Statement.JoinStr != "" ||
 		session.Statement.JoinStr != "" ||
 		session.Statement.RawSQL != "" ||
 		session.Statement.RawSQL != "" ||
 		!session.Statement.UseCache ||
 		!session.Statement.UseCache ||
+		session.Statement.IsForUpdate ||
 		session.Tx != nil ||
 		session.Tx != nil ||
 		len(session.Statement.selectStr) > 0 {
 		len(session.Statement.selectStr) > 0 {
 		return false
 		return false
@@ -1339,7 +1340,11 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
 				}
 				}
 				colName = session.Engine.Quote(nm) + "." + colName
 				colName = session.Engine.Quote(nm) + "." + colName
 			}
 			}
-			autoCond = builder.IsNull{colName}.Or(builder.Eq{colName: "0001-01-01 00:00:00"})
+			if session.Engine.dialect.DBType() == core.MSSQL {
+				autoCond = builder.IsNull{colName}
+			} else {
+				autoCond = builder.IsNull{colName}.Or(builder.Eq{colName: "0001-01-01 00:00:00"})
+			}
 		}
 		}
 	}
 	}
 
 
@@ -1865,15 +1870,22 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
 
 
 						t := vv.Convert(core.TimeType).Interface().(time.Time)
 						t := vv.Convert(core.TimeType).Interface().(time.Time)
 						z, _ := t.Zone()
 						z, _ := t.Zone()
-						if len(z) == 0 || t.Year() == 0 { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location
-							dbTZ := session.Engine.DatabaseTZ
-							if dbTZ == nil {
+						dbTZ := session.Engine.DatabaseTZ
+						if dbTZ == nil {
+							if session.Engine.dialect.DBType() == core.SQLITE {
+								dbTZ = time.UTC
+							} else {
 								dbTZ = time.Local
 								dbTZ = time.Local
 							}
 							}
+						}
+
+						// set new location if database don't save timezone or give an incorrect timezone
+						if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location
 							session.Engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location())
 							session.Engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location())
 							t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
 							t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
 								t.Minute(), t.Second(), t.Nanosecond(), dbTZ)
 								t.Minute(), t.Second(), t.Nanosecond(), dbTZ)
 						}
 						}
+
 						// !nashtsai! convert to engine location
 						// !nashtsai! convert to engine location
 						if col.TimeZone == nil {
 						if col.TimeZone == nil {
 							t = t.In(session.Engine.TZLocation)
 							t = t.In(session.Engine.TZLocation)
@@ -3036,6 +3048,9 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
 			if err != nil {
 			if err != nil {
 				return 0, err
 				return 0, err
 			}
 			}
+			if col.SQLType.IsBlob() {
+				return data, nil
+			}
 			return string(data), nil
 			return string(data), nil
 		}
 		}
 	}
 	}
@@ -3045,6 +3060,9 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
 		if err != nil {
 		if err != nil {
 			return 0, err
 			return 0, err
 		}
 		}
+		if col.SQLType.IsBlob() {
+			return data, nil
+		}
 		return string(data), nil
 		return string(data), nil
 	}
 	}
 
 

+ 11 - 5
sessionplus.go

@@ -815,15 +815,22 @@ func (session *Session) _row2BeanWithDateFormat(dateFormat string, rows *core.Ro
 
 
 						t := vv.Convert(core.TimeType).Interface().(time.Time)
 						t := vv.Convert(core.TimeType).Interface().(time.Time)
 						z, _ := t.Zone()
 						z, _ := t.Zone()
-						if len(z) == 0 || t.Year() == 0 { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location
-							dbTZ := session.Engine.DatabaseTZ
-							if dbTZ == nil {
+						dbTZ := session.Engine.DatabaseTZ
+						if dbTZ == nil {
+							if session.Engine.dialect.DBType() == core.SQLITE {
+								dbTZ = time.UTC
+							} else {
 								dbTZ = time.Local
 								dbTZ = time.Local
 							}
 							}
+						}
+
+						// set new location if database don't save timezone or give an incorrect timezone
+						if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location
 							session.Engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location())
 							session.Engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location())
 							t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
 							t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
 								t.Minute(), t.Second(), t.Nanosecond(), dbTZ)
 								t.Minute(), t.Second(), t.Nanosecond(), dbTZ)
 						}
 						}
+
 						// !nashtsai! convert to engine location
 						// !nashtsai! convert to engine location
 						var tz *time.Location
 						var tz *time.Location
 						if col.TimeZone == nil {
 						if col.TimeZone == nil {
@@ -877,7 +884,6 @@ func (session *Session) _row2BeanWithDateFormat(dateFormat string, rows *core.Ro
 					// !<winxxp>! 增加支持sql.Scanner接口的结构,如sql.NullString
 					// !<winxxp>! 增加支持sql.Scanner接口的结构,如sql.NullString
 					hasAssigned = true
 					hasAssigned = true
 					if err := nulVal.Scan(vv.Interface()); err != nil {
 					if err := nulVal.Scan(vv.Interface()); err != nil {
-						//fmt.Println("sql.Sanner error:", err.Error())
 						session.Engine.logger.Error("sql.Sanner error:", err.Error())
 						session.Engine.logger.Error("sql.Sanner error:", err.Error())
 						hasAssigned = false
 						hasAssigned = false
 					}
 					}
@@ -959,7 +965,7 @@ func (session *Session) _row2BeanWithDateFormat(dateFormat string, rows *core.Ro
 								//fieldValue.Set(reflect.ValueOf(v))
 								//fieldValue.Set(reflect.ValueOf(v))
 								fieldValue.Set(structInter.Elem())
 								fieldValue.Set(structInter.Elem())
 							} else {
 							} else {
-								return errors.New("cascade obj is not exist!")
+								return errors.New("cascade obj is not exist")
 							}
 							}
 						}
 						}
 					} else {
 					} else {

+ 12 - 3
statement.go

@@ -497,7 +497,7 @@ func buildConds(engine *Engine, table *core.Table, bean interface{},
 			continue
 			continue
 		}
 		}
 
 
-		if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text {
+		if engine.dialect.DBType() == core.MSSQL && (col.SQLType.Name == core.Text || col.SQLType.IsBlob() || col.SQLType.Name == core.TimeStampz) {
 			continue
 			continue
 		}
 		}
 		if col.SQLType.IsJson() {
 		if col.SQLType.IsJson() {
@@ -522,7 +522,11 @@ func buildConds(engine *Engine, table *core.Table, bean interface{},
 		}
 		}
 
 
 		if col.IsDeleted && !unscoped { // tag "deleted" is enabled
 		if col.IsDeleted && !unscoped { // tag "deleted" is enabled
-			conds = append(conds, builder.IsNull{colName}.Or(builder.Eq{colName: "0001-01-01 00:00:00"}))
+			if engine.dialect.DBType() == core.MSSQL {
+				conds = append(conds, builder.IsNull{colName})
+			} else {
+				conds = append(conds, builder.IsNull{colName}.Or(builder.Eq{colName: "0001-01-01 00:00:00"}))
+			}
 		}
 		}
 
 
 		fieldValue := *fieldValuePtr
 		fieldValue := *fieldValuePtr
@@ -1316,7 +1320,12 @@ func (statement *Statement) convertIDSQL(sqlStr string) string {
 			return ""
 			return ""
 		}
 		}
 
 
-		return fmt.Sprintf("SELECT %s FROM %v", colstrs, sqls[1])
+		var top string
+		if statement.LimitN > 0 && statement.Engine.dialect.DBType() == core.MSSQL {
+			top = fmt.Sprintf("TOP %d ", statement.LimitN)
+		}
+
+		return fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
 	}
 	}
 	return ""
 	return ""
 }
 }