Schwertlilien
As a recoder: notes and ideas.

2025-5-12

基于强化学习的Transformer目标检测模型框架(RL-DETR)

模型方法

本工作提出一种融合强化学习思想的Transformer目标检测框架,旨在提升标准DETR架构中对象查询(object queries)更新策略的表达能力与泛化性能。我们将DETR的每一层解码器视为强化学习中的一个时间步(time step),通过建模状态(state)、动作(action)和奖励(reward),引入高斯策略建模(Gaussian Policy)以及基于GPRO的优化函数,引导Transformer更有效地学习目标位置与类别信息。

整体结构

整个模型由以下三部分组成:

  1. 图像编码模块:输入图像通过预训练的ViT模型提取图像特征,输出 \(x \in \mathbb{R}^{H \times W \times d}\),其中 \(d\) 为特征维度。

  2. Transformer编码器-解码器结构:包括6层标准编码器和6层解码器,编码器提取全局语义信息,解码器迭代更新对象查询,最终输出边界框和类别。

  3. RL增强查询更新机制:解码器每层对查询的更新通过高斯采样实现,动作具有随机性;每一层的中间预测均计算奖励以用于强化学习训练。

状态定义

在第 \(k\) 层解码器中,状态定义为该层的对象查询:

\[ s_k = q_k \in \mathbb{R}^{N \times d}, \]

其中 \(N\) 是对象查询的数量,\(d\) 是特征维度,\(q_k\) 表示第 \(k\) 层解码器的查询输入,是从 \(q_{k-1}\) 经过处理后得到的。

动作定义

动作是对当前查询 \(q_k\) 的更新过程,即:

\[ a_k: q_k \longrightarrow q_{k+1}, \]

不同于传统DETR中确定性更新,我们引入高斯策略,将动作表示为一个从高斯分布中采样的随机变量:

\[ q_{k+1} \sim \mathcal{N}(\mu_k, \Sigma_k), \]

其中

\[ \mu_k = W_{\mu} h_k, \quad \log \sigma_k^2 = W_{\sigma} h_k, \]

\(h_k \in \mathbb{R}^{N \times d}\) 是通过标准解码器结构(包括自注意力、交叉注意力和前馈网络)对查询进行处理后的中间特征表示:

\[ h_k = \text{FFN}(\text{CrossAttn}(\text{SelfAttn}(q_k), x)). \]

从而每个位置的下一步查询由如下方式采样:

\[ q_{k+1} = \mu_k + \epsilon_k \cdot \sigma_k, \quad \epsilon_k \sim \mathcal{N}(0, I). \]

这种方式引入了动作的概率建模,支持策略梯度优化。

奖励设计

每一层输出的对象查询 \(q_{k+1}\) 会生成边界框和类别预测,通过以下方式与真实标签比较并计算奖励:

  1. 使用 Hungarian 匹配算法匹配预测与真实目标;
  2. 计算预测框与真实框的平均 IoU:\(\text{mean}(\text{IoU}_k)\)
  3. 计算分类准确率 \(\text{accuracy}_k\)

定义奖励为:

\[ r_k = \text{mean}(\text{IoU}_k) + \alpha \cdot \text{accuracy}_k, \]

其中 \(\alpha\) 为平衡系数,控制IoU与分类准确度的权重。

或使用负损失作为奖励(用于策略梯度):

\[ r_k = - \left( \text{box\_loss}_k + \text{class\_loss}_k \right). \]

强化学习优化目标

为了更稳定地更新策略,我们引入 GPRO(Generalized PPO with Reference Policy)优化目标,具体定义如下:

\[ \mathcal{J}_{\text{GPRO}}(\theta)=\mathbb{E}_{q \sim P(Q), \{o_i\}_{i=1}^G \sim \pi_{\theta_{\text{old}}}(o|q)} \left\{\frac{1}{G} \sum_{i=1}^{G}\left[\min\left(\frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}A_i, \text{clip}\left(\frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}, 1 - \epsilon, 1 + \epsilon \right) A_i\right) - \beta \mathbb{D}_{\text{KL}}(\pi_{\theta} \| \pi_{\text{ref}})\right]\right\}, \]

其中:

  • \(\pi_{\theta}(o|q)\) 为当前策略生成的动作分布;
  • \(\pi_{\theta_{\text{old}}}\) 为旧策略;
  • \(\pi_{\text{ref}}\) 为参考策略(如标准DETR的确定性更新);
  • \(A_i\) 为优势函数,定义为:

\[ A_i = r_k^i + \gamma V(q_{k+1}^i) - V(q_k^i), \]

  • \(\epsilon\) 控制策略更新范围,\(\beta\) 控制与参考策略的KL散度正则。

该目标通过约束更新步幅与参考策略距离,确保稳定训练,并鼓励在合理范围内探索。

小结

本模型将DETR解码器过程重新建模为强化学习中的多步决策问题,借助高斯策略建模查询更新过程,并以GPRO作为训练目标,稳定有效地提升了目标检测的性能,增强了模型对复杂查询空间的探索能力。

搜索
匹配结果数:
未搜索到匹配的文章。