Loading multiple CSVs for RNN data pipelines

Before wrapping up this chapter, here are a few notes about how we can load multiple CSV files, each containing one sequence, for RNN training and testing data. We are assuming to have a dataset made of multiple CSV files stored in a cluster (it could be HDFS or an object storage such as Amazon S3 or Minio), where each file represents a sequence, each row of one file contains the values for one time step only, the number of rows could be different across files, and the header row could be present or missing in all files.

With reference to CSV files saved in an S3-based object storage (refer to Chapter 3, Extract, Transform, Load, Data Ingestion from S3, for more details), the Spark context has been created as follows:

val conf = new SparkConf
conf.setMaster(master)
conf.setAppName("DataVec S3 Example")
val sparkContext = new JavaSparkContext(conf)

The Spark job configuration has been set up to access the object storage (as explained in Chapter 3Extract, Transform, Load), and we can get the data as follows:

val origData = sparkContext.binaryFiles("s3a://dl4j-bucket")

(dl4j-bucket is the bucket containing the CSV files). Next we create a DataVec CSVSequenceRecordReader specifying if all the CSV files in the bucket have the header row or not (use the value 0 for no, 1 for yes) and the values separator, as follows:

val numHeaderLinesEachFile = 0
val delimiter = ","
val seqRR = new CSVSequenceRecordReader(numHeaderLinesEachFile, delimiter)

Finally we get the sequence by applying a map transformation to the original data in seqRR, as follows:

val sequencesRdd = origData.map(new SequenceRecordReaderFunction(seqRR))

It is very similar in the case of RNN training with non-sequence CSV files, by using the DataVecDataSetFunction class of dl4j-spark and specifying the index of the label column and the number of labels for classification, as follows:

val labelIndex = 1
val numClasses = 4
val dataSetRdd = sequencesRdd.map(new DataVecSequenceDataSetFunction(labelIndex, numClasses, false))