「R」用purrr实现迭代

2020-07-03 17:48:51 浏览数 (2)

本文来源于 2018 年学习《R for Data Science》写的笔记。一起复习一下吧~

函数有3个好处:

  • 更容易看清代码意图
  • 更容易对需求变化做出反应(改变)
  • 更容易减少程序bug

除了函数,减少重复代码的另一种工具是迭代,它的作用在于可以对多个输入执行同一种处理,比如对多个列或多个数据集进行同样的操作。

迭代方式主要有两种:

  • 命令式编程 - for和while
  • 函数式编程 - purrr
准备工作

purrr是tidyverse的核心r包之一,提供了一些更加强大的编程工具。(读者可以点击原文获取小抄

代码语言:javascript复制
library(tidyverse)
#> ─ Attaching packages ─────────────────────────────────────────────────── tidyverse 1.2.1 ─
#> ✔ ggplot2 3.0.0     ✔ purrr   0.2.5
#> ✔ tibble  1.4.2     ✔ dplyr   0.7.6
#> ✔ tidyr   0.8.1     ✔ stringr 1.3.1
#> ✔ readr   1.1.1     ✔ forcats 0.3.0
#> ─ Conflicts ──────────────────────────────────────────────────── tidyverse_conflicts() ─
#> ✖ dplyr::filter() masks stats::filter()
#> ✖ dplyr::lag()    masks stats::lag()

for循环与函数式编程

因为R是一门函数式编程语言,我们可以先将for循环包装在函数中,然后再调用函数,而不是使用for循环,因此for循环在R中不像在其他编程语言中那么重要。

为了说明函数式编程,我们先利用下面简单的数据框进行一些思考:

代码语言:javascript复制
df = tibble(
    a = rnorm(10),
    b = rnorm(10),
    c = rnorm(10),
    d = rnorm(10)
)

如果想要计算每列的均值,我们使用for循环完成任务:

代码语言:javascript复制
output = vector("double", length(df))

for (i in seq_along(df)) {
    output[[i]] = mean(df[[i]])
}

output
#> [1]  0.45635 -0.17938  0.32879  0.00263

然后我们可能意识到需要频繁地计算每列的均值,因此将代码提取出来,转换为一个函数:

代码语言:javascript复制
col_mean = function(df) {
    output = vector("double", length(df))
    for ( i in seq_along(df)) {
        output[i] = mean(df[[i]])
    }
    
    output
}

然后我们觉得可能还需要这样计算每列的中位数和标准差,因此复制粘贴了col_mean(),并使用相应的median()sd()函数替换了mean()函数:

代码语言:javascript复制
col_median = function(df) {
    output = vector("double", length(df))
    for ( i in seq_along(df)) {
        output[i] = median(df[[i]])
    }
    
    output
}

col_sd = function(df) {
    output = vector("double", length(df))
    for ( i in seq_along(df)) {
        output[i] = sd(df[[i]])
    }
    
    output
}

(有时候我还真这么干的。)

哎呀,我们又复制粘贴了2次代码,因此是不是该思考下如何扩展一个代码让它同时发挥几个函数的功能呢?这段代码的大部分是一个for循环,而且如果不仔细很难看出3个函数有什么差别。

通过添加支持函数到每列的参数,我们可以使用同一个函数解决3个问题:

代码语言:javascript复制
col_summary = function(df, fun){
    out = vector("double", length(df))
    for (i in seq_along(df)) {
        out[i] = fun(df[[i]])
    }
    out
}

col_summary(df, median)
#> [1] 0.4666 0.0269 0.6161 0.0573
col_summary(df, mean)
#> [1]  0.45635 -0.17938  0.32879  0.00263

将函数作为参数传入另一个函数的做法是一种非常强大的功能,我们需要花些时间理解这种思想,但绝对是值得的。接下来我们将学习和使用purrr包,它提供的函数可以替代很多常见的for循环应用。R基础包中的apply应用函数族也可以完成类似的任务,但purrr包的函数更一致,也更容易学习。

使用purrr函数替代for循环的目的是将常见的列表问题分解为独立的几部分

  • 对于列表的单个元素,我们能找到解决办法吗?如果可以,我们就能使用purrr将该方法扩展到列表的所有元素。
  • 如果我们面临的是一个复杂的问题,那么将其分解为可行的子问题,然后依次解决。使用purrr,我们可以解决子问题,然后用管道将其组合起来。

映射函数

先对向量进行循环,然后对其每一个元素进行一番处理,最后保存结果。这种模式太普遍了,因而purrr包提供了一个函数族替我们完成这种操作。每种类型的输出都有一个相应的函数:

  • map()用于输出列表
  • map_lgl()用于输出逻辑型向量
  • map_dbl()用于输出双精度型向量
  • map_chr()用于输出字符型向量

每个函数都使用一个向量(注意列表可以作为递归向量看待)作为输入,并对向量的每个元素应用一个函数,然后返回和输入向量同样长度的一个新向量。向量的类型由映射函数的后缀决定。

使用map()函数族的优势不是速度,而是简洁:它可以让我们的代码更易编写,也更易阅读。

下面是进行上一节一样的操作:

代码语言:javascript复制
library(purrr)

map_dbl(df, mean)
#>        a        b        c        d 
#>  0.45635 -0.17938  0.32879  0.00263
map_dbl(df, median)
#>      a      b      c      d 
#> 0.4666 0.0269 0.6161 0.0573
map_dbl(df, sd)
#>     a     b     c     d 
#> 0.608 1.086 0.797 0.873

**与for循环相比,映射函数的重点在于需要执行的操作(即mean()median()sd()),而不是在所有元素中循环所需的跟踪记录以及保存结果。使用管道时这一点尤为突出:

代码语言:javascript复制
df %>% map_dbl(mean)
#>        a        b        c        d 
#>  0.45635 -0.17938  0.32879  0.00263
df %>% map_dbl(median)
#>      a      b      c      d 
#> 0.4666 0.0269 0.6161 0.0573
df %>% map_dbl(sd)
#>     a     b     c     d 
#> 0.608 1.086 0.797 0.873

map_*()col_summary()具有以下几点区别:

  • 所有的purrr函数都是用C实现的,这让它们的速度非常快,但牺牲了一些可读性。
  • 第二个参数可以是一个公式、一个字符向量或一个整型向量。
  • map_*()使用....f传递一些附加参数,供每次调用时使用
  • 映射函数还保留名称

快捷方式

对于第二个参数.f,我们可以使用几种快捷方式来减少输入量。比如我们现在想对某个数据集中的每一个分组都拟合一个线性模型,下面示例将mtcars数据集拆分为3个部分(按照气缸值分类),并对每个部分拟合一个线性模型:

代码语言:javascript复制
models = mtcars %>% 
    split(.$cyl) %>% 
    map(function(df) lm(mpg ~ wt, data = df))

因为在R中创建匿名函数的语法比较复杂,所以purrr提供了一种更方便的快捷方式——单侧公式:

代码语言:javascript复制
models = mtcars %>% 
    split(.$cyl) %>% 
    map(~lm(mpg ~ wt, data = .))

上面.作为一个代词:它表示当前列表元素(与for循环中用i表示当前索引是一样的)。

当检查多个模型时,有时候我们需要提取像R方这样的摘要统计量,要想完成这个任务,我们需要先运行summary()函数,然后提取结果中的r.squared:

代码语言:javascript复制
models %>% 
    map(summary) %>% 
    map_dbl(~.$r.squared)
#>     4     6     8 
#> 0.509 0.465 0.423

因为提取命名成分操作非常普遍,所以purrr提供了一种更简单的快捷方式:使用字符串。

代码语言:javascript复制
models %>% 
    map(summary) %>% 
    map_dbl("r.squared")
#>     4     6     8 
#> 0.509 0.465 0.423

对操作失败的处理

当使用映射函数重复多次操作时,某次操作失败的概率大大增加。这个时候我们会收到一条错误信息,但得不到任何结果。这让人很恼火!我们怎么保证不会出现一条鱼腥了一锅汤?

safely()是一个修饰函数(副词),它接收一个函数(动词),对其进行修改并返回修改后的函数。这样,修改后的函数就不会抛出错误,相反,它总是返回由下面两个元素组成的列表:

  • result - 原始结果。如果出现错误,那么它就是NULL
  • error - 错误对象。如果操作成功,那么它就是NULL

下面用log()函数进行说明:

代码语言:javascript复制
safe_log = safely(log)
str(safe_log(10))
#> List of 2
#>  $ result: num 2.3
#>  $ error : NULL

str(safe_log("a"))
#> List of 2
#>  $ result: NULL
#>  $ error :List of 2
#>   ..$ message: chr "数学函数中用了非数值参数"
#>   ..$ call   : language log(x = x, base = base)
#>   ..- attr(*, "class")= chr [1:3] "simpleError" "error" "condition"

safely()函数也可以与map()共同使用:

代码语言:javascript复制
x = list(1, 10, "a")
y = x %>% map(safely(log))
str(y)
#> List of 3
#>  $ :List of 2
#>   ..$ result: num 0
#>   ..$ error : NULL
#>  $ :List of 2
#>   ..$ result: num 2.3
#>   ..$ error : NULL
#>  $ :List of 2
#>   ..$ result: NULL
#>   ..$ error :List of 2
#>   .. ..$ message: chr "数学函数中用了非数值参数"
#>   .. ..$ call   : language log(x = x, base = base)
#>   .. ..- attr(*, "class")= chr [1:3] "simpleError" "error" "condition"

如果将以上结果转换为2个列表,一个列表包含所有错误对象,另一个列表包含所有原始结果,那么处理起来就会更容易。我们可以使用purrr::transpose()函数轻松完成该任务

代码语言:javascript复制
y = y %>% transpose()
str(y)
#> List of 2
#>  $ result:List of 3
#>   ..$ : num 0
#>   ..$ : num 2.3
#>   ..$ : NULL
#>  $ error :List of 3
#>   ..$ : NULL
#>   ..$ : NULL
#>   ..$ :List of 2
#>   .. ..$ message: chr "数学函数中用了非数值参数"
#>   .. ..$ call   : language log(x = x, base = base)
#>   .. ..- attr(*, "class")= chr [1:3] "simpleError" "error" "condition"

我们可以自行决定如何处理错误对象,一般来说,我们应该检查一下y中错误对象所对应的x值,或者使用y中的正常结果进行一些处理:

代码语言:javascript复制
is_ok = y$error %>% map_lgl(is_null)
x[!is_ok]
#> [[1]]
#> [1] "a"

y$result[is_ok] %>% flatten_dbl()
#> [1] 0.0 2.3

purrr还提供了两个有用的修饰函数:

  • safely()类似,possibly()函数总是会成功返回。它比safely()还要简单一些,因为可以设定出现错误时返回一个默认值:
代码语言:javascript复制
x = list(1, 10, "a")
x %>% map_dbl(possibly(log, NA_real_))
#> [1] 0.0 2.3  NA
  • quietly()函数与safely()的作用基本相同,但前者结果不包含错误对象,而是包含输出、消息和警告:
代码语言:javascript复制
x = list(1, -1)
x %>% map(quietly(log)) %>% str()
#> List of 2
#>  $ :List of 4
#>   ..$ result  : num 0
#>   ..$ output  : chr ""
#>   ..$ warnings: chr(0) 
#>   ..$ messages: chr(0) 
#>  $ :List of 4
#>   ..$ result  : num NaN
#>   ..$ output  : chr ""
#>   ..$ warnings: chr "产生了NaNs"
#>   ..$ messages: chr(0)

x %>% map(safely(log)) %>% str()
#> Warning in .f(...): 产生了NaNs
#> List of 2
#>  $ :List of 2
#>   ..$ result: num 0
#>   ..$ error : NULL
#>  $ :List of 2
#>   ..$ result: num NaN
#>   ..$ error : NULL

多参数映射

前面我们提到的映射函数都是对单个输入进行映射,但有时候我们需要多个相关输入同步迭代,这就是map2()和pmap()函数的用武之地

例如我们想模拟几个均值不同的随机正态分布,我们可以使用map完成这个任务:

代码语言:javascript复制
mu = list(5, 10, -3)
mu %>% 
    map(rnorm, n = 5) %>% 
    str()
#> List of 3
#>  $ : num [1:5] 5.65 6.48 6.35 4.61 4.74
#>  $ : num [1:5] 8.93 8.93 10.67 10.98 8.72
#>  $ : num [1:5] -4.04 -3.25 -2.16 -3.02 -2.53

如果我们想让标准差也不同,一种方法是使用均值向量和标准差向量的索引进行迭代:

代码语言:javascript复制
sigma = list(1, 5, 10)
seq_along(mu) %>% 
    map(~rnorm(5, mu[[.]], sigma[[.]])) %>% 
    str()
#> List of 3
#>  $ : num [1:5] 4.5 4.73 4.43 6.19 5.47
#>  $ : num [1:5] 8.71 8.59 18.26 7.93 4.93
#>  $ : num [1:5] -21.46 -7.94 -21.41 5.66 2.38

但这种方式比较难理解,我们使用map2()进行同步迭代:

代码语言:javascript复制
map2(mu, sigma, rnorm, n = 5) %>% str()
#> List of 3
#>  $ : num [1:5] 6.08 6.72 7.59 5.21 3.99
#>  $ : num [1:5] 13.44 6.81 3.61 22.29 14.29
#>  $ : num [1:5] 4.05 -1.77 -2.77 0.69 -23.91

注意这里每次调用时值发生变换的参数要放在映射函数前面,值不变的参数要放在映射函数后面。

map()函数一样,map2()函数也是对for循环的包装:

代码语言:javascript复制
map2 = function(x, y, f, ...){
    out = vector("list", length(x))
    for (i in seq_along(x)) {
        out[[i]] = f(x[[i]], y[[i]], ...)
    }
    out
}

(实际的map2()并不是这样的,此处是给出R实现的一种思想)

根据这个函数,我们可以涉及map3()map4()等等,但这样实在无聊。purrr提供了pmap()函数,它可以将列表作为参数。如果我们想要生成均值、标准差和样本数都不同的正态分布,可以使用:

代码语言:javascript复制
n = list(1, 3, 5)
args1 = list(n, mu, sigma)

args1 %>% 
    pmap(rnorm) %>% 
    str()
#> List of 3
#>  $ : num 3.55
#>  $ : num [1:3] 8.4 10.9 -3.3
#>  $ : num [1:5] 3.9 -11.61 2.06 7.14 -16.25

如果没有为列表元素命名,那么pmap()在调用函数时会按照位置匹配。这样做容易出错而且可读性差,因此最后使用命名参数:

代码语言:javascript复制
args2 = list(mean = mu, sd = sigma, n = n)
args2 %>% 
    pmap(rnorm) %>% 
    str()
#> List of 3
#>  $ : num 6.18
#>  $ : num [1:3] 11.2 18 14.8
#>  $ : num [1:5] -5.27 6.57 1.88 6.53 -8.35

这样更加安全。

因为长度都相同,所以将各个参数保存在一个数据框中:

代码语言:javascript复制
params = tibble::tribble(
    ~mean, ~sd, ~n,
    5, 1, 1,
    10, 5, 3,
    -3, 10, 5
)

params %>% 
    pmap(rnorm)
#> [[1]]
#> [1] 5.41
#> 
#> [[2]]
#> [1]  5.4 10.2 14.4
#> 
#> [[3]]
#> [1] -8.653 -4.457  9.747 -4.916 -0.436
调用不同的函数

还有一种更复杂的情况:不但传给函数的参数不同,甚至函数本身也是不同的。

代码语言:javascript复制
f = c("runif", "rnorm", "rpois")
param = list(
    list(min = -1, max = 1),
    list(sd = 5),
    list(lambda = 10)
)

为了处理这种情况,我们使用invoke_map()函数:

代码语言:javascript复制
invoke_map(f, param, n = 5) %>% str()
#> List of 3
#>  $ : num [1:5] 0.167 -0.235 -0.366 -0.933 0.304
#>  $ : num [1:5] 6.961 3.642 13.405 0.536 -2.078
#>  $ : int [1:5] 8 8 8 6 11

第1个参数是一个函数列表或包含函数名称的字符串向量。第2个参数是列表的一个列表,给出了要传给各个函数的不同参数。随后的参数要传给每个函数

我们使用tribble()让参数配对更容易:

代码语言:javascript复制
sim = tibble::tribble(
    ~f, ~params,
    "runif", list(min = -1, max = 1),
    "rnorm", list(sd = 5),
    "rpois", list(lambda = 10)
)


sim %>% 
    dplyr::mutate(sim = invoke_map(f, params, n = 10))
#> # A tibble: 3 x 3
#>   f     params     sim       
#>   <chr> <list>     <list>    
#> 1 runif <list [2]> <dbl [10]>
#> 2 rnorm <list [1]> <dbl [10]>
#> 3 rpois <list [1]> <int [10]>

游走函数

当使用函数的目的是向屏幕提供输出或将文件保存到磁盘——重要的是操作过程而不是返回值,我们应该使用游走函数,而不是映射函数。

下面是一个示例:

代码语言:javascript复制
x = list(1, "a", 3)

x %>% 
    walk(print)
#> [1] 1
#> [1] "a"
#> [1] 3

一般来说,walk()函数不如walk2()和pwalk()实用。例如有一个图形列表和一个文件名向量,那么我们就可以使用pwalk()将每个文件保存到相应的磁盘位置:

代码语言:javascript复制
library(ggplot2)

plots = mtcars %>% 
    split(.$cyl) %>% 
    map(~ggplot(., aes(mpg, wt))   geom_point())
paths = stringr::str_c(names(plots), ".pdf")

pwalk(list(paths, plots), ggsave, path = tempdir())
#> Saving 7 x 5 in image
#> Saving 7 x 5 in image
#> Saving 7 x 5 in image

我们来查看一下是不是建立好了:

代码语言:javascript复制
dir(tempdir())
#> [1] "4.pdf" "6.pdf" "8.pdf"

for循环的其他模式

purrr还提供了其他一些函数,虽然这些函数的使用率低,但了解还是有必要的。本节就是对它们进行简单介绍

预测函数

一些函数可以与返回TRUEFALSE的预测函数一同使用。

keep()discard()函数可以分别保留输入中预测值为TRUEFALSE的元素(在数据框中就是指列):

代码语言:javascript复制
iris %>% 
    keep(is.factor) %>% 
    str()
#> 'data.frame':    150 obs. of  1 variable:
#>  $ Species: Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...


iris %>% 
    discard(is.factor) %>% 
    str()
#> 'data.frame':    150 obs. of  4 variables:
#>  $ Sepal.Length: num  5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
#>  $ Sepal.Width : num  3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
#>  $ Petal.Length: num  1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
#>  $ Petal.Width : num  0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...

some()every()函数分别用来确定预测值是否对某个元素为真以及是否对所有元素为真:

代码语言:javascript复制
x = list(1:5, letters, list(10))


x %>% 
    some(is_character)
#> [1] TRUE

x %>% 
    every(is_vector)
#> [1] TRUE

detect()可以找出预测值为真的第一个元素,detect_index()可以返回该元素的索引。

代码语言:javascript复制
x = sample(10)
x
#>  [1] 10  8  5  7  4  1  2  9  3  6

x %>% 
    detect(~ . >5)
#> [1] 10

x %>% 
    detect_index(~ . >5)
#> [1] 1

head_while()tail_while()分别从向量的开头和结尾找出预测值为真的元素:

代码语言:javascript复制
x %>% 
    head_while(~ . > 5)
#> [1] 10  8

x %>% 
    tail_while(~ . > 5)
#> [1] 6

归约和累计

操作一个复杂的列表,有时候我们想要不断合并两个预算两个元素(基础函数Reduce干的事情)。

代码语言:javascript复制
dfs = list(
    age = tibble(name = "John", age = 30),
    sex = tibble(name = c("John", "Mary"), sex = c("M", "F")),
    trt = tibble(name = "Mary", treatment = "A")
)

dfs %>% reduce(full_join)
#> Joining, by = "name"
#> Joining, by = "name"
#> # A tibble: 2 x 4
#>   name    age sex   treatment
#>   <chr> <dbl> <chr> <chr>    
#> 1 John     30 M     <NA>     
#> 2 Mary     NA F     A

这里我们使用reduce结合dplyr中的full_join()将它们轻松合并为一个数据框。

reduce()函数使用一个“二元函数”(即两个基本输入),将其不断应用于一个列表,直到最后只剩下一个元素。

累计函数与归约函数类似,但会保留中间结果,比如下面求取累计和:

代码语言:javascript复制
x = sample(10)

x
#>  [1]  9 10  8  5  6  2  3  4  7  1
x %>% accumulate(` `)
#>  [1]  9 19 27 32 38 40 43 47 54 55

0 人点赞