dplyr - 按组大小过滤
Posted
技术标签:
【中文标题】dplyr - 按组大小过滤【英文标题】:dplyr - filter by group size 【发布时间】:2017-08-23 22:26:28 【问题描述】:过滤 data.frame 以仅获取大小为 5 的组的最佳方法是什么?
所以我的数据如下:
require(dplyr)
n <- 1e5
x <- rnorm(n)
# Category size ranging each from 1 to 5
cat <- rep(seq_len(n/3), sample(1:5, n/3, replace = TRUE))[1:n]
dat <- data.frame(x = x, cat = cat)
我能想到的 dplyr 方法是
dat <- group_by(dat, cat)
system.time(
out1 <- dat %>% filter(n() == 5L)
)
# user system elapsed
# 1.157 0.218 1.497
但这很慢……在 dplyr 中有没有更好的方法?
到目前为止,我的解决方法如下:
system.time(
all_ind <- rep(seq_len(n_groups(dat)), group_size(dat))
take_only <- which(group_size(dat) == 5L)
out2 <- dat[all_ind %in% take_only, ]
)
# user system elapsed
# 0.026 0.008 0.036
all.equal(out1, out2) # TRUE
但这并不像 dplyr 那样......
【问题讨论】:
等到你看到 data.table 解决方案。根本不是 dplyery。 也很简单。类似setDT(dat)[, if(.N == 5) .SD, by = cat]
data.table 中更快的方法是setDT(dat)[dat[, .I[.N==5], by = cat]$V1]
@ChirayuChamoli:猜你的意思是m
:是的,它已被订购。
@ChirayuChamoli:很好的答案。刚刚在下面添加了
【参考方案1】:
这是您可以尝试的另一种 dplyr 方法
semi_join(dat, count(dat, cat) %>% filter(n == 5), by = "cat")
--
这是基于 OP 原始方法的另一种方法,稍作修改:
n <- 1e5
x <- rnorm(n)
# Category size ranging each from 1 to 5
cat <- rep(seq_len(n/3), sample(1:5, n/3, replace = TRUE))[1:n]
dat <- data.frame(x = x, cat = cat)
# second data set for the dt approch
dat2 <- data.frame(x = x, cat = cat)
sol_floo0 <- function(dat)
dat <- group_by(dat, cat)
all_ind <- rep(seq_len(n_groups(dat)), group_size(dat))
take_only <- which(group_size(dat) == 5L)
dat[all_ind %in% take_only, ]
sol_floo0_v2 <- function(dat)
g <- group_by(dat, cat) %>% group_size()
ind <- rep(g == 5, g)
dat[ind, ]
microbenchmark::microbenchmark(times = 10,
sol_floo0(dat),
sol_floo0_v2(dat2))
#Unit: milliseconds
# expr min lq mean median uq max neval cld
# sol_floo0(dat) 43.72903 44.89957 45.71121 45.10773 46.59019 48.64595 10 b
# sol_floo0_v2(dat2) 29.83724 30.56719 32.92777 31.97169 34.10451 38.31037 10 a
all.equal(sol_floo0(dat), sol_floo0_v2(dat2))
#[1] TRUE
【讨论】:
谢谢。与我的filter
解决方案相比,这是一个巨大的加速。仍在寻找比我的第二个 hybrid 解决方案更好的解决方案...
第二个版本几乎与包含分组的 DT 解决方案一样快。很棒的一个【参考方案2】:
按时间比较答案:
require(dplyr)
require(data.table)
n <- 1e5
x <- rnorm(n)
# Category size ranging each from 1 to 5
cat <- rep(seq_len(n/3), sample(1:5, n/3, replace = TRUE))[1:n]
dat <- data.frame(x = x, cat = cat)
# second data set for the dt approch
dat2 <- data.frame(x = x, cat = cat)
sol_floo0 <- function(dat)
dat <- group_by(dat, cat)
all_ind <- rep(seq_len(n_groups(dat)), group_size(dat))
take_only <- which(group_size(dat) == 5L)
dat[all_ind %in% take_only, ]
sol_floo0_v2 <- function(dat)
g <- group_by(dat, cat) %>% group_size()
ind <- rep(g == 5, g)
dat[ind, ]
sol_docendo_discimus <- function(dat)
dat <- group_by(dat, cat)
semi_join(dat, count(dat, cat) %>% filter(n == 5), by = "cat")
sol_akrun <- function(dat2)
setDT(dat2)[dat2[, .I[.N==5], by = cat]$V1]
sol_sotos <- function(dat2)
setDT(dat2)[, if(.N == 5) .SD, by = cat]
sol_chirayu_chamoli <- function(dat)
rle_ <- rle(dat$cat)
dat[dat$cat %in% rle_$values[rle_$lengths==5], ]
microbenchmark::microbenchmark(times = 20,
sol_floo0(dat),
sol_floo0_v2(dat),
sol_docendo_discimus(dat),
sol_akrun(dat2),
sol_sotos(dat2),
sol_chirayu_chamoli(dat))
结果:
Unit: milliseconds
expr min lq mean median uq max neval cld
sol_floo0(dat) 58.00439 65.28063 93.54014 69.82658 82.79997 280.23114 20 cd
sol_floo0_v2(dat) 42.27791 50.27953 72.51729 58.63931 67.62540 238.97413 20 bc
sol_docendo_discimus(dat) 100.54095 113.15476 126.74142 121.69013 132.62533 183.05818 20 d
sol_akrun(dat2) 26.88369 34.01925 41.04378 37.07957 45.44784 63.95430 20 ab
sol_sotos(dat2) 16.10177 19.78403 24.04375 23.06900 28.05470 35.83611 20 a
sol_chirayu_chamoli(dat) 20.67951 24.18100 38.01172 27.61618 31.97834 230.51026 20 ab
【讨论】:
为了更好的衡量标准,您也应该在函数的时间安排中包含group_by
步骤
@docendodiscimus:对于一个 SO 问题,你是对的,它更好。在我的分析中,我已经将它们分组。这就是为什么我在测量中排除它......
sol_docendo_discimus
不需要group_by
调用!这就是运行时间如此之高的原因。【参考方案3】:
我概括了docendo discimus 编写的函数,以便与现有的 dplyr 函数一起使用:
#' inherit dplyr::filter
#' @param min minimal group size, use \codemin = NULL to filter on maximal group size only
#' @param max maximal group size, use \codemax = NULL to filter on minimal group size only
#' @export
#' @source Stack Overflow answer by docendo discimus, \urlhttps://***.com/a/43110620/4575331
filter_group_size <- function(.data, min = NULL, max = min)
g <- dplyr::group_size(.data)
if (is.null(min) & is.null(max))
stop('`min` and `max` cannot both be NULL.')
if (is.null(max))
max <- base::max(g, na.rm = TRUE)
ind <- base::rep(g >= min & g <= max, g)
.data[ind, ]
让我们检查一下5
的最小组大小:
dat2 %>%
group_by(cat) %>%
filter_group_size(5, NULL) %>%
summarise(n = n()) %>%
arrange(desc(n))
# # A tibble: 6,634 x 2
# cat n
# <int> <int>
# 1 NA 19
# 2 1 5
# 3 2 5
# 4 6 5
# 5 15 5
# 6 17 5
# 7 21 5
# 8 27 5
# 9 33 5
# 10 37 5
# # ... with 6,624 more rows
很好,现在检查 OP 的问题;正好是5
的组大小:
dat2 %>%
group_by(cat) %>%
filter_group_size(5) %>%
summarise(n = n()) %>%
pull(n) %>%
unique()
# [1] 5
万岁。
【讨论】:
我有一些这个版本失败的数据。似乎 group_by 对数据进行排序,这可能会导致错误的索引。data.frame(x=c(2,2,1)) %>% group_by(x) %>% group_size
导致 c(1,2)
而不是 c(2, 1)
首先,您使用的group_size()
是dplyr
包中的一个函数,与我的回答无关。其次,dplyr
包中的 group_by()
函数按字母顺序对组进行排序。【参考方案4】:
我知道您要求dplyr
解决方案,但如果您将它与一些purrr
结合使用,您可以在一行中得到它而无需指定任何新功能。 (虽然有点慢。)
library(dplyr)
library(purrr)
library(tidyr)
dat %>%
group_by(cat) %>%
nest() %>%
mutate(n = map(data, n_distinct)) %>%
unnest(n = n) %>%
filter(n == 5) %>%
select(cat, n)
【讨论】:
【参考方案5】:你可以用n()
做的更简洁:
library(dplyr)
dat %>% group_by(cat) %>% filter(n() == 5)
【讨论】:
如果您已经阅读了问题的一部分“我能想出的 dplyr 方式是”,您就会知道他/她已经尝试过了。 OP 认为它太慢了。【参考方案6】:加速 dplyr-way n()
过滤器的一种非常简单的方法是将结果存储在新列中。如果以后有多个filter
s,则计算组大小的初始时间摊销。
library(dplyr)
prep_group <- function(dat)
dat %>%
group_by(cat) %>%
mutate(
Occurrences = n()
) %>%
ungroup()
# Create a new data frame with the `Occurrences` column:
# dat_prepped <- dat %>% prep_group
过滤Occurrences
字段比变通解决方案快得多:
sol_floo0 <- function(dat)
dat <- group_by(dat, cat)
all_ind <- rep(seq_len(n_groups(dat)), group_size(dat))
take_only <- which(group_size(dat) == 5L)
dat[all_ind %in% take_only, ]
sol_floo0_v2 <- function(dat)
g <- group_by(dat, cat) %>% group_size()
ind <- rep(g == 5, g)
dat[ind, ]
sol_cached <- function(dat)
out <- filter(dat, Occurrences == 5L)
n <- 1e5
x <- rnorm(n)
# Category size ranging each from 1 to 5
cat <- rep(seq_len(n/3), sample(1:5, n/3, replace = TRUE))[1:n]
dat <- data.frame(x = x, cat = cat)
dat_prepped <- prep_group(dat)
microbenchmark::microbenchmark(times=50, sol_floo0(dat), sol_floo0_v2(dat), sol_cached(dat_prepped))
Unit: microseconds
expr min lq mean median uq max neval cld
sol_floo0(dat) 33345.764 35603.446 42430.441 37994.477 41379.411 144103.471 50 c
sol_floo0_v2(dat) 26180.539 27842.927 29694.203 29089.672 30997.411 37412.899 50 b
sol_cached(dat_prepped) 801.402 930.025 1342.348 1098.843 1328.192 5049.895 50 a
使用count()
-> left_join()
可以进一步加快准备工作:
prep_join <- function(dat)
dat %>%
left_join(
dat %>%
count(cat, name="Occurrences")
)
microbenchmark::microbenchmark(times=10, prep_group(dat), prep_join(dat))
Unit: milliseconds
expr min lq mean median uq max neval cld
prep_group(dat) 45.67805 47.68100 48.98929 49.11258 50.08214 52.44737 10 b
prep_join(dat) 35.01945 36.20857 37.96460 36.86776 38.71056 45.59041 10 a
【讨论】:
【参考方案7】:dat %>%
dplyr::group_by(cat) %>%
dplyr::add_tally() %>%
dplyr::filter(n == 5)
【讨论】:
欢迎来到 Stack Overflow。没有任何解释的代码转储很少有帮助。 Stack Overflow 是关于学习的,而不是提供 sn-ps 来盲目复制和粘贴。请edit您的问题并解释它如何回答所提出的具体问题。见How to Answer。在用现有答案回答老问题(这个问题超过 4 岁)时,这一点尤其重要。这个答案如何改进已经存在的内容(特别是乔的答案)?以上是关于dplyr - 按组大小过滤的主要内容,如果未能解决你的问题,请参考以下文章
R(和 dplyr?) - 按组从数据帧中采样,最大样本大小为 n