𝚲

重来TDD:矩阵

Transpose the Mind

比起上一篇锦上添花的 Printf矩阵变换 堪称 TDD 奥义的体现1

人工智能作为现象级的热点,忽然间网上懂神经网络的人,比懂写程序的人还多。好处是,让本文的阅读门槛低了不少。

向量的类型表示

聊矩阵的前提,是要确立向量的类型表示。有点编程经验的人,脑海里可能立马想到了用 Array(数组) 类型来表示 2。比如一个三维向量,用数组会是这样的:

scala> val vect3d = Array(1, 2, 3)
val vect3d: Array[Int] = Array(1, 2, 3)

很遗憾,上面的表示缺失了向量 维度的数量 ,这个至关重要。要把维度的信息添加上,向量的类型表示则会是 Vect[N <: Int, A], 且 N 必须是整数。若以 ADT 的形式来呈现 3,则是:

import scala.compiletime.ops.int.S

enum Vect[N <: Int, +A]:
  case `[]` extends Vect[0, Nothing]
  case ::[N <: Int, +A](
    x: A, xs: Vect[N, A]
  ) extends Vect[S[N], A]

List[A] 非常像,一个向量只会是以下两种可能:

  1. [] ,类似 Nil, 表示零维度的向量,对它来说 A 就是 Nothing 4
  2. :: ,类似 Cons 5,表示一个标量与另一个向量组成的复合向量。

如果说,xs 的维度数量是 N,那么复合得到的向量的维度数量是 S[N],它等同于 N+1 6。此时再看 vect3d

scala> val vect3d = 1 :: 2 :: 3 :: `[]`
val vect3d: 2 :: Int = ::(1,::(2,::(3,[])))

注意,若是觉得 2 :: Int 看着不对劲的,请自行默念三遍 S[N] = N+1。这里的 2 :: Int 等价于 Vect[3, Int]

为了友好地展示向量,我额外实现了一个 show 函数,效果如下:

scala> show(vect3d)
val res0: String = [1, 2, 3]

矩阵与变换函数

矩阵可以理解为,每个维度都是向量的向量。那么,它的类型可以表示为:

type Mat[R <: Int, C <: Int, A] 
  = Vect[R, Vect[C, A]]

进而,对矩阵进行变换的函数类型表示应该是:

def trans[R <: Int, C <: Int, A]: 
  Mat[R, C, A] => Mat[C, R, A] = ???

这样,在类型上严谨地表示出变换操作的含义,即 RC 代表的行列对调。既然如此,有没有感觉实现起来就是临门一脚的事了?

嘿嘿嘿。脑子里是不是在想,来个两层 for 循环嵌套先,数组下标一顿操作?诶,数组下标?这里没有数组,更没有下标,怎么办?

You have to let it all go, Neo. Fear, doubt, and disbelief. Free your mind.

Morpheus, The Matrix (1999) 7

这种指令式编程的思维习惯,我们暂时收住先。不妨来看看,一个典型的函数式编程思维是如何解决这类问题的?

def toList[A]: Vect[?, A] => List[A] = ???

要不试着实现这个 toList 函数,全当是个热身吧。对于一个向量,无论有多少个维度,它只会有两种情况:

def toList[A]: Vect[?, A] => List[A] = 
  case `[]`    => ???
  case x :: xs => ???  

当是 [] 的情况,答案是显而易见的,

def toList[A]: Vect[?, A] => List[A] = 
  case `[]`    => Nil
  case x :: xs => ???  

另一种嘛,呃……如果正向思考不明显,可以尝试一下逆向思考。上文提到,向量和列表都是只有两种情况。列表的另一种情况,也是 一个元素另一个列表 的复合列表。这里的那 一个元素 已经有了,就是 x;而 另一个列表 是不是可以继续由 另一个向量 得到呢?这不巧了嘛,我们手上正好有个函数能做到啊!

def toList[A]: Vect[?, A] => List[A] = 
  case `[]`    => Nil
  case x :: xs => x :: toList(xs)  

拿结果类型这么反向推导一下,就把 ??? 的坑给添上了。试试吧,看编译器认不认。

scala> toList(vect3d)
val res1: List[Int] = List(1, 2, 3)

现在,我们回来看 trans 函数。熟悉的配方,矩阵作为一个向量,就是只会有:

def trans[R <: Int, C <: Int, A]: Mat[R, C, A] => Mat[C, R, A] =
  case `[]`    => ???
  case x :: xs => ???

第一种情况为 [],这意味着 R0,此时的 trans 表示为 Mat[0, C, A] => Mat[C, 0, A]。而如何得到 Mat[C, 0, A] 呢? 套一下矩阵的定义,Mat[C, 0, A] 就是一个 C 个维度是 Vect[0, A] 的向量。而 Vect[0, A] 就是 [] 嘛,把它填充到 C 个维度里,不就是 Mat[C, 0, A] 了嘛。好,假定存在这个 fill 函数:

def trans[R <: Int, C <: Int, A]: Mat[R, C, A] => Mat[C, R, A] =
  case `[]`    => fill[C](`[]`)
  case x :: xs => ???

第二种情况稍显复杂,但 toList 的经验告诉我们,用 trans 肯定能从 xs 得到另一个变换好的向量,至于它怎么与 x 合起来,我们暂且假定有个 zip 函数能够帮我们做到。

def trans[R <: Int, C <: Int, A]: Mat[R, C, A] => Mat[C, R, A] =
  case `[]`    => fill[C](`[]`)
  case x :: xs => zip(x, trans(xs))

至于 zip 函数的类型表示,我们来看类型上下文:

R-1 不好表示,那就两边都加一, 变成 RS[R]。 看不懂 S[R] ,那就是刚才默念不到位。

def zip[R <: Int, C <: Int, A]:
  (Vect[C, A], Mat[C, R, A]) => 
    Mat[C, S[R], A] = ???

一顿推导下来,trans 函数被拆分为 fillzip 两个函数,留给我们逐个击破。fill 看着简单了,其实不然,受篇幅限制这里不展开了,感兴趣请告诉我,我另开一篇细说。跳到 zip ,还是老套路,向量只有两种情况:

def zip[R <: Int, C <: Int, A]: (Vect[C, A], Mat[C, R, A]) => Mat[C, S[R], A] =
  case (`[]`, `[]`)       => ???
  case (x :: xs, y :: ys) => ???

诶,不应该是四种情况吗?注意,参数中向量和矩阵都是 C 个维度,只会有两种情况,没错。第一种情况,答案不能再明显了,跳过。第二种情况,有点懵。没事,再看其类型上下文 8

按惯例,另一个 的部分递归解决,即 zip(xs, ys): Mat[C-1, S[R], A]。这与最终的结果 Mat[C, S[R], A] 就差一个 Vect[S[R], A]。这不又巧了,就是 (x :: y): Vect[S[R], A]啊。

def zip[R <: Int, C <: Int, A]: (Vect[C, A], Mat[C, R, A]) => Mat[C, S[R], A] =
  case (`[]`, `[]`)       => `[]`
  case (x :: xs, y :: yx) => 
    (x :: y) :: zip(xs, ys)

脑子没跟上的话,就结合下图进行颅内推演一下吧。

运算推演

完整的代码照例查看脚注 9, 最终的实现效果应该如下,

scala> show(mat)
val res0: String = [
  [1, 2]
  [3, 4]
  [5, 6]
]

scala> show(trans(mat))
val res1: String = [
  [1, 3, 5]
  [2, 4, 6]
]

或许你会想,“这就是个学术示例,工业界不会真有人这么玩类型吧”。工业界是否有人这么做,我确实不知道,但我发现 Mojo 作为一门新晋 AI 编程语言,致力于模型训练推理的编译优化,其官方文档也有类似的示例,引用至此供大家自行判断 10

fn concat[ty: DType, len1: Int, len2: Int](
  lhs: SIMD[ty, len1], rhs: SIMD[ty, len2]
) -> SIMD[ty, len1+len2]:

  var result = SIMD[ty, len1 + len2]()
  for i in range(len1):
    result[i] = SIMD[ty, 1](lhs[i])
  for j in range(len2):
    result[len1 + j] = SIMD[ty, 1](rhs[j])
  return result

var a = SIMD[DType.float32, 2](1, 2)
var x = concat(a, a)

print(
  'result type:', x.element_type, 
  'length:', len(x)
)

何为TDD,又为何?

什么是类型驱动开发?我尝试说下个人浅见:

  1. 基于业务场景,建立数据及其函数的类型表示;
  2. 在类型表示基础上,双向推导,挖坑假设;
  3. 迭代往复上述两个步骤,直到所有坑都被填满。

最后,引用一段电影念白来与君共勉。

I’m trying to free your mind, Neo. But I can only show you the door. You’re the one that has to walk through it.

Morpheus, The Matrix (1999) 11

照旧,若觉得不错,还请不吝分享,这是对我最大的鼓励。


  1. 矩阵变换的示例也借鉴自 「Type-Driven Development with Idris」 一书。↩︎

  2. 想到用 Tuple 表示的,给你点个赞。↩︎

  3. https://en.wikipedia.org/wiki/Algebraic_data_type↩︎

  4. 既然它什么东西都不是,那么它就可以是任何东西。😇↩︎

  5. Cons 是 Construct 的缩写。↩︎

  6. S 是 Succ 的缩写,其含义什么的后续。↩︎

  7. https://www.imdb.com/title/tt0133093/quotes/?item=qt0324296&ref_=ext_shr_lnk↩︎

  8. 不得不承认 Idris 查看类型上下文上是非常直观方便的。↩︎

  9. https://github.com/zhongl/type-driven-development-with-scala3/blob/main/ch03/matrix.worksheet.sc↩︎

  10. https://docs.modular.com/mojo/manual/parameters/#parameter-expressions-are-just-mojo-code↩︎

  11. https://www.imdb.com/title/tt0133093/quotes/?item=qt0324256&ref_=ext_shr_lnk↩︎