使用 STL 容器进行中位数计算时,正确的方法是啥?

Posted

技术标签:

【中文标题】使用 STL 容器进行中位数计算时,正确的方法是啥?【英文标题】:What is the right approach when using STL container for median calculation?使用 STL 容器进行中位数计算时,正确的方法是什么? 【发布时间】:2010-12-15 16:47:45 【问题描述】:

假设我需要从 1000000 个随机数值序列中检索中位数。

如果使用任何 std::list,我没有(内置)方法对序列进行排序以进行中值计算。

如果使用std::list,我不能随机访问值来检索排序序列的中间(中位数)。

自己实现排序并使用例如是否更好? std::vector,还是使用std::list 并使用std::list::iterator for-loop-walk 到中值更好?后者似乎不那么开销,但也感觉更难看..

或者我有更多更好的选择吗?

【问题讨论】:

【参考方案1】:

您可以使用库函数std::sortstd::vector 进行排序。

std::vector<int> vec;
// ... fill vector with stuff
std::sort(vec.begin(), vec.end());

【讨论】:

【参考方案2】:

任何随机访问容器(如std::vector)都可以使用std::sort 标头中的标准std::sort 算法进行排序。

要找到中位数,使用std::nth_element 会更快;这足以将一个选定的元素放在正确的位置,但不能完全对容器进行排序。所以你可以找到这样的中位数:

int median(vector<int> &v)

    size_t n = v.size() / 2;
    nth_element(v.begin(), v.begin()+n, v.end());
    return v[n];

【讨论】:

嗯。我没有意识到 nth_element 存在,我显然在我的回答中重新实现了它...... 需要注意的是nth_element会以不可预知的方式修改向量!如有必要,您可能希望对索引向量进行排序。 如果项数为偶数,则中位数为中间两个的平均值。 @sje397 是的,这个算法有一半是不正确的,即当向量包含偶数个元素时。调用 nth_element 函数 2 次(对于 2 个中间元素)是否比调用 sort 一次更昂贵?谢谢。 @F*** partial_sort 仍然是 O(N*log(N)) 并且 nth_element 是 O(N) (或者 O(2N) 如果执行两次,这仍然是线性的)所以我希望 nth_element随着 N 的增加会更快,但我还没有进行任何分析来证实这一点。【参考方案3】:

存在一个linear-time selection algorithm。下面的代码仅在容器具有随机访问迭代器时才有效,但可以修改为没有随机访问迭代器——您只需要更加小心避免使用像 end - beginiter + n 这样的快捷方式。

#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <sstream>
#include <vector>

template<class A, class C = std::less<typename A::value_type> >
class LinearTimeSelect 
public:
    LinearTimeSelect(const A &things) : things(things) 
    typename A::value_type nth(int n) 
        return nth(n, things.begin(), things.end());
    
private:
    static typename A::value_type nth(int n,
            typename A::iterator begin, typename A::iterator end) 
        int size = end - begin;
        if (size <= 5) 
            std::sort(begin, end, C());
            return begin[n];
        
        typename A::iterator walk(begin), skip(begin);
#ifdef RANDOM // randomized algorithm, average linear-time
        typename A::value_type pivot = begin[std::rand() % size];
#else // guaranteed linear-time, but usually slower in practice
        while (end - skip >= 5) 
            std::sort(skip, skip + 5);
            std::iter_swap(walk++, skip + 2);
            skip += 5;
        
        while (skip != end) std::iter_swap(walk++, skip++);
        typename A::value_type pivot = nth((walk - begin) / 2, begin, walk);
#endif
        for (walk = skip = begin, size = 0; skip != end; ++skip)
            if (C()(*skip, pivot)) std::iter_swap(walk++, skip), ++size;
        if (size <= n) return nth(n - size, walk, end);
        else return nth(n, begin, walk);
    
    A things;
;

int main(int argc, char **argv) 
    std::vector<int> seq;
    
        int i = 32;
        std::istringstream(argc > 1 ? argv[1] : "") >> i;
        while (i--) seq.push_back(i);
    
    std::random_shuffle(seq.begin(), seq.end());
    std::cout << "unordered: ";
    for (std::vector<int>::iterator i = seq.begin(); i != seq.end(); ++i)
        std::cout << *i << " ";
    LinearTimeSelect<std::vector<int> > alg(seq);
    std::cout << std::endl << "linear-time medians: "
        << alg.nth((seq.size()-1) / 2) << ", " << alg.nth(seq.size() / 2);
    std::sort(seq.begin(), seq.end());
    std::cout << std::endl << "medians by sorting: "
        << seq[(seq.size()-1) / 2] << ", " << seq[seq.size() / 2] << std::endl;
    return 0;

【讨论】:

【参考方案4】:

中位数比 Mike Seymour 的答案更复杂。中位数的不同取决于样本中的项目数是偶数还是奇数。如果项目数为偶数,则中位数为中间两项的平均值。这意味着整数列表的中位数可以是分数。最后,空列表的中位数是未定义的。这是通过我的基本测试用例的代码:

///Represents the exception for taking the median of an empty list
class median_of_empty_list_exception:public std::exception
  virtual const char* what() const throw() 
    return "Attempt to take the median of an empty list of numbers.  "
      "The median of an empty list is undefined.";
  
;

///Return the median of a sequence of numbers defined by the random
///access iterators begin and end.  The sequence must not be empty
///(median is undefined for an empty set).
///
///The numbers must be convertible to double.
template<class RandAccessIter>
double median(RandAccessIter begin, RandAccessIter end) 
  if(begin == end) throw median_of_empty_list_exception(); 
  std::size_t size = end - begin;
  std::size_t middleIdx = size/2;
  RandAccessIter target = begin + middleIdx;
  std::nth_element(begin, target, end);

  if(size % 2 != 0) //Odd number of elements
    return *target;
  else            //Even number of elements
    double a = *target;
    RandAccessIter targetNeighbor= target-1;
    std::nth_element(begin, targetNeighbor, end);
    return (a+*targetNeighbor)/2.0;
  

【讨论】:

我知道这是很久以前的事了,但因为我刚刚在 google 上找到了这个:std::nth_element 实际上也保证任何前面的元素都是 =。所以你可以只使用targetNeighbor = std::min_element(begin, target) 并跳过部分排序,这可能会快一点。 (nth_element 是平均线性的,而min_element 显然是线性的。)即使您更愿意再次使用nth_element,它也是等效的,并且可能会更快一点,只需执行nth_element(begin, targetNeighbor, target) @Dougal 我认为你的意思是 targetNeighbor = std::max_element(begin, target) 在这种情况下? @Dougal 我知道这条评论来自很久以前 ;),但我不知道你的方法应该如何工作,你确定这会给出正确的结果吗? @tobi303 你的永远是我的两倍。 :) 是的,它绝对应该:关键是在调用std::nth_element 之后,序列就像[smaller_than_target, target, bigger_than_target]。所以你知道target-1th元素在数组的前半部分,你只需要找到target之前元素的最大值就可以得到中位数。 @Dougal 啊,现在我明白了。谢谢【参考方案5】:

这是 Mike Seymour 答案的更完整版本:

// Could use pass by copy to avoid changing vector
double median(std::vector<int> &v)

  size_t n = v.size() / 2;
  std::nth_element(v.begin(), v.begin()+n, v.end());
  int vn = v[n];
  if(v.size()%2 == 1)
  
    return vn;
  else
  
    std::nth_element(v.begin(), v.begin()+n-1, v.end());
    return 0.5*(vn+v[n-1]);
  

它处理奇数或偶数长度的输入。

【讨论】:

对于通过副本,您的意思是删除输入中的引用 (&amp;) 吗? 我只是将评论作为一个注释,一个可以使用逐个复制,在这种情况下是的,应该删除&amp; 这个版本有一个bug。您需要在再次执行 nth_element 之前提取v[n],因为在第二轮之后v[n] 可能包含不同的值。 @MatthewFioravante,我明白了。根据docs,我猜nth_element 不需要稳定。 (相应地编辑了我的答案)。 与其第二次调用nth_element,不如直接从v[0]迭代到v[n]并确定那一半的最大值不是更有效吗?【参考方案6】:

该算法使用 STL nth_element (amortized O(N)) 算法和 max_element 算法 (O(n)) 有效地处理偶数和奇数大小的输入。请注意,nth_element 还有一个保证的副作用,即n 之前的所有元素都保证小于v[n],只是不一定要排序。

//post-condition: After returning, the elements in v may be reordered and the resulting order is implementation defined.
double median(vector<double> &v)

  if(v.empty()) 
    return 0.0;
  
  auto n = v.size() / 2;
  nth_element(v.begin(), v.begin()+n, v.end());
  auto med = v[n];
  if(!(v.size() & 1))  //If the set size is even
    auto max_it = max_element(v.begin(), v.begin()+n);
    med = (*max_it + med) / 2.0;
  
  return med;    

【讨论】:

我喜欢你的回答,但是当向量为空时返回零​​不适合我的应用程序,我希望在空向量的情况下出现异常。【参考方案7】:

把来自这个线程的所有见解放在一起,我最终有了这个例程。它适用于任何 stl 容器或任何提供输入迭代器的类,并处理奇数和偶数大小的容器。它还在容器的副本上工作,不修改原始内容。

template <typename T = double, typename C>
inline const T median(const C &the_container)

    std::vector<T> tmp_array(std::begin(the_container), 
                             std::end(the_container));
    size_t n = tmp_array.size() / 2;
    std::nth_element(tmp_array.begin(), tmp_array.begin() + n, tmp_array.end());

    if(tmp_array.size() % 2) return tmp_array[n]; 
    else
    
        // even sized vector -> average the two middle values
        auto max_it = std::max_element(tmp_array.begin(), tmp_array.begin() + n);
        return (*max_it + tmp_array[n]) / 2.0;
    

【讨论】:

正如 Matthew Fioravante ***.com/questions/1719070/… 所提到的,“您需要在再次执行 nth_element 之前提取 v[n],因为在第二轮之后 v[n] 可能包含不同的值。”所以,设 med = tmp_array[n],那么正确的返回行是:return (*max_it + med) / 2.0; @trig-ger nth_element 在此解决方案中仅使用一次。这不是问题。 static_assert(std::is_same_v&lt;typename C::value_type, T&gt;, "mismatched container and element types") 也许?【参考方案8】:

这是一个考虑了@MatthieuM 建议的答案。即不修改输入向量。它对偶数和奇数基数的范围使用单个部分排序(在索引向量上),而空范围由向量的at 方法抛出的异常处理:

double median(vector<int> const& v)

    bool isEven = !(v.size() % 2); 
    size_t n    = v.size() / 2;

    vector<size_t> vi(v.size()); 
    iota(vi.begin(), vi.end(), 0); 

    partial_sort(begin(vi), vi.begin() + n + 1, end(vi), 
        [&](size_t lhs, size_t rhs)  return v[lhs] < v[rhs]; ); 

    return isEven ? 0.5 * (v[vi.at(n-1)] + v[vi.at(n)]) : v[vi.at(n)];

Demo

【讨论】:

【参考方案9】:
you can use this approch. It also takes care of sliding window.
Here days are no of trailing elements for which we want to find median and this makes sure the original container is not changed


#include<bits/stdc++.h>

using namespace std;

int findMedian(vector<int> arr, vector<int> brr, int d, int i)

    int x,y;
    x= i-d;
    y=d;
    brr.assign(arr.begin()+x, arr.begin()+x+y);


    sort(brr.begin(), brr.end());

    if(d%2==0)
    
        return((brr[d/2]+brr[d/2 -1]));
    

    else
    
        return (2*brr[d/2]);
    

    // for (int i = 0; i < brr.size(); ++i)
    // 
    //     cout<<brr[i]<<" ";
    // 

    return 0;



int main()

    int n;
    int days;
    int input;
    int median;
    int count=0;

    cin>>n>>days;

    vector<int> arr;
    vector<int> brr;

    for (int i = 0; i < n; ++i)
    
        cin>>input;
        arr.push_back(input);
    

    for (int i = days; i < n; ++i)
    
        median=findMedian(arr,brr, days, i);

        
    



    return 0;

【讨论】:

添加代码sn -p时请尝试添加说明【参考方案10】:

Armadillo 的实现类似于https://***.com/users/2608582/matthew-fioravante 的答案https://***.com/a/34077478 中的实现

它使用一次调用nth_element 和一次调用max_element,它在这里: https://gitlab.com/conradsnicta/armadillo-code/-/blob/9.900.x/include/armadillo_bits/op_median_meat.hpp#L380

//! find the median value of a std::vector (contents is modified)
template<typename eT>
inline 
eT
op_median::direct_median(std::vector<eT>& X)
  
  arma_extra_debug_sigprint();
  
  const uword n_elem = uword(X.size());
  const uword half   = n_elem/2;
  
  typename std::vector<eT>::iterator first    = X.begin();
  typename std::vector<eT>::iterator nth      = first + half;
  typename std::vector<eT>::iterator pastlast = X.end();
  
  std::nth_element(first, nth, pastlast);
  
  if((n_elem % 2) == 0)  // even number of elements
    
    typename std::vector<eT>::iterator start   = X.begin();
    typename std::vector<eT>::iterator pastend = start + half;
    
    const eT val1 = (*nth);
    const eT val2 = (*(std::max_element(start, pastend)));
    
    return op_mean::robust_mean(val1, val2);
    
  else  // odd number of elements
    
    return (*nth);
    
  

【讨论】:

以上是关于使用 STL 容器进行中位数计算时,正确的方法是啥?的主要内容,如果未能解决你的问题,请参考以下文章

STL是啥意思?

C++ STL容器在for循环中删除迭代器 正确方法 it++正确吗

C++ STL容器在for循环中删除迭代器 正确方法 it++正确吗

STL是啥

GCC 用于 STL 的默认分配器是啥?

在 Vuex 中基于状态属性进行计算的正确方法是啥?