Case Study: Google NGrams

The Google Books NGrams dataset is a rich trove of data containing the count of every word that occurs in the millions of books scanned by Google Books since 2005. The dataset is publicly available,[17] quite large, and relatively simple to work with. All the files together consist of hundreds of gigabytes of data, and just the counts of individual words take up more than 50GB. Just so we don’t get totally overwhelmed, we’ll cut it down and only take the 2GB file of words that start with the letter A. In the Docker build environment, you can get it by running the /get_ngram_data.sh script, which will download a file called googlebooks-eng-all-1gram-20120701-a.

(If you’re working on your own machine, you can download the data, but if you are using the supplied docker image, it is already there; just run xz -d googlebooks-eng-all-1gram-20120701-a.xz in the /root home directory and wait, patiently, for it to decompress.)

What we want to do is to read through one of these files and find the most frequent word/year combination. Fortunately the file format makes this relatively easy; every line is of this form:

 ngram TAB year TAB match_count TAB volume_count NEWLINE

Such as:

 A'Aang_NOUN 1879 45 5

We’re most interested in the word itself, the year, and the match_count (that is, the number of occurrences). The volume_count, indicating the number of books containing the word, is less interesting for our purposes. To find the word/year with the largest count, we don’t need to keep all the words in memory—all we need to do is keep track of the word with the largest count we have seen so far, compare each line we read with the current maximum count, and if it’s greater, update the maximum and the corresponding word.

Now, if we were to do this in regular, idiomatic Scala, it would look something like this:

InputAndOutput/max_ngram_naive/main.scala
 object​ main {
 def​ main(args​:​​Array​[​String​])​:​​Unit​ = {
 var​ max ​=​ 0
 var​ max_word ​=​ ​""
 var​ max_year ​=​ 0
 
  println(​"reading from STDIN"​)
 val​ read_start ​=​ System.currentTimeMillis()
 var​ lines_read ​=​ 0
 for​ (line ​<-​ scala.io.Source.stdin.getLines) {
 val​ split_fields ​=​ line.split(​"\\s+"​)
 
 if​ (split_fields.size != 4) {
 throw​ ​new​ Exception(​"Parse Error"​)
  }
 val​ word ​=​ split_fields(0)
 val​ year ​=​ split_fields(1).toInt
 val​ count ​=​ split_fields(2).toInt
 
 if​ (count > max) {
  println(s​"found new max: $word $count $year"​)
  max ​=​ count
  max_word ​=​ word
  max_year ​=​ year
  }
  lines_read += 1
 if​ (lines_read % 5000000 == 0) {
 val​ elapsed_now ​=​ System.currentTimeMillis() - read_start
  println(s​"read $lines_read lines in $elapsed_now ms"​)
  }
  }
 val​ read_done ​=​ System.currentTimeMillis() - read_start
  println(s​"max count: ${max_word}, ${max_year}; ${max} occurrences"​)
  println(s​"$read_done ms elapsed total."​)
  }
 }

This code is correct, and produces the correct result. But when we’re handling this much data, we also have to worry about how long it takes. With the UNIX command time we can get the execution time of any program, like this:

 $ ​​time​​ ​​sbt​​ ​​run​​ ​​<​​ ​​../../googlebooks-eng-all-1gram-20
 120701-a
 ...
 [info] Running main
 found new max: A'Aang_NOUN 45 1879
 found new max: A.E.U._DET 65 1975
 found new max: A.J.B._NOUN 72 1960
 found new max: A.J.B._NOUN 300 1963
 found new max: A.J.B._NOUN 393 1995
 ...
 found new max: and_CONJ 380846175 2007
 found new max: and_CONJ 470334485 2008
 found new max: and 470825580 2008
 max count: and, 2008; 470825580 occurrences
 [success] Total time: 231 s, completed May 2, 2018 7:35:05 PM
 real 3m 57.89s
 user 1m 32.26s
 sys 0m 20.53s

As we can see, the code runs and finds that the maximum word is (unsurprisingly) “and.” However, it takes about four minutes to complete on my machine. Just from inspecting the code, we can see a lot of unnecessary allocation. In particular, we’ll allocate a string for every line, then split it into three more strings, before parsing into ints and updating. All those string allocations add up and make the JVM garbage collector work pretty hard. So we have every reason to believe that if we can cut out all the allocation and GC overhead, we could improve performance substantially.

Now, let’s implement an equivalent program using all of the techniques we’ve looked at in this chapter. The code will be structurally quite similar to the sscanf example project we just created, but with three major changes:

By reading directly into temporary storage, we can dramatically reduce the memory usage of our program and eliminate garbage collection overhead entirely, although it becomes a little ungainly. The main() function should look familiar:

InputAndOutput/max_ngram/main.scala
 def​ main(args​:​​Array​[​String​])​:​​Unit​ = {
 var​ max_word​:​​Ptr​[​Byte​] ​=​ stackalloc[​Byte​](1024)
 val​ max_count ​=​ stackalloc[​Int​]
 val​ max_year ​=​ stackalloc[​Int​]
 
 val​ line_buffer ​=​ stackalloc[​Byte​](1024)
 val​ temp_word ​=​ stackalloc[​Byte​](1024)
 val​ temp_count ​=​ stackalloc[​Int​]
 val​ temp_year ​=​ stackalloc[​Int​]
 val​ temp_doc_count ​=​ stackalloc[​Int​]
 
 var​ lines_read ​=​ 0
  !max_count ​=​ 0
  !max_year ​=​ 0
 
 while​ (stdio.fgets(line_buffer, 1024, stdio.stdin) != ​null​) {
  lines_read += 1
  parse_and_compare(line_buffer, max_word, temp_word, 1024,
  max_count, temp_count, max_year, temp_year, temp_doc_count)
  }
 
  stdio.printf(c​"done. read %d lines\n"​, lines_read)
  stdio.printf(c​"maximum word count: %d for '%s' @ %d\n"​, !max_count,
  max_word, !max_year)
 }

But the actual parsing code now has more logic in it and takes more arguments:

InputAndOutput/max_ngram/main.scala
 def​ parse_and_compare(line_buffer​:​​CString​, max_word​:​​CString​,
  temp_word​:​​CString​, max_word_buffer_size​:​​Int​,
  max_count​:​​Ptr​[​Int​], temp_count​:​​Ptr​[​Int​], max_year​:​​Ptr​[​Int​],
  temp_year​:​​Ptr​[​Int​], temp_doc_count​:​​Ptr​[​Int​])​:​​Unit​ = {
 val​ scan_result ​=​ stdio.sscanf(line_buffer, c​"%1023s %d %d %d\n"​,
  temp_word, temp_year, temp_count, temp_doc_count)
 if​ (scan_result < 4) {
 throw​ ​new​ Exception(​"bad input"​)
  }
 if​ (!temp_count <= !max_count) {
 return
  } ​else​ {
  stdio.printf(c​"saw new max: %s %d occurences at year %d\n"​, temp_word,
  !temp_count, !temp_year)
 val​ word_length ​=​ string.strlen(temp_word)
 if​ (word_length >= (max_word_buffer_size - 1)) {
 throw​ ​new​ Exception(
  s​"length $word_length exceeded buffer size $max_word_buffer_size"​)
  }
  string.strncpy(max_word, temp_word, max_word_buffer_size)
  !max_count ​=​ !temp_count
  !max_year ​=​ !temp_year
  }
 }

We’ll see the impact of these optimizations when we run the program and collect timing information:

 $ ​​time​​ ​​./target/scala-2.11/sort_alphabetically-out​​ ​​<​​ ​​\
  ​​/code/googlebooks-eng-all-1gram-20120701-a
 reading from STDIN...
 found new max: A'Aang_NOUN 1879 45
 found new max: A.E.U._DET 1975 65
 found new max: A.J.B._NOUN 1960 72
 found new max: A.J.B._NOUN 1963 300
 found new max: A.J.B._NOUN 1995 393
 ...
 found new max: and_CONJ 2007 380846175
 found new max: and_CONJ 2008 470334485
 found new max: and 2008 470825580
 done. read 86618505 lines
 maximum word count: 470825580 for 'and' @ 2008
 real 0m 33.94s
 user 0m 30.14s
 sys 0m 2.50s

Impressive, no?

With a moderate amount of effort, we can improve the performance of our program by a factor of five or more; our JVM version took almost four minutes to complete, whereas Scala Native took thirty seconds. We only had to use a few more lines of code than our “vanilla” Scala implementation, and the results are strikingly better. This is a pattern we’ll continue to see: Scala Native making high-performance, low-level patterns more concise, legible, and safe than ever before.