7-AWQ算子

7-AWQ算子
gogongxtAWQ算子实现
目标
AWQ算子本质上就是实现一个矩阵乘
- 输入 A: 正常的连续内存
[M, K]。 - 输入 B: 高度int4压缩。水平方向(N)打包了
bit,垂直方向(K)按
group_size进行了分块量化。 - 输出 C: 正常的连续内存
[M, N]
如下图所示:
graph TD
subgraph Input_A [输入矩阵 A: 激活值]
A_data["形状: [M, K]<br/>类型: float16/bfloat16<br/>(连续内存排布)"]
end
subgraph Input_B [输入矩阵 B: 量化权重]
direction TB
subgraph B_Storage [存储结构]
QW["qweight<br/>[K, N/8]<br/>int32 (打包 8个 4-bit)"]
QZ["qzeros<br/>[K/G, N/8]<br/>int32 (打包 8个 4-bit)"]
QS["scales<br/>[K/G, N]<br/>float16"]
end
subgraph B_Logic [逻辑视图]
W_Logic["解包后 B': [K, N]<br/>(用于计算)"]
end
end
subgraph Operation [Triton Kernel: awq_gemm_kernel]
direction LR
LoadA["加载 A 分块"]
LoadB["加载 qweight/zeros/scales"]
Dequant["位移解包 + 反量化:<br/>W = (qweight - zero) * scale"]
Dot["Tensor Core 矩阵乘<br/>tl.dot(A, W)"]
end
subgraph Output_C [输出矩阵 C]
C_data["形状: [M, N]<br/>类型: float16"]
end
A_data --> LoadA
QW --> LoadB
QZ --> LoadB
QS --> LoadB
LoadB --> Dequant
LoadA --> Dot
Dequant --> Dot
Dot --> C_data
实现代码
代码来自最新的vllm算子实现,事实上sglang中也是adapt的vllm的算子
实现代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/awq_triton.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import triton
import triton.language as tl
AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
@triton.jit
def awq_gemm_kernel(
a_ptr,
b_ptr,
c_ptr,
zeros_ptr,
scales_ptr,
M,
N,
K,
group_size,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
SPLIT_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
pid_z = tl.program_id(1)
# NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
# num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
accumulator_dtype = c_ptr.type.element_ty
# NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
# accumulator = tl.arange(0, BLOCK_SIZE_N)
# accumulator = tl.broadcast_to(accumulator[None, :],
# (BLOCK_SIZE_M, BLOCK_SIZE_N))
# accumulator = accumulator & 0x0
# accumulator = accumulator.to(accumulator_dtype)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype)
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
# that will map given indices to the correct order.
reverse_awq_order_tensor = (
(tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None]
).reshape(8)
# Create the necessary shifts to use to unpack.
shifts = reverse_awq_order_tensor * 4
shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8))
shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N))
# Offsets and masks.
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
masks_am = offsets_am < M
offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
masks_bn = offsets_bn < N // 8
offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
masks_zn = offsets_zn < N // 8
offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
masks_sn = offsets_sn < N
offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offsets_a = K * offsets_am[:, None] + offsets_k[None, :]
offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :]
a_ptrs = a_ptr + offsets_a
b_ptrs = b_ptr + offsets_b
# NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv
# block_offset = BLOCK_SIZE_K * SPLIT_K
# for k in range(0, (K + block_offset - 1) // (block_offset)):
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
masks_k = offsets_k < K
masks_a = masks_am[:, None] & masks_k[None, :]
a = tl.load(a_ptrs, mask=masks_a, other=0.0)
masks_b = masks_k[:, None] & masks_bn[None, :]
b = tl.load(b_ptrs, mask=masks_b, other=0.0)
b = tl.interleave(b, b)
b = tl.interleave(b, b)
b = tl.interleave(b, b)
# Dequantize b.
offsets_szk = (
BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K
) // group_size + tl.arange(0, 1)
offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :]
masks_zk = offsets_szk < K // group_size
masks_z = masks_zk[:, None] & masks_zn[None, :]
zeros_ptrs = zeros_ptr + offsets_z
zeros = tl.load(zeros_ptrs, mask=masks_z, other=0.0)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N))
offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :]
masks_sk = offsets_szk < K // group_size
masks_s = masks_sk[:, None] & masks_sn[None, :]
scales_ptrs = scales_ptr + offsets_s
scales = tl.load(scales_ptrs, mask=masks_s, other=0.0)
scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N))
b = (b >> shifts) & 0xF
zeros = (zeros >> shifts) & 0xF
b = (b - zeros) * scales
b = b.to(c_ptr.type.element_ty)
# Accumulate results.
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
offsets_k += BLOCK_SIZE_K * SPLIT_K
a_ptrs += BLOCK_SIZE_K * SPLIT_K
b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8)
c = accumulator.to(c_ptr.type.element_ty)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + pid_z * N * M + N * offs_cm[:, None] + offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
# input - [M, K]
# qweight - [K, N // 8]
# qzeros - [K // G, N // 8]
# scales - [K // G, N]
# split_k_iters - parallelism along K-dimension, int, power of 2.
def awq_gemm_triton(
input: torch.Tensor,
qweight: torch.Tensor,
scales: torch.Tensor,
qzeros: torch.Tensor,
split_k_iters: int,
block_size_m: int = 32,
block_size_n: int = 32,
block_size_k: int = 32,
) -> torch.Tensor:
M, K = input.shape
N = qweight.shape[1] * 8
group_size = qweight.shape[0] // qzeros.shape[0]
assert N > 0 and K > 0 and M > 0
assert qweight.shape[0] == K and qweight.shape[1] == N // 8
assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8
assert scales.shape[0] == K // group_size and scales.shape[1] == N
assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0
assert split_k_iters <= 32
assert group_size <= K
assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K
grid = lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
split_k_iters,
)
result = torch.zeros((split_k_iters, M, N), dtype=scales.dtype, device=input.device)
# A = input, B = qweight, C = result
# A = M x K, B = K x N, C = M x N
awq_gemm_kernel[grid](
input,
qweight,
result,
qzeros,
scales,
M,
N,
K,
group_size,
BLOCK_SIZE_M=block_size_m,
BLOCK_SIZE_N=block_size_n,
BLOCK_SIZE_K=block_size_k,
SPLIT_K=split_k_iters,
)
result = result.sum(0)
return result代码分析:
graph TB
subgraph Host_Side [1. Python/Host 端: 任务分发与 Grid 配置]
Init["输入 A [M, K] (fp16)<br/>qweight [K, N/8] (int32)"]
GridConfig["计算 Grid 尺寸:<br/>(M/BLOCK_M * N/BLOCK_N, SPLIT_K)"]
Workspace["分配临时 Result [SPLIT_K, M, N] (fp16)"]
end
subgraph Triton_Kernel [2. Triton Kernel 内部: 每一个程序实例 pid]
direction TB
subgraph Load_Phases [加载阶段]
LoadA["加载 A 分块 [BLOCK_M, BLOCK_K]<br/>(fp16)"]
LoadB["加载 qweight 分块 [BLOCK_K, BLOCK_N/8]<br/>(int32)"]
end
subgraph Dequant_Logic [核心反量化: 寄存器级操作]
direction LR
Interleave["tl.interleave (8x)<br/>将 int32 广播扩展"]
Shift[">> shifts & 0xF<br/>(按 [0,4,1,5,2,6,3,7] 顺序)"]
LoadSZ["加载 scales [1, BLOCK_N] (fp16)<br/>加载 qzeros [1, BLOCK_N/8] (int32)"]
ApplySZ["W = (W_int4 - zero) * scale<br/>(类型转换为 fp16)"]
end
subgraph Compute_Phase [计算阶段]
Dot["tl.dot(A, W)<br/>利用 Tensor Core 执行"]
Accumulate["Accumulator 累加 (K-loop)"]
end
end
subgraph Reduction_Phase [3. 结果合并: Split-K Reduction]
Sum["result.sum(axis=0)<br/>将不同 K 段的结果相加"]
Final["输出 C [M, N] (fp16)"]
end
Init --> GridConfig
GridConfig --> Workspace
Workspace --> Load_Phases
LoadA --> Dot
LoadB --> Dequant_Logic
LoadSZ --> Dequant_Logic
Dequant_Logic --> Dot
Dot --> Accumulate
Accumulate --> Sum
Sum --> Final
style Dequant_Logic fill:#f9f,stroke:#333,stroke-width:2px
style Compute_Phase fill:#bbf,stroke:#333,stroke-width:2px
计算逻辑:
在 Triton kernel 内部,每一轮循环(K 维度的
BLOCK_SIZE_K):
- 加载: 从
qweight取出一个[BLOCK_SIZE_K, BLOCK_SIZE_N / 8]的块。 - 解包 (Unpack): 通过
(iweights >> shifts) & 0xF得到 8 倍宽度的矩阵,恢复到逻辑上的 N 。 - 对齐 Group: 根据当前
k索引找到对应的scales和zeros。由于scales形状是[K/G, N],而k在变动,所以每跑完group_size步,才会更新一次 scale。 - 反量化 (Dequantize):
- 累加: 与激活值
A进行半精度矩阵乘法。
评论
匿名评论隐私政策





