《统计学习方法》第3章习题
Posted 程劼
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了《统计学习方法》第3章习题相关的知识,希望对你有一定的参考价值。
习题3.1
略
习题3.2
根据例 3.2 构造的 kd 树,可知最近邻点为 \\((2,3)^T\\)
习题3.3
k 近邻法主要需要构造相应的 kd 树。这里用 Python 实现 kd 树的构造与搜索
import heapq
import numpy as np
class KDNode:
def __init__(self, data, axis=0, left=None, right=None):
self.data = data
self.axis = axis
self.left = left
self.right = right
class KDTree:
def __init__(self, data):
self.raw_data = data
self.k = data.shape[1]
def construct(self):
data = self.raw_data
self.root = self._insert_node(data, 0)
def search(self, x, near_k=1, p=2):
self.knn = [(-np.inf, None)]*near_k
self._visit(self.root, x, p)
self.knn = np.array([i[1].data for i in heapq.nlargest(near_k, self.knn)])
return self.knn
def pre_order_traverse(self, node):
print(node.data)
if node.left:
self.pre_order_traverse(node.left)
if node.right:
self.pre_order_traverse(node.right)
def _insert_node(self, data, depth=0):
if len(data) == 0:
return None
axis = depth % self.k
data = sorted(data, key = lambda x: x[axis])
middle = len(data) // 2
return KDNode(
data[middle],
axis,
self._insert_node(data[:middle], depth+1),
self._insert_node(data[middle+1:], depth+1)
)
def _visit(self, node, x, p=2):
if node is not None:
dis = x[node.axis] - node.data[node.axis]
self._visit(node.left if dis < 0 else node.right, x, p)
curr_dis = np.linalg.norm(x-node.data, p)
heapq.heappushpop(self.knn, (-curr_dis, node))
if -(self.knn[0][0]) > abs(dis):
self._visit(node.right if dis < 0 else node.left, x, p)
if __name__ == "__main__":
data = np.array([
[2,3],
[5,4],
[9,6],
[4,7],
[8,1],
[7,2]
])
tree = KDTree(data)
tree.construct()
print(tree.search(np.array([3, 4.5]), 2))
通过调用 KDTree 的 search 方法即可实现查找 x 的 k 近邻。 结果为 \\([(2,3)^T, (5,4)^T]\\)
以上是关于《统计学习方法》第3章习题的主要内容,如果未能解决你的问题,请参考以下文章
PTA的Python练习题-第4章-7 统计学生平均成绩与及格人数
R语言基础题及答案——R语言与统计分析第六章课后习题(汤银才)