pread.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. package main
  2. import (
  3. "bufio"
  4. "errors"
  5. "flag"
  6. "fmt"
  7. "log"
  8. "os"
  9. "runtime"
  10. "strconv"
  11. "strings"
  12. "time"
  13. "github.com/tal-tech/go-zero/core/filex"
  14. "github.com/tal-tech/go-zero/core/fx"
  15. "github.com/tal-tech/go-zero/core/logx"
  16. "gopkg.in/cheggaaa/pb.v1"
  17. )
  18. var (
  19. file = flag.String("f", "", "the input file")
  20. concurrent = flag.Int("c", runtime.NumCPU(), "concurrent goroutines")
  21. wordVecDic TXDictionary
  22. )
  23. type (
  24. Vector []float64
  25. TXDictionary struct {
  26. EmbeddingCount int64
  27. Dim int64
  28. Dict map[string]Vector
  29. }
  30. pair struct {
  31. key string
  32. vec Vector
  33. }
  34. )
  35. func FastLoad(filename string) error {
  36. if filename == "" {
  37. return errors.New("no available dictionary")
  38. }
  39. now := time.Now()
  40. defer func() {
  41. logx.Infof("article2vec init dictionary end used %v", time.Since(now))
  42. }()
  43. dicFile, err := os.Open(filename)
  44. if err != nil {
  45. return err
  46. }
  47. defer dicFile.Close()
  48. header, err := filex.FirstLine(filename)
  49. if err != nil {
  50. return err
  51. }
  52. total := strings.Split(header, " ")
  53. wordVecDic.EmbeddingCount, err = strconv.ParseInt(total[0], 10, 64)
  54. if err != nil {
  55. return err
  56. }
  57. wordVecDic.Dim, err = strconv.ParseInt(total[1], 10, 64)
  58. if err != nil {
  59. return err
  60. }
  61. wordVecDic.Dict = make(map[string]Vector, wordVecDic.EmbeddingCount)
  62. ranges, err := filex.SplitLineChunks(filename, *concurrent)
  63. if err != nil {
  64. return err
  65. }
  66. info, err := os.Stat(filename)
  67. if err != nil {
  68. return err
  69. }
  70. bar := pb.New64(info.Size()).SetUnits(pb.U_BYTES).Start()
  71. fx.From(func(source chan<- interface{}) {
  72. for _, each := range ranges {
  73. source <- each
  74. }
  75. }).Walk(func(item interface{}, pipe chan<- interface{}) {
  76. offsetRange := item.(filex.OffsetRange)
  77. scanner := bufio.NewScanner(filex.NewRangeReader(dicFile, offsetRange.Start, offsetRange.Stop))
  78. scanner.Buffer([]byte{}, 1<<20)
  79. reader := filex.NewProgressScanner(scanner, bar)
  80. if offsetRange.Start == 0 {
  81. // skip header
  82. reader.Scan()
  83. }
  84. for reader.Scan() {
  85. text := reader.Text()
  86. elements := strings.Split(text, " ")
  87. vec := make(Vector, wordVecDic.Dim)
  88. for i, ele := range elements {
  89. if i == 0 {
  90. continue
  91. }
  92. v, err := strconv.ParseFloat(ele, 64)
  93. if err != nil {
  94. return
  95. }
  96. vec[i-1] = v
  97. }
  98. pipe <- pair{
  99. key: elements[0],
  100. vec: vec,
  101. }
  102. }
  103. }).ForEach(func(item interface{}) {
  104. p := item.(pair)
  105. wordVecDic.Dict[p.key] = p.vec
  106. })
  107. return nil
  108. }
  109. func main() {
  110. flag.Parse()
  111. start := time.Now()
  112. if err := FastLoad(*file); err != nil {
  113. log.Fatal(err)
  114. }
  115. fmt.Println(len(wordVecDic.Dict))
  116. fmt.Println(time.Since(start))
  117. }