diff --git a/Assets/NuGet.config b/Assets/NuGet.config
new file mode 100644
index 00000000..0c083882
--- /dev/null
+++ b/Assets/NuGet.config
@@ -0,0 +1,18 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Assets/NuGet.config.meta b/Assets/NuGet.config.meta
new file mode 100644
index 00000000..3b1ef5e2
--- /dev/null
+++ b/Assets/NuGet.config.meta
@@ -0,0 +1,2 @@
+fileFormatVersion: 2
+guid: c2580c867af34e1f5b32f8608fead674
\ No newline at end of file
diff --git a/Assets/Packages.meta b/Assets/Packages.meta
new file mode 100644
index 00000000..7189c497
--- /dev/null
+++ b/Assets/Packages.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: ab890539a7cfa610893c9abb7c829593
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.0.21.0.meta b/Assets/Packages/LLamaSharp.0.21.0.meta
new file mode 100644
index 00000000..c8b074e1
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.0.21.0.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 636965843ec59c693bf17b881f5faf41
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.0.21.0/.signature.p7s b/Assets/Packages/LLamaSharp.0.21.0/.signature.p7s
new file mode 100644
index 00000000..e3ece9f2
Binary files /dev/null and b/Assets/Packages/LLamaSharp.0.21.0/.signature.p7s differ
diff --git a/Assets/Packages/LLamaSharp.0.21.0/LLamaSharp.nuspec b/Assets/Packages/LLamaSharp.0.21.0/LLamaSharp.nuspec
new file mode 100644
index 00000000..138611b3
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.0.21.0/LLamaSharp.nuspec
@@ -0,0 +1,35 @@
+
+
+
+ LLamaSharp
+ 0.21.0
+ Rinne, Martin Evans, jlsantiago and all the other contributors in https://github.com/SciSharp/LLamaSharp/graphs/contributors.
+ MIT
+ https://licenses.nuget.org/MIT
+ README.md
+ https://avatars3.githubusercontent.com/u/44989469?s=200&v=4
+ LLamaSharp is a cross-platform library to run 🦙LLaMA/LLaVA model (and others) in your local device.
+ Based on [llama.cpp](https://github.com/ggerganov/llama.cpp), inference with LLamaSharp is efficient on both CPU and GPU.
+ With the higher-level APIs and RAG support, it's convenient to deploy LLM (Large Language Model) in your application with LLamaSharp.
+ Updated llama.cpp version to 5783575c9d99c4d9370495800663aa5397ceb0be
+ MIT, SciSharp STACK 2025
+ LLama, LLM, GPT, ChatGPT, NLP, AI, Chat Bot, SciSharp
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Assets/Packages/LLamaSharp.0.21.0/LLamaSharp.nuspec.meta b/Assets/Packages/LLamaSharp.0.21.0/LLamaSharp.nuspec.meta
new file mode 100644
index 00000000..ae8d5753
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.0.21.0/LLamaSharp.nuspec.meta
@@ -0,0 +1,7 @@
+fileFormatVersion: 2
+guid: bfe3f35f644b72f3996f44a59cab61f2
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.0.21.0/README.md b/Assets/Packages/LLamaSharp.0.21.0/README.md
new file mode 100644
index 00000000..e8ac579c
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.0.21.0/README.md
@@ -0,0 +1,269 @@
+
+
+[](https://discord.gg/7wNVU65ZDY)
+[](http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=sN9VVMwbWjs5L0ATpizKKxOcZdEPMrp8&authKey=RLDw41bLTrEyEgZZi%2FzT4pYk%2BwmEFgFcrhs8ZbkiVY7a4JFckzJefaYNW6Lk4yPX&noverify=0&group_code=985366726)
+[](https://www.nuget.org/packages/LLamaSharp)
+[](https://www.nuget.org/packages/LLamaSharp.Backend.Cpu)
+[](https://www.nuget.org/packages/LLamaSharp.Backend.Cuda11)
+[](https://www.nuget.org/packages/LLamaSharp.Backend.Cuda12)
+[](https://www.nuget.org/packages/LLamaSharp.semantic-kernel)
+[](https://www.nuget.org/packages/LLamaSharp.kernel-memory)
+[](https://www.nuget.org/packages/LLamaSharp.Backend.Vulkan)
+
+
+**LLamaSharp is a cross-platform library to run 🦙LLaMA/LLaVA model (and others) on your local device. Based on [llama.cpp](https://github.com/ggerganov/llama.cpp), inference with LLamaSharp is efficient on both CPU and GPU. With the higher-level APIs and RAG support, it's convenient to deploy LLMs (Large Language Models) in your application with LLamaSharp.**
+
+**Please star the repo to show your support for this project!🤗**
+
+---
+
+
+
+ Table of Contents
+
+
+
+## 📖Documentation
+
+- [Quick start](https://scisharp.github.io/LLamaSharp/latest/QuickStart/)
+- [FAQ](https://scisharp.github.io/LLamaSharp/latest/FAQ/)
+- [Tutorial](https://scisharp.github.io/LLamaSharp/latest/Tutorials/NativeLibraryConfig/)
+- [Full documentation](https://scisharp.github.io/LLamaSharp/latest/)
+- [API reference](https://scisharp.github.io/LLamaSharp/latest/xmldocs/)
+
+
+## 📌Console Demo
+
+
+
+ LLaMA |
+ LLaVA |
+
+
+  |
+  |
+
+
+
+
+## 🔗Integrations & Examples
+
+There are integrations for the following libraries, making it easier to develop your APP. Integrations for semantic-kernel and kernel-memory are developed in the LLamaSharp repository, while others are developed in their own repositories.
+
+- [semantic-kernel](https://github.com/microsoft/semantic-kernel): an SDK that integrates LLMs like OpenAI, Azure OpenAI, and Hugging Face.
+- [kernel-memory](https://github.com/microsoft/kernel-memory): a multi-modal AI Service specialized in the efficient indexing of datasets through custom continuous data hybrid pipelines, with support for RAG ([Retrieval Augmented Generation](https://en.wikipedia.org/wiki/Prompt_engineering#Retrieval-augmented_generation)), synthetic memory, prompt engineering, and custom semantic memory processing.
+- [BotSharp](https://github.com/SciSharp/BotSharp): an open source machine learning framework for AI Bot platform builder.
+- [Langchain](https://github.com/tryAGI/LangChain): a framework for developing applications powered by language models.
+
+
+The following examples show how to build APPs with LLamaSharp.
+
+- [Official Console Examples](./LLama.Examples/)
+- [Unity Demo](https://github.com/eublefar/LLAMASharpUnityDemo)
+- [LLamaStack (with WPF and Web demo)](https://github.com/saddam213/LLamaStack)
+- [Blazor Demo (with Model Explorer)](https://github.com/alexhiggins732/BLlamaSharp.ChatGpt.Blazor)
+- [ASP.NET Demo](./LLama.Web/)
+- [LLamaWorker (ASP.NET Web API like OAI and Function Calling Support)](https://github.com/sangyuxiaowu/LLamaWorker)
+- [VirtualPet (Desktop Application)](https://github.com/AcoranGonzalezMoray/VirtualPet-WindowsEdition)
+
+
+
+
+## 🚀Get started
+
+### Installation
+
+To gain high performance, LLamaSharp interacts with native libraries compiled from c++, these are called `backends`. We provide backend packages for Windows, Linux and Mac with CPU, CUDA, Metal and Vulkan. You **don't** need to compile any c++, just install the backend packages.
+
+If no published backend matches your device, please open an issue to let us know. If compiling c++ code is not difficult for you, you could also follow [this guide](./docs/ContributingGuide.md) to compile a backend and run LLamaSharp with it.
+
+1. Install [LLamaSharp](https://www.nuget.org/packages/LLamaSharp) package on NuGet:
+
+```
+PM> Install-Package LLamaSharp
+```
+
+2. Install one or more of these backends, or use a self-compiled backend.
+
+ - [`LLamaSharp.Backend.Cpu`](https://www.nuget.org/packages/LLamaSharp.Backend.Cpu): Pure CPU for Windows, Linux & Mac. Metal (GPU) support for Mac.
+ - [`LLamaSharp.Backend.Cuda11`](https://www.nuget.org/packages/LLamaSharp.Backend.Cuda11): CUDA 11 for Windows & Linux.
+ - [`LLamaSharp.Backend.Cuda12`](https://www.nuget.org/packages/LLamaSharp.Backend.Cuda12): CUDA 12 for Windows & Linux.
+ - [`LLamaSharp.Backend.Vulkan`](https://www.nuget.org/packages/LLamaSharp.Backend.Vulkan): Vulkan for Windows & Linux.
+
+3. (optional) For [Microsoft semantic-kernel](https://github.com/microsoft/semantic-kernel) integration, install the [LLamaSharp.semantic-kernel](https://www.nuget.org/packages/LLamaSharp.semantic-kernel) package.
+4. (optional) To enable RAG support, install the [LLamaSharp.kernel-memory](https://www.nuget.org/packages/LLamaSharp.kernel-memory) package (this package only supports `net6.0` or higher yet), which is based on [Microsoft kernel-memory](https://github.com/microsoft/kernel-memory) integration.
+
+### Model preparation
+
+There are two popular formats of model file of LLMs, these are PyTorch format (.pth) and Huggingface format (.bin). LLamaSharp uses a `GGUF` format file, which can be converted from these two formats. To get a `GGUF` file, there are two options:
+
+1. Search model name + 'gguf' in [Huggingface](https://huggingface.co), you will find lots of model files that have already been converted to GGUF format. Please take note of the publishing time of them because some old ones may only work with older versions of LLamaSharp.
+
+2. Convert PyTorch or Huggingface format to GGUF format yourself. Please follow the instructions from [this part of llama.cpp readme](https://github.com/ggerganov/llama.cpp?tab=readme-ov-file#prepare-and-quantize) to convert them with python scripts.
+
+Generally, we recommend downloading models with quantization rather than fp16, because it significantly reduces the required memory size while only slightly impacting the generation quality.
+
+
+### Example of LLaMA chat session
+
+Here is a simple example to chat with a bot based on a LLM in LLamaSharp. Please replace the model path with yours.
+
+```cs
+using LLama.Common;
+using LLama;
+
+string modelPath = @""; // change it to your own model path.
+
+var parameters = new ModelParams(modelPath)
+{
+ ContextSize = 1024, // The longest length of chat as memory.
+ GpuLayerCount = 5 // How many layers to offload to GPU. Please adjust it according to your GPU memory.
+};
+using var model = LLamaWeights.LoadFromFile(parameters);
+using var context = model.CreateContext(parameters);
+var executor = new InteractiveExecutor(context);
+
+// Add chat histories as prompt to tell AI how to act.
+var chatHistory = new ChatHistory();
+chatHistory.AddMessage(AuthorRole.System, "Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.");
+chatHistory.AddMessage(AuthorRole.User, "Hello, Bob.");
+chatHistory.AddMessage(AuthorRole.Assistant, "Hello. How may I help you today?");
+
+ChatSession session = new(executor, chatHistory);
+
+InferenceParams inferenceParams = new InferenceParams()
+{
+ MaxTokens = 256, // No more than 256 tokens should appear in answer. Remove it if antiprompt is enough for control.
+ AntiPrompts = new List { "User:" }, // Stop generation once antiprompts appear.
+
+ SamplingPipeline = new DefaultSamplingPipeline(),
+};
+
+Console.ForegroundColor = ConsoleColor.Yellow;
+Console.Write("The chat session has started.\nUser: ");
+Console.ForegroundColor = ConsoleColor.Green;
+string userInput = Console.ReadLine() ?? "";
+
+while (userInput != "exit")
+{
+ await foreach ( // Generate the response streamingly.
+ var text
+ in session.ChatAsync(
+ new ChatHistory.Message(AuthorRole.User, userInput),
+ inferenceParams))
+ {
+ Console.ForegroundColor = ConsoleColor.White;
+ Console.Write(text);
+ }
+ Console.ForegroundColor = ConsoleColor.Green;
+ userInput = Console.ReadLine() ?? "";
+}
+```
+
+For more examples, please refer to [LLamaSharp.Examples](./LLama.Examples).
+
+
+## 💡FAQ
+
+#### Why is my GPU not used when I have installed CUDA?
+
+1. If you are using backend packages, please make sure you have installed the CUDA backend package which matches the CUDA version installed on your system.
+2. Add the following line to the very beginning of your code. The log will show which native library file is loaded. If the CPU library is loaded, please try to compile the native library yourself and open an issue for that. If the CUDA library is loaded, please check if `GpuLayerCount > 0` when loading the model weight.
+
+```cs
+ NativeLibraryConfig.Instance.WithLogCallback(delegate (LLamaLogLevel level, string message) { Console.Write($"{level}: {message}"); } )
+```
+
+
+#### Why is the inference so slow?
+
+Firstly, due to the large size of LLM models, it requires more time to generate output than other models, especially when you are using models larger than 30B parameters.
+
+To see if that's a LLamaSharp performance issue, please follow the two tips below.
+
+1. If you are using CUDA, Metal or Vulkan, please set `GpuLayerCount` as large as possible.
+2. If it's still slower than you expect it to be, please try to run the same model with same setting in [llama.cpp examples](https://github.com/ggerganov/llama.cpp/tree/master/examples). If llama.cpp outperforms LLamaSharp significantly, it's likely a LLamaSharp BUG and please report that to us.
+
+
+#### Why does the program crash before any output is generated?
+
+Generally, there are two possible cases for this problem:
+
+1. The native library (backend) you are using is not compatible with the LLamaSharp version. If you compiled the native library yourself, please make sure you have checked-out llama.cpp to the corresponding commit of LLamaSharp, which can be found at the bottom of README.
+2. The model file you are using is not compatible with the backend. If you are using a GGUF file downloaded from huggingface, please check its publishing time.
+
+#### Why is my model generating output infinitely?
+
+Please set anti-prompt or max-length when executing the inference.
+
+
+## 🙌Contributing
+
+All contributions are welcome! There's a TODO list in [LLamaSharp Dev Project](https://github.com/orgs/SciSharp/projects/5) and you can pick an interesting one to start. Please read the [contributing guide](./CONTRIBUTING.md) for more information.
+
+You can also do one of the following to help us make LLamaSharp better:
+
+- Submit a feature request.
+- Star and share LLamaSharp to let others know about it.
+- Write a blog or demo about LLamaSharp.
+- Help to develop Web API and UI integration.
+- Just open an issue about the problem you've found!
+
+## Join the community
+
+Join our chat on [Discord](https://discord.gg/7wNVU65ZDY) (please contact Rinne to join the dev channel if you want to be a contributor).
+
+Join [QQ group](http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=sN9VVMwbWjs5L0ATpizKKxOcZdEPMrp8&authKey=RLDw41bLTrEyEgZZi%2FzT4pYk%2BwmEFgFcrhs8ZbkiVY7a4JFckzJefaYNW6Lk4yPX&noverify=0&group_code=985366726)
+
+## Star history
+
+[](https://star-history.com/#SciSharp/LLamaSharp&Date)
+
+## Contributor wall of fame
+
+[](https://github.com/SciSharp/LLamaSharp/graphs/contributors)
+
+## Map of LLamaSharp and llama.cpp versions
+If you want to compile llama.cpp yourself you **must** use the exact commit ID listed for each version.
+
+| LLamaSharp | Verified Model Resources | llama.cpp commit id |
+| - | -- | - |
+| v0.2.0 | This version is not recommended to use. | - |
+| v0.2.1 | [WizardLM](https://huggingface.co/TheBloke/wizardLM-7B-GGML/tree/previous_llama), [Vicuna (filenames with "old")](https://huggingface.co/eachadea/ggml-vicuna-13b-1.1/tree/main) | - |
+| v0.2.2, v0.2.3 | [WizardLM](https://huggingface.co/TheBloke/wizardLM-7B-GGML/tree/previous_llama_ggmlv2), [Vicuna (filenames without "old")](https://huggingface.co/eachadea/ggml-vicuna-13b-1.1/tree/main) | `63d2046` |
+| v0.3.0, v0.4.0 | [LLamaSharpSamples v0.3.0](https://huggingface.co/AsakusaRinne/LLamaSharpSamples/tree/v0.3.0), [WizardLM](https://huggingface.co/TheBloke/wizardLM-7B-GGML/tree/main) | `7e4ea5b` |
+| v0.4.1-preview | [Open llama 3b](https://huggingface.co/SlyEcho/open_llama_3b_ggml), [Open Buddy](https://huggingface.co/OpenBuddy/openbuddy-llama-ggml)| `aacdbd4` |
+|v0.4.2-preview | [Llama2 7B (GGML)](https://huggingface.co/TheBloke/llama-2-7B-Guanaco-QLoRA-GGML)| `3323112` |
+| v0.5.1 | [Llama2 7B (GGUF)](https://huggingface.co/TheBloke/llama-2-7B-Guanaco-QLoRA-GGUF)| `6b73ef1` |
+| v0.6.0 | | [`cb33f43`](https://github.com/ggerganov/llama.cpp/commit/cb33f43a2a9f5a5a5f8d290dd97c625d9ba97a2f) |
+| v0.7.0, v0.8.0 | [Thespis-13B](https://huggingface.co/TheBloke/Thespis-13B-v0.5-GGUF/tree/main?not-for-all-audiences=true), [LLaMA2-7B](https://huggingface.co/TheBloke/llama-2-7B-Guanaco-QLoRA-GGUF) | [`207b519`](https://github.com/ggerganov/llama.cpp/commit/207b51900e15cc7f89763a3bb1c565fe11cbb45d) |
+| v0.8.1 | | [`e937066`](https://github.com/ggerganov/llama.cpp/commit/e937066420b79a757bf80e9836eb12b88420a218) |
+| v0.9.0, v0.9.1 | [Mixtral-8x7B](https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF) | [`9fb13f9`](https://github.com/ggerganov/llama.cpp/blob/9fb13f95840c722ad419f390dc8a9c86080a3700) |
+| v0.10.0 | [Phi2](https://huggingface.co/TheBloke/phi-2-GGUF) | [`d71ac90`](https://github.com/ggerganov/llama.cpp/tree/d71ac90985854b0905e1abba778e407e17f9f887) |
+| v0.11.1, v0.11.2 | [LLaVA-v1.5](https://hf-mirror.com/jartine/llava-v1.5-7B-GGUF/blob/main/llava-v1.5-7b-mmproj-Q4_0.gguf), [Phi2](https://huggingface.co/TheBloke/phi-2-GGUF)| [`3ab8b3a`](https://github.com/ggerganov/llama.cpp/tree/3ab8b3a92ede46df88bc5a2dfca3777de4a2b2b6) |
+| v0.12.0 | LLama3 | [`a743d76`](https://github.com/ggerganov/llama.cpp/tree/a743d76a01f23038b2c85af1e9048ee836767b44) |
+| v0.13.0 | | [`1debe72`](https://github.com/ggerganov/llama.cpp/tree/1debe72737ea131cb52975da3d53ed3a835df3a6) |
+| v0.14.0 | Gemma2 | [`36864569`](https://github.com/ggerganov/llama.cpp/tree/368645698ab648e390dcd7c00a2bf60efa654f57) |
+| v0.15.0 | LLama3.1 | [`345c8c0c`](https://github.com/ggerganov/llama.cpp/tree/345c8c0c87a97c1595f9c8b14833d531c8c7d8df) |
+| v0.16.0 | | [`11b84eb4`](https://github.com/ggerganov/llama.cpp/tree/11b84eb4578864827afcf956db5b571003f18180) |
+| v0.17.0 | | [`c35e586e`](https://github.com/ggerganov/llama.cpp/tree/c35e586ea57221844442c65a1172498c54971cb0) |
+| v0.18.0 | | [`c35e586e`](https://github.com/ggerganov/llama.cpp/tree/c35e586ea57221844442c65a1172498c54971cb0) |
+| v0.19.0 | | [`958367bf`](https://github.com/ggerganov/llama.cpp/tree/958367bf530d943a902afa1ce1c342476098576b) |
+| v0.20.0 | | [`0827b2c1`](https://github.com/ggerganov/llama.cpp/tree/0827b2c1da299805288abbd556d869318f2b121e) |
+| v0.21.0 | | [`5783575c`](https://github.com/ggerganov/llama.cpp/tree/5783575c9d99c4d9370495800663aa5397ceb0be) |
+
+## License
+
+This project is licensed under the terms of the MIT license.
+
diff --git a/Assets/Packages/LLamaSharp.0.21.0/README.md.meta b/Assets/Packages/LLamaSharp.0.21.0/README.md.meta
new file mode 100644
index 00000000..5c6c6e6d
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.0.21.0/README.md.meta
@@ -0,0 +1,7 @@
+fileFormatVersion: 2
+guid: 63192d32e849551238a5cdc53236686c
+TextScriptImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.0.21.0/lib.meta b/Assets/Packages/LLamaSharp.0.21.0/lib.meta
new file mode 100644
index 00000000..5f07e65d
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.0.21.0/lib.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: e7631e8e5c1431433be8d6402f6abf7c
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.0.21.0/lib/netstandard2.0.meta b/Assets/Packages/LLamaSharp.0.21.0/lib/netstandard2.0.meta
new file mode 100644
index 00000000..62ce0503
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.0.21.0/lib/netstandard2.0.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 5ba284620da5bf88aa74980eff1af5af
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.0.21.0/lib/netstandard2.0/LLamaSharp.dll b/Assets/Packages/LLamaSharp.0.21.0/lib/netstandard2.0/LLamaSharp.dll
new file mode 100644
index 00000000..d9be0591
Binary files /dev/null and b/Assets/Packages/LLamaSharp.0.21.0/lib/netstandard2.0/LLamaSharp.dll differ
diff --git a/Assets/Packages/LLamaSharp.0.21.0/lib/netstandard2.0/LLamaSharp.dll.meta b/Assets/Packages/LLamaSharp.0.21.0/lib/netstandard2.0/LLamaSharp.dll.meta
new file mode 100644
index 00000000..08bfa91b
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.0.21.0/lib/netstandard2.0/LLamaSharp.dll.meta
@@ -0,0 +1,29 @@
+fileFormatVersion: 2
+guid: 381252d9d25ca5fcbbb704565fe77126
+labels:
+- NuGetForUnity
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 3
+ iconMap: {}
+ executionOrder: {}
+ defineConstraints: []
+ isPreloaded: 0
+ isOverridable: 0
+ isExplicitlyReferenced: 0
+ validateReferences: 1
+ platformData:
+ Any:
+ enabled: 1
+ settings: {}
+ Editor:
+ enabled: 0
+ settings:
+ DefaultValueInitialized: true
+ WindowsStoreApps:
+ enabled: 0
+ settings:
+ CPU: AnyCPU
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.0.21.0/lib/netstandard2.0/LLamaSharp.xml b/Assets/Packages/LLamaSharp.0.21.0/lib/netstandard2.0/LLamaSharp.xml
new file mode 100644
index 00000000..e21e24da
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.0.21.0/lib/netstandard2.0/LLamaSharp.xml
@@ -0,0 +1,7219 @@
+
+
+
+ LLamaSharp
+
+
+
+
+ Reserved to be used by the compiler for tracking metadata.
+ This class should not be used by developers in source code.
+
+
+ This definition is provided by the IsExternalInit NuGet package (https://www.nuget.org/packages/IsExternalInit).
+ Please see https://github.com/manuelroemer/IsExternalInit for more information.
+
+
+
+
+ The parameters for initializing a LLama context from a model.
+
+
+
+
+ Model context size (n_ctx)
+
+
+
+
+ maximum batch size that can be submitted at once (must be >=32 to use BLAS) (n_batch)
+
+
+
+
+ Physical batch size
+
+
+
+
+ max number of sequences (i.e. distinct states for recurrent models)
+
+
+
+
+ If true, extract embeddings (together with logits).
+
+
+
+
+ RoPE base frequency (null to fetch from the model)
+
+
+
+
+ RoPE frequency scaling factor (null to fetch from the model)
+
+
+
+
+ The encoding to use for models
+
+
+
+
+ Number of threads (null = autodetect) (n_threads)
+
+
+
+
+ Number of threads to use for batch processing (null = autodetect) (n_threads)
+
+
+
+
+ YaRN extrapolation mix factor (null = from model)
+
+
+
+
+ YaRN magnitude scaling factor (null = from model)
+
+
+
+
+ YaRN low correction dim (null = from model)
+
+
+
+
+ YaRN high correction dim (null = from model)
+
+
+
+
+ YaRN original context length (null = from model)
+
+
+
+
+ YaRN scaling method to use.
+
+
+
+
+ Override the type of the K cache
+
+
+
+
+ Override the type of the V cache
+
+
+
+
+ Whether to disable offloading the KQV cache to the GPU
+
+
+
+
+ Whether to use flash attention
+
+
+
+
+ defragment the KV cache if holes/size > defrag_threshold, Set to < 0 to disable (default)
+ defragment the KV cache if holes/size > defrag_threshold, Set to or < 0 to disable (default)
+
+
+
+
+ How to pool (sum) embedding results by sequence id (ignored if no pooling layer)
+
+
+
+
+ Attention type to use for embeddings
+
+
+
+
+ Transform history to plain text and vice versa.
+
+
+
+
+ Convert a ChatHistory instance to plain text.
+
+ The ChatHistory instance
+
+
+
+
+ Converts plain text to a ChatHistory instance.
+
+ The role for the author.
+ The chat history as plain text.
+ The updated history.
+
+
+
+ Copy the transform.
+
+
+
+
+
+ The parameters used for inference.
+
+
+
+
+ number of tokens to keep from initial prompt
+
+
+
+
+ how many new tokens to predict (n_predict), set to -1 to inifinitely generate response
+ until it complete.
+
+
+
+
+ Sequences where the model will stop generating further tokens.
+
+
+
+
+ Set a custom sampling pipeline to use.
+
+
+
+
+ A high level interface for LLama models.
+
+
+
+
+ The loaded context for this executor.
+
+
+
+
+ Identify if it's a multi-modal model and there is a image to process.
+
+
+
+
+ Multi-Modal Projections / Clip Model weights
+
+
+
+
+ List of images: List of images in byte array format.
+
+
+
+
+ Asynchronously infers a response from the model.
+
+ Your prompt
+ Any additional parameters
+ A cancellation token.
+
+
+
+
+ Convenience interface for implementing both type of parameters.
+
+ Mostly exists for backwards compatibility reasons, when these two were not split.
+
+
+
+ The parameters for initializing a LLama model.
+
+
+
+
+ main_gpu interpretation depends on split_mode:
+
+ -
+ None
+ The GPU that is used for the entire mode.
+
+ -
+ Row
+ The GPU that is used for small tensors and intermediate results.
+
+ -
+ Layer
+ Ignored.
+
+
+
+
+
+
+ How to split the model across multiple GPUs
+
+
+
+
+ Number of layers to run in VRAM / GPU memory (n_gpu_layers)
+
+
+
+
+ Use mmap for faster loads (use_mmap)
+
+
+
+
+ Use mlock to keep model in memory (use_mlock)
+
+
+
+
+ Model path (model)
+
+
+
+
+ how split tensors should be distributed across GPUs
+
+
+
+
+ Load vocab only (no weights)
+
+
+
+
+ Validate model tensor data before loading
+
+
+
+
+ Override specific metadata items in the model
+
+
+
+
+ A fixed size array to set the tensor splits across multiple GPUs
+
+
+
+
+ The size of this array
+
+
+
+
+ Get or set the proportion of work to do on the given device.
+
+ "[ 3, 2 ]" will assign 60% of the data to GPU 0 and 40% to GPU 1.
+
+
+
+
+
+ Create a new tensor splits collection, copying the given values
+
+
+
+
+
+
+ Create a new tensor splits collection with all values initialised to the default
+
+
+
+
+ Set all values to zero
+
+
+
+
+
+
+
+
+
+
+ A JSON converter for
+
+
+
+
+
+
+
+
+
+
+ An override for a single key/value pair in model metadata
+
+
+
+
+ Get the key being overridden by this override
+
+
+
+
+ Create a new override for an int key
+
+
+
+
+
+
+ Create a new override for a float key
+
+
+
+
+
+
+ Create a new override for a boolean key
+
+
+
+
+
+
+ Create a new override for a string key
+
+
+
+
+
+
+ A JSON converter for
+
+
+
+
+
+
+
+
+
+
+ Descriptor of a native library.
+
+
+
+
+ Metadata of this library.
+
+
+
+
+ Prepare the native library file and returns the local path of it.
+ If it's a relative path, LLamaSharp will search the path in the search directies you set.
+
+ The system information of the current machine.
+ The log callback.
+
+ The relative paths of the library. You could return multiple paths to try them one by one. If no file is available, please return an empty array.
+
+
+
+
+ Takes a stream of tokens and transforms them.
+
+
+
+
+ Takes a stream of tokens and transforms them, returning a new stream of tokens asynchronously.
+
+
+
+
+
+
+ Copy the transform.
+
+
+
+
+
+ An interface for text transformations.
+ These can be used to compose a pipeline of text transformations, such as:
+ - Tokenization
+ - Lowercasing
+ - Punctuation removal
+ - Trimming
+ - etc.
+
+
+
+
+ Takes a string and transforms it.
+
+
+
+
+
+
+ Copy the transform.
+
+
+
+
+
+ Extension methods to the interface.
+
+
+
+ Gets an instance for the specified .
+ The executor.
+ The to use to transform an input list messages into a prompt.
+ The to use to transform the output into text.
+ An instance for the provided .
+ is null.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Format the chat messages into a string prompt.
+
+
+ Convert the chat options to inference parameters.
+
+
+ A default transform that appends "Assistant: " to the end.
+
+
+
+ AntipromptProcessor keeps track of past tokens looking for any set Anti-Prompts
+
+
+
+
+ Initializes a new instance of the class.
+
+ The antiprompts.
+
+
+
+ Add an antiprompt to the collection
+
+
+
+
+
+ Overwrite all current antiprompts with a new set
+
+
+
+
+
+ Add some text and check if the buffer now ends with any antiprompt
+
+
+ true if the text buffer ends with any antiprompt
+
+
+
+ A batched executor that can infer multiple separate "conversations" simultaneously.
+
+
+
+
+ Set to 1 using interlocked exchange while inference is running
+
+
+
+
+ Epoch is incremented twice every time Infer is called. Conversations can use this to keep track of
+ whether they're waiting for inference, or can be sampled.
+
+
+
+
+ The this executor is using
+
+
+
+
+ The this executor is using
+
+
+
+
+ Get the number of tokens in the batch, waiting for to be called
+
+
+
+
+ Number of batches in the queue, waiting for to be called
+
+
+
+
+ Check if this executor has been disposed.
+
+
+
+
+ Create a new batched executor
+
+ The model to use
+ Parameters to create a new context
+
+
+
+ Start a new
+
+
+
+
+
+ Load a conversation that was previously saved to a file. Once loaded the conversation will
+ need to be prompted.
+
+
+
+
+
+
+
+ Load a conversation that was previously saved into memory. Once loaded the conversation will need to be prompted.
+
+
+
+
+
+
+
+ Run inference for all conversations in the batch which have pending tokens.
+
+ If the result is `NoKvSlot` then there is not enough memory for inference, try disposing some conversation
+ threads and running inference again.
+
+
+
+
+
+
+
+ Get a reference to a batch that tokens can be added to.
+
+
+
+
+
+
+
+ Get a reference to a batch that embeddings can be added to.
+
+
+
+
+
+
+
+ A single conversation thread that can be prompted (adding tokens from the user) or inferred (extracting a token from the LLM)
+
+
+
+
+ Indicates if this conversation has been "forked" and may share logits with another conversation.
+
+
+
+
+ Stores the indices to sample from. Contains valid items.
+
+
+
+
+ The executor which this conversation belongs to
+
+
+
+
+ Unique ID for this conversation
+
+
+
+
+ Total number of tokens in this conversation, cannot exceed the context length.
+
+
+
+
+ Indicates if this conversation has been disposed, nothing can be done with a disposed conversation
+
+
+
+
+ Indicates if this conversation is waiting for inference to be run on the executor. "Prompt" and "Sample" cannot be called when this is true.
+
+
+
+
+ Indicates that this conversation should be sampled.
+
+
+
+
+ Finalizer for Conversation
+
+
+
+
+ End this conversation, freeing all resources used by it
+
+
+
+
+
+ Create a copy of the current conversation
+
+ The copy shares internal state, so consumes very little extra memory.
+
+
+
+
+
+ Get the index in the context which each token can be sampled from, the return value of this function get be used to retrieve logits
+ () or to sample a token (.
+
+ How far from the end of the previous prompt should logits be sampled. Any value other than 0 requires
+ allLogits to have been set during prompting.
+ For example if 5 tokens were supplied in the last prompt call:
+
+ - The logits of the first token can be accessed with 4
+ - The logits of the second token can be accessed with 3
+ - The logits of the third token can be accessed with 2
+ - The logits of the fourth token can be accessed with 1
+ - The logits of the fifth token can be accessed with 0
+
+
+
+
+ Thrown if this conversation was not prompted before the previous call to infer
+ Thrown if Infer() must be called on the executor
+
+
+
+ Get the logits from this conversation, ready for sampling
+
+ How far from the end of the previous prompt should logits be sampled. Any value other than 0 requires allLogits to have been set during prompting
+
+
+ Thrown if this conversation was not prompted before the previous call to infer
+ Thrown if Infer() must be called on the executor
+
+
+
+ Add tokens to this conversation
+
+
+ If true, generate logits for all tokens. If false, only generate logits for the last token.
+
+
+
+
+
+
+ Add tokens to this conversation
+
+
+ If true, generate logits for all tokens. If false, only generate logits for the last token.
+
+
+
+
+
+
+ Add a single token to this conversation
+
+
+
+
+
+
+
+
+ Prompt this conversation with an image embedding
+
+
+
+
+
+ Prompt this conversation with embeddings
+
+ The raw values of the embeddings. This span must divide equally by the embedding size of this model.
+
+
+
+ Directly modify the KV cache of this conversation
+
+
+ Thrown if this method is called while == true
+
+
+
+ Provides direct access to the KV cache of a .
+ See for how to use this.
+
+
+
+
+ Removes all tokens that have positions in [start, end)
+
+ Start position (inclusive)
+ End position (exclusive)
+
+
+
+ Removes all tokens starting from the given position
+
+ Start position (inclusive)
+ Number of tokens
+
+
+
+ Adds relative position "delta" to all tokens that have positions in [p0, p1).
+ If the KV cache is RoPEd, the KV data is updated
+ accordingly
+
+ Start position (inclusive)
+ End position (exclusive)
+ Amount to add on to each token position
+
+
+
+ Integer division of the positions by factor of `d > 1`.
+ If the KV cache is RoPEd, the KV data is updated accordingly.
+
+ Start position (inclusive). If less than zero, it is clamped to zero.
+ End position (exclusive). If less than zero, it is treated as "infinity".
+ Amount to divide each position by.
+
+
+
+ A function which can temporarily access the KV cache of a to modify it directly
+
+ The current end token of this conversation
+ An which allows direct access to modify the KV cache
+ The new end token position
+
+
+
+ Save the complete state of this conversation to a file. if the file already exists it will be overwritten.
+
+
+
+
+
+
+ Save the complete state of this conversation in system memory.
+
+
+
+
+
+ Load state from a file
+ This should only ever be called by the BatchedExecutor, on a newly created conversation object!
+
+
+
+
+
+
+ Load state from a previously saved state.
+ This should only ever be called by the BatchedExecutor, on a newly created conversation object!
+
+
+
+
+
+
+
+
+ In memory saved state of a
+
+
+
+
+ Indicates if this state has been disposed
+
+
+
+
+ Get the size in bytes of this state object
+
+
+
+
+
+
+
+ Internal constructor prevent anyone outside of LLamaSharp extending this class
+
+
+
+
+ Extension method for
+
+
+
+
+ Sample a token from this conversation using the given sampler chain
+
+ to sample from
+
+ Offset from the end of the conversation to the logits to sample, see for more details
+
+
+
+
+ Sample a token from this conversation using the given sampling pipeline
+
+ to sample from
+
+ Offset from the end of the conversation to the logits to sample, see for more details
+
+
+
+
+ Rewind a back to an earlier state by removing tokens from the end
+
+ The conversation to rewind
+ The number of tokens to rewind
+ Thrown if `tokens` parameter is larger than TokenCount
+
+
+
+ Shift all tokens over to the left, removing "count" tokens from the start and shifting everything over.
+ Leaves "keep" tokens at the start completely untouched. This can be used to free up space when the context
+ gets full, keeping the prompt at the start intact.
+
+ The conversation to rewind
+ How much to shift tokens over by
+ The number of tokens at the start which should not be shifted
+
+
+
+ Base class for exceptions thrown from
+
+
+
+
+ This exception is thrown when "Prompt()" is called on a which has
+ already been prompted and before "Infer()" has been called on the associated
+ .
+
+
+
+
+ This exception is thrown when "Sample()" is called on a which has
+ already been prompted and before "Infer()" has been called on the associated
+ .
+
+
+
+
+ This exception is thrown when "Sample()" is called on a which was not
+ first prompted.
+ .
+
+
+
+
+ This exception is thrown when is called when = true
+
+
+
+
+ This exception is thrown when "Save()" is called on a which has
+ already been prompted and before "Infer()" has been called.
+ .
+
+
+
+
+ Save the state of a particular sequence to specified path. Also save some extra data which will be returned when loading.
+ Data saved with this method must be saved with
+
+
+
+
+
+
+
+
+ Load the state from the specified path into a particular sequence. Also reading header data. Must only be used with
+ data previously saved with
+
+
+
+
+
+
+
+
+
+ The main chat session class.
+
+
+
+
+ The filename for the serialized model state (KV cache, etc).
+
+
+
+
+ The filename for the serialized executor state.
+
+
+
+
+ The filename for the serialized chat history.
+
+
+
+
+ The filename for the serialized input transform pipeline.
+
+
+
+
+ The filename for the serialized output transform.
+
+
+
+
+ The filename for the serialized history transform.
+
+
+
+
+ The executor for this session.
+
+
+
+
+ The chat history for this session.
+
+
+
+
+ The history transform used in this session.
+
+
+
+
+ The input transform pipeline used in this session.
+
+
+
+
+ The output transform used in this session.
+
+
+
+
+ Create a new chat session and preprocess history.
+
+ The executor for this session
+ History for this session
+ History Transform for this session
+ A new chat session.
+
+
+
+ Create a new chat session.
+
+ The executor for this session
+
+
+
+ Create a new chat session with a custom history.
+
+
+
+
+
+
+ Use a custom history transform.
+
+
+
+
+
+
+ Add a text transform to the input transform pipeline.
+
+
+
+
+
+
+ Use a custom output transform.
+
+
+
+
+
+
+ Save a session from a directory.
+
+
+
+
+
+
+
+ Get the session state.
+
+ SessionState object representing session state in-memory
+
+
+
+ Load a session from a session state.
+
+
+ If true loads transforms saved in the session state.
+
+
+
+
+
+ Load a session from a directory.
+
+
+ If true loads transforms saved in the session state.
+
+
+
+
+
+ Add a message to the chat history.
+
+
+
+
+
+
+ Add a system message to the chat history.
+
+
+
+
+
+
+ Add an assistant message to the chat history.
+
+
+
+
+
+
+ Add a user message to the chat history.
+
+
+
+
+
+
+ Remove the last message from the chat history.
+
+
+
+
+
+ Compute KV cache for the message and add it to the chat history.
+
+
+
+
+
+
+ Compute KV cache for the system message and add it to the chat history.
+
+
+
+
+ Compute KV cache for the user message and add it to the chat history.
+
+
+
+
+ Compute KV cache for the assistant message and add it to the chat history.
+
+
+
+
+ Replace a user message with a new message and remove all messages after the new message.
+ This is useful when the user wants to edit a message. And regenerate the response.
+
+
+
+
+
+
+
+ Chat with the model.
+
+
+
+
+
+
+
+
+
+
+ Chat with the model.
+
+
+
+
+
+
+
+
+ Chat with the model.
+
+
+
+
+
+
+
+
+
+
+ Chat with the model.
+
+
+
+
+
+
+
+
+ Regenerate the last assistant message.
+
+
+
+
+
+
+
+
+ The state of a chat session in-memory.
+
+
+
+
+ Saved executor state for the session in JSON format.
+
+
+
+
+ Saved context state (KV cache) for the session.
+
+
+
+
+ The input transform pipeline used in this session.
+
+
+
+
+ The output transform used in this session.
+
+
+
+
+ The history transform used in this session.
+
+
+
+
+ The the chat history messages for this session.
+
+
+
+
+ Create a new session state.
+
+
+
+
+
+
+
+
+
+
+ Save the session state to folder.
+
+
+
+
+
+ Load the session state from folder.
+
+
+
+ Throws when session state is incorrect
+
+
+
+ Role of the message author, e.g. user/assistant/system
+
+
+
+
+ Role is unknown
+
+
+
+
+ Message comes from a "system" prompt, not written by a user or language model
+
+
+
+
+ Message comes from the user
+
+
+
+
+ Messages was generated by the language model
+
+
+
+
+ The chat history class
+
+
+
+
+ Chat message representation
+
+
+
+
+ Role of the message author, e.g. user/assistant/system
+
+
+
+
+ Message content
+
+
+
+
+ Create a new instance
+
+ Role of message author
+ Message content
+
+
+
+ List of messages in the chat
+
+
+
+
+ Create a new instance of the chat content class
+
+
+
+
+ Create a new instance of the chat history from array of messages
+
+
+
+
+
+ Add a message to the chat history
+
+ Role of the message author
+ Message content
+
+
+
+ Serialize the chat history to JSON
+
+
+
+
+
+ Deserialize a chat history from JSON
+
+
+
+
+
+
+ A queue with fixed storage size.
+ Currently it's only a naive implementation and needs to be further optimized in the future.
+
+
+
+
+
+
+
+ Number of items in this queue
+
+
+
+
+ Maximum number of items allowed in this queue
+
+
+
+
+ Create a new queue
+
+ the maximum number of items to store in this queue
+
+
+
+ Fill the quene with the data. Please ensure that data.Count <= size
+
+
+
+
+
+
+ Enquene an element.
+
+
+
+
+
+
+
+
+
+
+
+ The parameters used for inference.
+
+
+
+
+ number of tokens to keep from initial prompt when applying context shifting
+
+
+
+
+ how many new tokens to predict (n_predict), set to -1 to inifinitely generate response
+ until it complete.
+
+
+
+
+ Sequences where the model will stop generating further tokens.
+
+
+
+
+
+
+
+ Type of "mirostat" sampling to use.
+ https://github.com/basusourya/mirostat
+
+
+
+
+ Disable Mirostat sampling
+
+
+
+
+ Original mirostat algorithm
+
+
+
+
+ Mirostat 2.0 algorithm
+
+
+
+
+ The parameters for initializing a LLama model.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ `Encoding` cannot be directly JSON serialized, instead store the name as a string which can
+
+
+
+
+
+
+
+
+
+ The model path.
+
+
+
+ Base class for LLamaSharp runtime errors (i.e. errors produced by llama.cpp, converted into exceptions)
+
+
+
+
+ Create a new RuntimeError
+
+
+
+
+
+ Loading model weights failed
+
+
+
+
+ The model path which failed to load
+
+
+
+
+
+
+
+ `llama_decode` return a non-zero status code
+
+
+
+
+ The return status code
+
+
+
+
+
+
+
+ `llama_decode` return a non-zero status code
+
+
+
+
+
+
+
+
+
+
+ `llama_get_logits_ith` returned null, indicating that the index was invalid
+
+
+
+
+ The incorrect index passed to the `llama_get_logits_ith` call
+
+
+
+
+
+
+
+ Extension methods to the IContextParams interface
+
+
+
+
+ Convert the given `IModelParams` into a `LLamaContextParams`
+
+
+
+
+
+
+
+
+
+ Extension methods to the IModelParams interface
+
+
+
+
+ Convert the given `IModelParams` into a `LLamaModelParams`
+
+
+
+
+
+
+
+
+
+ Find the index of `item` in `list`
+
+
+ list to search
+ item to search for
+
+
+
+
+ Check if the given set of tokens ends with any of the given strings
+
+ Tokens to check
+ Strings to search for
+ Model to use to convert tokens into bytes
+ Encoding to use to convert bytes into characters
+
+
+
+
+ Check if the given set of tokens ends with any of the given strings
+
+ Tokens to check
+ Strings to search for
+ Model to use to convert tokens into bytes
+ Encoding to use to convert bytes into characters
+
+
+
+
+ Extensions to the KeyValuePair struct
+
+
+
+
+ Deconstruct a KeyValuePair into it's constituent parts.
+
+ The KeyValuePair to deconstruct
+ First element, the Key
+ Second element, the Value
+ Type of the Key
+ Type of the Value
+
+
+
+ Run a process for a certain amount of time and then terminate it
+
+
+
+ return code, standard output, standard error, flag indicating if process exited or was terminated
+
+
+
+ Extensions to span which apply in-place normalization
+
+
+
+
+ In-place multiple every element by 32760 and divide every element in the span by the max absolute value in the span
+
+
+ The same array
+
+
+
+ In-place multiple every element by 32760 and divide every element in the span by the max absolute value in the span
+
+
+ The same span
+
+
+
+ In-place divide every element in the array by the sum of absolute values in the array
+
+ Also known as "Manhattan normalization".
+
+ The same array
+
+
+
+ In-place divide every element in the span by the sum of absolute values in the span
+
+ Also known as "Manhattan normalization".
+
+ The same span
+
+
+
+ In-place divide every element by the euclidean length of the vector
+
+ Also known as "L2 normalization".
+
+ The same array
+
+
+
+ In-place divide every element by the euclidean length of the vector
+
+ Also known as "L2 normalization".
+
+ The same span
+
+
+
+ Creates a new array containing an L2 normalization of the input vector.
+
+
+ The same span
+
+
+
+ In-place apply p-normalization. https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm
+
+ - For p = 1, this is taxicab normalization
+ - For p = 2, this is euclidean normalization
+ - As p => infinity, this approaches infinity norm or maximum norm
+
+
+
+
+ The same array
+
+
+
+ In-place apply p-normalization. https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm
+
+ - For p = 1, this is taxicab normalization
+ - For p = 2, this is euclidean normalization
+ - As p => infinity, this approaches infinity norm or maximum norm
+
+
+
+
+ The same span
+
+
+
+ A llama_context, which holds all the context required to interact with a model
+
+
+
+
+ Total number of tokens in the context
+
+
+
+
+ Dimension of embedding vectors
+
+
+
+
+ The context params set for this context
+
+
+
+
+ The native handle, which is used to be passed to the native APIs
+
+ Be careful how you use this!
+
+
+
+ The encoding set for this model to deal with text input.
+
+
+
+
+ Get or set the number of threads to use for generation
+
+
+
+
+ Get or set the number of threads to use for batch processing
+
+
+
+
+ Get the maximum batch size for this context
+
+
+
+
+ Get the special tokens for the model associated with this context
+
+
+
+
+ Create a new LLamaContext for the given LLamaWeights
+
+
+
+
+
+
+
+
+ Tokenize a string.
+
+
+ Whether to add a bos to the text.
+ Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
+
+
+
+
+ Detokenize the tokens to text.
+
+
+
+
+
+
+ Save the state to specified path.
+
+
+
+
+
+ Save the state of a particular sequence to specified path.
+
+
+
+
+
+
+ Get the state data as an opaque handle, which can be loaded later using
+
+ Use if you intend to save this state to disk.
+
+
+
+
+ Get the state data as an opaque handle, which can be loaded later using
+
+ Use if you intend to save this state to disk.
+
+
+
+
+ Load the state from specified path.
+
+
+
+
+
+ Load the state from specified path into a particular sequence
+
+
+
+
+
+
+ Load the state from memory.
+
+
+
+
+
+ Load the state from memory into a particular sequence
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ A tuple, containing the decode result, the number of tokens that have not been decoded yet and the total number of tokens that have been decoded.
+
+
+
+
+
+
+ The state of this context, which can be reloaded later
+
+
+
+
+ Get the size in bytes of this state object
+
+
+
+
+
+
+
+ Write all the bytes of this state to the given stream
+
+
+
+
+
+ Write all the bytes of this state to the given stream
+
+
+
+
+
+ Load a state from a stream
+
+
+
+
+
+
+ Load a state from a stream
+
+
+
+
+
+
+ The state of a single sequence, which can be reloaded later
+
+
+
+
+ Get the size in bytes of this state object
+
+
+
+
+
+
+
+ Copy bytes to a destination pointer.
+
+ Destination to write to
+ Length of the destination buffer
+ Offset from start of src to start copying from
+ Number of bytes written to destination
+
+
+
+ Generate high dimensional embedding vectors from text
+
+
+
+
+ Dimension of embedding vectors
+
+
+
+
+ LLama Context
+
+
+
+
+ Create a new embedder, using the given LLamaWeights
+
+
+
+
+
+
+
+
+
+
+ Get high dimensional embedding vectors for the given text. Depending on the pooling type used when constructing
+ this this may return an embedding vector per token, or one single embedding vector for the entire string.
+
+ Embedding vectors are not normalized, consider using one of the extensions in .
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ The base class for stateful LLama executors.
+
+
+
+
+ The logger used by this executor.
+
+
+
+
+ The tokens that were already processed by the model.
+
+
+
+
+ The tokens that were consumed by the model during the current inference.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ The path of the session file.
+
+
+
+
+ A container of the tokens to be processed and after processed.
+
+
+
+
+ A container for the tokens of input.
+
+
+
+
+
+
+
+
+
+ The last tokens generated by the model.
+
+
+
+
+ The context used by the executor.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ This API is currently not verified.
+
+
+
+
+
+
+
+
+ This API has not been verified currently.
+
+
+
+
+
+ After running out of the context, take some tokens from the original prompt and recompute the logits in batches.
+
+
+
+
+
+ Try to reuse the matching prefix from the session file.
+
+
+
+
+ Decide whether to continue the loop.
+
+
+
+
+
+
+ Preprocess the inputs before the inference.
+
+
+
+
+
+
+ Do some post processing after the inference.
+
+
+
+
+
+
+
+ The core inference logic.
+
+
+
+
+
+
+ Save the current state to a file.
+
+
+
+
+
+ Get the current state data.
+
+
+
+
+
+ Load the state from data.
+
+
+
+
+
+ Load the state from a file.
+
+
+
+
+
+ Execute the inference.
+
+ The prompt. If null, generation will continue where it left off previously.
+
+
+
+
+
+
+ Asynchronously runs a prompt through the model to compute KV cache without generating any new tokens.
+ It could reduce the latency of the first time response if the first input from the user is not immediate.
+
+ Prompt to process
+
+
+
+
+ State arguments that are used in single inference
+
+
+
+
+
+
+
+
+
+ Tokens count remained to be used. (n_remain)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ The LLama executor for instruct mode.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ The descriptor of the state of the instruct executor.
+
+
+
+
+ Whether the executor is running for the first time (running the prompt).
+
+
+
+
+ Instruction prefix tokens.
+
+
+
+
+ Instruction suffix tokens.
+
+
+
+
+ The LLama executor for interactive mode.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Define whether to continue the loop to generate responses.
+
+
+
+
+
+
+
+
+
+
+
+ Return whether to break the generation.
+
+
+
+
+
+
+
+
+
+
+ The descriptor of the state of the interactive executor.
+
+
+
+
+ Whether the executor is running for the first time (running the prompt).
+
+
+
+
+ The quantizer to quantize the model.
+
+
+
+
+ Quantize the model.
+
+ The model file to be quantized.
+ The path to save the quantized model.
+ The type of quantization.
+ Thread to be used during the quantization. By default it's the physical core number.
+
+
+ Whether the quantization is successful.
+
+
+
+
+ Quantize the model.
+
+ The model file to be quantized.
+ The path to save the quantized model.
+ The type of quantization.
+ Thread to be used during the quantization. By default it's the physical core number.
+
+
+ Whether the quantization is successful.
+
+
+
+
+ Parse a string into a LLamaFtype. This is a "relaxed" parsing, which allows any string which is contained within
+ the enum name to be used.
+
+ For example "Q5_K_M" will convert to "LLAMA_FTYPE_MOSTLY_Q5_K_M"
+
+
+
+
+
+
+
+ This executor infer the input as one-time job. Previous inputs won't impact on the
+ response to current input.
+
+
+
+
+
+
+
+
+
+
+
+
+
+ The context used by the executor when running the inference.
+
+
+
+
+ If true, applies the default template to the prompt as defined in the rules for llama_chat_apply_template template.
+
+
+
+
+ The system message to use with the prompt. Only used when is true.
+
+
+
+
+ Create a new stateless executor which will use the given model
+
+
+
+
+
+
+
+
+
+
+ Converts a sequence of messages into text according to a model template
+
+
+
+
+ Custom template. May be null if a model was supplied to the constructor.
+
+
+
+
+ Keep a cache of roles converted into bytes. Roles are very frequently re-used, so this saves converting them many times.
+
+
+
+
+ Array of messages. The property indicates how many messages there are
+
+
+
+
+ Backing field for
+
+
+
+
+ Temporary array of messages in the format llama.cpp needs, used when applying the template
+
+
+
+
+ Indicates how many bytes are in array
+
+
+
+
+ Result bytes of last call to
+
+
+
+
+ Indicates if this template has been modified and needs regenerating
+
+
+
+
+ The encoding algorithm to use
+
+
+
+
+ Number of messages added to this template
+
+
+
+
+ Get the message at the given index
+
+
+
+ Thrown if index is less than zero or greater than or equal to
+
+
+
+ Whether to end the prompt with the token(s) that indicate the start of an assistant message.
+
+
+
+
+ Construct a new template, using the default model template
+
+
+
+
+
+
+ Construct a new template, using the default model template
+
+
+
+
+
+ Construct a new template, using a custom template.
+
+ Only support a pre-defined list of templates. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
+
+
+
+
+ Add a new message to the end of this template
+
+
+
+ This template, for chaining calls.
+
+
+
+ Add a new message to the end of this template
+
+
+ This template, for chaining calls.
+
+
+
+ Remove a message at the given index
+
+
+ This template, for chaining calls.
+
+
+
+ Remove all messages from the template and resets internal state to accept/generate new messages
+
+
+
+
+ Apply the template to the messages and return a span containing the results
+
+ A span over the buffer that holds the applied template
+
+
+
+ A message that has been added to a template
+
+
+
+
+ The "role" string for this message
+
+
+
+
+ The text content of this message
+
+
+
+
+ Deconstruct this message into role and content
+
+
+
+
+
+
+ A class that contains all the transforms provided internally by LLama.
+
+
+
+
+ The default history transform.
+ Uses plain text with the following format:
+ [Author]: [Message]
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Drop the name at the beginning and the end of the text.
+
+
+
+
+
+
+
+ A text input transform that only trims the text.
+
+
+
+
+
+
+
+
+
+
+ A no-op text input transform.
+
+
+
+
+
+
+
+
+
+
+ A text output transform that removes the keywords from the response.
+
+
+
+
+ Keywords that you want to remove from the response.
+ This property is used for JSON serialization.
+
+
+
+
+ Maximum length of the keywords.
+ This property is used for JSON serialization.
+
+
+
+
+ If set to true, when getting a matched keyword, all the related tokens will be removed.
+ Otherwise only the part of keyword will be removed.
+ This property is used for JSON serialization.
+
+
+
+
+ JSON constructor.
+
+
+
+
+
+
+ Keywords that you want to remove from the response.
+ The extra length when searching for the keyword. For example, if your only keyword is "highlight",
+ maybe the token you get is "\r\nhighligt". In this condition, if redundancyLength=0, the token cannot be successfully matched because the length of "\r\nhighligt" (10)
+ has already exceeded the maximum length of the keywords (8). On the contrary, setting redundancyLengyh >= 2 leads to successful match.
+ The larger the redundancyLength is, the lower the processing speed. But as an experience, it won't introduce too much performance impact when redundancyLength <= 5
+ If set to true, when getting a matched keyword, all the related tokens will be removed. Otherwise only the part of keyword will be removed.
+
+
+
+
+
+
+
+
+
+ A set of model weights, loaded into memory.
+
+
+
+
+ The native handle, which is used in the native APIs
+
+ Be careful how you use this!
+
+
+
+ Total number of tokens in the context
+
+
+
+
+ Get the size of this model in bytes
+
+
+
+
+ Get the number of parameters in this model
+
+
+
+
+ Dimension of embedding vectors
+
+
+
+
+ Get the special tokens of this model
+
+
+
+
+ All metadata keys in this model
+
+
+
+
+ Load weights into memory
+
+
+
+
+
+
+ Load weights into memory
+
+ Parameters to use to load the model
+ A cancellation token that can interrupt model loading
+ Receives progress updates as the model loads (0 to 1)
+
+ Thrown if weights failed to load for any reason. e.g. Invalid file format or loading cancelled.
+ Thrown if the cancellation token is cancelled.
+
+
+
+
+
+
+ Create a llama_context using this model
+
+
+
+
+
+
+
+ Convert a string of text into tokens
+
+
+
+
+ Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
+
+
+
+
+ A set of llava model weights (mmproj), loaded into memory.
+
+
+
+
+ The native handle, which is used in the native APIs
+
+ Be careful how you use this!
+
+
+
+ Load weights into memory
+
+ path to the "mmproj" model file
+
+
+
+
+ Load weights into memory
+
+ path to the "mmproj" model file
+
+
+
+
+
+ Create the Image Embeddings from the bytes of an image.
+
+
+ Image bytes. Supported formats:
+
+ - JPG
+ - PNG
+ - BMP
+ - TGA
+
+
+
+
+
+
+ Create the Image Embeddings.
+
+ Image in binary format (it supports jpeg format only)
+ Number of threads to use
+ return the SafeHandle of these embeddings
+
+
+
+ Create the Image Embeddings from the bytes of an image.
+
+
+ Path to the image file. Supported formats:
+
+ - JPG
+ - PNG
+ - BMP
+ - TGA
+
+
+
+
+
+
+
+ Create the Image Embeddings from the bytes of an image.
+
+ Path to the image file. Supported formats:
+
+ - JPG
+ - PNG
+ - BMP
+ - TGA
+
+
+
+
+
+
+
+
+ Eval the image embeddings
+
+
+
+
+
+
+
+
+
+
+
+ Return codes from llama_decode
+
+
+
+
+ An unspecified error
+
+
+
+
+ Ok.
+
+
+
+
+ Could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
+
+
+
+
+ Return codes from llama_encode
+
+
+
+
+ An unspecified error
+
+
+
+
+ Ok.
+
+
+
+
+ Possible GGML quantisation types
+
+
+
+
+ Full 32 bit float
+
+
+
+
+ 16 bit float
+
+
+
+
+ 4 bit float
+
+
+
+
+ 4 bit float
+
+
+
+
+ 5 bit float
+
+
+
+
+ 5 bit float
+
+
+
+
+ 8 bit float
+
+
+
+
+ 8 bit float
+
+
+
+
+ "type-1" 2-bit quantization in super-blocks containing 16 blocks, each block having 16 weight.
+ Block scales and mins are quantized with 4 bits. This ends up effectively using 2.5625 bits per weight (bpw)
+
+
+
+
+ "type-0" 3-bit quantization in super-blocks containing 16 blocks, each block having 16 weights.
+ Scales are quantized with 6 bits. This end up using 3.4375 bpw.
+
+
+
+
+ "type-1" 4-bit quantization in super-blocks containing 8 blocks, each block having 32 weights.
+ Scales and mins are quantized with 6 bits. This ends up using 4.5 bpw.
+
+
+
+
+ "type-1" 5-bit quantization. Same super-block structure as GGML_TYPE_Q4_K resulting in 5.5 bpw
+
+
+
+
+ "type-0" 6-bit quantization. Super-blocks with 16 blocks, each block having 16 weights.
+ Scales are quantized with 8 bits. This ends up using 6.5625 bpw
+
+
+
+
+ "type-0" 8-bit quantization. Only used for quantizing intermediate results.
+ The difference to the existing Q8_0 is that the block size is 256. All 2-6 bit dot products are implemented for this quantization type.
+
+
+
+
+ Integer, 8 bit
+
+
+
+
+ Integer, 16 bit
+
+
+
+
+ Integer, 32 bit
+
+
+
+
+ The value of this entry is the count of the number of possible quant types.
+
+
+
+
+
+
+ llama_split_mode
+
+
+
+ Single GPU
+
+
+
+
+ Split layers and KV across GPUs
+
+
+
+
+ split layers and KV across GPUs, use tensor parallelism if supported
+
+
+
+
+ Disposes all contained disposables when this class is disposed
+
+
+
+
+
+
+
+
+
+
+
+
+ llama_attention_type
+
+
+
+ A batch allows submitting multiple tokens to multiple sequences simultaneously
+
+
+
+
+ Keep a list of where logits can be sampled from
+
+
+
+
+ Get the number of logit positions that will be generated from this batch
+
+
+
+
+ The number of tokens in this batch
+
+
+
+
+ Maximum number of tokens that can be added to this batch (automatically grows if exceeded)
+
+
+
+
+ Maximum number of sequences a token can be assigned to (automatically grows if exceeded)
+
+
+
+
+ Create a new batch for submitting inputs to llama.cpp
+
+
+
+
+ Add a single token to the batch at the same position in several sequences
+
+ https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2
+ The token to add
+ The position to add it att
+ The set of sequences to add this token to
+
+ The index that the token was added at. Use this for GetLogitsIth
+
+
+
+ Add a single token to the batch at the same position in several sequences
+
+ https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2
+ The token to add
+ The position to add it att
+ The set of sequences to add this token to
+
+ The index that the token was added at. Use this for GetLogitsIth
+
+
+
+ Add a single token to the batch at a certain position for a single sequences
+
+ https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2
+ The token to add
+ The position to add it att
+ The sequence to add this token to
+
+ The index that the token was added at. Use this for GetLogitsIth
+
+
+
+ Add a range of tokens to a single sequence, start at the given position.
+
+ The tokens to add
+ The starting position to add tokens at
+ The sequence to add this token to
+ Whether the final token should generate logits
+ The index that the final token was added at. Use this for GetLogitsIth
+
+
+
+ Set TokenCount to zero for this batch
+
+
+
+
+ Get the positions where logits can be sampled from
+
+
+
+
+
+ An embeddings batch allows submitting embeddings to multiple sequences simultaneously
+
+
+
+
+ Keep a list of where logits can be sampled from
+
+
+
+
+ Get the number of logit positions that will be generated from this batch
+
+
+
+
+ Size of an individual embedding
+
+
+
+
+ The number of items in this batch
+
+
+
+
+ Maximum number of items that can be added to this batch (automatically grows if exceeded)
+
+
+
+
+ Maximum number of sequences an item can be assigned to (automatically grows if exceeded)
+
+
+
+
+ Create a new batch for submitting inputs to llama.cpp
+
+
+
+
+ Add a single embedding to the batch at the same position in several sequences
+
+ https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2
+ The embedding to add
+ The position to add it att
+ The set of sequences to add this token to
+
+ The index that the token was added at. Use this for GetLogitsIth
+
+
+
+ Add a single embedding to the batch for a single sequence
+
+
+
+
+
+ The index that the token was added at. Use this for GetLogitsIth
+
+
+
+ Called by embeddings batch to write embeddings into a destination span
+
+ Type of user data parameter passed in
+ Destination to write data to. Entire destination must be filled!
+ User data parameter passed in
+
+
+
+ Add a single embedding to the batch at the same position in several sequences
+
+ https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2
+ Type of userdata passed to write delegate
+ Userdata passed to write delegate
+ Delegate called once to write data into a span
+ Position to write this embedding to
+ All sequences to assign this embedding to
+ Whether logits should be generated for this embedding
+ The index that the token was added at. Use this for GetLogitsIth
+
+
+
+ Add a single embedding to the batch at a position for one sequence
+
+ https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2
+ Type of userdata passed to write delegate
+ Userdata passed to write delegate
+ Delegate called once to write data into a span
+ Position to write this embedding to
+ Sequence to assign this embedding to
+ Whether logits should be generated for this embedding
+ The index that the token was added at. Use this for GetLogitsIth
+
+
+
+ Set EmbeddingsCount to zero for this batch
+
+
+
+
+ Get the positions where logits can be sampled from
+
+
+
+
+
+
+
+ llama_chat_message
+
+
+
+ Pointer to the null terminated bytes that make up the role string
+
+
+
+
+ Pointer to the null terminated bytes that make up the content string
+
+
+
+
+ Called by llama.cpp with a progress value between 0 and 1
+
+
+
+ If the provided progress_callback returns true, model loading continues.
+ If it returns false, model loading is immediately aborted.
+ llama_progress_callback
+
+
+
+ A C# representation of the llama.cpp `llama_context_params` struct
+
+ changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
+ https://github.com/ggerganov/llama.cpp/pull/7544
+
+
+
+ text context, 0 = from model
+
+
+
+
+ logical maximum batch size that can be submitted to llama_decode
+
+
+
+
+ physical maximum batch size
+
+
+
+
+ max number of sequences (i.e. distinct states for recurrent models)
+
+
+
+
+ number of threads to use for generation
+
+
+
+
+ number of threads to use for batch processing
+
+
+
+
+ RoPE scaling type, from `enum llama_rope_scaling_type`
+
+
+
+
+ whether to pool (sum) embedding results by sequence id
+
+
+
+
+ Attention type to use for embeddings
+
+
+
+
+ RoPE base frequency, 0 = from model
+
+
+
+
+ RoPE frequency scaling factor, 0 = from model
+
+
+
+
+ YaRN extrapolation mix factor, negative = from model
+
+
+
+
+ YaRN magnitude scaling factor
+
+
+
+
+ YaRN low correction dim
+
+
+
+
+ YaRN high correction dim
+
+
+
+
+ YaRN original context size
+
+
+
+
+ defragment the KV cache if holes/size > defrag_threshold, Set to < 0 to disable (default)
+
+
+
+
+ ggml_backend_sched_eval_callback
+
+
+
+
+ User data passed into cb_eval
+
+
+
+
+ data type for K cache. EXPERIMENTAL
+
+
+
+
+ data type for V cache. EXPERIMENTAL
+
+
+
+
+ Deprecated!
+
+
+
+
+ if true, extract embeddings (together with logits)
+
+
+
+
+ whether to offload the KQV ops (including the KV cache) to GPU
+
+
+
+
+ whether to use flash attention. EXPERIMENTAL
+
+
+
+
+ whether to measure performance timings
+
+
+
+
+ ggml_abort_callback
+
+
+
+
+ User data passed into abort_callback
+
+
+
+
+ Get the default LLamaContextParams
+
+
+
+
+
+ Supported model file types
+
+ C# representation of llama_ftype
+
+
+
+ All f32
+
+ Benchmark@7B: 26GB
+
+
+
+ Mostly f16
+
+ Benchmark@7B: 13GB
+
+
+
+ Mostly 8 bit
+
+ Benchmark@7B: 6.7GB, +0.0004ppl
+
+
+
+ Mostly 4 bit
+
+ Benchmark@7B: 3.50GB, +0.2499 ppl
+
+
+
+ Mostly 4 bit
+
+ Benchmark@7B: 3.90GB, +0.1846 ppl
+
+
+
+ Mostly 5 bit
+
+ Benchmark@7B: 4.30GB @ 7B tokens, +0.0796 ppl
+
+
+
+ Mostly 5 bit
+
+ Benchmark@7B: 4.70GB, +0.0415 ppl
+
+
+
+ K-Quant 2 bit
+
+ Benchmark@7B: 2.67GB @ 7N parameters, +0.8698 ppl
+
+
+
+ K-Quant 3 bit (Small)
+
+ Benchmark@7B: 2.75GB, +0.5505 ppl
+
+
+
+ K-Quant 3 bit (Medium)
+
+ Benchmark@7B: 3.06GB, +0.2437 ppl
+
+
+
+ K-Quant 3 bit (Large)
+
+ Benchmark@7B: 3.35GB, +0.1803 ppl
+
+
+
+ K-Quant 4 bit (Small)
+
+ Benchmark@7B: 3.56GB, +0.1149 ppl
+
+
+
+ K-Quant 4 bit (Medium)
+
+ Benchmark@7B: 3.80GB, +0.0535 ppl
+
+
+
+ K-Quant 5 bit (Small)
+
+ Benchmark@7B: 4.33GB, +0.0353 ppl
+
+
+
+ K-Quant 5 bit (Medium)
+
+ Benchmark@7B: 4.45GB, +0.0142 ppl
+
+
+
+ K-Quant 6 bit
+
+ Benchmark@7B: 5.15GB, +0.0044 ppl
+
+
+
+ except 1d tensors
+
+
+
+
+ except 1d tensors
+
+
+
+
+ except 1d tensors
+
+
+
+
+ except 1d tensors
+
+
+
+
+ except 1d tensors
+
+
+
+
+ except 1d tensors
+
+
+
+
+ except 1d tensors
+
+
+
+
+ except 1d tensors
+
+
+
+
+ except 1d tensors
+
+
+
+
+ except 1d tensors
+
+
+
+
+ except 1d tensors
+
+
+
+
+ except 1d tensors
+
+
+
+
+ except 1d tensors
+
+
+
+
+ except 1d tensors
+
+
+
+
+ except 1d tensors
+
+
+
+
+ except 1d tensors
+
+
+
+
+ File type was not specified
+
+
+
+
+ A safe handle for a LLamaKvCacheView
+
+
+
+
+ Number of KV cache cells. This will be the same as the context size.
+
+
+
+
+ Get the total number of tokens in the KV cache.
+
+ For example, if there are two populated
+ cells, the first with 1 sequence id in it and the second with 2 sequence
+ ids then you'll have 3 tokens.
+
+
+
+
+ Maximum number of sequences visible for a cell. There may be more sequences than this
+ in reality, this is simply the maximum number this view can see.
+
+
+
+
+ Number of populated cache cells
+
+
+
+
+ Maximum contiguous empty slots in the cache.
+
+
+
+
+ Index to the start of the MaxContiguous slot range. Can be negative when cache is full.
+
+
+
+
+ Initialize a LLamaKvCacheViewSafeHandle which will call `llama_kv_cache_view_free` when disposed
+
+
+
+
+
+
+ Allocate a new KV cache view which can be used to inspect the KV cache
+
+
+ The maximum number of sequences visible in this view per cell
+
+
+
+
+
+
+
+ Read the current KV cache state into this view.
+
+
+
+
+ Get the raw KV cache view
+
+
+
+
+
+ Get the cell at the given index
+
+ The index of the cell [0, CellCount)
+ Data about the cell at the given index
+ Thrown if index is out of range (0 <= index < CellCount)
+
+
+
+ Get all of the sequences assigned to the cell at the given index. This will contain entries
+ sequences even if the cell actually has more than that many sequences, allocate a new view with a larger maxSequences parameter
+ if necessary. Invalid sequences will be negative values.
+
+ The index of the cell [0, CellCount)
+ A span containing the sequences assigned to this cell
+ Thrown if index is out of range (0 <= index < CellCount)
+
+
+
+ Create an empty KV cache view. (use only for debugging purposes)
+
+
+
+
+
+
+
+ Free a KV cache view. (use only for debugging purposes)
+
+
+
+
+ Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
+
+
+
+
+
+
+ Information associated with an individual cell in the KV cache view (llama_kv_cache_view_cell)
+
+
+
+
+ The position for this cell. Takes KV cache shifts into account.
+ May be negative if the cell is not populated.
+
+
+
+
+ An updateable view of the KV cache (llama_kv_cache_view)
+
+
+
+
+ Number of KV cache cells. This will be the same as the context size.
+
+
+
+
+ Maximum number of sequences that can exist in a cell. It's not an error
+ if there are more sequences in a cell than this value, however they will
+ not be visible in the view cells_sequences.
+
+
+
+
+ Number of tokens in the cache. For example, if there are two populated
+ cells, the first with 1 sequence id in it and the second with 2 sequence
+ ids then you'll have 3 tokens.
+
+
+
+
+ Number of populated cache cells.
+
+
+
+
+ Maximum contiguous empty slots in the cache.
+
+
+
+
+ Index to the start of the max_contiguous slot range. Can be negative
+ when cache is full.
+
+
+
+
+ Information for an individual cell.
+
+
+
+
+ The sequences for each cell. There will be n_seq_max items per cell.
+
+
+
+
+ Severity level of a log message. This enum should always be aligned with
+ the one defined on llama.cpp side at
+ https://github.com/ggerganov/llama.cpp/blob/0eb4e12beebabae46d37b78742f4c5d4dbe52dc1/ggml/include/ggml.h#L559
+
+
+
+
+ Logs are never written.
+
+
+
+
+ Logs that are used for interactive investigation during development.
+
+
+
+
+ Logs that track the general flow of the application.
+
+
+
+
+ Logs that highlight an abnormal or unexpected event in the application flow, but do not otherwise cause the application execution to stop.
+
+
+
+
+ Logs that highlight when the current flow of execution is stopped due to a failure.
+
+
+
+
+ Continue log level is equivalent to None in the way it is used in llama.cpp.
+
+
+
+
+ Keeps track of the previous log level to be able to handle the log level .
+
+
+
+
+ Override a key/value pair in the llama model metadata (llama_model_kv_override)
+
+
+
+
+ Key to override
+
+
+
+
+ Type of value
+
+
+
+
+ Add 4 bytes of padding, to align the next fields to 8 bytes
+
+
+
+
+ Value, **must** only be used if Tag == LLAMA_KV_OVERRIDE_INT
+
+
+
+
+ Value, **must** only be used if Tag == LLAMA_KV_OVERRIDE_FLOAT
+
+
+
+
+ Value, **must** only be used if Tag == LLAMA_KV_OVERRIDE_BOOL
+
+
+
+
+ Value, **must** only be used if Tag == String
+
+
+
+
+ Specifies what type of value is being overridden by LLamaModelKvOverride
+
+ llama_model_kv_override_type
+
+
+
+ Overriding an int value
+
+
+
+
+ Overriding a float value
+
+
+
+
+ Overriding a bool value
+
+
+
+
+ Overriding a string value
+
+
+
+
+ A C# representation of the llama.cpp `llama_model_params` struct
+
+
+
+
+ NULL-terminated list of devices to use for offloading (if NULL, all available devices are used)
+ todo: add support for llama_model_params.devices
+
+
+
+
+ // number of layers to store in VRAM
+
+
+
+
+ how to split the model across multiple GPUs
+
+
+
+
+ the GPU that is used for the entire model when split_mode is LLAMA_SPLIT_MODE_NONE
+
+
+
+
+ how to split layers across multiple GPUs (size: )
+
+
+
+
+ called with a progress value between 0 and 1, pass NULL to disable. If the provided progress_callback
+ returns true, model loading continues. If it returns false, model loading is immediately aborted.
+
+
+
+
+ context pointer passed to the progress callback
+
+
+
+
+ override key-value pairs of the model meta data
+
+
+
+
+ only load the vocabulary, no weights
+
+
+
+
+ use mmap if possible
+
+
+
+
+ force system to keep model in RAM
+
+
+
+
+ validate model tensor data
+
+
+
+
+ Create a LLamaModelParams with default values
+
+
+
+
+
+ Quantizer parameters used in the native API
+
+ llama_model_quantize_params
+
+
+
+ number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
+
+
+
+
+ quantize to this llama_ftype
+
+
+
+
+ output tensor type
+
+
+
+
+ token embeddings tensor type
+
+
+
+
+ allow quantizing non-f32/f16 tensors
+
+
+
+
+ quantize output.weight
+
+
+
+
+ only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
+
+
+
+
+ quantize all tensors to the default type
+
+
+
+
+ quantize to the same number of shards
+
+
+
+
+ pointer to importance matrix data
+
+
+
+
+ pointer to vector containing overrides
+
+
+
+
+ Create a LLamaModelQuantizeParams with default values
+
+
+
+
+
+ Input data for llama_decode
+ A llama_batch object can contain input about one or many sequences
+ The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
+
+
+
+
+ The number of items pointed at by pos, seq_id and logits.
+
+
+
+
+ Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created
+
+
+
+
+ Either `n_tokens * embd * sizeof(float)` or `NULL`, depending on how this batch was created
+
+
+
+
+ the positions of the respective token in the sequence
+ (if set to NULL, the token position will be tracked automatically by llama_decode)
+
+
+
+
+ https://github.com/ggerganov/llama.cpp/blob/master/llama.h#L139 ???
+
+
+
+
+ the sequence to which the respective token belongs
+ (if set to NULL, the sequence ID will be assumed to be 0)
+
+
+
+
+ if zero, the logits for the respective token will not be output
+ (if set to NULL, only the logits for last token will be returned)
+
+
+
+
+
+
+ llama_pooling_type
+
+
+
+ No specific pooling type. Use the model default if this is specific in
+
+
+
+
+ Do not pool embeddings (per-token embeddings)
+
+
+
+
+ Take the mean of every token embedding
+
+
+
+
+ Return the embedding for the special "CLS" token
+
+
+
+
+ Used by reranking models to attach the classification head to the graph
+
+
+
+
+ Indicates position in a sequence
+
+
+
+
+ The raw value
+
+
+
+
+ Create a new LLamaPos
+
+
+
+
+
+ Convert a LLamaPos into an integer (extract the raw value)
+
+
+
+
+
+
+ Convert an integer into a LLamaPos
+
+
+
+
+
+
+ Increment this position
+
+
+
+
+
+
+ Increment this position
+
+
+
+
+
+
+
+
+ llama_rope_type
+
+
+
+ ID for a sequence in a batch
+
+
+
+
+ LLamaSeqId with value 0
+
+
+
+
+ The raw value
+
+
+
+
+ Create a new LLamaSeqId
+
+
+
+
+
+ Convert a LLamaSeqId into an integer (extract the raw value)
+
+
+
+
+
+ Convert an integer into a LLamaSeqId
+
+
+
+
+
+
+
+
+
+ LLama performance information
+
+ llama_perf_context_data
+
+
+
+ Timestamp when reset was last called
+
+
+
+
+ Loading milliseconds
+
+
+
+
+ total milliseconds spent prompt processing
+
+
+
+
+ Total milliseconds in eval/decode calls
+
+
+
+
+ number of tokens in eval calls for the prompt (with batch size > 1)
+
+
+
+
+ number of eval calls
+
+
+
+
+ Timestamp when reset was last called
+
+
+
+
+ Time spent loading
+
+
+
+
+ total milliseconds spent prompt processing
+
+
+
+
+ Total milliseconds in eval/decode calls
+
+
+
+
+ number of tokens in eval calls for the prompt (with batch size > 1)
+
+
+
+
+ number of eval calls
+
+
+
+
+ LLama performance information
+
+ llama_perf_sampler_data
+
+
+
+ A single token
+
+
+
+
+ Token Value used when token is inherently null
+
+
+
+
+ The raw value
+
+
+
+
+ Create a new LLamaToken
+
+
+
+
+
+ Convert a LLamaToken into an integer (extract the raw value)
+
+
+
+
+
+
+ Convert an integer into a LLamaToken
+
+
+
+
+
+
+ Get attributes for this token
+
+
+
+
+
+
+ Get attributes for this token
+
+
+
+
+
+
+ Get score for this token
+
+
+
+
+
+
+ Check if this is a control token
+
+
+
+
+
+
+ Check if this is a control token
+
+
+
+
+
+
+ Check if this token should end generation
+
+
+
+
+
+
+ Check if this token should end generation
+
+
+
+
+
+
+
+
+
+ Token attributes
+
+ C# equivalent of llama_token_attr
+
+
+
+ A single token along with probability of this token being selected
+
+
+
+
+ token id
+
+
+
+
+ log-odds of the token
+
+
+
+
+ probability of the token
+
+
+
+
+ Create a new LLamaTokenData
+
+
+
+
+
+
+
+ Contains an array of LLamaTokenData, potentially sorted.
+
+
+
+
+ The LLamaTokenData
+
+
+
+
+ Indicates if `data` is sorted by logits in descending order. If this is false the token data is in _no particular order_.
+
+
+
+
+ Create a new LLamaTokenDataArray
+
+
+
+
+
+
+ Create a new LLamaTokenDataArray, copying the data from the given logits
+
+
+
+
+
+
+ Create a new LLamaTokenDataArray, copying the data from the given logits into temporary memory.
+
+ The memory must not be modified while this is in use.
+
+ Temporary memory which will be used to work on these logits. Must be at least as large as logits array
+
+
+
+
+ Overwrite the logit values for all given tokens
+
+ tuples of token and logit value to overwrite
+
+
+
+ Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
+
+
+
+
+ Contains a pointer to an array of LLamaTokenData which is pinned in memory.
+
+ C# equivalent of llama_token_data_array
+
+
+
+ A pointer to an array of LlamaTokenData
+
+ Memory must be pinned in place for all the time this LLamaTokenDataArrayNative is in use (i.e. `fixed` or `.Pin()`)
+
+
+
+ Number of LLamaTokenData in the array
+
+
+
+
+ The index in the array (i.e. not the token id)
+
+
+
+
+ A pointer to an array of LlamaTokenData
+
+
+
+
+ Indicates if the items in the array are sorted, so the most likely token is first
+
+
+
+
+ The index of the selected token (i.e. not the token id)
+
+
+
+
+ Number of LLamaTokenData in the array. Set this to shrink the array
+
+
+
+
+ Create a new LLamaTokenDataArrayNative around the data in the LLamaTokenDataArray
+
+ Data source
+ Created native array
+ A memory handle, pinning the data in place until disposed
+
+
+
+ C# equivalent of llama_vocab struct. This struct is an opaque type, with no fields in the API and is only used for typed pointers.
+
+
+
+
+ Get attributes for a specific token
+
+
+
+
+
+
+
+ Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
+
+
+
+
+
+
+
+ Identify if Token Id is a control token or a render-able token
+
+
+
+
+
+
+
+ beginning-of-sentence
+
+
+
+
+
+
+ end-of-sentence
+
+
+
+
+
+
+ end-of-turn
+
+
+
+
+
+
+ sentence separator
+
+
+
+
+
+
+ next-line
+
+
+
+
+
+
+ padding
+
+
+
+
+
+
+
+
+ llama_vocab_pre_type
+
+
+
+
+
+ llama_vocab_type
+
+
+
+ For models without vocab
+
+
+
+
+ LLaMA tokenizer based on byte-level BPE with byte fallback
+
+
+
+
+ GPT-2 tokenizer based on byte-level BPE
+
+
+
+
+ BERT tokenizer based on WordPiece
+
+
+
+
+ T5 tokenizer based on Unigram
+
+
+
+
+ RWKV tokenizer based on greedy tokenization
+
+
+
+
+ LLaVa Image embeddings
+
+ llava_image_embed
+
+
+
+ Set configurations for all the native libraries, including LLama and LLava
+
+
+
+
+ Set configurations for all the native libraries, including LLama and LLava
+
+
+
+
+ Configuration for LLama native library
+
+
+
+
+ Configuration for LLava native library
+
+
+
+
+ Check if the native library has already been loaded. Configuration cannot be modified if this is true.
+
+
+
+
+ Set the log callback that will be used for all llama.cpp log messages
+
+
+
+
+
+
+ Set the log callback that will be used for all llama.cpp log messages
+
+
+
+
+
+
+ Try to load the native library with the current configurations,
+ but do not actually set it to .
+
+ You can still modify the configuration after this calling but only before any call from .
+
+
+ The loaded livrary. When the loading failed, this will be null.
+ However if you are using .NET standard2.0, this will never return null.
+
+ Whether the running is successful.
+
+
+
+ A class to set same configurations to multiple libraries at the same time.
+
+
+
+
+ Do an action for all the configs in this container.
+
+
+
+
+
+ Set the log callback that will be used for all llama.cpp log messages
+
+
+
+
+
+
+ Set the log callback that will be used for all llama.cpp log messages
+
+
+
+
+
+
+ Try to load the native library with the current configurations,
+ but do not actually set it to .
+
+ You can still modify the configuration after this calling but only before any call from .
+
+ Whether the running is successful.
+
+
+
+ The name of the native library
+
+
+
+
+ The native library compiled from llama.cpp.
+
+
+
+
+ The native library compiled from the LLaVA example of llama.cpp.
+
+
+
+
+ A native library specified with a local file path.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Information of a native library file.
+
+ Which kind of library it is.
+ Whether it's compiled with cublas.
+ Whether it's compiled with vulkan.
+ Which AvxLevel it's compiled with.
+
+
+
+ Information of a native library file.
+
+ Which kind of library it is.
+ Whether it's compiled with cublas.
+ Whether it's compiled with vulkan.
+ Which AvxLevel it's compiled with.
+
+
+ Which kind of library it is.
+
+
+ Whether it's compiled with cublas.
+
+
+ Whether it's compiled with vulkan.
+
+
+ Which AvxLevel it's compiled with.
+
+
+
+ Avx support configuration
+
+
+
+
+ No AVX
+
+
+
+
+ Advanced Vector Extensions (supported by most processors after 2011)
+
+
+
+
+ AVX2 (supported by most processors after 2013)
+
+
+
+
+ AVX512 (supported by some processors after 2016, not widely supported)
+
+
+
+
+ Try to load libllama/llava_shared, using CPU feature detection to try and load a more specialised DLL if possible
+
+ The library handle to unload later, or IntPtr.Zero if no library was loaded
+
+
+
+ Operating system information.
+
+
+
+
+
+
+
+ Operating system information.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Get the system information of the current machine.
+
+
+
+
+
+
+ When you are using .NET standard2.0, dynamic native library loading is not supported.
+ This class will be returned in .
+
+
+
+
+
+
+
+
+
+
+ A LoRA adapter which can be applied to a context for a specific model
+
+
+
+
+ The model which this LoRA adapter was loaded with.
+
+
+
+
+ The full path of the file this adapter was loaded from
+
+
+
+
+ Native pointer of the loaded adapter, will be automatically freed when the model is unloaded
+
+
+
+
+ Indicates if this adapter has been unloaded
+
+
+
+
+ Unload this adapter
+
+
+
+
+ Direct translation of the llama.cpp API
+
+
+
+
+ A method that does nothing. This is a native method, calling it will force the llama native dependencies to be loaded.
+
+
+
+
+
+ Call once at the end of the program - currently only used for MPI
+
+
+
+
+ Get the maximum number of devices supported by llama.cpp
+
+
+
+
+
+ Check if memory mapping is supported
+
+
+
+
+
+ Check if memory locking is supported
+
+
+
+
+
+ Check if GPU offload is supported
+
+
+
+
+
+ Check if RPC offload is supported
+
+
+
+
+
+ Initialize the llama + ggml backend. Call once at the start of the program.
+
+ This is private because LLamaSharp automatically calls it, and it's only valid to call it once!
+
+
+
+
+ Load session file
+
+
+
+
+
+
+
+
+
+
+ Save session file
+
+
+
+
+
+
+
+
+
+ Set whether to use causal attention or not. If set to true, the model will only attend to the past tokens
+
+
+
+
+ Set whether the model is in embeddings mode or not.
+
+
+ If true, embeddings will be returned but logits will not
+
+
+
+ Set abort callback
+
+
+
+
+ Get the n_seq_max for this context
+
+
+
+
+
+
+ Get all output token embeddings.
+ When pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model, the embeddings for which
+ llama_batch.logits[i] != 0 are stored contiguously in the order they have appeared in the batch.
+ shape: [n_outputs*n_embd]
+ Otherwise, returns an empty span.
+
+
+
+
+
+
+ Apply chat template. Inspired by hf apply_chat_template() on python.
+
+ A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead.
+ Pointer to a list of multiple llama_chat_message
+ Number of llama_chat_message in this chat
+ Whether to end the prompt with the token(s) that indicate the start of an assistant message.
+ A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages)
+ The size of the allocated buffer
+ The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.
+
+
+
+ Get list of built-in chat templates
+
+
+
+
+
+
+
+ Print out timing information for this context
+
+
+
+
+
+ Print system information
+
+
+
+
+
+ Convert a single token into text
+
+
+
+ buffer to write string into
+ User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix')
+ If true, special tokens are rendered in the output
+ The length written, or if the buffer is too small a negative that indicates the length required
+
+
+
+ Convert text into tokens
+
+
+
+
+ The tokens pointer must be large enough to hold the resulting tokens.
+
+ add_special Allow to add BOS and EOS tokens if model is configured to do so.
+ Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
+ Returns the number of tokens on success, no more than n_max_tokens.
+ Returns a negative number on failure - the number of tokens that would have been returned
+
+
+
+
+ Convert the provided tokens into text (inverse of llama_tokenize()).
+
+
+
+
+ The char pointer must be large enough to hold the resulting text.
+
+ remove_special Allow to remove BOS and EOS tokens if model is configured to do so.
+ unparse_special If true, special tokens are rendered in the output.
+ Returns the number of chars/bytes on success, no more than textLengthMax. Returns a negative number on failure - the number of chars/bytes that would have been returned.
+
+
+
+ Register a callback to receive llama log messages
+
+
+
+
+
+ Returns the number of tokens in the KV cache (slow, use only for debug)
+ If a KV cell has multiple sequences assigned to it, it will be counted multiple times
+
+
+
+
+
+
+ Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
+
+
+
+
+
+
+ Clear the KV cache. Both cell info is erased and KV data is zeroed
+
+
+
+
+
+ Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
+
+
+
+
+
+ Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
+
+
+
+ Copy all tokens that belong to the specified sequence to another sequence
+ Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
+
+
+
+
+
+
+
+
+
+ Removes all tokens that do not belong to the specified sequence
+
+
+
+
+
+
+ Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
+ If the KV cache is RoPEd, the KV data is updated accordingly:
+ - lazily on next llama_decode()
+ - explicitly with llama_kv_cache_update()
+
+
+
+
+
+
+
+
+
+ Integer division of the positions by factor of `d > 1`
+ If the KV cache is RoPEd, the KV data is updated accordingly:
+ - lazily on next llama_decode()
+ - explicitly with llama_kv_cache_update()
+
+ p0 < 0 : [0, p1]
+
+ p1 < 0 : [p0, inf)
+
+
+
+
+
+
+
+
+
+ Returns the largest position present in the KV cache for the specified sequence
+
+
+
+
+
+
+
+ Allocates a batch of tokens on the heap
+ Each token can be assigned up to n_seq_max sequence ids
+ The batch has to be freed with llama_batch_free()
+ If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float)
+ Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
+ The rest of the llama_batch members are allocated with size n_tokens
+ All members are left uninitialized
+
+
+
+ Each token can be assigned up to n_seq_max sequence ids
+
+
+
+ Frees a batch of tokens allocated with llama_batch_init()
+
+
+
+
+
+ Apply a loaded control vector to a llama_context, or if data is NULL, clear
+ the currently loaded vector.
+ n_embd should be the size of a single layer's control, and data should point
+ to an n_embd x n_layers buffer starting from layer 1.
+ il_start and il_end are the layer range the vector should apply to (both inclusive)
+ See llama_control_vector_load in common to load a control vector.
+
+
+
+
+
+
+
+
+
+
+
+ Build a split GGUF final path for this chunk.
+ llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf"
+
+
+
+
+
+
+ Returns the split_path length.
+
+
+
+ Extract the path prefix from the split_path if and only if the split_no and split_count match.
+ llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0"
+
+
+
+
+
+
+ Returns the split_prefix length.
+
+
+
+ Sanity check for clip <-> llava embed size match
+
+ LLama Context
+ Llava Model
+ True if validate successfully
+
+
+
+ Build an image embed from image file bytes
+
+ SafeHandle to the Clip Model
+ Number of threads
+ Binary image in jpeg format
+ Bytes length of the image
+ SafeHandle to the Embeddings
+
+
+
+ Build an image embed from a path to an image filename
+
+ SafeHandle to the Clip Model
+ Number of threads
+ Image filename (jpeg) to generate embeddings
+ SafeHandle to the embeddings
+
+
+
+ Free an embedding made with llava_image_embed_make_*
+
+ Embeddings to release
+
+
+
+ Write the image represented by embed into the llama context with batch size n_batch, starting at context
+ pos n_past. on completion, n_past points to the next position in the context after the image embed.
+
+ Llama Context
+ Embedding handle
+ True on success
+
+
+
+ Get the loaded native library. If you are using netstandard2.0, it will always return null.
+
+
+
+
+
+
+
+ Returns 0 on success
+
+
+
+
+ Returns 0 on success
+
+
+
+ Configure llama.cpp logging
+
+
+
+
+ Callback from llama.cpp with log messages
+
+
+
+
+
+
+ Register a callback to receive llama log messages
+
+
+
+
+
+ A GC handle for the current log callback to ensure the callback is not collected
+
+
+
+
+ Register a callback to receive llama log messages
+
+
+
+
+
+ Register a callback to receive llama log messages
+
+
+
+
+
+ RoPE scaling type.
+
+ C# equivalent of llama_rope_scaling_type
+
+
+
+ No particular scaling type has been specified
+
+
+
+
+ Do not apply any RoPE scaling
+
+
+
+
+ Positional linear interpolation, as described by kaikendev: https://kaiokendev.github.io/til#extending-context-to-8k
+
+
+
+
+ YaRN scaling: https://arxiv.org/pdf/2309.00071.pdf
+
+
+
+
+ LongRope scaling
+
+
+
+
+ A safe wrapper around a llama_context
+
+
+
+
+ Total number of tokens in the context
+
+
+
+
+ Dimension of embedding vectors
+
+
+
+
+ Get the maximum batch size for this context
+
+
+
+
+ Get the physical maximum batch size for this context
+
+
+
+
+ Get or set the number of threads used for generation of a single token.
+
+
+
+
+ Get or set the number of threads used for prompt and batch processing (multiple token).
+
+
+
+
+ Get the pooling type for this context
+
+
+
+
+ Get the model which this context is using
+
+
+
+
+ Get the vocabulary for the model this context is using
+
+
+
+
+
+
+
+ Create a new llama_state for the given model
+
+
+
+
+
+
+
+
+ Create a new llama_context with the given model. **This should never be called directly! Always use SafeLLamaContextHandle.Create**!
+
+
+
+
+
+
+
+ Frees all allocated memory in the given llama_context
+
+
+
+
+
+ Set a callback which can abort computation
+
+
+
+
+
+
+
+ If this returns true computation is cancelled
+
+
+
+
+
+
+
+
+
+ Positive return values does not mean a fatal error, but rather a warning:
+ - 0: success
+ - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
+ - < 0: error
+
+
+
+
+ Processes a batch of tokens with the encoder part of the encoder-decoder model. Stores the encoder output
+ internally for later use by the decoder cross-attention layers.
+
+
+
+ 0 = success
< 0 = error
+
+
+
+ Set the number of threads used for decoding
+
+
+ n_threads is the number of threads used for generation (single token)
+ n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
+
+
+
+
+ Get the number of threads used for generation of a single token.
+
+
+
+
+
+
+ Get the number of threads used for prompt and batch processing (multiple token).
+
+
+
+
+
+
+ Token logits obtained from the last call to llama_decode
+ The logits for the last token are stored in the last row
+ Can be mutated in order to change the probabilities of the next token.
+ Rows: n_tokens
+ Cols: n_vocab
+
+
+
+
+
+
+ Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab
+
+
+
+
+
+
+
+ Get the size of the context window for the model for this context
+
+
+
+
+
+
+ Get the batch size for this context
+
+
+
+
+
+
+ Get the ubatch size for this context
+
+
+
+
+
+
+ Returns the **actual** size in bytes of the state (logits, embedding and kv_cache).
+ Only use when saving the state, not when restoring it, otherwise the size may be too small.
+
+
+
+
+
+
+ Copies the state to the specified destination address.
+ Destination needs to have allocated enough memory.
+
+
+
+
+ the number of bytes copied
+
+
+
+ Set the state reading from the specified address
+
+
+
+
+ the number of bytes read
+
+
+
+ Get the exact size needed to copy the KV cache of a single sequence
+
+
+
+
+ Copy the KV cache of a single sequence into the specified buffer
+
+
+
+
+
+
+
+
+
+ Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
+
+
+
+
+
+
+ - Positive: Ok
+ - Zero: Failed to load
+
+
+
+
+ Defragment the KV cache. This will be applied:
+ - lazily on next llama_decode()
+ - explicitly with llama_kv_cache_update()
+
+
+
+
+
+
+ Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
+
+
+
+
+
+ Check if the context supports KV cache shifting
+
+
+
+
+
+
+ Wait until all computations are finished. This is automatically done when using any of the functions to obtain computation results
+ and is not necessary to call it explicitly in most cases.
+
+
+
+
+
+ Get the pooling type for this context
+
+
+
+
+
+
+ Get the embeddings for a sequence id.
+ Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
+ when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
+ otherwise: float[n_embd] (1-dimensional)
+
+ A pointer to the first float in an embedding, length = ctx.EmbeddingSize
+
+
+
+ Get the embeddings for the ith sequence.
+ Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
+
+ A pointer to the first float in an embedding, length = ctx.EmbeddingSize
+
+
+
+ Add a LoRA adapter to this context
+
+
+
+
+
+
+
+
+ Remove a LoRA adapter from this context
+
+
+ Indicates if the lora was in this context and was remove
+
+
+
+ Remove all LoRA adapters from this context
+
+
+
+
+ Token logits obtained from the last call to llama_decode.
+ The logits for the last token are stored in the last row.
+ Only tokens with `logits = true` requested are present.
+ Can be mutated in order to change the probabilities of the next token.
+ Rows: n_tokens
+ Cols: n_vocab
+
+
+ The amount of tokens whose logits should be retrieved, in [numTokens X n_vocab] format.
+ Tokens' order is based on their order in the LlamaBatch (so, first tokens are first, etc).
+ This is helpful when requesting logits for many tokens in a sequence, or want to decode multiple sequences in one go.
+
+
+
+
+
+ Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab
+
+
+
+
+
+
+ Get the embeddings for the ith sequence.
+ Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
+
+ A pointer to the first float in an embedding, length = ctx.EmbeddingSize
+
+
+
+ Get the embeddings for the a specific sequence.
+ Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
+
+ A pointer to the first float in an embedding, length = ctx.EmbeddingSize
+
+
+
+ Convert the given text into tokens
+
+ The text to tokenize
+ Whether the "BOS" token should be added
+ Encoding to use for the text
+ Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
+
+
+
+
+
+ Convert a single llama token into bytes
+
+ Token to decode
+ A span to attempt to write into. If this is too small nothing will be written
+ The size of this token. **nothing will be written** if this is larger than `dest`
+
+
+
+ This object exists to ensure there is only ever 1 inference running at a time. This is a workaround for thread safety issues in llama.cpp itself.
+ Most notably CUDA, which seems to use some global singleton resources and will crash if multiple inferences are run (even against different models).
+
+ For more information see these issues:
+ - https://github.com/SciSharp/LLamaSharp/issues/596
+ - https://github.com/ggerganov/llama.cpp/issues/3960
+
+ If these are ever resolved this lock can probably be removed.
+
+
+
+
+ Wait until all computations are finished. This is automatically done when using any of the functions to obtain computation results
+ and is not necessary to call it explicitly in most cases.
+
+
+
+
+ Processes a batch of tokens with the encoder part of the encoder-decoder model. Stores the encoder output
+ internally for later use by the decoder cross-attention layers.
+
+
+ 0 = success
< 0 = error (the KV cache state is restored to the state before this call)
+
+
+
+
+
+ Positive return values does not mean a fatal error, but rather a warning:
+ - 0: success
+ - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
+ - < 0: error (the KV cache state is restored to the state before this call)
+
+
+
+
+ Decode a set of tokens in batch-size chunks.
+
+
+
+
+
+ A tuple, containing the decode result and the number of tokens that have not been decoded yet.
+
+
+
+
+
+ Positive return values does not mean a fatal error, but rather a warning:
+ - 0: success
+ - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
+ - < 0: error
+
+
+
+
+ Get the size of the state, when saved as bytes
+
+
+
+
+ Get the size of the KV cache for a single sequence ID, when saved as bytes
+
+
+
+
+
+
+ Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer.
+
+ Destination to write to
+ Number of bytes available to write to in dest (check required size with `GetStateSize()`)
+ The number of bytes written to dest
+ Thrown if dest is too small
+
+
+
+ Get the raw state of a single sequence from this context, encoded as bytes. Data is written into the `dest` pointer.
+
+ Destination to write to
+ Number of bytes available to write to in dest (check required size with `GetStateSize()`)
+ The sequence to get state data for
+ The number of bytes written to dest
+
+
+
+ Set the raw state of this context
+
+ The pointer to read the state from
+ Number of bytes that can be safely read from the pointer
+ Number of bytes read from the src pointer
+
+
+
+ Set the raw state of a single sequence
+
+ The pointer to read the state from
+ Sequence ID to set
+ Number of bytes that can be safely read from the pointer
+ Number of bytes read from the src pointer
+
+
+
+ Get performance information
+
+
+
+
+
+ Reset all performance information for this context
+
+
+
+
+ Check if the context supports KV cache shifting
+
+
+
+
+ Apply KV cache updates (such as K-shifts, defragmentation, etc.)
+
+
+
+
+ Defragment the KV cache. This will be applied:
+ - lazily on next llama_decode()
+ - explicitly with llama_kv_cache_update()
+
+
+
+
+
+ Get a new KV cache view that can be used to debug the KV cache
+
+
+
+
+
+
+ Count the number of used cells in the KV cache (i.e. have at least one sequence assigned to them)
+
+
+
+
+
+ Returns the number of tokens in the KV cache (slow, use only for debug)
+ If a KV cell has multiple sequences assigned to it, it will be counted multiple times
+
+
+
+
+
+ Clear the KV cache - both cell info is erased and KV data is zeroed
+
+
+
+
+ Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
+
+
+
+
+
+
+
+ Copy all tokens that belong to the specified sequence to another sequence. Note that
+ this does not allocate extra KV cache memory - it simply assigns the tokens to the
+ new sequence
+
+
+
+
+
+
+
+
+ Removes all tokens that do not belong to the specified sequence
+
+
+
+
+
+ Adds relative position "delta" to all tokens that belong to the specified sequence
+ and have positions in [p0, p1. If the KV cache is RoPEd, the KV data is updated
+ accordingly
+
+
+
+
+
+
+
+
+ Integer division of the positions by factor of `d > 1`.
+ If the KV cache is RoPEd, the KV data is updated accordingly.
+ p0 < 0 : [0, p1]
+ p1 < 0 : [p0, inf)
+
+
+
+
+
+
+
+
+ Returns the largest position present in the KV cache for the specified sequence
+
+
+
+
+
+
+ Base class for all llama handles to native resources
+
+
+
+
+
+
+
+
+
+
+ A reference to a set of llama model weights
+
+
+
+
+ Get the rope (positional embedding) type for this model
+
+
+
+
+ The number of tokens in the context that this model was trained for
+
+
+
+
+ Get the rope frequency this model was trained with
+
+
+
+
+ Dimension of embedding vectors
+
+
+
+
+ Get the size of this model in bytes
+
+
+
+
+ Get the number of parameters in this model
+
+
+
+
+ Get the number of layers in this model
+
+
+
+
+ Get the number of heads in this model
+
+
+
+
+ Returns true if the model contains an encoder that requires llama_encode() call
+
+
+
+
+ Returns true if the model contains a decoder that requires llama_decode() call
+
+
+
+
+ Returns true if the model is recurrent (like Mamba, RWKV, etc.)
+
+
+
+
+ Get a description of this model
+
+
+
+
+ Get the number of metadata key/value pairs
+
+
+
+
+
+ Get the vocabulary of this model
+
+
+
+
+
+
+
+ Load a model from the given file path into memory
+
+
+
+
+
+
+
+
+ Load the model from a file
+ If the file is split into multiple parts, the file name must follow this pattern: {name}-%05d-of-%05d.gguf
+ If the split file name does not follow this pattern, use llama_model_load_from_splits
+
+
+
+ The loaded model, or null on failure.
+
+
+
+ Load the model from multiple splits (support custom naming scheme)
+ The paths must be in the correct order
+
+
+
+
+
+ Apply a LoRA adapter to a loaded model
+ path_base_model is the path to a higher quality model to use as a base for
+ the layers modified by the adapter. Can be NULL to use the current loaded model.
+ The model needs to be reloaded before applying a new adapter, otherwise the adapter
+ will be applied on top of the previous one
+
+
+
+
+
+
+ Returns 0 on success
+
+
+
+ Frees all allocated memory associated with a model
+
+
+
+
+
+ Get the number of metadata key/value pairs
+
+
+
+
+
+
+ Get metadata key name by index
+
+ Model to fetch from
+ Index of key to fetch
+ buffer to write result into
+ The length of the string on success (even if the buffer is too small). -1 is the key does not exist.
+
+
+
+ Get metadata value as a string by index
+
+ Model to fetch from
+ Index of val to fetch
+ Buffer to write result into
+ The length of the string on success (even if the buffer is too small). -1 is the key does not exist.
+
+
+
+ Get metadata value as a string by key name
+
+
+
+
+ The length of the string on success, or -1 on failure
+
+
+
+ Get the number of tokens in the model vocabulary
+
+
+
+
+
+
+ Get the size of the context window for the model
+
+
+
+
+
+
+ Get the dimension of embedding vectors from this model
+
+
+
+
+
+
+ Get the number of layers in this model
+
+
+
+
+
+
+ Get the number of heads in this model
+
+
+
+
+
+
+ Get a string describing the model type
+
+
+
+
+ The length of the string on success (even if the buffer is too small)., or -1 on failure
+
+
+
+ Get the size of the model in bytes
+
+
+ The size of the model
+
+
+
+ Get the number of parameters in this model
+
+
+ The functions return the length of the string on success, or -1 on failure
+
+
+
+ Get the model's RoPE frequency scaling factor
+
+
+
+
+
+
+ For encoder-decoder models, this function returns id of the token that must be provided
+ to the decoder to start generating output sequence. For other models, it returns -1.
+
+
+
+
+
+ Returns true if the model contains an encoder that requires llama_encode() call
+
+
+
+
+
+
+ Returns true if the model contains a decoder that requires llama_decode() call
+
+
+
+
+
+
+ Returns true if the model is recurrent (like Mamba, RWKV, etc.)
+
+
+
+
+
+
+ Load a LoRA adapter from file. The adapter will be associated with this model but will not be applied
+
+
+
+
+
+
+
+ Convert a single llama token into bytes
+
+ Token to decode
+ A span to attempt to write into. If this is too small nothing will be written
+ User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix')
+ If true, special characters will be converted to text. If false they will be invisible.
+ The size of this token. **nothing will be written** if this is larger than `dest`
+
+
+
+ Convert a sequence of tokens into characters.
+
+
+
+
+ The section of the span which has valid data in it.
+ If there was insufficient space in the output span this will be
+ filled with as many characters as possible, starting from the _last_ token.
+
+
+
+
+ Convert a string of text into tokens
+
+
+
+
+ Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
+
+
+
+
+ Create a new context for this model
+
+
+
+
+
+
+ Get the metadata value for the given key
+
+ The key to fetch
+ The value, null if there is no such key
+
+
+
+ Get the metadata key for the given index
+
+ The index to get
+ The key, null if there is no such key or if the buffer was too small
+
+
+
+ Get the metadata value for the given index
+
+ The index to get
+ The value, null if there is no such value or if the buffer was too small
+
+
+
+ Get the default chat template. Returns nullptr if not available
+ If name is NULL, returns the default chat template
+
+
+
+
+
+
+ Get tokens for a model
+
+
+
+
+ Total number of tokens in this vocabulary
+
+
+
+
+ Get the the type of this vocabulary
+
+
+
+
+ Get the Beginning of Sentence token for this model
+
+
+
+
+ Get the End of Sentence token for this model
+
+
+
+
+ Get the newline token for this model
+
+
+
+
+ Get the padding token for this model
+
+
+
+
+ Get the sentence separator token for this model
+
+
+
+
+ Codellama beginning of infill prefix
+
+
+
+
+ Codellama beginning of infill middle
+
+
+
+
+ Codellama beginning of infill suffix
+
+
+
+
+ Codellama pad
+
+
+
+
+ Codellama rep
+
+
+
+
+ Codellama rep
+
+
+
+
+ end-of-turn token
+
+
+
+
+ For encoder-decoder models, this function returns id of the token that must be provided
+ to the decoder to start generating output sequence.
+
+
+
+
+ Check if the current model requires a BOS token added
+
+
+
+
+ Check if the current model requires a EOS token added
+
+
+
+
+ A chain of sampler stages that can be used to select tokens from logits.
+
+ Wraps a handle returned from `llama_sampler_chain_init`. Other samplers are owned by this chain and are never directly exposed.
+
+
+
+ Get the number of samplers in this chain
+
+
+
+
+
+
+
+ Apply this sampler to a set of candidates
+
+
+
+
+
+ Sample and accept a token from the idx-th output of the last evaluation. Shorthand for:
+
+
+ var logits = ctx.GetLogitsIth(idx);
+ var token_data_array = LLamaTokenDataArray.Create(logits);
+ using LLamaTokenDataArrayNative.Create(token_data_array, out var native_token_data);
+ sampler_chain.Apply(native_token_data);
+ var token = native_token_data.Data.Span[native_token_data.Selected];
+ sampler_chain.Accept(token);
+ return token;
+
+
+
+
+
+
+
+ Reset the state of this sampler
+
+
+
+
+ Accept a token and update the internal state of this sampler
+
+
+
+
+
+ Get the name of the sampler at the given index
+
+
+
+
+
+
+ Get the seed of the sampler at the given index if applicable. returns LLAMA_DEFAULT_SEED otherwise
+
+
+
+
+
+
+ Create a new sampler chain
+
+
+
+
+
+
+ Clone a sampler stage from another chain and add it to this chain
+
+ The chain to clone a stage from
+ The index of the stage to clone
+
+
+
+ Remove a sampler stage from this chain
+
+
+
+
+
+
+ Add a custom sampler stage
+
+
+
+
+
+
+ Add a sampler which picks the most likely token.
+
+
+
+
+
+ Add a sampler which picks from the probability distribution of all tokens
+
+
+
+
+
+ Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
+
+
+ The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
+ The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
+ The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
+
+
+
+
+
+ Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
+
+
+ The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
+ The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
+
+
+
+
+ Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
+
+
+
+
+
+ Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
+
+
+
+
+
+ Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
+
+
+
+
+ Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
+
+
+
+
+ Apply temperature to the logits.
+ If temperature is less than zero the maximum logit is left unchanged and the rest are set to -infinity
+
+
+
+
+
+ Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
+
+
+
+
+
+
+
+ XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
+
+
+
+
+
+
+
+
+ This sampler is meant to be used for fill-in-the-middle infilling, after top_k + top_p sampling
+
+ 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
+ 2. combine probs of tokens that have the same prefix
+
+ example:
+
+ - before:
+ "abc": 0.5
+ "abcd": 0.2
+ "abcde": 0.1
+ "dummy": 0.1
+
+ - after:
+ "abc": 0.8
+ "dummy": 0.1
+
+ 3. discard non-EOG tokens with low prob
+ 4. if no tokens are left -> pick EOT
+
+
+
+
+
+ Create a sampler which makes tokens impossible unless they match the grammar
+
+
+
+ Root rule of the grammar
+
+
+
+
+ Create a sampler using lazy grammar sampling: https://github.com/ggerganov/llama.cpp/pull/9639
+
+
+ Grammar in GBNF form
+ Root rule of the grammar
+ A list of tokens that will trigger the grammar sampler.
+ A list of words that will trigger the grammar sampler.
+
+
+
+
+ Create a sampler that applies various repetition penalties.
+
+ Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
+
+ How many tokens of history to consider when calculating penalties
+ Repetition penalty
+ Frequency penalty
+ Presence penalty
+
+
+
+
+ DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677.
+ Porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
+
+ The model this sampler will be used with
+
+ penalty multiplier, 0.0 = disabled
+ exponential base
+ repeated sequences longer than this are penalized
+ how many tokens to scan for repetitions (0 = entire context)
+
+
+
+ Create a sampler that applies a bias directly to the logits
+
+
+
+
+
+
+
+
+
+ llama_sampler_chain_params
+
+
+
+ whether to measure performance timings
+
+
+
+
+ Get the default LLamaSamplerChainParams
+
+
+
+
+
+ A bias to apply directly to a logit
+
+
+
+
+ The token to apply the bias to
+
+
+
+
+ The bias to add
+
+
+
+
+
+
+ llama_sampler_i
+
+
+
+ Get the name of this sampler
+
+
+
+
+
+
+ Update internal sampler state after a token has been chosen
+
+
+
+
+
+
+ Apply this sampler to a set of logits
+
+
+
+
+
+
+ Reset the internal state of this sampler
+
+
+
+
+
+ Create a clone of this sampler
+
+
+
+
+
+
+ Free all resources held by this sampler
+
+
+
+
+
+
+
+ llama_sampler
+
+
+
+ Holds the function pointers which make up the actual sampler
+
+
+
+
+ Any additional context this sampler needs, may be anything. We will use it
+ to hold a GCHandle.
+
+
+
+
+ This GCHandle roots this object, preventing it from being freed.
+
+
+
+
+ A reference to the user code which implements the custom sampler
+
+
+
+
+ Get a pointer to a `llama_sampler` (LLamaSamplerNative) struct, suitable for passing to `llama_sampler_chain_add`
+
+
+
+
+
+
+ A custom sampler stage for modifying logits or selecting a token
+
+
+
+
+ The human readable name of this stage
+
+
+
+
+ Apply this stage to a set of logits.
+ This can modify logits or select a token (or both).
+ If logits are modified the Sorted flag must be set to false.
+
+
+ If the logits are no longer sorted after the custom sampler has run it is critically important to
+ set Sorted=false. If unsure, always set it to false, this is a safe default.
+
+
+
+
+
+ Update the internal state of the sampler when a token is chosen
+
+
+
+
+
+ Reset the internal state of this sampler
+
+
+
+
+ Create a clone of this sampler
+
+
+
+
+ A Reference to a llava Image Embed handle
+
+
+
+
+ Get the model used to create this image embedding
+
+
+
+
+ Get the number of dimensions in an embedding
+
+
+
+
+ Get the number of "patches" in an image embedding
+
+
+
+
+ Create an image embed from an image file
+
+
+
+ Path to the image file. Supported formats:
+
+ - JPG
+ - PNG
+ - BMP
+ - TGA
+
+
+
+
+
+
+
+ Create an image embed from an image file
+
+
+ Path to the image file. Supported formats:
+
+ - JPG
+ - PNG
+ - BMP
+ - TGA
+
+
+
+
+
+
+
+
+ Create an image embed from the bytes of an image.
+
+
+
+ Image bytes. Supported formats:
+
+ - JPG
+ - PNG
+ - BMP
+ - TGA
+
+
+
+
+
+
+ Create an image embed from the bytes of an image.
+
+
+ Image bytes. Supported formats:
+
+ - JPG
+ - PNG
+ - BMP
+ - TGA
+
+
+
+
+
+
+
+
+
+
+ Copy the embeddings data to the destination span
+
+
+
+
+
+
+ A reference to a set of llava model weights.
+
+
+
+
+ Get the number of dimensions in an embedding
+
+
+
+
+ Get the number of "patches" in an image embedding
+
+
+
+
+
+
+
+ Load a model from the given file path into memory
+
+ MMP File (Multi-Modal Projections)
+ Verbosity level
+ SafeHandle of the Clip Model
+
+
+
+
+
+ Create the Image Embeddings.
+
+ LLama Context
+ Image filename (it supports jpeg format only)
+ return the SafeHandle of these embeddings
+
+
+
+ Create the Image Embeddings.
+
+ Image in binary format (it supports jpeg format only)
+ Number of threads to use
+ return the SafeHandle of these embeddings
+
+
+
+ Create the Image Embeddings.
+
+ LLama Context
+ Image in binary format (it supports jpeg format only)
+ return the SafeHandle of these embeddings
+
+
+
+ Create the Image Embeddings.
+
+ Image in binary format (it supports jpeg format only)
+ Number of threads to use
+ return the SafeHandle of these embeddings
+
+
+
+ Evaluates the image embeddings.
+
+ Llama Context
+ The current embeddings to evaluate
+
+ True on success
+
+
+
+ Load MULTI MODAL PROJECTIONS model / Clip Model
+
+ Model path/file
+ Verbosity level
+ SafeLlavaModelHandle
+
+
+
+ Frees MULTI MODAL PROJECTIONS model / Clip Model
+
+ Internal Pointer to the model
+
+
+
+
+
+
+ Create a new sampler wrapping a llama.cpp sampler chain
+
+
+
+
+ Create a sampling chain. This will be called once, the base class will automatically dispose the chain.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ An implementation of ISamplePipeline which mimics the default llama.cpp sampling
+
+
+
+
+ Bias values to add to certain logits
+
+
+
+
+ Repetition penalty, as described in https://arxiv.org/abs/1909.05858
+
+
+
+
+ Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
+ Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text
+ so far, decreasing the model's likelihood to repeat the same line verbatim.
+
+
+
+
+ Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
+ Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
+ text so far, increasing the model's likelihood to talk about new topics.
+
+
+
+
+ How many tokens should be considered for penalties
+
+
+
+
+ Whether the newline token should be protected from being modified by penalty
+
+
+
+
+ Whether the EOS token should be suppressed. Setting this to 'true' prevents EOS from being sampled
+
+
+
+
+ Temperature to apply (higher temperature is more "creative")
+
+
+
+
+ Number of tokens to keep in TopK sampling
+
+
+
+
+ P value for locally typical sampling
+
+
+
+
+ P value for TopP sampling
+
+
+
+
+ P value for MinP sampling
+
+
+
+
+ Grammar to apply to constrain possible tokens
+
+
+
+
+ The minimum number of tokens to keep for samplers which remove tokens
+
+
+
+
+ Seed to use for random sampling
+
+
+
+
+
+
+
+ A grammar in GBNF form
+
+
+
+
+
+
+ A grammar in GBNF form
+
+
+
+
+
+
+
+
+
+
+
+
+ A sampling pipeline which always selects the most likely token
+
+
+
+
+ Grammar to apply to constrain possible tokens
+
+
+
+
+
+
+
+ Convert a span of logits into a single sampled token. This interface can be implemented to completely customise the sampling process.
+
+
+
+
+ Sample a single token from the given context at the given position
+
+ The context being sampled from
+ Position to sample logits from
+
+
+
+
+ Reset all internal state of the sampling pipeline
+
+
+
+
+ Update the pipeline, with knowledge that a particular token was just accepted
+
+
+
+
+
+ Extension methods for
+
+
+
+
+ Sample a single token from the given context at the given position
+
+
+ The context being sampled from
+ Position to sample logits from
+
+
+
+
+ Decodes a stream of tokens into a stream of characters
+
+
+
+
+ The number of decoded characters waiting to be read
+
+
+
+
+ If true, special characters will be converted to text. If false they will be invisible.
+
+
+
+
+ Create a new decoder
+
+ Text encoding to use
+ Model weights
+
+
+
+ Create a new decoder
+
+ Context to retrieve encoding and model weights from
+
+
+
+ Create a new decoder
+
+ Text encoding to use
+ Context to retrieve model weights from
+
+
+
+ Create a new decoder
+
+ Text encoding to use
+ Models weights to use
+
+
+
+ Add a single token to the decoder
+
+
+
+
+
+ Add a single token to the decoder
+
+
+
+
+
+ Add all tokens in the given enumerable
+
+
+
+
+
+ Add all tokens in the given span
+
+
+
+
+
+ Read all decoded characters and clear the buffer
+
+
+
+
+
+ Read all decoded characters as a string and clear the buffer
+
+
+
+
+
+ Set the decoder back to its initial state
+
+
+
+
+ A prompt formatter that will use llama.cpp's template formatter
+ If your model is not supported, you will need to define your own formatter according the cchat prompt specification for your model
+
+
+
+
+ A prompt formatter that will use llama.cpp's template formatter
+ If your model is not supported, you will need to define your own formatter according the cchat prompt specification for your model
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Apply the template to the messages and return the resulting prompt as a string
+
+
+ The formatted template string as defined by the model
+
+
+
diff --git a/Assets/Packages/LLamaSharp.0.21.0/lib/netstandard2.0/LLamaSharp.xml.meta b/Assets/Packages/LLamaSharp.0.21.0/lib/netstandard2.0/LLamaSharp.xml.meta
new file mode 100644
index 00000000..d0872ff9
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.0.21.0/lib/netstandard2.0/LLamaSharp.xml.meta
@@ -0,0 +1,7 @@
+fileFormatVersion: 2
+guid: d026a86dc56a0efc8b38919a8f0a0207
+TextScriptImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0.meta
new file mode 100644
index 00000000..3a170dec
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: a7da5a0204981c7c1836ed7925726d56
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/.signature.p7s b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/.signature.p7s
new file mode 100644
index 00000000..99d44527
Binary files /dev/null and b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/.signature.p7s differ
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/LLamaSharp.Backend.Cpu.nuspec b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/LLamaSharp.Backend.Cpu.nuspec
new file mode 100644
index 00000000..a1bde715
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/LLamaSharp.Backend.Cpu.nuspec
@@ -0,0 +1,18 @@
+
+
+
+ LLamaSharp.Backend.Cpu
+ 0.21.0
+ LLamaSharp.Backend.Cpu, the backend for LLamaSharp
+ llama.cpp Authors
+ false
+ MIT
+ https://licenses.nuget.org/MIT
+ icon512.png
+ https://github.com/SciSharp/LLamaSharp
+ LLamaSharp.Backend.Cpu is a backend for LLamaSharp to use with Cpu only.
+
+ Copyright 2023 The llama.cpp Authors. All rights reserved.
+ LLamaSharp LLama LLM GPT AI ChatBot SciSharp
+
+
\ No newline at end of file
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/LLamaSharp.Backend.Cpu.nuspec.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/LLamaSharp.Backend.Cpu.nuspec.meta
new file mode 100644
index 00000000..d108bf99
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/LLamaSharp.Backend.Cpu.nuspec.meta
@@ -0,0 +1,7 @@
+fileFormatVersion: 2
+guid: 9c95dd7d1a81e38d88fa6c68ef8a80ec
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/icon512.png b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/icon512.png
new file mode 100644
index 00000000..d7940900
Binary files /dev/null and b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/icon512.png differ
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/icon512.png.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/icon512.png.meta
new file mode 100644
index 00000000..eb21a9cc
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/icon512.png.meta
@@ -0,0 +1,117 @@
+fileFormatVersion: 2
+guid: fe8951864279553a4aa23f7cfb96bc31
+TextureImporter:
+ internalIDToNameTable: []
+ externalObjects: {}
+ serializedVersion: 13
+ mipmaps:
+ mipMapMode: 0
+ enableMipMap: 1
+ sRGBTexture: 1
+ linearTexture: 0
+ fadeOut: 0
+ borderMipMap: 0
+ mipMapsPreserveCoverage: 0
+ alphaTestReferenceValue: 0.5
+ mipMapFadeDistanceStart: 1
+ mipMapFadeDistanceEnd: 3
+ bumpmap:
+ convertToNormalMap: 0
+ externalNormalMap: 0
+ heightScale: 0.25
+ normalMapFilter: 0
+ flipGreenChannel: 0
+ isReadable: 0
+ streamingMipmaps: 0
+ streamingMipmapsPriority: 0
+ vTOnly: 0
+ ignoreMipmapLimit: 0
+ grayScaleToAlpha: 0
+ generateCubemap: 6
+ cubemapConvolution: 0
+ seamlessCubemap: 0
+ textureFormat: 1
+ maxTextureSize: 2048
+ textureSettings:
+ serializedVersion: 2
+ filterMode: 1
+ aniso: 1
+ mipBias: 0
+ wrapU: 0
+ wrapV: 0
+ wrapW: 0
+ nPOTScale: 1
+ lightmap: 0
+ compressionQuality: 50
+ spriteMode: 0
+ spriteExtrude: 1
+ spriteMeshType: 1
+ alignment: 0
+ spritePivot: {x: 0.5, y: 0.5}
+ spritePixelsToUnits: 100
+ spriteBorder: {x: 0, y: 0, z: 0, w: 0}
+ spriteGenerateFallbackPhysicsShape: 1
+ alphaUsage: 1
+ alphaIsTransparency: 0
+ spriteTessellationDetail: -1
+ textureType: 0
+ textureShape: 1
+ singleChannelComponent: 0
+ flipbookRows: 1
+ flipbookColumns: 1
+ maxTextureSizeSet: 0
+ compressionQualitySet: 0
+ textureFormatSet: 0
+ ignorePngGamma: 0
+ applyGammaDecoding: 0
+ swizzle: 50462976
+ cookieLightType: 0
+ platformSettings:
+ - serializedVersion: 4
+ buildTarget: DefaultTexturePlatform
+ maxTextureSize: 2048
+ resizeAlgorithm: 0
+ textureFormat: -1
+ textureCompression: 1
+ compressionQuality: 50
+ crunchedCompression: 0
+ allowsAlphaSplitting: 0
+ overridden: 0
+ ignorePlatformSupport: 0
+ androidETC2FallbackOverride: 0
+ forceMaximumCompressionQuality_BC6H_BC7: 0
+ - serializedVersion: 4
+ buildTarget: Standalone
+ maxTextureSize: 2048
+ resizeAlgorithm: 0
+ textureFormat: -1
+ textureCompression: 1
+ compressionQuality: 50
+ crunchedCompression: 0
+ allowsAlphaSplitting: 0
+ overridden: 0
+ ignorePlatformSupport: 0
+ androidETC2FallbackOverride: 0
+ forceMaximumCompressionQuality_BC6H_BC7: 0
+ spriteSheet:
+ serializedVersion: 2
+ sprites: []
+ outline: []
+ customData:
+ physicsShape: []
+ bones: []
+ spriteID:
+ internalID: 0
+ vertices: []
+ indices:
+ edges: []
+ weights: []
+ secondaryTextures: []
+ spriteCustomMetadata:
+ entries: []
+ nameFileIdTable: {}
+ mipmapLimitGroupName:
+ pSDRemoveMatte: 0
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes.meta
new file mode 100644
index 00000000..f700da56
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 7e3b623f23645fbdbac08f862161c33e
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64.meta
new file mode 100644
index 00000000..02559478
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 9183345255e26f454bde3ccf838fc4ff
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native.meta
new file mode 100644
index 00000000..22cf2270
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 83493f3c59e82962ebd6bd5c580cecd5
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx.meta
new file mode 100644
index 00000000..cdaefd22
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 05dc3c20bbd0790ba843b3ea0a30ad58
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libggml-base.so b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libggml-base.so
new file mode 100644
index 00000000..d7649341
Binary files /dev/null and b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libggml-base.so differ
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libggml-base.so.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libggml-base.so.meta
new file mode 100644
index 00000000..2435e26c
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libggml-base.so.meta
@@ -0,0 +1,69 @@
+fileFormatVersion: 2
+guid: 18ead8342c8644389b4bf8a480c12a43
+labels:
+- NuGetForUnity
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 3
+ iconMap: {}
+ executionOrder: {}
+ defineConstraints: []
+ isPreloaded: 0
+ isOverridable: 0
+ isExplicitlyReferenced: 0
+ validateReferences: 1
+ platformData:
+ Any:
+ enabled: 0
+ settings:
+ 'Exclude ': 1
+ Exclude Android: 1
+ Exclude Bratwurst: 1
+ Exclude CloudRendering: 1
+ Exclude Editor: 1
+ Exclude EmbeddedLinux: 1
+ Exclude GameCoreScarlett: 1
+ Exclude GameCoreXboxOne: 1
+ Exclude Linux64: 1
+ Exclude OSXUniversal: 1
+ Exclude PS4: 1
+ Exclude PS5: 1
+ Exclude QNX: 1
+ Exclude ReservedCFE: 1
+ Exclude Switch: 1
+ Exclude VisionOS: 1
+ Exclude WebGL: 1
+ Exclude Win: 1
+ Exclude Win64: 1
+ Exclude WindowsStoreApps: 1
+ Exclude XboxOne: 1
+ Exclude iOS: 1
+ Exclude tvOS: 1
+ Editor:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: Linux
+ Linux64:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ OSXUniversal:
+ enabled: 0
+ settings: {}
+ Win:
+ enabled: 0
+ settings: {}
+ Win64:
+ enabled: 0
+ settings: {}
+ WindowsStoreApps:
+ enabled: 0
+ settings: {}
+ iOS:
+ enabled: 0
+ settings: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libggml-cpu.so b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libggml-cpu.so
new file mode 100644
index 00000000..160b803a
Binary files /dev/null and b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libggml-cpu.so differ
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libggml-cpu.so.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libggml-cpu.so.meta
new file mode 100644
index 00000000..4df26c28
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libggml-cpu.so.meta
@@ -0,0 +1,69 @@
+fileFormatVersion: 2
+guid: 3c93668b5c8c4730b199b379fea4a185
+labels:
+- NuGetForUnity
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 3
+ iconMap: {}
+ executionOrder: {}
+ defineConstraints: []
+ isPreloaded: 0
+ isOverridable: 0
+ isExplicitlyReferenced: 0
+ validateReferences: 1
+ platformData:
+ Any:
+ enabled: 0
+ settings:
+ 'Exclude ': 1
+ Exclude Android: 1
+ Exclude Bratwurst: 1
+ Exclude CloudRendering: 1
+ Exclude Editor: 1
+ Exclude EmbeddedLinux: 1
+ Exclude GameCoreScarlett: 1
+ Exclude GameCoreXboxOne: 1
+ Exclude Linux64: 1
+ Exclude OSXUniversal: 1
+ Exclude PS4: 1
+ Exclude PS5: 1
+ Exclude QNX: 1
+ Exclude ReservedCFE: 1
+ Exclude Switch: 1
+ Exclude VisionOS: 1
+ Exclude WebGL: 1
+ Exclude Win: 1
+ Exclude Win64: 1
+ Exclude WindowsStoreApps: 1
+ Exclude XboxOne: 1
+ Exclude iOS: 1
+ Exclude tvOS: 1
+ Editor:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: Linux
+ Linux64:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ OSXUniversal:
+ enabled: 0
+ settings: {}
+ Win:
+ enabled: 0
+ settings: {}
+ Win64:
+ enabled: 0
+ settings: {}
+ WindowsStoreApps:
+ enabled: 0
+ settings: {}
+ iOS:
+ enabled: 0
+ settings: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libggml.so b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libggml.so
new file mode 100644
index 00000000..a54abbcd
Binary files /dev/null and b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libggml.so differ
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libggml.so.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libggml.so.meta
new file mode 100644
index 00000000..4204338f
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libggml.so.meta
@@ -0,0 +1,69 @@
+fileFormatVersion: 2
+guid: 9a6be9d930cd4ec7a4f02fb6ac468d3a
+labels:
+- NuGetForUnity
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 3
+ iconMap: {}
+ executionOrder: {}
+ defineConstraints: []
+ isPreloaded: 0
+ isOverridable: 0
+ isExplicitlyReferenced: 0
+ validateReferences: 1
+ platformData:
+ Any:
+ enabled: 0
+ settings:
+ 'Exclude ': 1
+ Exclude Android: 1
+ Exclude Bratwurst: 1
+ Exclude CloudRendering: 1
+ Exclude Editor: 1
+ Exclude EmbeddedLinux: 1
+ Exclude GameCoreScarlett: 1
+ Exclude GameCoreXboxOne: 1
+ Exclude Linux64: 1
+ Exclude OSXUniversal: 1
+ Exclude PS4: 1
+ Exclude PS5: 1
+ Exclude QNX: 1
+ Exclude ReservedCFE: 1
+ Exclude Switch: 1
+ Exclude VisionOS: 1
+ Exclude WebGL: 1
+ Exclude Win: 1
+ Exclude Win64: 1
+ Exclude WindowsStoreApps: 1
+ Exclude XboxOne: 1
+ Exclude iOS: 1
+ Exclude tvOS: 1
+ Editor:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: Linux
+ Linux64:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ OSXUniversal:
+ enabled: 0
+ settings: {}
+ Win:
+ enabled: 0
+ settings: {}
+ Win64:
+ enabled: 0
+ settings: {}
+ WindowsStoreApps:
+ enabled: 0
+ settings: {}
+ iOS:
+ enabled: 0
+ settings: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libllama.so b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libllama.so
new file mode 100644
index 00000000..00296671
Binary files /dev/null and b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libllama.so differ
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libllama.so.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libllama.so.meta
new file mode 100644
index 00000000..7457655a
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libllama.so.meta
@@ -0,0 +1,69 @@
+fileFormatVersion: 2
+guid: 0d49ee6176a54a96a4e54de5d967aaea
+labels:
+- NuGetForUnity
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 3
+ iconMap: {}
+ executionOrder: {}
+ defineConstraints: []
+ isPreloaded: 0
+ isOverridable: 0
+ isExplicitlyReferenced: 0
+ validateReferences: 1
+ platformData:
+ Any:
+ enabled: 0
+ settings:
+ 'Exclude ': 1
+ Exclude Android: 1
+ Exclude Bratwurst: 1
+ Exclude CloudRendering: 1
+ Exclude Editor: 1
+ Exclude EmbeddedLinux: 1
+ Exclude GameCoreScarlett: 1
+ Exclude GameCoreXboxOne: 1
+ Exclude Linux64: 1
+ Exclude OSXUniversal: 1
+ Exclude PS4: 1
+ Exclude PS5: 1
+ Exclude QNX: 1
+ Exclude ReservedCFE: 1
+ Exclude Switch: 1
+ Exclude VisionOS: 1
+ Exclude WebGL: 1
+ Exclude Win: 1
+ Exclude Win64: 1
+ Exclude WindowsStoreApps: 1
+ Exclude XboxOne: 1
+ Exclude iOS: 1
+ Exclude tvOS: 1
+ Editor:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: Linux
+ Linux64:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ OSXUniversal:
+ enabled: 0
+ settings: {}
+ Win:
+ enabled: 0
+ settings: {}
+ Win64:
+ enabled: 0
+ settings: {}
+ WindowsStoreApps:
+ enabled: 0
+ settings: {}
+ iOS:
+ enabled: 0
+ settings: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libllava_shared.so b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libllava_shared.so
new file mode 100644
index 00000000..a4e6a752
Binary files /dev/null and b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libllava_shared.so differ
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libllava_shared.so.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libllava_shared.so.meta
new file mode 100644
index 00000000..290f72a6
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx/libllava_shared.so.meta
@@ -0,0 +1,69 @@
+fileFormatVersion: 2
+guid: 2846428a4ce74e2590d4983c7afc719d
+labels:
+- NuGetForUnity
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 3
+ iconMap: {}
+ executionOrder: {}
+ defineConstraints: []
+ isPreloaded: 0
+ isOverridable: 0
+ isExplicitlyReferenced: 0
+ validateReferences: 1
+ platformData:
+ Any:
+ enabled: 0
+ settings:
+ 'Exclude ': 1
+ Exclude Android: 1
+ Exclude Bratwurst: 1
+ Exclude CloudRendering: 1
+ Exclude Editor: 1
+ Exclude EmbeddedLinux: 1
+ Exclude GameCoreScarlett: 1
+ Exclude GameCoreXboxOne: 1
+ Exclude Linux64: 1
+ Exclude OSXUniversal: 1
+ Exclude PS4: 1
+ Exclude PS5: 1
+ Exclude QNX: 1
+ Exclude ReservedCFE: 1
+ Exclude Switch: 1
+ Exclude VisionOS: 1
+ Exclude WebGL: 1
+ Exclude Win: 1
+ Exclude Win64: 1
+ Exclude WindowsStoreApps: 1
+ Exclude XboxOne: 1
+ Exclude iOS: 1
+ Exclude tvOS: 1
+ Editor:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: Linux
+ Linux64:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ OSXUniversal:
+ enabled: 0
+ settings: {}
+ Win:
+ enabled: 0
+ settings: {}
+ Win64:
+ enabled: 0
+ settings: {}
+ WindowsStoreApps:
+ enabled: 0
+ settings: {}
+ iOS:
+ enabled: 0
+ settings: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2.meta
new file mode 100644
index 00000000..b66a900e
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 607ff69954bcd9198962176e57ef3e3b
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libggml-base.so b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libggml-base.so
new file mode 100644
index 00000000..d7649341
Binary files /dev/null and b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libggml-base.so differ
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libggml-base.so.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libggml-base.so.meta
new file mode 100644
index 00000000..da6b6cf1
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libggml-base.so.meta
@@ -0,0 +1,69 @@
+fileFormatVersion: 2
+guid: 917aad0d60c5467bb00ed734df2fe416
+labels:
+- NuGetForUnity
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 3
+ iconMap: {}
+ executionOrder: {}
+ defineConstraints: []
+ isPreloaded: 0
+ isOverridable: 0
+ isExplicitlyReferenced: 0
+ validateReferences: 1
+ platformData:
+ Any:
+ enabled: 0
+ settings:
+ 'Exclude ': 1
+ Exclude Android: 1
+ Exclude Bratwurst: 1
+ Exclude CloudRendering: 1
+ Exclude Editor: 1
+ Exclude EmbeddedLinux: 1
+ Exclude GameCoreScarlett: 1
+ Exclude GameCoreXboxOne: 1
+ Exclude Linux64: 1
+ Exclude OSXUniversal: 1
+ Exclude PS4: 1
+ Exclude PS5: 1
+ Exclude QNX: 1
+ Exclude ReservedCFE: 1
+ Exclude Switch: 1
+ Exclude VisionOS: 1
+ Exclude WebGL: 1
+ Exclude Win: 1
+ Exclude Win64: 1
+ Exclude WindowsStoreApps: 1
+ Exclude XboxOne: 1
+ Exclude iOS: 1
+ Exclude tvOS: 1
+ Editor:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: Linux
+ Linux64:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ OSXUniversal:
+ enabled: 0
+ settings: {}
+ Win:
+ enabled: 0
+ settings: {}
+ Win64:
+ enabled: 0
+ settings: {}
+ WindowsStoreApps:
+ enabled: 0
+ settings: {}
+ iOS:
+ enabled: 0
+ settings: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libggml-cpu.so b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libggml-cpu.so
new file mode 100644
index 00000000..b0b05aed
Binary files /dev/null and b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libggml-cpu.so differ
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libggml-cpu.so.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libggml-cpu.so.meta
new file mode 100644
index 00000000..b8e99d64
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libggml-cpu.so.meta
@@ -0,0 +1,69 @@
+fileFormatVersion: 2
+guid: 23037954d7714574a2edc4b7f517f786
+labels:
+- NuGetForUnity
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 3
+ iconMap: {}
+ executionOrder: {}
+ defineConstraints: []
+ isPreloaded: 0
+ isOverridable: 0
+ isExplicitlyReferenced: 0
+ validateReferences: 1
+ platformData:
+ Any:
+ enabled: 0
+ settings:
+ 'Exclude ': 1
+ Exclude Android: 1
+ Exclude Bratwurst: 1
+ Exclude CloudRendering: 1
+ Exclude Editor: 1
+ Exclude EmbeddedLinux: 1
+ Exclude GameCoreScarlett: 1
+ Exclude GameCoreXboxOne: 1
+ Exclude Linux64: 1
+ Exclude OSXUniversal: 1
+ Exclude PS4: 1
+ Exclude PS5: 1
+ Exclude QNX: 1
+ Exclude ReservedCFE: 1
+ Exclude Switch: 1
+ Exclude VisionOS: 1
+ Exclude WebGL: 1
+ Exclude Win: 1
+ Exclude Win64: 1
+ Exclude WindowsStoreApps: 1
+ Exclude XboxOne: 1
+ Exclude iOS: 1
+ Exclude tvOS: 1
+ Editor:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: Linux
+ Linux64:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ OSXUniversal:
+ enabled: 0
+ settings: {}
+ Win:
+ enabled: 0
+ settings: {}
+ Win64:
+ enabled: 0
+ settings: {}
+ WindowsStoreApps:
+ enabled: 0
+ settings: {}
+ iOS:
+ enabled: 0
+ settings: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libggml.so b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libggml.so
new file mode 100644
index 00000000..a54abbcd
Binary files /dev/null and b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libggml.so differ
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libggml.so.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libggml.so.meta
new file mode 100644
index 00000000..748b8a6c
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libggml.so.meta
@@ -0,0 +1,69 @@
+fileFormatVersion: 2
+guid: a53660bb5d7749f2a79cff8824747e9e
+labels:
+- NuGetForUnity
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 3
+ iconMap: {}
+ executionOrder: {}
+ defineConstraints: []
+ isPreloaded: 0
+ isOverridable: 0
+ isExplicitlyReferenced: 0
+ validateReferences: 1
+ platformData:
+ Any:
+ enabled: 0
+ settings:
+ 'Exclude ': 1
+ Exclude Android: 1
+ Exclude Bratwurst: 1
+ Exclude CloudRendering: 1
+ Exclude Editor: 1
+ Exclude EmbeddedLinux: 1
+ Exclude GameCoreScarlett: 1
+ Exclude GameCoreXboxOne: 1
+ Exclude Linux64: 1
+ Exclude OSXUniversal: 1
+ Exclude PS4: 1
+ Exclude PS5: 1
+ Exclude QNX: 1
+ Exclude ReservedCFE: 1
+ Exclude Switch: 1
+ Exclude VisionOS: 1
+ Exclude WebGL: 1
+ Exclude Win: 1
+ Exclude Win64: 1
+ Exclude WindowsStoreApps: 1
+ Exclude XboxOne: 1
+ Exclude iOS: 1
+ Exclude tvOS: 1
+ Editor:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: Linux
+ Linux64:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ OSXUniversal:
+ enabled: 0
+ settings: {}
+ Win:
+ enabled: 0
+ settings: {}
+ Win64:
+ enabled: 0
+ settings: {}
+ WindowsStoreApps:
+ enabled: 0
+ settings: {}
+ iOS:
+ enabled: 0
+ settings: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libllama.so b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libllama.so
new file mode 100644
index 00000000..00296671
Binary files /dev/null and b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libllama.so differ
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libllama.so.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libllama.so.meta
new file mode 100644
index 00000000..7a2fbf35
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libllama.so.meta
@@ -0,0 +1,69 @@
+fileFormatVersion: 2
+guid: dde802fab50c43d795a81ca4e4a900c0
+labels:
+- NuGetForUnity
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 3
+ iconMap: {}
+ executionOrder: {}
+ defineConstraints: []
+ isPreloaded: 0
+ isOverridable: 0
+ isExplicitlyReferenced: 0
+ validateReferences: 1
+ platformData:
+ Any:
+ enabled: 0
+ settings:
+ 'Exclude ': 1
+ Exclude Android: 1
+ Exclude Bratwurst: 1
+ Exclude CloudRendering: 1
+ Exclude Editor: 1
+ Exclude EmbeddedLinux: 1
+ Exclude GameCoreScarlett: 1
+ Exclude GameCoreXboxOne: 1
+ Exclude Linux64: 1
+ Exclude OSXUniversal: 1
+ Exclude PS4: 1
+ Exclude PS5: 1
+ Exclude QNX: 1
+ Exclude ReservedCFE: 1
+ Exclude Switch: 1
+ Exclude VisionOS: 1
+ Exclude WebGL: 1
+ Exclude Win: 1
+ Exclude Win64: 1
+ Exclude WindowsStoreApps: 1
+ Exclude XboxOne: 1
+ Exclude iOS: 1
+ Exclude tvOS: 1
+ Editor:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: Linux
+ Linux64:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ OSXUniversal:
+ enabled: 0
+ settings: {}
+ Win:
+ enabled: 0
+ settings: {}
+ Win64:
+ enabled: 0
+ settings: {}
+ WindowsStoreApps:
+ enabled: 0
+ settings: {}
+ iOS:
+ enabled: 0
+ settings: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libllava_shared.so b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libllava_shared.so
new file mode 100644
index 00000000..a4e6a752
Binary files /dev/null and b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libllava_shared.so differ
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libllava_shared.so.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libllava_shared.so.meta
new file mode 100644
index 00000000..13dec23e
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx2/libllava_shared.so.meta
@@ -0,0 +1,69 @@
+fileFormatVersion: 2
+guid: 76b656e0c348450b8d9de65aa2d37a6a
+labels:
+- NuGetForUnity
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 3
+ iconMap: {}
+ executionOrder: {}
+ defineConstraints: []
+ isPreloaded: 0
+ isOverridable: 0
+ isExplicitlyReferenced: 0
+ validateReferences: 1
+ platformData:
+ Any:
+ enabled: 0
+ settings:
+ 'Exclude ': 1
+ Exclude Android: 1
+ Exclude Bratwurst: 1
+ Exclude CloudRendering: 1
+ Exclude Editor: 1
+ Exclude EmbeddedLinux: 1
+ Exclude GameCoreScarlett: 1
+ Exclude GameCoreXboxOne: 1
+ Exclude Linux64: 1
+ Exclude OSXUniversal: 1
+ Exclude PS4: 1
+ Exclude PS5: 1
+ Exclude QNX: 1
+ Exclude ReservedCFE: 1
+ Exclude Switch: 1
+ Exclude VisionOS: 1
+ Exclude WebGL: 1
+ Exclude Win: 1
+ Exclude Win64: 1
+ Exclude WindowsStoreApps: 1
+ Exclude XboxOne: 1
+ Exclude iOS: 1
+ Exclude tvOS: 1
+ Editor:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: Linux
+ Linux64:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ OSXUniversal:
+ enabled: 0
+ settings: {}
+ Win:
+ enabled: 0
+ settings: {}
+ Win64:
+ enabled: 0
+ settings: {}
+ WindowsStoreApps:
+ enabled: 0
+ settings: {}
+ iOS:
+ enabled: 0
+ settings: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512.meta
new file mode 100644
index 00000000..61372f60
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 8c8f2d554d6da6a27a474454cea38f9e
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libggml-base.so b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libggml-base.so
new file mode 100644
index 00000000..d7649341
Binary files /dev/null and b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libggml-base.so differ
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libggml-base.so.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libggml-base.so.meta
new file mode 100644
index 00000000..aa897283
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libggml-base.so.meta
@@ -0,0 +1,69 @@
+fileFormatVersion: 2
+guid: fa873b6f27534d4d85806a86dcbde8d7
+labels:
+- NuGetForUnity
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 3
+ iconMap: {}
+ executionOrder: {}
+ defineConstraints: []
+ isPreloaded: 0
+ isOverridable: 0
+ isExplicitlyReferenced: 0
+ validateReferences: 1
+ platformData:
+ Any:
+ enabled: 0
+ settings:
+ 'Exclude ': 1
+ Exclude Android: 1
+ Exclude Bratwurst: 1
+ Exclude CloudRendering: 1
+ Exclude Editor: 1
+ Exclude EmbeddedLinux: 1
+ Exclude GameCoreScarlett: 1
+ Exclude GameCoreXboxOne: 1
+ Exclude Linux64: 1
+ Exclude OSXUniversal: 1
+ Exclude PS4: 1
+ Exclude PS5: 1
+ Exclude QNX: 1
+ Exclude ReservedCFE: 1
+ Exclude Switch: 1
+ Exclude VisionOS: 1
+ Exclude WebGL: 1
+ Exclude Win: 1
+ Exclude Win64: 1
+ Exclude WindowsStoreApps: 1
+ Exclude XboxOne: 1
+ Exclude iOS: 1
+ Exclude tvOS: 1
+ Editor:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: Linux
+ Linux64:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ OSXUniversal:
+ enabled: 0
+ settings: {}
+ Win:
+ enabled: 0
+ settings: {}
+ Win64:
+ enabled: 0
+ settings: {}
+ WindowsStoreApps:
+ enabled: 0
+ settings: {}
+ iOS:
+ enabled: 0
+ settings: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libggml-cpu.so b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libggml-cpu.so
new file mode 100644
index 00000000..e24ab24d
Binary files /dev/null and b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libggml-cpu.so differ
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libggml-cpu.so.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libggml-cpu.so.meta
new file mode 100644
index 00000000..c0bef9a0
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libggml-cpu.so.meta
@@ -0,0 +1,69 @@
+fileFormatVersion: 2
+guid: da86c7c680e345a3949a7d3ad444218f
+labels:
+- NuGetForUnity
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 3
+ iconMap: {}
+ executionOrder: {}
+ defineConstraints: []
+ isPreloaded: 0
+ isOverridable: 0
+ isExplicitlyReferenced: 0
+ validateReferences: 1
+ platformData:
+ Any:
+ enabled: 0
+ settings:
+ 'Exclude ': 1
+ Exclude Android: 1
+ Exclude Bratwurst: 1
+ Exclude CloudRendering: 1
+ Exclude Editor: 1
+ Exclude EmbeddedLinux: 1
+ Exclude GameCoreScarlett: 1
+ Exclude GameCoreXboxOne: 1
+ Exclude Linux64: 1
+ Exclude OSXUniversal: 1
+ Exclude PS4: 1
+ Exclude PS5: 1
+ Exclude QNX: 1
+ Exclude ReservedCFE: 1
+ Exclude Switch: 1
+ Exclude VisionOS: 1
+ Exclude WebGL: 1
+ Exclude Win: 1
+ Exclude Win64: 1
+ Exclude WindowsStoreApps: 1
+ Exclude XboxOne: 1
+ Exclude iOS: 1
+ Exclude tvOS: 1
+ Editor:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: Linux
+ Linux64:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ OSXUniversal:
+ enabled: 0
+ settings: {}
+ Win:
+ enabled: 0
+ settings: {}
+ Win64:
+ enabled: 0
+ settings: {}
+ WindowsStoreApps:
+ enabled: 0
+ settings: {}
+ iOS:
+ enabled: 0
+ settings: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libggml.so b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libggml.so
new file mode 100644
index 00000000..a54abbcd
Binary files /dev/null and b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libggml.so differ
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libggml.so.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libggml.so.meta
new file mode 100644
index 00000000..a9b5e6d8
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libggml.so.meta
@@ -0,0 +1,69 @@
+fileFormatVersion: 2
+guid: 75ec28c73b2448f9afa757f930683404
+labels:
+- NuGetForUnity
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 3
+ iconMap: {}
+ executionOrder: {}
+ defineConstraints: []
+ isPreloaded: 0
+ isOverridable: 0
+ isExplicitlyReferenced: 0
+ validateReferences: 1
+ platformData:
+ Any:
+ enabled: 0
+ settings:
+ 'Exclude ': 1
+ Exclude Android: 1
+ Exclude Bratwurst: 1
+ Exclude CloudRendering: 1
+ Exclude Editor: 1
+ Exclude EmbeddedLinux: 1
+ Exclude GameCoreScarlett: 1
+ Exclude GameCoreXboxOne: 1
+ Exclude Linux64: 1
+ Exclude OSXUniversal: 1
+ Exclude PS4: 1
+ Exclude PS5: 1
+ Exclude QNX: 1
+ Exclude ReservedCFE: 1
+ Exclude Switch: 1
+ Exclude VisionOS: 1
+ Exclude WebGL: 1
+ Exclude Win: 1
+ Exclude Win64: 1
+ Exclude WindowsStoreApps: 1
+ Exclude XboxOne: 1
+ Exclude iOS: 1
+ Exclude tvOS: 1
+ Editor:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: Linux
+ Linux64:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ OSXUniversal:
+ enabled: 0
+ settings: {}
+ Win:
+ enabled: 0
+ settings: {}
+ Win64:
+ enabled: 0
+ settings: {}
+ WindowsStoreApps:
+ enabled: 0
+ settings: {}
+ iOS:
+ enabled: 0
+ settings: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libllama.so b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libllama.so
new file mode 100644
index 00000000..00296671
Binary files /dev/null and b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libllama.so differ
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libllama.so.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libllama.so.meta
new file mode 100644
index 00000000..48530e5f
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libllama.so.meta
@@ -0,0 +1,69 @@
+fileFormatVersion: 2
+guid: 61ef9eec92ea4e698bc81de0fa3cb18e
+labels:
+- NuGetForUnity
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 3
+ iconMap: {}
+ executionOrder: {}
+ defineConstraints: []
+ isPreloaded: 0
+ isOverridable: 0
+ isExplicitlyReferenced: 0
+ validateReferences: 1
+ platformData:
+ Any:
+ enabled: 0
+ settings:
+ 'Exclude ': 1
+ Exclude Android: 1
+ Exclude Bratwurst: 1
+ Exclude CloudRendering: 1
+ Exclude Editor: 1
+ Exclude EmbeddedLinux: 1
+ Exclude GameCoreScarlett: 1
+ Exclude GameCoreXboxOne: 1
+ Exclude Linux64: 1
+ Exclude OSXUniversal: 1
+ Exclude PS4: 1
+ Exclude PS5: 1
+ Exclude QNX: 1
+ Exclude ReservedCFE: 1
+ Exclude Switch: 1
+ Exclude VisionOS: 1
+ Exclude WebGL: 1
+ Exclude Win: 1
+ Exclude Win64: 1
+ Exclude WindowsStoreApps: 1
+ Exclude XboxOne: 1
+ Exclude iOS: 1
+ Exclude tvOS: 1
+ Editor:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: Linux
+ Linux64:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ OSXUniversal:
+ enabled: 0
+ settings: {}
+ Win:
+ enabled: 0
+ settings: {}
+ Win64:
+ enabled: 0
+ settings: {}
+ WindowsStoreApps:
+ enabled: 0
+ settings: {}
+ iOS:
+ enabled: 0
+ settings: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libllava_shared.so b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libllava_shared.so
new file mode 100644
index 00000000..a4e6a752
Binary files /dev/null and b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libllava_shared.so differ
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libllava_shared.so.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libllava_shared.so.meta
new file mode 100644
index 00000000..29dbf961
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/avx512/libllava_shared.so.meta
@@ -0,0 +1,69 @@
+fileFormatVersion: 2
+guid: 368e27e8a1a3460fb5c7bfdfaf77b16e
+labels:
+- NuGetForUnity
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 3
+ iconMap: {}
+ executionOrder: {}
+ defineConstraints: []
+ isPreloaded: 0
+ isOverridable: 0
+ isExplicitlyReferenced: 0
+ validateReferences: 1
+ platformData:
+ Any:
+ enabled: 0
+ settings:
+ 'Exclude ': 1
+ Exclude Android: 1
+ Exclude Bratwurst: 1
+ Exclude CloudRendering: 1
+ Exclude Editor: 1
+ Exclude EmbeddedLinux: 1
+ Exclude GameCoreScarlett: 1
+ Exclude GameCoreXboxOne: 1
+ Exclude Linux64: 1
+ Exclude OSXUniversal: 1
+ Exclude PS4: 1
+ Exclude PS5: 1
+ Exclude QNX: 1
+ Exclude ReservedCFE: 1
+ Exclude Switch: 1
+ Exclude VisionOS: 1
+ Exclude WebGL: 1
+ Exclude Win: 1
+ Exclude Win64: 1
+ Exclude WindowsStoreApps: 1
+ Exclude XboxOne: 1
+ Exclude iOS: 1
+ Exclude tvOS: 1
+ Editor:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: Linux
+ Linux64:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ OSXUniversal:
+ enabled: 0
+ settings: {}
+ Win:
+ enabled: 0
+ settings: {}
+ Win64:
+ enabled: 0
+ settings: {}
+ WindowsStoreApps:
+ enabled: 0
+ settings: {}
+ iOS:
+ enabled: 0
+ settings: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx.meta
new file mode 100644
index 00000000..a75f5d6b
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 3ab8e2b24642e875eb58f013926f8091
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libggml-base.so b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libggml-base.so
new file mode 100644
index 00000000..d7649341
Binary files /dev/null and b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libggml-base.so differ
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libggml-base.so.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libggml-base.so.meta
new file mode 100644
index 00000000..90e2a21b
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libggml-base.so.meta
@@ -0,0 +1,69 @@
+fileFormatVersion: 2
+guid: 48d31ae2093548bdaac7698f604aee83
+labels:
+- NuGetForUnity
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 3
+ iconMap: {}
+ executionOrder: {}
+ defineConstraints: []
+ isPreloaded: 0
+ isOverridable: 0
+ isExplicitlyReferenced: 0
+ validateReferences: 1
+ platformData:
+ Any:
+ enabled: 0
+ settings:
+ 'Exclude ': 1
+ Exclude Android: 1
+ Exclude Bratwurst: 1
+ Exclude CloudRendering: 1
+ Exclude Editor: 1
+ Exclude EmbeddedLinux: 1
+ Exclude GameCoreScarlett: 1
+ Exclude GameCoreXboxOne: 1
+ Exclude Linux64: 1
+ Exclude OSXUniversal: 1
+ Exclude PS4: 1
+ Exclude PS5: 1
+ Exclude QNX: 1
+ Exclude ReservedCFE: 1
+ Exclude Switch: 1
+ Exclude VisionOS: 1
+ Exclude WebGL: 1
+ Exclude Win: 1
+ Exclude Win64: 1
+ Exclude WindowsStoreApps: 1
+ Exclude XboxOne: 1
+ Exclude iOS: 1
+ Exclude tvOS: 1
+ Editor:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: Linux
+ Linux64:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ OSXUniversal:
+ enabled: 0
+ settings: {}
+ Win:
+ enabled: 0
+ settings: {}
+ Win64:
+ enabled: 0
+ settings: {}
+ WindowsStoreApps:
+ enabled: 0
+ settings: {}
+ iOS:
+ enabled: 0
+ settings: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libggml-cpu.so b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libggml-cpu.so
new file mode 100644
index 00000000..41aacd01
Binary files /dev/null and b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libggml-cpu.so differ
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libggml-cpu.so.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libggml-cpu.so.meta
new file mode 100644
index 00000000..3e8fa733
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libggml-cpu.so.meta
@@ -0,0 +1,69 @@
+fileFormatVersion: 2
+guid: 09a973d2b7f3445bb984be8664f3ff23
+labels:
+- NuGetForUnity
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 3
+ iconMap: {}
+ executionOrder: {}
+ defineConstraints: []
+ isPreloaded: 0
+ isOverridable: 0
+ isExplicitlyReferenced: 0
+ validateReferences: 1
+ platformData:
+ Any:
+ enabled: 0
+ settings:
+ 'Exclude ': 1
+ Exclude Android: 1
+ Exclude Bratwurst: 1
+ Exclude CloudRendering: 1
+ Exclude Editor: 1
+ Exclude EmbeddedLinux: 1
+ Exclude GameCoreScarlett: 1
+ Exclude GameCoreXboxOne: 1
+ Exclude Linux64: 1
+ Exclude OSXUniversal: 1
+ Exclude PS4: 1
+ Exclude PS5: 1
+ Exclude QNX: 1
+ Exclude ReservedCFE: 1
+ Exclude Switch: 1
+ Exclude VisionOS: 1
+ Exclude WebGL: 1
+ Exclude Win: 1
+ Exclude Win64: 1
+ Exclude WindowsStoreApps: 1
+ Exclude XboxOne: 1
+ Exclude iOS: 1
+ Exclude tvOS: 1
+ Editor:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: Linux
+ Linux64:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ OSXUniversal:
+ enabled: 0
+ settings: {}
+ Win:
+ enabled: 0
+ settings: {}
+ Win64:
+ enabled: 0
+ settings: {}
+ WindowsStoreApps:
+ enabled: 0
+ settings: {}
+ iOS:
+ enabled: 0
+ settings: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libggml.so b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libggml.so
new file mode 100644
index 00000000..a54abbcd
Binary files /dev/null and b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libggml.so differ
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libggml.so.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libggml.so.meta
new file mode 100644
index 00000000..553b2481
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libggml.so.meta
@@ -0,0 +1,69 @@
+fileFormatVersion: 2
+guid: 680f408be02e47e5bcfaabf0c24bbce0
+labels:
+- NuGetForUnity
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 3
+ iconMap: {}
+ executionOrder: {}
+ defineConstraints: []
+ isPreloaded: 0
+ isOverridable: 0
+ isExplicitlyReferenced: 0
+ validateReferences: 1
+ platformData:
+ Any:
+ enabled: 0
+ settings:
+ 'Exclude ': 1
+ Exclude Android: 1
+ Exclude Bratwurst: 1
+ Exclude CloudRendering: 1
+ Exclude Editor: 1
+ Exclude EmbeddedLinux: 1
+ Exclude GameCoreScarlett: 1
+ Exclude GameCoreXboxOne: 1
+ Exclude Linux64: 1
+ Exclude OSXUniversal: 1
+ Exclude PS4: 1
+ Exclude PS5: 1
+ Exclude QNX: 1
+ Exclude ReservedCFE: 1
+ Exclude Switch: 1
+ Exclude VisionOS: 1
+ Exclude WebGL: 1
+ Exclude Win: 1
+ Exclude Win64: 1
+ Exclude WindowsStoreApps: 1
+ Exclude XboxOne: 1
+ Exclude iOS: 1
+ Exclude tvOS: 1
+ Editor:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: Linux
+ Linux64:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ OSXUniversal:
+ enabled: 0
+ settings: {}
+ Win:
+ enabled: 0
+ settings: {}
+ Win64:
+ enabled: 0
+ settings: {}
+ WindowsStoreApps:
+ enabled: 0
+ settings: {}
+ iOS:
+ enabled: 0
+ settings: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libllama.so b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libllama.so
new file mode 100644
index 00000000..00296671
Binary files /dev/null and b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libllama.so differ
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libllama.so.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libllama.so.meta
new file mode 100644
index 00000000..e58fd4ef
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libllama.so.meta
@@ -0,0 +1,69 @@
+fileFormatVersion: 2
+guid: c5485f03b76349e5a9754eab4aabe868
+labels:
+- NuGetForUnity
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 3
+ iconMap: {}
+ executionOrder: {}
+ defineConstraints: []
+ isPreloaded: 0
+ isOverridable: 0
+ isExplicitlyReferenced: 0
+ validateReferences: 1
+ platformData:
+ Any:
+ enabled: 0
+ settings:
+ 'Exclude ': 1
+ Exclude Android: 1
+ Exclude Bratwurst: 1
+ Exclude CloudRendering: 1
+ Exclude Editor: 1
+ Exclude EmbeddedLinux: 1
+ Exclude GameCoreScarlett: 1
+ Exclude GameCoreXboxOne: 1
+ Exclude Linux64: 1
+ Exclude OSXUniversal: 1
+ Exclude PS4: 1
+ Exclude PS5: 1
+ Exclude QNX: 1
+ Exclude ReservedCFE: 1
+ Exclude Switch: 1
+ Exclude VisionOS: 1
+ Exclude WebGL: 1
+ Exclude Win: 1
+ Exclude Win64: 1
+ Exclude WindowsStoreApps: 1
+ Exclude XboxOne: 1
+ Exclude iOS: 1
+ Exclude tvOS: 1
+ Editor:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: Linux
+ Linux64:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ OSXUniversal:
+ enabled: 0
+ settings: {}
+ Win:
+ enabled: 0
+ settings: {}
+ Win64:
+ enabled: 0
+ settings: {}
+ WindowsStoreApps:
+ enabled: 0
+ settings: {}
+ iOS:
+ enabled: 0
+ settings: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libllava_shared.so b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libllava_shared.so
new file mode 100644
index 00000000..a4e6a752
Binary files /dev/null and b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libllava_shared.so differ
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libllava_shared.so.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libllava_shared.so.meta
new file mode 100644
index 00000000..14b00255
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/linux-x64/native/noavx/libllava_shared.so.meta
@@ -0,0 +1,69 @@
+fileFormatVersion: 2
+guid: 58acff590380401eabaf97813ca0014e
+labels:
+- NuGetForUnity
+PluginImporter:
+ externalObjects: {}
+ serializedVersion: 3
+ iconMap: {}
+ executionOrder: {}
+ defineConstraints: []
+ isPreloaded: 0
+ isOverridable: 0
+ isExplicitlyReferenced: 0
+ validateReferences: 1
+ platformData:
+ Any:
+ enabled: 0
+ settings:
+ 'Exclude ': 1
+ Exclude Android: 1
+ Exclude Bratwurst: 1
+ Exclude CloudRendering: 1
+ Exclude Editor: 1
+ Exclude EmbeddedLinux: 1
+ Exclude GameCoreScarlett: 1
+ Exclude GameCoreXboxOne: 1
+ Exclude Linux64: 1
+ Exclude OSXUniversal: 1
+ Exclude PS4: 1
+ Exclude PS5: 1
+ Exclude QNX: 1
+ Exclude ReservedCFE: 1
+ Exclude Switch: 1
+ Exclude VisionOS: 1
+ Exclude WebGL: 1
+ Exclude Win: 1
+ Exclude Win64: 1
+ Exclude WindowsStoreApps: 1
+ Exclude XboxOne: 1
+ Exclude iOS: 1
+ Exclude tvOS: 1
+ Editor:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ DefaultValueInitialized: true
+ OS: Linux
+ Linux64:
+ enabled: 1
+ settings:
+ CPU: x86_64
+ OSXUniversal:
+ enabled: 0
+ settings: {}
+ Win:
+ enabled: 0
+ settings: {}
+ Win64:
+ enabled: 0
+ settings: {}
+ WindowsStoreApps:
+ enabled: 0
+ settings: {}
+ iOS:
+ enabled: 0
+ settings: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/osx-arm64.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/osx-arm64.meta
new file mode 100644
index 00000000..5ddfaf29
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/osx-arm64.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 83514264c94bfa0628fe0cd2d6de7237
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/osx-arm64/native.meta b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/osx-arm64/native.meta
new file mode 100644
index 00000000..1f46cc37
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/osx-arm64/native.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 4ebbcc51596ead803b2aa42229f6a4e6
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/osx-arm64/native/ggml-metal.metal b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/osx-arm64/native/ggml-metal.metal
new file mode 100644
index 00000000..44f04c90
--- /dev/null
+++ b/Assets/Packages/LLamaSharp.Backend.Cpu.0.21.0/runtimes/osx-arm64/native/ggml-metal.metal
@@ -0,0 +1,6735 @@
+#define GGML_COMMON_DECL_METAL
+#define GGML_COMMON_IMPL_METAL
+#if defined(GGML_METAL_EMBED_LIBRARY)
+__embed_ggml-common.h__
+#else
+// TODO: this should not be a relative path, but can't figure out how to set Metal include paths in Package.swift
+#include "../ggml-common.h"
+#endif
+#include "ggml-metal-impl.h"
+
+#include
+
+using namespace metal;
+
+#define MAX(x, y) ((x) > (y) ? (x) : (y))
+#define MIN(x, y) ((x) < (y) ? (x) : (y))
+#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
+
+#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
+
+// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
+//
+// cmd:
+// .../usr/bin/metal -dM -E -c ggml/src/ggml-metal/ggml-metal.metal
+// .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal/ggml-metal.metal
+//
+#if __METAL_VERSION__ < 310 && defined(GGML_METAL_USE_BF16)
+#undef GGML_METAL_USE_BF16
+#endif
+
+#if defined(GGML_METAL_USE_BF16)
+typedef matrix bfloat4x4;
+#endif
+
+constexpr constant static float kvalues_iq4nl_f[16] = {
+ -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
+};
+
+// NOTE: this is not dequantizing - we are simply fitting the template
+template
+void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
+ reg = (type4x4)(*src);
+}
+
+template
+void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
+ reg = (type4x4)(*src);
+}
+
+template
+void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
+ reg = (type4)(*(src + il));
+}
+
+#if defined(GGML_METAL_USE_BF16)
+template
+void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
+ reg = (type4x4)(*src);
+}
+#endif
+
+template
+void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 1);
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
+ const float d2 = d1 / 256.f;
+ const float md = -8.h * xb->d;
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
+ const ushort mask1 = mask0 << 8;
+
+ float4x4 reg_f;
+
+ for (int i = 0; i < 8; i++) {
+ reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
+ reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
+ }
+
+ reg = (type4x4) reg_f;
+}
+
+template
+void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 1);
+ const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
+ const float d2 = d1 / 256.f;
+ const float md = -8.h * xb->d;
+ const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
+ const ushort mask1 = mask0 << 8;
+
+ for (int i = 0; i < 2; i++) {
+ reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md;
+ reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md;
+ }
+}
+
+template
+void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 2);
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
+ const float d2 = d1 / 256.f;
+ const float m = xb->m;
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
+ const ushort mask1 = mask0 << 8;
+
+ float4x4 reg_f;
+
+ for (int i = 0; i < 8; i++) {
+ reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;
+ reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;
+ }
+
+ reg = (type4x4) reg_f;
+}
+
+template
+void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 2);
+ const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
+ const float d2 = d1 / 256.f;
+ const float m = xb->m;
+ const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
+ const ushort mask1 = mask0 << 8;
+
+ for (int i = 0; i < 2; i++) {
+ reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m;
+ reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m;
+ }
+}
+
+template
+void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 3);
+ const float d = xb->d;
+ const float md = -16.h * xb->d;
+ const ushort mask = il ? 0x00F0 : 0x000F;
+
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+ const int x_mv = il ? 4 : 0;
+
+ const int gh_mv = il ? 12 : 0;
+ const int gh_bk = il ? 0 : 4;
+
+ float4x4 reg_f;
+
+ for (int i = 0; i < 8; i++) {
+ // extract the 5-th bits for x0 and x1
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+ // combine the 4-bits from qs with the 5th bit
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+ reg_f[i/2][2*(i%2) + 0] = d * x0 + md;
+ reg_f[i/2][2*(i%2) + 1] = d * x1 + md;
+ }
+
+ reg = (type4x4) reg_f;
+}
+
+template
+void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 3);
+ const float d = xb->d;
+ const float md = -16.h * xb->d;
+ const ushort mask = (il/4) ? 0x00F0 : 0x000F;
+
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+ const int x_mv = (il/4) ? 4 : 0;
+
+ const int gh_mv = (il/4) ? 12 : 0;
+ const int gh_bk = (il/4) ? 0 : 4;
+
+ for (int ii = 0; ii < 2; ii++) {
+ int i = 2*(il%4) + ii;
+
+ // extract the 5-th bits for x0 and x1
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+ // combine the 4-bits from qs with the 5th bit
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+ reg[2*ii + 0] = d * x0 + md;
+ reg[2*ii + 1] = d * x1 + md;
+ }
+}
+
+template
+void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 4);
+ const float d = xb->d;
+ const float m = xb->m;
+ const ushort mask = il ? 0x00F0 : 0x000F;
+
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+ const int x_mv = il ? 4 : 0;
+
+ const int gh_mv = il ? 12 : 0;
+ const int gh_bk = il ? 0 : 4;
+
+ float4x4 reg_f;
+
+ for (int i = 0; i < 8; i++) {
+ // extract the 5-th bits for x0 and x1
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+ // combine the 4-bits from qs with the 5th bit
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+ reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
+ reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
+ }
+
+ reg = (type4x4) reg_f;
+}
+
+template
+void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 4);
+ const float d = xb->d;
+ const float m = xb->m;
+ const ushort mask = (il/4) ? 0x00F0 : 0x000F;
+
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+ const int x_mv = (il/4) ? 4 : 0;
+
+ const int gh_mv = (il/4) ? 12 : 0;
+ const int gh_bk = (il/4) ? 0 : 4;
+
+ for (int ii = 0; ii < 2; ii++) {
+ int i = 2*(il%4) + ii;
+
+ // extract the 5-th bits for x0 and x1
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+ // combine the 4-bits from qs with the 5th bit
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+ reg[2*ii + 0] = d * x0 + m;
+ reg[2*ii + 1] = d * x1 + m;
+ }
+}
+
+template
+void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
+ device const int8_t * qs = ((device const int8_t *)xb->qs);
+ const float d = xb->d;
+
+ float4x4 reg_f;
+
+ for (int i = 0; i < 16; i++) {
+ reg_f[i/4][i%4] = (qs[i + 16*il] * d);
+ }
+
+ reg = (type4x4) reg_f;
+}
+
+template
+void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {
+ device const int8_t * qs = ((device const int8_t *)xb->qs);
+ const float d = xb->d;
+
+ for (int i = 0; i < 4; i++) {
+ reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
+ }
+}
+
+template
+void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
+ const float d = xb->d;
+ const float min = xb->dmin;
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
+ float dl, ml;
+ uint8_t sc = xb->scales[il];
+
+ q = q + 32*(il/8) + 16*(il&1);
+ il = (il/2)%4;
+
+ half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+ uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
+ }
+}
+
+template
+void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
+ const half d_all = xb->d;
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
+ device const uint8_t * h = (device const uint8_t *)xb->hmask;
+ device const int8_t * scales = (device const int8_t *)xb->scales;
+
+ q = q + 32 * (il/8) + 16 * (il&1);
+ h = h + 16 * (il&1);
+ uint8_t m = 1 << (il/2);
+ uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
+ ((il/4)>0 ? 12 : 3);
+ uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
+ uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
+ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
+ : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
+ float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
+ const float ml = 4.f * dl;
+
+ il = (il/2) & 3;
+ const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+ const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ dl *= coef;
+
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
+ }
+}
+
+static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
+ return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
+ : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
+}
+
+template
+void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) {
+ device const uchar * q = xb->qs;
+
+ short is = (il/4) * 2;
+ q = q + (il/4) * 32 + 16 * (il&1);
+ il = il & 3;
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
+ const float min = xb->dmin;
+ const float dl = d * sc[0];
+ const float ml = min * sc[1];
+
+ const ushort mask = il < 2 ? 0x0F : 0xF0;
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
+ }
+}
+
+template
+void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
+ device const uint8_t * q = xb->qs;
+ device const uint8_t * qh = xb->qh;
+
+ short is = (il/4) * 2;
+ q = q + 32 * (il/4) + 16 * (il&1);
+ qh = qh + 16 * (il&1);
+ uint8_t ul = 1 << (il/2);
+ il = il & 3;
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+ const float d = il < 2 ? xb->d : xb->d / 16.f;
+ const float min = xb->dmin;
+ const float dl = d * sc[0];
+ const float ml = min * sc[1];
+
+ const ushort mask = il<2 ? 0x0F : 0xF0;
+ const float qh_val = il<2 ? 16.f : 256.f;
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
+ }
+}
+
+template
+void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
+ const half d_all = xb->d;
+ device const uint8_t * ql = (device const uint8_t *)xb->ql;
+ device const uint8_t * qh = (device const uint8_t *)xb->qh;
+ device const int8_t * scales = (device const int8_t *)xb->scales;
+
+ ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
+ qh = qh + 32*(il/8) + 16*(il&1);
+ float sc = scales[(il%2) + 2 * ((il/2))];
+ il = (il/2) & 3;
+
+ const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
+ const float coef = il>1 ? 1.f/16.f : 1.f;
+ const float ml = d_all * sc * 32.f;
+ const float dl = d_all * sc * coef;
+ for (int i = 0; i < 16; ++i) {
+ const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
+ : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
+ reg[i/4][i%4] = dl * q - ml;
+ }
+}
+
+template
+void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const float d = xb->d;
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
+ device const uint16_t * q2 = xb->qs + 4*ib32;
+ const uint32_t aux32_g = q2[0] | (q2[1] << 16);
+ const uint32_t aux32_s = q2[2] | (q2[3] << 16);
+ thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
+ const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
+ constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
+ uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
+ for (int i = 0; i < 8; ++i) {
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+ }
+ grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
+ signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
+ for (int i = 0; i < 8; ++i) {
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+ }
+}
+
+template
+void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const float d = xb->d;
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ device const uint16_t * q2 = xb->qs + 4*ib32;
+ const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
+ constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
+ uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
+ for (int i = 0; i < 8; ++i) {
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+ }
+ grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
+ signs = ksigns_iq2xs[q2[2*il+1] >> 9];
+ for (int i = 0; i < 8; ++i) {
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+ }
+}
+
+template
+void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const float d = xb->d;
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ device const uint8_t * q3 = xb->qs + 8*ib32;
+ device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
+ const uint32_t aux32 = gas[0] | (gas[1] << 16);
+ const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
+ constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
+ constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
+ uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
+ for (int i = 0; i < 4; ++i) {
+ reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
+ reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
+ }
+ grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
+ grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
+ signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
+ for (int i = 0; i < 4; ++i) {
+ reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
+ reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
+ }
+}
+
+template
+void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const float d = xb->d;
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ device const uint8_t * qs = xb->qs + 8*ib32;
+ device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
+ const uint8_t qh = xb->qh[ib32] >> 4*il;
+ const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
+ constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
+ for (int i = 0; i < 4; ++i) {
+ reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
+ reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
+ }
+ grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
+ grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
+ for (int i = 0; i < 4; ++i) {
+ reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
+ reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
+ }
+}
+
+template
+void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const float d = xb->d;
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
+ device const uint8_t * signs = qs + QK_K/8;
+ const uint8_t qh = xb->qh[ib32] >> 4*il;
+ const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
+ constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
+ for (int i = 0; i < 8; ++i) {
+ reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
+ reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
+ }
+}
+
+template
+void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const int ib32 = il/2;
+ il = il%2;
+ const float d = xb->d;
+ device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
+ device const uint16_t * qh = xb->qh;
+ const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
+ const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
+ const uint16_t h = qh[ib32] >> 6*il;
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
+ for (int i = 0; i < 4; ++i) {
+ reg[0][i] = dl * (grid1[i] & 0xf) + ml;
+ reg[1][i] = dl * (grid1[i] >> 4) + ml;
+ reg[2][i] = dl * (grid2[i] & 0xf) + ml;
+ reg[3][i] = dl * (grid2[i] >> 4) + ml;
+ }
+}
+
+template
+void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const int ib32 = il/2;
+ il = il%2;
+ device const uint16_t * sc = (device const uint16_t *)xb->scales;
+
+ iq1m_scale_t scale;
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+ const float d = scale.f16;
+
+ device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
+ device const uint8_t * qh = xb->qh + 2*ib32 + il;
+
+ const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
+ const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+ const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
+ for (int i = 0; i < 4; ++i) {
+ reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
+ reg[1][i] = dl * (grid1[i] >> 4) + ml1;
+ reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
+ reg[3][i] = dl * (grid2[i] >> 4) + ml2;
+ }
+}
+
+template
+void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
+ device const uint16_t * q4 = (device const uint16_t *)xb->qs;
+ const float d = xb->d;
+ uint32_t aux32;
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+ for (int i = 0; i < 4; ++i) {
+ aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
+ reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
+ reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
+ reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
+ reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
+ }
+}
+
+template
+void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) {
+ device const uint16_t * q4 = (device const uint16_t *)xb->qs;
+ const float d = xb->d;
+ uint32_t aux32;
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+ aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f;
+ reg[0] = d * kvalues_iq4nl_f[q8[0]];
+ reg[1] = d * kvalues_iq4nl_f[q8[1]];
+ reg[2] = d * kvalues_iq4nl_f[q8[2]];
+ reg[3] = d * kvalues_iq4nl_f[q8[3]];
+}
+
+template
+void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
+ const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
+ const float d = (float)xb->d * (ls - 32);
+ uint32_t aux32;
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+ for (int i = 0; i < 4; ++i) {
+ aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
+ reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
+ reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
+ reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
+ reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
+ }
+}
+
+enum ggml_sort_order {
+ GGML_SORT_ORDER_ASC,
+ GGML_SORT_ORDER_DESC,
+};
+
+// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
+// pros: works for non-contiguous tensors, supports broadcast across all dims
+// cons: not very efficient
+kernel void kernel_add(
+ constant ggml_metal_kargs_bin & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig.z;
+ const int i02 = tgpig.y;
+ const int i01 = tgpig.x;
+
+ const int i13 = i03%args.ne13;
+ const int i12 = i02%args.ne12;
+ const int i11 = i01%args.ne11;
+
+ device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
+ device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+ device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
+
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ const int i10 = i0%args.ne10;
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10));
+ }
+}
+
+kernel void kernel_sub(
+ constant ggml_metal_kargs_bin & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig.z;
+ const int i02 = tgpig.y;
+ const int i01 = tgpig.x;
+
+ const int i13 = i03%args.ne13;
+ const int i12 = i02%args.ne12;
+ const int i11 = i01%args.ne11;
+
+ device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
+ device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+ device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
+
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ const int i10 = i0%args.ne10;
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10));
+ }
+}
+
+kernel void kernel_mul(
+ constant ggml_metal_kargs_bin & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig.z;
+ const int i02 = tgpig.y;
+ const int i01 = tgpig.x;
+
+ const int i13 = i03%args.ne13;
+ const int i12 = i02%args.ne12;
+ const int i11 = i01%args.ne11;
+
+ device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
+ device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+ device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1;
+
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ const int i10 = i0%args.ne10;
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
+ }
+}
+
+kernel void kernel_div(
+ constant ggml_metal_kargs_bin & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig.z;
+ const int i02 = tgpig.y;
+ const int i01 = tgpig.x;
+
+ const int i13 = i03%args.ne13;
+ const int i12 = i02%args.ne12;
+ const int i11 = i01%args.ne11;
+
+ device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
+ device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+ device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1;
+
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ const int i10 = i0%args.ne10;
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
+ }
+}
+
+template
+kernel void kernel_repeat(
+ constant ggml_metal_kargs_repeat & args,
+ device const char * src0,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i3 = tgpig.z;
+ const int i2 = tgpig.y;
+ const int i1 = tgpig.x;
+
+ const int i03 = i3%args.ne03;
+ const int i02 = i2%args.ne02;
+ const int i01 = i1%args.ne01;
+
+ device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
+ device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1;
+
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ const int i00 = i0%args.ne00;
+ *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
+ }
+}
+
+typedef decltype(kernel_repeat) kernel_repeat_t;
+
+template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat;
+template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat;
+template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat;
+template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat;
+
+// assumption: src1 is a row
+// broadcast src1 into src0
+kernel void kernel_add_row(
+ constant ggml_metal_kargs_bin & args,
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ const uint nb = args.ne00/4;
+ dst[tpig] = src0[tpig] + src1[tpig % nb];
+}
+
+kernel void kernel_sub_row(
+ constant ggml_metal_kargs_bin & args,
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ const uint nb = args.ne00/4;
+ dst[tpig] = src0[tpig] - src1[tpig % nb];
+}
+
+kernel void kernel_mul_row(
+ constant ggml_metal_kargs_bin & args,
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ const uint nb = args.ne00/4;
+ dst[tpig] = src0[tpig] * src1[tpig % nb];
+}
+
+kernel void kernel_div_row(
+ constant ggml_metal_kargs_bin & args,
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ const uint nb = args.ne00/4;
+ dst[tpig] = src0[tpig] / src1[tpig % nb];
+}
+
+kernel void kernel_scale(
+ device const float * src0,
+ device float * dst,
+ constant float & scale,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_scale_4(
+ device const float4 * src0,
+ device float4 * dst,
+ constant float & scale,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_clamp(
+ device const float * src0,
+ device float * dst,
+ constant float & min,
+ constant float & max,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
+}
+
+kernel void kernel_relu(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = max(0.0f, src0[tpig]);
+}
+
+kernel void kernel_sigmoid(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
+}
+
+kernel void kernel_tanh(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float & x = src0[tpig];
+ dst[tpig] = precise::tanh(x);
+}
+
+constant float GELU_COEF_A = 0.044715f;
+constant float GELU_QUICK_COEF = -1.702f;
+constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+
+kernel void kernel_gelu(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float & x = src0[tpig];
+
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_4(
+ device const float4 * src0,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float4 & x = src0[tpig];
+
+ // BEWARE !!!
+ // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
+ // This was observed with Falcon 7B and 40B models
+ //
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_quick(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float & x = src0[tpig];
+
+ dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_gelu_quick_4(
+ device const float4 * src0,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float4 & x = src0[tpig];
+
+ dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_silu(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float & x = src0[tpig];
+ dst[tpig] = x / (1.0f + exp(-x));
+}
+
+kernel void kernel_silu_4(
+ device const float4 * src0,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float4 & x = src0[tpig];
+ dst[tpig] = x / (1.0f + exp(-x));
+}
+
+kernel void kernel_elu(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float & x = src0[tpig];
+ dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
+}
+
+kernel void kernel_sqr(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] * src0[tpig];
+}
+
+kernel void kernel_sqrt(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = sqrt(src0[tpig]);
+}
+
+kernel void kernel_sin(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = sin(src0[tpig]);
+}
+
+kernel void kernel_cos(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = cos(src0[tpig]);
+}
+
+kernel void kernel_sum_rows(
+ device const float * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tpig[[thread_position_in_grid]]) {
+ int64_t i3 = tpig.z;
+ int64_t i2 = tpig.y;
+ int64_t i1 = tpig.x;
+
+ if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+ return;
+ }
+
+ device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
+ device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
+
+ float row_sum = 0;
+
+ for (int64_t i0 = 0; i0 < ne00; i0++) {
+ row_sum += src_row[i0];
+ }
+
+ dst_row[0] = row_sum;
+}
+
+template
+kernel void kernel_soft_max(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant float & scale,
+ constant float & max_bias,
+ constant float & m0,
+ constant float & m1,
+ constant uint32_t & n_head_log2,
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = (tgpig) / (ne02*ne01);
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+
+ device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
+ device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+
+ float slope = 1.0f;
+
+ // ALiBi
+ if (max_bias > 0.0f) {
+ const int64_t h = i02;
+
+ const float base = h < n_head_log2 ? m0 : m1;
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+ slope = pow(base, exp);
+ }
+
+ // parallel max
+ float lmax = -INFINITY;
+
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
+ }
+
+ // find the max value in the block
+ float max_val = simd_max(lmax);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = -INFINITY;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = max_val;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ max_val = buf[tiisg];
+ max_val = simd_max(max_val);
+ }
+
+ // parallel sum
+ float lsum = 0.0f;
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
+ lsum += exp_psrc0;
+ pdst[i00] = exp_psrc0;
+ }
+
+ // This barrier fixes a failing test
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+ threadgroup_barrier(mem_flags::mem_none);
+
+ float sum = simd_sum(lsum);
+
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = sum;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sum = buf[tiisg];
+ sum = simd_sum(sum);
+ }
+
+ const float inv_sum = 1.0f/sum;
+
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ pdst[i00] *= inv_sum;
+ }
+}
+
+template
+kernel void kernel_soft_max_4(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant float & scale,
+ constant float & max_bias,
+ constant float & m0,
+ constant float & m1,
+ constant uint32_t & n_head_log2,
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = (tgpig) / (ne02*ne01);
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+
+ device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
+ device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+
+ float slope = 1.0f;
+
+ if (max_bias > 0.0f) {
+ const int64_t h = i02;
+
+ const float base = h < n_head_log2 ? m0 : m1;
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+ slope = pow(base, exp);
+ }
+
+ // parallel max
+ float4 lmax4 = -INFINITY;
+
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
+ }
+
+ const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
+
+ float max_val = simd_max(lmax);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = -INFINITY;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = max_val;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ max_val = buf[tiisg];
+ max_val = simd_max(max_val);
+ }
+
+ // parallel sum
+ float4 lsum4 = 0.0f;
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
+ lsum4 += exp_psrc4;
+ pdst4[i00] = exp_psrc4;
+ }
+
+ const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+
+ // This barrier fixes a failing test
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+ threadgroup_barrier(mem_flags::mem_none);
+
+ float sum = simd_sum(lsum);
+
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = sum;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sum = buf[tiisg];
+ sum = simd_sum(sum);
+ }
+
+ const float inv_sum = 1.0f/sum;
+
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ pdst4[i00] *= inv_sum;
+ }
+}
+
+typedef decltype(kernel_soft_max) kernel_soft_max_t;
+typedef decltype(kernel_soft_max_4) kernel_soft_max_4_t;
+
+template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max;
+template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max;
+template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4;
+template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4;
+
+kernel void kernel_diag_mask_inf(
+ device const float * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int & n_past,
+ uint3 tpig[[thread_position_in_grid]]) {
+ const int64_t i02 = tpig[2];
+ const int64_t i01 = tpig[1];
+ const int64_t i00 = tpig[0];
+
+ if (i00 > n_past + i01) {
+ dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
+ } else {
+ dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
+ }
+}
+
+kernel void kernel_diag_mask_inf_8(
+ device const float4 * src0,
+ device float4 * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int & n_past,
+ uint3 tpig[[thread_position_in_grid]]) {
+
+ const int64_t i = 2*tpig[0];
+
+ dst[i+0] = src0[i+0];
+ dst[i+1] = src0[i+1];
+ int64_t i4 = 4*i;
+ const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
+ const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
+ const int64_t i00 = i4;
+ for (int k = 3; k >= 0; --k) {
+ if (i00 + 4 + k <= n_past + i01) {
+ break;
+ }
+ dst[i+1][k] = -INFINITY;
+ if (i00 + k > n_past + i01) {
+ dst[i][k] = -INFINITY;
+ }
+ }
+}
+
+// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
+// TODO: optimize
+kernel void kernel_ssm_conv_f32(
+ device const void * src0,
+ device const void * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t ir = tgpig.x;
+ const int64_t i2 = tgpig.y;
+ const int64_t i3 = tgpig.z;
+
+ const int64_t nc = ne10;
+ //const int64_t ncs = ne00;
+ //const int64_t nr = ne01;
+ //const int64_t n_t = ne1;
+ //const int64_t n_s = ne2;
+
+ device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
+ device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
+ device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2);
+
+ float sumf = 0.0f;
+
+ for (int64_t i0 = 0; i0 < nc; ++i0) {
+ sumf += s[i0] * c[i0];
+ }
+
+ x[0] = sumf;
+}
+
+// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
+// TODO: optimize
+kernel void kernel_ssm_scan_f32(
+ device const void * src0,
+ device const void * src1,
+ device const void * src2,
+ device const void * src3,
+ device const void * src4,
+ device const void * src5,
+ device float * dst,
+ constant int64_t & d_state,
+ constant int64_t & d_inner,
+ constant int64_t & n_seq_tokens,
+ constant int64_t & n_seqs,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant uint64_t & nb20,
+ constant uint64_t & nb21,
+ constant uint64_t & nb22,
+ constant uint64_t & nb30,
+ constant uint64_t & nb31,
+ constant uint64_t & nb40,
+ constant uint64_t & nb41,
+ constant uint64_t & nb42,
+ constant uint64_t & nb50,
+ constant uint64_t & nb51,
+ constant uint64_t & nb52,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t ir = tgpig.x;
+ const int64_t i3 = tgpig.y;
+
+ const int64_t nc = d_state;
+ //const int64_t nr = d_inner;
+ const int64_t n_t = n_seq_tokens;
+ //const int64_t n_s = n_seqs;
+
+ for (int64_t i2 = 0; i2 < n_t; ++i2) {
+ device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
+ device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
+ device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
+ device const float * A = (device const float *) ((device const char *) src3 + ir*nb31);
+ device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
+ device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
+ device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
+ device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13);
+
+ if (i2 > 0) {
+ s0 = s;
+ }
+
+ // i1 == 0
+ float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
+ float x_dt = x[0] * dt_soft_plus;
+ float sumf = 0.0f;
+
+ for (int64_t i0 = 0; i0 < nc; ++i0) {
+ int64_t i = i0;
+ float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
+ sumf += state * C[i0];
+ s[i] = state;
+ }
+
+ y[0] = sumf;
+ }
+}
+
+kernel void kernel_argmax(
+ device const void * x,
+ device int32_t * dst,
+ constant int64_t & ncols,
+ constant uint64_t & nb01,
+ threadgroup float * shared_maxval [[threadgroup(0)]],
+ threadgroup int32_t * shared_argmax [[threadgroup(1)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ device const float * x_row = (device const float *) ((device const char *) x + tgpig * nb01);
+
+ float lmax = -INFINITY;
+ int32_t larg = -1;
+
+ for (int i00 = tpitg; i00 < ncols; i00 += ntg) {
+ if (x_row[i00] > lmax) {
+ lmax = x_row[i00];
+ larg = i00;
+ }
+ }
+
+ // find the argmax value in the block
+ float max_val = simd_max(lmax);
+ int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));
+
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ shared_maxval[tiisg] = -INFINITY;
+ shared_argmax[tiisg] = -1;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ shared_maxval[sgitg] = max_val;
+ shared_argmax[sgitg] = arg_val;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ max_val = shared_maxval[tiisg];
+ arg_val = shared_argmax[tiisg];
+
+ float max_val_reduced = simd_max(max_val);
+ int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));
+
+ dst[tgpig] = arg_val_reduced;
+
+ return;
+ }
+
+ dst[tgpig] = arg_val;
+}
+
+kernel void kernel_norm(
+ constant ggml_metal_kargs_norm & args,
+ device const char * src0,
+ device char * dst,
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ ushort tpitg[[thread_position_in_threadgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort ntg[[threads_per_threadgroup]]) {
+ if (sgitg == 0) {
+ shmem_f32[tiisg] = 0.0f;
+ }
+
+ device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
+
+ float4 sumf4(0.0f);
+
+ float sumf = 0.0f;
+
+ for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+ sumf4 += x[i00];
+ }
+ sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3];
+ sumf = simd_sum(sumf);
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ shmem_f32[sgitg] = sumf;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sumf = shmem_f32[tiisg];
+ sumf = simd_sum(sumf);
+
+ const float mean = sumf/args.ne00;
+
+ device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
+
+ sumf = 0.0f;
+ for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+ y[i00] = x[i00] - mean;
+ sumf += dot(y[i00], y[i00]);
+ }
+ sumf = simd_sum(sumf);
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ shmem_f32[sgitg] = sumf;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sumf = shmem_f32[tiisg];
+ sumf = simd_sum(sumf);
+
+ const float variance = sumf/args.ne00;
+
+ const float scale = 1.0f/sqrt(variance + args.eps);
+ for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+ y[i00] = y[i00] * scale;
+ }
+}
+
+kernel void kernel_rms_norm(
+ constant ggml_metal_kargs_rms_norm & args,
+ device const char * src0,
+ device char * dst,
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ ushort tpitg[[thread_position_in_threadgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort ntg[[threads_per_threadgroup]]) {
+ if (sgitg == 0) {
+ shmem_f32[tiisg] = 0.0f;
+ }
+
+ device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
+
+ float sumf = 0.0f;
+
+ // parallel sum
+ for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+ sumf += dot(x[i00], x[i00]);
+ }
+ sumf = simd_sum(sumf);
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ shmem_f32[sgitg] = sumf;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sumf = shmem_f32[tiisg];
+ sumf = simd_sum(sumf);
+
+ const float mean = sumf/args.ne00;
+ const float scale = 1.0f/sqrt(mean + args.eps);
+
+ device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
+ for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+ y[i00] = x[i00] * scale;
+ }
+}
+
+kernel void kernel_group_norm(
+ device const float * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int32_t & n_groups,
+ constant float & eps,
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ const int64_t ne = ne00*ne01*ne02;
+ const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
+
+ int start = tgpig * gs;
+ int end = start + gs;
+
+ start += tpitg;
+
+ if (end >= ne) {
+ end = ne;
+ }
+
+ float tmp = 0.0f; // partial sum for thread in warp
+
+ for (int j = start; j < end; j += ntg) {
+ tmp += src0[j];
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ tmp = simd_sum(tmp);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = tmp;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ tmp = buf[tiisg];
+ tmp = simd_sum(tmp);
+ }
+
+ const float mean = tmp / gs;
+ tmp = 0.0f;
+
+ for (int j = start; j < end; j += ntg) {
+ float xi = src0[j] - mean;
+ dst[j] = xi;
+ tmp += xi * xi;
+ }
+
+ tmp = simd_sum(tmp);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = tmp;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ tmp = buf[tiisg];
+ tmp = simd_sum(tmp);
+ }
+
+ const float variance = tmp / gs;
+ const float scale = 1.0f/sqrt(variance + eps);
+ for (int j = start; j < end; j += ntg) {
+ dst[j] *= scale;
+ }
+}
+
+// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q4 quants begin (0 or QK4_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
+ float d = qb_curr->d;
+
+ float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
+
+ device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2);
+
+ for (int i = 0; i < 8; i += 2) {
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
+ acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
+ acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
+ acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
+ }
+
+ return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]);
+}
+
+// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q4 quants begin (0 or QK4_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
+ float d = qb_curr->d;
+ float m = qb_curr->m;
+
+ float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
+
+ device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2);
+
+ for (int i = 0; i < 8; i+=2) {
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
+ acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
+ acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
+ acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
+ }
+
+ return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
+}
+
+// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q5 quants begin (0 or QK5_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
+ float d = qb_curr->d;
+
+ float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
+
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
+
+ for (int i = 0; i < 8; i+=2) {
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010));
+ acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
+ acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
+ acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+ }
+
+ return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]);
+}
+
+// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q5 quants begin (0 or QK5_1/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
+ float d = qb_curr->d;
+ float m = qb_curr->m;
+
+ float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
+
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
+
+ for (int i = 0; i < 8; i+=2) {
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010));
+ acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
+ acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
+ acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+ }
+
+ return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
+}
+
+// putting them in the kernel cause a significant performance penalty
+#define N_DST 4 // each SIMD group works on 4 rows
+#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
+//Note: This is a template, but strictly speaking it only applies to
+// quantizations where the block size is 32. It also does not
+// guard against the number of rows not being divisible by
+// N_DST, so this is another explicit assumption of the implementation.
+template
+void mul_vec_q_n_f32_impl(
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+ const int nb = args.ne00/QK4_0;
+
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * nsg + sgitg) * nr;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ //device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
+
+ // pointers to src0 rows
+ device const block_q_type * ax[nr];
+ for (int row = 0; row < nr; ++row) {
+ const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+
+ ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
+ }
+
+ float yl[16]; // src1 vector cache
+ float sumf[nr] = {0.f};
+
+ const short ix = (tiisg/2);
+ const short il = (tiisg%2)*8;
+
+ device const float * yb = y + ix*QK4_0 + il;
+
+ // each thread in a SIMD group deals with half a block.
+ for (int ib = ix; ib < nb; ib += nw/2) {
+ float sumy[2] = { 0.f, 0.f };
+
+#pragma unroll
+ for (int i = 0; i < 8; i += 2) {
+ sumy[0] += yb[i + 0] + yb[i + 1];
+ yl[i + 0] = yb[i + 0];
+ yl[i + 1] = yb[i + 1]/256.f;
+
+ sumy[1] += yb[i + 16] + yb[i + 17];
+ yl[i + 8] = yb[i + 16]/16.f;
+ yl[i + 9] = yb[i + 17]/4096.f;
+ }
+
+#pragma unroll
+ for (int row = 0; row < nr; row++) {
+ sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
+ }
+
+ yb += QK4_0 * 16;
+ }
+
+ device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
+
+ for (int row = 0; row < nr; ++row) {
+ const float tot = simd_sum(sumf[row]);
+
+ if (tiisg == 0 && first_row + row < args.ne01) {
+ dst_f32[first_row + row] = tot;
+ }
+ }
+}
+
+kernel void kernel_mul_mv_q4_0_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+kernel void kernel_mul_mv_q4_1_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+kernel void kernel_mul_mv_q5_0_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+kernel void kernel_mul_mv_q5_1_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+#define NB_Q8_0 8
+
+template
+void kernel_mul_mv_q8_0_f32_impl(
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+ const int nr = N_DST;
+ const int nsg = N_SIMDGROUP;
+ const int nw = N_SIMDWIDTH;
+
+ const int nb = args.ne00/QK8_0;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0*nsg + sgitg)*nr;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
+
+ // pointers to src0 rows
+ device const block_q8_0 * ax[nr];
+ for (int row = 0; row < nr; ++row) {
+ const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+
+ ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
+ }
+
+ float yl[NB_Q8_0];
+ float sumf[nr] = { 0.f };
+
+ const short ix = tiisg/4;
+ const short il = tiisg%4;
+
+ device const float * yb = y + ix*QK8_0 + il*NB_Q8_0;
+
+ // each thread in a SIMD group deals with NB_Q8_0 quants at a time
+ for (int ib = ix; ib < nb; ib += nw/4) {
+ for (short i = 0; i < NB_Q8_0; ++i) {
+ yl[i] = yb[i];
+ }
+
+ for (int row = 0; row < nr; row++) {
+ device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
+ float sumq = 0.f;
+ for (short iq = 0; iq < NB_Q8_0; ++iq) {
+ sumq += qs[iq] * yl[iq];
+ }
+ sumf[row] += sumq*ax[row][ib].d;
+ }
+
+ yb += nw*NB_Q8_0;
+ }
+
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+ for (int row = 0; row < nr; ++row) {
+ const float tot = simd_sum(sumf[row]);
+
+ if (tiisg == 0 && first_row + row < args.ne01) {
+ dst_f32[first_row + row] = tot;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_q8_0_f32")]]
+kernel void kernel_mul_mv_q8_0_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+// mat-vec kernel processing in chunks of float4
+// chpb - chunks per quantization block
+template
+void kernel_mul_mv_ext_q4_f32_impl(
+ constant ggml_metal_kargs_mul_mv_ext & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ const short chpt = 4; // chunks per thread
+
+ //const short nxpsg = (32);
+ const short nypsg = (32/nxpsg);
+
+ const short tx = tiisg%nxpsg;
+ const short ty = tiisg/nxpsg;
+
+ const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
+ const int i11 = tgpig.y*r1ptg;
+ const int i1m = tgpig.z;
+
+ const int i12 = i1m%args.ne12;
+ const int i13 = i1m/args.ne12;
+
+ const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
+
+ device const float4 * y4[r1ptg];
+
+ for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
+ y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1;
+ }
+
+ float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
+
+ short cch = tx%chpb; // current chunk index
+
+ for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) {
+ float4 lx[chpt];
+
+#pragma unroll(chpt)
+ for (short ch = 0; ch < chpt; ++ch) {
+ deq_t4(xq, cch, lx[ch]);
+
+ cch += nxpsg;
+ if (cch >= chpb) {
+ xq += cch/chpb;
+ cch %= chpb;
+ }
+ }
+
+#pragma unroll(chpt)
+ for (short ch = 0; ch < chpt; ++ch) {
+#pragma unroll(r1ptg)
+ for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+ sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]);
+
+ }
+ }
+
+#pragma unroll(r1ptg)
+ for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+ y4[ir1] += chpt*nxpsg;
+ }
+ }
+
+ // reduce only the threads in each row
+ for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+ if (nxpsg >= 32) {
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
+ }
+ if (nxpsg >= 16) {
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 8);
+ }
+ if (nxpsg >= 8) {
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 4);
+ }
+ if (nxpsg >= 4) {
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 2);
+ }
+ if (nxpsg >= 2) {
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 1);
+ }
+
+ //sumf[ir1] = simd_sum(sumf[ir1]);
+ }
+
+ if (tx == 0) {
+ for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
+ device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
+
+ if (i01 < args.ne01) {
+ dst_f32[i01] = sumf[ir1];
+ }
+ }
+ }
+}
+
+// mat-vec kernel processing in chunks of float4x4
+template
+void kernel_mul_mv_ext_q4x4_f32_impl(
+ constant ggml_metal_kargs_mul_mv_ext & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ const short chpt = 1;
+
+ //const short nxpsg = (32);
+ const short nypsg = (32/nxpsg);
+
+ const short tx = tiisg%nxpsg;
+ const short ty = tiisg/nxpsg;
+
+ const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
+ const int i11 = tgpig.y*r1ptg;
+ const int i1m = tgpig.z;
+
+ const int i12 = i1m%args.ne12;
+ const int i13 = i1m/args.ne12;
+
+ const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
+
+ device const float4x4 * y4x4[r1ptg];
+
+ for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
+ y4x4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4x4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4x4 *) src1;
+ }
+
+ float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
+
+ short cch = tx%chpb;
+
+ for (int ich = tx; 16*ich < args.ne00; ich += chpt*nxpsg) {
+ float4x4 lx[chpt];
+
+#pragma unroll(chpt)
+ for (short ch = 0; ch < chpt; ++ch) {
+ deq_t4x4(xq, cch, lx[ch]);
+
+ cch += nxpsg;
+ if (cch >= chpb) {
+ xq += cch/chpb;
+ cch %= chpb;
+ }
+ }
+
+#pragma unroll(chpt)
+ for (short ch = 0; ch < chpt; ++ch) {
+#pragma unroll(r1ptg)
+ for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+ sumf[ir1] +=
+ dot(lx[ch][0], y4x4[ir1][ch*nxpsg][0]) +
+ dot(lx[ch][1], y4x4[ir1][ch*nxpsg][1]) +
+ dot(lx[ch][2], y4x4[ir1][ch*nxpsg][2]) +
+ dot(lx[ch][3], y4x4[ir1][ch*nxpsg][3]);
+
+ }
+ }
+
+#pragma unroll(r1ptg)
+ for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+ y4x4[ir1] += chpt*nxpsg;
+ }
+ }
+
+ for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+ if (nxpsg >= 32) {
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
+ }
+ if (nxpsg >= 16) {
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 8);
+ }
+ if (nxpsg >= 8) {
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 4);
+ }
+ if (nxpsg >= 4) {
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 2);
+ }
+ if (nxpsg >= 2) {
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 1);
+ }
+
+ //sumf[ir1] = simd_sum(sumf[ir1]);
+ }
+
+ if (tx == 0) {
+ for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
+ device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
+
+ if (i01 < args.ne01) {
+ dst_f32[i01] = sumf[ir1];
+ }
+ }
+ }
+}
+
+// dispatchers needed for compile-time nxpsg
+// epb - elements per quantization block
+template
+kernel void kernel_mul_mv_ext_q4_f32_disp(
+ constant ggml_metal_kargs_mul_mv_ext & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ switch (args.nxpsg) {
+ case 4: kernel_mul_mv_ext_q4_f32_impl<4, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+ case 8: kernel_mul_mv_ext_q4_f32_impl<8, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+ case 16: kernel_mul_mv_ext_q4_f32_impl<16, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+ case 32: kernel_mul_mv_ext_q4_f32_impl<32, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+ }
+}
+
+template
+kernel void kernel_mul_mv_ext_q4x4_f32_disp(
+ constant ggml_metal_kargs_mul_mv_ext & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ switch (args.nxpsg) {
+ case 4: kernel_mul_mv_ext_q4x4_f32_impl<4, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+ case 8: kernel_mul_mv_ext_q4x4_f32_impl<8, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+ case 16: kernel_mul_mv_ext_q4x4_f32_impl<16, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+ case 32: kernel_mul_mv_ext_q4x4_f32_impl<32, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+ }
+}
+
+typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
+typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t;
+
+template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>;
+template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>;
+template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
+template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_0, 32, dequantize_q4_0_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_1, 32, dequantize_q4_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_1, 32, dequantize_q4_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_1, 32, dequantize_q4_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_1, 32, dequantize_q4_1_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_0, 32, dequantize_q5_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_0, 32, dequantize_q5_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_0, 32, dequantize_q5_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_0, 32, dequantize_q5_0_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_1, 32, dequantize_q5_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_1, 32, dequantize_q5_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_1, 32, dequantize_q5_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_1, 32, dequantize_q5_1_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q8_0, 32, dequantize_q8_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q8_0, 32, dequantize_q8_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>;
+
+template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
+template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
+template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
+template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>;
+template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q4_K, 256, dequantize_q4_K>;
+template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q4_K, 256, dequantize_q4_K>;
+template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q4_K, 256, dequantize_q4_K>;
+
+template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q5_K, 256, dequantize_q5_K>;
+template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q5_K, 256, dequantize_q5_K>;
+template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q5_K, 256, dequantize_q5_K>;
+template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q5_K, 256, dequantize_q5_K>;
+
+template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q6_K, 256, dequantize_q6_K>;
+template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q6_K, 256, dequantize_q6_K>;
+template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
+template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
+
+#define N_MV_T_T 4
+
+template
+void kernel_mul_mv_impl(
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig,
+ ushort tiisg) {
+ const int r0 = tgpig.x;
+ const int rb = tgpig.y*N_MV_T_T;
+ const int im = tgpig.z;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+
+ device const T0 * x = (device const T0 *) (src0 + offset0);
+
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
+
+ if (args.ne00 < 128) {
+ for (int row = 0; row < N_MV_T_T; ++row) {
+ int r1 = rb + row;
+ if (r1 >= args.ne11) {
+ break;
+ }
+
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const T1 * y = (device const T1 *) (src1 + offset1);
+
+ float sumf = 0;
+ for (int i = tiisg; i < args.ne00; i += 32) {
+ sumf += (T0) x[i] * (T1) y[i];
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
+ }
+ }
+ } else {
+ device const T04 * x4 = (device const T04 *) x;
+ for (int row = 0; row < N_MV_T_T; ++row) {
+ int r1 = rb + row;
+ if (r1 >= args.ne11) {
+ break;
+ }
+
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const T1 * y = (device const T1 *) (src1 + offset1);
+ device const T14 * y4 = (device const T14 *) y;
+
+ float sumf = 0;
+ for (int i = tiisg; i < args.ne00/4; i += 32) {
+ sumf += dot((float4) x4[i], (float4) y4[i]);
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
+ dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
+ }
+ }
+ }
+}
+
+template
+kernel void kernel_mul_mv(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]]) {
+ kernel_mul_mv_impl(
+ args,
+ src0,
+ src1,
+ dst,
+ tgpig,
+ tiisg);
+}
+
+typedef decltype(kernel_mul_mv) mul_mv_t;
+
+template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv;
+template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv;
+template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv;
+template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv;
+#endif
+
+template
+kernel void kernel_mul_mv_1row(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]]) {
+
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const T * x = (device const T *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
+
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+ float sumf = 0;
+ if (args.ne00 < 128) {
+ for (int i = tiisg; i < args.ne00; i += 32) {
+ sumf += (float) x[i] * (float) y[i];
+ }
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst_f32[r0] = all_sum;
+ }
+ } else {
+ device const T4 * x4 = (device const T4 *) x;
+ device const float4 * y4 = (device const float4 *) y;
+
+ for (int i = tiisg; i < args.ne00/4; i += 32) {
+ sumf += dot((float4) x4[i], y4[i]);
+ }
+
+ float all_sum = simd_sum(sumf);
+
+ if (tiisg == 0) {
+ for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
+ dst_f32[r0] = all_sum;
+ }
+ }
+}
+
+typedef decltype(kernel_mul_mv_1row) mul_mv_1row_t;
+
+template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row;
+#endif
+
+// Assumes row size (ne00) is a multiple of 4
+template
+kernel void kernel_mul_mv_l4(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]]) {
+
+ const int nrows = args.ne11;
+ const int r0 = tgpig.x;
+ const int im = tgpig.z;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+
+ device const T4 * x4 = (device const T4 *) (src0 + offset0);
+
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
+
+ for (int r1 = 0; r1 < nrows; ++r1) {
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const float4 * y4 = (device const float4 *) (src1 + offset1);
+
+ float sumf = 0;
+ for (int i = tiisg; i < args.ne00/4; i += 32) {
+ sumf += dot((float4) x4[i], y4[i]);
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
+ }
+ }
+}
+
+typedef decltype(kernel_mul_mv_l4) mul_mv_l4_t;
+
+template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4;
+#endif
+
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+ const float y = (i0 / 2 - low) / max(0.001f, high - low);
+ return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static void rope_yarn(
+ float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale,
+ thread float * cos_theta, thread float * sin_theta) {
+ // Get n-d rotational scaling corrected for extrapolation
+ float theta_interp = freq_scale * theta_extrap;
+ float theta = theta_interp;
+ if (ext_factor != 0.0f) {
+ float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+ // Get n-d magnitude scaling corrected for interpolation
+ mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
+ }
+ *cos_theta = cos(theta) * mscale;
+ *sin_theta = sin(theta) * mscale;
+}
+
+// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
+// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
+static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
+ return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
+}
+
+static void rope_yarn_corr_dims(
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
+) {
+ // start and end correction dims
+ dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
+ dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
+}
+
+template
+kernel void kernel_rope_norm(
+ constant ggml_metal_kargs_rope & args,
+ device const char * src0,
+ device const char * src1,
+ device const char * src2,
+ device char * dst,
+ ushort tiitg[[thread_index_in_threadgroup]],
+ ushort3 tptg [[threads_per_threadgroup]],
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
+ const int i3 = tgpig[2];
+ const int i2 = tgpig[1];
+ const int i1 = tgpig[0];
+
+ float corr_dims[2];
+ rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
+
+ device const int32_t * pos = (device const int32_t *) src1;
+
+ const float theta_base = (float) pos[i2];
+ const float inv_ndims = -1.f/args.n_dims;
+
+ float cos_theta;
+ float sin_theta;
+
+ for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
+ if (i0 < args.n_dims) {
+ const int ic = i0/2;
+
+ const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
+
+ const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
+
+ rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
+
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+ const float x0 = src[0];
+ const float x1 = src[1];
+
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
+ } else {
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+ dst_data[0] = src[0];
+ dst_data[1] = src[1];
+ }
+ }
+}
+
+template
+kernel void kernel_rope_neox(
+ constant ggml_metal_kargs_rope & args,
+ device const char * src0,
+ device const char * src1,
+ device const char * src2,
+ device char * dst,
+ ushort tiitg[[thread_index_in_threadgroup]],
+ ushort3 tptg [[threads_per_threadgroup]],
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
+ const int i3 = tgpig[2];
+ const int i2 = tgpig[1];
+ const int i1 = tgpig[0];
+
+ float corr_dims[2];
+ rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
+
+ device const int32_t * pos = (device const int32_t *) src1;
+
+ const float theta_base = (float) pos[i2];
+ const float inv_ndims = -1.f/args.n_dims;
+
+ float cos_theta;
+ float sin_theta;
+
+ for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
+ if (i0 < args.n_dims) {
+ const int ic = i0/2;
+
+ const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
+
+ const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
+
+ rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
+
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
+
+ const float x0 = src[0];
+ const float x1 = src[args.n_dims/2];
+
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
+ dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
+ } else {
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+ dst_data[0] = src[0];
+ dst_data[1] = src[1];
+ }
+ }
+}
+
+typedef decltype(kernel_rope_norm) kernel_rope_norm_t;
+typedef decltype(kernel_rope_neox) kernel_rope_neox_t;
+
+template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm;
+template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm;
+
+template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox;
+template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox;
+
+typedef void (im2col_t)(
+ device const float * x,
+ device char * dst,
+ constant int32_t & ofs0,
+ constant int32_t & ofs1,
+ constant int32_t & IW,
+ constant int32_t & IH,
+ constant int32_t & CHW,
+ constant int32_t & s0,
+ constant int32_t & s1,
+ constant int32_t & p0,
+ constant int32_t & p1,
+ constant int32_t & d0,
+ constant int32_t & d1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]);
+
+template
+kernel void kernel_im2col(
+ device const float * x,
+ device char * dst,
+ constant int32_t & ofs0,
+ constant int32_t & ofs1,
+ constant int32_t & IW,
+ constant int32_t & IH,
+ constant int32_t & CHW,
+ constant int32_t & s0,
+ constant int32_t & s1,
+ constant int32_t & p0,
+ constant int32_t & p1,
+ constant int32_t & d0,
+ constant int32_t & d1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+// const int64_t IC = tgpg[0];
+ const int64_t OH = tgpg[1];
+ const int64_t OW = tgpg[2];
+
+// const int64_t N = ntg[0];
+ const int64_t KH = ntg[1];
+ const int64_t KW = ntg[2];
+
+ const int64_t in = tpitg[0];
+ const int64_t ikh = tpitg[1];
+ const int64_t ikw = tpitg[2];
+
+ const int64_t iic = tgpig[0];
+ const int64_t ioh = tgpig[1];
+ const int64_t iow = tgpig[2];
+
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
+
+ const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*CHW + (iic*(KH*KW) + ikh*KW + ikw);
+
+ device T * pdst = (device T *) (dst);
+
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+ pdst[offset_dst] = 0.0f;
+ } else {
+ const int64_t offset_src = in*ofs0 + iic*ofs1 + iih*IW + iiw;
+ pdst[offset_dst] = x[offset_src];
+ }
+}
+
+template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col;
+template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col;
+
+typedef void (im2col_ext_t)(
+ device const float * x,
+ device char * dst,
+ constant int32_t & ofs0,
+ constant int32_t & ofs1,
+ constant int32_t & IW,
+ constant int32_t & IH,
+ constant int32_t & CHW,
+ constant int32_t & s0,
+ constant int32_t & s1,
+ constant int32_t & p0,
+ constant int32_t & p1,
+ constant int32_t & d0,
+ constant int32_t & d1,
+ constant int32_t & N,
+ constant int32_t & KH,
+ constant int32_t & KW,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]);
+
+template
+kernel void kernel_im2col_ext(
+ device const float * x,
+ device char * dst,
+ constant int32_t & ofs0,
+ constant int32_t & ofs1,
+ constant int32_t & IW,
+ constant int32_t & IH,
+ constant int32_t & CHW,
+ constant int32_t & s0,
+ constant int32_t & s1,
+ constant int32_t & p0,
+ constant int32_t & p1,
+ constant int32_t & d0,
+ constant int32_t & d1,
+ constant int32_t & N,
+ constant int32_t & KH,
+ constant int32_t & KW,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
+ const int64_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]
+
+ const int64_t d = tgpig[0] / CHW;
+ const int64_t chw = tgpig[0] % CHW;
+ const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
+ const int64_t HW = tgpig[0] % KHW;
+
+ const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
+ if (tpitg_0 >= N) {
+ return;
+ }
+
+ const int64_t tpitg_1 = HW / KW;
+ const int64_t tpitg_2 = HW % KW;
+
+ const int64_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
+ const int64_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
+
+ const int64_t offset_dst =
+ (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
+ (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
+
+ device T * pdst = (device T *) (dst);
+
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+ pdst[offset_dst] = 0.0f;
+ } else {
+ const int64_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
+ pdst[offset_dst] = x[offset_src + iih * IW + iiw];
+ }
+}
+
+template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext;
+template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext;
+
+typedef void (conv_transpose_1d_t)(
+ device const float * src0,
+ device const float * src1,
+ device char * dst,
+ constant int32_t & IC,
+ constant int32_t & IL,
+ constant int32_t & K,
+ constant int32_t & s0,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]]);
+
+template
+kernel void kernel_conv_transpose_1d(
+ device const T * src0,
+ device const float * src1,
+ device char * dst,
+ constant int32_t & IC,
+ constant int32_t & IL,
+ constant int32_t & K,
+ constant int32_t & s0,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]]) {
+
+ float v = 0.0f;
+
+ for (int64_t c = 0; c < IC; c++) {
+ const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
+ const int32_t input_offset = c * IL;
+
+ for (int64_t i = 0; i < IL; i++) {
+ if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) {
+ v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i];
+ }
+ }
+ }
+
+ device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1);
+
+ dst_ptr[0] = v;
+}
+
+template [[host_name("kernel_conv_transpose_1d_f32_f32")]]
+kernel void kernel_conv_transpose_1d(
+ device const float * src0,
+ device const float * src1,
+ device char * dst,
+ constant int32_t & IC,
+ constant int32_t & IL,
+ constant int32_t & K,
+ constant int32_t & s0,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]]);
+
+template [[host_name("kernel_conv_transpose_1d_f16_f32")]]
+kernel void kernel_conv_transpose_1d(
+ device const half * src0,
+ device const float * src1,
+ device char * dst,
+ constant int32_t & IC,
+ constant int32_t & IL,
+ constant int32_t & K,
+ constant int32_t & s0,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]]);
+
+kernel void kernel_upscale_f32(
+ device const char * src0,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ constant float & sf0,
+ constant float & sf1,
+ constant float & sf2,
+ constant float & sf3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ const int64_t i3 = tgpig.z;
+ const int64_t i2 = tgpig.y;
+ const int64_t i1 = tgpig.x;
+
+ const int64_t i03 = i3/sf3;
+ const int64_t i02 = i2/sf2;
+ const int64_t i01 = i1/sf1;
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ const int64_t i00 = i0/sf0;
+
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ dst_ptr[0] = src0_ptr[0];
+ }
+}
+
+kernel void kernel_pad_f32(
+ device const char * src0,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ const int64_t i3 = tgpig.z;
+ const int64_t i2 = tgpig.y;
+ const int64_t i1 = tgpig.x;
+
+ const int64_t i03 = i3;
+ const int64_t i02 = i2;
+ const int64_t i01 = i1;
+
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
+
+ if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ if (i0 < ne00) {
+ dst_ptr[i0] = src0_ptr[i0];
+ } else {
+ dst_ptr[i0] = 0.0f;
+ }
+ }
+
+ return;
+ }
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ dst_ptr[i0] = 0.0f;
+ }
+}
+
+kernel void kernel_pad_reflect_1d_f32(
+ device const char * src0,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant int64_t & ne0,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ constant int32_t & p0,
+ constant int32_t & p1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ const int64_t i3 = tgpig.z;
+ const int64_t i2 = tgpig.y;
+ const int64_t i1 = tgpig.x;
+
+ const int64_t i03 = i3;
+ const int64_t i02 = i2;
+ const int64_t i01 = i1;
+
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
+
+ if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ if (i0 < p0) {
+ dst_ptr[i0] = src0_ptr[p0 - i0];
+ } else if (i0 < ne0 - p1) {
+ dst_ptr[i0] = src0_ptr[i0 - p0];
+ } else {
+ dst_ptr[i0] = src0_ptr[(ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1];
+ }
+ }
+ }
+}
+
+kernel void kernel_arange_f32(
+ device char * dst,
+ constant int64_t & ne0,
+ constant float & start,
+ constant float & step,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ device float * dst_ptr = (device float *) dst;
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ dst_ptr[i0] = start + step * i0;
+ }
+}
+
+kernel void kernel_timestep_embedding_f32(
+ device const char * src0,
+ device char * dst,
+ constant uint64_t & nb1,
+ constant int & dim,
+ constant int & max_period,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ int i = tgpig.x;
+ device float * embed_data = (device float *)(dst + i*nb1);
+
+ int half_ = dim / 2;
+ for (int j = tpitg.x; j < half_; j += ntg.x) {
+ float timestep = ((device float *)src0)[i];
+ float freq = (float)exp(-log((float)max_period) * j / half_);
+ float arg = timestep * freq;
+ embed_data[j ] = cos(arg);
+ embed_data[j + half_] = sin(arg);
+ }
+
+ if (dim % 2 != 0 && tpitg.x == 0) {
+ embed_data[dim] = 0.f;
+ }
+}
+
+// bitonic sort implementation following the CUDA kernels as reference
+typedef void (argsort_t)(
+ device const float * x,
+ device int32_t * dst,
+ constant int64_t & ncols,
+ constant int64_t & ncols_pad,
+ threadgroup int32_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]]);
+
+template
+kernel void kernel_argsort_f32_i32(
+ device const float * x,
+ device int32_t * dst,
+ constant int64_t & ncols,
+ constant int64_t & ncols_pad,
+ threadgroup int32_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
+ // bitonic sort
+ int col = tpitg[0];
+ int row = tgpig[1];
+
+ if (col >= ncols_pad) return;
+
+ device const float * x_row = x + row * ncols;
+ threadgroup int32_t * dst_row = shared_values;
+
+ // initialize indices
+ dst_row[col] = col;
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int k = 2; k <= ncols_pad; k *= 2) {
+ for (int j = k / 2; j > 0; j /= 2) {
+ int ixj = col ^ j;
+ if (ixj > col) {
+ if ((col & k) == 0) {
+ if (dst_row[col] >= ncols ||
+ (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+ ) {
+ SWAP(dst_row[col], dst_row[ixj]);
+ }
+ } else {
+ if (dst_row[ixj] >= ncols ||
+ (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+ ) {
+ SWAP(dst_row[col], dst_row[ixj]);
+ }
+ }
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+ }
+
+ // copy the result to dst without the padding
+ if (col < ncols) {
+ dst[row * ncols + col] = dst_row[col];
+ }
+}
+
+template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32;
+template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32;
+
+kernel void kernel_leaky_relu_f32(
+ device const float * src0,
+ device float * dst,
+ constant float & slope,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
+}
+
+// ref: https://arxiv.org/pdf/2307.08691.pdf
+template<
+ typename q_t, // query types in shared memory
+ typename q4_t,
+ typename q8x8_t,
+ typename k_t, // key types in shared memory
+ typename k4x4_t,
+ typename k8x8_t,
+ typename v_t, // value types in shared memory
+ typename v4x4_t,
+ typename v8x8_t,
+ typename qk_t, // Q*K types
+ typename qk8x8_t,
+ typename s_t, // soft-max types
+ typename s8x8_t,
+ typename o_t, // attention accumulation types
+ typename o4_t,
+ typename o8x8_t,
+ typename kd4x4_t, // key type in device memory
+ short nl_k,
+ void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
+ typename vd4x4_t, // key type in device memory
+ short nl_v,
+ void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
+ short D, // head size
+ short Q = 8, // queries per threadgroup
+ short KV = 8, // key/value processed per each simdgroup
+ short C = 32> // cache items per threadgroup
+kernel void kernel_flash_attn_ext(
+ constant ggml_metal_kargs_flash_attn_ext & args,
+ device const char * q,
+ device const char * k,
+ device const char * v,
+ device const char * mask,
+ device char * dst,
+ threadgroup half * shmem_f16 [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 ntg[[threads_per_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ const short nsg = ntg.y; // number of simdgroups
+
+ const int iq3 = tgpig[2];
+ const int iq2 = tgpig[1];
+ const int iq1 = tgpig[0]*Q;
+
+ const short D4 = D/4;
+ const short D8 = D/8;
+ const short D16 = D/16;
+ const short NW = N_SIMDWIDTH;
+ const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
+
+ const short TS = nsg*SH; // shared memory size per query in (s_t == float)
+ const short T = D + 2*TS; // shared memory size per query in (half)
+
+ threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t
+ threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*D); // reuse query data for accumulation
+ threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*D); // same as above but in o4_t
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
+
+ threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
+ threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
+
+ threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
+ threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
+
+ // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
+ o8x8_t lo[D8];
+
+ // load heads from Q to shared memory
+ for (short j = sgitg; j < Q; j += nsg) {
+ device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
+
+ for (short i = tiisg; i < D4; i += NW) {
+ if (iq1 + j < args.ne01) {
+ sq4[j*D4 + i] = (q4_t) q4[i];
+ } else {
+ sq4[j*D4 + i] = (q4_t) 0.0f;
+ }
+ }
+ }
+
+ // zero out lo
+ for (short i = 0; i < D8; ++i) {
+ lo[i] = make_filled_simdgroup_matrix((o_t) 0.0f);
+ }
+
+ // zero out shared memory SH
+ for (short j = 0; j < Q; ++j) {
+ for (short i = tiisg; i < SH; i += NW) {
+ ss[j*TS + i] = 0.0f;
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ {
+ half S[Q] = { [0 ... Q-1] = 0.0f };
+ half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
+
+ // thread indices inside the simdgroup
+ // TODO: see if we can utilize quad-group functions for better performance
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3)
+ const short tx = tiisg%4;
+ const short ty = tiisg/4;
+
+ // broadcast kv
+ //const short rk2 = args.ne02/args.ne12;
+ //const short rk3 = args.ne03/args.ne13;
+
+ const short ikv2 = iq2/(args.ne02/args.ne_12_2);
+ const short ikv3 = iq3/(args.ne03/args.ne_12_3);
+
+ // load the queries from shared memory into local memory
+ q8x8_t mq[D8];
+
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_load(mq[i], sq + i*8, D);
+ }
+
+ const bool has_mask = mask != q;
+
+ half slope = 1.0f;
+
+ // ALiBi
+ if (args.max_bias > 0.0f) {
+ const short h = iq2;
+
+ const half base = h < args.n_head_log2 ? args.m0 : args.m1;
+ const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
+
+ slope = pow(base, exph);
+ }
+
+ // loop over the KV cache
+ // each simdgroup handles blocks of Q rows and C columns
+ for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
+ const int ic = ic0 + C*sgitg;
+ if (ic >= args.ne11) {
+ break;
+ }
+
+ if (has_mask) {
+ // used to detect blocks full of -INF
+ half smax = -INFINITY;
+
+ // load the mask in shared memory
+ #pragma unroll(Q)
+ for (short j = 0; j < Q; ++j) {
+ device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
+
+ const half m = pm[ic + tiisg];
+
+ ss[j*TS + C + tiisg] = m;
+ smax = max(smax, m);
+ }
+
+ smax = simd_max(smax);
+
+ if (smax == -INFINITY) {
+ continue;
+ }
+ }
+
+ // Q*K^T
+ {
+ for (short cc = 0; cc < C/8; ++cc) {
+ qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f);
+
+ // this is compile-time check, so it does not have runtime overhead
+ if (is_same::value) {
+ // we can read directly from global memory
+ device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+
+ #pragma unroll(D8)
+ for (short i = 0; i < D8; ++i) {
+ k8x8_t mk;
+ simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
+
+ simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
+ }
+ } else {
+ for (short ii = 0; ii < D16; ii += 4) {
+ device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+
+ if (D16%4 == 0) {
+ // the head is evenly divisible by 4*16 = 64, so no need for bound checks
+ {
+ k4x4_t tmp;
+ deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
+ sk4x4[4*ty + tx] = tmp;
+ }
+
+ simdgroup_barrier(mem_flags::mem_threadgroup);
+
+ #pragma unroll(4)
+ for (short k = 0; k < 4; ++k) {
+ k8x8_t mk;
+
+ simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
+ simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
+
+ simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
+ simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
+ }
+ } else {
+ if (ii + tx < D16) {
+ k4x4_t tmp;
+ deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
+ sk4x4[4*ty + tx] = tmp;
+ }
+
+ simdgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (short k = 0; k < 4 && ii + k < D16; ++k) {
+ k8x8_t mk;
+
+ simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
+ simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
+
+ simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
+ simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
+ }
+ }
+ }
+ }
+
+ // cast qk_t -> s_t
+ //s8x8_t mqks(1.0f);
+ //simdgroup_multiply(mqks, mqk, mqks);
+ //simdgroup_store(mqks, ss + 8*cc, TS, 0, false);
+
+ simdgroup_store(mqk, ss + 8*cc, TS, 0, false);
+ }
+ }
+
+ // online softmax
+ {
+ for (ushort j = 0; j < Q; ++j) {
+ const half m = M[j];
+
+ // scale and apply the logitcap / mask
+ half s = ss[j*TS + tiisg]*args.scale;
+
+ if (args.logit_softcap != 0.0f) {
+ s = args.logit_softcap*precise::tanh(s);
+ }
+
+ // mqk = mqk + mask*slope
+ s += slope*ss[j*TS + C + tiisg];
+
+ M[j] = simd_max(max(M[j], s));
+
+ const half ms = exp(m - M[j]);
+ const half vs = exp(s - M[j]);
+
+ S[j] = S[j]*ms + simd_sum(vs);
+
+ // the P matrix from the paper (Q rows, C columns)
+ ss[j*TS + tiisg] = vs;
+
+ // create a QxQ diagonal matrix for rescaling the output
+ if (tiisg == j) {
+ ss[j*TS + 2*C + j] = ms;
+ }
+ }
+ }
+
+ // O = diag(ms)*O
+ {
+ s8x8_t mm;
+ simdgroup_load(mm, ss + 2*C, TS, 0, false);
+
+ #pragma unroll(D8)
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_multiply(lo[i], mm, lo[i]);
+ }
+ }
+
+ // O = O + (Q*K^T)*V
+ {
+ for (short cc = 0; cc < C/8; ++cc) {
+ s8x8_t ms;
+ simdgroup_load(ms, ss + 8*cc, TS, 0, false);
+
+ if (is_same::value) {
+ // we can read directly from global memory
+ device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+
+ #pragma unroll(D8)
+ for (short i = 0; i < D8; ++i) {
+ v8x8_t mv;
+ simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
+
+ simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
+ }
+ } else {
+ for (short ii = 0; ii < D16; ii += 4) {
+ device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+
+ if (D16%4 == 0) {
+ // no need for bound checks
+ {
+ v4x4_t tmp;
+ deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
+ sv4x4[4*ty + tx] = tmp;
+ }
+
+ simdgroup_barrier(mem_flags::mem_threadgroup);
+
+ #pragma unroll(4)
+ for (short k = 0; k < 4; ++k) {
+ v8x8_t mv;
+
+ simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
+
+ simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
+ }
+ } else {
+ if (ii + tx < D16) {
+ v4x4_t tmp;
+ deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
+ sv4x4[4*ty + tx] = tmp;
+ }
+
+ simdgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (short k = 0; k < 4 && ii + k < D16; ++k) {
+ v8x8_t mv;
+
+ simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
+
+ simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
+ for (short j = 0; j < Q; ++j) {
+ if (tiisg == 0) {
+ ss[j*TS + 0] = S[j];
+ ss[j*TS + 1] = M[j];
+ }
+ }
+ }
+
+ // reduce the warps sequentially
+ for (ushort sg = 1; sg < nsg; ++sg) {
+ half S = { 0.0f };
+ half M = { -__FLT16_MAX__/2 };
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // each simdgroup stores its output to shared memory, reusing sq
+ if (sgitg == sg) {
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_store(lo[i], so + i*8, D, 0, false);
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // the first simdgroup accumulates the results from the other simdgroups
+ if (sgitg == 0) {
+ for (short j = 0; j < Q; ++j) {
+ const half S0 = ss[j*TS + 0];
+ const half S1 = ss[j*TS + sg*SH + 0];
+
+ const half M0 = ss[j*TS + 1];
+ const half M1 = ss[j*TS + sg*SH + 1];
+
+ M = max(M0, M1);
+
+ const half ms0 = exp(M0 - M);
+ const half ms1 = exp(M1 - M);
+
+ S = S0*ms0 + S1*ms1;
+
+ if (tiisg == 0) {
+ ss[j*TS + 0] = S;
+ ss[j*TS + 1] = M;
+
+ ss[j*TS + 2*C + j ] = ms0;
+ ss[j*TS + 2*C + j + sg*SH] = ms1;
+ }
+ }
+
+ // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
+ {
+ s8x8_t ms0;
+ s8x8_t ms1;
+
+ simdgroup_load(ms0, ss + 2*C, TS, 0, false);
+ simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
+
+ #pragma unroll(D8)
+ for (short i = 0; i < D8; ++i) {
+ o8x8_t t;
+
+ simdgroup_load (t, so + i*8, D, 0, false);
+ simdgroup_multiply(t, ms1, t);
+
+ simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
+ }
+ }
+ }
+ }
+
+ // store result to shared memory (reuse sq)
+ if (sgitg == 0) {
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_store(lo[i], so + i*8, D, 0, false);
+ }
+ }
+
+ device float4 * dst4 = (device float4 *) dst;
+
+ // final rescale with 1/S and store to global memory
+ if (sgitg == 0) {
+ for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
+ const float S = ss[j*TS + 0];
+
+ for (short i = tiisg; i < D4; i += NW) {
+ dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
+ }
+ }
+ }
+}
+
+// TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as
+// template to be able to explore different combinations
+//
+#define FA_TYPES \
+ half, half4, simdgroup_half8x8, \
+ half, half4x4, simdgroup_half8x8, \
+ half, half4x4, simdgroup_half8x8, \
+ float, simdgroup_float8x8, \
+ float, simdgroup_float8x8, \
+ half, half4, simdgroup_half8x8
+
+typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t;
+
+template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+#endif
+
+template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+
+template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+
+template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+
+template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+
+template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+
+#undef FA_TYPES
+
+template<
+ typename q4_t, // query types in shared memory
+ typename q4x4_t,
+ typename k4x4_t, // key types in shared memory
+ typename v4x4_t, // value types in shared memory
+ typename qk_t, // Q*K types
+ typename s_t, // soft-max types
+ typename s4_t,
+ typename s4x4_t,
+ typename o4x4_t, // attention accumulation types
+ typename kd4x4_t, // key type in device memory
+ short nl_k,
+ void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
+ typename vd4x4_t, // key type in device memory
+ short nl_v,
+ void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
+ short D, // head size
+ short Q = 1, // queries per threadgroup
+ short C = 32> // cache items per threadgroup
+kernel void kernel_flash_attn_ext_vec(
+ constant ggml_metal_kargs_flash_attn_ext & args,
+ device const char * q,
+ device const char * k,
+ device const char * v,
+ device const char * mask,
+ device char * dst,
+ threadgroup half * shmem_f16 [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 ntg[[threads_per_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ const short nsg = ntg.y; // number of simdgroups
+
+ const int iq3 = tgpig[2];
+ const int iq2 = tgpig[1];
+ const int iq1 = tgpig[0];
+
+ const short D4 = D/4;
+ const short D16 = D/16;
+ const short NW = N_SIMDWIDTH;
+ const short NL = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
+ const short SH = 2*C; // shared memory per simdgroup
+
+ const short T = D + nsg*SH; // shared memory size per query in (half)
+
+ //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t
+ threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 + 0*D); // same as above but in q4x4_t
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*D); // scratch buffer for attention
+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*D); // same as above but in s4_t
+ threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask
+ threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D + Q*T); // scratch buffer for the results
+
+ // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
+ o4x4_t lo[D16/NL];
+
+ // load heads from Q to shared memory
+ device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
+
+ for (short i = tiisg; i < D4; i += NW) {
+ if (iq1 < args.ne01) {
+ sq4[i] = (q4_t) q4[i];
+ } else {
+ sq4[i] = (q4_t) 0.0f;
+ }
+ }
+
+ // zero out lo
+ for (short i = 0; i < D16/NL; ++i) {
+ lo[i] = (o4x4_t) 0.0f;
+ }
+
+ // zero out shared memory SH
+ for (short i = tiisg; i < SH/4; i += NW) {
+ ss4[i] = (s4_t) 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ {
+ half S = 0.0f;
+ half M = -__FLT16_MAX__/2;
+
+ // thread indices inside the simdgroup
+ const short tx = tiisg%NL;
+ const short ty = tiisg/NL;
+
+ // broadcast kv
+ //const short rk2 = args.ne02/args.ne12;
+ //const short rk3 = args.ne03/args.ne13;
+
+ const short ikv2 = iq2/(args.ne02/args.ne_12_2);
+ const short ikv3 = iq3/(args.ne03/args.ne_12_3);
+
+ // load the queries from shared memory into local memory
+ q4x4_t mq[D16/NL];
+
+ #pragma unroll(D16/NL)
+ for (short ii = 0; ii < D16; ii += NL) {
+ mq[ii/NL] = sq4x4[ii + tx];
+ }
+
+ const bool has_mask = mask != q;
+
+ // pointer to the mask
+ device const half * pm = (device const half *) (mask + iq1*args.nb31);
+
+ half slope = 1.0f;
+
+ // ALiBi
+ if (args.max_bias > 0.0f) {
+ const short h = iq2;
+
+ const half base = h < args.n_head_log2 ? args.m0 : args.m1;
+ const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
+
+ slope = pow(base, exph);
+ }
+
+ // loop over the KV cache
+ // each simdgroup handles blocks of Q rows and C columns
+ for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
+ const int ic = ic0 + C*sgitg;
+ if (ic >= args.ne11) {
+ break;
+ }
+
+ if (has_mask) {
+ sm[tiisg] = pm[ic + tiisg];
+ }
+
+ // Q*K^T
+ {
+ // each simdgroup processes 1 query and 4 (NW/NL) keys
+ for (short cc = 0; cc < C/4; ++cc) {
+ qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
+
+ device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+
+ #pragma unroll(D16/NL)
+ for (short ii = 0; ii < D16; ii += NL) {
+ const short i = ii + tx;
+
+ k4x4_t mk;
+ deq_k(pk + i/nl_k, i%nl_k, mk);
+
+ // note: this is less precise than the version below
+ //mqka[0] += dot(mq[ii/NL][0], mk[0]);
+ //mqka[1] += dot(mq[ii/NL][1], mk[1]);
+ //mqka[2] += dot(mq[ii/NL][2], mk[2]);
+ //mqka[3] += dot(mq[ii/NL][3], mk[3]);
+
+ mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]);
+ mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]);
+ mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]);
+ mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]);
+ }
+
+ qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
+
+ // simdgroup reduce
+ // [ 0 .. 7] -> [ 0]
+ // [ 8 .. 15] -> [ 8]
+ // [16 .. 23] -> [16]
+ // [24 .. 31] -> [24]
+ //mqk += simd_shuffle_down(mqk, 16);
+ //mqk += simd_shuffle_down(mqk, 8);
+ mqk += simd_shuffle_down(mqk, 4);
+ mqk += simd_shuffle_down(mqk, 2);
+ mqk += simd_shuffle_down(mqk, 1);
+
+ // mqk = mqk*scale + mask*slope
+ if (tx == 0) {
+ mqk *= args.scale;
+
+ if (args.logit_softcap != 0.0f) {
+ mqk = args.logit_softcap*precise::tanh(mqk);
+ }
+
+ mqk += sm[4*cc + ty]*slope;
+
+ ss[4*cc + ty] = mqk;
+ }
+ }
+ }
+
+ simdgroup_barrier(mem_flags::mem_threadgroup);
+
+ // online softmax
+ {
+ const half m = M;
+ const half s = ss[tiisg];
+
+ M = simd_max(max(M, s));
+
+ const half ms = exp(m - M);
+ const half vs = exp(s - M);
+
+ S = S*ms + simd_sum(vs);
+
+ // the P matrix from the paper (Q rows, C columns)
+ ss[tiisg] = vs;
+
+ // O = diag(ms)*O
+ #pragma unroll(D16/NL)
+ for (short ii = 0; ii < D16; ii += NL) {
+ lo[ii/NL] *= ms;
+ }
+ }
+
+ simdgroup_barrier(mem_flags::mem_threadgroup);
+
+ // O = O + (Q*K^T)*V
+ {
+ for (short cc = 0; cc < C/4; ++cc) {
+ device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+
+ const s4x4_t ms(ss[4*cc + ty]);
+
+ #pragma unroll(D16/NL)
+ for (short ii = 0; ii < D16; ii += NL) {
+ const short i = ii + tx;
+
+ v4x4_t mv;
+ deq_v(pv4 + i/nl_v, i%nl_v, mv);
+
+ lo[ii/NL] += mv*ms;
+ }
+ }
+ }
+ }
+
+ // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
+ if (tiisg == 0) {
+ ss[0] = (s_t) S;
+ ss[1] = (s_t) M;
+ }
+ }
+
+ // simdgroup reduce
+ // [ 0, 8, 16, 24] -> [ 0]
+ // [ 1, 9, 17, 25] -> [ 1]
+ // [ 2, 10, 18, 26] -> [ 2]
+ // [ 3, 11, 19, 27] -> [ 3]
+ // [ 4, 12, 20, 28] -> [ 4]
+ // [ 5, 13, 21, 29] -> [ 5]
+ // [ 6, 14, 22, 30] -> [ 6]
+ // [ 7, 15, 23, 31] -> [ 7]
+ for (short ii = 0; ii < D16; ii += NL) {
+ lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
+ lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8);
+ //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4);
+ //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2);
+ //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1);
+
+ lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
+ lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8);
+ //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4);
+ //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2);
+ //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1);
+
+ lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
+ lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8);
+ //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4);
+ //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2);
+ //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1);
+
+ lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
+ lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8);
+ //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4);
+ //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2);
+ //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1);
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // store results to shared memory
+ for (short i = tiisg; i < D16; i += NL) {
+ sr4x4[i] = lo[i/NL];
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // parallel reduce
+ for (short r = nsg/2; r > 0; r >>= 1) {
+ if (sgitg < r) {
+ const half S0 = ss[ 0];
+ const half S1 = ss[r*SH + 0];
+
+ const half M0 = ss[ 1];
+ const half M1 = ss[r*SH + 1];
+
+ const half M = max(M0, M1);
+
+ const half ms0 = exp(M0 - M);
+ const half ms1 = exp(M1 - M);
+
+ const half S = S0*ms0 + S1*ms1;
+
+ if (tiisg == 0) {
+ ss[0] = S;
+ ss[1] = M;
+ }
+
+ // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
+ for (short i = tiisg; i < D16; i += NW) {
+ sr4x4[i] = sr4x4[i]*ms0 + sr4x4[i + r*D16]*ms1;
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+ device float4x4 * dst44 = (device float4x4 *) dst;
+
+ // final rescale with 1/S and store to global memory
+ if (sgitg == 0) {
+ const float S = ss[0];
+
+ for (short i = tiisg; i < D16; i += NW) {
+ dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
+ }
+ }
+}
+
+// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
+// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
+//
+#define FA_TYPES \
+ half4, half4x4, \
+ half4x4, \
+ half4x4, \
+ float, \
+ half, half4, half4x4, \
+ half4x4
+
+typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+
+#undef FA_TYPES
+
+template
+kernel void kernel_set(
+ constant ggml_metal_kargs_set & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i13 = tgpig[2];
+ const int i12 = tgpig[1];
+ const int i11 = tgpig[0];
+
+ const int64_t n = i13*args.ne12*args.ne11*args.ne10 + i12*args.ne11*args.ne10 + i11*args.ne10;
+
+ const int64_t i3 = n / (args.ne12*args.ne11*args.ne10);
+ const int64_t i2 = (n - i3*args.ne12*args.ne11*args.ne10) / (args.ne11*args.ne10);
+ const int64_t i1 = (n - i3*args.ne12*args.ne11*args.ne10 - i2*args.ne11*args.ne10) / args.ne10;
+
+ device T * dst_data = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + args.offs);
+
+ for (int64_t i10 = tpitg.x; i10 < args.ne10; i10 += ntg.x) {
+ device const T * src = (device T *) (src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10);
+ dst_data[i10] = (T) src[0];
+ }
+}
+
+typedef decltype(kernel_set) kernel_set_t;
+
+template [[host_name("kernel_set_f32")]] kernel kernel_set_t kernel_set;
+template [[host_name("kernel_set_i32")]] kernel kernel_set_t kernel_set;
+
+template
+kernel void kernel_cpy(
+ constant ggml_metal_kargs_cpy & args,
+ device const char * src0,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig[2];
+ const int i02 = tgpig[1];
+ const int i01 = tgpig[0];
+
+ const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+ const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
+ const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
+ const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
+ const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
+
+ device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+ for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
+ device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+ dst_data[i00] = (T1) src[0];
+ }
+}
+
+typedef decltype(kernel_cpy) kernel_cpy_t;
+
+template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy;
+template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy;
+#endif
+template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy;
+template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy;
+template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy;
+#endif
+
+kernel void kernel_cpy_f32_q8_0(
+ constant ggml_metal_kargs_cpy & args,
+ device const char * src0,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig[2];
+ const int i02 = tgpig[1];
+ const int i01 = tgpig[0];
+
+ const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+ const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+ const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+ const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+ const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0;
+
+ device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+ for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
+ device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+
+ float amax = 0.0f; // absolute max
+
+ for (int j = 0; j < QK8_0; j++) {
+ const float v = src[j];
+ amax = MAX(amax, fabs(v));
+ }
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst_data[i00/QK8_0].d = d;
+
+ for (int j = 0; j < QK8_0; ++j) {
+ const float x0 = src[j]*id;
+
+ dst_data[i00/QK8_0].qs[j] = round(x0);
+ }
+ }
+}
+
+kernel void kernel_cpy_f32_q4_0(
+ constant ggml_metal_kargs_cpy & args,
+ device const char * src0,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig[2];
+ const int i02 = tgpig[1];
+ const int i01 = tgpig[0];
+
+ const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+ const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+ const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+ const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+ const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0;
+
+ device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+ for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
+ device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+
+ float amax = 0.0f; // absolute max
+ float max = 0.0f;
+
+ for (int j = 0; j < QK4_0; j++) {
+ const float v = src[j];
+ if (amax < fabs(v)) {
+ amax = fabs(v);
+ max = v;
+ }
+ }
+
+ const float d = max / -8;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst_data[i00/QK4_0].d = d;
+
+ for (int j = 0; j < QK4_0/2; ++j) {
+ const float x0 = src[0 + j]*id;
+ const float x1 = src[QK4_0/2 + j]*id;
+
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
+
+ dst_data[i00/QK4_0].qs[j] = xi0;
+ dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
+ }
+ }
+}
+
+kernel void kernel_cpy_f32_q4_1(
+ constant ggml_metal_kargs_cpy & args,
+ device const char * src0,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig[2];
+ const int i02 = tgpig[1];
+ const int i01 = tgpig[0];
+
+ const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+ const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+ const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+ const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+ const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_1;
+
+ device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+ for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
+ device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+
+ float min = FLT_MAX;
+ float max = -FLT_MAX;
+
+ for (int j = 0; j < QK4_1; j++) {
+ const float v = src[j];
+ if (min > v) min = v;
+ if (max < v) max = v;
+ }
+
+ const float d = (max - min) / ((1 << 4) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst_data[i00/QK4_1].d = d;
+ dst_data[i00/QK4_1].m = min;
+
+ for (int j = 0; j < QK4_1/2; ++j) {
+ const float x0 = (src[0 + j] - min)*id;
+ const float x1 = (src[QK4_1/2 + j] - min)*id;
+
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
+
+ dst_data[i00/QK4_1].qs[j] = xi0;
+ dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
+ }
+ }
+}
+
+kernel void kernel_cpy_f32_q5_0(
+ constant ggml_metal_kargs_cpy & args,
+ device const char * src0,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig[2];
+ const int i02 = tgpig[1];
+ const int i01 = tgpig[0];
+
+ const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+ const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+ const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+ const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+ const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0;
+
+ device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+ for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
+ device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+
+ float amax = 0.0f; // absolute max
+ float max = 0.0f;
+
+ for (int j = 0; j < QK5_0; j++) {
+ const float v = src[j];
+ if (amax < fabs(v)) {
+ amax = fabs(v);
+ max = v;
+ }
+ }
+
+ const float d = max / -16;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst_data[i00/QK5_0].d = d;
+
+ uint32_t qh = 0;
+ for (int j = 0; j < QK5_0/2; ++j) {
+ const float x0 = src[0 + j]*id;
+ const float x1 = src[QK5_0/2 + j]*id;
+
+ const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
+ const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
+
+ dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
+ }
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
+ for (int j = 0; j < 4; ++j) {
+ dst_data[i00/QK5_0].qh[j] = qh8[j];
+ }
+ }
+}
+
+kernel void kernel_cpy_f32_q5_1(
+ constant ggml_metal_kargs_cpy & args,
+ device const char * src0,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig[2];
+ const int i02 = tgpig[1];
+ const int i01 = tgpig[0];
+
+ const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+ const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+ const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+ const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+ const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_1;
+
+ device block_q5_1 * dst_data = (device block_q5_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+ for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
+ device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+
+ float max = src[0];
+ float min = src[0];
+
+ for (int j = 1; j < QK5_1; j++) {
+ const float v = src[j];
+ min = v < min ? v : min;
+ max = v > max ? v : max;
+ }
+
+ const float d = (max - min) / 31;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst_data[i00/QK5_1].d = d;
+ dst_data[i00/QK5_1].m = min;
+
+ uint32_t qh = 0;
+ for (int j = 0; j < QK5_1/2; ++j) {
+ const float x0 = (src[0 + j] - min)*id;
+ const float x1 = (src[QK5_1/2 + j] - min)*id;
+
+ const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+ const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+ dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
+ }
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
+ for (int j = 0; j < 4; ++j) {
+ dst_data[i00/QK5_1].qh[j] = qh8[j];
+ }
+ }
+}
+
+static inline int best_index_int8(int n, constant float * val, float x) {
+ if (x <= val[0]) return 0;
+ if (x >= val[n-1]) return n-1;
+ int ml = 0, mu = n-1;
+ while (mu-ml > 1) {
+ int mav = (ml+mu)/2;
+ if (x < val[mav]) mu = mav; else ml = mav;
+ }
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
+kernel void kernel_cpy_f32_iq4_nl(
+ constant ggml_metal_kargs_cpy & args,
+ device const char * src0,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig[2];
+ const int i02 = tgpig[1];
+ const int i01 = tgpig[0];
+
+ const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+ const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+ const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+ const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+ const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL;
+
+ device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+ for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
+ device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+
+ float amax = 0.0f; // absolute max
+ float max = 0.0f;
+
+ for (int j = 0; j < QK4_0; j++) {
+ const float v = src[j];
+ if (amax < fabs(v)) {
+ amax = fabs(v);
+ max = v;
+ }
+ }
+
+ const float d = max / kvalues_iq4nl_f[0];
+ const float id = d ? 1.0f/d : 0.0f;
+
+ float sumqx = 0, sumq2 = 0;
+ for (int j = 0; j < QK4_NL/2; ++j) {
+ const float x0 = src[0 + j]*id;
+ const float x1 = src[QK4_NL/2 + j]*id;
+
+ const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
+ const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
+
+ dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
+
+ const float v0 = kvalues_iq4nl_f[xi0];
+ const float v1 = kvalues_iq4nl_f[xi1];
+ const float w0 = src[0 + j]*src[0 + j];
+ const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
+ sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
+ sumq2 += w0*v0*v0 + w1*v1*v1;
+
+ }
+
+ dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
+ }
+}
+
+kernel void kernel_concat(
+ constant ggml_metal_kargs_concat & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+
+ const int i3 = tgpig.z;
+ const int i2 = tgpig.y;
+ const int i1 = tgpig.x;
+
+ int o[4] = {0, 0, 0, 0};
+ o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
+
+ device const float * x;
+
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
+ x = (device const float *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00);
+ } else {
+ x = (device const float *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10);
+ }
+
+ device float * y = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+ *y = *x;
+ }
+}
+
+template
+void kernel_mul_mv_q2_K_f32_impl(
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ const int nb = args.ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int ix = tiisg/8; // 0...3
+ const int it = tiisg%8; // 0...7
+ const int iq = it/4; // 0 or 1
+ const int ir = it%4; // 0...3
+ const int is = (8*ir)/16;// 0 or 1
+
+ device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
+
+ for (int ib = ix; ib < nb; ib += 4) {
+
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; ++i) {
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+ yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
+ yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
+ yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
+ }
+
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
+ device const half * dh = &x[ib].d;
+
+ for (int row = 0; row < N_DST; row++) {
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; i += 2) {
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
+ }
+ float dall = dh[0];
+ float dmin = dh[1] * 1.f/16.f;
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
+ (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
+ (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
+ (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
+ dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
+
+ qs += args.nb01/2;
+ sc += args.nb01;
+ dh += args.nb01/2;
+ }
+
+ y4 += 4 * QK_K;
+ }
+
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+ for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst_f32[first_row + row] = all_sum;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_q2_K_f32")]]
+kernel void kernel_mul_mv_q2_K_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+template
+void kernel_mul_mv_q3_K_f32_impl(
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ const int nb = args.ne00/QK_K;
+
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0);
+ device const float * yy = (device const float *) (src1 + offset1);
+
+ float yl[32];
+
+ //const uint16_t kmask1 = 0x3030;
+ //const uint16_t kmask2 = 0x0f0f;
+
+ const int tid = tiisg/4;
+ const int ix = tiisg%4;
+ const int ip = tid/4; // 0 or 1
+ const int il = 2*((tid%4)/2); // 0 or 2
+ const int ir = tid%2;
+ const int n = 8;
+ const int l0 = n*ir;
+
+ // One would think that the Metal compiler would figure out that ip and il can only have
+ // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
+ // with these two tales.
+ //
+ // Possible masks for the high bit
+ const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
+ {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
+ {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
+ {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
+
+ // Possible masks for the low 2 bits
+ const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
+
+ const ushort4 hm = mm[2*ip + il/2];
+
+ const short shift = 2*il;
+
+ const float v1 = il == 0 ? 4.f : 64.f;
+ const float v2 = 4.f * v1;
+
+ const uint16_t s_shift1 = 4*ip;
+ const uint16_t s_shift2 = s_shift1 + il;
+
+ const int q_offset = 32*ip + l0;
+ const int y_offset = 128*ip + 32*il + l0;
+
+ device const float * y1 = yy + ix*QK_K + y_offset;
+
+ uint32_t scales32, aux32;
+ thread uint16_t * scales16 = (thread uint16_t *)&scales32;
+ thread const int8_t * scales = (thread const int8_t *)&scales32;
+
+ float sumf1[2] = {0.f};
+ float sumf2[2] = {0.f};
+ for (int i = ix; i < nb; i += 4) {
+ for (int l = 0; l < 8; ++l) {
+ yl[l+ 0] = y1[l+ 0];
+ yl[l+ 8] = y1[l+16];
+ yl[l+16] = y1[l+32];
+ yl[l+24] = y1[l+48];
+ }
+
+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
+ device const uint16_t * a = (device const uint16_t *)(x[i].scales);
+ device const half * dh = &x[i].d;
+
+ for (int row = 0; row < 2; ++row) {
+ const float d_all = (float)dh[0];
+
+ scales16[0] = a[4];
+ scales16[1] = a[5];
+ aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
+ scales16[0] = a[il+0];
+ scales16[1] = a[il+1];
+ scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
+
+ float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
+ for (int l = 0; l < n; l += 2) {
+ const int32_t qs = q[l/2];
+ s1 += yl[l+0] * (qs & qm[il/2][0]);
+ s2 += yl[l+1] * (qs & qm[il/2][1]);
+ s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
+ s4 += yl[l+16] * (qs & qm[il/2][2]);
+ s5 += yl[l+17] * (qs & qm[il/2][3]);
+ s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
+ }
+ float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
+ float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
+ sumf1[row] += d1 * (scales[0] - 32);
+ sumf2[row] += d2 * (scales[2] - 32);
+
+ s1 = s2 = s3 = s4 = s5 = s6 = 0;
+ for (int l = 0; l < n; l += 2) {
+ const int32_t qs = q[l/2+8];
+ s1 += yl[l+8] * (qs & qm[il/2][0]);
+ s2 += yl[l+9] * (qs & qm[il/2][1]);
+ s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
+ s4 += yl[l+24] * (qs & qm[il/2][2]);
+ s5 += yl[l+25] * (qs & qm[il/2][3]);
+ s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
+ }
+ d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
+ d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
+ sumf1[row] += d1 * (scales[1] - 32);
+ sumf2[row] += d2 * (scales[3] - 32);
+
+ q += args.nb01/2;
+ h += args.nb01/2;
+ a += args.nb01/2;
+ dh += args.nb01/2;
+ }
+
+ y1 += 4 * QK_K;
+ }
+
+ for (int row = 0; row < 2; ++row) {
+ const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
+ sumf1[row] = simd_sum(sumf);
+ }
+
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+ if (tiisg == 0) {
+ for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
+ dst_f32[first_row + row] = sumf1[row];
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_q3_K_f32")]]
+kernel void kernel_mul_mv_q3_K_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+template
+void kernel_mul_mv_q4_K_f32_impl(
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ const uint16_t kmask1 = 0x3f3f;
+ const uint16_t kmask2 = 0x0f0f;
+ const uint16_t kmask3 = 0xc0c0;
+
+ const int ix = tiisg/8; // 0...3
+ const int it = tiisg%8; // 0...7
+ const int iq = it/4; // 0 or 1
+ const int ir = it%4; // 0...3
+
+ const int nb = args.ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+ //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int first_row = r0 * N_DST;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
+
+ float yl[16];
+ float yh[16];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
+
+ uint16_t sc16[4];
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
+
+ for (int ib = ix; ib < nb; ib += 4) {
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; ++i) {
+ yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
+ yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
+ yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
+ yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
+ }
+
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
+ device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
+ device const half * dh = &x[ib].d;
+
+ for (int row = 0; row < N_DST; row++) {
+ sc16[0] = sc[0] & kmask1;
+ sc16[1] = sc[2] & kmask1;
+ sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
+ sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
+
+ device const uint16_t * q2 = q1 + 32;
+
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; i += 2) {
+ acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
+ acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
+ acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
+ acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
+ acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
+ acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
+ acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
+ acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
+ }
+
+ float dall = dh[0];
+ float dmin = dh[1];
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
+ (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
+ (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
+ (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
+
+ q1 += args.nb01/2;
+ sc += args.nb01/2;
+ dh += args.nb01/2;
+ }
+
+ y4 += 4 * QK_K;
+ }
+
+ device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
+
+ for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst_f32[first_row + row] = all_sum;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_q4_K_f32")]]
+kernel void kernel_mul_mv_q4_K_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+template
+void kernel_mul_mv_q5_K_f32_impl(
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ const int nb = args.ne00/QK_K;
+
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
+ device const float * yy = (device const float *) (src1 + offset1);
+
+ float sumf[2]={0.f};
+
+ float yl[16], yh[16];
+
+ const uint16_t kmask1 = 0x3f3f;
+ const uint16_t kmask2 = 0x0f0f;
+ const uint16_t kmask3 = 0xc0c0;
+
+ const int tid = tiisg/4;
+ const int ix = tiisg%4;
+ const int iq = tid/4;
+ const int ir = tid%4;
+ const int n = 8;
+
+ const int l0 = n*ir;
+ const int q_offset = 32*iq + l0;
+ const int y_offset = 64*iq + l0;
+
+ const uint8_t hm1 = 1u << (2*iq);
+ const uint8_t hm2 = hm1 << 1;
+ const uint8_t hm3 = hm1 << 4;
+ const uint8_t hm4 = hm2 << 4;
+
+ uint16_t sc16[4];
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
+
+ device const float * y1 = yy + ix*QK_K + y_offset;
+
+ for (int i = ix; i < nb; i += 4) {
+ device const uint8_t * q1 = x[i].qs + q_offset;
+ device const uint8_t * qh = x[i].qh + l0;
+ device const half * dh = &x[i].d;
+ device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
+
+ device const float * y2 = y1 + 128;
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
+ for (int l = 0; l < 8; ++l) {
+ yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
+ yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
+ yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
+ yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
+ }
+
+ for (int row = 0; row < 2; ++row) {
+ device const uint8_t * q2 = q1 + 64;
+
+ sc16[0] = a[0] & kmask1;
+ sc16[1] = a[2] & kmask1;
+ sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
+ sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
+
+ float4 acc1 = {0.f};
+ float4 acc2 = {0.f};
+ for (int l = 0; l < n; ++l) {
+ uint8_t h = qh[l];
+ acc1[0] += yl[l+0] * (q1[l] & 0x0F);
+ acc1[1] += yl[l+8] * (q1[l] & 0xF0);
+ acc1[2] += yh[l+0] * (q2[l] & 0x0F);
+ acc1[3] += yh[l+8] * (q2[l] & 0xF0);
+ acc2[0] += h & hm1 ? yl[l+0] : 0.f;
+ acc2[1] += h & hm2 ? yl[l+8] : 0.f;
+ acc2[2] += h & hm3 ? yh[l+0] : 0.f;
+ acc2[3] += h & hm4 ? yh[l+8] : 0.f;
+ }
+ const float dall = dh[0];
+ const float dmin = dh[1];
+ sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
+ sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
+ sc8[4] * (acc1[2] + 16.f*acc2[2]) +
+ sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
+
+ q1 += args.nb01;
+ qh += args.nb01;
+ dh += args.nb01/2;
+ a += args.nb01/2;
+ }
+
+ y1 += 4 * QK_K;
+ }
+
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+ for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
+ const float tot = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst_f32[first_row + row] = tot;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_q5_K_f32")]]
+kernel void kernel_mul_mv_q5_K_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+template
+void kernel_mul_mv_q6_K_f32_impl(
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ const uint8_t kmask1 = 0x03;
+ const uint8_t kmask2 = 0x0C;
+ const uint8_t kmask3 = 0x30;
+ const uint8_t kmask4 = 0xC0;
+
+ const int nb = args.ne00/QK_K;
+
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int row = 2*r0 + sgitg;
+
+ if (row >= args.ne0) {
+ return;
+ }
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
+ device const float * yy = (device const float *) (src1 + offset1);
+
+ float sumf = 0;
+
+ const int tid = tiisg/2;
+ const int ix = tiisg%2;
+ const int ip = tid/8; // 0 or 1
+ const int il = tid%8;
+ const int n = 4;
+ const int l0 = n*il;
+ const int is = 8*ip + l0/16;
+
+ const int y_offset = 128*ip + l0;
+ const int q_offset_l = 64*ip + l0;
+ const int q_offset_h = 32*ip + l0;
+
+ for (int i = ix; i < nb; i += 2) {
+ device const uint8_t * q1 = x[i].ql + q_offset_l;
+ device const uint8_t * q2 = q1 + 32;
+ device const uint8_t * qh = x[i].qh + q_offset_h;
+ device const int8_t * sc = x[i].scales + is;
+
+ device const float * y = yy + i * QK_K + y_offset;
+
+ const float dall = x[i].d;
+
+ float4 sums = {0.f, 0.f, 0.f, 0.f};
+ for (int l = 0; l < n; ++l) {
+ sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
+ sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
+ sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
+ sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+ }
+
+ sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
+
+ }
+
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+ const float tot = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst_f32[row] = tot;
+ }
+}
+
+[[host_name("kernel_mul_mv_q6_K_f32")]]
+kernel void kernel_mul_mv_q6_K_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+// ======================= "True" 2-bit
+
+template
+void kernel_mul_mv_iq2_xxs_f32_impl(
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ const int nb = args.ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int nb32 = nb * (QK_K / 32);
+
+ threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
+ threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256);
+ {
+ int nval = 4;
+ int pos = (32*sgitg + tiisg)*nval;
+ for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xxs_grid[pos + i];
+ nval = 2;
+ pos = (32*sgitg + tiisg)*nval;
+ for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+ const int ix = tiisg;
+
+ device const float * y4 = y + 32 * ix;
+
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+ for (int i = 0; i < 32; ++i) {
+ yl[i] = y4[i];
+ }
+
+ const int ibl = ib32 / (QK_K / 32);
+ const int ib = ib32 % (QK_K / 32);
+
+ device const block_iq2_xxs * xr = x + ibl;
+ device const uint16_t * q2 = xr->qs + 4 * ib;
+ device const half * dh = &xr->d;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ const float db = dh[0];
+ device const uint8_t * aux8 = (device const uint8_t *)q2;
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
+ const float d = db * (0.5f + (aux32 >> 28));
+
+ float sum = 0;
+ for (int l = 0; l < 4; ++l) {
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
+ const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
+ for (int j = 0; j < 8; ++j) {
+ sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+ }
+ }
+ sumf[row] += d * sum;
+
+ dh += args.nb01/2;
+ q2 += args.nb01/2;
+ }
+
+ y4 += 32 * 32;
+ }
+
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+ for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst_f32[first_row + row] = all_sum * 0.25f;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
+kernel void kernel_mul_mv_iq2_xxs_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+template
+void kernel_mul_mv_iq2_xs_f32_impl(
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ const int nb = args.ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int nb32 = nb * (QK_K / 32);
+
+ threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
+ threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 512);
+ {
+ int nval = 8;
+ int pos = (32*sgitg + tiisg)*nval;
+ for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xs_grid[pos + i];
+ nval = 2;
+ pos = (32*sgitg + tiisg)*nval;
+ for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+ const int ix = tiisg;
+
+ device const float * y4 = y + 32 * ix;
+
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+ for (int i = 0; i < 32; ++i) {
+ yl[i] = y4[i];
+ }
+
+ const int ibl = ib32 / (QK_K / 32);
+ const int ib = ib32 % (QK_K / 32);
+
+ device const block_iq2_xs * xr = x + ibl;
+ device const uint16_t * q2 = xr->qs + 4 * ib;
+ device const uint8_t * sc = xr->scales + ib;
+ device const half * dh = &xr->d;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ const float db = dh[0];
+ const uint8_t ls1 = sc[0] & 0xf;
+ const uint8_t ls2 = sc[0] >> 4;
+ const float d1 = db * (0.5f + ls1);
+ const float d2 = db * (0.5f + ls2);
+
+ float sum1 = 0, sum2 = 0;
+ for (int l = 0; l < 2; ++l) {
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
+ const uint8_t signs = ssigns[(q2[l] >> 9)];
+ for (int j = 0; j < 8; ++j) {
+ sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+ }
+ }
+ for (int l = 2; l < 4; ++l) {
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
+ const uint8_t signs = ssigns[(q2[l] >> 9)];
+ for (int j = 0; j < 8; ++j) {
+ sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+ }
+ }
+ sumf[row] += d1 * sum1 + d2 * sum2;
+
+ dh += args.nb01/2;
+ q2 += args.nb01/2;
+ sc += args.nb01;
+ }
+
+ y4 += 32 * 32;
+ }
+
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+ for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst_f32[first_row + row] = all_sum * 0.25f;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_iq2_xs_f32")]]
+kernel void kernel_mul_mv_iq2_xs_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+template
+void kernel_mul_mv_iq3_xxs_f32_impl(
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ const int nb = args.ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int nb32 = nb * (QK_K / 32);
+
+ threadgroup uint32_t * svalues = (threadgroup uint32_t *)(shmem);
+ threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256);
+ {
+ int nval = 4;
+ int pos = (32*sgitg + tiisg)*nval;
+ for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3xxs_grid[pos + i];
+ nval = 2;
+ pos = (32*sgitg + tiisg)*nval;
+ for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+ const int ix = tiisg;
+
+ device const float * y4 = y + 32 * ix;
+
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+ for (int i = 0; i < 32; ++i) {
+ yl[i] = y4[i];
+ }
+
+ const int ibl = ib32 / (QK_K / 32);
+ const int ib = ib32 % (QK_K / 32);
+
+ device const block_iq3_xxs * xr = x + ibl;
+ device const uint8_t * q3 = xr->qs + 8 * ib;
+ device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
+ device const half * dh = &xr->d;
+
+ for (int row = 0; row < N_DST; row++) {
+ const float db = dh[0];
+ const uint32_t aux32 = gas[0] | (gas[1] << 16);
+ const float d = db * (0.5f + (aux32 >> 28));
+
+ float2 sum = {0};
+ for (int l = 0; l < 4; ++l) {
+ const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
+ const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
+ const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
+ for (int j = 0; j < 4; ++j) {
+ sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+ sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+ }
+ }
+ sumf[row] += d * (sum[0] + sum[1]);
+
+ dh += args.nb01/2;
+ q3 += args.nb01;
+ gas += args.nb01/2;
+ }
+
+ y4 += 32 * 32;
+ }
+
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+ for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst_f32[first_row + row] = all_sum * 0.5f;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_iq3_xxs_f32")]]
+kernel void kernel_mul_mv_iq3_xxs_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+template
+void kernel_mul_mv_iq3_s_f32_impl(
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ const int nb = args.ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int nb32 = nb * (QK_K / 32);
+
+ threadgroup uint32_t * svalues = (threadgroup uint32_t *) shmem;
+ {
+ int nval = 8;
+ int pos = (32*sgitg + tiisg)*nval;
+ for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3s_grid[pos + i];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+ const int ix = tiisg;
+
+ device const float * y4 = y + 32 * ix;
+
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+ for (int i = 0; i < 32; ++i) {
+ yl[i] = y4[i];
+ }
+
+ const int ibl = ib32 / (QK_K / 32);
+ const int ib = ib32 % (QK_K / 32);
+
+ device const block_iq3_s * xr = x + ibl;
+ device const uint8_t * qs = xr->qs + 8 * ib;
+ device const uint8_t * qh = xr->qh + ib;
+ device const uint8_t * sc = xr->scales + (ib/2);
+ device const uint8_t * signs = xr->signs + 4 * ib;
+ device const half * dh = &xr->d;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ const float db = dh[0];
+ const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
+
+ float2 sum = {0};
+ for (int l = 0; l < 4; ++l) {
+ const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
+ const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
+ const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
+ const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
+ for (int j = 0; j < 4; ++j) {
+ sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
+ sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
+ }
+ }
+ sumf[row] += d * (sum[0] + sum[1]);
+
+ dh += args.nb01/2;
+ qs += args.nb01;
+ qh += args.nb01;
+ sc += args.nb01;
+ signs += args.nb01;
+ }
+
+ y4 += 32 * 32;
+ }
+
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+ for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst_f32[first_row + row] = all_sum;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_iq3_s_f32")]]
+kernel void kernel_mul_mv_iq3_s_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+template
+void kernel_mul_mv_iq2_s_f32_impl(
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ const int nb = args.ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int nb32 = nb * (QK_K / 32);
+
+ //threadgroup uint64_t * svalues = (threadgroup uint64_t *) shmem;
+ //{
+ // int nval = 32;
+ // int pos = (32*sgitg + tiisg)*nval;
+ // for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2s_grid[pos + i];
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
+ //}
+
+ const int ix = tiisg;
+
+ device const float * y4 = y + 32 * ix;
+
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+ for (int i = 0; i < 32; ++i) {
+ yl[i] = y4[i];
+ }
+
+ const int ibl = ib32 / (QK_K / 32);
+ const int ib = ib32 % (QK_K / 32);
+
+ device const block_iq2_s * xr = x + ibl;
+ device const uint8_t * qs = xr->qs + 4 * ib;
+ device const uint8_t * qh = xr->qh + ib;
+ device const uint8_t * sc = xr->scales + ib;
+ device const uint8_t * signs = qs + QK_K/8;
+ device const half * dh = &xr->d;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ const float db = dh[0];
+ const float d1 = db * (0.5f + (sc[0] & 0xf));
+ const float d2 = db * (0.5f + (sc[0] >> 4));
+
+ float2 sum = {0};
+ for (int l = 0; l < 2; ++l) {
+ //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
+ //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
+ constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
+ for (int j = 0; j < 8; ++j) {
+ sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
+ sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
+ }
+ }
+ sumf[row] += d1 * sum[0] + d2 * sum[1];
+
+ dh += args.nb01/2;
+ qs += args.nb01;
+ qh += args.nb01;
+ sc += args.nb01;
+ signs += args.nb01;
+ }
+
+ y4 += 32 * 32;
+ }
+
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+ for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst_f32[first_row + row] = all_sum * 0.25f;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_iq2_s_f32")]]
+kernel void kernel_mul_mv_iq2_s_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+template
+void kernel_mul_mv_iq1_s_f32_impl(
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ const int nb = args.ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int nb32 = nb * (QK_K / 32);
+
+ const int ix = tiisg;
+
+ device const float * y4 = y + 32 * ix;
+
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+ float sumy = 0;
+ for (int i = 0; i < 32; ++i) {
+ yl[i] = y4[i];
+ sumy += yl[i];
+ }
+
+ const int ibl = ib32 / (QK_K / 32);
+ const int ib = ib32 % (QK_K / 32);
+
+ device const block_iq1_s * xr = x + ibl;
+ device const uint8_t * qs = xr->qs + 4 * ib;
+ device const uint16_t * qh = xr->qh + ib;
+ device const half * dh = &xr->d;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
+ constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
+ constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
+
+ float sum = 0;
+ for (int j = 0; j < 4; ++j) {
+ sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
+ + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
+ + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
+ + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
+ }
+ sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);
+
+ dh += args.nb01/2;
+ qs += args.nb01;
+ qh += args.nb01/2;
+ }
+
+ y4 += 32 * 32;
+ }
+
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+ for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst_f32[first_row + row] = all_sum;
+ }
+ }
+}
+
+template
+void kernel_mul_mv_iq1_m_f32_impl(
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ const int nb = args.ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int nb32 = nb * (QK_K / 32);
+
+ const int ix = tiisg;
+
+ device const float * y4 = y + 32 * ix;
+
+ iq1m_scale_t scale;
+
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+ float4 sumy = {0.f};
+ for (int i = 0; i < 8; ++i) {
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+ yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
+ yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
+ yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
+ }
+
+ const int ibl = ib32 / (QK_K / 32);
+ const int ib = ib32 % (QK_K / 32);
+
+ device const block_iq1_m * xr = x + ibl;
+ device const uint8_t * qs = xr->qs + 4 * ib;
+ device const uint8_t * qh = xr->qh + 2 * ib;
+ device const uint16_t * sc = (device const uint16_t *)xr->scales;
+
+ for (int row = 0; row < N_DST; row++) {
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
+ constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
+ constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
+
+ float2 sum = {0.f};
+ for (int j = 0; j < 4; ++j) {
+ sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
+ + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
+ sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
+ + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
+ }
+ const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+ const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+
+ sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
+ (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
+
+ sc += args.nb01/2;
+ qs += args.nb01;
+ qh += args.nb01;
+ }
+
+ y4 += 32 * 32;
+ }
+
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+ for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst_f32[first_row + row] = all_sum;
+ }
+ }
+}
+
+template
+void kernel_mul_mv_iq4_nl_f32_impl(
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ threadgroup float * shmem_f32 = (threadgroup float *) shmem;
+ const int nb = args.ne00/QK4_NL;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+ const int first_row = (r0 * 2 + sgitg) * 2;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
+
+ const int ix = tiisg/2; // 0...15
+ const int it = tiisg%2; // 0 or 1
+
+ shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ float4 yl[4];
+ float sumf[2]={0.f}, all_sum;
+
+ device const float * yb = y + ix * QK4_NL + it * 8;
+
+ uint32_t aux32[2];
+ thread const uint8_t * q8 = (thread const uint8_t *)aux32;
+
+ float4 qf1, qf2;
+
+ for (int ib = ix; ib < nb; ib += 16) {
+
+ device const float4 * y4 = (device const float4 *)yb;
+ yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
+
+ for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
+
+ device const block_iq4_nl & xb = x[row*nb + ib];
+ device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
+
+ float4 acc1 = {0.f}, acc2 = {0.f};
+
+ aux32[0] = q4[0] | (q4[1] << 16);
+ aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
+ aux32[0] &= 0x0f0f0f0f;
+ qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
+ qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
+ acc1 += yl[0] * qf1;
+ acc2 += yl[1] * qf2;
+
+ aux32[0] = q4[2] | (q4[3] << 16);
+ aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
+ aux32[0] &= 0x0f0f0f0f;
+ qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
+ qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
+ acc1 += yl[2] * qf1;
+ acc2 += yl[3] * qf2;
+
+ acc1 += acc2;
+
+ sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
+
+ }
+
+ yb += 16 * QK4_NL;
+ }
+
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+ for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst_f32[first_row + row] = all_sum;
+ }
+ }
+}
+
+template
+void kernel_mul_mv_iq4_xs_f32_impl(
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ threadgroup float * shmem_f32 = (threadgroup float *) shmem;
+ const int nb = args.ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+ const int first_row = (r0 * 2 + sgitg) * 2;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
+
+ const int ix = tiisg/16; // 0 or 1
+ const int it = tiisg%16; // 0...15
+ const int ib = it/2;
+ const int il = it%2;
+
+ shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ float4 yl[4];
+ float sumf[2]={0.f}, all_sum;
+
+ device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
+
+ uint32_t aux32[2];
+ thread const uint8_t * q8 = (thread const uint8_t *)aux32;
+
+ float4 qf1, qf2;
+
+ for (int ibl = ix; ibl < nb; ibl += 2) {
+ device const float4 * y4 = (device const float4 *)yb;
+ yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
+
+ for (int row = 0; row < 2; ++row) {
+ device const block_iq4_xs & xb = x[row*nb + ibl];
+ device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
+
+ float4 acc1 = {0.f}, acc2 = {0.f};
+
+ aux32[0] = (q4[0] ) & 0x0f0f0f0f;
+ aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
+ qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
+ qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
+ acc1 += yl[0] * qf1;
+ acc2 += yl[1] * qf2;
+
+ aux32[0] = (q4[1] ) & 0x0f0f0f0f;
+ aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
+ qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
+ qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
+ acc1 += yl[2] * qf1;
+ acc2 += yl[3] * qf2;
+
+ acc1 += acc2;
+
+ const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
+ sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
+
+ }
+
+ yb += 2 * QK_K;
+ }
+
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+ for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst_f32[first_row + row] = all_sum;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_iq1_s_f32")]]
+kernel void kernel_mul_mv_iq1_s_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+[[host_name("kernel_mul_mv_iq1_m_f32")]]
+kernel void kernel_mul_mv_iq1_m_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+[[host_name("kernel_mul_mv_iq4_nl_f32")]]
+kernel void kernel_mul_mv_iq4_nl_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+[[host_name("kernel_mul_mv_iq4_xs_f32")]]
+kernel void kernel_mul_mv_iq4_xs_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+template
+kernel void kernel_get_rows_q(
+ device const void * src0,
+ device const void * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
+ float4x4 temp;
+ dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
+ *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
+ }
+}
+
+template
+kernel void kernel_get_rows_f(
+ device const void * src0,
+ device const void * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+ (( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
+ ((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
+ }
+}
+
+kernel void kernel_get_rows_i32(
+ device const void * src0,
+ device const void * src1,
+ device int32_t * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+ (( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
+ ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
+ }
+}
+
+
+#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
+#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
+#define BLOCK_SIZE_K 32
+#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
+#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
+#define THREAD_PER_BLOCK 128
+#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
+#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
+#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
+#define SG_MAT_ROW 8
+
+// each block_q contains 16*nl weights
+template
+kernel void kernel_mul_mm(
+ constant ggml_metal_kargs_mul_mm & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiitg[[thread_index_in_threadgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ threadgroup T * sa = (threadgroup T *)(shmem);
+ threadgroup float * sb = (threadgroup float *)(shmem + 4096);
+
+ const int r0 = tgpig.y;
+ const int r1 = tgpig.x;
+ const int im = tgpig.z;
+
+ // if this block is of 64x32 shape or smaller
+ const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
+ const short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
+
+ // a thread shouldn't load data outside of the matrix
+ const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
+ const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
+
+ simdgroup_T8x8 ma[4];
+ simdgroup_float8x8 mb[2];
+ simdgroup_float8x8 mc[8];
+
+ for (short i = 0; i < 8; i++){
+ mc[i] = make_filled_simdgroup_matrix(0.f);
+ }
+
+ short il = (tiitg % THREAD_PER_ROW);
+
+ const int i12 = im%args.ne12;
+ const int i13 = im/args.ne12;
+
+ const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const short offset1 = il/nl;
+
+ device const block_q * x = (device const block_q *)(src0
+ + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
+
+ device const float * y = (device const float *)(src1
+ + args.nb13*i13
+ + args.nb12*i12
+ + args.nb11*(r1*BLOCK_SIZE_N + thread_col)
+ + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+
+ for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
+ // load data and store to threadgroup memory
+ T4x4 temp_a;
+ dequantize_func(x, il, temp_a);
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ #pragma unroll(16)
+ for (short i = 0; i < 16; i++) {
+ *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
+ + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
+ + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
+ }
+
+ *(threadgroup float2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y);
+
+ il = (il + 2 < nl) ? il + 2 : il % 2;
+ x = (il < 2) ? x + (2 + nl - 1)/nl : x;
+ y += BLOCK_SIZE_K;
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // load matrices from threadgroup memory and conduct outer products
+ threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
+ threadgroup const float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
+
+ #pragma unroll(4)
+ for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
+ #pragma unroll(4)
+ for (short i = 0; i < 4; i++) {
+ simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
+ }
+
+ simdgroup_barrier(mem_flags::mem_none);
+
+ #pragma unroll(2)
+ for (short i = 0; i < 2; i++) {
+ simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
+ }
+
+ #pragma unroll(8)
+ for (short i = 0; i < 8; i++){
+ simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
+ }
+
+ lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE;
+ lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE;
+ }
+ }
+
+ if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) {
+ device float * C = (device float *) dst +
+ (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
+ (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
+
+ for (short i = 0; i < 8; i++) {
+ simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0);
+ }
+ } else {
+ // block is smaller than 64x32, we should avoid writing data outside of the matrix
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup float * temp_str = ((threadgroup float *) shmem) \
+ + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
+ for (short i = 0; i < 8; i++) {
+ simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (sgitg == 0) {
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
+ device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.ne0 + im*args.ne1*args.ne0;
+ device float4 * D4 = (device float4 *) D;
+
+ threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
+ threadgroup float4 * C4 = (threadgroup float4 *) C;
+
+ int i = 0;
+ for (; i < n_rows/4; i++) {
+ *(D4 + i) = *(C4 + i);
+ }
+
+ i *= 4;
+ for (; i < n_rows; i++) {
+ *(D + i) = *(C + i);
+ }
+ }
+ }
+ }
+}
+
+// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
+// TODO: this kernel needs to be reimplemented from scratch for better performance
+template
+void kernel_mul_mm_id_impl(
+ int32_t ne00,
+ int32_t ne02,
+ uint64_t nb01,
+ uint64_t nb02,
+ int32_t ne11,
+ int32_t ne12,
+ uint64_t nb10,
+ uint64_t nb11,
+ uint64_t nb12,
+ int32_t ne0,
+ int32_t ne1,
+ int64_t ne0ne1,
+ device const char * src0,
+ device const char * src1,
+ threadgroup ushort2 * rowids,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiitg[[thread_index_in_threadgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ threadgroup half * sa = (threadgroup half *)(shmem);
+ threadgroup float * sb = (threadgroup float *)(shmem + 4096);
+
+ const int r0 = tgpig.y;
+ const int r1 = tgpig.x;
+
+ if (r1*BLOCK_SIZE_N >= ne1) return;
+
+ // if this block is of 64x32 shape or smaller
+ short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
+ short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
+
+ // a thread shouldn't load data outside of the matrix
+ short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
+ short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
+
+ simdgroup_half8x8 ma[4];
+ simdgroup_float8x8 mb[2];
+ simdgroup_float8x8 mc[8];
+ for (int i = 0; i < 8; i++){
+ mc[i] = make_filled_simdgroup_matrix(0.f);
+ }
+ short il = (tiitg % THREAD_PER_ROW);
+
+ ushort offset1 = il/nl;
+
+ threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
+
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
+ device const float * y = (device const float *)(src1
+ + nb12 * id[1]
+ + nb11 * (id[0] % ne11)
+ + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+
+ for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
+ // load data and store to threadgroup memory
+ half4x4 temp_a;
+ dequantize_func(x, il, temp_a);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int i = 0; i < 16; i++) {
+ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
+ }
+
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
+
+ il = (il + 2 < nl) ? il + 2 : il % 2;
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
+ y += BLOCK_SIZE_K;
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // load matrices from threadgroup memory and conduct outer products
+ threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
+ threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
+
+ #pragma unroll(BLOCK_SIZE_K/8)
+ for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
+ #pragma unroll(4)
+ for (int i = 0; i < 4; i++) {
+ simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
+ }
+ simdgroup_barrier(mem_flags::mem_none);
+ #pragma unroll(2)
+ for (int i = 0; i < 2; i++) {
+ simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
+ }
+
+ lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
+ lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
+
+ #pragma unroll(8)
+ for (int i = 0; i < 8; i++){
+ simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
+ }
+ }
+ }
+
+ {
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup float * temp_str = ((threadgroup float *) shmem) \
+ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
+ for (int i = 0; i < 8; i++) {
+ simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (sgitg == 0) {
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
+ threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
+ int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1;
+
+ device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff;
+ device float4 * D4 = (device float4 *) D;
+
+ threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
+ threadgroup float4 * C4 = (threadgroup float4 *) C;
+
+ int i = 0;
+ for (; i < n_rows/4; i++) {
+ *(D4 + i) = *(C4 + i);
+ }
+
+ i *= 4;
+ for (; i < n_rows; i++) {
+ *(D + i) = *(C + i);
+ }
+ }
+ }
+ }
+}
+
+template
+kernel void kernel_mul_mm_id(
+ constant ggml_metal_kargs_mul_mm_id & args,
+ device const char * src0s,
+ device const char * src1,
+ device char * dst,
+ device const char * ids,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiitg[[thread_index_in_threadgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ const int32_t i02 = tgpig.z;
+
+ tgpig.z = 0;
+
+ device const char * src0 = src0s + i02*args.nb02;
+
+ // row indices
+ threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);
+
+ // TODO: parallelize this loop
+ int32_t _ne1 = 0;
+ for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
+ for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
+ int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];
+ if (id == i02) {
+ if (tiitg == 0) {
+ rowids[_ne1] = ushort2(ii0, ii1);
+ }
+ _ne1++;
+ }
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ kernel_mul_mm_id_impl(
+ args.ne00,
+ args.ne02,
+ args.nb01,
+ args.nb02,
+ args.ne11,
+ args.ne12,
+ args.nb10,
+ args.nb11,
+ args.nb12,
+ args.ne0,
+ _ne1,
+ (int64_t)args.ne0*args.ne1,
+ src0,
+ src1,
+ rowids,
+ dst,
+ shmem,
+ tgpig,
+ tiitg,
+ sgitg);
+}
+
+#define QK_NL 16
+
+//
+// get rows
+//
+
+typedef decltype(kernel_get_rows_f) get_rows_f_t;
+
+template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f;
+template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f;
+#endif
+
+typedef decltype(kernel_get_rows_q) get_rows_q_t;
+
+template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q;
+
+//
+// matrix-matrix multiplication
+//
+
+typedef decltype(kernel_mul_mm) mat_mm_t;
+
+template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm;
+#endif
+template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm;
+
+//
+// indirect matrix-matrix multiplication
+//
+
+typedef decltype(kernel_mul_mm_id