如何改进此算法以优化运行时间(在段中查找点)

Posted

技术标签:

【中文标题】如何改进此算法以优化运行时间(在段中查找点)【英文标题】:How can improve this algorithm to optimize the running time (find points in segments) 【发布时间】:2016-10-31 14:59:23 【问题描述】:

我有 2 个积分,第一个是段数 (Xi,Xj),第二个是可以或不能在这些段内的点数。

例如,输入可以是:

2 3
0 5
8 10
1 6 11

在第一行中,2 表示“2 个段”,3 表示“3 个点”。 2段是“0到5”和“8到10”,要寻找的点是1、6、11。 输出是

1 0 0

其中点 1 在“0 到 5”段中,而点 6 和 11 不在任何段中。 如果一个点出现在多个段中,例如 3,则输出为 2。

原始代码只是一个双循环来搜索段之间的点。我使用了 Java Arrays 快速排序(已修改,因此当它对段的端点进行排序时,也对起点进行排序,因此 start[i] 和 end[i] 属于同一段 i)来提高双循环的速度,但这还不够。

下一个代码可以正常工作,但是当段太多时它会变得很慢:

public class PointsAndSegments 

    private static int[] fastCountSegments(int[] starts, int[] ends, int[] points) 
        sort(starts, ends);
        int[] cnt2 = CountSegments(starts,ends,points);
        return cnt2;
    

    private static void dualPivotQuicksort(int[] a, int[] b, int left,int right, int div) 
    int len = right - left;
    if (len < 27)  // insertion sort for tiny array
        for (int i = left + 1; i <= right; i++) 
            for (int j = i; j > left && b[j] < b[j - 1]; j--) 
                swap(a, b, j, j - 1);
            
        
        return;
    
    int third = len / div;
    // "medians"
    int m1 = left  + third;
    int m2 = right - third;
    if (m1 <= left) 
        m1 = left + 1;
    
    if (m2 >= right) 
        m2 = right - 1;
    
    if (a[m1] < a[m2]) 
        swap(a, b, m1, left);
        swap(a, b, m2, right);
    
    else 
        swap(a, b, m1, right);
        swap(a, b, m2, left);
    
    // pivots
    int pivot1 = b[left];
    int pivot2 = b[right];
    // pointers
    int less  = left  + 1;
    int great = right - 1;
    // sorting
    for (int k = less; k <= great; k++) 
        if (b[k] < pivot1) 
            swap(a, b, k, less++);
         
        else if (b[k] > pivot2) 
            while (k < great && b[great] > pivot2) 
                great--;
            
            swap(a, b, k, great--);
            if (b[k] < pivot1) 
                swap(a, b, k, less++);
            
        
    
    // swaps
    int dist = great - less;
    if (dist < 13) 
       div++;
    
    swap(a, b, less  - 1, left);
    swap(a, b, great + 1, right);
    // subarrays
    dualPivotQuicksort(a, b, left,   less - 2, div);
    dualPivotQuicksort(a, b, great + 2, right, div);

    // equal elements
    if (dist > len - 13 && pivot1 != pivot2) 
        for (int k = less; k <= great; k++) 
            if (b[k] == pivot1) 
                swap(a, b, k, less++);
            
            else if (b[k] == pivot2) 
                swap(a, b, k, great--);
                if (b[k] == pivot1) 
                    swap(a, b, k, less++);
                
            
        
    
    // subarray
    if (pivot1 < pivot2) 
        dualPivotQuicksort(a, b, less, great, div);
    
    

    public static void sort(int[] a, int[] b) 
        sort(a, b, 0, b.length);
    

    public static void sort(int[] a, int[] b, int fromIndex, int toIndex) 
        rangeCheck(a.length, fromIndex, toIndex);
        dualPivotQuicksort(a, b, fromIndex, toIndex - 1, 3);
    

    private static void rangeCheck(int length, int fromIndex, int toIndex) 
        if (fromIndex > toIndex) 
            throw new IllegalArgumentException("fromIndex > toIndex");
        
        if (fromIndex < 0) 
            throw new ArrayIndexOutOfBoundsException(fromIndex);
        
        if (toIndex > length) 
            throw new ArrayIndexOutOfBoundsException(toIndex);
        
    

    private static void swap(int[] a, int[] b, int i, int j) 
        int swap1 = a[i];
        int swap2 = b[i];
        a[i] = a[j];
        b[i] = b[j];
        a[j] = swap1;
        b[j] = swap2;
    

    private static int[] naiveCountSegments(int[] starts, int[] ends, int[] points) 
        int[] cnt = new int[points.length];
        for (int i = 0; i < points.length; i++) 
            for (int j = 0; j < starts.length; j++) 
                if (starts[j] <= points[i] && points[i] <= ends[j]) 
                    cnt[i]++;
                
            
        
        return cnt;
    

    public static void main(String[] args) 
        Scanner scanner = new Scanner(System.in);
        int n, m;
        n = scanner.nextInt();
        m = scanner.nextInt();
        int[] starts = new int[n];
        int[] ends = new int[n];
        int[] points = new int[m];
        for (int i = 0; i < n; i++) 
            starts[i] = scanner.nextInt();
            ends[i] = scanner.nextInt();
        
        for (int i = 0; i < m; i++) 
            points[i] = scanner.nextInt();
        
        //use fastCountSegments
        int[] cnt = fastCountSegments(starts, ends, points);
        for (int x : cnt) 
            System.out.print(x + " ");
        
    

我认为问题出在 CountSegments() 方法中,但我不确定其他解决方法。据说,我应该使用分而治之的算法,但是 4 天后,我想出了任何解决方案。 我找到了a similar problem in CodeForces,但输出不同,大多数解决方案都是 C++。由于我刚开始学习 java 的 3 个月,我想我已经达到了我的知识极限。

【问题讨论】:

你必须告诉最大段数和分数,对于像在线评委这样的问题,应该提供吗? 段和点在 1 到 50,000 之间。 片段会重叠吗? 是的,它们可以重叠,并且 Segment=(a,b) 范围是 -10^8 【参考方案1】:

给定 OP 的约束,让n 为段数,m 为要查询的点数,其中n,m

由于每个查询都是一个验证问题:该点是否可以被一些区间覆盖,是或否,我们不需要找到哪个/多少个区间点已覆盖。

算法概述:

    首先按起点对所有区间进行排序,如果相同则按长度(最右边的终点)进行排序 尝试合并区间以获得一些不相交重叠区间。例如(0,5), (2,9), (3,7), (3,5), (12,15) ,你会得到 (0,9), (12,15)。随着间隔的排序,这可以在O(n) 中贪婪地完成 以上是预计算,现在对于每个点,我们使用不相交的间隔进行查询。如果任何区间包含这样的点,只需二进制搜索,每个查询是O(lg(n)),我们得到m 点,所以总共O(m lg(n))

结合整个算法,我们将得到一个O(nlg(n) + mlg(n))算法

【讨论】:

如果段在 2. 中合并并重叠,我如何计算点相交的段数?我也应该合并2中的点吗?我假设,所以在 3 中我用二进制搜索搜索点。我不确定我的 java lvl 是否足以在代码中转换这 3 个点,但我有一个起点 不,你搞混了。如果您不需要知道该点相交的段数,则此方法有效,而只需知道“是”或“否”。而在2中,就是3的准备,不考虑点,你只是尝试将输入数据变成一些更有利的形式(不相交和排序间隔),然后你开始使用这个转换后的数据逐个查询点.是的,在第 3 点中,正如我所说,我们使用二进制搜索的每一点,您可以阅读 C++ 中的 lower_bound() 或 upper_bound()(不确定 Java 是否有类似的内置函数) 抱歉,我想我选择了一个输入和输出示例,它没有指定如果点出现在多个段中,我必须计算。我刚刚修改了那个细节。 @Nooblhu 所以你改变了你的 OP,这是一个完全不同的问题...... 对不起,我并没有真正改变那个选项,检查我的代码,它确实计算段数。但我试图选择一个简单的例子,它只是混淆了我期望的输出【参考方案2】:

这是一个类似于@Shole 想法的实现:

public class SegmentsAlgorithm 

    private PriorityQueue<int[]> remainSegments = new PriorityQueue<>((o0, o1) -> Integer.compare(o0[0], o1[0]));
    private SegmentWeight[] arraySegments;

    public void addSegment(int begin, int end) 
        remainSegments.add(new int[]begin, end);
    

    public void prepareArrayCache() 
        List<SegmentWeight> preCalculate = new ArrayList<>();
        PriorityQueue<int[]> currentSegmentsByEnds = new PriorityQueue<>((o0, o1) -> Integer.compare(o0[1], o1[1]));
        int begin = remainSegments.peek()[0];
        while (!remainSegments.isEmpty() && remainSegments.peek()[0] == begin) 
            currentSegmentsByEnds.add(remainSegments.poll());
        
        preCalculate.add(new SegmentWeight(begin, currentSegmentsByEnds.size()));
        int next;
        while (!remainSegments.isEmpty()) 
            if (currentSegmentsByEnds.isEmpty()) 
                next = remainSegments.peek()[0];
             else 
                next = Math.min(currentSegmentsByEnds.peek()[1], remainSegments.peek()[0]);
            
            while (!currentSegmentsByEnds.isEmpty() && currentSegmentsByEnds.peek()[1] == next) 
                currentSegmentsByEnds.poll();
            
            while (!remainSegments.isEmpty() && remainSegments.peek()[0] == next) 
                currentSegmentsByEnds.add(remainSegments.poll());
            
            preCalculate.add(new SegmentWeight(next, currentSegmentsByEnds.size()));
        
        while (!currentSegmentsByEnds.isEmpty()) 
            next = currentSegmentsByEnds.peek()[1];
            while (!currentSegmentsByEnds.isEmpty() && currentSegmentsByEnds.peek()[1] == next) 
                currentSegmentsByEnds.poll();
            
            preCalculate.add(new SegmentWeight(next, currentSegmentsByEnds.size()));
        
        SegmentWeight[] arraySearch = new SegmentWeight[preCalculate.size()];
        int i = 0;
        for (SegmentWeight l : preCalculate) 
            arraySearch[i++] = l;
        
        this.arraySegments = arraySearch;
    

    public int searchPoint(int p) 
        int result = 0;
        if (arraySegments != null && arraySegments.length > 0 && arraySegments[0].begin <= p) 
            int index = Arrays.binarySearch(arraySegments, new SegmentWeight(p, 0), (o0, o1) -> Integer.compare(o0.begin, o1.begin));
            if (index < 0)  // Bug fixed
                index = - 2 - index;
            
            if (index >= 0 && index < arraySegments.length)  // Protection added
                result = arraySegments[index].weight;
            
        
        return result;
    

    public static void main(String[] args) 
        SegmentsAlgorithm algorithm = new SegmentsAlgorithm();
        int[][] segments = 0, 5,3, 10,8, 9,14, 20,12, 28;
        for (int[] segment : segments) 
            algorithm.addSegment(segment[0], segment[1]);
        
        algorithm.prepareArrayCache();

        int[] points = -1, 2, 4, 6, 11, 28;

        for (int point: points) 
            System.out.println(point + ": " + algorithm.searchPoint(point));
        
    

    public static class SegmentWeight 

        int begin;
        int weight;

        public SegmentWeight(int begin, int weight) 
            this.begin = begin;
            this.weight = weight;
        
    

打印出来:

-1: 0
2: 1
4: 2
6: 1
11: 2
28: 0

编辑:

public static void main(String[] args) 
    SegmentsAlgorithm algorithm = new SegmentsAlgorithm();
    Scanner scanner = new Scanner(System.in);
    int n = scanner.nextInt();
    int m = scanner.nextInt();
    for (int i = 0; i < n; i++) 
        algorithm.addSegment(scanner.nextInt(), scanner.nextInt());
    
    algorithm.prepareArrayCache();
    for (int i = 0; i < m; i++) 
        System.out.print(algorithm.searchPoint(scanner.nextInt())+ " ");
    
    System.out.println();

【讨论】:

您的代码运行良好,但我需要用测试用例证明这一点。您如何调整 2D 数组 [][] 段以像这样从 Scanner 获取输入:1st, 2nd, 3rd, 4th, 5th, 6th......? @Nooblhu 已编辑,适用于 Scanner。 不知道为什么,它在 searchPoint 方法中抛出 ArrayIndexOutofBounds:result = arraySegments[Math.abs(index)].weight;在 System.out.print(algorithm.searchPoint(scanner.nextInt())+ " ") 中调用时; (使用问题的测试用例) OutofBounds 错误已修复,但我不确定为什么结果不正确 示例:对于段 (0, 5) (-3, 2) (7, 10) 点:1 和 6:结果应该是 2 0(因为 1 在 (0,5) 和 (-3,2) 中,而 0 不在任何段中),但它输出 0 0

以上是关于如何改进此算法以优化运行时间(在段中查找点)的主要内容,如果未能解决你的问题,请参考以下文章

matlab改进遗传算法求解带时间窗的路径优化问题

优化 DBSCAN 以计算运行

如何优化此代码以运行更大的值? [复制]

优化调度基于matlab改进粒子群算法求解微电网优化调度问题含Matlab源码 052期

如何优化 SQL 查询以减少运行时间?

有没有办法优化此代码以查找数字的除数?