背景

在上一篇文章C#/.NET做张量计算的一些痛点,思考和解决方案里,讨论了一些pure C#来进行张量运算的设计问题,这篇文章则是自己的一个实践。

其实矩阵乘的优化本身在c/cpp那边已经被玩的很透了,只不过C#/.NET这边确实没有什么人来用pure C#做这个优化,其实C#也是支持AVX等指令的(System.Numerics),本文主要是自己实践过程中的一些记录,以已有的常用优化方法为主。

注:本文的优化方法主要参考自how-to-optimize-gemm一文。

常规优化

首先,给出一个Naive的算子实现作为参照,为了方便,这里取了dotnet 7并用了generic math

public unsafe void Exec(ReadOnlySpan<T> a, ReadOnlySpan<T> b, Span<T> c, in NativeLayout layoutA, in NativeLayout layoutB, in NativeLayout layoutC)
{
    int aRows = layoutA._shape[0];
    int aCols = layoutA._shape[1];
    int bCols = layoutB._shape[1];
    for (int i = 0; i < aRows; i++)
    {
        for (int j = 0; j < bCols; j++)
        {
            var res = T.Zero;
            for (int k = 0; k < aCols; k++)
            {
                res += a[i * aCols + k] * b[k * bCols + j];
            }
            c[i * bCols + j] = res;
        }
    }
}

对于性能测试数据,这里以float32的测试为主,用了如下六组测试数据,分别是小规模、中等规模、大规模,超大规模以及两种极端情况。为了简洁,我们这里以Naive实现为基准,用“相对耗时”来衡量各种方式的加速效果,“相对耗时“是指某一实现测试出来的耗时与相同规模下Naive实现耗时的商。

{ 6, 8 }, { 8, 11}
{ 32, 64 }, { 64, 96}
{ 256, 512 }, { 512, 768}
{ 1000, 1000 }, { 1000, 1000}
{ 2, 1024 }, { 1024, 1}
{ 1024, 1 }, { 2, 1024}

在本文,我们假定矩阵计算为C = A * B,并且A是m行n列,B是n行p列。矩阵按行优先的方式存储。

为了尽可能避免外在因素干扰,本文代码都运行在没有UI的服务器上,配置如下:

BenchmarkDotNet=v0.13.2, OS=ubuntu 22.04
Intel Xeon Gold 6148 CPU 2.40GHz, 1 CPU, 2 logical and 2 physical cores
.NET SDK=7.0.100-rc.2.22477.23
  [Host]     : .NET 7.0.0 (7.0.22.47203), X64 RyuJIT AVX2
  Job-GZJTJY : .NET 7.0.0 (7.0.22.47203), X64 RyuJIT AVX2

简单交换外层循环顺序

首先,我们试以下将外层的循环交换顺序,,由mpn变成mnp,这样做的理由是可以增加局部性。我们可以认为两个矩阵的数据分别会进入cache,由于我们的矩阵都是行优先存储的,那么当我们计算结果矩阵中(i, j)处的数据时,是用A中的第i行数据和B中的第j列数据来运算,这时候我们可以假设这部分数据全部进入了cache。

如果我们按照mpn来循环,那么下一个计算的点是C[i, j + 1],也就是A中第i行和B中第j+1列进行运算,这时候A中的局部性得到保持,而B中的局部性会被破坏。如果我们按照mnp来循环,两边的局部性都会得到保证。

在实践中其实我们很少直接这样交换顺序来优化,并且C#中release模式进行的优化本身就考虑了这个问题,所以我们测试出来结果不会有太大变动(我的实测是对于两个1000x1000的矩阵会有17%的提升),这里就不放代码和测试数据,如果感兴趣可以看这篇文章

四个操作一组简单组合

这其实是unroll思想的一种,即我们在内层的循环里面一次性取4个元素来进行计算。这样做局部性其实没有太大变化,但是对于边界的判断次数大大减少,只需要原来的1/4次,代码实现如下:

public unsafe void Exec(ReadOnlySpan<T> a, ReadOnlySpan<T> b, Span<T> c, in NativeLayout layoutA, in NativeLayout layoutB, in NativeLayout layoutC)
{
    int aRows = layoutA._shape[0];
    int aCols = layoutA._shape[1];
    int bCols = layoutB._shape[1];

    for (int i = 0; i <= aRows - 4; i += 4)
    {
        for (int j = 0; j < bCols; j++)
        {
            T c0 = T.Zero, c1 = T.Zero, c2 = T.Zero, c3 = T.Zero;
            int idx0 = i * aCols, idx1 = (i + 1) * aCols, idx2 = (i + 2) * aCols, idx3 = (i + 3) * aCols;
            for (int k = 0; k < aCols; k++)
            {
                var value = b[k * bCols + j];
                c0 += value * a[idx0++];
                c1 += value * a[idx1++];
                c2 += value * a[idx2++];
                c3 += value * a[idx3++];
            }
            c[i * bCols + j] = c0;
            c[(i + 1) * bCols + j] = c1;
            c[(i + 2) * bCols + j] = c2;
            c[(i + 3) * bCols + j] = c3;
        }
    }
    for(int i = aRows / 4 * 4; i < aRows; i++)
    {
        for (int j = 0; j < bCols; j++)
        {
            for (int k = 0; k < aCols; k++)
            {
                c[i * bCols + j] += a[i * aCols + k] * b[k * bCols + j];
            }
        }
     }
}

Benchmark测试的结果如下:

Shape 6, 8, 11 32, 64, 96 256, 512, 768 1000, 1000, 1000 2, 1024, 1 1024, 1, 1024
相对耗时 93% 58% 60% 78% 215% 72%

可以看到对于较大规模的矩阵耗时已经减少了很多,但是这个提升还是远远不够的。特别的,对于行向量乘以列向量这种情况,该方法是反而变慢的,这是因为该方法是一次性计算多行,那么既然A只有一行,我们只是徒增了一些额外操作而已。

4x4一组进行计算

既然有了4x1的优化方式,那我们也很自然可以想到是否能用4x4的方式,继续通过空间换取时间来优化,与上面思路相同的代码如下:

public unsafe void Exec(ReadOnlySpan<T> a, ReadOnlySpan<T> b, Span<T> c, in NativeLayout layoutA, in NativeLayout layoutB, in NativeLayout layoutC)
{
    int aRows = layoutA._shape[0];
    int aCols = layoutA._shape[1];
    int bCols = layoutB._shape[1];

    T c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31, c32, c33;
    T b00, b01, b02, b03;

    for (int i = 0; i <= aRows - 4; i += 4)
    {
        for (int j = 0; j <= bCols - 4; j += 4)
        {
            c00 = c01 = c02 = c03 = c10 = c11 = c12 = c13 = c20 = c21 = c22 = c23 = c30 = c31 = c32 = c33 = T.Zero;
            b00 = b01 = b02 = b03 = T.Zero;
            int idx0 = i * aCols, idx1 = (i + 1) * aCols, idx2 = (i + 2) * aCols, idx3 = (i + 3) * aCols;

            for (int k = 0; k < aCols; k++)
            {
                b00 = b[k * bCols + j];
                b01 = b[k * bCols + j + 1];
                b02 = b[k * bCols + j + 2];
                b03 = b[k * bCols + j + 3];
                
                c00 += b00 * a[idx0];
                c10 += b00 * a[idx1];
                c20 += b00 * a[idx2];
                c30 += b00 * a[idx3];

                c01 += b01 * a[idx0];
                c11 += b01 * a[idx1];
                c21 += b01 * a[idx2];
                c31 += b01 * a[idx3];

                c02 += b02 * a[idx0];
                c12 += b02 * a[idx1];
                c22 += b02 * a[idx2];
                c32 += b02 * a[idx3];

                c03 += b03 * a[idx0++];
                c13 += b03 * a[idx1++];
                c23 += b03 * a[idx2++];
                c33 += b03 * a[idx3++];
            }
            c[i * bCols + j] = c00;
            c[i * bCols + j + 1] = c01;
            c[i * bCols + j + 2] = c02;
            c[i * bCols + j + 3] = c03;
            c[(i + 1) * bCols + j] = c10;
            c[(i + 1) * bCols + j + 1] = c11;
            c[(i + 1) * bCols + j + 2] = c12;
            c[(i + 1) * bCols + j + 3] = c13;
            c[(i + 2) * bCols + j] = c20;
            c[(i + 2) * bCols + j + 1] = c21;
            c[(i + 2) * bCols + j + 2] = c22;
            c[(i + 2) * bCols + j + 3] = c23;
            c[(i + 3) * bCols + j] = c30;
            c[(i + 3) * bCols + j + 1] = c31;
            c[(i + 3) * bCols + j + 2] = c32;
            c[(i + 3) * bCols + j + 3] = c33;
        }
    }
    for (int i = aRows / 4 * 4; i < aRows; i++)
    {
        for (int j = 0; j < bCols; j++)
        {
            for (int k = 0; k < aCols; k++)
            {
                c[i * bCols + j] += a[i * aCols + k] * b[k * bCols + j];
            }
        }
    }
    for (int i = 0; i < aRows / 4 * 4; i++)
    {
        for (int j = bCols / 4 * 4; j < bCols; j++)
        {
            for (int k = 0; k < aCols; k++)
            {
                c[i * bCols + j] += a[i * aCols + k] * b[k * bCols + j];
            }
        }
    }
}

性能测试的加速比如下,可以看出相比之前4x1的方式又有了全面的提升,不过这个提升还是不够的。

Shape 6, 8, 11 32, 64, 96 256, 512, 768 1000, 1000, 1000 2, 1024, 1 1024, 1, 1024
相对耗时 91% 40% 48% 52% 214% 57%

引入SIMD

在x86上SIMD可以说是一个神器了,也就是所谓“单指令多数据”,C#中其实给了对于SIMD的封装,可以在System.Numerics这个命名空间下找到,有两种形式,一种是Vector<T>,一种是Vector256Vector128,总的来讲比较推荐使用前者,因为后者在你需要处理多种数据类型的时候,底层实现会很麻烦,要给各种类型来一个kernel实现。

当然,这里也提一嘴,Vector<T>整体性能相对于Vector256Vector128一般只有轻微的下降,不过也有少数坑,比如这个issue

我们参照MMult_4x4_10.c,写出如下代码。原文是用了4x4以及SSE128的指令,并且还是double运算,但我们现在是进行float运算并且一般也都用avx256指令,所以我们改写成8x8的kernel并在此基础上应用AVX256指令。

public unsafe void Exec(ReadOnlySpan<T> a, ReadOnlySpan<T> b, Span<T> c, in NativeLayout layoutA, in NativeLayout layoutB, in NativeLayout layoutC)
{
    int aRows = layoutA._shape[0];
    int aCols = layoutA._shape[1];
    int bCols = layoutB._shape[1];

    int aIdx = 0, bIdx = 0;

    for (int i = 0; i <= aRows - 8; i += 8)
    {
        bIdx = 0;
        for (int j = 0; j <= bCols - 8; j += 8)
        {
            Kernel32b8x8(a.Slice(aIdx), b.Slice(bIdx), c.Slice(i * bCols + j), aRows, aCols, bCols);
            bIdx += 8;
        }
        aIdx += 8 * aCols;
    }
    for (int i = aRows / 8 * 8; i < aRows; i++)
    {
        for (int j = 0; j < bCols; j++)
        {
            for (int k = 0; k < aCols; k++)
            {
                c[i * bCols + j] += a[i * aCols + k] * b[k * bCols + j];
            }
        }
    }
    for (int i = 0; i < aRows / 8 * 8; i++)
    {
        for (int j = bCols / 8 * 8; j < bCols; j++)
        {
            for (int k = 0; k < aCols; k++)
            {
                c[i * bCols + j] += a[i * aCols + k] * b[k * bCols + j];
            }
        }
    }
}

public static unsafe void Kernel32b8x8(ReadOnlySpan<T> a, ReadOnlySpan<T> b, Span<T> c, int aRows, int aCols, int bCols)
{
    Vector<T> c_0_v = new Vector<T>();
    Vector<T> c_1_v = new Vector<T>();
    Vector<T> c_2_v = new Vector<T>();
    Vector<T> c_3_v = new Vector<T>();
    Vector<T> c_4_v = new Vector<T>();
    Vector<T> c_5_v = new Vector<T>();
    Vector<T> c_6_v = new Vector<T>();
    Vector<T> c_7_v = new Vector<T>();

    Vector<T> a_0_v, a_1_v, a_2_v, a_3_v, a_4_v, a_5_v, a_6_v, a_7_v;
    Vector<T> b_v;
    int offset = 0;

    for (int k = 0; k < aCols; k++)
    {
        a_0_v = new Vector<T>(a[k]);
        a_1_v = new Vector<T>(a[aCols + k]);
        a_2_v = new Vector<T>(a[2 * aCols + k]);
        a_3_v = new Vector<T>(a[3 * aCols + k]);
        a_4_v = new Vector<T>(a[4 * aCols + k]);
        a_5_v = new Vector<T>(a[5 * aCols + k]);
        a_6_v = new Vector<T>(a[6 * aCols + k]);
        a_7_v = new Vector<T>(a[7 * aCols + k]);

        b_v = new Vector<T>(b.Slice(offset));
        offset += bCols;

        c_0_v += a_0_v * b_v;
        c_1_v += a_1_v * b_v;
        c_2_v += a_2_v * b_v;
        c_3_v += a_3_v * b_v;
        c_4_v += a_4_v * b_v;
        c_5_v += a_5_v * b_v;
        c_6_v += a_6_v * b_v;
        c_7_v += a_7_v * b_v;
    }
    c_0_v.CopyTo(c);
    c_1_v.CopyTo(c.Slice(bCols));
    c_2_v.CopyTo(c.Slice(2 * bCols));
    c_3_v.CopyTo(c.Slice(3 * bCols));
    c_4_v.CopyTo(c.Slice(4 * bCols));
    c_5_v.CopyTo(c.Slice(5 * bCols));
    c_6_v.CopyTo(c.Slice(6 * bCols));
    c_7_v.CopyTo(c.Slice(7 * bCols));
}

性能测试的结果如下:

Shape 6, 8, 11 32, 64, 96 256, 512, 768 1000, 1000, 1000 2, 1024, 1 1024, 1, 1024
相对耗时 152% 8% 11% 12% 216% 69%

可以看到,效果拔群!对于中等以上规模的矩阵优化都达到了10倍左右,虽然小规模矩阵和两个特殊情况的运行时间反而增加,但瑕不掩瑜,毕竟对于这些特殊情况我们之后可以单独做处理。

其实这里有个很奇怪的地方,在这个性能优化对比中,加了SIMD之后提升大概是1倍,但是在我们这里,直接相比之前获得了额外大概5倍的提升。4x4改写成8x8并不是造成这一现象的唯一原因(可以自行尝试,这里不赘述),真正原因是cpp中开启O3优化之后,编译器会隐式地为我们加上SIMD优化,所以在cpp中,优化虽然有,但并不明显(仍然有优化是编译器有时候做的不如手动好),而C#中则没有默认这一优化,所以会导致这样一个差别。

如果为Naive实现和SIMD的实现都加上AggressiveOptimization的标签,两者的性能差距会被拉近一些,不过仍然会有8倍多的提升。

进行Blocking

Cache miss往往是造成性能降低的一大杀手,能让cache更稳定地命中是可以很好提升性能的。

Blocking的思想就是我们每次都把计算局限于一个区域内,然后把这个区域能计算的行列都计算完再处理下一个区域,这样就可以保留局部性。我们可以把A进行Block,也可以对B进行Block,不过整体来讲对B进行Block会效果更好一点(因为我们是行主序),代码如下。其中_bh和_bw分别是指block的高和宽,需要自己设定,在这里我们测试的时候设置成128x256。

public unsafe void Exec(ReadOnlySpan<T> a, ReadOnlySpan<T> b, Span<T> c, in NativeLayout layoutA, in NativeLayout layoutB, in NativeLayout layoutC)
{
    int aRows = layoutA._shape[0];
    int aCols = layoutA._shape[1];
    int bCols = layoutB._shape[1];
    int aIdx, bIdx = 0, cIdx = 0, bOffset;

    for (int k = 0; k < aCols; k += _bh)
    {
        int height = (aCols - k) > _bh ? _bh : (aCols - k);
        cIdx = bOffset = 0;
        bIdx = k * bCols;
        aIdx = k;
        for (int j = 0; j < bCols; j += _bw, bOffset += _bw, cIdx += _bw)
        {
            int width = (bCols - j) > _bw ? _bw : (bCols - j);
            ExecBlock(a.Slice(aIdx), b.Slice(bIdx + bOffset), c.Slice(cIdx), aRows, height, width, aRows, aCols, bCols);
        }
    }
}

// (m, p) * (p, n)
public static unsafe void ExecBlock(ReadOnlySpan<T> a, ReadOnlySpan<T> b, Span<T> c, int m, int p, int n, int aRows, int aCols, int bCols)
{
    int aIdx = 0;
    int bIdx, cIdx;

    for (int i = 0; i <= m - 8; i += 8)
    {
        bIdx = 0;
        cIdx = i * bCols;
        aIdx = i * aCols;
        for (int j = 0; j <= n - 8; j += 8, cIdx += 8)
        {
            Kernel32b8x8(a.Slice(aIdx), b.Slice(bIdx), c.Slice(cIdx), p, aRows, aCols, bCols);
            bIdx += 8;
        }
    }
    for (int i = m / 8 * 8; i < m; i++)
    {
        for (int j = 0; j < n; j++)
        {
            for (int k = 0; k < p; k++)
            {
                c[i * bCols + j] += a[i * aCols + k] * b[k * bCols + j];
            }
        }
    }
    for (int i = 0; i < m / 8 * 8; i++)
    {
        for (int j = n / 8 * 8; j < n; j++)
        {
            for (int k = 0; k < p; k++)
            {
                c[i * bCols + j] += a[i * aCols + k] * b[k * bCols + j];
            }
        }
    }
}

public static unsafe void Kernel32b8x8(ReadOnlySpan<T> a, ReadOnlySpan<T> b, Span<T> c, int p, int aRows, int aCols, int bCols)
{
    Vector<T> c_0_v = new Vector<T>(c);
    Vector<T> c_1_v = new Vector<T>(c.Slice(bCols));
    Vector<T> c_2_v = new Vector<T>(c.Slice(2 * bCols));
    Vector<T> c_3_v = new Vector<T>(c.Slice(3 * bCols));
    Vector<T> c_4_v = new Vector<T>(c.Slice(4 * bCols));
    Vector<T> c_5_v = new Vector<T>(c.Slice(5 * bCols));
    Vector<T> c_6_v = new Vector<T>(c.Slice(6 * bCols));
    Vector<T> c_7_v = new Vector<T>(c.Slice(7 * bCols));

    Vector<T> a_0_v, a_1_v, a_2_v, a_3_v, a_4_v, a_5_v, a_6_v, a_7_v;
    Vector<T> b_v;
    int offset = 0;

    for (int k = 0; k < p; k++)
    {
        a_0_v = new Vector<T>(a[k]);
        a_1_v = new Vector<T>(a[aCols + k]);
        a_2_v = new Vector<T>(a[2 * aCols + k]);
        a_3_v = new Vector<T>(a[3 * aCols + k]);
        a_4_v = new Vector<T>(a[4 * aCols + k]);
        a_5_v = new Vector<T>(a[5 * aCols + k]);
        a_6_v = new Vector<T>(a[6 * aCols + k]);
        a_7_v = new Vector<T>(a[7 * aCols + k]);

        b_v = new Vector<T>(b.Slice(offset));
        offset += bCols;

        c_0_v += a_0_v * b_v;
        c_1_v += a_1_v * b_v;
        c_2_v += a_2_v * b_v;
        c_3_v += a_3_v * b_v;
        c_4_v += a_4_v * b_v;
        c_5_v += a_5_v * b_v;
        c_6_v += a_6_v * b_v;
        c_7_v += a_7_v * b_v;
    }
    c_0_v.CopyTo(c);
    c_1_v.CopyTo(c.Slice(bCols));
    c_2_v.CopyTo(c.Slice(2 * bCols));
    c_3_v.CopyTo(c.Slice(3 * bCols));
    c_4_v.CopyTo(c.Slice(4 * bCols));
    c_5_v.CopyTo(c.Slice(5 * bCols));
    c_6_v.CopyTo(c.Slice(6 * bCols));
    c_7_v.CopyTo(c.Slice(7 * bCols));
}

性能测试结果如下,这里是Block参数为256x128的结果。性能有微小的提升,对于更大规模的矩阵提升会更大,但仍然是不能满意。

Shape 6, 8, 11 32, 64, 96 256, 512, 768 1000, 1000, 1000 2, 1024, 1 1024, 1, 1024
相对耗时 167% 9% 10% 9% 187% 77%

这里也放出对A进行Block的代码,实测性能提升是不如对B进行Block的,参数合适的时候才能有所提升,不合适的时候甚至可能性能下降。

public unsafe void Exec(ReadOnlySpan<T> a, ReadOnlySpan<T> b, Span<T> c, in NativeLayout layoutA, in NativeLayout layoutB, in NativeLayout layoutC)
{
    int aRows = layoutA._shape[0];
    int aCols = layoutA._shape[1];
    int bCols = layoutB._shape[1];

    for (int i = 0; i < aRows; i += _bh)
    {
        int height = (aRows - i) > _bh ? _bh : (aRows - i);
        for (int k = 0; k < aCols; k += _bw)
        {
            int width = (aCols - k) > _bw ? _bw : (aCols - k);
            ExecBlock(a.Slice(i * aCols + k), b.Slice(k * bCols), c.Slice(i * bCols), height, width, bCols, aRows, aCols, bCols);
        }
    }
}

// (m, p) * (p, n)
public static unsafe void ExecBlock(ReadOnlySpan<T> a, ReadOnlySpan<T> b, Span<T> c, int m, int p, int n, int aRows, int aCols, int bCols)
{
    int aIdx = 0;
    int bIdx;

    for (int i = 0; i <= m - 8; i += 8)
    {
        bIdx = 0;
        for (int j = 0; j <= bCols - 8; j += 8)
        {
            Kernel32b8x8(a.Slice(aIdx), b.Slice(bIdx), c.Slice(i * bCols + j), p, aRows, aCols, bCols);
            bIdx += 8;
        }
        aIdx += 8 * aCols;
    }
    for (int i = m / 8 * 8; i < m; i++)
    {
        for (int j = 0; j < bCols; j++)
        {
            for (int k = 0; k < p; k++)
            {
                c[i * bCols + j] += a[i * aCols + k] * b[k * bCols + j];
            }
        }
    }
    for (int i = 0; i < m / 8 * 8; i++)
    {
        for (int j = bCols / 8 * 8; j < bCols; j++)
        {
            for (int k = 0; k < p; k++)
            {
                c[i * bCols + j] += a[i * aCols + k] * b[k * bCols + j];
            }
        }
    }
}

public static unsafe void Kernel32b8x8(ReadOnlySpan<T> a, ReadOnlySpan<T> b, Span<T> c, int p, int aRows, int aCols, int bCols)
{
    Vector<T> c_0_v = new Vector<T>(c);
    Vector<T> c_1_v = new Vector<T>(c.Slice(bCols));
    Vector<T> c_2_v = new Vector<T>(c.Slice(2 * bCols));
    Vector<T> c_3_v = new Vector<T>(c.Slice(3 * bCols));
    Vector<T> c_4_v = new Vector<T>(c.Slice(4 * bCols));
    Vector<T> c_5_v = new Vector<T>(c.Slice(5 * bCols));
    Vector<T> c_6_v = new Vector<T>(c.Slice(6 * bCols));
    Vector<T> c_7_v = new Vector<T>(c.Slice(7 * bCols));

    Vector<T> a_0_v, a_1_v, a_2_v, a_3_v, a_4_v, a_5_v, a_6_v, a_7_v;
    Vector<T> b_v;
    int offset = 0;

    for (int k = 0; k < p; k++)
    {
        a_0_v = new Vector<T>(a[k]);
        a_1_v = new Vector<T>(a[aCols + k]);
        a_2_v = new Vector<T>(a[2 * aCols + k]);
        a_3_v = new Vector<T>(a[3 * aCols + k]);
        a_4_v = new Vector<T>(a[4 * aCols + k]);
        a_5_v = new Vector<T>(a[5 * aCols + k]);
        a_6_v = new Vector<T>(a[6 * aCols + k]);
        a_7_v = new Vector<T>(a[7 * aCols + k]);

        b_v = new Vector<T>(b.Slice(offset));
        offset += bCols;

        c_0_v += a_0_v * b_v;
        c_1_v += a_1_v * b_v;
        c_2_v += a_2_v * b_v;
        c_3_v += a_3_v * b_v;
        c_4_v += a_4_v * b_v;
        c_5_v += a_5_v * b_v;
        c_6_v += a_6_v * b_v;
        c_7_v += a_7_v * b_v;
    }
    c_0_v.CopyTo(c);
    c_1_v.CopyTo(c.Slice(bCols));
    c_2_v.CopyTo(c.Slice(2 * bCols));
    c_3_v.CopyTo(c.Slice(3 * bCols));
    c_4_v.CopyTo(c.Slice(4 * bCols));
    c_5_v.CopyTo(c.Slice(5 * bCols));
    c_6_v.CopyTo(c.Slice(6 * bCols));
    c_7_v.CopyTo(c.Slice(7 * bCols));
}

减少不必要的乘法运算

上面两部分的代码包含的向量化和Blocking思想我想大致是没有问题的,但是实现上有点太过耿直,尤其是最内层的kernel是性能攸关的地方却包含了太多没必要的乘法运算操作。我们把kernel改写如下之后再进行测试:

public static unsafe void Kernel32b8x8(ReadOnlySpan<T> a, ReadOnlySpan<T> b, Span<T> c, int p, int aRows, int aCols, int bCols)
{
    int offset = 0;
    int aIdx;
    int cIdx = bCols;

    Vector<T> c_0_v = new Vector<T>(c);
    Vector<T> c_1_v = new Vector<T>(c.Slice(cIdx));
    cIdx += bCols;
    Vector<T> c_2_v = new Vector<T>(c.Slice(cIdx));
    cIdx += bCols;
    Vector<T> c_3_v = new Vector<T>(c.Slice(cIdx));
    cIdx += bCols;
    Vector<T> c_4_v = new Vector<T>(c.Slice(cIdx));
    cIdx += bCols;
    Vector<T> c_5_v = new Vector<T>(c.Slice(cIdx));
    cIdx += bCols;
    Vector<T> c_6_v = new Vector<T>(c.Slice(cIdx));
    cIdx += bCols;
    Vector<T> c_7_v = new Vector<T>(c.Slice(cIdx));

    Vector<T> b_v;

    for (int k = 0; k < p; k++)
    {
        b_v = new Vector<T>(b.Slice(offset));
        offset += bCols;
        aIdx = k;

        c_0_v += Vector.Multiply(b_v, a[aIdx]);
        aIdx += aCols;
        c_1_v += Vector.Multiply(b_v, a[aIdx]);
        aIdx += aCols;
        c_2_v += Vector.Multiply(b_v, a[aIdx]);
        aIdx += aCols;
        c_3_v += Vector.Multiply(b_v, a[aIdx]);
        aIdx += aCols;
        c_4_v += Vector.Multiply(b_v, a[aIdx]);
        aIdx += aCols;
        c_5_v += Vector.Multiply(b_v, a[aIdx]);
        aIdx += aCols;
        c_6_v += Vector.Multiply(b_v, a[aIdx]);
        aIdx += aCols;
        c_7_v += Vector.Multiply(b_v, a[aIdx]);
    }
    cIdx = bCols;
    c_0_v.CopyTo(c);
    c_1_v.CopyTo(c.Slice(cIdx));
    cIdx += bCols;
    c_2_v.CopyTo(c.Slice(cIdx));
    cIdx += bCols;
    c_3_v.CopyTo(c.Slice(cIdx));
    cIdx += bCols;
    c_4_v.CopyTo(c.Slice(cIdx));
    cIdx += bCols;
    c_5_v.CopyTo(c.Slice(cIdx));
    cIdx += bCols;
    c_6_v.CopyTo(c.Slice(cIdx));
    cIdx += bCols;
    c_7_v.CopyTo(c.Slice(cIdx));
}

性能测试结果如下,我们主要想测的部分耗时已经都到了10%以下,所以保留一位小数。可以看出,性能还是有明显地进一步提升的,虽然提升仍然不够。

Shape 6, 8, 11 32, 64, 96 256, 512, 768 1000, 1000, 1000 2, 1024, 1 1024, 1, 1024
相对耗时 170.3% 8.3% 8.1% 7.6% 225.0% 81.6%

其实,在凑不出8x8的情况下我们还是做了不少低效的操作,应该对此进行优化,但一方面,上述性能测试主要部分(中等以上的三种)的shape都是可以被8整除的,这个其实不太影响测试结果;另一方面,为了下一步优化的简洁,所以这个优化放到最后。

添加Pack预处理

前面做Blocking主要原因就是要维护局部性,但仍然有个问题,那就是每次在最内部的8x8 kernel里面,从读取一行8个元素换到读取下一行8个元素的时候,两次读取不是连续的内存,数据肯定不会在同一条cache line上面,但如果我们预处理一下,把这一小块矩阵内存提前取出来并且变成连续的,那么就可以大大增加在同一条cache line上面的概率。

因为我们是行主序的矩阵,所以我们先对B进行Pack,这样效果更显著,代码如下:

public unsafe void Exec(ReadOnlySpan<T> a, ReadOnlySpan<T> b, Span<T> c, in NativeLayout layoutA, in NativeLayout layoutB, in NativeLayout layoutC)
{
    int aRows = layoutA._shape[0];
    int aCols = layoutA._shape[1];
    int bCols = layoutB._shape[1];
    int bIdx, cIdx;

    for (int k = 0; k < aCols; k += _bh)
    {
        int height = (aCols - k) > _bh ? _bh : (aCols - k);
        cIdx = 0;
        bIdx = k * bCols;
        for (int j = 0; j < bCols; j += _bw, cIdx += _bw, bIdx += _bw)
        {
            int width = (bCols - j) > _bw ? _bw : (bCols - j);
            ExecBlock(a.Slice(k), b.Slice(bIdx), c.Slice(cIdx), aRows, height, width, aRows, aCols, bCols);

        }
    }
}

// (m, p) * (p, n)
private static unsafe void ExecBlock(ReadOnlySpan<T> a, ReadOnlySpan<T> b, Span<T> c, int m, int p, int n, int aRows, int aCols, int bCols)
{
    int aIdx, bIdx, cIdx;
    T[] packedB = new T[n * p];
    var packedSpanB = packedB.AsSpan();

    for (int i = 0; i <= m - 8; i += 8)
    {
        cIdx = i * bCols;
        aIdx = i * aCols;
        for (int j = 0; j <= n - 8; j += 8, cIdx += 8)
        {
            bIdx = j * p;
            if (i == 0) PackMatrixBWith8xP(b.Slice(j), packedSpanB.Slice(bIdx), p, bCols);
            Kernel32b8x8(a.Slice(aIdx), packedSpanB.Slice(bIdx), c.Slice(cIdx), p, 8, aRows, aCols, bCols);
        }
    }
    for (int i = m / 8 * 8; i < m; i++)
    {
        for (int j = 0; j < n; j++)
        {
            for (int k = 0; k < p; k++)
            {
                        c[i * bCols + j] += a[i * aCols + k] * b[k * bCols + j];
            }
        }
    }
    for (int i = 0; i < m / 8 * 8; i++)
    {
        for (int j = n / 8 * 8; j < n; j++)
        {
            for (int k = 0; k < p; k++)
            {
                        c[i * bCols + j] += a[i * aCols + k] * b[k * bCols + j];
            }
        }
    }
}

private static void PackMatrixBWith8xP(ReadOnlySpan<T> src, Span<T> dst, int p, int bCols)
{
    int dstIdx = 0;
    Vector<T> data;
    for(int k = 0; k < p; k++, dstIdx += 8)
    {
        data = new Vector<T>(src.Slice(k * bCols));
        data.CopyTo(dst.Slice(dstIdx));
    }
}

private static unsafe void Kernel32b8x8(ReadOnlySpan<T> a, ReadOnlySpan<T> b, Span<T> c, int p, int n, int aRows, int aCols, int bCols)
{
    int offset = 0;
    int aIdx;
    int cIdx = bCols;

    Vector<T> c_0_v = new Vector<T>(c);
    Vector<T> c_1_v = new Vector<T>(c.Slice(cIdx));
    cIdx += bCols;
    Vector<T> c_2_v = new Vector<T>(c.Slice(cIdx));
    cIdx += bCols;
    Vector<T> c_3_v = new Vector<T>(c.Slice(cIdx));
    cIdx += bCols;
    Vector<T> c_4_v = new Vector<T>(c.Slice(cIdx));
    cIdx += bCols;
    Vector<T> c_5_v = new Vector<T>(c.Slice(cIdx));
    cIdx += bCols;
    Vector<T> c_6_v = new Vector<T>(c.Slice(cIdx));
    cIdx += bCols;
    Vector<T> c_7_v = new Vector<T>(c.Slice(cIdx));

    Vector<T> b_v;

    for (int k = 0; k < p; k++)
    {
        b_v = new Vector<T>(b.Slice(offset));
        offset += n;
        aIdx = k;

        c_0_v += Vector.Multiply(b_v, a[aIdx]);
        aIdx += aCols;
        c_1_v += Vector.Multiply(b_v, a[aIdx]);
        aIdx += aCols;
        c_2_v += Vector.Multiply(b_v, a[aIdx]);
        aIdx += aCols;
        c_3_v += Vector.Multiply(b_v, a[aIdx]);
        aIdx += aCols;
        c_4_v += Vector.Multiply(b_v, a[aIdx]);
        aIdx += aCols;
        c_5_v += Vector.Multiply(b_v, a[aIdx]);
        aIdx += aCols;
        c_6_v += Vector.Multiply(b_v, a[aIdx]);
        aIdx += aCols;
        c_7_v += Vector.Multiply(b_v, a[aIdx]);
    }
    cIdx = bCols;
    c_0_v.CopyTo(c);
    c_1_v.CopyTo(c.Slice(cIdx));
    cIdx += bCols;
    c_2_v.CopyTo(c.Slice(cIdx));
    cIdx += bCols;
    c_3_v.CopyTo(c.Slice(cIdx));
    cIdx += bCols;
    c_4_v.CopyTo(c.Slice(cIdx));
    cIdx += bCols;
    c_5_v.CopyTo(c.Slice(cIdx));
    cIdx += bCols;
    c_6_v.CopyTo(c.Slice(cIdx));
    cIdx += bCols;
    c_7_v.CopyTo(c.Slice(cIdx));
}

测试结果如下:

Shape 6, 8, 11 32, 64, 96 256, 512, 768 1000, 1000, 1000 2, 1024, 1 1024, 1, 1024
相对耗时 174.3% 9.2% 7.2% 6.1% 248.0% 55.2%

再加上对A的Pack也是可以的,但是实测性能并不会提升,反而会略有下降,这里便不再赘述,加了对A的Pack的代码如下:

public unsafe void Exec(ReadOnlySpan<T> a, ReadOnlySpan<T> b, Span<T> c, in NativeLayout layoutA, in NativeLayout layoutB, in NativeLayout layoutC)
{
    int aRows = layoutA._shape[0];
    int aCols = layoutA._shape[1];
    int bCols = layoutB._shape[1];
    int bIdx, cIdx;

    T[] packedB = new T[_bw * _bh];
    T[] packedA = new T[aRows * _bh];
    var packedSpanB = packedB.AsSpan();
    var packedSpanA = packedA.AsSpan();
    for (int k = 0; k < aCols; k += _bh)
    {
        int height = (aCols - k) > _bh ? _bh : (aCols - k);
        cIdx = 0;
        bIdx = k * bCols;
        for (int j = 0; j < bCols; j += _bw, cIdx += _bw, bIdx += _bw)
        {
            int width = (bCols - j) > _bw ? _bw : (bCols - j);
            ExecBlock(a.Slice(k), b.Slice(bIdx), c.Slice(cIdx), aRows, height, width, aRows, aCols, bCols, packedSpanB, packedSpanA, j == 0);
        }
    }
}

// (m, p) * (p, n)
private static unsafe void ExecBlock(ReadOnlySpan<T> a, ReadOnlySpan<T> b, Span<T> c, int m, int p, int n, int aRows, int aCols, int bCols, Span<T> packedSpanB, Span<T> packedSpanA, bool packA)
{
    int aIdx, packedAIdx, cIdx;

    for (int i = 0; i <= m - 8; i += 8)
    {
        cIdx = i * bCols;
        aIdx = i * aCols;
        packedAIdx = i * p;
        if (packA) PackMatrixAWith8xP(a.Slice(aIdx), packedSpanA.Slice(i * p), p, aCols);
        for (int j = 0; j <= n - 8; j += 8, cIdx += 8)
        {
            if (i == 0) PackMatrixBWith8xP(b.Slice(j), packedSpanB.Slice(j * p), p, bCols);
            Kernel32b8x8(packedSpanA.Slice(packedAIdx), packedSpanB.Slice(j * p), c.Slice(cIdx), 8, p, 8, aRows, aCols, bCols);

        }
    }
    for (int i = m / 8 * 8; i < m; i++)
    {
        for (int j = 0; j < n; j++)
        {
            for (int k = 0; k < p; k++)
            {
                c[i * bCols + j] += a[i * aCols + k] * b[k * bCols + j];
            }
        }
    }
    for (int i = 0; i < m / 8 * 8; i++)
    {
        for (int j = n / 8 * 8; j < n; j++)
        {
            for (int k = 0; k < p; k++)
            {
                c[i * bCols + j] += a[i * aCols + k] * b[k * bCols + j];
            }
        }
    }
}

private static void PackMatrixBWith8xP(ReadOnlySpan<T> src, Span<T> dst, int p, int bCols)
{
    int dstIdx = 0;
    Vector<T> data;
    for (int k = 0; k < p; k++, dstIdx += 8)
    {
        data = new Vector<T>(src.Slice(k * bCols));
        data.CopyTo(dst.Slice(dstIdx));
    }
}

private static void PackMatrixAWith8xP(ReadOnlySpan<T> src, Span<T> dst, int p, int aCols)
{
    int dstIdx = 0;
    int srcIdx;
    for (int k = 0; k < p; k++)
    {
        srcIdx = k;
        dst[dstIdx++] = src[srcIdx];
        srcIdx += aCols;
        dst[dstIdx++] = src[srcIdx];
        srcIdx += aCols;
        dst[dstIdx++] = src[srcIdx];
        srcIdx += aCols;
        dst[dstIdx++] = src[srcIdx];
        srcIdx += aCols;
        dst[dstIdx++] = src[srcIdx];
        srcIdx += aCols;
        dst[dstIdx++] = src[srcIdx];
        srcIdx += aCols;
        dst[dstIdx++] = src[srcIdx];
        srcIdx += aCols;
        dst[dstIdx++] = src[srcIdx];
    }
}

private static unsafe void Kernel32b8x8(ReadOnlySpan<T> a, ReadOnlySpan<T> b, Span<T> c, int m, int p, int n, int aRows, int aCols, int bCols)
{
    int offset = 0;
    int aIdx;
    int cIdx = bCols;

    Vector<T> c_0_v = new Vector<T>(c);
    Vector<T> c_1_v = new Vector<T>(c.Slice(cIdx));
    cIdx += bCols;
    Vector<T> c_2_v = new Vector<T>(c.Slice(cIdx));
    cIdx += bCols;
    Vector<T> c_3_v = new Vector<T>(c.Slice(cIdx));
    cIdx += bCols;
    Vector<T> c_4_v = new Vector<T>(c.Slice(cIdx));
    cIdx += bCols;
    Vector<T> c_5_v = new Vector<T>(c.Slice(cIdx));
    cIdx += bCols;
    Vector<T> c_6_v = new Vector<T>(c.Slice(cIdx));
    cIdx += bCols;
    Vector<T> c_7_v = new Vector<T>(c.Slice(cIdx));

    Vector<T> b_v;

    for (int k = 0; k < p; k++)
    {
        b_v = new Vector<T>(b.Slice(offset));
        offset += n;
        aIdx = k * m;

        c_0_v += Vector.Multiply(b_v, a[aIdx++]);
        c_1_v += Vector.Multiply(b_v, a[aIdx++]);
        c_2_v += Vector.Multiply(b_v, a[aIdx++]);
        c_3_v += Vector.Multiply(b_v, a[aIdx++]);
        c_4_v += Vector.Multiply(b_v, a[aIdx++]);
        c_5_v += Vector.Multiply(b_v, a[aIdx++]);
        c_6_v += Vector.Multiply(b_v, a[aIdx++]);
        c_7_v += Vector.Multiply(b_v, a[aIdx++]);
    }
    cIdx = bCols;
    c_0_v.CopyTo(c);
    c_1_v.CopyTo(c.Slice(cIdx));
    cIdx += bCols;
    c_2_v.CopyTo(c.Slice(cIdx));
    cIdx += bCols;
    c_3_v.CopyTo(c.Slice(cIdx));
    cIdx += bCols;
    c_4_v.CopyTo(c.Slice(cIdx));
    cIdx += bCols;
    c_5_v.CopyTo(c.Slice(cIdx));
    cIdx += bCols;
    c_6_v.CopyTo(c.Slice(cIdx));
    cIdx += bCols;
    c_7_v.CopyTo(c.Slice(cIdx));
}

容易看出,Pack A方式最大的优点就是在pack的时候可以顺便转置用到的那一小块矩阵,这样在kernel里面取a中的一列8个元素时,在内存上就是连续的。不过实测中这一点好处是没法抵消pack A的时间消耗的。在Arm的NEON指令集中,有一个vfmaq_laneq_f32的指令,可以通过一个数字控制向量A和向量B中的第几个元素相乘,比较方便地完成向量与标量之间的乘法运算,通过这种方式就可以做到一次性加载8个元素,提高计算访存比。可惜x86这边我暂时没找到对应的指令(也可能是因为我对这些不是很熟悉),只能放弃Pack A的方案。

FMA与指针优化

how-to-optimize-gemm一文中,优化到此结束,但其实这里还有很大的优化空间,一方面是FMA指令坐乘加运算,另一方面是C#本身这里,把span换成指针也可以带来提升。

这里要说明的是,C#/.NET中的Span<T>是非常高效的,如果只是看访存内部元素的速度的话,实测和指针访问几乎一样的速度(低一些的dotnet版本可能会有微小差距),并且Span<T>是一种更高级的抽象,无论是长度信息、切片还是越界检查,都给我们开发和debug带来了极大的方便,尤其是越界检查这一点,比c++简直不要好太多。我们为什么在这里要换成指针,一方面是因为dotnet 7及更低版本中要用FMA指令,必须用到指针;另一方面我们的实现中包含micro kernel,每次都需要对Span进行切片,切片操作本质上就是建立一个新的Span对象,多以这里会带来一定的开销,在我们基本排除bug之后,为了极致的性能还是要换成指针。

FMA的实现版本如下所示,因为我们主要就是要测32位浮点数的运算效率,所以这里实现就比较粗暴了,真正工程实现应该考虑到数据类型的分类讨论。

public unsafe void Exec(ReadOnlySpan<T> a, ReadOnlySpan<T> b, Span<T> c, in NativeLayout layoutA, in NativeLayout layoutB, in NativeLayout layoutC)
{
    int aRows = layoutA._shape[0];
    int aCols = layoutA._shape[1];
    int bCols = layoutB._shape[1];
    int bIdx, cIdx = 0;

    var packedB = new T[_bh * _bw];
    var packedSpanB = packedB.AsSpan();
    fixed (T* ptrA = a, ptrB = b, ptrC = c, ptrPackedB = packedSpanB)
    {
        for (int j = 0; j < bCols; j += _bw, cIdx += _bw)
        {
            int width = (bCols - j) > _bw ? _bw : (bCols - j);
            for (int k = 0; k < aCols; k += _bh) // exchange order 
            {
                int height = (aCols - k) > _bh ? _bh : (aCols - k);
                bIdx = j + k * bCols;
                ExecBlock(ptrA + k, ptrB + bIdx, ptrC + cIdx, aRows, height, width, aRows, aCols, bCols, ptrPackedB);
            }
        }
    }
}

// (m, p) * (p, n)
private static unsafe void ExecBlock(T* a, T* b, T* c, int m, int p, int n, int aRows, int aCols, int bCols, T* packedSpanB)
{
    int aIdx, bIdx, cIdx;
    float* fa = (float*)a;
    float* fb = (float*)packedSpanB;
    float* fc = (float*)c;

    for (int i = 0; i <= m - 8; i += 8)
    {
        cIdx = i * bCols;
        aIdx = i * aCols;
        for (int j = 0; j <= n - 8; j += 8, cIdx += 8)
        {
            bIdx = j * p;
            if (i == 0) PackMatrixBWith8xP(b + j, packedSpanB + bIdx, p, bCols);
            Kernel32b8x8(fa + aIdx, fb + bIdx, fc + cIdx, p, 8, aRows, aCols, bCols);
        }
    }
    for (int i = m / 8 * 8; i < m; i++)
    {
        for (int j = 0; j < n; j++)
        {
            for (int k = 0; k < p; k++)
            {
                c[i * bCols + j] += a[i * aCols + k] * b[k * bCols + j];
            }
        }
    }
    for (int i = 0; i < m / 8 * 8; i++)
    {
        for (int j = n / 8 * 8; j < n; j++)
        {
            for (int k = 0; k < p; k++)
            {
                c[i * bCols + j] += a[i * aCols + k] * b[k * bCols + j];
            }
        }
    }
}

private unsafe static void PackMatrixBWith8xP(T* src, T* dst, int p, int bCols)
{
    int dstIdx = 0;
    Vector256<T> data;
    for (int k = 0; k < p; k++, dstIdx += 8)
    {
        data = Vector256.Load(src + k * bCols);
        data.Store(dst + dstIdx);
    }
}

private static unsafe void Kernel32b8x8(float* a, float* b, float* c, int p, int n, int aRows, int aCols, int bCols)
{
    var cFrom = c;

    var c_0_v = Vector256.Load(cFrom);
    cFrom += bCols;
    var c_1_v = Vector256.Load(cFrom);
    cFrom += bCols;
    var c_2_v = Vector256.Load(cFrom);
    cFrom += bCols;
    var c_3_v = Vector256.Load(cFrom);
    cFrom += bCols;
    var c_4_v = Vector256.Load(cFrom);
    cFrom += bCols;
    var c_5_v = Vector256.Load(cFrom);
    cFrom += bCols;
    var c_6_v = Vector256.Load(cFrom);
    cFrom += bCols;
    var c_7_v = Vector256.Load(cFrom);

    Vector256<float> b_v;

    for (int k = 0; k < p; k++)
    {
        b_v = Vector256.Load(b);
        b += n;
        var ta = a + k;

        c_0_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta), c_0_v);
        ta += aCols;
        c_1_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta), c_1_v);
        ta += aCols;
        c_2_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta), c_2_v);
        ta += aCols;
        c_3_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta), c_3_v);
        ta += aCols;
        c_4_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta), c_4_v);
        ta += aCols;
        c_5_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta), c_5_v);
        ta += aCols;
        c_6_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta), c_6_v);
        ta += aCols;
        c_7_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta), c_7_v);
    }

    c_0_v.Store(c);
    c += bCols;
    c_1_v.Store(c);
    c += bCols;
    c_2_v.Store(c);
    c += bCols;
    c_3_v.Store(c);
    c += bCols;
    c_4_v.Store(c);
    c += bCols;
    c_5_v.Store(c);
    c += bCols;
    c_6_v.Store(c);
    c += bCols;
    c_7_v.Store(c);
}

我们得到的性能测试结果如下(这里block size变更为64x64):

Shape 6, 8, 11 32, 64, 96 256, 512, 768 1000, 1000, 1000 2, 1024, 1 1024, 1, 1024
相对耗时 224.7% 5.5% 4.0% 4.5% 251.7% 61.2%

另外,也给出Pack A的Fma方案,性能会略低一点:

public unsafe void Exec(ReadOnlySpan<T> a, ReadOnlySpan<T> b, Span<T> c, in NativeLayout layoutA, in NativeLayout layoutB, in NativeLayout layoutC)
{
    // The array should be contiguous here
    int aRows = layoutA._shape[0];
    int aCols = layoutA._shape[1];
    int bCols = layoutB._shape[1];
    int bIdx, cIdx = 0;

    T[] packedB = new T[_bw * _bh];
    T[] packedA = new T[aRows * _bh];
    var packedSpanB = packedB.AsSpan();
    var packedSpanA = packedA.AsSpan();
    fixed (T* ptrA = a, ptrB = b, ptrC = c, ptrPackedB = packedSpanB, ptrPackedA = packedSpanA)
    {
        for (int k = 0; k < aCols; k += _bh)
        {
            int height = (aCols - k) > _bh ? _bh : (aCols - k);
            cIdx = 0;
            bIdx = k * bCols;
            for (int j = 0; j < bCols; j += _bw, cIdx += _bw, bIdx += _bw)
            {
                int width = (bCols - j) > _bw ? _bw : (bCols - j);
                ExecBlock(ptrA + k, ptrB + bIdx, ptrC + cIdx, aRows, height, width, aRows, aCols, bCols, ptrPackedB, ptrPackedA, j == 0);
            }
        }
    }
}

// (m, p) * (p, n)
private static unsafe void ExecBlock(T* a, T* b, T* c, int m, int p, int n, int aRows, int aCols, int bCols, T* packedSpanB, T* packedSpanA, bool packA)
{
    int aIdx, packedAIdx, cIdx;
    float* fa = (float*)packedSpanA;
    float* fb = (float*)packedSpanB;
    float* fc = (float*)c;

    for (int i = 0; i <= m - 8; i += 8)
    {
        cIdx = i * bCols;
        aIdx = i * aCols;
        packedAIdx = i * p;
        if (packA) PackMatrixAWith8xP(a + aIdx, packedSpanA + i * p, p, aCols);
        for (int j = 0; j <= n - 8; j += 8, cIdx += 8)
        {
            if (i == 0) PackMatrixBWith8xP(b + j, packedSpanB + j * p, p, bCols);
            Kernel32b8x8(fa + packedAIdx, fb + j * p, fc + cIdx, 8, p, 8, aRows, aCols, bCols);
        }
    }

    for (int i = m / 8 * 8; i < m; i++)
    {
        for (int j = 0; j < n; j++)
        {
            for (int k = 0; k < p; k++)
            {
                c[i * bCols + j] += a[i * aCols + k] * b[k * bCols + j];
            }
        }
    }
    for (int i = 0; i < m / 8 * 8; i++)
    {
        for (int j = n / 8 * 8; j < n; j++)
        {
            for (int k = 0; k < p; k++)
            {
                c[i * bCols + j] += a[i * aCols + k] * b[k * bCols + j];
            }
        }
    }
}

private unsafe static void PackMatrixBWith8xP(T* src, T* dst, int p, int bCols)
{
    int dstIdx = 0;
    Vector256<T> data;
    for (int k = 0; k < p; k++, dstIdx += 8)
    {
        data = Vector256.Load(src + k * bCols);
        data.Store(dst + dstIdx);
    }
}

private unsafe static void PackMatrixAWith8xP(T* src, T* dst, int p, int aCols)
{
    int dstIdx = 0;
    int srcIdx;
    for (int k = 0; k < p; k++)
    {
        srcIdx = k;
        dst[dstIdx++] = src[srcIdx];
        srcIdx += aCols;
        dst[dstIdx++] = src[srcIdx];
        srcIdx += aCols;
        dst[dstIdx++] = src[srcIdx];
        srcIdx += aCols;
        dst[dstIdx++] = src[srcIdx];
        srcIdx += aCols;
        dst[dstIdx++] = src[srcIdx];
        srcIdx += aCols;
        dst[dstIdx++] = src[srcIdx];
        srcIdx += aCols;
        dst[dstIdx++] = src[srcIdx];
        srcIdx += aCols;
        dst[dstIdx++] = src[srcIdx];
    }
}

private static unsafe void Kernel32b8x8(float* a, float* b, float* c, int m, int p, int n, int aRows, int aCols, int bCols)
{
    int aIdx;
    var cFrom = c;

    var c_0_v = Vector256.Load(cFrom);
    cFrom += bCols;
    var c_1_v = Vector256.Load(cFrom);
    cFrom += bCols;
    var c_2_v = Vector256.Load(cFrom);
    cFrom += bCols;
    var c_3_v = Vector256.Load(cFrom);
    cFrom += bCols;
    var c_4_v = Vector256.Load(cFrom);
    cFrom += bCols;
    var c_5_v = Vector256.Load(cFrom);
    cFrom += bCols;
    var c_6_v = Vector256.Load(cFrom);
    cFrom += bCols;
    var c_7_v = Vector256.Load(cFrom);

    Vector256<float> b_v;

    for (int k = 0; k < p; k++)
    {
        b_v = Vector256.Load(b);
        b += n;
        var ta = a + k * m;

        c_0_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta++), c_0_v);
        c_1_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta++), c_1_v);
        c_2_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta++), c_2_v);
        c_3_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta++), c_3_v);
        c_4_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta++), c_4_v);
        c_5_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta++), c_5_v);
        c_6_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta++), c_6_v);
        c_7_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta++), c_7_v);
    }

    c_0_v.Store(c);
    c += bCols;
    c_1_v.Store(c);
    c += bCols;
    c_2_v.Store(c);
    c += bCols;
    c_3_v.Store(c);
    c += bCols;
    c_4_v.Store(c); 
    c += bCols;
    c_5_v.Store(c);
    c += bCols;
    c_6_v.Store(c);
    c += bCols;
    c_7_v.Store(c);
}

处理边角情况

有的情况下,矩阵尺寸是没法整除以8的,这时候需要对这些case也做一下加速,而不是直接用Naive的方式去处理。

这里目前实现的比较简单,如果不能8x8,那就凑1x8,然后剩余部分用unrolling的方式加速(其实应该再细分为4x8等好一点),代码如下:

public unsafe void Exec(ReadOnlySpan<T> a, ReadOnlySpan<T> b, Span<T> c, in NativeLayout layoutA, in NativeLayout layoutB, in NativeLayout layoutC)
{
    // The array should be contiguous here
    int aRows = layoutA._shape[0];
    int aCols = layoutA._shape[1];
    int bCols = layoutB._shape[1];
    int bIdx, cIdx = 0;

    var packedB = new T[_bh * _bw];
    var packedSpanB = packedB.AsSpan();
    fixed (T* ptrA = a, ptrB = b, ptrC = c, ptrPackedB = packedSpanB)
    {
        for (int j = 0; j < bCols; j += _bw, cIdx += _bw)
        {
            int width = (bCols - j) > _bw ? _bw : (bCols - j);
            for (int k = 0; k < aCols; k += _bh) // exchange order 
            {
                int height = (aCols - k) > _bh ? _bh : (aCols - k);
                bIdx = j + k * bCols;
                ExecBlock(ptrA + k, ptrB + bIdx, ptrC + cIdx, aRows, height, width, aRows, aCols, bCols, ptrPackedB);
            }
        }
    }
}

// (m, p) * (p, n)
private static unsafe void ExecBlock(T* a, T* b, T* c, int m, int p, int n, int aRows, int aCols, int bCols, T* packedSpanB)
{
    int aIdx, bIdx, cIdx;
    float* fa = (float*)a;
    float* fb = (float*)packedSpanB;
    float* fc = (float*)c;
    bool hasPacked = false;

    for (int i = 0; i <= m - 8; i += 8)
    {
        cIdx = i * bCols;
        aIdx = i * aCols;
        for (int j = 0; j <= n - 8; j += 8, cIdx += 8)
        {
            bIdx = j * p;
            if (i == 0)
            {
                PackMatrixBWithPx8(b + j, packedSpanB + bIdx, p, bCols);
                hasPacked = true;
            }
            Kernel32b8x8(fa + aIdx, fb + bIdx, fc + cIdx, p, 8, aRows, aCols, bCols);
        }
    }
    int iStart = m / 8 * 8;
    aIdx = iStart * aCols;
    for (int i = iStart; i < m; i++, aIdx += aCols)
    {
        cIdx = i * bCols;
        for (int j = 0; j <= n - 8; j += 8, cIdx += 8)
        {
            if (!hasPacked && i == iStart) PackMatrixBWithPx8(b + j, packedSpanB + j * p, p, bCols);
            Kernel32b1x8(fa + aIdx, fb + j * p, fc + cIdx, p, 8, aRows, aCols, bCols);
        }
    }
    int jStart = n / 8 * 8;
    T c0, c1, c2, c3;
    for (int i = 0; i <= m - 4; i += 4)
    {
        for (int j = jStart; j < n; j++)
        {
            int aIdx0 = i * aCols, aIdx1 = aIdx0 + aCols, aIdx2 = aIdx1 + aCols, aIdx3 = aIdx2 + aCols;
            cIdx = i * bCols + j;
            c0 = c[cIdx];
            cIdx += bCols;
            c1 = c[cIdx];
            cIdx += bCols;
            c2 = c[cIdx];
            cIdx += bCols;
            c3 = c[cIdx];

            for (int k = 0; k < p; k++)
            {
                var bValue = b[k * bCols + j];
                c0 += a[aIdx0++] * bValue;
                c1 += a[aIdx1++] * bValue;
                c2 += a[aIdx2++] * bValue;
                c3 += a[aIdx3++] * bValue;
            }
            cIdx = i * bCols + j;
            c[cIdx] = c0;
            cIdx += bCols;
            c[cIdx] = c1;
            cIdx += bCols;
            c[cIdx] = c2;
            cIdx += bCols;
            c[cIdx] = c3;
        }
    }
    for (int i = m / 4 * 4; i < m; i++)
    {
        for (int k = 0; k < p; k++)
        {
            cIdx = i * bCols + jStart;
            bIdx = k * bCols + jStart;
            var aValue = a[i * aCols + k];
            for (int j = jStart; j < n; j++)
            {
                c[cIdx++] += aValue * b[bIdx++];
            }
        }
    }
}

private unsafe static void PackMatrixBWithPx8(T* src, T* dst, int p, int bCols)
{
            Vector256<T> data;
            for (int k = 0; k < p; k++, dst += 8, src += bCols)
            {
                data = Vector256.Load(src);
                data.Store(dst);
            }
}

private static unsafe void Kernel32b8x8(float* a, float* b, float* c, int p, int n, int aRows, int aCols, int bCols)
{
    var cFrom = c;

    var c_0_v = Avx2.LoadVector256(cFrom);
    cFrom += bCols;
    var c_1_v = Avx2.LoadVector256(cFrom);
    cFrom += bCols;
    var c_2_v = Avx2.LoadVector256(cFrom);
    cFrom += bCols;
    var c_3_v = Avx2.LoadVector256(cFrom);
    cFrom += bCols;
    var c_4_v = Avx2.LoadVector256(cFrom);
    cFrom += bCols;
    var c_5_v = Avx2.LoadVector256(cFrom);
    cFrom += bCols;
    var c_6_v = Avx2.LoadVector256(cFrom);
    cFrom += bCols;
    var c_7_v = Avx2.LoadVector256(cFrom);

    Vector256<float> b_v;

    for (int k = 0; k < p; k++)
    {
        b_v = Avx2.LoadVector256(b);
        b += n;
        var ta = a + k;

        c_0_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta), c_0_v);
        ta += aCols;
        c_1_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta), c_1_v);
        ta += aCols;
        c_2_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta), c_2_v);
        ta += aCols;
        c_3_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta), c_3_v);
        ta += aCols;
        c_4_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta), c_4_v);
        ta += aCols;
        c_5_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta), c_5_v);
        ta += aCols;
        c_6_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta), c_6_v);
        ta += aCols;
        c_7_v = Fma.MultiplyAdd(b_v, Vector256.Create(*ta), c_7_v);
    }

    Avx2.Store(c, c_0_v);
    c += bCols;
    Avx2.Store(c, c_1_v);
    c += bCols;
    Avx2.Store(c, c_2_v);
    c += bCols;
    Avx2.Store(c, c_3_v);
    c += bCols;
    Avx2.Store(c, c_4_v);
    c += bCols;
    Avx2.Store(c, c_5_v);
    c += bCols;
    Avx2.Store(c, c_6_v);
    c += bCols;
    Avx2.Store(c, c_7_v);
}

为了直观,这里特意挑了一些不能被8整除的尺寸来测试性能,跟之前没处理边角的实现做对比,可以看出性能是有全面提升的,对于某些尺寸甚至可以提升六七倍。

Method Dimensions Mean Error StdDev Gen0 Gen1 Gen2 Allocated
OldMatmul (1023,1) (1,50) 70,615.8 ns 1,347.03 ns 1,441.30 ns 62.3779 62.3779 62.3779 216 KB
NewMatmul (1023,1) (1,50) 67,714.8 ns 954.01 ns 892.38 ns 62.3779 62.3779 62.3779 216 KB
OldMatmul (125,125) (125,125) 255,549.1 ns 1,036.57 ns 918.89 ns 5.8594 0.9766 - 77.22 KB
NewMatmul (125,125) (125,125) 80,186.4 ns 1,514.84 ns 1,416.99 ns 5.9814 1.0986 - 77.22 KB
OldMatmul (2,939) (939,50) 131,929.5 ns 855.23 ns 758.14 ns 1.2207 - - 16.57 KB
NewMatmul (2,939) (939,50) 16,574.9 ns 132.95 ns 124.36 ns 1.2817 - - 16.57 KB
OldMatmul (30,65) (65,91) 64,089.7 ns 505.97 ns 448.53 ns 2.0752 - - 26.84 KB
NewMatmul (30,65) (65,91) 9,435.6 ns 99.40 ns 92.98 ns 2.0905 0.0305 - 26.84 KB
OldMatmul (50,939) (939,1) 60,769.8 ns 355.89 ns 332.90 ns 1.2207 - - 16.38 KB
NewMatmul (50,939) (939,1) 21,044.7 ns 124.23 ns 116.21 ns 1.2512 - - 16.38 KB
OldMatmul (6,7) (7,11) 932.7 ns 9.20 ns 8.60 ns 1.2846 0.0010 - 16.44 KB
NewMatmul (6,7) (7,11) 700.0 ns 7.07 ns 6.27 ns 1.2846 0.0010 - 16.44 KB
OldMatmul (67,1) (1,789) 70,884.6 ns 779.27 ns 728.93 ns 66.6504 66.6504 66.6504 222.7 KB
NewMatmul (67,1) (1,789) 68,316.9 ns 588.33 ns 521.54 ns 66.6504 66.6504 66.6504 222.7 KB

多线程优化

后续更新