第七色在线视频,2021少妇久久久久久久久久,亚洲欧洲精品成人久久av18,亚洲国产精品特色大片观看完整版,孙宇晨将参加特朗普的晚宴

為了賬號(hào)安全,請(qǐng)及時(shí)綁定郵箱和手機(jī)立即綁定

triton學(xué)習(xí)筆記2: 循環(huán)優(yōu)化術(shù)

Puzzles 8: Long softmax

puzzles8是计算batch的softmax,题目如下:

Softmax of a batch of logits.

Uses one program block axis. Block size B0 represents the batch of x of length N0.

Block logit length T. Process it B1 < T elements at a time.

… math::

z_{i, j} = \text{softmax}(x_{i,1} \ldots x_{i, T}) \text{ for } i = 1\ldots N_0

Note softmax needs to be computed in numerically stable form as in Python. In addition in Triton

they recommend not using exp but instead using exp2. You need the identity

… math::

\exp(x) = 2^{\log_2(e) x}

Advanced: there one way to do this with 3 loops. You can also do it with 2 loops if you are clever.

Hint: you will find this identity useful:

… math::

\exp(x_i - m) =  \exp(x_i - m/2 - m/2) = \exp(x_i - m/ 2) /  \exp(m/2)

“”"

def softmax_spec(x: Float32[4, 200]) -> Float32[4, 200]:

x_max = x.max(1, keepdim=True)[0]

x = x - x_max

x_exp = x.exp()

return x_exp / x_exp.sum(1, keepdim=True)

然后这题需要提供两种解法,一种是暴力的解法,3个loop;另一种是聪明的解法,2个loop。先从暴力解法开始着手。

暴力解法思路

  1. 一个loop去取每一个行的最大值

  2. 每行中的每列减去对应行的最大值,顺便exp

  3. 一个loop去相加对应exp之后的值函数

  4. 一个loop计算最后的softmax

相关的triton接口

  1. torch.full(shape, value, dtype)可以直接初始化一个大小为shape,值为value的dtype向量,可以用来初始化极小值,用来取最大值,后面发现用tl.zeros也可以

解法

def softmax_kernel_brute_force(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
    """2 loops ver."""
    block_id_i = tl.program_id(0)
    log2_e = 1.44269504
    # Finish me!i
    offset_x = block_id_i * B0 + tl.arange(0,B0)
    mask_x = offset_x < N0

    row_max = tl.zeros(shape=[B0,1], dtype=tl.float32)
    row_sum_exp = tl.zeros([B0, 1], dtype=tl.float32)

    for idj in tl.range(0,T,B1):
        offset_y = idj + tl.arange(0,B1)
        mask_y = offset_y < T
        offset_xy = offset_x[:,None] * T + offset_y[None,:]
        mask_xy = mask_x[:,None] & mask_y[None,:]
        block_value =tl.load(x_ptr + offset_xy, mask_xy, other=float('-inf'))
        row_max = tl.maximum(row_max, tl.max(block_value,axis=1, keep_dims=True))


    for idj in tl.range(0, T, B1):
        offset_y = idj + tl.arange(0,B1)
        mask_y = offset_y < T
        offset_xy = offset_x[:,None] * T + offset_y[None,:]
        mask_xy = mask_x[:,None] & mask_y[None,:]
        block_value = tl.load(x_ptr + offset_xy, mask_xy, other=float('-inf'))
        block_value -= row_max
        row_sum_exp += tl.sum(exp_approx(block_value),axis=1, keep_dims=True)

    for idj in tl.range(0, T, B1):
        offset_y = idj + tl.arange(0,B1)
        mask_y = offset_y < T
        offset_xy = offset_x[:,None] * T + offset_y[None,:]
        mask_xy = mask_x[:,None] & mask_y[None,:]
        block_value = tl.load(x_ptr + offset_xy, mask_xy, other=float('-inf'))
        softmax_value = exp_approx(block_value - row_max) / row_sum_exp
        tl.store(z_ptr + offset_xy, softmax_value, mask_xy)
    return
  1. 写得比较冗长,但是核心思路应该就是上面说的三个循环

两个循环思路

  1. 这个思路就类似online softmax

m0←−∞d0←0for j←1,V domj←max⁡(mj−1,xj)dj←dj−1×emj−1−mj+exj−mj(Update row_sum_exp within the loop)end forfor i←1,V doyi←exi−mVdVend for \begin{aligned} & m_0 \leftarrow -\infty \\ & d_0 \leftarrow 0 \\ & \text{for } j \leftarrow 1, V \text{ do} \\ & \quad m_j \leftarrow \max(m_{j-1}, x_j) \\ & \quad d_j \leftarrow d_{j-1} \times e^{m_{j-1}-m_j} + e^{x_j-m_j} \quad \text{(Update row\_sum\_exp within the loop)} \\ & \text{end for} \\ & \text{for } i \leftarrow 1, V \text{ do} \\ & \quad y_i \leftarrow \frac{e^{x_i-m_V}}{d_V} \\ & \text{end for} \end{aligned} m0d00for j1,V domjmax(mj1,xj)djdj1×emj1mj+exjmj(Update row_sum_exp within the loop)end forfor i1,V doyidVeximVend for

解法

@triton.jit
def exp_approx(x):
    return tl.exp2(1.44269504 * x)



@triton.jit
def softmax_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
    """2 loops ver."""
    block_id_i = tl.program_id(0)
    log2_e = 1.44269504
    # Finish me!i
    offset_x = block_id_i * B0 + tl.arange(0,B0)
    mask_x = offset_x < N0

    row_max = tl.zeros(shape=[B0, 1],dtype=tl.float32)
    row_sum_exp = tl.zeros([B0, 1], dtype=tl.float32)

    for idj in tl.range(0, T, B1):
        offset_y = idj + tl.arange(0,B1)
        mask_y = offset_y < T
        offset_xy = offset_x[:,None] * T + offset_y[None,:]
        mask_xy = mask_x[:,None] & mask_y[None,:]
        block_value = tl.load(x_ptr + offset_xy, mask_xy, other=float('-inf'))

        tmp_row_max = row_max 
        tmp_row_max = tl.maximum(tl.max(block_value, axis=1, keep_dims=True), tmp_row_max)
        row_sum_exp = row_sum_exp * exp_approx(row_max - tmp_row_max) + tl.sum(exp_approx(block_value - tmp_row_max),axis=1,keep_dims=True)

        row_max = tmp_row_max

    for idj in tl.range(0, T, B1):
        offset_y = idj + tl.arange(0,B1)
        mask_y = offset_y < T
        offset_xy = offset_x[:,None] * T + offset_y[None,:]
        mask_xy = mask_x[:,None] & mask_y[None,:]
        block_value = tl.load(x_ptr + offset_xy, mask_xy, other=float('-inf'))
        z_value = exp_approx(block_value - row_max) / row_sum_exp
        tl.store(z_ptr + offset_xy, z_value, mask_xy)

    return 

Puzzle 9: Simple FlashAttention

A scalar version of FlashAttention.

Uses zero programs. Block size B0 represent the batches of q to process out of N0. Sequence length is T. Process it B1 < T elements (k, v) at a time for some B1.

… math::
z_{i} = \sum_{j=1}^{T} \text{softmax}(q_i k_1, \ldots, q_i k_T)j v{j} \text{ for } i = 1\ldots N_0

This can be done in 1 loop using a similar trick from the last puzzle.

Hint: Use tl.where to mask q dot k to -inf to avoid overflow (NaN).

这个类似flash attention v1了,one pass

Flash attention v1的完整递推公式

xi←Q[k,:]⋅KT[:,i]mi←max⁡(mi−1,xi)di′←di−1′⋅emi−1−mi+exi−miOi′←Oi−1′⋅di−1′di′⋅emi−1−mi+exi−midi′⋅V[i,:] \mathbf{ \begin{aligned} x_i &\leftarrow Q[k,:] \cdot K^T[:,i] \\ m_i &\leftarrow \max(m_{i-1}, x_i) \\ d_i' &\leftarrow d_{i-1}' \cdot e^{m_{i-1} - m_i} + e^{x_i - m_i} \\ O_i' &\leftarrow O_{i-1}' \cdot \frac{d_{i-1}'}{d_i'} \cdot e^{m_{i-1} - m_i} + \frac{e^{x_i - m_i}}{d_i'} \cdot V[i,:] \\ \end{aligned} } ximidiOiQ[k,:]KT[:,i]max(mi1,xi)di1emi1mi+eximiOi1didi1emi1mi+dieximiV[i,:]

最终输出:

O[k,:]←ON′ O[k,:] \leftarrow O_N' O[k,:]ON

其中:

  • Q[k,:]Q[k,:]Q[k,:]QQQ 矩阵的第 kkk 行向量。

  • KT[:,i]K^T[:,i]KT[:,i]KTK^TKT 矩阵的第 iii 列向量。

  • xix_ixi是预 softmax 的 logits 值。

  • $ m_i $ 是累积的最大值。

  • di′d_i'di 是累积的指数和。

  • Oi′O_i'Oi 是部分输出的累积值。

  • V[i,:]V[i,:]V[i,:] 是 $ V $ 矩阵的第 $ i $ 行向量。

  • O[k,:]O[k,:]O[k,:] 是输出矩阵的第 $k $ 行向量。

解法

@triton.jit
def myexp(x):
    return tl.exp2(1.44269504 * x)

@triton.jit
def flashatt_kernel(
    q_ptr, k_ptr, v_ptr, z_ptr, N0, T, B0: tl.constexpr, B1: tl.constexpr
):
    block_id_i = tl.program_id(0)
    log2_e = 1.44269504
    # Finish me!
    
    off_i = block_id_i * B0 + tl.arange(0, B0)
    mask_i = off_i < N0
    inf = 1.0e6

    # Need `other`!!!
    q = tl.load(q_ptr + off_i, mask=mask_i)

    # The variable names of Triton's offcial FlashAttention tutorial
    # is attached here for reference.
    # Our variable names are consistent with Puzzle 8.

    # l_i
    exp_sum = tl.zeros((B0,), dtype=tl.float32)
    # m_i
    qk_max = tl.full((B0,), -inf, dtype=tl.float32)
    z = tl.zeros((B0,), dtype=tl.float32)

    for id_j in tl.range(0, T, B1):
        off_j = id_j + tl.arange(0, B1)
        mask_j = off_j < T
        mask_ij = mask_i[:, None] & mask_j[None, :]

        k = tl.load(k_ptr + off_j, mask=mask_j)
        qk = q[:, None] * k[None, :] + tl.where(mask_ij, 0, -inf)
        # print(qk.shape)

        # m_ij
        new_max = tl.maximum(tl.max(qk, axis=1), qk_max)
        qk_exp = myexp(qk - new_max[:, None])
        # alpha
        factor = myexp(qk_max - new_max)
        # l_ij
        new_exp_sum = exp_sum * factor + tl.sum(qk_exp, axis=1)
        v = tl.load(v_ptr + off_j, mask=mask_j, other=0.0)
        z = z * factor + tl.sum(qk_exp * v[None, :], axis=1)

        qk_max = new_max
        exp_sum = new_exp_sum

    z = z / exp_sum
    tl.store(z_ptr + off_i, z, mask=mask_i)
    return

Reference

本文由博客一文多发平台 OpenWrite 发布!

點(diǎn)擊查看更多內(nèi)容
TA 點(diǎn)贊

若覺(jué)得本文不錯(cuò),就分享一下吧!

評(píng)論

作者其他優(yōu)質(zhì)文章

正在加載中
  • 推薦
  • 評(píng)論
  • 收藏
  • 共同學(xué)習(xí),寫下你的評(píng)論
感謝您的支持,我會(huì)繼續(xù)努力的~
掃碼打賞,你說(shuō)多少就多少
贊賞金額會(huì)直接到老師賬戶
支付方式
打開(kāi)微信掃一掃,即可進(jìn)行掃碼打賞哦
今天注冊(cè)有機(jī)會(huì)得

100積分直接送

付費(fèi)專欄免費(fèi)學(xué)

大額優(yōu)惠券免費(fèi)領(lǐng)

立即參與 放棄機(jī)會(huì)
微信客服

購(gòu)課補(bǔ)貼
聯(lián)系客服咨詢優(yōu)惠詳情

幫助反饋 APP下載

慕課網(wǎng)APP
您的移動(dòng)學(xué)習(xí)伙伴

公眾號(hào)

掃描二維碼
關(guān)注慕課網(wǎng)微信公眾號(hào)

舉報(bào)

0/150
提交
取消