使用 MLX 和 Swift 进行设备端 ML 研究
Swift 编程语言在机器学习研究方面具有很大的潜力,因为它结合了像 Python 这样易于使用的高级语法和像 C++ 这样编译型语言的速度。
MLX 是一个用于 Apple 芯片上机器学习研究的数组框架。MLX 旨在用于研究,而不是在应用程序中生产部署模型。
MLX Swift 将 MLX 扩展到 Swift 语言,使 ML 研究人员在 Apple 芯片上进行实验更加容易。
作为此版本的一部分,我们包括
- MLX 核心的全面 Swift API
- 更高级别的神经网络和优化器软件包
- 使用 Mistral 7B 进行文本生成的示例
- MNIST 训练的示例
- MLX 的 C API,它充当 Swift 和 C++ 核心之间的桥梁
我们根据宽松的 MIT 许可证 发布以上所有内容。
这是使 ML 研究人员能够使用 Swift 进行实验的一大步。
动机
MLX 具有多项对于机器学习研究非常重要的功能,而现有的 Swift 库几乎都不支持这些功能。这些功能包括
- 对硬件加速的原生支持。MLX 可以在 CPU 或 GPU 上运行计算密集型操作。
- 用于训练神经网络和基于梯度的机器学习模型的自动微分
有关 MLX 的更多信息,请参阅文档。
Swift 编程语言速度快、易于使用,并且在 Apple 芯片上运行良好。借助 MLX Swift,您现在拥有了一个研究人员友好的机器学习框架,能够轻松地在不同平台和设备上进行实验。
快速浏览
使用 Xcode 或 SwiftPM 设置 MLX Swift 既快速又简单。
在 MLX Swift 中,构建和执行 N 维数组的操作非常简单。在以下示例中,所有操作都将在默认设备上运行,除非另有指定,否则默认设备为 GPU。
import MLX
import MLXRandom
let r = MLXRandom.normal([2])
print(r)
// array([-0.125875, 0.264235], dtype=float32)
let a = MLXArray(0 ..< 6, [3, 2])
print(a)
// array([[0, 1],
// [2, 3],
// [4, 5]], dtype=int32)
// last element of 0th row
print(a[0, -1])
// array(1, dtype=int32)
// slice of the first two rows
print(a[0 ..< 2])
// array([[0, 1],
// [2, 3]], dtype=int32)
// add with broadcast
let b = a + r
print(b)
// array([[-0.510713, 1.04633],
// [1.48929, 3.04633],
// [3.48929, 5.04633]], dtype=float32)
您还可以在 MLX Swift 中使用函数变换。MLX 中的函数变换对于使用自动微分训练模型以及优化计算图以提高速度或内存使用率非常有用。以下是计算函数梯度的示例。
func fn(_ x: MLXArray) -> MLXArray {
x.square()
}
let gradFn = grad(fn)
let x = MLXArray(1.5)
let dfdx = gradFn(x)
// prints 2 * 1.5 = 3
print(dfdx)
文档包含更多完整的示例,以帮助您开始使用 MLX Swift
- 使用 LLM 进行文本生成:一个完整的 LLM 文本生成示例,使用 Mistral 7B。该示例将使用任何 Mistral 或 Llama 风格的模型生成文本,包括预量化的 MLX 模型,其中许多模型可在 Hugging Face 上找到。
- 在 MNIST 上训练 MLP:该示例训练一个简单的多层感知器,以使用 MLX Swift 神经网络和优化器软件包对 MNIST 数字进行分类。
更多资源
以下是一些开始使用 MLX Swift 的更多资源
- Swift 文档和示例
- GitHub 仓库
- 如果您遇到任何问题或对改进有建议,我们鼓励您提交 issue。
- 我们欢迎贡献。如果您有兴趣为 MLX Swift 做出贡献,请查看我们的贡献指南。