算法-元素最近邻快速查找

背景

需求:给定一个元素,在指定集合中输出n个最近邻的元素。如果集合中没有重复元素,那么只有一个元素输出;如果集合中由重复元素,那么可能存在多个最近邻元素输出。

求解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
import numpy as np
from numpy import *
import math
from mpl_toolkits.mplot3d import Axes3D
import csv

def ReadCsvData(file_name):
time = []
loc = []
with open(file_name,'r') as csvfile:
reader = csv.reader(csvfile)
rows= [row for row in reader]
for data in rows:
data = np.array(data)
time.append(float(data[0]))
loc.append([float(data[1]), float(data[2]), float(data[3])])
time = np.array(time)
loc = np.array(loc)
return time, loc

def searchInsert(nums, target):
start = 0
end = len(nums) - 1
while start <= end:
mid = (start + end) // 2
if nums[mid] == target:
return mid
elif nums[mid] < target:
start = mid + 1
else:
end = mid - 1
return end + 1

def rIndex(nums, target):
n = len(nums)
if n == 0: return -1
mid = searchInsert(nums, target)
rlist = [] # 保持索引
i, j = -1, n
left, rigth = 0, 0 # 左右扩展的标志
mxg = float('-inf')
if 0 < mid < n: # 如果找到了
i, j = mid-1, mid
mxg = min(abs(nums[i] - target), abs(nums[j] - target))
left, rigth = 1, 1
elif mid == 0: # 小于最左边的数字
j = mid
mxg = abs(nums[j] - target)
left, rigth = 0, 1
elif mid == n: # 大于最右边的数字
i = mid-1
mxg = abs(nums[i] - target)
left, rigth = 1, 0
while left == 1 or rigth == 1: # 两边查找
if i == -1: left = 0
if j == n: rigth = 0
if left == 1 and i >= 0:
le = abs(nums[i] - target)
if le == mxg:
rlist = [i] + rlist
i -= 1
else:
left = 0
if rigth == 1 and j < len(nums):
ri = abs(nums[j] - target)
if mxg == ri:
rlist = rlist + [j]
j += 1
else:
rigth = 0
return rlist

if __name__ == '__main__':
################# File path #######################
gt_path = "../data/ADVIO/advio-02/ground-truth/pose.csv"
ARKit_path = "../data/ADVIO/advio-02/iphone/arkit.csv"

################# Read data #######################
time_gt = []
loc_gt = []
time_ARKit = []
loc_ARKit = []
errors = []
time_gt, loc_gt = ReadCsvData(gt_path)
time_ARKit, loc_ARKit = ReadCsvData(ARKit_path)
print len(time_gt), len(time_ARKit)
if len(time_gt) < len(time_ARKit):
for i in range(len(time_gt)):
index = rIndex(time_ARKit, time_gt[i])[0]
error = math.sqrt((loc_gt[i][0] - loc_ARKit[index][0])**2 + (loc_gt[i][1] - loc_ARKit[index][1])**2 + (loc_gt[i][2] - loc_ARKit[index][2])**2)
errors.append(error)
else:
for i in range(len(time_ARKit)):
index = rIndex(time_gt, time_ARKit[i])[0]
error = math.sqrt((loc_gt[index][0] - loc_ARKit[i][0])**2 + (loc_gt[index][1] - loc_ARKit[i][1])**2 + (loc_gt[index][2] - loc_ARKit[i][2])**2)
errors.append(error)
errors = np.array(errors)
print errors
plt.figure(1)
plt.plot(loc_gt[:, 0], loc_gt[:, 2], c = 'r')
plt.plot(loc_ARKit[:, 0], loc_ARKit[:, 2], c = 'b')
plt.xlabel("X (m)")
plt.ylabel("Y (m)")
plt.legend(["Groud-truth", "ARKit"])
plt.figure(2)
plt.plot(errors, c = 'b')
plt.xlabel('Index')
plt.ylabel('Error (m)')
plt.show()