1 Introduction

This document showcases how to expose user-defined Stan functions using CmdStanR. To run the examples below, you need to have CmdStanR, Rcpp and RcppEigen packages installed and an installation of CmdStan. The minimum CmdStan version is 2.26.2, though I recommend running with 2.29.2 or newer. If you experience issues using this function or running the tutorial below or have a feature request, please report it on the issue tracker here.

2 Preparing the environment

We first need to load the function expose_cmdstanr_functions() that will do all the hard work for us. The function is available on Github here. You can also copy and source the code below.

expose_cmdstanr_functions <- function(model_path, include_paths = NULL,
                                     expose_to_global_env = FALSE) {
  required_pkgs <- c("Rcpp", "RcppEigen", "cmdstanr")
  found_pkgs <- required_pkgs %in% rownames(installed.packages())
  if (!all(found_pkgs)) {
    stop(
      "The following required packages are missing: ",
      paste0(required_packages[!found_pkgs], collapse = ", "),
      "."
    )
  }
  if (cmdstanr::cmdstan_version() < "2.26.0") {
    stop("Please install CmdStan version 2.26 or newer.", call. = FALSE)
  }
  get_cmdstan_flags <- function(flag_name) {
    cmdstan_path <- cmdstanr::cmdstan_path()
    flags <- processx::run(
      "make", 
      args = c(paste0("print-", flag_name)),
      wd = cmdstan_path
    )$stdout
    flags <- gsub(
      pattern = paste0(flag_name, " ="),
      replacement = "", x = flags, fixed = TRUE
    )
    flags <- gsub(
      pattern = " stan/", replacement = paste0(" ", cmdstan_path, "/stan/"),
      x = flags, fixed = TRUE
    )
    flags <- gsub(
      pattern = "-I lib/", replacement = paste0("-I ", cmdstan_path, "/lib/"),
      x = flags, fixed = TRUE
    )
    flags <- gsub(
      pattern = "-I src", replacement = paste0("-I ", cmdstan_path, "/src"),
      x = flags, fixed = TRUE
    )
    gsub("\n", "", flags)
  }
  temp_stan_file <- tempfile(pattern = "model-", fileext = ".stan")
  temp_cpp_file <- paste0(tools::file_path_sans_ext(temp_stan_file), ".cpp")
  file.copy(model_path, temp_stan_file, overwrite = TRUE)
  if (isTRUE(.Platform$OS.type == "windows")) {
    stanc3 <- "./bin/stanc.exe"
  } else {
    stanc3 <- "./bin/stanc"
  }
  processx::run(
    stanc3,
    args = c(
      temp_stan_file,
      "--standalone-functions",
      paste0("--include-paths=", include_paths),
      paste0("--o=",temp_cpp_file)
    ),
    wd = cmdstanr::cmdstan_path()
  )
  code <- paste(readLines(temp_cpp_file), collapse = "\n")
  code <- paste(
    "// [[Rcpp::depends(RcppEigen)]]",
    "#include <stan/math/prim/fun/Eigen.hpp>",
    "#include <RcppCommon.h>
    #include <boost/random/additive_combine.hpp>
    #include <iostream>

    namespace Rcpp {
      SEXP wrap(boost::ecuyer1988 RNG);
      SEXP wrap(boost::ecuyer1988& RNG);
      SEXP wrap(std::ostream stream);
      template <> boost::ecuyer1988 as(SEXP ptr_RNG);
      template <> boost::ecuyer1988& as(SEXP ptr_RNG);
      template <> std::ostream* as(SEXP ptr_stream);
      namespace traits {
        template <> class Exporter<boost::ecuyer1988&>;
        template <> struct input_parameter<boost::ecuyer1988&>;
      }
    }

    #include <Rcpp.h>

    namespace Rcpp {
      SEXP wrap(boost::ecuyer1988 RNG){
        boost::ecuyer1988* ptr_RNG = &RNG;
        Rcpp::XPtr<boost::ecuyer1988> Xptr_RNG(ptr_RNG);
        return Xptr_RNG;
      }

      SEXP wrap(boost::ecuyer1988& RNG){
        boost::ecuyer1988* ptr_RNG = &RNG;
        Rcpp::XPtr<boost::ecuyer1988> Xptr_RNG(ptr_RNG);
        return Xptr_RNG;
      }

      SEXP wrap(std::ostream stream) {
        std::ostream* ptr_stream = &stream;
        Rcpp::XPtr<std::ostream> Xptr_stream(ptr_stream);
        return Xptr_stream;
      }

      template <> boost::ecuyer1988 as(SEXP ptr_RNG) {
        Rcpp::XPtr<boost::ecuyer1988> ptr(ptr_RNG);
        boost::ecuyer1988& RNG = *ptr;
        return RNG;
      }

      template <> boost::ecuyer1988& as(SEXP ptr_RNG) {
        Rcpp::XPtr<boost::ecuyer1988> ptr(ptr_RNG);
        boost::ecuyer1988& RNG = *ptr;
        return RNG;
      }

      template <> std::ostream* as(SEXP ptr_stream) {
        Rcpp::XPtr<std::ostream> ptr(ptr_stream);
        return ptr;
      }

      namespace traits {
        template <> class Exporter<boost::ecuyer1988&> {
        public:
          Exporter( SEXP x ) : t(Rcpp::as<boost::ecuyer1988&>(x)) {}
          inline boost::ecuyer1988& get() { return t ; }
        private:
          boost::ecuyer1988& t ;
        } ;

        template <>
        struct input_parameter<boost::ecuyer1988&> {
          typedef
          typename Rcpp::ConstReferenceInputParameter<boost::ecuyer1988&> type ;
          //typedef typename boost::ecuyer1988& type ;
        };
      }
    }

    RcppExport SEXP get_stream_() {
      std::ostream* pstream(&Rcpp::Rcout);
      Rcpp::XPtr<std::ostream> ptr(pstream, false);
      return ptr;
    }

    RcppExport SEXP get_rng_(SEXP seed) {
      int seed_ = Rcpp::as<int>(seed);
      boost::ecuyer1988* rng = new boost::ecuyer1988(seed_);
      Rcpp::XPtr<boost::ecuyer1988> ptr(rng, true);
      return ptr;
    }
    ",
    "#include <RcppEigen.h>",
    code,
    sep = "\n"
  )
  code <- gsub("// [[stan::function]]",
               "// [[Rcpp::export]]", code, fixed = TRUE)
  code <- gsub(
    "stan::math::accumulator<double>& lp_accum__, std::ostream* pstream__ = nullptr){",
    "std::ostream* pstream__ = nullptr){\nstan::math::accumulator<double> lp_accum__;",
    code,
    fixed = TRUE
  )
  code <- gsub("__ = nullptr", "__ = 0", code, fixed = TRUE)

  get_stream <- function() {
    return(.Call('get_stream_'))
  }
  get_rng <- function(seed=0L) {
    if (!identical(seed, 0L)) {
      if (length(seed) != 1)
        stop("Seed must be a length-1 integer vector.")
    }
    return(.Call('get_rng_', seed))
  }
  if (expose_to_global_env) {
    env = globalenv()
  } else {
    env = new.env()
  }
  compiled <- withr::with_makevars(
    c(
      USE_CXX14 = 1,
      PKG_CPPFLAGS = "",
      PKG_CXXFLAGS = get_cmdstan_flags("CXXFLAGS"),
      PKG_LIBS = paste0(
        get_cmdstan_flags("LDLIBS"),
        get_cmdstan_flags("LIBSUNDIALS"),
        get_cmdstan_flags("TBB_TARGETS"),
        get_cmdstan_flags("LDFLAGS_TBB")
      )
    ),
    Rcpp::sourceCpp(code = code, env = env)
  )
  for (x in compiled$functions) {
    FUN <- get(x, envir = env)
    args <- formals(FUN)
    args$pstream__ <- get_stream()
    if ("lp__" %in% names(args)) args$lp__ <- 0
    if ("base_rng__" %in% names(args)) args$base_rng__ <- get_rng()
    formals(FUN) <- args
    assign(x, FUN, envir = env)
  }
  assign("stan_rng__", get_rng, envir = env)
  if (expose_to_global_env) {
    invisible(NULL)
  } else {
    return(env)
  }
}

3 Simple example

Once we load the above expose_cmdstanr_functions() function into the R environmentm we can try it out on a Stan function that applies the softmax() to each row of an input matrix. We can write the function(s) in a file with a .stan or .stanfunctions suffix or a string inside an R script. Using separate files is recommended, as developing functions inside strings is not user friendly. To make it a bit easier for you to copy and run this tutorial on your own, this example is using Stan functions written in strings.

model_code <- "
functions {
  matrix rows_softmax(matrix x) {
    matrix[rows(x), cols(x)] y;
    for(i in 1:rows(x)) {
      y[i, :] = softmax(x[i, ]')';
    }
    return y;
  }
}
"

We then use the write_stan_file() utility function from cmdstanr to store the model code in a file. If you prefer, you can also use the write() function in base R or any other function that writes a string in a file. Make sure the file has a .stanor .stanfunctions extension.

stan_file <- cmdstanr::write_stan_file(code = model_code)

Finally, we supply the written stan_file to expose_cmdstanr_functions():

udfs <- expose_cmdstanr_functions(model_path = stan_file)

The function returns a new environment with the Stan UDF in it. We can thus use udfs$rows_softmax() directly in R. We create a matrix of random values

input_matrix <- matrix(runif(25, 0, 2), nrow = 5)
input_matrix
##           [,1]      [,2]      [,3]     [,4]      [,5]
## [1,] 0.3372246 0.8277205 1.9614013 1.454146 0.7730535
## [2,] 1.3449728 1.4893030 1.7413037 1.377557 0.4452663
## [3,] 1.7938820 1.3789612 0.4099089 0.279518 0.7293895
## [4,] 1.7489683 1.2871712 1.0768524 1.586360 1.7205503
## [5,] 0.3748405 1.1560285 0.8086070 1.200402 0.5067908

pass the input matrix to the Stan UDF

res <- udfs$rows_softmax(input_matrix)
res
##            [,1]      [,2]      [,3]       [,4]       [,5]
## [1,] 0.08124108 0.1326769 0.4122367 0.24822677 0.12561854
## [2,] 0.19679475 0.2273502 0.2925081 0.20331275 0.08003426
## [3,] 0.40390587 0.2667364 0.1012111 0.08883829 0.13930827
## [4,] 0.25235426 0.1590211 0.1288589 0.21448195 0.24528380
## [5,] 0.12267925 0.2679395 0.1893013 0.28009665 0.13998335

and finally validate that rows do actually sum to 1.

apply(res, 1, sum)
## [1] 1 1 1 1 1

And that is it! Well, at least if you do not plan on working with random number generators. If you do, continue with the next section.

4 RNG user-defined functions

When working with RNG functions, we need a way to specify the seed for the random number generator to make things reproducible. In the example we will use the gpareto_rng function from Aki Vehtari’s case study. You can also find it in Sean Pinkney’s helpful Stan functions repo. We can proceed to expose the function:

gpareto_code <- "
functions {
  real gpareto_rng(real ymin, real k, real sigma) {
    if (sigma <= 0) 
      reject(\"sigma <= 0; found sigma = \", sigma);
    
    if (fabs(k) > 1e-15) 
      return ymin + (uniform_rng(0, 1) ^ -k - 1) * sigma / k;
    else 
      return ymin - sigma * log(uniform_rng(0, 1)); // limit k->0
  }
}
"
stan_file_rng <- cmdstanr::write_stan_file(code = gpareto_code)
udfs_rng <- expose_cmdstanr_functions(model_path = stan_file_rng)

We define some inputs and generat a few values:

ymin <- rexp(1)
k <- rexp(1,5)
sigma <- rexp(1)
udfs_rng$gpareto_rng(ymin, k, sigma)
## [1] 1.285279
udfs_rng$gpareto_rng(ymin, k, sigma)
## [1] 1.293365
udfs_rng$gpareto_rng(ymin, k, sigma)
## [1] 1.42966

If we do not specify a seed, the random generator is created using the default seed.

If we examine what names are in the returned environment, we will find the gpareto_rng function we expected. But, we will also find a function called stan_rng__. This function is used to create a random number generator object with a specified seed. The function has a double underscores suffix in order to avoid any chance of shadowing with a potential Stan UDF function.

names(udfs_rng)
## [1] "stan_rng__"  "gpareto_rng"

If we want to make our code reproducible, we generate an RNG object with a seed and then the RNG function and set the base_rng__ argument of the exposed function to our created RNG object. The generator object advances with each call to the Stan RNG user-defined function.

ymin <- rexp(1)
k <- rexp(1,5)
sigma <- rexp(1)
seed <- udfs_rng$stan_rng__(1)
vals <- c()
for(i in 1:10) {
  vals <- c(vals, udfs_rng$gpareto_rng(ymin, k, sigma, base_rng__ = seed))
}
vals
##  [1] 1.256861 1.270420 1.489992 1.866879 3.310240 2.306312 1.475671 2.727928
##  [9] 2.161117 2.284900

If we then create a separate generator object with the same seed we will get the same stream of generated values.

seed <- udfs_rng$stan_rng__(1)
vals <- c()
for(i in 1:10) {
  vals <- c(vals, udfs_rng$gpareto_rng(ymin, k, sigma, base_rng__ = seed))
}
vals
##  [1] 1.256861 1.270420 1.489992 1.866879 3.310240 2.306312 1.475671 2.727928
##  [9] 2.161117 2.284900

5 Exposing UDFs to the global namespace

If you want to expose the functions directly to the global R namespace, set the expose_to_global_env argument to TRUE.

expose_cmdstanr_functions(model_path = stan_file, expose_to_global_env = TRUE)
rows_softmax(input_matrix)
##            [,1]      [,2]      [,3]       [,4]       [,5]
## [1,] 0.08124108 0.1326769 0.4122367 0.24822677 0.12561854
## [2,] 0.19679475 0.2273502 0.2925081 0.20331275 0.08003426
## [3,] 0.40390587 0.2667364 0.1012111 0.08883829 0.13930827
## [4,] 0.25235426 0.1590211 0.1288589 0.21448195 0.24528380
## [5,] 0.12267925 0.2679395 0.1893013 0.28009665 0.13998335