Sparse vector Multiplication
Posted tobeabetterpig
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Sparse vector Multiplication相关的知识,希望对你有一定的参考价值。
Sparse vector Multiplication https://github.com/tongzhang1994/Facebook-Interview-Coding/blob/master/Sparce%20Matrix%20Multiplication.java public class Solution {//assume inputs are like {{2, 4}, {0, 10}, {3, 15}},index0 is index of non-zero vals,index1 is the val private Comparator<ArrayList<Integer>> sparseVectorComparator = new Comparator<ArrayList<Integer>>(){ public int compare(ArrayList<Integer> a, ArrayList<Integer> b) { return a.get(0) - b.get(0); } };//remember to add ";" !!! public int sparseVectorMultiplication(ArrayList<ArrayList<Integer>> a, ArrayList<ArrayList<Integer>> b) { if (a == null || b == null || a.size() == 0 || b.size() == 0) { return 0; } int m = a.size(); int n = b.size(); int res = 0; //two inputs are unsorted, directly iterate the elements(brute force); O(m*n) time; if use sort, O(mlogm + nlogn) for (int i = 0; i < m; i++) { ArrayList<Integer> pairA = a.get(i); for (int j = 0; j < n; j++) { ArrayList<Integer> pairB = b.get(i); if (pairA.get(0) == pairB.get(0)) {//if their indices are the same, calculate and break res += pairA.get(1) * pairB.get(1); break;//pairA has been calculated, jump to next pair } } } //if we need to sort the inputs Collections.sort(a, sparseVectorComparator); //two inputs are sorted by index0, use two pointers(move the smaller, calculate the equal); O(m+n) time int i = 0; int j = 0; while (i < m && j < n) { ArrayList<Integer> pairA = a.get(i); ArrayList<Integer> pairB = b.get(j); if (pairA.get(0) < pairB.get(0)) { i++; } else if (pairA.get(0) > pairB.get(0)) { j++; } else { res += pairA.get(1) * pairB.get(1); i++; j++; } } //two inputs are sorted by index0, have same size, sometimes dense, sometimes sparse; two pointes + binary search int i = 0; int j = 0; int countA = 0; int countB = 0; while (i < m && j < n) { ArrayList<Integer> pairA = a.get(i); ArrayList<Integer> pairB = b.get(j); if (pairA.get(0) < pairB.get(0)) { i++; countA++; countB = 0; if (countA > Math.log(m)) { i = search(a, i, m, pairB.get(0)); countA = 0; } } else if (pairA.get(0) > pairB.get(0)) { j++; countB++; countA = 0; if (countB > Math.log(n)) { j = search(b, j, n, pairA.get(0)); countB = 0; } } else { res += pairA.get(1) * pairB.get(1); i++; j++; countA = 0; countB = 0; } } //two inputs are sorted by index0, input b is much larger than input a, iterate a and binary search b; O(m*logn) time int i = 0; int j = 0; while (i < m) { ArrayList<Integer> pairA = a.get(i++); j = search(b, j, n, pairA.get(0)); ArrayList<Integer> pairB = b.get(j++); if (pairA.get(0) == pairB.get(0)) { res += pairA.get(1) * pairB.get(1); } } return res; } private int search(ArrayList<ArrayList<Integer>> array, int start, int end, int target) { while (start + 1 < end) { int mid = start + (end - start) / 2; ArrayList<Integer> pair = array.get(mid); if (pair.get(0) == target) { return mid; } else if (pair.get(0) < target) { start = mid; } else { end = mid; } } if (array.get(end).get(0) == target) { return end; } return start; } } 面试官先问每个vector很大,不能在内存中存下怎么办,我说只需存下非零元素和他们的下标就行,然后问面试官是否可用预处理后的 这两个vector非零元素的index和value作为输入,面试官同意后写完O(M*N)的代码(输入未排序,只能一个个找),MN分别是两个vector长度。 又问这两个输入如果是根据下标排序好的怎么办,是否可以同时利用两个输入都是排序好这一个特性,最后写出了O(M + N)的双指针方法, 每次移动pair里index0较小的指针,如果相等则进行计算,再移动两个指针。 又问如果一个向量比另一个长很多怎么办,我说可以遍历长度短的那一个,然后用二分搜索的方法在另一个vector中找index相同的那个元素, 相乘加入到结果中,这样的话复杂度就是O(M*logN)。 又问如果两个数组一样长,且一会sparse一会dense怎么办。他说你可以在two pointer的扫描中内置一个切换二分搜索的机制。 看差值我说过,设计个反馈我说过,他说不好。他期待的解答是,two pointers找到下个位置需要m次比较,而直接二分搜需要log(n)次比较。 那么在你用two pointers方法移动log(n)次以后,就可以果断切换成二分搜索模式了。 Binary search如果找到了一个元素index,那就用这次的index作为下次binary search的开始。可以节约掉之前的东西,不用search了。 然后问,如果找不到呢,如何优化。说如果找不到,也返回上次search结束的index,然后下次接着search。 就是上一次找到了,就用这个index继续找这次的;如果找不到,也有一个ending index,就用那个index当starting index。 比如[1, 89,100],去找90;如果不存在,那么binary search的ending index应该是89,所以下次就从那个index开始。 如果找不到,会返回要插入的位置index + 1,index是要插入的位置,我写的就是返回要插入的index的。 但是不管返回89还是100的index都无所谓,反正只差一个,对performance没有明显影响的。
以上是关于Sparse vector Multiplication的主要内容,如果未能解决你的问题,请参考以下文章
scipy.sparse.csr_matrix 行过滤 - 如何正确实现?