news 2026/7/6 3:24:49

征程 6 | 工具链 QAT ObserverBase 源码解析

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
征程 6 | 工具链 QAT ObserverBase 源码解析

1. 概述

ObserverBase 是 horizon_plugin_pytorch 量化框架中所有 Observer 的抽象基类。它定义了量化校准器的统一接口和核心功能,为各种量化策略(MinMax、MSE、KL 等)提供了基础架构。

2. ABCMeta 深度解析

2.1 Python 元类机制

在 Python 中,​类也是对象​,类是由元类(metaclass)创建的:

默认情况下,所有类都由元类创建。当指定 metaclass=ABCMeta 时,类的创建过程由 ABCMeta 控制。

示例如下:

from abc import ABCMeta, abstractmethod class ObserverBase(torch.nn.Module, metaclass=ABCMeta): @abstractmethod def forward(self, x): pass

2.2 @abstractmethod 装饰器

def abstractmethod(funcobj): """标记方法为抽象方法""" funcobj.__isabstractmethod__ = True # 仅设置标志位 return funcobj

2.3 ObserverBase 中的应用

# 基类定义抽象方法 class ObserverBase(torch.nn.Module, metaclass=ABCMeta): @abstractmethod def forward(self, x): pass # ObserverBase.__abstractmethods__ = frozenset({'forward'}) # 子类实现 class MinMaxObserver(ObserverBase): def forward(self, x_orig): return x_orig # MinMaxObserver.__abstractmethods__ = frozenset() → 可实例化

3. ObserverBase 完整源码

class ObserverBase(torch.nn.Module, metaclass=ABCMeta): r"""Base observer Module. Any observer implementation should derive from this class. Concrete observers should follow the same API. In forward, they will update the statistics of the observed Tensor. And they should provide a `calculate_qparams` function that computes the quantization parameters given the collected statistics. Args: averaging_constant: Averaging constant for min/max. ch_axis: Channel axis. dtype: Quantized data type. qscheme: Quantization scheme to be used. quant_min: Min quantization value. Will follow dtype if unspecified. quant_max: Max quantization value. Will follow dtype if unspecified. is_sync_quantize: If sync statistics when training with multiple devices. factory_kwargs: kwargs which are passed to factory functions for min_val and max_val. """ _version = 3 eps: torch.Tensor min_val: torch.Tensor max_val: torch.Tensor is_sync_quantize: Optional[bool] = True @typechecked def __init__( self, averaging_constant: float = 0.01, ch_axis: int = -1, dtype: Union[torch.dtype, QuantDType] = qint8, qscheme: torch.qscheme = torch.per_tensor_symmetric, quant_min: int = None, quant_max: int = None, is_sync_quantize: Optional[bool] = None, factory_kwargs: Dict = None, compute_scale_strategy=ComputeScaleStrategy.STATISTIC, ): super(ObserverBase, self).__init__() if qscheme == torch.per_channel_symmetric: assert ( ch_axis >= 0 ), "ch_axis should be non-negative when using per_channel_symmetric qcsheme" else: assert ( ch_axis < 0 ), "ch_axis should be negative when using per_tensor_symmetric qcsheme" dtype = get_horizon_quant_dtype(dtype) assert qscheme in ( torch.per_tensor_symmetric, torch.per_channel_symmetric, ), ( "only support per_tensor_symmetric and per_channel_symmetric " "qscheme" ) self.averaging_constant = averaging_constant self.ch_axis = ch_axis self.dtype = dtype self.qscheme = qscheme self._set_quant_min_max(self.dtype, quant_min, quant_max) if is_sync_quantize is not None: self.is_sync_quantize = is_sync_quantize self.compute_scale_strategy = compute_scale_strategy factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) self.register_buffer( "eps", torch.tensor([torch.finfo(torch.float32).eps], **factory_kwargs), ) self.register_buffer("min_val", torch.tensor([], **factory_kwargs)) self.register_buffer("max_val", torch.tensor([], **factory_kwargs)) def _set_quant_min_max( self, dtype, quant_min=None, quant_max=None, ): if (quant_min is not None) and (quant_max is not None): assert quant_min < quant_max, ( "qmin must be strictly less than qmax for user-specified " "quantization range." ) assert ( quant_min <= 0 <= quant_max ), "Used-specified quantization range must include 0." assert qinfo(dtype).min <= quant_min, "quant_min out of bound" assert quant_max <= qinfo(dtype).max, "quant_max out of bound" self.quant_min, self.quant_max = quant_min, quant_max else: self.quant_min, self.quant_max = ( qinfo(self.dtype).min, qinfo(self.dtype).max, ) def reset_dtype(self, dtype): dtype = get_horizon_quant_dtype(dtype) if dtype == self.dtype: return self.dtype = dtype self._set_quant_min_max(self.dtype) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): # buffers has been renamed from min/max_vals to min/max_val buffer_name_mapping = {"min_vals": "min_val", "max_vals": "max_val"} for old_name in buffer_name_mapping: k = prefix + old_name if k in state_dict: v = state_dict.pop(k) state_dict[prefix + buffer_name_mapping[old_name]] = v eps_key = prefix + "eps" if eps_key not in state_dict: # eps was moved to a buffer in version 2 eps = torch.tensor([torch.finfo(torch.float32).eps]) state_dict[eps_key] = eps local_state = ["min_val", "max_val"] for name in local_state: key = prefix + name if key in state_dict: # if ndim=0, make it ndim=1 state_dict[key] = state_dict[key].reshape(-1) val = state_dict[key] # Custom handling to allow loading min_val or max_val # of size N into uninitialized buffers of size 0. The # buffers are resized here, and the values are copied in # the default state_dict loading code of the parent. if name == "min_val" and hasattr(self, "min_val"): self.min_val.resize_(val.shape) elif hasattr(self, "max_val"): self.max_val.resize_(val.shape) super(ObserverBase, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ) def _load_from_state_dict_script( self, state_dict: Union[Dict[str, torch.Tensor], Dict[str, torch.Tensor]], prefix: str, local_metadata: Dict[str, torch.Tensor], strict: bool, missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str], ): self._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ) def sync_minmax(self, min_val, max_val): if dist.is_initialized() and min_val.is_cuda: dist.all_reduce(min_val, op=dist.ReduceOp.MIN) dist.all_reduce(max_val, op=dist.ReduceOp.MAX) def calculate_qparams(self): r"""Calculate the quantization parameters. Returns: scales: Scales tensor of shape (#channels,) zero_points: Zero points tensor of shape (#channels,) """ if self.min_val.numel() == 0 or self.max_val.numel() == 0: warnings.warn( "Must run observer before calling calculate_qparams. " "Returning default scale and zero point. " "This is an expected behavior if you use KLObserver " "and set 1 < update_interval <= total steps. ", ) return torch.tensor( [1.0], device=self.min_val.device ), torch.tensor([0], device=self.min_val.device) scale = _compute_scale_symmetric( self.min_val, self.max_val, self.quant_min, self.quant_max, self.eps, self.compute_scale_strategy, ) return scale, None def repr_msgs(self): msges = [] # only print minmax value for per tensor if hasattr(self, "min_val") and self.min_val.numel() == 1: msges.append("min_val={}".format(self.min_val.item())) if hasattr(self, "max_val") and self.max_val.numel() == 1: msges.append("max_val={}".format(self.max_val.item())) return msges def extra_repr(self): return ",".join(self.repr_msgs()) @abstractmethod def forward(self, x): pass with_args = classmethod(_with_args)

4. 核心属性详解

4.1 量化配置属性

# 基础量化参数 self.averaging_constant: float # 移动平均系数 self.ch_axis: int # 通道轴 (per_channel量化时使用) self.dtype: QuantDType # 量化数据类型 (qint8, qint4等) self.qscheme: torch.qscheme # 量化方案 (per_tensor/per_channel) self.quant_min: int # 量化最小值 self.quant_max: int # 量化最大值 self.is_sync_quantize: bool # 多卡同步统计量 self.compute_scale_strategy # scale计算策略 (STATISTIC/POT/FP16等)

4.2 统计量缓冲区

self.register_buffer("eps", torch.tensor([torch.finfo(torch.float32).eps])) self.register_buffer("min_val", torch.tensor([])) self.register_buffer("max_val", torch.tensor([]))

使用 register_buffer 注册的原因:

  • 不参与梯度计算​:统计量不是模型参数
  • 随模型迁移设备​:model.cuda() 时自动迁移
  • 可保存到 state_dict​:校准结果可持久化

5. 核心方法详解

5.1 init - 初始化

参数说明:

参数默认值说明
averaging_constant0.01移动平均系数,值越大当前 batch 权重越高
ch_axis-1通道轴,负数表示 per_tensor,非负表示 per_channel
dtypeqint8量化数据类型
qschemeper_tensor_symmetric量化方案
quant_min/maxNone自定义量化范围,None 时根据 dtype 自动设置
is_sync_quantizeTRUE多卡训练时是否同步统计量

关键校验逻辑:

# per_channel 必须指定有效的 ch_axis if qscheme == torch.per_channel_symmetric: assert ch_axis >= 0, "ch_axis should be non-negative" else: assert ch_axis < 0, "ch_axis should be negative for per_tensor" # 仅支持对称量化 assert qscheme in ( torch.per_tensor_symmetric, torch.per_channel_symmetric, )

5.2 forward - 更新统计信息(抽象方法)

设计意图:

  • 子类必须实现此方法(由 ABCMeta 强制)
  • 在校准阶段,每个 forward pass 收集激活值的统计信息
  • 返回原始输入(不修改数据流)

典型实现模式:

def forward(self, x_orig): # 1. 计算当前 batch 的统计量 min_val_cur, max_val_cur = compute_statistics(x_orig) # 2. 多卡同步(可选) if self.is_sync_quantize: self.sync_minmax(min_val_cur, max_val_cur) # 3. 更新累计统计量(移动平均) self.min_val = update_statistics(self.min_val, min_val_cur) self.max_val = update_statistics(self.max_val, max_val_cur) return x_orig # 原样返回,不干扰前向传播

5.3 calculate_qparams - 计算量化参数

核心计算逻辑(_compute_scale_symmetric):

def _compute_scale_symmetric(min_val, max_val, quant_min, quant_max, eps, strategy): # 对称量化公式:scale = max(|min|, |max|) / (quant_range / 2) scale = ( torch.max(-min_val, max_val) .clamp_min(0) .div(float(quant_max - quant_min) / 2) .clamp_min(eps) ) # 可选的 scale 约束策略 if strategy == ComputeScaleStrategy.KPOT: # K-POT (可训练POT) scale = k_pot_scale(scale) elif strategy == ComputeScaleStrategy.POT: # Power-of-Two scale = 2 ** torch.ceil(torch.log2(scale)) elif strategy == ComputeScaleStrategy.FP16: # FP16 精度 scale = _get_fp16_scale(scale) return scale

5.4 sync_minmax - 多卡同步

def sync_minmax(self, min_val, max_val): if dist.is_initialized() and min_val.is_cuda: dist.all_reduce(min_val, op=dist.ReduceOp.MIN) dist.all_reduce(max_val, op=dist.ReduceOp.MAX)

原理:

  • 使用 all_reduce 聚合多卡的统计量
  • MIN 操作取所有卡的最小值
  • MAX 操作取所有卡的最大值
  • 确保多卡训练时校准结果一致

5.5 _load_from_state_dict - 状态加载

关键功能:

  • 版本兼容(处理旧版名称 min_vals → min_val)
  • 动态调整 buffer 大小
  • 支持从校准模型加载参数到 QAT 模型

6. 类继承体系

ObserverBase (抽象基类) │ ├── MinMaxObserver # 移动平均 min/max 统计 │ │ │ └── ClipObserver # 带截断的 min/max 统计 │ ├── FixedScaleObserver # 固定 scale(不统计) │ ├── PercentileObserver # 百分位统计 │ ├── MSEObserver # 最小化 MSE 搜索最优 scale │ ├── KLObserver # KL 散度校准 │ ├── MixObserver # 混合多种方法 │ └── HistogramObserver # 直方图统计(支持多种度量)

7. 设计亮点

  1. 统一接口​:所有 Observer 遵循相同的 API,便于替换和扩展
  2. 抽象基类约束​:通过 ABCMeta 强制子类实现 forward 方法
  3. 状态持久化​:统计量作为 buffer 保存,支持校准结果复用
  4. 分布式支持​:内置多卡同步机制
  5. 版本兼容​:_load_from_state_dict 处理历史版本兼容
  6. 灵活配置​:支持多种量化方案、数据类型、scale 策略

8.与 PyTorch 原生 Observer 的对比

特性PyTorch ObserverBaseHorizon ObserverBase
量化方案支持非对称量化仅支持对称量化
scale 约束POT/FP16/KPOT 策略
分布式同步需自行实现内置 sync_minmax
数据类型标准 torch.dtype扩展 QuantDType (qint4 等)
版本管理_version 字段支持迁移
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/7/6 3:23:48

番茄小说下载器完整指南:3分钟上手,解锁跨平台离线阅读新体验

番茄小说下载器完整指南&#xff1a;3分钟上手&#xff0c;解锁跨平台离线阅读新体验 【免费下载链接】Tomato-Novel-Downloader 番茄小说下载器不精简版 项目地址: https://gitcode.com/gh_mirrors/to/Tomato-Novel-Downloader 还在为番茄小说的在线阅读限制而烦恼吗&a…

作者头像 李华
网站建设 2026/7/6 3:23:13

Apache 2.4.23 PUT方法配置与安全:3步开启与5项关键安全加固

Apache 2.4.23 PUT方法配置与安全&#xff1a;3步开启与5项关键安全加固 在Web服务器管理中&#xff0c;PUT方法作为HTTP协议的重要功能之一&#xff0c;为资源上传提供了标准化的解决方案。然而&#xff0c;这一功能的启用往往伴随着显著的安全风险。本文将深入探讨Apache 2.4…

作者头像 李华
网站建设 2026/7/6 3:21:57

Python爬虫经典案例第67篇:社交媒体平台爬取:Twitter数据采集实战

1. 引言 Twitter(现更名为X)是全球最具影响力的社交媒体平台之一,拥有超过3.3亿月度活跃用户。作为一个实时信息传播平台,Twitter上的数据涵盖了新闻、政治、娱乐、科技等各个领域,具有极高的研究价值: 舆情分析:追踪热点话题和公众情绪 社交网络分析:研究用户关系和信…

作者头像 李华
网站建设 2026/7/6 3:21:53

NSK ZFT2508-5 滚珠丝杠技术解析

型号 ZFT2508-5 属于 NSK 的管循环式滚珠丝杠系列。与您上一条查询的满滚珠间隙品&#xff08;SFT2508-2.5&#xff09;相比&#xff0c;该型号是同尺寸&#xff08;25 mm 轴径、8 mm 较快导程&#xff09;下的 Z 预紧&#xff08;单螺母偏移导程预紧&#xff09;版本。 在 NSK…

作者头像 李华
网站建设 2026/7/6 3:20:10

Agent设计模式实践:构建高可靠性的LLM智能体服务

引言&#xff1a;从Demo到生产&#xff0c;可靠性是最大的鸿沟 过去两年&#xff0c;LLM智能体从概念验证快速走向实际应用。然而&#xff0c;行业数据揭示了令人警醒的现实&#xff1a;大量AI智能体永远停留在原型阶段&#xff0c;它们在演示环境下表现得智能而流畅&#xff0…

作者头像 李华
网站建设 2026/7/6 3:19:19

小龙虾本地AI安装,开源智能体环境搭建全流程

上周末闲着没事&#xff0c;我突然想在自己电脑上养一只“AI小龙虾”——就是那种能在本地跑、不用联网、随叫随到的智能体。之前一直用云API&#xff0c;但月底一看账单&#xff0c;肉疼得就像吃了一百块一只的小龙虾还没吃饱。于是决定自己动手&#xff0c;从零开始搭一个开源…

作者头像 李华