2025-5-12
基于强化学习的Transformer目标检测模型框架(RL-DETR)
模型方法
本工作提出一种融合强化学习思想的Transformer目标检测框架,旨在提升标准DETR架构中对象查询(object queries)更新策略的表达能力与泛化性能。我们将DETR的每一层解码器视为强化学习中的一个时间步(time step),通过建模状态(state)、动作(action)和奖励(reward),引入高斯策略建模(Gaussian Policy)以及基于GPRO的优化函数,引导Transformer更有效地学习目标位置与类别信息。
整体结构
整个模型由以下三部分组成:
图像编码模块:输入图像通过预训练的ViT模型提取图像特征,输出 \(x \in \mathbb{R}^{H \times W \times d}\),其中 \(d\) 为特征维度。
Transformer编码器-解码器结构:包括6层标准编码器和6层解码器,编码器提取全局语义信息,解码器迭代更新对象查询,最终输出边界框和类别。
RL增强查询更新机制:解码器每层对查询的更新通过高斯采样实现,动作具有随机性;每一层的中间预测均计算奖励以用于强化学习训练。
状态定义
在第 \(k\) 层解码器中,状态定义为该层的对象查询:
\[ s_k = q_k \in \mathbb{R}^{N \times d}, \]
其中 \(N\) 是对象查询的数量,\(d\) 是特征维度,\(q_k\) 表示第 \(k\) 层解码器的查询输入,是从 \(q_{k-1}\) 经过处理后得到的。
动作定义
动作是对当前查询 \(q_k\) 的更新过程,即:
\[ a_k: q_k \longrightarrow q_{k+1}, \]
不同于传统DETR中确定性更新,我们引入高斯策略,将动作表示为一个从高斯分布中采样的随机变量:
\[ q_{k+1} \sim \mathcal{N}(\mu_k, \Sigma_k), \]
其中
\[ \mu_k = W_{\mu} h_k, \quad \log \sigma_k^2 = W_{\sigma} h_k, \]
\(h_k \in \mathbb{R}^{N \times d}\) 是通过标准解码器结构(包括自注意力、交叉注意力和前馈网络)对查询进行处理后的中间特征表示:
\[ h_k = \text{FFN}(\text{CrossAttn}(\text{SelfAttn}(q_k), x)). \]
从而每个位置的下一步查询由如下方式采样:
\[ q_{k+1} = \mu_k + \epsilon_k \cdot \sigma_k, \quad \epsilon_k \sim \mathcal{N}(0, I). \]
这种方式引入了动作的概率建模,支持策略梯度优化。
奖励设计
每一层输出的对象查询 \(q_{k+1}\) 会生成边界框和类别预测,通过以下方式与真实标签比较并计算奖励:
- 使用 Hungarian 匹配算法匹配预测与真实目标;
- 计算预测框与真实框的平均 IoU:\(\text{mean}(\text{IoU}_k)\);
- 计算分类准确率 \(\text{accuracy}_k\)。
定义奖励为:
\[ r_k = \text{mean}(\text{IoU}_k) + \alpha \cdot \text{accuracy}_k, \]
其中 \(\alpha\) 为平衡系数,控制IoU与分类准确度的权重。
或使用负损失作为奖励(用于策略梯度):
\[ r_k = - \left( \text{box\_loss}_k + \text{class\_loss}_k \right). \]
强化学习优化目标
为了更稳定地更新策略,我们引入 GPRO(Generalized PPO with Reference Policy)优化目标,具体定义如下:
\[ \mathcal{J}_{\text{GPRO}}(\theta)=\mathbb{E}_{q \sim P(Q), \{o_i\}_{i=1}^G \sim \pi_{\theta_{\text{old}}}(o|q)} \left\{\frac{1}{G} \sum_{i=1}^{G}\left[\min\left(\frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}A_i, \text{clip}\left(\frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}, 1 - \epsilon, 1 + \epsilon \right) A_i\right) - \beta \mathbb{D}_{\text{KL}}(\pi_{\theta} \| \pi_{\text{ref}})\right]\right\}, \]
其中:
- \(\pi_{\theta}(o|q)\) 为当前策略生成的动作分布;
- \(\pi_{\theta_{\text{old}}}\) 为旧策略;
- \(\pi_{\text{ref}}\) 为参考策略(如标准DETR的确定性更新);
- \(A_i\) 为优势函数,定义为:
\[ A_i = r_k^i + \gamma V(q_{k+1}^i) - V(q_k^i), \]
- \(\epsilon\) 控制策略更新范围,\(\beta\) 控制与参考策略的KL散度正则。
该目标通过约束更新步幅与参考策略距离,确保稳定训练,并鼓励在合理范围内探索。
小结
本模型将DETR解码器过程重新建模为强化学习中的多步决策问题,借助高斯策略建模查询更新过程,并以GPRO作为训练目标,稳定有效地提升了目标检测的性能,增强了模型对复杂查询空间的探索能力。