NearestNeighbors sklearn 的自定义指标
Posted
技术标签:
【中文标题】NearestNeighbors sklearn 的自定义指标【英文标题】:Custom metric for NearestNeighbors sklearn 【发布时间】:2018-12-01 14:15:15 【问题描述】:您好,我正在进行一个使用 512 位哈希创建集群的项目。我正在使用自定义度量按位汉明距离。但是当我将两个散列与这个函数进行比较时,我得到的距离结果与使用 NearestNeighbors 不同。
将此扩展到 DBSCAN,使用 eps=5,创建的集群具有一定的一致性,正在正确集群。但是我尝试检查来自同一个集群的点之间的距离,我获得了巨大的距离。这是一个例子。
示例: 这是由 DBSCAN 创建的来自 2 个集群的点列表,正如您在使用函数计算距离时所看到的,给出的数字大于 30,但 NN 给出的结果与 eps=5 一致。
from sklearn.neighbors import NearestNeighbors
hash_list_1 = [2711636196460699638441853508983975450613573844625556129377064665736210167114069990407028214648954985399518205946842968661290371575620508000646896480583712,
2711636396252606881895803338309150146134565539796776390549907030396205082681800682439355456735713892762967881436259141637319066484744271299497977370896760,
2711636396252606881918517135048330084905033589325484952567856239496981859330884970846906663264518266744879431357749780779892124020350824669153434630258784,
2711636396252797418317524490088561493800258861799581574018898781319096107333812163580085003775074676924785748114206505865657620572909617106316367216148512,
2711636196460318585955127494483972276879239064090689852809978361705086216958169367104329622890567955158961917611852516176654399246340379120409329566384160,
2711636396252606881918605860499354102197401318666579124151729671752374458560929422237113300739169875232495266727513833203360007861082211711747836501459040,
2685449071597530523833230885351500532369477539914318172159429043161052628696351016818586542171509728747070238075233795777242761861490021015910382103951968,
2685449271584547381638295372872027557715092296493457397817270817010861872186702795218797216694169625716749654321460983923962566367029011600112932108533792,
2685449071792640184514638654713547133316375160837810451952682241651988724244365461216285304336254942220323815140042850082680124299635209323646382761738272,
1847461275963134712629870519594779049860430827711272857522520377357653173694038204556169999876899727026751811340128091158803029889914422883922033917198368,
2711636396252606881901567718540735842607739343712295416931961674938924754114357607352250040524848697769853213132484145241805622979375000168935113673834592,
2711636396252606881901567718538101947732706353297593371282460773094032493492652041376662823635245997887100968237677157520342076957158825588198798784364576]
hash_list_2 = [1677246762479319235863065539858628614044010438213592493389244703420353559152336301659250128835190166728647823546464421558167523127086351613289685036466208,
1677246762479700308655934218233084077989052614799077817712715603728397519829375248244181345837838956827991047769168833176865438232999821278031784406056992,
1677246762479700314487411751526941880161990070273187005125752885368412445003620183982282356578440274746789782460884881633682918768578649732794162647826464,
1677246762479319238759152196394352786642547660315097253847095508872934279466872914748604884925141826161428241625796765725368284151706959618924400925900832,
1677246762479890853811162999308711253291696853123890392766127782305403145675433285374478727414572392743118524142664546768046227747593095585347134902140960,
1677246765601448867710925237522621090876591539557992237656925108430781026329148912958069241932475038282622646533152559554888274158032061637714105308528752,
1678883457783648388335228538833424204662395277995143067623864457726472665342252064374635323999849241968448535982901839797440478656657327613912450890367008,
1677246765601448864793634462245189770642489500950753120409198344054454862566173176691699195659218600616315451200851360013275424257209428603245704937128032,
1677246762479700314471974894075267160937462491405299015541470373650765401692659096424270522124311243007780041455682577230603077926878181390448030335795232,
1677246762479700317400446530288778920091525622772690226165317385340164047644547471081180880454458397836230795248631079659291423401151022423365062554976288,
1677246762479700317400446530288758590086745084806873060513679821541689120894219120403259478342385343805541797540566045409406476458247878183422733877936160,
2516871453405060707064684111867902766968378200849671168835363528433280949578746081906100803610196553501646503982070255639855643685380535999494563083255776,
1677246762479319230037086118512223039643232176451879100417048497454912234466993748113993020733268935613563596294183318283010061477487433484794582123053088,
1677246762479319235834673272207747972667132521699112379991979781620810490520617303678451683578338921267417975279632387450778387555221361833006151849902112,
1677246762479700305748490595643272813492272250002832996415474372704463760357437926852625171223210803220593114114602433734175731538424778624130491225112608]
def custom_metric(x, y):
return bin(int(x[0]) ^ int(y[0])).count('1')
objective_hash = hash_list_1[0]
complete_list = hash_list_1 + hash_list_2
distance = [custom_metric([objective_hash], [hash_point]) for hash_point in complete_list]
print("Function iteration distance:")
print(distance)
neighbors_model = NearestNeighbors(radius=100, algorithm='ball_tree',
leaf_size=2,
metric=custom_metric,
metric_params=None,
n_jobs=4)
X = [[x] for x in complete_list]
neighbors_model.fit(X)
distance, neighborhoods = neighbors_model.radius_neighbors(objective_hash, 100, return_distance=True)
print("Nearest Neighbors distance:")
print(distance)
print("Nearest Neighbors index:")
print(neighborhoods)
【问题讨论】:
检查任何意外的数据转换。例如,也许你所有的数据都被转换成双精度数? 性能方面,你的距离函数太可怕了。计算字符串中的字符。哎哟。我建议你使用实际的 bity 并在 sklearn 中的 cython 中重写这部分,以获得不错的性能。 @Anony-Mousse 正如您所建议的,问题是 numpy 无法处理 int 这么大的数字,并转换为浮点数。我知道它不是很有效,但现在我有兴趣让它正常工作,然后进行优化。我解决了预先计算距离的问题。 你可能应该使用字节数组...并使用高效的位操作。 【参考方案1】:问题:
Numpy 无法处理这么大的数字,并将它们转换为浮点数会损失很多精度。
解决办法:
使用您的自定义指标预先计算所有距离并将它们提供给 DBSCAN 算法。
【讨论】:
以上是关于NearestNeighbors sklearn 的自定义指标的主要内容,如果未能解决你的问题,请参考以下文章
scikit_learn (sklearn)库中NearestNeighbors(最近邻)函数的各参数说明
KNN K-最近邻:train_test_split 和 knn.kneighbors