scala中值实现
Posted
技术标签:
【中文标题】scala中值实现【英文标题】:scala median implementation 【发布时间】:2011-06-07 10:21:37 【问题描述】:scala 中中位数的快速实现是什么?
这是我在rosetta code找到的:
def median(s: Seq[Double]) =
val (lower, upper) = s.sortWith(_<_).splitAt(s.size / 2)
if (s.size % 2 == 0) (lower.last + upper.head) / 2.0 else upper.head
我不喜欢它,因为它会进行排序。我知道有一些方法可以计算线性时间的中位数。
编辑:
我想要一组可以在各种场景中使用的中值函数:
-
可以在线性时间内完成的快速、就地中值计算
适用于可以多次遍历的流的中位数,但您只能将
O(log n)
值保留在内存中 like this
适用于流的中位数,您最多可以在内存中保存O(log n)
值,并且您最多可以遍历流一次(这甚至可能吗?)
请仅发布编译且正确计算中位数的代码。为简单起见,您可以假设所有输入都包含奇数个值。
【问题讨论】:
快速谷歌搜索给了我this 和this。基本上,您正在寻找的是选择算法。 Scala 版本留给读者作为练习。 “好的”算法要复杂得多。 Google 为“中位数的中位数”或“五的中位数”。 一个良好实现(即库)的排序算法在您的应用程序的现实中可能证明比一些所谓的线性时间算法的某些实现更快。至于上面的代码,您可能会省略拆分并进行索引访问,具体取决于您假设的 Seq 实现类型。 我认为第三种情况是不可能的。例如,假设我得到了从 1000 到 1500 的数字。中位数是 1250。现在,如果我开始获得低于 1000 的数字,中位数将减少 1 直到达到 1000。同样,如果我开始获得高于 1500 的数字,中位数将增加直到达到 1500。所以你需要保持到目前为止看到的所有数字。 【参考方案1】:不可变算法
Taylor Leese 的 first algorithm indicated 是二次的,但具有线性平均值。但是,这取决于枢轴选择。所以我在这里提供了一个版本,它有一个可插入的枢轴选择,以及随机枢轴和中位数枢轴的中位数(保证线性时间)。
import scala.annotation.tailrec
@tailrec def findKMedian(arr: Array[Double], k: Int)(implicit choosePivot: Array[Double] => Double): Double =
val a = choosePivot(arr)
val (s, b) = arr partition (a >)
if (s.size == k) a
// The following test is used to avoid infinite repetition
else if (s.isEmpty)
val (s, b) = arr partition (a ==)
if (s.size > k) a
else findKMedian(b, k - s.size)
else if (s.size < k) findKMedian(b, k - s.size)
else findKMedian(s, k)
def findMedian(arr: Array[Double])(implicit choosePivot: Array[Double] => Double) = findKMedian(arr, (arr.size - 1) / 2)
随机枢轴(二次、线性平均)、不可变
这是随机枢轴选择。具有随机因素的算法分析比正常情况更棘手,因为它主要处理概率和统计数据。
def chooseRandomPivot(arr: Array[Double]): Double = arr(scala.util.Random.nextInt(arr.size))
中位数的中位数(线性),不可变
中位数方法,与上述算法一起使用时保证线性时间。首先,计算最多 5 个数字的中位数的算法,这是中位数算法的基础。这个是由Rex Kerr 在this answer 中提供的——算法很大程度上取决于它的速度。
def medianUpTo5(five: Array[Double]): Double =
def order2(a: Array[Double], i: Int, j: Int) =
if (a(i)>a(j)) val t = a(i); a(i) = a(j); a(j) = t
def pairs(a: Array[Double], i: Int, j: Int, k: Int, l: Int) =
if (a(i)<a(k)) order2(a,j,k); a(j)
else order2(a,i,l); a(i)
if (five.length < 2) return five(0)
order2(five,0,1)
if (five.length < 4) return (
if (five.length==2 || five(2) < five(0)) five(0)
else if (five(2) > five(1)) five(1)
else five(2)
)
order2(five,2,3)
if (five.length < 5) pairs(five,0,1,2,3)
else if (five(0) < five(2)) order2(five,1,4); pairs(five,1,4,2,3)
else order2(five,3,4); pairs(five,0,1,3,4)
然后,中位数算法本身的中位数。基本上,它保证选择的枢轴将大于至少 30% 并且小于列表的其他 30%,这足以保证前面算法的线性。有关详细信息,请查看另一个答案中提供的***链接。
def medianOfMedians(arr: Array[Double]): Double =
val medians = arr grouped 5 map medianUpTo5 toArray;
if (medians.size <= 5) medianUpTo5 (medians)
else medianOfMedians(medians)
就地算法
所以,这是算法的就地版本。我正在使用一个使用支持数组就地实现分区的类,以便对算法的更改最小化。
case class ArrayView(arr: Array[Double], from: Int, until: Int)
def apply(n: Int) =
if (from + n < until) arr(from + n)
else throw new ArrayIndexOutOfBoundsException(n)
def partitionInPlace(p: Double => Boolean): (ArrayView, ArrayView) =
var upper = until - 1
var lower = from
while (lower < upper)
while (lower < until && p(arr(lower))) lower += 1
while (upper >= from && !p(arr(upper))) upper -= 1
if (lower < upper) val tmp = arr(lower); arr(lower) = arr(upper); arr(upper) = tmp
(copy(until = lower), copy(from = lower))
def size = until - from
def isEmpty = size <= 0
override def toString = arr mkString ("ArraySize(", ", ", ")")
; object ArrayView
def apply(arr: Array[Double]) = new ArrayView(arr, 0, arr.size)
@tailrec def findKMedianInPlace(arr: ArrayView, k: Int)(implicit choosePivot: ArrayView => Double): Double =
val a = choosePivot(arr)
val (s, b) = arr partitionInPlace (a >)
if (s.size == k) a
// The following test is used to avoid infinite repetition
else if (s.isEmpty)
val (s, b) = arr partitionInPlace (a ==)
if (s.size > k) a
else findKMedianInPlace(b, k - s.size)
else if (s.size < k) findKMedianInPlace(b, k - s.size)
else findKMedianInPlace(s, k)
def findMedianInPlace(arr: Array[Double])(implicit choosePivot: ArrayView => Double) = findKMedianInPlace(ArrayView(arr), (arr.size - 1) / 2)
随机枢轴,就地
我只是为就地算法实现随机枢轴,因为中位数的中位数需要比我定义的 ArrayView
类目前提供的支持更多。
def chooseRandomPivotInPlace(arr: ArrayView): Double = arr(scala.util.Random.nextInt(arr.size))
直方图算法(O(log(n)) 内存),不可变
所以,关于流。对于只能遍历一次的流,不可能做任何小于O(n)
内存的事情,除非您碰巧知道字符串长度是多少(在这种情况下,它不再是我书中的流)。
使用桶也有点问题,但如果我们可以多次遍历它,那么我们就可以知道它的大小、最大值和最小值,并从那里开始工作。例如:
def findMedianHistogram(s: Traversable[Double]) =
def medianHistogram(s: Traversable[Double], discarded: Int, medianIndex: Int): Double =
// The buckets
def numberOfBuckets = (math.log(s.size).toInt + 1) max 2
val buckets = new Array[Int](numberOfBuckets)
// The upper limit of each bucket
val max = s.max
val min = s.min
val increment = (max - min) / numberOfBuckets
val indices = (-numberOfBuckets + 1 to 0) map (max + increment * _)
// Return the bucket a number is supposed to be in
def bucketIndex(d: Double) = indices indexWhere (d <=)
// Compute how many in each bucket
s foreach d => buckets(bucketIndex(d)) += 1
// Now make the buckets cumulative
val partialTotals = buckets.scanLeft(discarded)(_+_).drop(1)
// The bucket where our target is at
val medianBucket = partialTotals indexWhere (medianIndex <)
// Keep track of how many numbers there are that are less
// than the median bucket
val newDiscarded = if (medianBucket == 0) discarded else partialTotals(medianBucket - 1)
// Test whether a number is in the median bucket
def insideMedianBucket(d: Double) = bucketIndex(d) == medianBucket
// Get a view of the target bucket
val view = s.view filter insideMedianBucket
// If all numbers in the bucket are equal, return that
if (view forall (view.head ==)) view.head
// Otherwise, recurse on that bucket
else medianHistogram(view, newDiscarded, medianIndex)
medianHistogram(s, 0, (s.size - 1) / 2)
测试和基准测试
为了测试算法,我使用Scalacheck,并将每个算法的输出与带有排序的简单实现的输出进行比较。当然,这假设排序版本是正确的。
我正在使用所有提供的枢轴选择以及固定的枢轴选择(数组的中间,向下舍入)对上述每个算法进行基准测试。每种算法都使用三种不同的输入数组大小进行了测试,并且针对每种算法测试了 3 次。
这是测试代码:
import org.scalacheck.Prop, Pretty, Test
import Prop._
import Pretty._
def test(algorithm: Array[Double] => Double,
reference: Array[Double] => Double): String =
def prettyPrintArray(arr: Array[Double]) = arr mkString ("Array(", ", ", ")")
val resultEqualsReference = forAll (arr: Array[Double]) =>
arr.nonEmpty ==> (algorithm(arr) == reference(arr)) :| prettyPrintArray(arr)
Test.check(Test.Params(), resultEqualsReference)(Pretty.Params(verbosity = 0))
import java.lang.System.currentTimeMillis
def bench[A](n: Int)(body: => A): Long =
val start = currentTimeMillis()
1 to n foreach _ => body
currentTimeMillis() - start
import scala.util.Random.nextDouble
def benchmark(algorithm: Array[Double] => Double,
arraySizes: List[Int]): List[Iterable[Long]] =
for (size <- arraySizes)
yield for (iteration <- 1 to 3)
yield bench(50000)(algorithm(Array.fill(size)(nextDouble)))
def testAndBenchmark: String =
val immutablePivotSelection: List[(String, Array[Double] => Double)] = List(
"Random Pivot" -> chooseRandomPivot,
"Median of Medians" -> medianOfMedians,
"Midpoint" -> ((arr: Array[Double]) => arr((arr.size - 1) / 2))
)
val inPlacePivotSelection: List[(String, ArrayView => Double)] = List(
"Random Pivot (in-place)" -> chooseRandomPivotInPlace,
"Midpoint (in-place)" -> ((arr: ArrayView) => arr((arr.size - 1) / 2))
)
val immutableAlgorithms = for ((name, pivotSelection) <- immutablePivotSelection)
yield name -> (findMedian(_: Array[Double])(pivotSelection))
val inPlaceAlgorithms = for ((name, pivotSelection) <- inPlacePivotSelection)
yield name -> (findMedianInPlace(_: Array[Double])(pivotSelection))
val histogramAlgorithm = "Histogram" -> ((arr: Array[Double]) => findMedianHistogram(arr))
val sortingAlgorithm = "Sorting" -> ((arr: Array[Double]) => arr.sorted.apply((arr.size - 1) / 2))
val algorithms = sortingAlgorithm :: histogramAlgorithm :: immutableAlgorithms ::: inPlaceAlgorithms
val formattingString = "%%-%ds %%s" format (algorithms map (_._1.length) max)
// Tests
val testResults = for ((name, algorithm) <- algorithms)
yield formattingString format (name, test(algorithm, sortingAlgorithm._2))
// Benchmarks
val arraySizes = List(100, 500, 1000)
def formatResults(results: List[Long]) = results map ("%8d" format _) mkString
val benchmarkResults: List[String] = for
(name, algorithm) <- algorithms
results <- benchmark(algorithm, arraySizes).transpose
yield formattingString format (name, formatResults(results))
val header = formattingString format ("Algorithm", formatResults(arraySizes.map(_.toLong)))
"Tests" :: "*****" :: testResults :::
("" :: "Benchmark" :: "*********" :: header :: benchmarkResults) mkString ("", "\n", "\n")
结果
测试:
Tests
*****
Sorting OK, passed 100 tests.
Histogram OK, passed 100 tests.
Random Pivot OK, passed 100 tests.
Median of Medians OK, passed 100 tests.
Midpoint OK, passed 100 tests.
Random Pivot (in-place)OK, passed 100 tests.
Midpoint (in-place) OK, passed 100 tests.
基准测试:
Benchmark
*********
Algorithm 100 500 1000
Sorting 1038 6230 14034
Sorting 1037 6223 13777
Sorting 1039 6220 13785
Histogram 2918 11065 21590
Histogram 2596 11046 21486
Histogram 2592 11044 21606
Random Pivot 904 4330 8622
Random Pivot 902 4323 8815
Random Pivot 896 4348 8767
Median of Medians 3591 16857 33307
Median of Medians 3530 16872 33321
Median of Medians 3517 16793 33358
Midpoint 1003 4672 9236
Midpoint 1010 4755 9157
Midpoint 1017 4663 9166
Random Pivot (in-place) 392 1746 3430
Random Pivot (in-place) 386 1747 3424
Random Pivot (in-place) 386 1751 3431
Midpoint (in-place) 378 1735 3405
Midpoint (in-place) 377 1740 3408
Midpoint (in-place) 375 1736 3408
分析
所有算法(排序版本除外)都有与平均线性时间复杂度兼容的结果。
中位数的中位数,它保证了最坏情况下的线性时间复杂度比随机枢轴要慢得多。
固定枢轴选择比随机枢轴稍差,但在非随机输入上的性能可能要差得多。
就地版本大约快 230% ~ 250%,但进一步的测试(未显示)似乎表明这种优势随着阵列大小的增加而增加。
我对直方图算法感到非常惊讶。它显示线性时间复杂度平均值,也比中位数的中位数快 33%。然而,输入是随机的。最坏的情况是二次的——我在调试代码时看到了一些例子。
【讨论】:
这段代码的三个问题是(a)它不能编译(递归函数需要一个明确的返回类型),(b)它不是线性时间(因为分区是 O(n) 并且运行了 O(n) 次),并且 (c) 它会产生错误的答案。除此之外,是的。 @Malvolio 这里和那里有一些错误,但没有什么比认为它运行 O(n) 次更粗鲁了...... ;-) 无论如何,我不在乎算法是否有效或有正确的复杂度,我只是将别人声称是线性时间的算法翻译成Scala。 @Malvolio 这个算法是(或似乎是)O(nlogn),因为 Arr 的大小平均每次减半。然而,这种分析是肤浅的。该算法看起来很像快速排序,但只有一半的分区被递归到,这使得它已经比快速排序更快。此外,它不需要一直下降到 1 大小的分区。至于错误,它们是一个错误,主要与原始算法从分区中隐式删除a
以及声明a
时缺少“arr”有关。差一个错误太糟糕了。
@Malvolio random 并非“毫无意义”。如果您不使用随机,有人可能会猜到您正在使用的策略,选择您的程序花费 O(n^2) 时间的示例并挂起您的服务器。该算法平均而言是正确且线性的。
@Raphael 该论点完全有效。假设每次数组的长度减少两倍。然后第一次迭代需要 n 个时间单位,第二次迭代需要 n/2 个单位,第三次迭代需要 n/4 个单位,以此类推,总和为 n+n/2+n/4 + ... = 2*n。当然这只是一个直观的解释,严格的证明可以在任何关于算法的书中找到。以上是关于scala中值实现的主要内容,如果未能解决你的问题,请参考以下文章