Chapter 9. Object-Oriented Programming

image with no caption

Many programmers believe that object-oriented programming (OOP) makes for clearer, more reusable code. Though very different from the familiar OOP languages like C++, Java, and Python, R is very much OOP in outlook.

The following themes are key to R:

This chapter covers OOP in R. We’ll discuss programming in the two types of classes, S3 and S4, and then present a few useful OOP-related R utilities.

The original R structure for classes, known as S3, is still the dominant class paradigm in R use today. Indeed, most of R’s own built-in classes are of the S3 type.

An S3 class consists of a list, with a class name attribute and dispatch capability added. The latter enables the use of generic functions, as we saw in Chapter 1.S4 classes were developed later, with goal of adding safety, meaning that you cannot accidentally access a class component that is not already in existence.

As an example, let’s look at a simple regression analysis run via R’s lm() function. First, let’s see what lm() does:

> ?lm

The output of this help query will tell you, among other things, that this function returns an object of class "lm".

Let’s try creating an instance of this object and then printing it:

> x <- c(1,2,3)
> y <- c(1,3,8)
> lmout <- lm(y ˜ x)
> class(lmout)
[1] "lm"
> lmout

Call:
lm(formula = y ˜ x)

Coefficients:
(Intercept)            x
       −3.0          3.5

Here, we printed out the object lmout. (Remember that by simply typing the name of an object in interactive mode, the object is printed.) The R interpreter then saw that lmout was an object of class "lm" and thus called print.lm(), a special print method for the "lm" class. In R terminology, the call to the generic function print() was dispatched to the method print.lm() associated with the class "lm".

Let’s take a look at the generic function and the class method in this case:

> print
function(x, ...) UseMethod("print")
<environment: namespace:base>
> print.lm
function (x, digits = max(3, getOption("digits") - 3), ...)
{
    cat("\nCall:\n", deparse(x$call), "\n\n", sep = "")
    if (length(coef(x))) {
        cat("Coefficients:\n")
        print.default(format(coef(x), digits = digits), print.gap = 2,
            quote = FALSE)
    }
    else cat("No coefficients\n")
    cat("\n")
    invisible(x)
}
<environment: namespace:stats>

You may be surprised to see that print() consists solely of a call to UseMethod(). But this is actually the dispatcher function, so in view of print()’s role as a generic function, you should not be surprised after all.

Don’t worry about the details of print.lm(). The main point is that the printing depends on context, with a special print function called for the "lm" class. Now let’s see what happens when we print this object with its class attribute removed:

> unclass(lmout)
$coefficients
(Intercept)           x
       −3.0         3.5

$residuals
   1    2    3
 0.5 −1.0  0.5

$effects
(Intercept)           x
  −6.928203   −4.949747    1.224745

$rank
[1] 2
...

I’ve shown only the first few lines here—there’s a lot more. (Try running this on your own!) But you can see that the author of lm() decided to make print.lm() much more concise, limiting it to printing a few key quantities.

You can find all the implementations of a given generic method by calling methods(), like this:

> methods(print)
  [1] print.acf*
  [2] print.anova
  [3] print.aov*
  [4] print.aovlist*
  [5] print.ar*
  [6] print.Arima*
  [7] print.arima0*
  [8] print.AsIs
  [9] print.aspell*
 [10] print.Bibtex*
 [11] print.browseVignettes*
 [12] print.by
 [13] print.check_code_usage_in_package*
 [14] print.check_demo_index*
 [15] print.checkDocFiles*
 [16] print.checkDocStyle*
 [17] print.check_dotInternal*
 [18] print.checkFF*
 [19] print.check_make_vars*
 [20] print.check_package_code_syntax*
...

Asterisks denote nonvisible functions, meaning ones that are not in the default namespaces. You can find these functions via getAnywhere() and then access them by using a namespace qualifier. An example is print.aspell(). The aspell() function itself does a spellcheck on the file specified in its argument. For example, suppose the file wrds consists of this line:

Which word is mispelled?

In this case, this function will catch the misspelled word, as follows:

aspell("wrds")
mispelled
  wrds:1:15

The output says that there is the indicated spelling error in line 1, character 15 of the input file. But what concerns us here is the mechanism by which that output was printed.

The aspell() function returns an object of class "aspell", which does have its own generic print function, print.aspell(). In fact, that function was invoked in our example, after the call to aspell(), and the return value was printed out. At that time, R called UseMethod() on the object of class "aspell". But if we call that print method directly, R won’t recognize it:

> aspout <- aspell("wrds")
> print.aspell(aspout)
Error: could not find function "print.aspell"

However, we can find it by calling getAnywhere():

> getAnywhere(print.aspell)
A single object matching 'print.aspell' was found
It was found in the following places
  registered S3 method for print from namespace utils
  namespace:utils
with value

function (x, sort = TRUE, verbose = FALSE, indent = 2L, ...)
{
    if (!(nr <- nrow(x)))
...

So, the function is in the utils namespace, and we can execute it by adding such a qualifier:

> utils:::print.aspell(aspout)
mispelled
  wrds:1:15

You can see all the generic methods this way:

> methods(class="default")
...

S3 classes have a rather cobbled-together structure. A class instance is created by forming a list, with the components of the list being the member variables of the class. (Readers who know Perl may recognize this ad hoc nature in Perl’s own OOP system.) The "class" attribute is set by hand by using the attr() or class() function, and then various implementations of generic functions are defined. We can see this in the case of lm() by inspecting the function:

> lm
...
z <- list(coefficients = if (is.matrix(y))
                    matrix(,0,3) else numeric(0L), residuals = y,
          fitted.values = 0 * y, weights = w, rank = 0L,
          df.residual = if (is.matrix(y)) nrow(y) else length(y))
}
...
class(z) <- c(if(is.matrix(y)) "mlm", "lm")
...

Again, don’t mind the details; the basic process is there. A list was created and assigned to z, which will serve as the framework for the "lm" class instance (and which will eventually be the value returned by the function). Some components of that list, such as residuals, were already assigned when the list was created. In addition, the class attribute was set to "lm" (and possibly to "mlm", as will be explained in the next section).

As an example of how to write an S3 class, let’s switch to something simpler. Continuing our employee example from Section 4.1, we could write this:

> j <- list(name="Joe", salary=55000, union=T)
> class(j) <- "employee"
> attributes(j)  # let's check
$names
[1] "name"   "salary" "union"

$class
[1] "employee"

Before we write a print method for this class, let’s see what happens when we call the default print():

> j
$name
[1] "Joe"

$salary
[1] 55000

$union
[1] TRUE

attr(,"class")
[1] "employee"

Essentially, j was treated as a list for printing purposes.

Now let’s write our own print method:

print.employee <- function(wrkr) {
   cat(wrkr$name,"\n")
   cat("salary",wrkr$salary,"\n")
   cat("union member",wrkr$union,"\n")
}

So, any call to print() on an object of class "employee" should now be referred to print.employee(). We can check that formally:

> methods(,"employee")
[1] print.employee

Or, of course, we can simply try it out:

> j
Joe
salary 55000
union member TRUE

Now it’s time for a more involved example, in which we will write an R class "ut" for upper-triangular matrices. These are square matrices whose elements below the diagonal are zeros, such as shown in Equation 9-1.

Our motivation here is to save storage space (though at the expense of a little extra access time) by storing only the nonzero portion of the matrix.

The component mat of this class will store the matrix. As mentioned, to save on storage space, only the diagonal and above-diagonal elements will be stored, in column-major order. Storage for the matrix (9.1), for instance, consists of the vector (1,5,6,12,9,2), and the component mat has that value.

We will include a component ix in this class, to show where in mat the various columns begin. For the preceding case, ix is c(1,2,4), meaning that column 1 begins at mat[1], column 2 begins at mat[2], and column 3 begins at mat[4]. This allows for handy access to individual elements or columns of the matrix.

The following is the code for our class.

1    # class "ut", compact storage of upper-triangular matrices
2
3    # utility function, returns 1+...+i
4    sum1toi <- function(i) return(i*(i+1)/2)
5
6    # create an object of class "ut" from the full matrix inmat (0s included)
7    ut <- function(inmat) {
8       n <- nrow(inmat)
9       rtrn <- list()  # start to build the object
10       class(rtrn) <- "ut"
11       rtrn$mat <- vector(length=sum1toi(n))
12       rtrn$ix <- sum1toi(0:(n-1)) + 1
13       for (i in 1:n) {
14          # store column i
15          ixi <- rtrn$ix[i]
16          rtrn$mat[ixi:(ixi+i-1)] <- inmat[1:i,i]
17       }
18       return(rtrn)
19    }
20
21    # uncompress utmat to a full matrix
22    expandut <- function(utmat) {
23       n <- length(utmat$ix)  # numbers of rows and cols of matrix
24       fullmat <- matrix(nrow=n,ncol=n)
25       for (j in 1:n) {
26          # fill jth column
27          start <- utmat$ix[j]
28          fin <- start + j - 1
29          abovediagj <- utmat$mat[start:fin] # above-diag part of col j
30          fullmat[,j] <- c(abovediagj,rep(0,n-j))
31       }
32       return(fullmat)
33    }
34
35    # print matrix
36    print.ut <- function(utmat)
37       print(expandut(utmat))
38
39    # multiply one ut matrix by another, returning another ut instance;
40    # implement as a binary operation
41    "%mut%" <- function(utmat1,utmat2) {
42       n <- length(utmat1$ix)  # numbers of rows and cols of matrix
43       utprod <- ut(matrix(0,nrow=n,ncol=n))
44       for (i in 1:n) {  # compute col i of product
45          # let a[j] and bj denote columns j of utmat1 and utmat2, respectively,
46          # so that, e.g. b2[1] means element 1 of column 2 of utmat2
47          # then column i of product is equal to
48          #    bi[1]*a[1] + ... + bi[i]*a[i]
49          # find index of start of column i in utmat2
50          startbi <- utmat2$ix[i]
51          # initialize vector that will become bi[1]*a[1] + ... + bi[i]*a[i]
52          prodcoli <- rep(0,i)
53          for (j in 1:i) {  # find bi[j]*a[j], add to prodcoli
54             startaj <- utmat1$ix[j]
55             bielement <- utmat2$mat[startbi+j-1]
56             prodcoli[1:j] <- prodcoli[1:j] +
57                bielement * utmat1$mat[startaj:(startaj+j-1)]
58            }
59          # now need to tack on the lower 0s
60          startprodcoli <- sum1toi(i-1)+1
61          utprod$mat[startbi:(startbi+i-1)] <- prodcoli
62       }
63       return(utprod)
64    }

Let’s test it.

> test
function() {
   utm1 <- ut(rbind(1:2,c(0,2)))
   utm2 <- ut(rbind(3:2,c(0,1)))
   utp <- utm1 %mut% utm2
   print(utm1)
   print(utm2)
   print(utp)
   utm1 <- ut(rbind(1:3,0:2,c(0,0,5)))
   utm2 <- ut(rbind(4:2,0:2,c(0,0,1)))
   utp <- utm1 %mut% utm2
   print(utm1)
   print(utm2)
   print(utp)
}
> test()
     [,1] [,2]
[1,]    1    2
[2,]    0    2
     [,1] [,2]
[1,]    3    2
[2,]    0    1
     [,1] [,2]
[1,]    3    4
[2,]    0    2
     [,1] [,2] [,3]
[1,]    1    2    3
[2,]    0    1    2
[3,]    0    0    5
     [,1] [,2] [,3]
[1,]    4    3    2
[2,]    0    1    2
[3,]    0    0    1
     [,1] [,2] [,3]
[1,]    4    5    9
[2,]    0    1    4
[3,]    0    0    5

Throughout the code, we take into account the fact that the matrices involved have a lot of zeros. For example, we avoid multiplying by zeros simply by not adding terms to sums when the terms include a 0 factor.

The ut() function is fairly straightforward. This function is a constructor, which is a function whose job it is to create an instance of the given class, eventually returning that instance. So in line 9, we create a list that will serve as the body of the class object, naming it rtrn as a reminder that this will be the class instance to be constructed and returned.

As noted earlier, the main member variables of our class will be mat and idx, implemented as components of the list. Memory for these two components is allocated in lines 11 and 12.

The loop that follows then fills in rtrn$mat column by column and assigns rtrn$idx element by element. A slicker way to do this for loop would be to use the rather obscure row() and col() functions. The row() function takes a matrix input and returns a new matrix of the same size, but with each element replaced by its row number. Here’s an example:

> m
     [,1] [,2]
[1,]    1    4
[2,]    2    5
[3,]    3    6
> row(m)
     [,1] [,2]
[1,]    1    1
[2,]    2    2
[3,]    3    3

The col() function works similarly.

Using this idea, we could replace the for loop in ut() with a one-liner:

rtrn$mat <- inmat[row(inmat) <= col(inmat)]

Whenever possible, we should exploit vectorization. Take a look at line 12, for example:

rtrn$ix <- sum1toi(0:(n-1)) + 1

Since sum1toi() (which we defined on line 4) is based only on the vectorized functions "*"() and "+"(), sum1toi() itself is also vectorized. This allows us to apply sum1toi() to a vector above. Note that we used recycling as well.

We want our "ut" class to include some methods, not just variables. To this end, we have included three methods:

  • The expandut() function converts from a compressed matrix to an ordinary one. In expandut(), the key lines are 27 and 28, where we use rtrn$ix to determine where in utmat$mat the jth column of our matrix is stored. That data is then copied to the jth column of fullmat in line 30. Note the use of rep() to generate the zeros in the lower portion of this column.

  • The print.ut() function is for printing. This function is quick and easy, using expandut(). Recall that any call to print() on an object of type "ut" will be dispatched to print.ut(), as in our test cases earlier.

  • The "%mut%"() function is for multiplying two compressed matrices (without uncompressing them). This function starts in line 39. Since this is a binary operation, we take advantage of the fact that R accommodates user-defined binary operations, as described in Section 7.12, and implement our matrix-multiply function as %mut%.

Let’s look at the details of the "%mut%"() function. First, in line 43, we allocate space for the product matrix. Note the use of recycling in an unusual context. The first argument of matrix() is required to be a vector of a length compatible with the number of specified rows and columns, so the 0 we provide is recycled to a vector of length n2. Of course, rep() could be used instead, but exploiting recycling makes for a bit shorter, more elegant code.

For both clarity and fast execution, the code here has been written around the fact that R stores matrices in column-major order. As mentioned in the comments, our code then makes use of the fact that column i of the product can be expressed as a linear combination of the columns of the first factor. It will help to see a specific example of this property, shown in Equation 9-2.

The comments say that, for instance, column 3 of the product is equal to the following:

Inspection of Equation 9-2 confirms the relation.

Couching the multiplication problem in terms of columns of the two input matrices enables us to compact the code and to likely increase the speed. The latter again stems from vectorization, a benefit discussed in detail in Chapter 14. This approach is used in the loop beginning at line 53. (Arguably, in this case, the increase in speed comes at the expense of readability of the code.)

As another example, consider a statistical regression setting with one predictor variable. Since any statistical model is merely an approximation, in principle, you can get better and better models by fitting polynomials of higher and higher degrees. However, at some point, this becomes overfitting, so that the prediction of new, future data actually deteriorates for degrees higher than some value.

The class "polyreg" aims to deal with this issue. It fits polynomials of various degrees but assesses fits via cross-validation to reduce the risk of overfitting. In this form of cross-validation, known as the leaving-one-out method, for each point we fit the regression to all the data except this observation, and then we predict that observation from the fit. An object of this class consists of outputs from the various regression models, plus the original data.

The following is the code for the "polyreg" class.

1    # "polyreg," S3 class for polynomial regression in one predictor variable
2
3    # polyfit(y,x,maxdeg) fits all polynomials up to degree maxdeg; y is
4    # vector for response variable, x for predictor; creates an object of
5    # class "polyreg"
6    polyfit <- function(y,x,maxdeg) {
7       # form powers of predictor variable, ith power in ith column
8       pwrs <- powers(x,maxdeg)  # could use orthog polys for greater accuracy
9       lmout <- list()  # start to build class
10       class(lmout) <- "polyreg"  # create a new class
11       for (i in 1:maxdeg) {
12          lmo <- lm(y ˜ pwrs[,1:i])
13          # extend the lm class here, with the cross-validated predictions
14          lmo$fitted.cvvalues <- lvoneout(y,pwrs[,1:i,drop=F])
15          lmout[[i]] <- lmo
16       }
17       lmout$x <- x
18       lmout$y <- y
19       return(lmout)
20    }
21
22    # print() for an object fits of class "polyreg":  print
23    # cross-validated mean-squared prediction errors
24    print.polyreg <- function(fits) {
25       maxdeg <- length(fits) - 2
26       n <- length(fits$y)
27       tbl <- matrix(nrow=maxdeg,ncol=1)
28       colnames(tbl) <- "MSPE"
29       for (i in 1:maxdeg) {
30          fi <- fits[[i]]
31          errs <- fits$y - fi$fitted.cvvalues
32          spe <- crossprod(errs,errs)  # sum of squared prediction errors
33          tbl[i,1] <- spe/n
34       }
35       cat("mean squared prediction errors, by degree\n")
36       print(tbl)
37    }
38
39    # forms matrix of powers of the vector x, through degree dg
40    powers <- function(x,dg) {
41       pw <- matrix(x,nrow=length(x))
42       prod <- x
43       for (i in 2:dg) {
44          prod <- prod * x
45          pw <- cbind(pw,prod)
46       }
47       return(pw)
48    }
49
50    # finds cross-validated predicted values; could be made much faster via
51    # matrix-update methods
52    lvoneout <- function(y,xmat) {
53       n <- length(y)
54       predy <- vector(length=n)
55       for (i in 1:n) {
56          # regress, leaving out ith observation
57          lmo <- lm(y[-i] ˜ xmat[-i,])
58          betahat <- as.vector(lmo$coef)
59          # the 1 accommodates the constant term
60          predy[i] <- betahat %*% c(1,xmat[i,])
61       }
62       return(predy)
63    }
64
65    # polynomial function of x, coefficients cfs
66    poly <- function(x,cfs) {
67       val <- cfs[1]
68       prod <- 1
69       dg <- length(cfs) - 1
70       for (i in 1:dg) {
71          prod <- prod * x
72          val <- val + cfs[i+1] * prod
73       }
74    }

As you can see, "polyreg" consists of polyfit(), the constructor function, and print.polyreg(), a print function tailored to this class. It also contains several utility functions to evaluate powers and polynomials and to perform cross-validation. (Note that in some cases here, efficiency has been sacrificed for clarity.)

As an example of using the class, we’ll generate some artificial data and create an object of class "polyreg" from it, printing out the results.

> n <- 60
> x <- (1:n)/n
> y <- vector(length=n)
> for (i in 1:n) y[i] <- sin((3*pi/2)*x[i]) + x[i]^2 + rnorm(1,mean=0,sd=0.5)
> dg <- 15
> (lmo <- polyfit(y,x,dg))
mean squared prediction errors, by degree
           MSPE
 [1,] 0.4200127
 [2,] 0.3212241
 [3,] 0.2977433
 [4,] 0.2998716
 [5,] 0.3102032
 [6,] 0.3247325
 [7,] 0.3120066
 [8,] 0.3246087
 [9,] 0.3463628
[10,] 0.4502341
[11,] 0.6089814
[12,] 0.4499055
[13,]        NA
[14,]        NA
[15,]        NA

Note first that we used a common R trick in this command:

> (lmo <- polyfit(y,x,dg))

By surrounding the entire assignment statement in parentheses, we get the printout and form lmo at the same time, in case we need the latter for other things.

The function polyfit() fits polynomial models up through a specified degree, in this case 15, calculating the cross-validated mean squared prediction error for each model. The last few values in the output were NA, because roundoff error considerations led R to refuse to fit polynomials of degrees that high.

So, how is it all done? The main work is handled by the function polyfit(), which creates an object of class "polyreg". That object consists mainly of the objects returned by the R regression fitter lm() for each degree.

In forming those objects, note line 14:

lmo$fitted.cvvalues <- lvoneout(y,pwrs[,1:i,drop=F])

Here, lmo is an object returned by lm(), but we are adding an extra component to it: fitted.cvvalues. Since we can add a new component to a list at any time, and since S3 classes are lists, this is possible.

We also have a method for the generic function print(), print.polyreg() in line 24. In Section 12.1.5, we will add a method for the plot() generic function, plot.polyreg().

In computing prediction errors, we used cross-validation, or the leaving-one-out method, in a form that predicts each observation from all the others. To implement this, we take advantage of R’s use of negative subscripts in line 57:

lmo <- lm(y[-i] ˜ xmat[-i,])

So, we are fitting the model with the ith observation deleted from our data set.