跳转至

交通标志识别

校准

1.启动gazebo节点

roslaunch turtlebot3_gazebo turtlebot3_autorace_2020.launch

2.移动到交通标志前,可以看清标识

roslaunch turtlebot3_teleop turtlebot3_teleop_key.launch

3.打开可视化界面

rqt_image_view
  • 选择/camera/image_compensated
  • 这时的图像应该清晰地看到该标志,保存该图像同时从中裁剪出只带有交通标志的部分
  • 然后将裁剪下载的标志图像放在/turtlebot3_autorace_2020/turtlebot3_autorace_detect/image/目录中
  • 图像的名称应与源码中指定的名称相匹配,譬如construction.pngintersection.pngleft.pngright.pngparking.pngstop.pntunnel.png

测试

1.关闭所有节点,打开gazebo

roslaunch turtlebot3_gazebo turtlebot3_autorace_2020.launch

2.启动相机内标定

roslaunch turtlebot3_autorace_camera intrinsic_camera_calibration.launch

3.启动相机外标定

roslaunch turtlebot3_autorace_camera extrinsic_camera_calibration.launch 

4.启动键盘控制节点

roslaunch turtlebot3_teleop turtlebot3_teleop_key.launch

控制小车移动到标志前,然后关闭键盘控制节点

5.启动交通检测程序

roslaunch turtlebot3_autorace_detect detect_sign.launch mission:=intersection

选择需要检测的信号标志的任务参数,对应的参数有:intersectionconstructionparkinglevel_crossingtunnel

6.打开可视化界面

rqt_image_view

订阅话题/detect/image_traffic_sign/compressed

  • mission:=intersection: 交叉路口标志

  • mission:=intersection: 左转标志

  • mission:=intersection: 右转标志

  • mission:=construction: 施工路段标志

  • mission:=parking: 停车标志

  • mission:=level_crossing: 交通杆标志

  • mission:=tunnel:隧道标志

实现

detect_intersection_sign

turtlebot3_autorace_2020\turtlebot3_autorace_detect\nodes\detect_intersection_sign

fnPreproc

    def fnPreproc(self):
        # Initiate SIFT detector
        self.sift = cv2.SIFT_create()

        dir_path = os.path.dirname(os.path.realpath(__file__))
        dir_path = dir_path.replace('turtlebot3_autorace_detect/nodes', 'turtlebot3_autorace_detect/')
        dir_path += 'image/'
        # 灰度图
        self.img_intersection = cv2.imread(dir_path + 'intersection.png', 0)
        self.img_left = cv2.imread(dir_path + 'left.png', 0)
        self.img_right = cv2.imread(dir_path + 'right.png', 0)

        self.kp_intersection, self.des_intersection = self.sift.detectAndCompute(self.img_intersection, None)
        self.kp_left, self.des_left = self.sift.detectAndCompute(self.img_left, None)
        self.kp_right, self.des_right = self.sift.detectAndCompute(self.img_right, None)

        FLANN_INDEX_KDTREE = 0
        index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
        # 指定索引中的树应该递归遍历的次数,数值越高,精度越高,也越耗时
        search_params = dict(checks=50)
        # 创建Flann匹配器
        self.flann = cv2.FlannBasedMatcher(index_params, search_params)

cbFindTrafficSign

    def cbFindTrafficSign(self, image_msg):
        # 降低图片检测效率
        if self.counter % 3 != 0:
            self.counter += 1
            return
        else:
            self.counter = 1
        # 将ros图片转换成opencv图片
        if self.sub_image_type == "compressed":
            np_arr = np.frombuffer(image_msg.data, np.uint8)
            cv_image_input = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
        elif self.sub_image_type == "raw":
            cv_image_input = self.cvBridge.imgmsg_to_cv2(image_msg, "bgr8")

        MIN_MATCH_COUNT = 5

        # find the keypoints and descriptors with SIFT
        kp1, des1 = self.sift.detectAndCompute(cv_image_input, None)

        matches_intersection = self.flann.knnMatch(des1, self.des_intersection, k=2)
        matches_left = self.flann.knnMatch(des1, self.des_left, k=2)
        matches_right = self.flann.knnMatch(des1, self.des_right, k=2)
        # 确定需要发布的交通标志图片
        image_out_num = 1
        # ------------------ intersection标志检测 ------------------ #
        good_intersection = []
        for m, n in matches_intersection:
            if m.distance < 0.7 * n.distance:
                good_intersection.append(m)
        if len(good_intersection) > MIN_MATCH_COUNT:
            src_pts = np.float32([kp1[m.queryIdx].pt for m in good_intersection]).reshape(-1, 1, 2)
            dst_pts = np.float32([self.kp_intersection[m.trainIdx].pt for m in good_intersection]).reshape(-1, 1, 2)

            M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
            matches_intersection = mask.ravel().tolist()

            msg_sign = UInt8()
            msg_sign.data = self.TrafficSign.intersection.value

            self.pub_traffic_sign.publish(msg_sign)

            rospy.loginfo("detect intersection sign")
            image_out_num = 2

        # ------------------ left标志检测 ------------------ #
        good_left = []
        for m, n in matches_left:
            if m.distance < 0.7 * n.distance:
                good_left.append(m)
        if len(good_left) > MIN_MATCH_COUNT:
            src_pts = np.float32([kp1[m.queryIdx].pt for m in good_left]).reshape(-1, 1, 2)
            dst_pts = np.float32([self.kp_left[m.trainIdx].pt for m in good_left]).reshape(-1, 1, 2)

            M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
            matches_left = mask.ravel().tolist()

            msg_sign = UInt8()
            msg_sign.data = self.TrafficSign.left.value

            self.pub_traffic_sign.publish(msg_sign)
            rospy.loginfo("detect left sign")
            image_out_num = 3
        else:
            matches_left = None
        # ------------------ right标志检测 ------------------ #
        good_right = []
        for m, n in matches_right:
            if m.distance < 0.7 * n.distance:
                good_right.append(m)
        if len(good_right) > MIN_MATCH_COUNT:
            src_pts = np.float32([kp1[m.queryIdx].pt for m in good_right]).reshape(-1, 1, 2)
            dst_pts = np.float32([self.kp_right[m.trainIdx].pt for m in good_right]).reshape(-1, 1, 2)

            M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
            matches_right = mask.ravel().tolist()

            msg_sign = UInt8()
            msg_sign.data = self.TrafficSign.right.value

            self.pub_traffic_sign.publish(msg_sign)
            rospy.loginfo("detect right sign")
            image_out_num = 4
        else:
            matches_right = None
        # ------------------ 检测结果发布 ------------------ #
        # 没有检测到标志
        if image_out_num == 1:
            print('not detect a sign ===================================')
            if self.pub_image_type == "compressed":
                self.pub_image_traffic_sign.publish(self.cvBridge.cv2_to_compressed_imgmsg(cv_image_input, "jpg"))

            elif self.pub_image_type == "raw":
                self.pub_image_traffic_sign.publish(self.cvBridge.cv2_to_imgmsg(cv_image_input, "bgr8"))
        # intersection标志发布
        elif image_out_num == 2:
            draw_params_intersection = dict(matchColor=(255, 0, 0),  # draw matches in green color
                                            singlePointColor=None,
                                            matchesMask=matches_intersection,  # draw only inliers
                                            flags=2)

            final_intersection = cv2.drawMatches(cv_image_input, kp1, self.img_intersection, self.kp_intersection,
                                                 good_intersection, None, **draw_params_intersection)

            if self.pub_image_type == "compressed":
                self.pub_image_traffic_sign.publish(self.cvBridge.cv2_to_compressed_imgmsg(final_intersection, "jpg"))

            elif self.pub_image_type == "raw":
                self.pub_image_traffic_sign.publish(self.cvBridge.cv2_to_imgmsg(final_intersection, "bgr8"))
        # left标志发布
        elif image_out_num == 3:
            draw_params_left = dict(matchColor=(255, 0, 0),  # draw matches in green color
                                    singlePointColor=None,
                                    matchesMask=matches_left,  # draw only inliers
                                    flags=2)

            final_left = cv2.drawMatches(cv_image_input, kp1, self.img_left, self.kp_left, good_left, None,
                                         **draw_params_left)

            if self.pub_image_type == "compressed":
                # publishes traffic sign image01 in compressed type
                self.pub_image_traffic_sign.publish(self.cvBridge.cv2_to_compressed_imgmsg(final_left, "jpg"))

            elif self.pub_image_type == "raw":
                # publishes traffic sign image01 in raw type
                self.pub_image_traffic_sign.publish(self.cvBridge.cv2_to_imgmsg(final_left, "bgr8"))
        # right标志发布
        elif image_out_num == 4:
            draw_params_right = dict(matchColor=(255, 0, 0),  # draw matches in green color
                                     singlePointColor=None,
                                     matchesMask=matches_right,  # draw only inliers
                                     flags=2)

            fianl_right = cv2.drawMatches(cv_image_input, kp1, self.img_right, self.kp_right, good_right, None,
                                          **draw_params_right)

            if self.pub_image_type == "compressed":
                self.pub_image_traffic_sign.publish(self.cvBridge.cv2_to_compressed_imgmsg(fianl_right, "jpg"))
            elif self.pub_image_type == "raw":
                self.pub_image_traffic_sign.publish(self.cvBridge.cv2_to_imgmsg(fianl_right, "bgr8"))

detect_construction_sign

turtlebot3_autorace_2020\turtlebot3_autorace_detect\nodes\detect_construction_sign

fnPreproc

    def fnPreproc(self):
        # Initiate SIFT detector
        self.sift = cv2.SIFT_create()

        dir_path = os.path.dirname(os.path.realpath(__file__))
        dir_path = dir_path.replace('turtlebot3_autorace_detect/nodes', 'turtlebot3_autorace_detect/')
        dir_path += 'image/'

        self.img_construction = cv2.imread(dir_path + 'construction.png', 0)
        self.kp_construction, self.des_construction = self.sift.detectAndCompute(self.img_construction, None)

        FLANN_INDEX_KDTREE = 0
        index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
        search_params = dict(checks=50)

        self.flann = cv2.FlannBasedMatcher(index_params, search_params)

cbFindTrafficSign

    def cbFindTrafficSign(self, image_msg):
        # drop the frame to 1/5 (6fps) because of the processing speed. This is up to your computer's operating power.
        if self.counter % 3 != 0:
            self.counter += 1
            return
        else:
            self.counter = 1

        if self.sub_image_type == "compressed":
            # converting compressed image01 to opencv image01
            np_arr = np.frombuffer(image_msg.data, np.uint8)
            cv_image_input = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
        elif self.sub_image_type == "raw":
            cv_image_input = self.cvBridge.imgmsg_to_cv2(image_msg, "bgr8")

        MIN_MATCH_COUNT = 5

        # find the keypoints and descriptors with SIFT
        kp1, des1 = self.sift.detectAndCompute(cv_image_input, None)

        matches_construction = self.flann.knnMatch(des1, self.des_construction, k=2)

        image_out_num = 1
        # ------------------ construction标志检测 ------------------ #
        good_construction = []
        for m, n in matches_construction:
            if m.distance < 0.7 * n.distance:
                good_construction.append(m)
        if len(good_construction) > MIN_MATCH_COUNT:
            src_pts = np.float32([kp1[m.queryIdx].pt for m in good_construction]).reshape(-1, 1, 2)
            dst_pts = np.float32([self.kp_construction[m.trainIdx].pt for m in good_construction]).reshape(-1, 1, 2)

            M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
            matches_construction = mask.ravel().tolist()
            # 发布检测到的交通标识
            msg_sign = UInt8()
            msg_sign.data = self.TrafficSign.construction.value

            self.pub_traffic_sign.publish(msg_sign)

            rospy.loginfo("construction")
            image_out_num = 2
        else:
            matches_construction = None
            rospy.loginfo("not found")
        # ------------------ 发布检测结果 ------------------ #
        # 没有检测到
        if image_out_num == 1:
            if self.pub_image_type == "compressed":
                self.pub_image_traffic_sign.publish(self.cvBridge.cv2_to_compressed_imgmsg(cv_image_input, "jpg"))

            elif self.pub_image_type == "raw":
                self.pub_image_traffic_sign.publish(self.cvBridge.cv2_to_imgmsg(cv_image_input, "bgr8"))
        # 检测到
        elif image_out_num == 2:
            draw_params_construction = dict(matchColor=(255, 0, 0),  # draw matches in green color
                                            singlePointColor=None,
                                            matchesMask=matches_construction,  # draw only inliers
                                            flags=2)

            final_construction = cv2.drawMatches(cv_image_input, kp1, self.img_construction, self.kp_construction,
                                                 good_construction, None, **draw_params_construction)

            if self.pub_image_type == "compressed":
                self.pub_image_traffic_sign.publish(self.cvBridge.cv2_to_compressed_imgmsg(final_construction, "jpg"))

            elif self.pub_image_type == "raw":
                self.pub_image_traffic_sign.publish(self.cvBridge.cv2_to_imgmsg(final_construction, "bgr8"))

detect_tunnel_sign

turtlebot3_autorace_2020\turtlebot3_autorace_detect\nodes\detect_tunnel_sign

fnPreproc

    def fnPreproc(self):
        # Initiate SIFT detector
        # self.sift = cv2.SIFT_create()
        self.sift = sift = cv2.SIFT_create(nfeatures=0, nOctaveLayers=3, contrastThreshold=0.04, edgeThreshold=10,
                                           sigma=1.6)

        dir_path = os.path.dirname(os.path.realpath(__file__))
        dir_path = dir_path.replace('turtlebot3_autorace_detect/nodes', 'turtlebot3_autorace_detect/')
        dir_path += 'image/'

        self.img_tunnel = cv2.imread(dir_path + 'tunnel.png', 0)  # trainImage3
        self.kp_tunnel, self.des_tunnel = self.sift.detectAndCompute(self.img_tunnel, None)

        FLANN_INDEX_KDTREE = 0
        index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
        search_params = dict(checks=50)

        self.flann = cv2.FlannBasedMatcher(index_params, search_params)

cbFindTrafficSign

    def cbFindTrafficSign(self, image_msg):
        # 降低检测频率
        if self.counter % 3 != 0:
            self.counter += 1
            return
        else:
            self.counter = 1
        # # 保存图片
        # cv_image = self.cvBridge.imgmsg_to_cv2(image_msg, desired_encoding='bgr8')
        # cv2.imwrite('/home/itcast/Downloads/tunnel01.png', cv_image)

        if self.sub_image_type == "compressed":
            # converting compressed image01 to opencv image01
            np_arr = np.frombuffer(image_msg.data, np.uint8)
            cv_image_input = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
        elif self.sub_image_type == "raw":
            cv_image_input = self.cvBridge.imgmsg_to_cv2(image_msg, "bgr8")

        MIN_MATCH_COUNT = 5

        # find the keypoints and descriptors with SIFT
        kp1, des1 = self.sift.detectAndCompute(cv_image_input, None)
        matches_tunnel = self.flann.knnMatch(des1, self.des_tunnel, k=2)

        image_out_num = 1
        # ------------------ 检测tunnel标志 ------------------ #
        good_tunnel = []
        for m, n in matches_tunnel:
            if m.distance < 0.7 * n.distance:
                good_tunnel.append(m)
        if len(good_tunnel) > MIN_MATCH_COUNT:
            src_pts = np.float32([kp1[m.queryIdx].pt for m in good_tunnel]).reshape(-1, 1, 2)
            dst_pts = np.float32([self.kp_tunnel[m.trainIdx].pt for m in good_tunnel]).reshape(-1, 1, 2)

            M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
            matchesMask_tunnel = mask.ravel().tolist()
            # 发布检测到的交通标识
            msg_sign = UInt8()
            msg_sign.data = self.TrafficSign.tunnel.value

            self.pub_traffic_sign.publish(msg_sign)

            rospy.loginfo("tunnel")
            image_out_num = 4
        else:
            matchesMask_tunnel = None
            rospy.loginfo("nothing")
        # ------------------ 发布检测结果 ------------------ #
        # 没有检测到
        if image_out_num == 1:
            if self.pub_image_type == "compressed":
                # publishes traffic sign image01 in compressed type
                self.pub_image_traffic_sign.publish(self.cvBridge.cv2_to_compressed_imgmsg(cv_image_input, "jpg"))

            elif self.pub_image_type == "raw":
                # publishes traffic sign image01 in raw type
                self.pub_image_traffic_sign.publish(self.cvBridge.cv2_to_imgmsg(cv_image_input, "bgr8"))
        # 检测到
        elif image_out_num == 4:
            draw_params_tunnel = dict(matchColor=(255, 0, 0),  # draw matches in green color
                                      singlePointColor=None,
                                      matchesMask=matchesMask_tunnel,  # draw only inliers
                                      flags=2)

            final_tunnel = cv2.drawMatches(cv_image_input, kp1, self.img_tunnel, self.kp_tunnel, good_tunnel, None,
                                           **draw_params_tunnel)

            if self.pub_image_type == "compressed":
                self.pub_image_traffic_sign.publish(self.cvBridge.cv2_to_compressed_imgmsg(final_tunnel, "jpg"))

            elif self.pub_image_type == "raw":
                self.pub_image_traffic_sign.publish(self.cvBridge.cv2_to_imgmsg(final_tunnel, "bgr8"))

SIFT尺度不变特征变换

SIFT(Scale-Invariant Feature Transform)可以翻译为尺度不变特征变换,该算法可用于提取关键点并计算其描述符。具体分为以下几个步骤:

  1. 尺度空间的极值检测
  2. 关键点定位
  3. 方向分配
  4. 关键点描述符
  5. 关键点匹配
import cv2
import numpy as np

img = cv2.imread('./img/box.png')
gray= cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)

# 创建SIFT检测器
sift = cv2.SIFT_create()
# 检测灰度图
kp = sift.detect(gray,None)
kp1, des1 = sift.detectAndCompute(gray, None)

# 绘制检测出的特征点
# img=cv2.drawKeypoints(gray, kp, None)
img=cv2.drawKeypoints(gray, kp1, None, flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)

cv2.imshow("dst", img)
if cv2.waitKey(0) == 27:
    cv2.destroyAllWindows()

这里sift.drawKeypoints(gray,None)用来查找图片中的素有关键点和描述符,如果希望只在图片中的某个区域查找,则可将掩膜Mask作为第二个参数。每个关键点都是一个特殊的结构对象,具有许多属性,例如其(x,y)坐标,有意义的邻域的大小,描述其方向的角度,描述其关键点强度等。

cv2.drawKeypoints函数默认的flags打印参数是cv2.DRAW_MATCHES_FLAGS_DEFAULT,只在关键点上打印圆圈,我们可以传入cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS参数,他会画出一个关键点尺寸大小的圆,并标出其方向信息。

FLANN based Matcher

FLANN是Fast Library for Approximate Nearest Neighbors.的缩写。可以实现一张图片和其它图片进行特征匹配.

对于基于FLANN的匹配器,我们需要传递两个字典,这些字典指定要使用的算法及其相关参数等。

IndexParams

第一个是IndexParams。对于各种算法,要传递的信息在FLANN文档中进行了说明。

  • 对于SIFT算法,初始化如下:
index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 5)

SearchParams

第二个字典是SearchParams。它指定索引中的树应递归遍历的次数。较高的值可提供更好的精度,但也需要更多时间

search_params = dict(checks=100)

匹配代码如下:

import numpy as np
import cv2

MIN_MATCH_COUNT = 5


def main():
    im1 = cv2.imread('img/box.png', cv2.IMREAD_COLOR)  # queryImage
    im2 = cv2.imread('img/box_in_scene.png', cv2.IMREAD_COLOR)  # trainImage

    # im1 = cv2.imread('./intersection/intersection.png', cv2.IMREAD_COLOR)  # queryImage
    # im2 = cv2.imread('./img/intersection_pic.png', cv2.IMREAD_COLOR)  # trainImage
    # im1 = cv2.imread('./img/construction.png', cv2.IMREAD_COLOR)  # queryImage
    # im2 = cv2.imread('./img/a.png', cv2.IMREAD_COLOR)  # trainImage
    # im1 = cv2.imread('./sign/left.png', cv2.IMREAD_COLOR)  # queryImage
    # im2 = cv2.imread('img/intersection_pic.png', cv2.IMREAD_COLOR)  # trainImage
    img1 = cv2.cvtColor(im1, cv2.COLOR_BGR2GRAY)
    img2 = cv2.cvtColor(im2, cv2.COLOR_BGR2GRAY)

    # 保留500个特征点
    sift = cv2.SIFT_create(500)

    # find the keypoints and descriptors with SIFT
    # 通过SIFT找到关键点和描述符
    kp1, des1 = sift.detectAndCompute(img1, None)
    kp2, des2 = sift.detectAndCompute(img2, None)
    print("kp1: {}, kp2: {}".format(len(kp1), len(kp2)))

    im1_copy = im1.copy()
    cv2.drawKeypoints(im1_copy, kp1, im1_copy, flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
    # cv2.imshow("kp1", im1_copy)
    im2_copy = im2.copy()
    cv2.drawKeypoints(im2_copy, kp2, im2_copy, flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
    # cv2.imshow("kp2", im2_copy)

    # cv2.imshow('img1',im1_copy)
    # cv2.imshow('img2',im2_copy)
    # cv2.waitKey(0)
    # return

    # FLANN parameters
    FLANN_INDEX_KDTREE = 0
    # KD树  深度为5
    index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
    # 指定索引中的树应该递归遍历的次数。值越高,精度越高,但是也越耗时
    search_params = dict(checks=50)  # or pass empty dictionary
    # 创建Flann匹配器
    flann = cv2.FlannBasedMatcher(index_params, search_params)

    # 执行匹配 2个最佳匹配点
    matches = flann.knnMatch(des1, des2, k=2)
    print(f'匹配:{matches}')

使用特征匹配查找目标

为了定位物体的位置,我们可以使用calib3d模块中的函数,即cv2.findHomography(),我们将两个图像中对应的特征点的集合传给这个函数,它将找到该对象的透视变换矩阵。然后,我们可以使用cv2.perspectiveTransform()执行这个变换进而找到对象。找到这个转换矩阵至少需要四个正确的点。匹配时可能会出现一些可能影响结果的错误。为了解决这个问题,我们可以使用RANSAC或LEAST_MEDIAN算法。

代码实现

通过SIFT查找特征点,并使用FLANN进行特征匹配,找到最佳匹配结果

import numpy as np
import cv2

MIN_MATCH_COUNT = 5


def main():
    im1 = cv2.imread('img/box.png', cv2.IMREAD_COLOR)  # queryImage
    im2 = cv2.imread('img/box_in_scene.png', cv2.IMREAD_COLOR)  # trainImage

    # im1 = cv2.imread('./intersection/intersection.png', cv2.IMREAD_COLOR)  # queryImage
    # im2 = cv2.imread('./img/intersection_pic.png', cv2.IMREAD_COLOR)  # trainImage
    # im1 = cv2.imread('./img/construction.png', cv2.IMREAD_COLOR)  # queryImage
    # im2 = cv2.imread('./img/a.png', cv2.IMREAD_COLOR)  # trainImage
    # im1 = cv2.imread('./sign/left.png', cv2.IMREAD_COLOR)  # queryImage
    # im2 = cv2.imread('img/intersection_pic.png', cv2.IMREAD_COLOR)  # trainImage
    img1 = cv2.cvtColor(im1, cv2.COLOR_BGR2GRAY)
    img2 = cv2.cvtColor(im2, cv2.COLOR_BGR2GRAY)

    # 保留500个特征点
    sift = cv2.SIFT_create(500)

    # find the keypoints and descriptors with SIFT
    # 通过SIFT找到关键点和描述符
    kp1, des1 = sift.detectAndCompute(img1, None)
    kp2, des2 = sift.detectAndCompute(img2, None)
    print("kp1: {}, kp2: {}".format(len(kp1), len(kp2)))

    im1_copy = im1.copy()
    cv2.drawKeypoints(im1_copy, kp1, im1_copy, flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
    # cv2.imshow("kp1", im1_copy)
    im2_copy = im2.copy()
    cv2.drawKeypoints(im2_copy, kp2, im2_copy, flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
    # cv2.imshow("kp2", im2_copy)

    # cv2.imshow('img1',im1_copy)
    # cv2.imshow('img2',im2_copy)
    # cv2.waitKey(0)
    # return

    # FLANN parameters
    FLANN_INDEX_KDTREE = 0
    # KD树  深度为5
    index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
    # 指定索引中的树应该递归遍历的次数。值越高,精度越高,但是也越耗时
    search_params = dict(checks=50)  # or pass empty dictionary
    # 创建Flann匹配器
    flann = cv2.FlannBasedMatcher(index_params, search_params)

    # 执行匹配 2个最佳匹配点
    matches = flann.knnMatch(des1, des2, k=2)
    print(f'匹配:{matches}')
    # 存储好的匹配
    good = []
    for m, n in matches:
        if m.distance < 0.7 * n.distance:
            good.append(m)

    # 创建掩膜, 只绘制好的匹配
    matchesMask = None
    print(f'匹配的点的个数:{len(good)}')
    if len(good) > MIN_MATCH_COUNT:
        # 分别取出匹配成功的queryImage的所有关键点 src_pts 以及trainImage的所有关键点 dst_pts
        # m.queryIdx 关键点的索引   pt:坐标
        src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
        dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)

        # 使用findHomography并结合RANSAC算法,避免一些错误的点对结果产生影响(去除一些偏差大的点)
        # 返回透视变换矩阵和mask
        M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
        # 转成一行多列的。描述了对应索引位置的匹配结果是否在结果区域内
        matchesMask = mask.ravel().tolist()

        h, w = img1.shape
        pts = np.float32([[0, 0], [0, h - 1], [w - 1, h - 1], [w - 1, 0]]).reshape(-1, 1, 2)
        # 透视变换
        dst = cv2.perspectiveTransform(pts, M)
        # 在图2中绘制匹配的蓝色方框
        im2 = cv2.polylines(im2, [np.int32(dst)], True, (255, 0, 0), 3, cv2.LINE_AA)

    else:
        print("Not enough matches are found - %d/%d" % (len(good), MIN_MATCH_COUNT))
        matchesMask = None

    draw_params = dict(matchColor=(0, 255, 0),  # draw matches in green color
                       singlePointColor=None,
                       matchesMask=matchesMask,  # draw only inliers
                       flags=cv2.DRAW_MATCHES_FLAGS_NOT_DRAW_SINGLE_POINTS)
    # 绘制匹配
    img3 = cv2.drawMatches(im1, kp1, im2, kp2, good, None, **draw_params)
    cv2.imshow('matches', img3)
    cv2.imshow('img1', im1)
    cv2.imshow('img2', img2)

    return 0


if __name__ == '__main__':
    main()

    cv2.waitKey(0)

good保存了所有匹配组的query索引和train索引信息,则如果结果个数符合要求,我们将匹配成功的所有的关键点坐标提取出来。将他们交给cv2.findHomography()函数,来计算透视变换矩阵。一旦我们得到了3×33×3的变换矩阵,就可以使用它将queryImage的角转换为trainImage中的对应点。