QARepVGG--含demo实现

news/2025/2/23 13:00:12

文章目录

  • 前言
  • 引入
  • Demo实现
  • 总结


前言

 在上一篇博文RepVGG中,介绍了RepVGG网络。RepVGG 作为一种高效的重参数化网络,通过训练时的多分支结构(3x3卷积、1x1卷积、恒等映射)和推理时的单分支合并,在精度与速度间取得了优秀平衡。然而,其在低精度(如INT8)量化后常出现显著精度损失。
 本文将要介绍的QARepVGG(Make RepVGG Greater Again: A Quantization-aware Approach)的提出正是为了解决这一问题。其核心贡献在于基础的Block设计:
在这里插入图片描述

引入

 文章做了详细的消融实验来一步一步的推理出这种结构,本文在此不多做赘述。只大概提一下:RepVGG其实是由三个单元构成:权重、BN和ReLU。卷积操作一般不会影响权重值的改变,基本服从0~1分布;而根据BN层的公式,会出现一个乘法项,导致方差可能发生改变;另外,如果输入的数值范围很大,经过ReLU也会产生大的方差项,导致量化困难。
 因此,QARepVGG去掉了BN层,并在三个分支后新加了一个BN层来将分布改成一个量化友好的分布。
 当然,建议读者阅读原论文,好多实验的设计跟分析很透彻。

Demo实现

 本文旨在复现一个QARepVGG Block,读者可一键运行:

import torch
import torch.nn as nn

class QARepVGGBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        assert in_channels == out_channels, "输入输出通道必须相同!"
        
        # 分支1:3x3卷积 + BN
        self.conv3x3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn3x3 = nn.BatchNorm2d(out_channels)
        
        # 分支2:1x1卷积(无BN)
        self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        
        # 分支3:恒等映射(无BN)
        self.identity = nn.Identity()  # 直接传递输入
        
        # 合并后的BN层
        self.final_bn = nn.BatchNorm2d(out_channels)
        
        # 初始化权重(关键!)
        self._init_weights()

    def _init_weights(self):
        """显式初始化权重"""
        nn.init.kaiming_normal_(self.conv3x3.weight, mode='fan_out', nonlinearity='relu')
        nn.init.zeros_(self.conv1x1.weight)  # 初始化为零,与恒等映射互补

    def forward(self, x):
        # 分支1:3x3卷积 + BN
        branch3x3 = self.bn3x3(self.conv3x3(x))
        # 分支2:1x1卷积
        branch1x1 = self.conv1x1(x)
        # 分支3:恒等映射
        branch_id = self.identity(x)
        # 合并后通过最终BN
        out = self.final_bn(branch3x3 + branch1x1 + branch_id)
        return out

    def reparameterize(self):
        """将多分支合并为单一3x3卷积,并融合BN参数"""
        # 1. 将各分支转换为等效3x3卷积
        # 分支1:3x3卷积 + BN3x3
        kernel3x3, bias3x3 = self._fuse_conv_bn(self.conv3x3, self.bn3x3)
        
        # 分支2:1x1卷积(无BN),填充为3x3
        kernel1x1 = self._pad_1x1_to_3x3(self.conv1x1.weight)
        bias1x1 = torch.zeros_like(bias3x3)  # 无偏置
        
        # 分支3:恒等映射(视为1x1单位矩阵卷积,填充为3x3)
        identity_kernel = torch.eye(self.conv3x3.in_channels, device=self.conv3x3.weight.device)
        identity_kernel = identity_kernel.view(self.conv3x3.in_channels, self.conv3x3.in_channels, 1, 1)
        kernel_id = self._pad_1x1_to_3x3(identity_kernel)
        bias_id = torch.zeros_like(bias3x3)
        
        # 2. 合并所有分支的权重和偏置
        merged_kernel = kernel3x3 + kernel1x1 + kernel_id
        merged_bias = bias3x3 + bias1x1 + bias_id
        
        # 3. 融合最终BN层参数
        scale = self.final_bn.weight / (self.final_bn.running_var + self.final_bn.eps).sqrt()
        merged_kernel = merged_kernel * scale.view(-1, 1, 1, 1)
        merged_bias = scale * (merged_bias - self.final_bn.running_mean) + self.final_bn.bias
        
        # 4. 构建合并后的卷积层
        merged_conv = nn.Conv2d(
            self.conv3x3.in_channels,
            self.conv3x3.out_channels,
            kernel_size=3,
            padding=1,
            bias=True
        )
        merged_conv.weight.data = merged_kernel
        merged_conv.bias.data = merged_bias
        return merged_conv

    def _fuse_conv_bn(self, conv, bn):
        """融合卷积和BN的权重与偏置"""
        kernel = conv.weight
        running_mean = bn.running_mean
        running_var = bn.running_var
        gamma = bn.weight
        beta = bn.bias
        eps = bn.eps

        std = (running_var + eps).sqrt()
        scale_factor = gamma / std

        fused_kernel = kernel * scale_factor.view(-1, 1, 1, 1)
        fused_bias = beta - running_mean * scale_factor
        return fused_kernel, fused_bias

    def _pad_1x1_to_3x3(self, kernel):
        """将1x1卷积核填充为3x3(中心为原权重,其余为0)"""
        if kernel.size(-1) == 1:
            padded = torch.zeros(kernel.size(0), kernel.size(1), 3, 3, device=kernel.device)
            padded[:, :, 1, 1] = kernel.squeeze()
            return padded
        return kernel 


def test_qarepvgg():
    torch.manual_seed(42)

    # 输入数据(小方差加速BN收敛)
    x = torch.randn(2, 3, 4, 4) * 0.1

    # 初始化模块
    block = QARepVGGBlock(3, 3)

    # 训练模式:更新BN统计量
    block.train()
    for _ in range(100):  # 充分训练
        y = block(x)
        y.sum().backward()  # 伪反向传播

    # 推理模式:合并权重
    block.eval()
    with torch.no_grad():
        # 原始输出
        orig_out = block(x)

        # 合并后的卷积
        merged_conv = block.reparameterize()
        merged_out = merged_conv(x)

    # 打印关键数据
    print("out:", orig_out.mean().item())
    print("merge:", merged_out.mean().item())
    print("diff:", torch.abs(orig_out - merged_out).max().item())

    # 验证一致性(容差1e-6)
    assert torch.allclose(orig_out, merged_out, atol=1e-6), f"合并失败!最大差值:{torch.abs(orig_out - merged_out).max().item()}"
    print("✅ 测试通过!")

test_qarepvgg()

在这里插入图片描述

总结

 欢迎留言交流讨论。


http://www.niftyadmin.cn/n/5863419.html

相关文章

cocos2dx Win10环境搭建(VS2019)

一、cocos2dx 介绍 Cocos2d-x是一个开源的跨平台游戏开发引擎,主要用于开发2D游戏。它基于Cocos2d-iphone引擎进行了移植,支持C, Lua和Javascript等多种编程语言。以下是Cocos2d-x的一些基本概念和使用场景: 基本概念: 场景&…

使用 DistilBERT 进行资源高效的自然语言处理

DistilBERT 是 BERT 的一个更小、更快的版本,在减少资源消耗的同时仍能保持良好性能。对于计算能力和内存受限的环境来说,它是一个理想的选择。 在自然语言处理(NLP)中,像 BERT 这样的模型提供了高精度和出色的性能。然…

Redission可重试、超时续约的实现原理

Redission遇到其他进程已经占用资源的时候会在指定时间waitTime内进行重试。实现过程如下: 执行获取锁的lua脚本时,会返回一个值, 如果获取锁成功,返回nil,也就是java里的null 如果获取锁失败,用语句“PT…

(三)趣学设计模式 之 抽象工厂模式!

目录 一、 啥是抽象工厂模式?二、 为什么要用抽象工厂模式?三、 抽象工厂模式怎么实现?四、 抽象工厂模式的应用场景五、 抽象工厂模式的优点和缺点六、 抽象工厂模式与工厂方法模式的区别七、 总结 🌟我的其他文章也讲解的比较有…

编程小白冲Kaggle每日打卡(12)--kaggle学堂:<机器学习简介>模型如何工作

Kaggle官方课程链接:How Models Work 本专栏旨在Kaggle官方课程的汉化,让大家更方便地看懂。 How Models Work 第一步,如果你是机器学习的新手。 Introduction 我们将从概述机器学习模型的工作原理和使用方法开始。如果你以前做过统计建模…

YOLOv8与DAttention机制的融合:复杂场景下目标检测性能的增强

文章目录 1. YOLOv8简介2. DAttention (DAT)注意力机制概述2.1 DAttention机制的工作原理 3. YOLOv8与DAttention (DAT)的结合3.1 引入DAT的动机3.2 集成方法3.3 代码实现 4. 实验与结果分析4.1 实验设置4.2 结果分析推理速度性能对比 5. 深度分析:DAttention在YOLO…

分发糖果(力扣135)

题目说相邻的两个孩子中评分更高的孩子获得的糖果更多,表示我们既要考虑到跟左边的孩子比较,也要考虑右边的孩子,但是我们如果两边一起考虑一定会顾此失彼。这里就引入一个思想:先满足右边大于左边时的糖果分发情况,再…

贪心算法

int a[1000], b5, c8; swap(b, c); // 交换操作 memset(a, 0, sizeof(a)); // 初始化为0或-1 引导问题 为一个小老鼠准备了M磅的猫粮,准备去和看守仓库的猫做交易,因为仓库里有小老鼠喜欢吃的五香豆,第i个房间有J[i] 磅的五香豆&#xf…