7-AWQ算子

AWQ算子实现

目标

AWQ算子本质上就是实现一个矩阵乘

  • 输入 A: 正常的连续内存 [M, K]
  • 输入 B: 高度int4压缩。水平方向(N)打包了 bit,垂直方向(K)按 group_size 进行了分块量化。
  • 输出 C: 正常的连续内存 [M, N]

如下图所示:

graph TD
    subgraph Input_A [输入矩阵 A: 激活值]
        A_data["形状: [M, K]<br/>类型: float16/bfloat16<br/>(连续内存排布)"]
    end

    subgraph Input_B [输入矩阵 B: 量化权重]
        direction TB
        subgraph B_Storage [存储结构]
            QW["qweight<br/>[K, N/8]<br/>int32 (打包 8个 4-bit)"]
            QZ["qzeros<br/>[K/G, N/8]<br/>int32 (打包 8个 4-bit)"]
            QS["scales<br/>[K/G, N]<br/>float16"]
        end

        subgraph B_Logic [逻辑视图]
            W_Logic["解包后 B': [K, N]<br/>(用于计算)"]
        end
    end

    subgraph Operation [Triton Kernel: awq_gemm_kernel]
        direction LR
        LoadA["加载 A 分块"]
        LoadB["加载 qweight/zeros/scales"]
        Dequant["位移解包 + 反量化:<br/>W = (qweight - zero) * scale"]
        Dot["Tensor Core 矩阵乘<br/>tl.dot(A, W)"]
    end

    subgraph Output_C [输出矩阵 C]
        C_data["形状: [M, N]<br/>类型: float16"]
    end

    A_data --> LoadA
    QW --> LoadB
    QZ --> LoadB
    QS --> LoadB
    LoadB --> Dequant
    LoadA --> Dot
    Dequant --> Dot
    Dot --> C_data

实现代码

代码来自最新的vllm算子实现,事实上sglang中也是adapt的vllm的算子

实现代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/awq_triton.py

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch
import triton
import triton.language as tl

AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]


@triton.jit
def awq_gemm_kernel(
    a_ptr,
    b_ptr,
    c_ptr,
    zeros_ptr,
    scales_ptr,
    M,
    N,
    K,
    group_size,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    SPLIT_K: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    pid_z = tl.program_id(1)

    # NOTE: This doesn't work in TRITON_INTERPRET=1 mode.  Use below instead.
    # num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)

    pid_m = pid // num_pid_n
    pid_n = pid % num_pid_n

    accumulator_dtype = c_ptr.type.element_ty

    # NOTE: This doesn't work in TRITON_INTERPRET=1 mode.  Use below instead.
    # accumulator = tl.arange(0, BLOCK_SIZE_N)
    # accumulator = tl.broadcast_to(accumulator[None, :],
    # (BLOCK_SIZE_M, BLOCK_SIZE_N))
    # accumulator = accumulator & 0x0
    # accumulator = accumulator.to(accumulator_dtype)
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype)

    # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
    # that will map given indices to the correct order.
    reverse_awq_order_tensor = (
        (tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None]
    ).reshape(8)

    # Create the necessary shifts to use to unpack.
    shifts = reverse_awq_order_tensor * 4
    shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8))
    shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N))

    # Offsets and masks.
    offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    masks_am = offsets_am < M

    offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
    masks_bn = offsets_bn < N // 8

    offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
    masks_zn = offsets_zn < N // 8

    offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    masks_sn = offsets_sn < N

    offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
    offsets_a = K * offsets_am[:, None] + offsets_k[None, :]
    offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :]

    a_ptrs = a_ptr + offsets_a
    b_ptrs = b_ptr + offsets_b

    # NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv
    # block_offset = BLOCK_SIZE_K * SPLIT_K
    # for k in range(0, (K + block_offset - 1) // (block_offset)):
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
        masks_k = offsets_k < K
        masks_a = masks_am[:, None] & masks_k[None, :]
        a = tl.load(a_ptrs, mask=masks_a, other=0.0)

        masks_b = masks_k[:, None] & masks_bn[None, :]
        b = tl.load(b_ptrs, mask=masks_b, other=0.0)
        b = tl.interleave(b, b)
        b = tl.interleave(b, b)
        b = tl.interleave(b, b)

        # Dequantize b.
        offsets_szk = (
            BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K
        ) // group_size + tl.arange(0, 1)
        offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :]
        masks_zk = offsets_szk < K // group_size
        masks_z = masks_zk[:, None] & masks_zn[None, :]
        zeros_ptrs = zeros_ptr + offsets_z
        zeros = tl.load(zeros_ptrs, mask=masks_z, other=0.0)
        zeros = tl.interleave(zeros, zeros)
        zeros = tl.interleave(zeros, zeros)
        zeros = tl.interleave(zeros, zeros)
        zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N))

        offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :]
        masks_sk = offsets_szk < K // group_size
        masks_s = masks_sk[:, None] & masks_sn[None, :]
        scales_ptrs = scales_ptr + offsets_s
        scales = tl.load(scales_ptrs, mask=masks_s, other=0.0)
        scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N))

        b = (b >> shifts) & 0xF
        zeros = (zeros >> shifts) & 0xF
        b = (b - zeros) * scales
        b = b.to(c_ptr.type.element_ty)

        # Accumulate results.
        accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)

        offsets_k += BLOCK_SIZE_K * SPLIT_K
        a_ptrs += BLOCK_SIZE_K * SPLIT_K
        b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8)

    c = accumulator.to(c_ptr.type.element_ty)
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + pid_z * N * M + N * offs_cm[:, None] + offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)


# input   - [M, K]
# qweight - [K, N // 8]
# qzeros  - [K // G, N // 8]
# scales  - [K // G, N]
# split_k_iters - parallelism along K-dimension, int, power of 2.
def awq_gemm_triton(
    input: torch.Tensor,
    qweight: torch.Tensor,
    scales: torch.Tensor,
    qzeros: torch.Tensor,
    split_k_iters: int,
    block_size_m: int = 32,
    block_size_n: int = 32,
    block_size_k: int = 32,
) -> torch.Tensor:
    M, K = input.shape
    N = qweight.shape[1] * 8
    group_size = qweight.shape[0] // qzeros.shape[0]

    assert N > 0 and K > 0 and M > 0
    assert qweight.shape[0] == K and qweight.shape[1] == N // 8
    assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8
    assert scales.shape[0] == K // group_size and scales.shape[1] == N
    assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0
    assert split_k_iters <= 32
    assert group_size <= K
    assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K

    grid = lambda META: (
        triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
        split_k_iters,
    )

    result = torch.zeros((split_k_iters, M, N), dtype=scales.dtype, device=input.device)

    # A = input, B = qweight, C = result
    # A = M x K, B = K x N, C = M x N
    awq_gemm_kernel[grid](
        input,
        qweight,
        result,
        qzeros,
        scales,
        M,
        N,
        K,
        group_size,
        BLOCK_SIZE_M=block_size_m,
        BLOCK_SIZE_N=block_size_n,
        BLOCK_SIZE_K=block_size_k,
        SPLIT_K=split_k_iters,
    )

    result = result.sum(0)

    return result

代码分析:

graph TB
    subgraph Host_Side [1. Python/Host 端: 任务分发与 Grid 配置]
        Init["输入 A [M, K] (fp16)<br/>qweight [K, N/8] (int32)"]
        GridConfig["计算 Grid 尺寸:<br/>(M/BLOCK_M * N/BLOCK_N, SPLIT_K)"]
        Workspace["分配临时 Result [SPLIT_K, M, N] (fp16)"]
    end

    subgraph Triton_Kernel [2. Triton Kernel 内部: 每一个程序实例 pid]
        direction TB

        subgraph Load_Phases [加载阶段]
            LoadA["加载 A 分块 [BLOCK_M, BLOCK_K]<br/>(fp16)"]
            LoadB["加载 qweight 分块 [BLOCK_K, BLOCK_N/8]<br/>(int32)"]
        end

        subgraph Dequant_Logic [核心反量化: 寄存器级操作]
            direction LR
            Interleave["tl.interleave (8x)<br/>将 int32 广播扩展"]
            Shift[">> shifts & 0xF<br/>(按 [0,4,1,5,2,6,3,7] 顺序)"]
            LoadSZ["加载 scales [1, BLOCK_N] (fp16)<br/>加载 qzeros [1, BLOCK_N/8] (int32)"]
            ApplySZ["W = (W_int4 - zero) * scale<br/>(类型转换为 fp16)"]
        end

        subgraph Compute_Phase [计算阶段]
            Dot["tl.dot(A, W)<br/>利用 Tensor Core 执行"]
            Accumulate["Accumulator 累加 (K-loop)"]
        end
    end

    subgraph Reduction_Phase [3. 结果合并: Split-K Reduction]
        Sum["result.sum(axis=0)<br/>将不同 K 段的结果相加"]
        Final["输出 C [M, N] (fp16)"]
    end

    Init --> GridConfig
    GridConfig --> Workspace
    Workspace --> Load_Phases
    LoadA --> Dot
    LoadB --> Dequant_Logic
    LoadSZ --> Dequant_Logic
    Dequant_Logic --> Dot
    Dot --> Accumulate
    Accumulate --> Sum
    Sum --> Final

    style Dequant_Logic fill:#f9f,stroke:#333,stroke-width:2px
    style Compute_Phase fill:#bbf,stroke:#333,stroke-width:2px

计算逻辑:

在 Triton kernel 内部,每一轮循环(K 维度的 BLOCK_SIZE_K):

  1. 加载:qweight 取出一个 [BLOCK_SIZE_K, BLOCK_SIZE_N / 8] 的块。
  2. 解包 (Unpack): 通过 (iweights >> shifts) & 0xF 得到 8 倍宽度的矩阵,恢复到逻辑上的 N 。
  3. 对齐 Group: 根据当前 k 索引找到对应的 scaleszeros。由于 scales 形状是 [K/G, N],而 k 在变动,所以每跑完 group_size 步,才会更新一次 scale。
  4. 反量化 (Dequantize):

  1. 累加: 与激活值 A 进行半精度矩阵乘法。