SwinTransformer的相对位置索引的原理以及源码分析

文章目录

  • 1. 理论分析
  • 2. 完整代码


引用:参考博客链接


1. 理论分析

根据论文中提供的公式可知是在 Q Q Q K K K进行匹配并除以 d \sqrt d d 后加上了相对位置偏执 B B B

A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q K T d + B ) V \begin{aligned} &Attention(Q,K,V) = Softmax(\frac{QK^T}{\sqrt d}+B)V \end{aligned} Attention(Q,K,V)=Softmax(d QKT+B)V

    如下图,假设输入的feature map高宽都为2,那么首先我们可以构建出每个像素的绝对位置(左下方的矩阵),对于每个像素的绝对位置是使用行号和列号表示的。比如蓝色的像素对应的是第0行第0列所以绝对位置索引是(0,0),接下来再看看相对位置索引。首先看下蓝色的像素,在蓝色像素使用q与所有像素k进行匹配过程中,是以蓝色像素为参考点。然后用蓝色像素的绝对位置索引与其他位置索引进行相减,就得到其他位置相对蓝色像素的相对位置索引。例如黄色像素的绝对位置索引是(0,1),则它相对蓝色像素的相对位置索引为 ( 0 , 0 ) − ( 0 , 1 ) = ( 0 , − 1 ) (0,0)−(0,1)=(0,−1) (0,0)(0,1)=(0,1) 。那么同理可以得到其他位置相对蓝色像素的相对位置索引矩阵。同样,也能得到相对黄色,红色以及绿色像素的相对位置索引矩阵。接下来将每个相对位置索引矩阵按行展平,并拼接在一起可以得到下面的4x4矩阵 。

在这里插入图片描述

对应源代码为:

import torch
import torch.nn as  nn
from timm.models.layers import trunc_normal_

window_size = [2,2]
num_heads = 3

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 绝对位置索引

coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2

    请注意,这里描述的一直是相对位置索引,并不是相对位置偏执参数。因为后面我们会根据相对位置索引去取对应的参数。比如说黄色像素是在蓝色像素的右边,所以相对蓝色像素的相对位置索引为(0,−1)。绿色像素是在红色像素的右边,所以相对红色像素的相对位置索引为(0,−1)。可以发现这两者的相对位置索引都是(0,−1),所以他们使用的相对位置偏执参数都是一样的。

    其实讲到这基本已经讲完了,但在源码中作者为了方便把二维索引给转成了一维索引。具体这么转的呢,有人肯定想到,简单啊直接把行、列索引相加不就变一维了吗?比如上面的相对位置索引中有(0,−1)和(−1,0)在二维的相对位置索引中明显是代表不同的位置,但如果简单相加都等于-1那不就出问题了吗?接下来我们看看源码中是怎么做的。首先在原始的相对位置索引上加上 ( M − 1 ) (M-1) (M1) ( M M M为窗口的大小,在本示例中 M M M=2),加上之后索引中就不会有负数了。
    
在这里插入图片描述

对应源代码为:

relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1

    
接着将所有的行标都乘上2M-1。

在这里插入图片描述

    
对应源代码为:

relative_coords[:, :, 0] *= 2 * window_size[1] - 1

最后将行标和列标进行相加。这样即保证了相对位置关系,而且不会出现上述0 + ( − 1 ) = ( − 1 ) + 0 0+(-1)=(-1)+00+(−1)=(−1)+0的问题了,是不是很神奇。

在这里插入图片描述

relative_position_index = relative_coords.sum(-1)

    刚刚上面也说了,之前计算的是相对位置索引,并不是相对位置偏执参数。真正使用到的可训练参数 B ^ \hat{B} B^ 是保存在relative position bias table表里的,这个表的长度是等于 ( 2 M − 1 ) × ( 2 M − 1 ) (2M−1)×(2M−1) (2M1)×(2M1)的。那么上述公式中的相对位置偏执参数 B B B是根据上面的相对位置索引表根据查relative position bias table表得到的,如下图所示。

在这里插入图片描述

对应源代码为:

'''
relative_position_bias_table:
	其shape=((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
'''
relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
trunc_normal_(relative_position_bias_table, std=.02)

'''
index:
	虽然shape=(window_size[0] * window_size[1], window_size[0] * window_size[1]),
	但是只有(2 * window_size[0] - 1) * (2 * window_size[1] - 1)个不同的元素。
	作为索引,正好能一一对应relative_position_bias_table中的元素
'''
index = relative_position_index.view(-1)
relative_position_bias = relative_position_bias_table[index] # index的每一个不同的元素对应relative_position_bias_table中一个值

relative_position_bias = relative_position_bias.view(window_size[0] * window_size[1], window_size[0] * window_size[1], -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

2. 完整代码

import torch
import torch.nn as  nn
from timm.models.layers import trunc_normal_

window_size = [2,2]
num_heads = 3

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # # 绝对位置索引 # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords_temp = relative_coords.numpy()
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1

relative_coords[:, :, 0] *= 2 * window_size[1] - 1

relative_position_index = relative_coords.sum(-1)

relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
trunc_normal_(relative_position_bias_table, std=.02)

print('relative_position_bias_table:',relative_position_bias_table.shape)
print(relative_position_index.shape)

index = relative_position_index.view(-1)
print('index:',index.shape)
relative_position_bias = relative_position_bias_table[index]
print('relative_position_bias_Noreshape:',relative_position_bias.shape)

relative_position_bias = relative_position_bias.view(window_size[0] * window_size[1], window_size[0] * window_size[1], -1)

relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

print('relative_position_bias:',relative_position_bias.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/778649.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

软件设计之Java入门视频(12)

软件设计之Java入门视频(12) 视频教程来自B站尚硅谷: 尚硅谷Java入门视频教程,宋红康java基础视频 相关文件资料(百度网盘) 提取密码:8op3 idea 下载可以关注 软件管家 公众号 学习内容: 该视频共分为1-7…

gptoolbox matlab工具箱cmake 调试笔记

一、问题描述 起因:在matlab中运行Offset surface of triangle mesh in matlab的时候报错: 不支持将脚本 signed_distance 作为函数执行: E:\MATLAB_File\gptoolbox\mex\signed_distance.m> 出错 offset_bunny (第 22 行) D signed_distance(BC,V,F…

绝区贰--及时优化降低 LLM 成本和延迟

前言 大型语言模型 (LLM) 为各行各业带来了变革性功能,让用户能够利用尖端的自然语言处理技术处理各种应用。然而,这些强大的 AI 系统的便利性是有代价的 — 确实如此。随着 LLM 变得越来越普及,其计算成本和延迟可能会迅速增加,…

Python实战训练(方程与拟合曲线)

1.方程 求e^x-派(3.14)的解 用二分法来求解,先简单算出解所在的区间,然后用迭代法求逼近解,一般不能得到精准的解,所以设置一个能满足自己进度的标准来判断解是否满足 这里打印出解x0是因为在递归过程中…

经典双运算放大器LM358

前言 LM358双运放有几十年的历史了吧?通用运放,很常用,搞电路的避免不了接触运放,怎么选择运放,是工程师关心的问题吧? 从本文开始,将陆续发一些常用的运放,大家选型可以参考&#…

【最新整理】全国高校本科及专科招生和毕业数据集(2008-2022年)

整理了各省高校本科、专科招生和毕业数据等21个相关指标,包括招生、在校、毕业人数,以及财政教育支出、教育经费等数据。含原始数据、线性插值、回归填补三个版本,希望对大家有所帮助 一、数据介绍 数据名称:高校本科、专科招生…

如何处理 PostgreSQL 中由于表连接顺序不当导致的性能问题?

文章目录 一、理解表连接和连接顺序二、识别由于表连接顺序不当导致的性能问题三、影响表连接顺序的因素四、解决方案手动调整连接顺序创建合适的索引分析数据分布和优化查询逻辑 五、示例分析手动调整连接顺序创建索引优化查询逻辑 六、总结 在 PostgreSQL 中,表连…

[FreeRTOS 内部实现] 事件组

文章目录 事件组结构体创建事件组事件组等待位事件组设置位 事件组结构体 // 路径:Source/event_groups.c typedef struct xEventGroupDefinition {EventBits_t uxEventBits;List_t xTasksWaitingForBits; } EventGroup_t;uxEventBits 中的每一位表示某个事件是否…

【LeetCode:3101. 交替子数组计数 + 滑动窗口 + 数学公式】

🚀 算法题 🚀 🌲 算法刷题专栏 | 面试必备算法 | 面试高频算法 🍀 🌲 越难的东西,越要努力坚持,因为它具有很高的价值,算法就是这样✨ 🌲 作者简介:硕风和炜,…

怎样把自己电脑ip改成动态ip:步骤与解析

在今天的网络世界中,IP地址是计算机与互联网沟通的桥梁。而动态IP地址,作为其中的一种类型,由于其自动分配和管理的特性,为用户提供了更大的便利性和灵活性。那么,您是否想知道怎样将电脑IP改为动态呢?本文…

用Excel处理数据图像,出现交叉怎么办?

一、问题描述 用excel制作X-Y散点图,意外的出现了4个交叉点,而实际上的图表数据是没有交叉的。 二、模拟图表 模拟部分数据,并创建X-Y散点图,数据区域,X轴数据是依次增加的,因此散点图应该是没有交叉的。…

js好用的动态分页插件

js好用的动态分页插件是一款简单的分页样式插件,支持样式类型,当前页,每页显示数量,按钮数量,总条数,上一页文字,下一页文字,输入框跳转等功能。 js好用的动态分页插件

java基础:流程控制

一、用户交互Scanner (一)基础 1、概念:基本语法中我们并没有实现程序和人的交互,但是Java给我们提供了这样一个工具类,我们可以获取用户的输入。java.util.Scanner 是 Java5的新特征,我们可以通过Scanne…

Java实现登录验证 -- JWT令牌实现

目录 1.实现登录验证的引出原因 2.JWT令牌2.1 使用JWT令牌时2.2 令牌的组成 3. JWT令牌(token)生成和校验3.1 引入JWT令牌的依赖3.2 使用Jar包中提供的API来实现JWT令牌的生成和校验3.3 使用JWT令牌验证登录3.4 令牌的优缺点 1.实现登录验证的引出 传统…

Spring 泛型依赖注入

Spring 泛型依赖注入,是利用泛型的优点对代码时行精简,将可重复使用的代码全部放到一个类之中,方便以后的维护和修改,同时在不增加代码的情况下增加代码的复用性。 示例代码: 创建实体类 Product package test.spri…

在电子表格中对多列数据去重

一、数据展示 二、代码 Sub 选中区域数据去重()Dim arr()Dim c, d, id Selection.Counti 0For Each c In SelectionIf c.Value <> "" ThenReDim Preserve arr(0 To i)arr(i) c.Valuei i 1End IfNextarr 一维去重(arr)i 0For Each c In Range("O2&…

当需要对多个表进行联合更新操作时,怎样确保数据的一致性?

文章目录 一、问题分析二、解决方案三、示例代码&#xff08;以 MySQL 为例&#xff09;四、加锁机制示例五、测试和验证六、总结 在数据库管理中&#xff0c;经常会遇到需要对多个表进行联合更新的情况。这种操作带来了一定的复杂性&#xff0c;因为要确保在整个更新过程中数据…

Charles拦截发送数据包-cnblog

Charles拦截发送数据包 打开允许断点 右键要打断点的数据包&#xff0c;打断点 重新发请求进入断点模式 修改完毕后发送

集成学习(三)GBDT 梯度提升树

前面学习了&#xff1a;集成学习&#xff08;二&#xff09;Boosting-CSDN博客 梯度提升树&#xff1a;GBDT-Gradient Boosting Decision Tree 一、介绍 作为当代众多经典算法的基础&#xff0c;GBDT的求解过程可谓十分精妙&#xff0c;它不仅开创性地舍弃了使用原始标签进行…

浪潮信息携手算力企业为华东产业集群布局提供高质量算力支撑

随着信息技术的飞速发展&#xff0c;算力已成为推动数字经济发展的核心力量。近日&#xff0c;浪潮信息与五家领先的算力运营公司在南京正式签署战略合作协议&#xff0c;共同加速华东地区智算基础设施布局&#xff0c;为区域经济发展注入新动力。 进击的算力 江苏持续加码智算…