diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 511b95833..8cf347e57 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,6 +1,6 @@ name: CI -on: [push, pull_request] +on: [pull_request] jobs: format-check: diff --git a/THIRDPARTYNOTICES.txt b/THIRDPARTYNOTICES.txt index f377e67bb..d959effbb 100644 --- a/THIRDPARTYNOTICES.txt +++ b/THIRDPARTYNOTICES.txt @@ -206,3 +206,207 @@ Notice for apache/tvm limitations under the License. ------------------------------------------------------------------------------------ +Notice for IST-DASLab/marlin/ +------------------------------- + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +------------------------------------------------------------------------------------ diff --git a/bitblas/base/roller/arch/__init__.py b/bitblas/base/arch/__init__.py similarity index 100% rename from bitblas/base/roller/arch/__init__.py rename to bitblas/base/arch/__init__.py diff --git a/bitblas/base/roller/arch/arch_base.py b/bitblas/base/arch/arch_base.py similarity index 100% rename from bitblas/base/roller/arch/arch_base.py rename to bitblas/base/arch/arch_base.py diff --git a/bitblas/base/roller/arch/cpu.py b/bitblas/base/arch/cpu.py similarity index 100% rename from bitblas/base/roller/arch/cpu.py rename to bitblas/base/arch/cpu.py diff --git a/bitblas/base/roller/arch/cuda.py b/bitblas/base/arch/cuda.py similarity index 100% rename from bitblas/base/roller/arch/cuda.py rename to bitblas/base/arch/cuda.py diff --git a/bitblas/base/roller/__init__.py b/bitblas/base/roller/__init__.py index 9afd7cff0..3f728e695 100644 --- a/bitblas/base/roller/__init__.py +++ b/bitblas/base/roller/__init__.py @@ -4,4 +4,4 @@ from .rasterization import NoRasterization, Rasterization2DRow, Rasterization2DColumn # noqa: F401 from .hint import Hint # noqa: F401 from .policy import DefaultPolicy, TensorCorePolicy # noqa: F401 -from .arch import TileDevice, CUDA # noqa: F401 +from ..arch import TileDevice, CUDA # noqa: F401 diff --git a/bitblas/base/roller/policy/default.py b/bitblas/base/roller/policy/default.py index 730c8336f..e9f7b809f 100644 --- a/bitblas/base/roller/policy/default.py +++ b/bitblas/base/roller/policy/default.py @@ -9,7 +9,7 @@ import numpy as np from bitblas import tvm -from ..arch import TileDevice +from ...arch import TileDevice from ..bestfit import BestFit from ..hint import Hint, Stride, TileDict from .common import coalesced_factor, coalesced_tensor_shape, factorize, get_all_factors diff --git a/bitblas/base/roller/policy/tensorcore.py b/bitblas/base/roller/policy/tensorcore.py index ae45b5893..e69bcabc3 100644 --- a/bitblas/base/roller/policy/tensorcore.py +++ b/bitblas/base/roller/policy/tensorcore.py @@ -5,7 +5,7 @@ from typing import Dict, List, Tuple, Optional import numpy as np -from ..arch import TileDevice +from ...arch import TileDevice from ..hint import Hint, Stride, TileDict, IntrinInfo from ..node import PrimFuncNode from .common import coalesced_factor, factorize, get_all_factors diff --git a/bitblas/base/utils.py b/bitblas/base/utils.py index 4cd82fa93..1596b3c86 100644 --- a/bitblas/base/utils.py +++ b/bitblas/base/utils.py @@ -13,7 +13,7 @@ from tvm.relax.expr import Function import bitblas from .analysis import get_root_block, get_reduction_blocks, find_var_from_func -from bitblas.base.roller.arch import CUDA +from bitblas.base.arch import CUDA from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags import tempfile diff --git a/bitblas/builder/__init__.py b/bitblas/builder/__init__.py new file mode 100644 index 000000000..8e4d715c9 --- /dev/null +++ b/bitblas/builder/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .lib_generator import LibraryGenerator # noqa: F401 +from .wrapper import TIRWrapper # noqa: F401 diff --git a/bitblas/builder/lib_generator/__init__.py b/bitblas/builder/lib_generator/__init__.py new file mode 100644 index 000000000..a0800751a --- /dev/null +++ b/bitblas/builder/lib_generator/__init__.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from typing import Optional +from bitblas import TileDevice +import ctypes +import os +import tempfile +import subprocess +import logging + +logger = logging.getLogger(__name__) + + +class LibraryGenerator(object): + srcpath: Optional[str] = None + libpath: Optional[str] = None + lib_code: Optional[str] = None + + def __init__(self, arch: TileDevice): + self.arch = arch + + def update_lib_code(self, lib_code: str): + self.lib_code = lib_code + + # Assume currently we only support CUDA compilation + def load_lib(self): + return ctypes.CDLL(self.libpath) + + def compile_lib(self, timeout: float = None): + arch = self.arch + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) + compute_version = arch.compute_capability + libpath = src.name.replace(".cu", ".so") + + command = [ + "nvcc", + "-std=c++17", + "-Xcudafe", + "--diag_suppress=177", + "--compiler-options", + "'-fPIC'", + "-lineinfo", + "--shared", + src.name, + "-lcuda", + f"-gencode=arch=compute_{compute_version},code=compute_{compute_version}", + "-o", + libpath, + ] + src.write(self.lib_code) + src.flush() + try: + ret = subprocess.run(command, timeout=timeout) + except subprocess.TimeoutExpired: + logger.warning(f"Compilation Timeout! {command}") + return None + if ret.returncode != 0: + logger.warning(f"Compilation Failed! {command}") + return None + self.srcpath = src.name + self.libpath = libpath + + def remove_lib(self): + if self.libpath: + os.remove(self.libpath) + self.libpath = None + + def get_source_path(self): + return self.srcpath + + def get_lib_path(self): + return self.libpath + + def set_lib_path(self, libpath): + self.libpath = libpath + + def set_src_path(self, srcpath): + self.srcpath = srcpath diff --git a/bitblas/builder/wrapper/__init__.py b/bitblas/builder/wrapper/__init__.py new file mode 100644 index 000000000..c864f7a4b --- /dev/null +++ b/bitblas/builder/wrapper/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from .tir import TIRWrapper # noqa: F401 diff --git a/bitblas/builder/wrapper/base.py b/bitblas/builder/wrapper/base.py new file mode 100644 index 000000000..1705af2cc --- /dev/null +++ b/bitblas/builder/wrapper/base.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from abc import ABC, abstractmethod + + +class BaseWrapper(ABC): + + @abstractmethod + def wrap(self, *args, **kwargs): + raise NotImplementedError diff --git a/bitblas/builder/wrapper/tir.py b/bitblas/builder/wrapper/tir.py new file mode 100644 index 000000000..2d0162f66 --- /dev/null +++ b/bitblas/builder/wrapper/tir.py @@ -0,0 +1,404 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas import tvm +from typing import Optional, List, Dict, Union +from tvm import IRModule +from bitblas import TileDevice +from bitblas.utils import match_global_kernel +from bitblas.utils.rtmod_analysis import get_annotated_device_mod +import re +from .base import BaseWrapper +import logging + +logger = logging.getLogger(__name__) + + +class TIRCUDASourceWrapper(object): + _TYPE_MAP = { + "float32": "float", + "float16": "half", + "bfloat16": "__nv_bfloat162", + "e4m3_float8": "__nv_fp8_e4m3", + "e5m2_float8": "__nv_fp8_e5m2", + "float64": "double", + "int64": "int64_t", + "int32": "int", + "uint32": "unsigned int", + "bool": "int8_t", + "int8": "int8_t", + "uint8": "uint8_t", + "int16": "int16_t", + "uchar": "uint8_t", + } + + def __init__(self, optimized_mod: IRModule, source: str, arch: TileDevice): + self.mod = optimized_mod + self.arch = arch + self.source = source + self.function_name: Optional[str] = None + self.dynamic_smem_buf: Optional[int] = None + self.block_info: Union[List[int], Dict] = [1, 1, 1] + self.grid_info: Union[List[int], Dict] = [1, 1, 1] + self.parse_source_information() + self.srcpath: Optional[str] = None + self.libpath: Optional[str] = None + self.lib_code: Optional[str] = self.update_lib_code(source) + + def parse_source_information(self): + device_mod = get_annotated_device_mod(self.mod, self.arch.target) + assert (len(device_mod.functions) == 1 + ), "Only support one function in the module for static shape kernel." + for g_var, func in device_mod.functions.items(): + self.function_name = g_var.name_hint + attrs = func.attrs + if "dyn_shared_memory_buf" in attrs: + self.dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"]) + if "thread_extent" in attrs: + thread_extent = attrs["thread_extent"] + for tag, extent in thread_extent.items(): + if "threadIdx" in tag: + self.block_info["xyz".index(tag[-1])] = extent + elif "blockIdx" in tag: + self.grid_info["xyz".index(tag[-1])] = extent + + def get_dynamic_symbolic_set(self, prim_func): + # Determine the set of dynamic symbols used in the function + dynamic_symbolic_set = set() + for param in prim_func.params: + buffer = prim_func.buffer_map[param] + for dim in buffer.shape: + if isinstance(dim, tvm.tir.Var): + dynamic_symbolic_set.add(dim.name) + return dynamic_symbolic_set + + def get_cuda_init_func(self): + # Initialize an empty string for the CUDA function call + call_str = """""" + # If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call + if self.dynamic_smem_buf is not None: + call_str = """ + cudaFuncSetAttribute({}, + cudaFuncAttributeMaxDynamicSharedMemorySize, {}); + """.format(self.function_name, self.dynamic_smem_buf) + # Format the initialization function using the call_str + init_funcs = """ + extern "C" void init() {{ + {} + }} + """.format(call_str) + return init_funcs + + def update_lib_code(self, code: str): + # Update the library code with the given code string + self.lib_code = code + # Find the index of the global kernel function in the code + index = match_global_kernel(code) + # Extract the declaration of the function starting from the found index + declaration = code[index:].split(";")[0] + + function_name = self.function_name + # Get the CUDA initialization function + init_func = self.get_cuda_init_func() + + # Locate the opening brace of the function to insert arguments + index = code.index("{", index) + function_args = [] + # Populate the function arguments from the primary function's parameters and buffers + for param in self.prim_func.params: + buffer = self.prim_func.buffer_map[param] + function_args.append({ + "name": buffer.name, + "type": self._TYPE_MAP[buffer.dtype] + "* __restrict__", + }) + + dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) + # Add dynamic symbolic parameters as integers to the function arguments + for dyn_sym in dynamic_symbolic_set: + function_args.append({"name": dyn_sym, "type": "int"}) + + function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) + # Format the function arguments for declaration + def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) + + def func_call_args(s, function_args): + # Extract the function call arguments matching the function definition + pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)" + matches = re.findall(pattern, s) + call_args = [] + for match in matches: + for arg in function_args: + if arg["name"] == match: + call_args.append(match) + return call_args + + call_args = ", ".join(func_call_args(declaration, function_args)) + block_info, grid_info = self.block_info, self.grid_info + + def legalize_c(p): + # Convert TIR expressions to legal C expressions + # Directly convert to string since the special case handling + # does not alter the string representation for `tvm.tir.Var` and `IntImm`. + # Replace Python's floor division operator with C's division operator + if isinstance(p, tvm.tir.IntImm): + p = int(p) + return str(p).replace("//", "/") + + # Prepare the block and grid dimensions for the CUDA kernel launch + block_str = "dim3({}, {}, {})".format( + legalize_c(block_info[0]), + legalize_c(block_info[1]), + legalize_c(block_info[2]), + ) + grid_str = "dim3({}, {}, {})".format( + legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2])) + # Determine the shared memory size, defaulting to 0 if not specified + smem_str = 0 if self.dynamic_smem_buf is None else self.dynamic_smem_buf + # Format the CUDA kernel launch string + if len(dynamic_symbolic_set) != 0: + call_str = "if ({} == 0) return; \n\t\t".format(list(dynamic_symbolic_set)[0]) + else: + call_str = "" + call_str += "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str, + smem_str, call_args) + # Create the host function wrapper for the CUDA kernel + host_func = """ + extern "C" void call({}) {{ + {} + }} + """.format(def_args, call_str) + # Combine the source, initialization function, and host function to form the complete library code + lib_code = self.source + init_func + host_func + return lib_code + + @property + def prim_func(self): + return self.mod["main"] + + +class TIRCUDASourceWrapperWithDynamic(TIRCUDASourceWrapper): + + def __init__(self, optimized_mod: IRModule, source: str, arch: TileDevice): + super().__init__(optimized_mod, source, arch) + + def get_cuda_init_func(self): + # Initialize an empty string to accumulate CUDA function calls for setting dynamic shared memory + call_str = """""" + # Iterate over functions and their dynamic shared memory requirements + for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items(): + if dynamic_smem_buf is not None: + # Format the cudaFuncSetAttribute call for dynamic shared memory + call_str += """ + cudaFuncSetAttribute({}, + cudaFuncAttributeMaxDynamicSharedMemorySize, {}); + """.format(function_name, dynamic_smem_buf) + # Define the init function that will set the attributes for each kernel + init_funcs = """ +extern "C" void init() {{ + {} +}} + """.format(call_str) + return init_funcs + + def create_dispatch_func(self, code, function_informations): + # Extract the set of dynamic symbolic names used in the primary function + dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) + + # Find the location of the global kernel function in the code + index = match_global_kernel(code) + + # Analyze the function declaration to prepare for argument extraction + dummy_declaration = code[index:].split(";")[0] + + function_name = self.function_name + + # Identify the start of the function body to insert arguments + index = code.index("{", index) + function_args = [] + # Collect function arguments based on primary function's parameters and buffer mappings + for param in self.prim_func.params: + buffer = self.prim_func.buffer_map[param] + function_args.append({ + "name": buffer.name, + "type": self._TYPE_MAP[buffer.dtype] + "* __restrict__", + }) + # Add dynamic symbols as integer arguments + for dyn_sym in dynamic_symbolic_set: + function_args.append({"name": dyn_sym, "type": "int"}) + + function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) + + # Format the argument definitions for function declaration + def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) + + def func_call_args(s: str, function_args): + # Extract and clean the function call arguments to match the declaration + pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)" + matches = re.findall(pattern, s) + call_args = [] + for match in matches: + match = re.sub(r"\d+", "", match) # Remove numbers + match = re.sub(r"_", "", match) # Remove underscores + for arg in function_args: + if arg["name"] == match: + call_args.append(match) + return call_args + + call_args = ", ".join(func_call_args(dummy_declaration, function_args)) + + def legalize_c(p): + # Convert TIR expressions to legal C expressions + # Directly convert to string since the special case handling + # does not alter the string representation for `tvm.tir.Var` and `IntImm`. + # Replace Python's floor division operator with C's division operator + if isinstance(p, tvm.tir.IntImm): + p = int(p) + return str(p).replace("//", "/") + + last_range = 0 + num_items = len(function_informations) + _call_str = """""" + for function_name, info in function_informations.items(): + # Prepare block and grid configurations for kernel launches + block_info, grid_info = info["block_info"], info["grid_info"] + block_str = "dim3({}, {}, {})".format( + legalize_c(block_info[0]), + legalize_c(block_info[1]), + legalize_c(block_info[2]), + ) + grid_str = "dim3({}, {}, {})".format( + legalize_c(grid_info[0]), + legalize_c(grid_info[1]), + legalize_c(grid_info[2]), + ) + # Handle dynamic shared memory specification + smem_str = (0 if info["dynamic_smem_buf"] is None else info["dynamic_smem_buf"]) + opt_shapes = info["opt_shapes"] + # Generate conditional kernel launch code based on dynamic symbolic ranges + (symbolic,) = list(dynamic_symbolic_set) + range_str = opt_shapes[symbolic] + if last_range == 0: + call_str = "if ({} == 0) return; \n".format(symbolic,) + call_str += "if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( + symbolic, + range_str, + function_name, + grid_str, + block_str, + smem_str, + call_args, + ) + else: + call_str = "\t\telse if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( + symbolic, + range_str, + function_name, + grid_str, + block_str, + smem_str, + call_args, + ) + if last_range == num_items - 1: + call_str += ( + "\t\telse {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( + function_name, grid_str, block_str, smem_str, call_args)) + last_range += 1 + _call_str += call_str + + # Wrap the kernel dispatch logic in an external C function + host_func = """ +extern "C" void call({}) {{ + {} +}} + """.format(def_args, _call_str) + return host_func + + def parse_source_information(self): + # Parse device module to extract execution configurations for each function + device_mod = get_annotated_device_mod(self.mod, self.arch.target) + block_info_map = {} + grid_info_map = {} + dynamic_smem_buf_map = {} + for g_var, func in device_mod.functions.items(): + # Default block and grid configurations + block_info = [1, 1, 1] + grid_info = [1, 1, 1] + function_name = g_var.name_hint + attrs = func.attrs + dynamic_smem_buf = None + if "dyn_shared_memory_buf" in attrs: + dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"]) + if "thread_extent" in attrs: + # Extract block and grid sizes from thread extents + thread_extent = attrs["thread_extent"] + for tag, extent in thread_extent.items(): + if "threadIdx" in tag: + block_info["xyz".index(tag[-1])] = extent + elif "blockIdx" in tag: + grid_info["xyz".index(tag[-1])] = extent + # Map the extracted configurations to each function + block_info_map[function_name] = block_info + grid_info_map[function_name] = grid_info + dynamic_smem_buf_map[function_name] = dynamic_smem_buf + # Store the mappings for use in code generation + self.block_info = block_info_map + self.grid_info = grid_info_map + self.dynamic_smem_buf = dynamic_smem_buf_map + + def update_lib_code(self, code: str): + # Organize function information for code generation + function_informations = {} + for g_var, func in self.mod.functions.items(): + if g_var.name_hint == "main": + continue + function_name = g_var.name_hint + attrs = func.attrs + assert "opt_shapes" in attrs + opt_shapes = attrs["opt_shapes"] + function_informations[function_name] = { + "function_name": function_name, + "opt_shapes": opt_shapes, + "block_info": self.block_info[function_name], + "grid_info": self.grid_info[function_name], + "dynamic_smem_buf": self.dynamic_smem_buf[function_name], + } + + def compare_map_objects(map_obj): + comparable_representation = list(map_obj.values()) + return comparable_representation + + function_informations = dict( + sorted( + function_informations.items(), + key=lambda item: compare_map_objects(item[1]["opt_shapes"]))) + + self.lib_code = code + + # Generate the initialization and dispatch functions + init_func = self.get_cuda_init_func() + host_func = self.create_dispatch_func(code, function_informations) + # Concatenate source code with generated code segments + lib_code = self.source + init_func + host_func + return lib_code + + @property + def prim_func(self): + return self.mod["main"] + + +class TIRWrapper(BaseWrapper): + + def __init__(self, arch: TileDevice): + super().__init__() + self.optimized_mod = None + self.arch = arch + self.lib = None + + def assign_optimized_module(self, optimized_mod: IRModule): + self.optimized_mod = optimized_mod + + # Get Scheduled Rt Module and return source to be compiled + def wrap(self, c_source: str, is_dynamic: bool = False): + assert self.optimized_mod is not None, "Please assign optimized module first." + wrapper_class = TIRCUDASourceWrapper if not is_dynamic else TIRCUDASourceWrapperWithDynamic + wrapper = wrapper_class(self.optimized_mod, c_source, self.arch) + return wrapper.lib_code diff --git a/bitblas/cache/operator.py b/bitblas/cache/operator.py index 0c41ab686..295630f5d 100644 --- a/bitblas/cache/operator.py +++ b/bitblas/cache/operator.py @@ -107,16 +107,16 @@ def _save_operator_config_and_artifact(self, config, op_inst, config_path): with open(optimized_file_path, "w") as optimized_file: if op_inst.optimized_func is not None: optimized_file.write(op_inst.optimized_func.script(show_meta=False)) - if op_inst.wrapper.lib_name is not None: + if op_inst.wrapper.libpath is not None: # copy lib name to the same directory as the artifact - src_name = op_inst.wrapper.src_name + srcpath = op_inst.wrapper.srcpath shutil.copy( - src_name, + srcpath, os.path.join(config_path, os.path.basename("wrapper_source.cu")), ) - lib_name = op_inst.wrapper.lib_name + libpath = op_inst.wrapper.libpath shutil.copy( - lib_name, + libpath, os.path.join(config_path, os.path.basename("wrapper_compiled.so")), ) @@ -130,7 +130,7 @@ def _load_operators_from_arch_path(self, arch_path, target): self._load_operator(config_path, target) def _load_operator(self, config_path, target): - mapping, config, rt_mod, src_name, lib_name = None, None, None, None, None + mapping, config, rt_mod, srcpath, libpath = None, None, None, None, None for file in os.listdir(config_path): full_path = os.path.join(config_path, file) if file == "mapping.json": @@ -142,19 +142,23 @@ def _load_operator(self, config_path, target): elif file.endswith(".tar"): rt_mod = tvm.runtime.load_module(full_path) elif file == "wrapper_compiled.so": - lib_name = full_path + libpath = full_path elif file == "wrapper_source.cu": - src_name = full_path + srcpath = full_path if mapping and config and rt_mod: - self._instantiate_and_add_operator(mapping, config, rt_mod, src_name, lib_name, target) + self._instantiate_and_add_operator(mapping, config, rt_mod, srcpath, libpath, target) - def _instantiate_and_add_operator(self, mapping, config, rt_mod, src_name, lib_name, target): + def _instantiate_and_add_operator(self, mapping, config, rt_mod, srcpath, libpath, target): config_cls = getattr(bitblas, mapping["config_type"]) operator_cls = getattr(bitblas, mapping["operator_type"]) op_inst = operator_cls( - config=config_cls(**config), target=target, enable_tuning=False, from_database=True) - op_inst.update_runtime_module(rt_mod, src_name=src_name, lib_name=lib_name) + config=config_cls(**config), + target=target, + enable_tuning=False, + from_database=True, + ) + op_inst.update_runtime_module(rt_mod, srcpath=srcpath, libpath=libpath) self.add(config_cls(**config), op_inst) diff --git a/bitblas/generator.py b/bitblas/generator.py deleted file mode 100644 index 4ac6f2be2..000000000 --- a/bitblas/generator.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - - -class BitBLASGenerator: - - def __init__(self): - # Initialize the generator with configuration - pass - - def generate_cuda_code(self): - pass - - def generate_header(self): - pass diff --git a/bitblas/ops/general_matmul.py b/bitblas/ops/general_matmul/__init__.py similarity index 88% rename from bitblas/ops/general_matmul.py rename to bitblas/ops/general_matmul/__init__.py index 97dd7d13f..5c3f6d2e6 100644 --- a/bitblas/ops/general_matmul.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -4,17 +4,16 @@ from tvm.target import Target import operator from functools import reduce -from bitblas.base.roller.arch.cuda import CUDA +from bitblas.base.arch.cuda import CUDA from typing import Any, Literal, Optional, Tuple, Union -from .operator import Operator, TransformKind, OPExecutorCPU -from .impl.matmul_dequantize_impl import ( - select_implementation as weight_dequantize_implementation,) -from .impl.matmul_impl import select_implementation as consistent_implementation -from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 +from ..operator import OperatorConfig, Operator, TransformKind, OPExecutorCPU +from .tirscript.matmul_dequantize_impl import select_implementation as weight_dequantize_implementation +from .tirscript.matmul_impl import select_implementation as consistent_implementation +from ...base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 from bitblas.utils.target_detector import auto_detect_nvidia_target from dataclasses import dataclass -from .ladder_permutate import LadderPermutate, LadderPermutateConfig -from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig +from ..ladder_permutate import LadderPermutate, LadderPermutateConfig +from ..lop3_permutate import LOP3Permutate, LOP3PermutateConfig import logging import torch @@ -42,7 +41,7 @@ def is_native_compute(A_dtype, W_dtype) -> bool: @dataclass(frozen=True) -class MatmulConfig: +class MatmulConfig(OperatorConfig): M: Union[int, Tuple[int]] = None N: int = None K: int = None @@ -186,7 +185,7 @@ def __post_init__(self): class Matmul(Operator): - # TODO(lei): This should be improved into a general datatype. + # TODO(lei): This should be improved into a general datatype class. BITBLAS_TRICK_DTYPE_MAP = { "float64": ("fp", 64), "float32": ("fp", 32), @@ -216,6 +215,7 @@ def __init__( target: Optional[Union[str, Target]] = None, enable_tuning: bool = True, from_database: bool = False, + backend: str = "tir", ): # if from database, we should disable default schedule # to save compilation time @@ -228,6 +228,7 @@ def __init__( self.source_format = source_format self.bit = bit + self.backend = backend super().__init__(name, config, target) if source_format == "int" and self.with_zeros: @@ -239,6 +240,10 @@ def __init__( if target.kind.name != "cuda": raise ValueError("Currently only support cuda target") + self.dispatch_tir(target, from_database, source_format, enable_tuning) + + def dispatch_tir(self, target: Target, from_database: bool = False, source_format: str = "uint", enable_tuning: bool = True): + '''Dispatch the tir script implementation''' self.arch = CUDA(target) if isinstance(self.M, Tuple): @@ -252,6 +257,50 @@ def __init__( self._build_default_module(target) self.workspace = None + if source_format == "nf": + self.lut = torch.tensor( + [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ], + dtype=getattr(torch, self.A_dtype), + ).cuda() + else: + self.lut = None + + # create permutate_opertors + self.ladder_permutate_a = self._assign_ladder_permutate_a(target, enable_tuning) + self.ladder_permutate_b = self._assign_ladder_permutate_b(target, enable_tuning) + self.lop3_permutate = self._assign_lop3_permutate(target, enable_tuning) + # create cpu weight executors + self.input_executors = self._create_input_executors() + self.weight_executors = self._create_weight_executors() + + if enable_tuning: + self.hardware_aware_finetune() + + # output data type + self.torch_output_dtype = getattr(torch, self.out_dtype) + + def _alloc_workspace(self): + return torch.empty(WORKSPACE_SIZE, dtype=torch.float16).cuda() + + def _assign_ladder_permutate_a(self, target: Target, enable_tuning: bool): + ladder_permutate_a = None if self.propagate_a: # for general purpose, we use propagate_a to control the ladder permutation. ladder_permutate_config = LadderPermutateConfig( @@ -263,14 +312,18 @@ def __init__( transpose_matrix=False, transform_kind=self.propagate_a, ) - self.ladder_permutate_a = LadderPermutate( + ladder_permutate_a = LadderPermutate( config=ladder_permutate_config, target=target, enable_tuning=enable_tuning, ) - self.workspace = torch.empty(WORKSPACE_SIZE, dtype=torch.float16).cuda() - else: - self.ladder_permutate_a = None + self.workspace = self._alloc_workspace() + return ladder_permutate_a + + def _assign_ladder_permutate_b(self, target: Target, enable_tuning: bool): + # unused variables + del target + del enable_tuning if self.propagate_b: ladder_permutate_config = LadderPermutateConfig( @@ -283,13 +336,16 @@ def __init__( transpose_matrix=self.layout == "nt", transform_kind=self.propagate_b, ) - self.ladder_permutate_b = LadderPermutate( + return LadderPermutate( config=ladder_permutate_config, target=tvm.target.Target("llvm"), ) - else: - self.ladder_permutate_b = None + return None + def _assign_lop3_permutate(self, target: Target, enable_tuning: bool): + # unused variables + del target + del enable_tuning if self.fast_decoding: assert self.source_format in ["int", "uint"] lop3_permutate_config = LOP3PermutateConfig( @@ -299,68 +355,25 @@ def __init__( dequantize_bits=self.bit, storage_dtype=self.storage_dtype, ) - self.lop3_permutate = LOP3Permutate( + return LOP3Permutate( config=lop3_permutate_config, target=tvm.target.Target("llvm"), ) - else: - self.lop3_permutate = None + return None + def _create_input_executors(self): input_executors = OPExecutorCPU() - if self.ladder_permutate_a is not None: + if self.propagate_a is not TransformKind.NonTransform: input_executors.append(self.ladder_permutate_a) - self.input_executors = input_executors - + return input_executors + + def _create_weight_executors(self): weight_executors = OPExecutorCPU() - if self.lop3_permutate is not None: + if self.fast_decoding: weight_executors.append(self.lop3_permutate) - - if self.ladder_permutate_b is not None: + if self.propagate_b is not TransformKind.NonTransform: weight_executors.append(self.ladder_permutate_b) - - self.weight_executors = weight_executors - - if enable_tuning: - self.hardware_aware_finetune() - - if source_format == "nf": - self.lut = torch.tensor( - [ - -1.0, - -0.6961928009986877, - -0.5250730514526367, - -0.39491748809814453, - -0.28444138169288635, - -0.18477343022823334, - -0.09105003625154495, - 0.0, - 0.07958029955625534, - 0.16093020141124725, - 0.24611230194568634, - 0.33791524171829224, - 0.44070982933044434, - 0.5626170039176941, - 0.7229568362236023, - 1.0, - ], - dtype=getattr(torch, self.A_dtype), - ).cuda() - else: - self.lut = None - - # output data type - self.torch_output_dtype = getattr(torch, self.out_dtype) - - def _build_default_module(self, target: Target): - try: - self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) - except Exception: - self.optimized_func = None - logger.warning( - "[BitBLAS][Warning] Apply default schedule failed, should do hardware-aware optimization manually." - ) - - self._build_runtime_module(target) + return weight_executors def _select_implementation(self): if is_native_compute(self.A_dtype, self.W_dtype): diff --git a/bitblas/ops/general_matmul/cuda/__init__.py b/bitblas/ops/general_matmul/cuda/__init__.py new file mode 100644 index 000000000..a0366abd3 --- /dev/null +++ b/bitblas/ops/general_matmul/cuda/__init__.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# TODO: Not Implemented Yet +from bitblas.ops.operator import TransformKind +from bitblas.base import TileDevice +from .template import i4_scale_template_source + + +class MatmulDequantizeCudaEmitter: + + def __init__( + self, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", + propagate_a: TransformKind = TransformKind.NonTransform, + propagate_b: TransformKind = TransformKind.NonTransform, + ): + self.N = N + self.K = K + self.in_dtype = in_dtype + self.out_dtype = out_dtype + self.accum_dtype = accum_dtype + self.bit = bit + self.storage_dtype = storage_dtype + self.source_format = source_format + self.with_scaling = with_scaling + self.with_zeros = with_zeros + self.group_size = group_size if group_size != -1 else K + self.fast_decoding = fast_decoding + self.with_bias = with_bias + self.zeros_mode = zeros_mode + self.propagate_a = self._legalize_transform_kind(propagate_a) + self.propagate_b = self._legalize_transform_kind(propagate_b) + + def _legalize_group_size(self): + if self.group_size == -1: + self.group_size = self.K + + def _legalize_transform_kind(self, propagate): + if propagate is None: + return TransformKind.NonTransform + if isinstance(propagate, bool): + return (TransformKind.IntraWarpTransform if propagate else TransformKind.NonTransform) + elif isinstance(propagate, int): + return TransformKind(propagate) + + def is_available(self, arch: TileDevice): + conditions = [] + # group size must be -1, 128, k + conditions.append(self.group_size in [-1, 128, self.K]) + # source format must be int + conditions.append(self.source_format == "int") + # with scaling must be true + conditions.append(self.with_scaling) + # with zeros must be false + conditions.append(not self.with_zeros) + # bit must be 4 + conditions.append(self.bit == 4) + # in_dtype must be float16 + conditions.append(self.in_dtype == "float16") + # out_dtype must be float16 + conditions.append(self.out_dtype == "float16") + # accum_dtype must be float32 + conditions.append(self.accum_dtype == "float32") + # sm version must be 80 (A100) + conditions.append(self.arch.sm_version == 80) + return all(conditions) + + def get_weight_transform(self): + raise NotImplementedError + + def get_scale_transform(self): + raise NotImplementedError + + def get_wrapped_source(self): + wrapped_source = f""" + extern "C" void init() {{ + + }} + extern "C" void call(half* __restrict__ A, int8_t* __restrict__ B, half* __restrict__ Scale, half* __restrict__ C, int m, void* workspace, cudaStream_t stream=cudaStreamDefault) {{ + marlin_cuda(A, B, C, Scale, m, {self.N}, {self.K}, workspace, {self.group_size}, 0, -1, -1, 108, 16); + }} + """ + return i4_scale_template_source + wrapped_source diff --git a/bitblas/ops/general_matmul/cuda/template.py b/bitblas/ops/general_matmul/cuda/template.py new file mode 100644 index 000000000..a088e12fa --- /dev/null +++ b/bitblas/ops/general_matmul/cuda/template.py @@ -0,0 +1,830 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +i4_scale_template_source = """ +// Copyright 2018 The apache/tvm Authors. All Rights Reserved. +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +// Modifications Copyright (c) Microsoft. +// The code below is mostly copied from marlin_cuda in IST-DASLab/marlin. + +#ifndef MARLIN_CUDA_KERNEL_CUH +#define MARLIN_CUDA_KERNEL_CUH + + +#include +#include +#include +#include + + +constexpr int ceildiv(int a, int b) { + return (a + b - 1) / b; +} + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed for instance as inputs to tensor core +// operations. Consequently, all corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { + return elems[i]; + } +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales + +// Predicated asynchronous global->shared copy; used for inputs A where we apply predication to handle batchsizes that +// are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" :: "r"((int) pred), "r"(smem), "l"(glob_ptr), "n"(BYTES) + ); +} + +// Asynchronous global->shared copy with a cache hint indicating that the values may be evicted immediately; used for +// quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need +// for inputs A and outputs C. +__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .b64 p;\n" + " createpolicy.fractional.L2::evict_first.b64 p, 1.0;" + " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n" + "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES) + ); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" :: "n"(n)); +} + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 output/accumulation. +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]) + ); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem) + ); +} + +// Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to +// automatically recognize it in all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut) + ); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values. +// We mostly follow the strategy in the link below, with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + FragB frag_b; + frag_b[0] = __hsub2( + *reinterpret_cast(&lo), + *reinterpret_cast(&SUB) + ); + frag_b[1] = __hfma2( + *reinterpret_cast(&hi), + *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) + ); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used only for grouped quantization. +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible globally. + asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible globally, while releasing the barrier. + asm volatile ("fence.acq_rel.gpu;\n"); + asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); + } +} + + +template < + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m dimension (batchsize) of the threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const int stages, // number of stages for the async global->shared fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks with a separate quantization scale +> +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the same size, which might involve multiple + // column "slices" (of width 16 * `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it ensures good utilization of all SMs + // for many kinds of shape and GPU configurations, while requiring as few slow global cross-threadblock reductions as + // possible. + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + // Ensure that the number of tiles in each stripe is a multiple of the groupsize; this avoids an annoying special case + // where a stripe starts in the middle of group. + if (group_blocks != -1) + iters = (group_blocks / thread_k_blocks) * ceildiv(iters, (group_blocks / thread_k_blocks)); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to top + + // We can easily implement parallel problem execution by just remapping indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for synchronization. + auto init_slice = [&] () { + slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) + slice_iters = 0; + if (slice_iters == 0) + return; + if (slice_row + slice_iters > k_tiles) + slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) + slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) + slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + // We typically use `constexpr` to indicate that this value is a compile-time constant + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; // delta between subsequent A tiles in global memory + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // between shared memory tile reads + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // within a shared memory tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); // overall size of a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); // number of shared write iterations for a tile + + int b_gl_stride = 16 * prob_n / 32; + constexpr int b_sh_stride = 32 * thread_n_blocks / 4; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); + constexpr int b_sh_wr_delta = threads; + constexpr int b_sh_rd_delta = threads; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_sh_stage = s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x; + int b_sh_rd = threadIdx.x; + + int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; + int s_sh_wr = threadIdx.x; + int s_sh_rd; + // We use a different scale layout for grouped and column-wise quantization as we scale a `half2` tile in column-major + // layout in the former and in row-major in the latter case. + if (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + + // Precompute which thread should not read memory in which iterations; this is needed if there are more threads than + // required for a certain tilesize or when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // To ensure that writing and reading A tiles to/from shared memory, the latter in fragment format, is fully bank + // conflict free, we need to use a rather fancy XOR-based layout. The key here is that neither reads nor writes of + // the 16-byte `int4` blocks of 8 consecutive threads involve the same shared memory banks. Further, it seems (based + // on NSight-Compute) that each warp must also write a consecutive memory segment? + auto transform_a = [&] (int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main loop unrolls, all shared memory + // accesses are static, we simply precompute both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at runtime; we break dependencies between + // subsequent accesses with a tile by maintining multiple pointers (we have enough registers), a tiny optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s = sh_b + (stages * b_sh_stage); + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; + + // Zero accumulators. + auto zero_accums = [&] () { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + // Asynchronously fetch the next A, B and s tile from global to the next shared memory pipeline location. + auto fetch_to_shared = [&] (int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i] + ); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + B_ptr[i] += b_gl_rd_delta_o; + } + // Only fetch scales if this tile starts a new group + if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) + cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + s_gl_rd += s_gl_rd_delta; + } + } + // Insert a fence even when we are winding down the pipeline to ensure that waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&] () { + // We only have `stages - 2` active fetches since we are double buffering and can only issue the next fetch when + // it is guaranteed that the previous shared memory load is fully complete (as it may otherwise be overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe into the current register buffer. + auto fetch_to_registers = [&] (int k, int pipe) { + // It may seem inefficient that we reload the groups for every sub-tile; however, this does not seem to be a + // significant bottleneck, while some theoretically better attempts have lead to bad instruction ordering by the + // compiler and correspondingly a noticeable drop in performance. + if (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&] (int k) { + // We have the m dimension as the inner loop in order to encourage overlapping dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant = frag_b_quant[k % 2][j]; + int b_quant_shift = b_quant >> 8; + FragB frag_b0 = dequant(b_quant); + // If there are no groups, we can just scale the final output once and can avoid doing so for each weight. + if (group_blocks != -1) + scale(frag_b0, frag_s[k % 2][j], 0); + FragB frag_b1 = dequant(b_quant_shift); + if (group_blocks != -1) + scale(frag_b1, frag_s[k % 2][j], 1); + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the number of warps while keeping the n + // dimension of a tile reasonable, we have multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&] () { + constexpr int red_off = threads / b_sh_stride / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any unnecessary read or write iterations, + // e.g., for two warps we write only once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we finally have to globally reduce over + // the results. As the striped partitioning minimizes the number of such reductions and our outputs are usually rather + // small, we perform this reduction serially in L2 cache. + auto global_reduce = [&] (bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to maximize L2 cache utilization in this step. + // To do this, we write out results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up the compiler and lead to slowdowns, + // hence we also use async-copies even though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m + ); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += __half2float( + reinterpret_cast<__half*>(&c_red)[j] + ); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<__half*>(&c)[j] = __float2half( + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] + ); + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually reshuffle matrix fragments in this step, + // the reduction above is performed in fragment layout. + auto write_result = [&] () { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final global write patterns + auto write = [&] (int idx, float c0, float c1, FragS& s) { + half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + if (group_blocks == -1) // for per-column quantization we finally apply the scale here + res = __hmul2(res, s[0]); + ((half2*) sh)[idx] = res; + }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&] () { + #pragma unroll + for (int i = 0; i < stages - 1; i++) + fetch_to_shared(i, i, i < slice_iters); + zero_accums(); + wait_for_stage(); + fetch_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + }; + start_pipes(); + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to ensure all shared memory accesses are + // static. Note that both pipelines have even length meaning that the next iteration will always start at index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); + pipe++; + wait_for_stage(); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) + break; + } + a_gl_rd += a_gl_rd_delta_o * stages; + + // Process results and, if necessary, proceed to the next column slice. While this pattern may not be the most + // readable, other ways of writing the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before write-out + if (group_blocks == -1 && last) { + if (s_sh_wr_pred) + cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async_fence(); + } + thread_block_reduce(); + if (group_blocks == -1 && last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + if (slice_count > 1) { // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] -= b_gl_stride; + } + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + start_pipes(); + } + } + } +} + + +// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per schedule allows some more +// latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles. +const int THREADS = 256; +const int STAGES = 4; // 4 pipeline stages fit into shared memory +const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + +#define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \ + else if ( \ + thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS \ + ) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + SHARED_MEM \ + ); \ + Marlin< \ + THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS \ + ><<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, \ + prob_m, prob_n, prob_k, \ + locks \ + ); \ + } + +const int ERR_PROB_SHAPE = 1; +const int ERR_KERN_SHAPE = 2; + +int marlin_cuda( + const void* A, + const void* B, + void* C, + void* s, + int prob_m, + int prob_n, + int prob_k, + void* workspace, + int groupsize = -1, + int dev = 0, + cudaStream_t stream = 0, + int thread_k = -1, + int thread_n = -1, + int sms = -1, + int max_par = 16 +) { + int tot_m = prob_m; + int tot_m_blocks = ceildiv(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + if (sms == -1) + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + if (thread_k == -1 || thread_n == -1) { + if (prob_m <= 16) { + // For small batchizes, better partitioning is slightly more important than better compute utilization + thread_k = 128; + thread_n = 128; + } else { + thread_k = 64; + thread_n = 256; + } + } + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; + int blocks = sms; + + if (prob_n % thread_n != 0 || prob_k % thread_k != 0 || (group_blocks != -1 && prob_k % group_blocks != 0)) + return ERR_PROB_SHAPE; + if (prob_m == 0 || prob_n == 0 || prob_k == 0) + return 0; + + const int4* A_ptr = (const int4*) A; + const int4* B_ptr = (const int4*) B; + int4* C_ptr = (int4*) C; + const int4* s_ptr = (const int4*) s; + + int cols = prob_n / thread_n; + int* locks = (int*) workspace; + + int ret = 0; + for (int i = 0; i < tot_m_blocks; i += 4) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > 4) { + // Note that parallel > 1 currently only works for inputs without any padding + par = (16 * thread_m_blocks - pad) / 64; + if (par > max_par) + par = max_par; + prob_m = 64 * par; + i += 4 * (par - 1); + thread_m_blocks = 4; + } + + // For compilation speed, we only define the kernel configurations that have seemed useful (in terms of performance) + // in our testing, however many more are, in principle, possible. + if (false) {} + CALL_IF(1, 8, 8, -1) + CALL_IF(1, 8, 8, 8) + CALL_IF(1, 16, 4, -1) + CALL_IF(1, 16, 4, 8) + CALL_IF(2, 16, 4, -1) + CALL_IF(2, 16, 4, 8) + CALL_IF(3, 16, 4, -1) + CALL_IF(3, 16, 4, 8) + CALL_IF(4, 16, 4, -1) + CALL_IF(4, 16, 4, 8) + else + ret = ERR_KERN_SHAPE; + + A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + } + + return ret; +} + + +#endif +""" diff --git a/bitblas/ops/general_matmul/tilelang/__init__.py b/bitblas/ops/general_matmul/tilelang/__init__.py new file mode 100644 index 000000000..92956855c --- /dev/null +++ b/bitblas/ops/general_matmul/tilelang/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# TODO: Not Implemented Yet diff --git a/bitblas/ops/general_matmul/tirscript/__init__.py b/bitblas/ops/general_matmul/tirscript/__init__.py new file mode 100644 index 000000000..f783e05b3 --- /dev/null +++ b/bitblas/ops/general_matmul/tirscript/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from .matmul_dequantize_impl import select_implementation as matmul_dequantize_select_implementation # noqa: F401 +from .matmul_impl import select_implementation as matmul_select_implementation # noqa: F401 diff --git a/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py b/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py new file mode 100644 index 000000000..65f1c75e5 --- /dev/null +++ b/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py @@ -0,0 +1,968 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pre-transformed tir expression of matmul +from bitblas import tvm +from tvm import te, DataType +from tvm.tir import IndexMap +from bitblas.ops.operator import TransformKind +from bitblas.gpu.matmul_analysis import get_propagate_map +from bitblas.quantization import ( + _tir_packed_int_to_int_convert, + _tir_packed_to_signed_convert, + _tir_packed_to_unsigned_convert, + _tir_u32_to_f4_to_f16, + _tir_u8_to_f8_e4m3_to_f16, + _tir_packed_to_unsigned_convert_with_zeros, +) + + +class MatMulNTDequantizeEmitter: + + def __init__( + self, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", + propagate_a: TransformKind = TransformKind.NonTransform, + propagate_b: TransformKind = TransformKind.NonTransform, + ): + self.M = self._validate_dimension(M, "M") + self.N = N + self.K = K + self.in_dtype = in_dtype + self.out_dtype = out_dtype + self.accum_dtype = accum_dtype + self.bit = bit + self.storage_dtype = storage_dtype + self.source_format = source_format + self.with_scaling = with_scaling + self.with_zeros = with_zeros + self.group_size = group_size if group_size != -1 else K + self.fast_decoding = fast_decoding + self.with_bias = with_bias + self.zeros_mode = zeros_mode + self.propagate_a = self._legalize_transform_kind(propagate_a) + self.propagate_b = self._legalize_transform_kind(propagate_b) + + self._validate_bit() + self._validate_layout() + + @staticmethod + def _validate_dimension(dim, name): + if not isinstance(dim, int): + return tvm.te.var(name.lower()) + return dim + + def _validate_bit(self): + if self.bit not in [1, 2, 4, 8]: + raise ValueError(f"Unsupported bit: {self.bit}") + + def _validate_layout(self): + # TODO: extend the dequantize operators into General Layout + pass + + def _legalize_group_size(self): + if self.group_size == -1: + self.group_size = self.K + + def _legalize_transform_kind(self, propagate): + if propagate is None: + return TransformKind.NonTransform + if isinstance(propagate, bool): + return (TransformKind.IntraWarpTransform if propagate else TransformKind.NonTransform) + elif isinstance(propagate, int): + return TransformKind(propagate) + + def _create_placeholders(self): + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + in_dtype = self.in_dtype + bit = self.bit + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + + A = te.placeholder((self.M, self.K), name="A", dtype=in_dtype) + B = te.placeholder((self.N, self.K // storage_nbit * bit), name="B", dtype=storage_dtype) + if self.propagate_a: + A = te.placeholder((self.M // l, self.K // r, l, r), name="A", dtype=in_dtype) + if self.propagate_b: + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) + qr = r * bit // storage_nbit + B = te.placeholder((self.N // l, (self.K // scaling_factor) // qr, l, qr), + name="B", + dtype=storage_dtype) + + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=in_dtype) + Zeros = te.placeholder((self.N, self.K // self.group_size), name="Zeros", dtype=in_dtype) + QZeros = te.placeholder(((self.K // self.group_size), self.N // storage_nbit * bit), + name="QZeros", + dtype=self.storage_dtype) + Bias = te.placeholder((self.N,), name="Bias", dtype=in_dtype) + return A, B, LUT, Scale, Zeros, QZeros, Bias + + def _propagate_input(self, tensor, transform_kind=TransformKind.NonTransform, matrix_name="A"): + if transform_kind == TransformKind.NonTransform: + return tensor + in_dtype = self.in_dtype + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + _, inversed_index_map = get_propagate_map( + trans=False, dtype=in_dtype, matrix_name=matrix_name) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return tensor[new_index] + + return te.compute( + (self.M, self.K), + fcompute, + name=f"{matrix_name}_reindex", + ) + + def _propagage_weight(self, tensor, transform_kind=TransformKind.NonTransform, matrix_name="B"): + if transform_kind == TransformKind.NonTransform: + return tensor + in_dtype = self.in_dtype + bit = self.bit + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) + + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + _, inversed_index_map = get_propagate_map( + trans=True, dtype=in_dtype, matrix_name=matrix_name) + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) + initial_indices = inversed_index_map.initial_indices + scaling_final_indices = inversed_index_map.map_indices( + initial_indices[:-1] + [initial_indices[-1] * scaling_factor]) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + inversed_index_map = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + qr = r * bit // storage_nbit + + def fcompute(i, j): + warp_i, warp_j = i % l, j % qr + spatial_args = i // l, j // qr + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return tensor[new_index] + + return te.compute( + (self.N, self.K // storage_nbit * bit), + fcompute, + name=f"{matrix_name}_reindex", + ) + + def _decode_func(self, B, LUT, Scale, Zeros, QZeros): + bit = self.bit + in_dtype = self.in_dtype + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + + # TODO: Move the decode function into a more general place + def decode(n, k): + w = None + if self.with_zeros and self.zeros_mode == "quantized": + qzeros_dequantize = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + QZeros[k, n // n_float_per_elem], + n % n_float_per_elem, + dtype=self.storage_dtype, + ) + w = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit)( + bit, + B[n, k // n_float_per_elem], + k % n_float_per_elem, + qzeros_dequantize, + dtype=in_dtype, + ) + elif self.source_format == "uint": + if bit == 8: + w = B[n, k].astype(in_dtype) + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif self.source_format == "int": + if bit == 1: + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + if bit == 8: + w = B[n, k].astype(in_dtype) + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif self.source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif self.source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype) + elif self.source_format == "nf": + index = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", + ) + w = LUT[index] + else: + raise ValueError(f"Unsupported source_format: {self.source_format}") + + assert w is not None, "w is None" + + group_size = self.group_size + zeros_mode = self.zeros_mode + + if not self.with_scaling: + return w + + if not self.with_zeros: + return w * Scale[n, k // group_size] + + if zeros_mode == "original": + w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] + elif zeros_mode == "rescale": + w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] + elif zeros_mode == "quantized": + w = w * Scale[n, k // group_size] + else: + raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) + + return w + + return te.compute((self.N, self.K), decode, name="B_decode") + + def _compute_matmul(self, A, B_decode): + k = te.reduce_axis((0, self.K), name="k") + C = te.compute( + (self.M, self.N), + lambda i, j: te.sum( + A[i, k].astype(self.accum_dtype) * B_decode[j, k].astype(self.accum_dtype), axis=k), + name="C", + ) + return C + + def _convert_dtype(self, tensor): + if self.accum_dtype != self.out_dtype: + return te.compute((self.M, self.N), + lambda i, j: tensor[i, j].astype(self.out_dtype), + name="D") + return tensor + + def _apply_bias(self, tensor, Bias): + if self.with_bias: + return te.compute((self.M, self.N), lambda i, j: tensor[i, j] + Bias[j], name="E") + return tensor + + def emit(self): + A, B, LUT, Scale, Zeros, QZeros, Bias = self._create_placeholders() + A_reindex = self._propagate_input(A, self.propagate_a, "A") + B_reindex = self._propagage_weight(B, self.propagate_b, "B") + + B_decode = self._decode_func(B_reindex, LUT, Scale, Zeros, QZeros) + C = self._compute_matmul(A_reindex, B_decode) + D = self._convert_dtype(C) + last_output = self._apply_bias(D, Bias) + + args = [A, B] + if self.source_format == "nf": + args.append(LUT) + if self.with_scaling: + args.append(Scale) + if self.with_zeros: + args.append(QZeros if self.zeros_mode == "quantized" else Zeros) + if self.with_bias: + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": self.fast_decoding, + "source_format": { + "bits": self.bit, + "format": self.source_format, + }, + "storage_dtype": self.storage_dtype, + "target_format": self.in_dtype, + "with_zeros": self.with_zeros, + "zeros_mode": self.zeros_mode, + "with_scaling": self.with_scaling, + "group_size": self.group_size, + } + }, + ) + if self.propagate_a: + func = func.with_attr("input_transform_kind", self.propagate_a.value) + if self.propagate_b: + func = func.with_attr("weight_transform_kind", self.propagate_b.value) + return tvm.IRModule.from_expr(func) + + +def matmul_nt_dequantize_b( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", +): + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) + if not isinstance(M, int): + M = tvm.te.var("m") + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + if group_size == -1: + group_size = K + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((N, K // storage_nbit * bit), name="B", dtype=storage_dtype) + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) + Zeros = te.placeholder((N, K // group_size), name="Zeros", dtype=in_dtype) + QZeros = te.placeholder(((K // group_size), N // storage_nbit * bit), + name="QZeros", + dtype=storage_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def qzeros_dequantize(k, n): + return _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + QZeros[k, n // n_float_per_elem], + n % n_float_per_elem, + dtype=storage_dtype, + ) + + Dequantize_qzeros = None + if with_zeros and zeros_mode == "quantized": + Dequantize_qzeros = te.compute( + (K // group_size, N), + qzeros_dequantize, + name="Dequantize_zeros", + ) + + def decode_func(n, k): + if with_zeros and zeros_mode == "quantized": + assert Dequantize_qzeros is not None, "Dequantize_zeros is None" + w = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit)( + bit, + B[n, k // n_float_per_elem], + k % n_float_per_elem, + Dequantize_qzeros[k // group_size, n], + dtype=in_dtype, + ) + elif source_format == "uint": + if bit == 8: + # 8 bit does not need to be compressed + w = B[n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "int": + if bit == 1: + # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B[n, k].astype(in_dtype) + else: + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype) + elif source_format == "nf": + w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", # assume the index data type is int32 + )] + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + if not with_scaling: + return w + + if not with_zeros: + return w * Scale[n, k // group_size] + + if zeros_mode == "original": + w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] + elif zeros_mode == "rescale": + w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] + elif zeros_mode == "quantized": + w = w * Scale[n, k // group_size] + else: + raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) + + return w + + B_decode = te.compute((N, K), decode_func, name="B_decode") + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k), + name="C", + ) + + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + args = [A, B] + if source_format == "nf": + args.append(LUT) + if with_scaling: + args.append(Scale) + if with_zeros: + if zeros_mode == "quantized": + args.append(QZeros) + else: + args.append(Zeros) + if with_bias: + last_output = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": fast_decoding, + "source_format": { + "bits": bit, + "format": source_format, + }, + "storage_dtype": storage_dtype, + "target_format": in_dtype, + "with_scaling": with_scaling, + "with_zeros": with_zeros, + "zeros_mode": zeros_mode, + "group_size": group_size, + } + }, + ) + return tvm.IRModule.from_expr(func) + + +def matmul_nt_dequantize_b_propagate_b( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", + transform_kind: TransformKind = TransformKind.IntraWarpTransform, +): + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) + if not isinstance(M, int): + M = tvm.te.var("m") + + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + + _, inverse_indexmap = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) + initial_indices = inverse_indexmap.initial_indices + scaling_final_indices = inverse_indexmap.map_indices(initial_indices[:-1] + + [initial_indices[-1] * scaling_factor]) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + inverse_indexmap = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + if group_size == -1: + group_size = K + qr = r * bit // storage_nbit + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((N // l, (K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) + Zeros = te.placeholder((N, K // group_size), name="Zeros", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % qr + spatial_args = i // l, j // qr + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inverse_indexmap.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return B[new_index] + + B_reindex = te.compute( + (N, K // storage_nbit * bit), + fcompute, + name="B_reindex", + ) + + def decode_func(n, k): + if source_format == "uint": + if bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "int": + if bit == 1: + # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) + else: + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[n, k], dtype=in_dtype) + elif source_format == "nf": + w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", # assume the index data type is int32 + )] + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + if not with_scaling: + return w + + if not with_zeros: + return w * Scale[n, k // group_size] + + if zeros_mode == "original": + w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] + elif zeros_mode == "rescale": + w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] + else: + raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) + + return w + + B_decode = te.compute((N, K), decode_func, name="B_decode") + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + args = [A, B] + if source_format == "nf": + args.append(LUT) + if with_scaling: + args.append(Scale) + if with_zeros: + args.append(Zeros) + if with_bias: + last_output = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": fast_decoding, + "source_format": { + "bits": bit, + "format": source_format, + }, + "storage_dtype": storage_dtype, + "target_format": in_dtype, + "with_zeros": with_zeros, + "zeros_mode": zeros_mode, + "with_scaling": with_scaling, + "group_size": group_size, + } + }, + ) + func = func.with_attr("weight_transform_kind", transform_kind.value) + return tvm.IRModule.from_expr(func) + + +def matmul_nt_dequantize_b_propagate_a_propagate_b( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", + transform_kind_input: TransformKind = TransformKind.IntraWarpTransform, + transform_kind_weight: TransformKind = TransformKind.IntraWarpTransform, +): + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) + if not isinstance(M, int): + M = tvm.te.var("m") + + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") + A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind_input >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return A[new_index] + + A_reindex = te.compute( + (M, K), + fcompute, + name="A_reindex", + ) + + _, inversed_index_map = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) + initial_indices = inversed_index_map.initial_indices + scaling_final_indices = inversed_index_map.map_indices( + initial_indices[:-1] + [initial_indices[-1] * scaling_factor]) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + inversed_index_map = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + if group_size == -1: + group_size = K + qr = r * bit // storage_nbit + B = te.placeholder((N // l, (K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) + Zeros = te.placeholder((N, K // group_size), name="Zeros", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % qr + spatial_args = i // l, j // qr + if transform_kind_weight >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return B[new_index] + + B_reindex = te.compute( + (N, K // storage_nbit * bit), + fcompute, + name="B_reindex", + ) + + def decode_func(n, k): + if source_format == "uint": + if bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "int": + # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. + if bit == 1: + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) + else: + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[n, k], dtype=in_dtype) + elif source_format == "nf": + w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", # assume the index data type is int32 + )] + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + if not with_scaling: + return w + + if not with_zeros: + return w * Scale[n, k // group_size] + + if zeros_mode == "original": + w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] + elif zeros_mode == "rescale": + w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] + else: + raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) + + return w + + B_decode = te.compute((N, K), decode_func, name="B_decode") + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A_reindex[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), + axis=k, + ), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + args = [A, B] + if source_format == "nf": + args.append(LUT) + if with_scaling: + args.append(Scale) + if with_zeros: + args.append(Zeros) + if with_bias: + last_output = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": fast_decoding, + "source_format": { + "bits": bit, + "format": source_format, + }, + "storage_dtype": storage_dtype, + "target_format": in_dtype, + "with_zeros": with_zeros, + "zeros_mode": zeros_mode, + "with_scaling": with_scaling, + "group_size": group_size, + } + }, + ) + func = func.with_attr("input_transform_kind", transform_kind_input.value) + func = func.with_attr("weight_transform_kind", transform_kind_weight.value) + return tvm.IRModule.from_expr(func) + + +# Should be refactored with Emitter +def select_implementation( + M=None, + N=1024, + K=1024, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + layout="nt", + zeros_mode="original", + propagate_a=False, + propagate_b=False, +): + if layout == "nn": + raise ValueError( + "Currently only support propagate_a=False and propagate_b=False for layout=nn in Dequantize Implementation" + ) + elif layout == "nt": + if propagate_a and propagate_b: + return matmul_nt_dequantize_b_propagate_a_propagate_b( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + with_zeros, + group_size, + fast_decoding, + with_bias, + zeros_mode, + transform_kind_input=propagate_a, + transform_kind_weight=propagate_b, + ) + elif propagate_a: + raise NotImplementedError + elif propagate_b: + return matmul_nt_dequantize_b_propagate_b( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + with_zeros, + group_size, + fast_decoding, + with_bias, + zeros_mode, + transform_kind=propagate_b, + ) + else: + return matmul_nt_dequantize_b( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + with_zeros, + group_size, + fast_decoding, + with_bias, + zeros_mode, + ) + else: + raise ValueError(f"Unsupported layout: {layout}") diff --git a/bitblas/ops/general_matmul/tirscript/matmul_impl.py b/bitblas/ops/general_matmul/tirscript/matmul_impl.py new file mode 100644 index 000000000..b093f0d9c --- /dev/null +++ b/bitblas/ops/general_matmul/tirscript/matmul_impl.py @@ -0,0 +1,356 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pre-transformed tir expression of matmul +from bitblas import tvm +from tvm import te +from bitblas.gpu.matmul_analysis import get_propagate_map +from bitblas.ops.operator import TransformKind + + +def matmul_nn( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, +): + if not isinstance(M, int): + M = tvm.te.var("m") + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((K, N), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum(A[i, k].astype(accum_dtype) * B[k, j].astype(accum_dtype), axis=k), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + last_output = E + + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + + func = te.create_prim_func(args) + + return tvm.IRModule.from_expr(func) + + +def matmul_nt( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, +): + if not isinstance(M, int): + M = tvm.te.var("m") + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((N, K), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum(A[i, k].astype(accum_dtype) * B[j, k].astype(accum_dtype), axis=k), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + last_output = E + + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + + func = te.create_prim_func(args) + + return tvm.IRModule.from_expr(func) + + +def matmul( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", +): + if layout == "nn": + return matmul_nn(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) + return matmul_nt(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) + + +def matmul_nt_propagate_a( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + transform_kind: TransformKind = TransformKind.IntraWarpTransform, +): + if not isinstance(M, int): + M = tvm.te.var("m") + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + + _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") + + A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) + B = te.placeholder((N, K), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return A[new_index] + + A_reindex = te.compute( + (M, K), + fcompute, + name="A_reindex", + ) + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A_reindex[i, k].astype(accum_dtype) * B[j, k].astype(accum_dtype), axis=k), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + last_output = E + + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + + func = te.create_prim_func(args) + func = func.with_attr("input_transform_kind", transform_kind.value) + + return tvm.IRModule.from_expr(func) + + +def matmul_nt_propagate_b( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + transform_kind: TransformKind = TransformKind.IntraWarpTransform, +): + if not isinstance(M, int): + M = tvm.te.var("m") + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + + _, inversed_index_map = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") + + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((N // l, K // r, l, r), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return B[new_index] + + B_reindex = te.compute( + (N, K), + fcompute, + name="B_reindex", + ) + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A[i, k].astype(accum_dtype) * B_reindex[j, k].astype(accum_dtype), axis=k), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + last_output = E + + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + + func = te.create_prim_func(args) + func = func.with_attr("weight_transform_kind", transform_kind.value) + + return tvm.IRModule.from_expr(func) + + +def matmul_nt_propagate_a_propagate_b( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + transform_kind_input: TransformKind = TransformKind.IntraWarpTransform, + transform_kind_weight: TransformKind = TransformKind.IntraWarpTransform, +): + if not isinstance(M, int): + M = tvm.te.var("m") + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + + A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) + B = te.placeholder((N // l, K // r, l, r), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind_input >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return A[new_index] + + A_reindex = te.compute( + (M, K), + fcompute, + name="A_reindex", + ) + + _, inversed_index_map = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind_weight >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return B[new_index] + + B_reindex = te.compute( + (N, K), + fcompute, + name="B_reindex", + ) + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A_reindex[i, k].astype(accum_dtype) * B_reindex[j, k].astype(accum_dtype), + axis=k, + ), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + last_output = E + + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + + func = te.create_prim_func(args) + func = func.with_attr("input_transform_kind", transform_kind_input.value) + func = func.with_attr("weight_transform_kind", transform_kind_weight.value) + + return tvm.IRModule.from_expr(func) + + +def select_implementation( + M=None, + N=16384, + K=16384, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", + propagate_a: TransformKind = TransformKind.NonTransform, + propagate_b: TransformKind = TransformKind.NonTransform, +): + if layout == "nn": + if propagate_a or propagate_b: + raise ValueError( + "Currently only support propagate_a=False and propagate_b=False for layout=nn") + return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) + elif layout == "nt": + if propagate_a and propagate_b: + return matmul_nt_propagate_a_propagate_b( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + with_bias, + transform_kind_input=propagate_a, + transform_kind_weight=propagate_b, + ) + elif propagate_a: + return matmul_nt_propagate_a( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + with_bias, + transform_kind=propagate_a, + ) + elif propagate_b: + return matmul_nt_propagate_b( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + with_bias, + transform_kind=propagate_b, + ) + else: + return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) + else: + raise ValueError(f"Unsupported layout: {layout}") diff --git a/bitblas/ops/ladder_permutate.py b/bitblas/ops/ladder_permutate/__init__.py similarity index 96% rename from bitblas/ops/ladder_permutate.py rename to bitblas/ops/ladder_permutate/__init__.py index 70999b09d..6644705cd 100644 --- a/bitblas/ops/ladder_permutate.py +++ b/bitblas/ops/ladder_permutate/__init__.py @@ -2,8 +2,8 @@ # Licensed under the MIT License. from tvm.target import Target from typing import Literal, Union -from .operator import Operator -from .impl.ladder_permutate_impl import select_implementation +from ..operator import Operator +from .ladder_permutate_impl import select_implementation from dataclasses import dataclass diff --git a/bitblas/ops/ladder_permutate/ladder_permutate_impl.py b/bitblas/ops/ladder_permutate/ladder_permutate_impl.py new file mode 100644 index 000000000..76b5a01fb --- /dev/null +++ b/bitblas/ops/ladder_permutate/ladder_permutate_impl.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas.gpu.matmul_analysis import get_propagate_map +from typing import Literal +from tvm import te, IRModule, DataType +from tvm.tir import IndexMap + + +def select_implementation( + M: int, + N: int, + datatype: Literal["float16", "int8", "e4m3_float8", "e5m2_float8"] = "float16", + dequantize_bits: int = -1, + storage_dtype: Literal["float16", "int8", "uint8", "int32", "uint32"] = "float16", + propagate_kind: Literal["A", "B"] = "B", + transpose_matrix: bool = False, + transform_kind: int = 0, + target_instruction: Literal["nvidia-mma"] = "nvidia-mma", +): + if target_instruction != "nvidia-mma": + raise ValueError("Currently only support nvidia-mma instruction") + + # This is trick to get the basic tile size for the current datatype + # as for nvidia tensorcore instruction, the basic tile size is 16x16/16x32 for float16/int8 + l = r = 16 # noqa: E741 + if datatype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + intra_index_map, _ = get_propagate_map( + transpose_matrix, dtype=datatype, matrix_name=propagate_kind) + + target_dtype = DataType(datatype) + scaling_factor = 1 + if dequantize_bits > 0 and dequantize_bits < target_dtype.bits: + scaling_factor = ((target_dtype.bits // dequantize_bits) * DataType(storage_dtype).bits // + target_dtype.bits) + r = r // scaling_factor + initial_indices = intra_index_map.initial_indices + scaling_final_indices = intra_index_map.map_indices(initial_indices[:-1] + + [initial_indices[-1] * scaling_factor]) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + intra_index_map = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + inp = te.placeholder((M, N // scaling_factor), name="inp", dtype=storage_dtype) + args = [inp] + + assert transform_kind != 0, "Permute only apply when transform_kind >= 1" + if transform_kind >= 1: + arg = args[-1] + + inter_warp = te.compute( + (M // l, (N // scaling_factor) // r, l, r), + lambda i, j, ii, jj: arg[i * l + ii, j * r + jj], + name="inter_warp_permutate", + ) + args.append(inter_warp) + if transform_kind >= 2: + arg = args[-1] + + def fcompute(*args): + warp_i, warp_j = args[-2:] + spatial_args = args[:-2] + permutate_i, permutate_j = intra_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, permutate_i, permutate_j) + return arg[new_index] + + intra_warp = te.compute( + (M // l, (N // scaling_factor) // r, l, r), + fcompute, + name="intra_warp_permutate", + ) + args.append(intra_warp) + args = [args[0], args[-1]] + + func = te.create_prim_func(args) + + return IRModule.from_expr(func) diff --git a/bitblas/ops/lop3_permutate.py b/bitblas/ops/lop3_permutate/__init__.py similarity index 95% rename from bitblas/ops/lop3_permutate.py rename to bitblas/ops/lop3_permutate/__init__.py index 867432a5e..7715be471 100644 --- a/bitblas/ops/lop3_permutate.py +++ b/bitblas/ops/lop3_permutate/__init__.py @@ -2,8 +2,8 @@ # Licensed under the MIT License. from tvm.target import Target from typing import Literal, Union -from .operator import Operator -from .impl.lop3_permutate_impl import select_implementation +from ..operator import Operator +from .lop3_permutate_impl import select_implementation from dataclasses import dataclass import torch diff --git a/bitblas/ops/lop3_permutate/lop3_permutate_impl.py b/bitblas/ops/lop3_permutate/lop3_permutate_impl.py new file mode 100644 index 000000000..07d8f4f0c --- /dev/null +++ b/bitblas/ops/lop3_permutate/lop3_permutate_impl.py @@ -0,0 +1,152 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from typing import Literal +from tvm import DataType +from tvm import IRModule +from tvm.ir import GlobalVar +from tvm.script import tir as T + + +# fmt: off +# TIR interleave weight impl-> 2D implementation +def tir_interleave_weight( + N: int = 2, + K: int = 16, + bits: int = 4, + QK: int = -1, + target_dtype: str = "float16", + storage_dtype: str = "int32", +): + if QK == -1: + QK = K * bits // 32 + bits_stride = DataType(target_dtype).bits + mask = (1 << bits) - 1 # for 4bit the val is 0x0000000f + num_groups = 32 // bits_stride + elems_per_group = bits_stride // bits + + @T.prim_func + def interleave_weight(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), storage_dtype)): + for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): + with T.block("B"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + offset = v2 * elems_per_group + v3 + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) + + @T.prim_func + def interleave_weight_f16_2b(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), + storage_dtype)): + B_tmp_1 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_2 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_3 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): + with T.block("B_tmp"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + offset = v2 * elems_per_group + v3 + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) + + for ax0, ax1 in T.grid(N, QK): + with T.block("B"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + B_tmp_1[v0, v1] = B[v0, v1] & T.uint32(0xFF0000FF) + B_tmp_2[v0, v1] = ((B[v0, v1] & T.uint32(0x00FF0000)) << 8) >> 16 + B_tmp_3[v0, v1] = ((B[v0, v1] & T.uint32(0x0000FF00)) << 16) >> 8 + B[v0, v1] = B_tmp_1[v0, v1] | B_tmp_2[v0, v1] | B_tmp_3[v0, v1] + + @T.prim_func + def interleave_weight_f16_1b(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), + storage_dtype)): + B_tmp_1 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_2 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_3 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_4 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_5 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_6 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_7 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): + with T.block("B_tmp"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + offset = v2 * elems_per_group + v3 + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) + + for ax0, ax1 in T.grid(N, QK): + with T.block("B"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + B_tmp_1[v0, v1] = B[v0, v1] & T.uint32(0xF000000F) + B_tmp_2[v0, v1] = ((B[v0, v1] & T.uint32(0x000000F0)) >> 4) << 8 + B_tmp_3[v0, v1] = ((B[v0, v1] & T.uint32(0x00000F00)) >> 8) << 16 + B_tmp_4[v0, v1] = ((B[v0, v1] & T.uint32(0x0000F000)) >> 12) << 24 + B_tmp_5[v0, v1] = ((B[v0, v1] & T.uint32(0x000F0000)) >> 16) << 8 + B_tmp_6[v0, v1] = ((B[v0, v1] & T.uint32(0x00F00000)) >> 20) << 12 + B_tmp_7[v0, v1] = ((B[v0, v1] & T.uint32(0x00F00000)) >> 24) << 20 + B[v0, v1] = ( + B_tmp_1[v0, v1] + | B_tmp_2[v0, v1] + | B_tmp_3[v0, v1] + | B_tmp_4[v0, v1] + | B_tmp_5[v0, v1] + | B_tmp_6[v0, v1] + | B_tmp_7[v0, v1]) + + @T.prim_func + def interleave_weight_int8_1b(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), + storage_dtype)): + B_tmp_1 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_2 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_3 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_4 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_5 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): + with T.block("B_tmp"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + offset = v2 * elems_per_group + v3 + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) + + for ax0, ax1 in T.grid(N, QK): + with T.block("B"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + B_tmp_1[v0, v1] = B[v0, v1] & T.uint32(0xF0F00F0F) + B_tmp_2[v0, v1] = ((B[v0, v1] & T.uint32(0x000000F0)) >> 4) << 16 + B_tmp_3[v0, v1] = ((B[v0, v1] & T.uint32(0x0000F000)) >> 12) << 24 + B_tmp_4[v0, v1] = ((B[v0, v1] & T.uint32(0x000F0000)) >> 16) << 4 + B_tmp_5[v0, v1] = ((B[v0, v1] & T.uint32(0x0F000000)) >> 24) << 12 + B[v0, v1] = ( + B_tmp_1[v0, v1] + | B_tmp_2[v0, v1] + | B_tmp_3[v0, v1] + | B_tmp_4[v0, v1] + | B_tmp_5[v0, v1]) + + if target_dtype == "float16" and bits == 2: + return interleave_weight_f16_2b + elif target_dtype == "float16" and bits == 1: + return interleave_weight_f16_1b + elif target_dtype == "int8" and bits == 1: + return interleave_weight_int8_1b + + return interleave_weight + + +# fmt: on + + +def select_implementation( + M: int, + N: int, + datatype: Literal["float16", "int8"] = "float16", + storage_dtype: Literal["int8", "uint8", "int32", "uint32"] = "int32", + dequantize_bits: int = 4, +): + func = tir_interleave_weight( + N=M, + K=N, + bits=dequantize_bits, + target_dtype=datatype, + storage_dtype=storage_dtype, + ) + mod = IRModule() + mod.update_func(GlobalVar("main"), func) + return mod diff --git a/bitblas/ops/matmul.py b/bitblas/ops/matmul.py index 34014abb9..af0370294 100644 --- a/bitblas/ops/matmul.py +++ b/bitblas/ops/matmul.py @@ -148,7 +148,7 @@ def __init__( input_executors = TransformExecutorCPU() if self.ladder_permutate_a is not None: - input_executors.append(self.ladder_permutate_b) + input_executors.append(self.ladder_permutate_a) self.input_executors = input_executors @@ -161,17 +161,6 @@ def __init__( if enable_tuning: self.hardware_aware_finetune() - def _build_default_module(self, target: Target): - try: - self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) - except Exception: - self.optimized_func = None - logger.warning( - "[BitBLAS][Warning] Apply default schedule failed, should do hardware-aware optimization manually." - ) - - self._build_runtime_module(target) - def _select_implementation(self): return select_implementation( M=self.M, diff --git a/bitblas/ops/matmul_dequantize.py b/bitblas/ops/matmul_dequantize.py index 7381b3f12..6971547b0 100644 --- a/bitblas/ops/matmul_dequantize.py +++ b/bitblas/ops/matmul_dequantize.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. from bitblas import tvm from tvm.target import Target -from bitblas.base.roller.arch.cuda import CUDA +from bitblas.base.arch.cuda import CUDA from typing import Any, List, Literal, Optional, Tuple, Union from .operator import Operator, TransformKind from .impl.matmul_dequantize_impl import select_implementation @@ -198,17 +198,6 @@ def __init__( if enable_tuning: self.hardware_aware_finetune() - def _build_default_module(self, target: Target): - try: - self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) - except Exception: - self.optimized_func = None - logger.warning( - "[BitBLAS][Warning] Apply default schedule failed, should do hardware-aware optimization manually." - ) - - self._build_runtime_module(target) - def _select_implementation(self): return select_implementation( M=self.M, diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index f3d391778..9c592f9f2 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -6,17 +6,16 @@ from tvm.target import Target from tvm.tir import PrimFunc from tvm.contrib.dlpack import to_pytorch_func -from tvm._ffi.base import _LIB, raise_last_ffi_error -from tvm._ffi._ctypes.types import TVMValue, ArgTypeCode import bitblas import ctypes from typing import List, Dict, Any, Optional import numpy as np from ..base import fast_tune, fast_tune_with_dynamic_range from copy import deepcopy -from bitblas.base.roller.arch import get_arch +from bitblas.base.arch import get_arch from bitblas.utils.tensor_adapter import tvm_tensor_to_torch -from bitblas.wrapper import CUDASourceWrapper, CUDASourceWrapperWithDynamic +from bitblas.builder.wrapper import TIRWrapper +from bitblas.builder.lib_generator import LibraryGenerator from dataclasses import dataclass from enum import IntEnum import logging @@ -30,10 +29,9 @@ class TransformKind(IntEnum): IntraWarpTransform = 2 -@dataclass +@dataclass(frozen=True) class OperatorConfig: """Base class for operator configurations. Used for typing.""" - pass @@ -58,9 +56,8 @@ def __init__(self, name, config: OperatorConfig, target: Target = None): self.num_output_args: int = ( 1 # todo(lei): should be analyzed from the prim_func. ) - self.wrapper = None - self.src_name = None - self.lib_name = None + self.lib_generator = LibraryGenerator(self.arch) + self.wrapper = TIRWrapper(self.arch) self.lib = None def get_source(self, target: Target = None) -> str: @@ -106,7 +103,7 @@ def tvm_callback_cuda_postproc(code, _): **self.pass_context }): rt_mod = tvm.build(self.optimized_func, target=target, name=self.name) - except Exception: # noqa: F841 + except Exception: # noqa: F841 logger.debug( "Failed to build optimized function for CUDA target with default schedule, Please consider enable hardware aware tuning!" ) @@ -124,18 +121,15 @@ def tvm_callback_cuda_postproc(code, _): self.torch_func = to_pytorch_func(rt_mod) if self.arch.platform == "CUDA": try: - if (self.dynamic_range is not None and len(self.optimized_func.functions) > 1): - wrapper = CUDASourceWrapperWithDynamic(self.optimized_func, - self.get_source(target), self.arch) - else: - wrapper = CUDASourceWrapper(self.optimized_func, self.get_source(target), - self.arch) - wrapper.compile_lib() - self.wrapper = wrapper - self.src_name = self.wrapper.src_name - self.lib_name = self.wrapper.lib_name - self.lib = self.wrapper.load_lib() + is_dynamic = ( + self.dynamic_range is not None and len(self.optimized_func.functions) > 1) + self.wrapper.assign_optimized_module(self.optimized_func) + wrapped_source = self.wrapper.wrap(self.get_source(target), is_dynamic) + self.lib_generator.update_lib_code(wrapped_source) + self.lib_generator.compile_lib() + self.lib = self.lib_generator.load_lib() self.lib.init() + except Exception as e: build_runtime_library_error = e logger.debug( @@ -159,6 +153,17 @@ def apply_default_schedule(self, func_mod: IRModule, target: Target) -> IRModule return optimized_mod return None + def _build_default_module(self, target: Target): + try: + self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) + except Exception: + self.optimized_func = None + logger.warning( + "[BitBLAS][Warning] Apply default schedule failed. Please perform hardware-aware tuning manually." + ) + + self._build_runtime_module(target) + def post_process(self, code: str) -> str: return code @@ -222,8 +227,8 @@ def var_warpper(v): def map_numpy_type(intype): typemap = { - 'e4m3_float8': 'float8_e4m3fn', - 'e5m2_float8': 'float8_e5m2', + "e4m3_float8": "float8_e4m3fn", + "e5m2_float8": "float8_e5m2", } if intype in typemap: return typemap[intype] @@ -266,16 +271,9 @@ def _tensor_adapter(self, tensor, device): else: raise RuntimeError("Not supported type: ", type(tensor)) - def _forward_from_tvm_args(self, *args): - _tvm_args = [self._tensor_adapter(arg, self.arch.device) for arg in args] - self.rt_mod(*_tvm_args) - - def _forward_from_tvm_nd_array(self, *args): - self.rt_mod(*args) - def _forward_from_torch_func(self, *args): - # torch func is not reliable as some datatypes they don't support - # like float8. + # Torch func is not reliable as the runtime overhead dlpack + # is not negaliable, ref to https://discuss.tvm.apache.org/t/strange-overhead-of-tvm-runtime-ndarray-from-dlpack/16516 self.torch_func(*args) return args[-1] @@ -292,39 +290,26 @@ def _forward_from_prebuild_lib(self, *args, stream=0): def call_lib(self, *args, stream=0): self.lib.call(*args, ctypes.c_void_p(stream)) - def _forward_from_tvm_lib_func(self, values): - tcodes = (ctypes.c_int * self.num_args)() - ret_val = TVMValue() - ret_tcode = ctypes.c_int() - for i in range(self.num_args): - tcodes[i] = ArgTypeCode.NDARRAY_HANDLE - if (_LIB.TVMFuncCall( - self.function_handle, - values, - tcodes, - ctypes.c_int(self.num_args), - ctypes.byref(ret_val), - ctypes.byref(ret_tcode), - ) != 0): - raise_last_ffi_error() - def __call__(self, *args: Any) -> Any: return self.forward(*args) def update_func(self, func: PrimFunc): self.prim_func_mod["main"] = func - def update_runtime_module(self, rt_mod, src_name=None, lib_name=None): + def update_runtime_module(self, rt_mod, srcpath=None, libpath=None): self.rt_mod = rt_mod self.time_evaluator = rt_mod.time_evaluator(rt_mod.entry_name, self.arch.device, number=10) self.function_handle = rt_mod.get_function(rt_mod.entry_name).handle self.torch_func = to_pytorch_func(rt_mod) - if src_name is not None: - self.src_name = src_name - if lib_name is not None: - self.lib_name = lib_name - self.lib = ctypes.CDLL(lib_name) + if srcpath is not None: + assert self.lib_generator is not None, "lib_generator is not initialized" + self.lib_generator.set_src_path(srcpath) + if libpath is not None: + assert self.lib_generator is not None, "lib_generator is not initialized" + self.lib_generator.set_lib_path(libpath) + self.lib = ctypes.CDLL(libpath) self.lib.init() + # TODO: update the lib code from srcpath @abstractmethod def _select_implementation(self) -> IRModule: @@ -334,6 +319,18 @@ def _select_implementation(self) -> IRModule: def prim_func(self): return self.prim_func_mod["main"] + @property + def srcpath(self): + return self.lib_generator.get_source_path() + + @property + def libpath(self): + return self.lib_generator.get_lib_path() + + @property + def wrapped_source(self): + return self.lib_generator.lib_code + class OPExecutorCPU: """ diff --git a/bitblas/utils/__init__.py b/bitblas/utils/__init__.py index 00bddc2a5..bdf9589f7 100644 --- a/bitblas/utils/__init__.py +++ b/bitblas/utils/__init__.py @@ -3,3 +3,4 @@ from .post_process import match_global_kernel, tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 # noqa: F401 from .tensor_adapter import tvm_tensor_to_torch, lazy_tvm_tensor_to_torch, lazy_torch_to_tvm_tensor # noqa: F401 from .target_detector import get_all_nvidia_targets, auto_detect_nvidia_target # noqa: F401 +from .rtmod_analysis import get_annotated_device_mod # noqa: F401 diff --git a/bitblas/utils/rtmod_analysis.py b/bitblas/utils/rtmod_analysis.py new file mode 100644 index 000000000..69a08dfdc --- /dev/null +++ b/bitblas/utils/rtmod_analysis.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas import tvm +from tvm import IRModule +from tvm.runtime import ndarray +from tvm.driver import lower +from tvm.target import Target +from typing import Tuple, List + + +def get_annotated_device_mod(mod: IRModule, target: Target) -> "IRModule": + """ + Lower the given IRModule and create a device module for the specified target. + + Parameters: + - mod: The input IRModule. + - target: The compilation target. + + Returns: + - A device module ready for execution. + """ + input_mod = lower(mod) + target_input_mod = {target: input_mod} + annotated_mods = {} + runtime = None + target_host = None + for tgt, mod in target_input_mod.items(): + if not isinstance(tgt, (str, Target)): + raise ValueError("The key of inputs must be str or " + "Target when inputs is dict.") + if not isinstance(mod, tvm.IRModule): + raise ValueError("inputs must be Schedule, IRModule, " + "or dict of str to IRModule.") + annotated_mods[tgt] = mod.with_attr("runtime", runtime) + annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) + if not target_host: + for tar, _ in annotated_mods.items(): + device_type = ndarray.device(tar.kind.name, 0).device_type + if device_type == ndarray.cpu(0).device_type: + target_host = tar + break + if not target_host: + target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" + annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) + for target, mod in annotated_mods.items(): + mixed_mod_passes = tvm.get_global_func("driver.mixed_mod_passes") + device_mod_passes = tvm.get_global_func("driver.device_mod_passes") + mod = mixed_mod_passes(mod, target)(mod) + device_mod = device_mod_passes(mod, target)(mod) + return device_mod + + +def get_thread_block_information(mod: IRModule) -> Tuple[List[int], List[int]]: + """ + Extracts the thread block and grid dimensions for the reduction block within a given IRModule. + + Parameters: + - mod: The input IRModule from which to extract thread block and grid information. + + Returns: + A tuple containing two lists: + - The first list contains the dimensions of the thread block (threadIdx.x, threadIdx.y, threadIdx.z). + - The second list contains the dimensions of the grid (blockIdx.x, blockIdx.y, blockIdx.z). + """ + + # Initialize the schedule from the IRModule + sch = tvm.tir.Schedule(mod) + + # Get the root block and its child blocks + root_block = sch.get_block("root") + child_blocks = sch.get_child_blocks(root_block) + + # Initialize default block and grid dimensions (1, 1, 1) + block_dims, grid_dims = [1, 1, 1], [1, 1, 1] + + for block in child_blocks: + # Get the loops surrounding the main block + loops = sch.get_loops(block) + + # Iterate over each loop to extract thread and block bindings + for loop in loops: + stmt = sch.get(loop) + thread_binding = stmt.thread_binding + extent = int(stmt.extent) + + # Skip loops without thread binding + if thread_binding: + if "threadIdx" in thread_binding.thread_tag: + block_dims["xyz".index(thread_binding.thread_tag[-1])] = extent + elif "blockIdx" in thread_binding.thread_tag: + grid_dims["xyz".index(thread_binding.thread_tag[-1])] = extent + + return block_dims, grid_dims diff --git a/bitblas/wrapper/general.py b/bitblas/wrapper/general.py index 1271329f1..aa76f6158 100644 --- a/bitblas/wrapper/general.py +++ b/bitblas/wrapper/general.py @@ -131,23 +131,23 @@ def __init__(self, optimized_mod: IRModule, source: str, arch: TileDevice): self.block_info: Union[List[int], Dict] = [1, 1, 1] self.grid_info: Union[List[int], Dict] = [1, 1, 1] self.parse_source_information() - self.src_name: Optional[str] = None - self.lib_name: Optional[str] = None + self.srcpath: Optional[str] = None + self.libpath: Optional[str] = None self.lib_code: Optional[str] = self.update_lib_code(source) def load_lib(self): - return ctypes.CDLL(self.lib_name) + return ctypes.CDLL(self.libpath) def remove_lib(self): - if self.lib_name: - os.remove(self.lib_name) - self.lib_name = None + if self.libpath: + os.remove(self.libpath) + self.libpath = None def compile_lib(self, timeout: float = None): arch = self.arch src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) compute_version = arch.compute_capability - lib_name = src.name.replace(".cu", ".so") + libpath = src.name.replace(".cu", ".so") command = [ "nvcc", @@ -162,7 +162,7 @@ def compile_lib(self, timeout: float = None): "-lcuda", f"-gencode=arch=compute_{compute_version},code=compute_{compute_version}", "-o", - lib_name, + libpath, ] src.write(self.lib_code) src.flush() @@ -174,8 +174,8 @@ def compile_lib(self, timeout: float = None): if ret.returncode != 0: logger.warning(f"Compilation Failed! {command}") return None - self.src_name = src.name - self.lib_name = lib_name + self.srcpath = src.name + self.libpath = libpath def parse_source_information(self): device_mod = get_annotated_device_mod(self.mod, self.arch.target) diff --git a/testing/cpp/CMakeLists.txt b/testing/cpp/CMakeLists.txt index cf8eb0d3a..b92fa8da7 100644 --- a/testing/cpp/CMakeLists.txt +++ b/testing/cpp/CMakeLists.txt @@ -12,4 +12,5 @@ find_package(GTest REQUIRED) include_directories(${GTEST_INCLUDE_DIRS}) +add_subdirectory(efficient_i4_cuda_impl) add_subdirectory(lop3_type_conversion) diff --git a/testing/cpp/efficient_i4_cuda_impl/CMakeLists.txt b/testing/cpp/efficient_i4_cuda_impl/CMakeLists.txt new file mode 100644 index 000000000..36ffdf548 --- /dev/null +++ b/testing/cpp/efficient_i4_cuda_impl/CMakeLists.txt @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +function (ADD_CUDA_TEST_EXECUTABLE name) + add_executable(${name} ${name}.cu) + set_target_properties(${name} PROPERTIES CUDA_ARCHITECTURES 80) + # add flags + target_compile_options(${name} PRIVATE --expt-relaxed-constexpr) + set_target_properties(${name} PROPERTIES + CUDA_SEPARABLE_COMPILATION ON) + target_link_libraries(${name} gtest gtest_main) +endfunction(ADD_CUDA_TEST_EXECUTABLE) + +ADD_CUDA_TEST_EXECUTABLE(efficient_i4) + +function (ADD_CPP_TEST_EXECUTABLE name) + add_executable(${name} ${name}.cpp) + target_link_libraries(${name} gtest gtest_main pthread) +endfunction(ADD_CPP_TEST_EXECUTABLE) + +ADD_CPP_TEST_EXECUTABLE(param_permutate) diff --git a/testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu b/testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu new file mode 100644 index 000000000..257f49a31 --- /dev/null +++ b/testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu @@ -0,0 +1,391 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#include +#include +#include +#include +#include "i4matmul.hpp" + +#define cudaCheckLastError(ans) \ + { \ + gpuAssert((ans), __FILE__, __LINE__); \ + } +inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) +{ + if (code != cudaSuccess) + { + fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); + if (abort) + exit(code); + } +} + +void general_compress(const int8_t *lowbit, int8_t *compressed, const int nbit, const int N, bool isSigned = false) +{ + int zero_point = isSigned ? ((1 << (nbit - 1)) - 1) : 0; + const int nbit_per_byte = 8 / nbit; + + for (int i = 0; i < N / nbit_per_byte; i++) + { + compressed[i] = 0; + for (int j = 0; j < nbit_per_byte; j++) + { + compressed[i] |= ((lowbit[nbit_per_byte * i + j] + zero_point) << (nbit * j)); + } + } +} + + +// Helper function to interleave the perm array +std::vector interleave_perms(const std::vector& perm) { + std::vector interleaved_perm; + std::array interleave = {0, 2, 4, 6, 1, 3, 5, 7}; + + int num_rows = perm.size() / 8; + for (int i = 0; i < num_rows; ++i) { + std::array row; + std::copy(perm.begin() + i * 8, perm.begin() + (i + 1) * 8, row.begin()); + for (int j : interleave) { + interleaved_perm.push_back(row[j]); + } + } + + return interleaved_perm; +} + + +std::tuple, std::vector, std::vector> get_perms() { + std::vector perm; + + for (int i = 0; i < 32; ++i) { + std::vector perm1; + int col = i / 4; + for (int block : {0, 1}) { + for (int row : { + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1 + }) { + perm1.push_back(16 * row + col + 8 * block); + } + } + for (int j = 0; j < 4; ++j) { + for (int p : perm1) { + perm.push_back(p + 256 * j); + } + } + } + + // Interleave the perm array + perm = interleave_perms(perm); + + std::vector scale_perm; + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + scale_perm.push_back(i + 8 * j); + } + } + + std::vector scale_perm_single; + for (int i = 0; i < 4; ++i) { + for (int j : {0, 1, 8, 9, 16, 17, 24, 25}) { + scale_perm_single.push_back(2 * i + j); + } + } + + return std::make_tuple(perm, scale_perm, scale_perm_single); +} + +void weight_pre_process(const int8_t *lowbit, int8_t *compressed, const int nbit, const int K, const int N) +{ + int8_t* tmp1 = new int8_t[K * N]; + const int maxq = 15; + auto [perm, scale_perm, scale_perm_single] = get_perms(); + const int tile_size = 16; + // transform the lowbit matrix to the compressed matrix + for (int i = 0; i < (K / tile_size); i += 1) + { + for (int j = 0; j < (N / tile_size); j += 1) + { + for (int k = 0; k < tile_size; k++) + { + for (int l = 0; l < tile_size; l++) + { + int idx_target = i * N * tile_size + j * tile_size * tile_size + k * tile_size + l; + int idx_source = (i * tile_size + k) * N + j * tile_size + l; + tmp1[idx_target] = lowbit[idx_source] + (maxq + 1) / 2; + } + } + } + } + // print the first 10 of tmp2 + printf("tmp1\n"); + for (int i = 0; i < 10; i++) + { + printf("%d ", tmp1[i]); + } + printf(" ... "); + for (int i = K * N - 10; i < K * N; i++) + { + printf("%d ", tmp1[i]); + } + printf("\n"); + // permute the matrix + int32_t* tmp2 = new int32_t[K * N]; + const int perm_size = perm.size(); + for (int i = 0; i < (N * K / perm_size); i++) + { + for (int j = 0; j < perm_size; j++) + { + int idx_target = i * perm_size + j; + int idx_source = i * perm_size + perm[j]; + tmp2[idx_target] = (int32_t)tmp1[idx_source]; + } + } + // print the first 10 of tmp2 + printf("tmp2\n"); + for (int i = 0; i < 10; i++) + { + printf("%d ", tmp2[i]); + } + printf(" ... "); + for (int i = K * N / (32 / nbit) - 10; i < K * N / (32 / nbit); i++) + { + printf("%d ", tmp2[i]); + } + printf("\n"); + // compress + int32_t* tmp3 = new int32_t[K * N / (32 / nbit)]; + // set zero + for (int i = 0; i < K * N / (32 / nbit); i++) + { + tmp3[i] = 0; + } + for (int i = 0; i < (K / tile_size); i++) + { + for (int j = 0; j < (N * tile_size / 8); j++) + { + for (int k = 0; k < 8; k++) + { + int idx_target = i * N * tile_size / 8 + j; + int idx_source = i * N * tile_size + j * 8 + k; + tmp3[idx_target] |= (tmp2[idx_source] << (nbit * (k % 8))); + } + } + } + // print the first 10 of tmp3 + printf("tmp3\n"); + for (int i = 0; i < 10; i++) + { + printf("%d ", tmp3[i]); + } + printf(" ... "); + for (int i = K * N / (32 / nbit) - 10; i < K * N / (32 / nbit); i++) + { + printf("%d ", tmp3[i]); + } + printf("\n"); + // copy tmp3 to compressed + for (int i = 0; i < K * N / (32 / nbit); i++) + { + ((int32_t *)(compressed))[i] = tmp3[i]; + } +} + +void scale_pre_process(const half *scale, half *scale_perm, const int K, const int N, int group_size) +{ + auto [perm, scale_perm_group, scale_perm_single] = get_perms(); + if (group_size == -1) + group_size = K; + if (group_size == K){ + const int perm_size = scale_perm_single.size(); + for (int i = 0; i < (N * K / group_size / perm_size); i++) + { + for (int j = 0; j < perm_size; j++) + { + int idx_target = i * perm_size + j; + int idx_source = i * perm_size + scale_perm_single[j]; + if (idx_target < 10){ + printf("idx_target = %d, idx_source = %d\n", idx_target, idx_source); + } + scale_perm[idx_target] = scale[idx_source]; + } + } + } + else{ + const int perm_size = scale_perm_group.size(); + for (int i = 0; i < (N * K / group_size / perm_size); i++) + { + for (int j = 0; j < perm_size; j++) + { + int idx_target = i * perm_size + j; + int idx_source = i * perm_size + scale_perm_group[j]; + scale_perm[idx_target] = scale[idx_source]; + } + } + } + // print the first 10 of tmp2 + printf("scale_perm\n"); + for (int i = 0; i < 10; i++) + { + printf("%f ", (float)scale_perm[i]); + } + printf(" ... "); + for (int i = K * N / group_size - 10; i < K * N / group_size; i++) + { + printf("%f ", (float)scale_perm[i]); + } +} + +TEST(EfficientI4MatmulTest, GEMVTest) +{ + const int prom_m = 1; + const int prom_n = 256; + const int prom_k = 256; + const int bits = 4; + const int group_size = prom_k; + + half* A = new half[prom_m * prom_k]; + int8_t* B = new int8_t[prom_k * prom_n]; + int8_t* qB_interleave = new int8_t[prom_k * prom_n / (8 / bits)]; + half* C = new half[prom_m * prom_n]; + half* s = new half[prom_n * (prom_k / group_size)]; + half* s_perm = new half[prom_n * (prom_k / group_size)]; + + // Initialize A and B + for (int i = 0; i < prom_m * prom_k; i++) + { + A[i] = __float2half(rand() / (float)RAND_MAX); + } + for (int i = 0; i < prom_k * prom_n; i++) + { + B[i] = rand() % 4 - 2; + } + for (int i = 0; i < prom_k * prom_n / group_size; i++) + { + // s[i] = __float2half(0.1); + s[i] = __float2half(rand() / (float)RAND_MAX); + } + + weight_pre_process(B, qB_interleave, bits, prom_k, prom_n); + // print the first 10 elements and last 10 elements of C + for (int i = 0; i < 10; i++) + { + printf("%d ", B[i]); + } + printf(" ... "); + for (int i = prom_k * prom_n - 10; i < prom_k * prom_n; i++) + { + printf("%d ", B[i]); + } + // print interleave of B + for (int i = 0; i < 10; i++) + { + printf("%d ", qB_interleave[i]); + } + printf(" ... "); + for (int i = prom_k * prom_n / (8 / bits) - 10; i < prom_k * prom_n / (8 / bits); i++) + { + printf("%d ", qB_interleave[i]); + } + printf("\n"); + // print last 10 of qb_interleave + for (int i = prom_k * prom_n / (8 / bits) - 10; i < prom_k * prom_n / (8 / bits); i++) + { + printf("%d ", qB_interleave[i]); + } + printf("\n"); + // print last 10 of B + for (int i = prom_k * prom_n - 10; i < prom_k * prom_n; i++) + { + printf("%d ", B[i]); + } + printf("\n"); + // print last 10 of s + for (int i = prom_n * (prom_k / group_size) - 10; i < prom_n * (prom_k / group_size); i++) + { + printf("%f ", __half2float(s[i])); + } + printf("\n"); + scale_pre_process(s, s_perm, prom_k, prom_n, group_size); + // define cuda variables + float* d_workspace = nullptr; + cudaCheckLastError(cudaMalloc((void**)&d_workspace, prom_n * prom_k * 16 * sizeof(float))); + + half* d_A; + int8_t* d_qB; + half* d_C; + half* d_s; + cudaCheckLastError(cudaMalloc((void**)&d_A, prom_m * prom_k * sizeof(half))); + cudaCheckLastError(cudaMalloc((void**)&d_qB, prom_k * prom_n / (8 / bits) * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void**)&d_C, prom_m * prom_n * sizeof(half))); + cudaCheckLastError(cudaMalloc((void**)&d_s, prom_n * (prom_k / group_size) * sizeof(half))); + // copy A and B to device + cudaCheckLastError(cudaMemcpy(d_A, A, prom_m * prom_k * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(d_qB, qB_interleave, prom_n * prom_k / (8 / bits) * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(d_s, s_perm, prom_n * (prom_k / group_size) * sizeof(half), cudaMemcpyHostToDevice)); + + // allocate workspace + // call the kernel + int ret = marlin_cuda(d_A, d_qB, d_C, d_s, prom_m, prom_n, prom_k, d_workspace, group_size == prom_k? -1: group_size); + printf("ret = %d\n", ret); + + // copy C back to host + cudaCheckLastError(cudaMemcpy(C, d_C, prom_m * prom_n * sizeof(half), cudaMemcpyDeviceToHost)); + // print the first 10 elements and last 10 elements of C + for (int i = 0; i < 10; i++) + { + printf("%f ", __half2float(C[i])); + } + printf(" ... "); + for (int i = prom_m * prom_n - 10; i < prom_m * prom_n; i++) + { + printf("%f ", __half2float(C[i])); + } + printf("\n"); + + // ref calculation + float* ref_C = new float[prom_m * prom_n]; + // zero fill + for (int i = 0; i < prom_m * prom_n; i++) + { + ref_C[i] = __float2half(0.0); + } + // + for (int i = 0; i < prom_m; i++) + { + for (int j = 0; j < prom_n; j++) + { + ref_C[i * prom_n + j] = __float2half(0.0); + for (int k = 0; k < prom_k; k++) + { + ref_C[i * prom_n + j] += float(A[i * prom_k + k]) * (float(B[k * prom_n + j]) * float(s[(k / group_size) * prom_n + j])); + } + } + } + for (int i = 0; i < 10; i++) + { + printf("%f ", __half2float(ref_C[i])); + } + printf(" ... "); + for (int i = prom_m * prom_n - 10; i < prom_m * prom_n; i++) + { + printf("%f ", __half2float(ref_C[i])); + } + printf("\n"); + + // check the result + for (int i = 0; i < prom_m * prom_n; i++) + { + EXPECT_NEAR(__half2float(C[i]), __half2float(ref_C[i]), 1e-1); + } + + // free memory + delete[] A; + delete[] B; + delete[] C; + cudaCheckLastError(cudaFree(d_A)); + cudaCheckLastError(cudaFree(d_qB)); + cudaCheckLastError(cudaFree(d_C)); +} diff --git a/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp b/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp new file mode 100644 index 000000000..a12a57dcd --- /dev/null +++ b/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp @@ -0,0 +1,826 @@ +// Copyright 2018 The apache/tvm Authors. All Rights Reserved. +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +// Modifications Copyright (c) Microsoft. +// The code below is mostly copied from marlin_cuda in IST-DASLab/marlin. + +#ifndef MARLIN_CUDA_KERNEL_CUH +#define MARLIN_CUDA_KERNEL_CUH + + +#include +#include +#include +#include + + +constexpr int ceildiv(int a, int b) { + return (a + b - 1) / b; +} + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed for instance as inputs to tensor core +// operations. Consequently, all corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { + return elems[i]; + } +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales + +// Predicated asynchronous global->shared copy; used for inputs A where we apply predication to handle batchsizes that +// are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" :: "r"((int) pred), "r"(smem), "l"(glob_ptr), "n"(BYTES) + ); +} + +// Asynchronous global->shared copy with a cache hint indicating that the values may be evicted immediately; used for +// quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need +// for inputs A and outputs C. +__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .b64 p;\n" + " createpolicy.fractional.L2::evict_first.b64 p, 1.0;" + " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n" + "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES) + ); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" :: "n"(n)); +} + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 output/accumulation. +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]) + ); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem) + ); +} + +// Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to +// automatically recognize it in all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut) + ); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values. +// We mostly follow the strategy in the link below, with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + FragB frag_b; + frag_b[0] = __hsub2( + *reinterpret_cast(&lo), + *reinterpret_cast(&SUB) + ); + frag_b[1] = __hfma2( + *reinterpret_cast(&hi), + *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) + ); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used only for grouped quantization. +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible globally. + asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible globally, while releasing the barrier. + asm volatile ("fence.acq_rel.gpu;\n"); + asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); + } +} + + +template < + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m dimension (batchsize) of the threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const int stages, // number of stages for the async global->shared fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks with a separate quantization scale +> +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the same size, which might involve multiple + // column "slices" (of width 16 * `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it ensures good utilization of all SMs + // for many kinds of shape and GPU configurations, while requiring as few slow global cross-threadblock reductions as + // possible. + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + // Ensure that the number of tiles in each stripe is a multiple of the groupsize; this avoids an annoying special case + // where a stripe starts in the middle of group. + if (group_blocks != -1) + iters = (group_blocks / thread_k_blocks) * ceildiv(iters, (group_blocks / thread_k_blocks)); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to top + + // We can easily implement parallel problem execution by just remapping indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for synchronization. + auto init_slice = [&] () { + slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) + slice_iters = 0; + if (slice_iters == 0) + return; + if (slice_row + slice_iters > k_tiles) + slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) + slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) + slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + // We typically use `constexpr` to indicate that this value is a compile-time constant + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; // delta between subsequent A tiles in global memory + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // between shared memory tile reads + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // within a shared memory tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); // overall size of a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); // number of shared write iterations for a tile + + int b_gl_stride = 16 * prob_n / 32; + constexpr int b_sh_stride = 32 * thread_n_blocks / 4; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); + constexpr int b_sh_wr_delta = threads; + constexpr int b_sh_rd_delta = threads; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_sh_stage = s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x; + int b_sh_rd = threadIdx.x; + + int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; + int s_sh_wr = threadIdx.x; + int s_sh_rd; + // We use a different scale layout for grouped and column-wise quantization as we scale a `half2` tile in column-major + // layout in the former and in row-major in the latter case. + if (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + + // Precompute which thread should not read memory in which iterations; this is needed if there are more threads than + // required for a certain tilesize or when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // To ensure that writing and reading A tiles to/from shared memory, the latter in fragment format, is fully bank + // conflict free, we need to use a rather fancy XOR-based layout. The key here is that neither reads nor writes of + // the 16-byte `int4` blocks of 8 consecutive threads involve the same shared memory banks. Further, it seems (based + // on NSight-Compute) that each warp must also write a consecutive memory segment? + auto transform_a = [&] (int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main loop unrolls, all shared memory + // accesses are static, we simply precompute both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at runtime; we break dependicies between + // subsequent accesses with a tile by maintining multiple pointers (we have enough registers), a tiny optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s = sh_b + (stages * b_sh_stage); + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; + + // Zero accumulators. + auto zero_accums = [&] () { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + // Asynchronously fetch the next A, B and s tile from global to the next shared memory pipeline location. + auto fetch_to_shared = [&] (int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i] + ); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + B_ptr[i] += b_gl_rd_delta_o; + } + // Only fetch scales if this tile starts a new group + if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) + cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + s_gl_rd += s_gl_rd_delta; + } + } + // Insert a fence even when we are winding down the pipeline to ensure that waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&] () { + // We only have `stages - 2` active fetches since we are double buffering and can only issue the next fetch when + // it is guaranteed that the previous shared memory load is fully complete (as it may otherwise be overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe into the current register buffer. + auto fetch_to_registers = [&] (int k, int pipe) { + // It may seem inefficient that we reload the groups for every sub-tile; however, this does not seem to be a + // significant bottleneck, while some theoretically better attempts have lead to bad instruction ordering by the + // compiler and correspondingly a noticable drop in performance. + if (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&] (int k) { + // We have the m dimension as the inner loop in order to encourage overlapping dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant = frag_b_quant[k % 2][j]; + int b_quant_shift = b_quant >> 8; + FragB frag_b0 = dequant(b_quant); + // If there are no groups, we can just scale the final output once and can avoid doing so for each weight. + if (group_blocks != -1) + scale(frag_b0, frag_s[k % 2][j], 0); + FragB frag_b1 = dequant(b_quant_shift); + if (group_blocks != -1) + scale(frag_b1, frag_s[k % 2][j], 1); + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the number of warps while keeping the n + // dimension of a tile reasonable, we have multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&] () { + constexpr int red_off = threads / b_sh_stride / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any unnecessary read or write iterations, + // e.g., for two warps we write only once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we finally have to globally reduce over + // the results. As the striped partioning minimizes the number of such reductions and our outputs are usually rather + // small, we perform this reduction serially in L2 cache. + auto global_reduce = [&] (bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to maximize L2 cache utilization in this step. + // To do this, we write out results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up the compiler and lead to slowdowns, + // hence we also use async-copies even though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m + ); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += __half2float( + reinterpret_cast<__half*>(&c_red)[j] + ); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<__half*>(&c)[j] = __float2half( + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] + ); + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually reshuffle matrix fragments in this step, + // the reduction above is performed in fragment layout. + auto write_result = [&] () { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final global write patterns + auto write = [&] (int idx, float c0, float c1, FragS& s) { + half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + if (group_blocks == -1) // for per-column quantization we finally apply the scale here + res = __hmul2(res, s[0]); + ((half2*) sh)[idx] = res; + }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&] () { + #pragma unroll + for (int i = 0; i < stages - 1; i++) + fetch_to_shared(i, i, i < slice_iters); + zero_accums(); + wait_for_stage(); + fetch_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + }; + start_pipes(); + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to ensure all shared memory accesses are + // static. Note that both pipelines have even length meaning that the next iteration will always start at index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); + pipe++; + wait_for_stage(); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) + break; + } + a_gl_rd += a_gl_rd_delta_o * stages; + + // Process results and, if necessary, proceed to the next column slice. While this pattern may not be the most + // readable, other ways of writing the loop seemed to noticeably worse performance after compliation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before write-out + if (group_blocks == -1 && last) { + if (s_sh_wr_pred) + cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async_fence(); + } + thread_block_reduce(); + if (group_blocks == -1 && last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + if (slice_count > 1) { // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] -= b_gl_stride; + } + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + start_pipes(); + } + } + } +} + + +// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per schedule allows some more +// latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles. +const int THREADS = 256; +const int STAGES = 4; // 4 pipeline stages fit into shared memory +const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + +#define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \ + else if ( \ + thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS \ + ) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + SHARED_MEM \ + ); \ + Marlin< \ + THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS \ + ><<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, \ + prob_m, prob_n, prob_k, \ + locks \ + ); \ + } + +const int ERR_PROB_SHAPE = 1; +const int ERR_KERN_SHAPE = 2; + +int marlin_cuda( + const void* A, + const void* B, + void* C, + void* s, + int prob_m, + int prob_n, + int prob_k, + void* workspace, + int groupsize = -1, + int dev = 0, + cudaStream_t stream = 0, + int thread_k = -1, + int thread_n = -1, + int sms = -1, + int max_par = 16 +) { + int tot_m = prob_m; + int tot_m_blocks = ceildiv(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + if (sms == -1) + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + if (thread_k == -1 || thread_n == -1) { + if (prob_m <= 16) { + // For small batchizes, better partioning is slightly more important than better compute utilization + thread_k = 128; + thread_n = 128; + } else { + thread_k = 64; + thread_n = 256; + } + } + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; + int blocks = sms; + + if (prob_n % thread_n != 0 || prob_k % thread_k != 0 || (group_blocks != -1 && prob_k % group_blocks != 0)) + return ERR_PROB_SHAPE; + if (prob_m == 0 || prob_n == 0 || prob_k == 0) + return 0; + + const int4* A_ptr = (const int4*) A; + const int4* B_ptr = (const int4*) B; + int4* C_ptr = (int4*) C; + const int4* s_ptr = (const int4*) s; + + int cols = prob_n / thread_n; + int* locks = (int*) workspace; + + int ret = 0; + for (int i = 0; i < tot_m_blocks; i += 4) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > 4) { + // Note that parallel > 1 currently only works for inputs without any padding + par = (16 * thread_m_blocks - pad) / 64; + if (par > max_par) + par = max_par; + prob_m = 64 * par; + i += 4 * (par - 1); + thread_m_blocks = 4; + } + + // For compilation speed, we only define the kernel configurations that have seemed useful (in terms of performance) + // in our testing, however many more are, in principle, possible. + if (false) {} + CALL_IF(1, 8, 8, -1) + CALL_IF(1, 8, 8, 8) + CALL_IF(1, 16, 4, -1) + CALL_IF(1, 16, 4, 8) + CALL_IF(2, 16, 4, -1) + CALL_IF(2, 16, 4, 8) + CALL_IF(3, 16, 4, -1) + CALL_IF(3, 16, 4, 8) + CALL_IF(4, 16, 4, -1) + CALL_IF(4, 16, 4, 8) + else + ret = ERR_KERN_SHAPE; + + A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + } + + return ret; +} + + +#endif diff --git a/testing/cpp/efficient_i4_cuda_impl/param_permutate.cpp b/testing/cpp/efficient_i4_cuda_impl/param_permutate.cpp new file mode 100644 index 000000000..64248b3d1 --- /dev/null +++ b/testing/cpp/efficient_i4_cuda_impl/param_permutate.cpp @@ -0,0 +1,89 @@ +#include +#include +#include +#include +#include +#include + +// Helper function to interleave the perm array +std::vector interleave_perms(const std::vector& perm) { + std::vector interleaved_perm; + std::array interleave = {0, 2, 4, 6, 1, 3, 5, 7}; + + int num_rows = perm.size() / 8; + for (int i = 0; i < num_rows; ++i) { + std::array row; + std::copy(perm.begin() + i * 8, perm.begin() + (i + 1) * 8, row.begin()); + for (int j : interleave) { + interleaved_perm.push_back(row[j]); + } + } + + return interleaved_perm; +} + +std::tuple, std::vector, std::vector> get_perms() { + std::vector perm; + + for (int i = 0; i < 32; ++i) { + std::vector perm1; + int col = i / 4; + for (int block : {0, 1}) { + for (int row : { + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1 + }) { + perm1.push_back(16 * row + col + 8 * block); + } + } + for (int j = 0; j < 4; ++j) { + for (int p : perm1) { + perm.push_back(p + 256 * j); + } + } + } + + // Interleave the perm array + perm = interleave_perms(perm); + + std::vector scale_perm; + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + scale_perm.push_back(i + 8 * j); + } + } + + std::vector scale_perm_single; + for (int i = 0; i < 4; ++i) { + for (int j : {0, 1, 8, 9, 16, 17, 24, 25}) { + scale_perm_single.push_back(2 * i + j); + } + } + + return std::make_tuple(perm, scale_perm, scale_perm_single); +} + +TEST(EfficientI4MatmulTest, ParamPermutate) +{ + auto [perm, scale_perm, scale_perm_single] = get_perms(); + + std::cout << "perm: "; + for (int i = 0; i < 10; ++i) { + std::cout << perm[i] << " "; + } + std::cout << std::endl; + + std::cout << "scale_perm: "; + for (const auto& val : scale_perm) { + std::cout << val << " "; + } + std::cout << std::endl; + + std::cout << "scale_perm_single: "; + for (const auto& val : scale_perm_single) { + std::cout << val << " "; + } + std::cout << std::endl; +} diff --git a/testing/python/builder/test_backend_tir_builder.py b/testing/python/builder/test_backend_tir_builder.py new file mode 100644 index 000000000..22c134b12 --- /dev/null +++ b/testing/python/builder/test_backend_tir_builder.py @@ -0,0 +1,55 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import bitblas +from bitblas import MatmulConfig, Matmul +import logging +from bitblas import set_log_level +from bitblas.builder.wrapper import TIRWrapper + +set_log_level(logging.DEBUG) + + +def get_codegen_result(ops): + code = ops.get_source() + return code + + +def matmul_backend_code_wrap( + M, + N, + K, + A_dtype, + W_dtype, + accum_dtype, + out_dtype, + with_bias, +): + import torch + torch.random.manual_seed(0) + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + with_bias=with_bias, + ) + matmul = Matmul(config=matmul_config, enable_tuning=False) + backend = TIRWrapper(arch=matmul.arch) + backend.assign_optimized_module(matmul.optimized_func) + wrapped_code = backend.wrap(matmul.get_source(), is_dynamic=isinstance(M, list)) + assert "void call" in wrapped_code + + +def test_matmul_transform_weight(): + matmul_backend_code_wrap(1, 768, 768, "float16", "uint4", "float16", "float16", False) + matmul_backend_code_wrap(768, 768, 768, "float16", "uint4", "float16", "float16", False) + matmul_backend_code_wrap([1, 768], 768, 768, "float16", "uint4", "float16", "float16", False) + + +# fmt: on +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/module/test_bitblas_linear.py b/testing/python/module/test_bitblas_linear.py index eee08c93c..f329a146e 100644 --- a/testing/python/module/test_bitblas_linear.py +++ b/testing/python/module/test_bitblas_linear.py @@ -6,11 +6,11 @@ import time import numpy as np import torch.nn as nn -import pytest torch.manual_seed(0) bitblas.set_log_level("DEBUG") + def correctness_consistent(m, in_features, out_features, bias): linear_torch = (nn.Linear(in_features, out_features, bias=bias).to(torch.float16).cuda()) linear_bitblas = BitBLASLinear( @@ -45,6 +45,7 @@ def test_correctness_consistent(): correctness_consistent(1024, 1024, 1024, True) correctness_consistent([1, 1024], 1024, 1024, True) + def correctness_weight_only_dequantize( m, in_features, diff --git a/testing/python/type_conversion/test_int4b_fp16_convert.py b/testing/python/type_conversion/test_int4b_fp16_convert.py index 92b0e0788..3a58a47e1 100644 --- a/testing/python/type_conversion/test_int4b_fp16_convert.py +++ b/testing/python/type_conversion/test_int4b_fp16_convert.py @@ -5,7 +5,6 @@ import torch import numpy as np from tvm.script import tir as T -import numpy as np def general_compress_to_int8(lowprecision_weight, source_bits=4): @@ -21,9 +20,7 @@ def general_compress_to_int8(lowprecision_weight, source_bits=4): ) for j in range(lowprecision_weight.shape[-1] // elems_per_byte): for k in range(elems_per_byte): - int8_weight[:, j] |= lowprecision_weight[:, j * elems_per_byte + k] << ( - source_bits * k - ) + int8_weight[:, j] |= lowprecision_weight[:, j * elems_per_byte + k] << (source_bits * k) return int8_weight @@ -44,25 +41,25 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): if nbits == 1 and target_dtype == "int8": # special handling for 1b interleave - n16_weight = new_qweight & np.int32(0xF0F00F0F) - n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16 - n16_weight |= ((new_qweight & np.int32(0x0000F000)) >> 12) << 24 - n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4 - n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12 + n16_weight = new_qweight & np.int32(np.uint32(0xF0F00F0F)) + n16_weight |= ((new_qweight & np.int32(np.uint32(0x000000F0))) >> 4) << 16 + n16_weight |= ((new_qweight & np.int32(np.uint32(0x0000F000))) >> 12) << 24 + n16_weight |= ((new_qweight & np.int32(np.uint32(0x000F0000))) >> 16) << 4 + n16_weight |= ((new_qweight & np.int32(np.uint32(0x0F000000))) >> 24) << 12 return n16_weight.view(np.int8) elif nbits == 2 and target_dtype == "float16": - n8_weight = new_qweight & np.int32(0xFF0000FF) - n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16 - n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8 + n8_weight = new_qweight & np.int32(np.uint32(0xFF0000FF)) + n8_weight |= ((new_qweight & np.int32(np.uint32(0x0000FF00))) >> 8) << 16 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x00FF0000))) >> 16) << 8 return n8_weight.view(np.int8) elif nbits == 1 and target_dtype == "float16": - n8_weight = new_qweight & 0xF000000F - n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8 - n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16 - n8_weight |= ((new_qweight & 0x0000F000) >> 12) << 24 - n8_weight |= ((new_qweight & 0x000F0000) >> 16) << 4 - n8_weight |= ((new_qweight & 0x00F00000) >> 20) << 12 - n8_weight |= ((new_qweight & 0x0F000000) >> 24) << 20 + n8_weight = new_qweight & np.int32(np.uint32(0xF000000F)) + n8_weight |= ((new_qweight & np.int32(np.uint32(0x000000F0))) >> 4) << 8 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x00000F00))) >> 8) << 16 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x0000F000))) >> 12) << 24 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x000F0000))) >> 16) << 4 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x00F00000))) >> 20) << 12 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x0F000000))) >> 24) << 20 return new_qweight.view(np.int8) @@ -80,17 +77,11 @@ def interleave_weight(A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32 with T.block("B"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + ( - offset // num_groups - ) * bits - B[v0, v1] = B[v0, v1] | ( - ((A[v0, v1] >> (bits * offset)) & mask) << shift - ) + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) @T.prim_func - def interleave_weight_f16_2b( - A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32") - ): + def interleave_weight_f16_2b(A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32")): B_tmp_1 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_2 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_3 = T.alloc_buffer((N, QK), "int32", scope="local") @@ -98,12 +89,8 @@ def interleave_weight_f16_2b( with T.block("B_tmp"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + ( - offset // num_groups - ) * bits - B[v0, v1] = B[v0, v1] | ( - ((A[v0, v1] >> (bits * offset)) & mask) << shift - ) + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) for ax0, ax1 in T.grid(N, QK): with T.block("B"): @@ -114,9 +101,7 @@ def interleave_weight_f16_2b( B[v0, v1] = B_tmp_1[v0, v1] | B_tmp_2[v0, v1] | B_tmp_3[v0, v1] @T.prim_func - def interleave_weight_f16_1b( - A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32") - ): + def interleave_weight_f16_1b(A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32")): B_tmp_1 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_2 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_3 = T.alloc_buffer((N, QK), "int32", scope="local") @@ -128,12 +113,8 @@ def interleave_weight_f16_1b( with T.block("B_tmp"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + ( - offset // num_groups - ) * bits - B[v0, v1] = B[v0, v1] | ( - ((A[v0, v1] >> (bits * offset)) & mask) << shift - ) + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) for ax0, ax1 in T.grid(N, QK): with T.block("B"): @@ -152,13 +133,10 @@ def interleave_weight_f16_1b( | B_tmp_4[v0, v1] | B_tmp_5[v0, v1] | B_tmp_6[v0, v1] - | B_tmp_7[v0, v1] - ) + | B_tmp_7[v0, v1]) @T.prim_func - def interleave_weight_int8_1b( - A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32") - ): + def interleave_weight_int8_1b(A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32")): B_tmp_1 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_2 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_3 = T.alloc_buffer((N, QK), "int32", scope="local") @@ -168,12 +146,8 @@ def interleave_weight_int8_1b( with T.block("B_tmp"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + ( - offset // num_groups - ) * bits - B[v0, v1] = B[v0, v1] | ( - ((A[v0, v1] >> (bits * offset)) & mask) << shift - ) + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) for ax0, ax1 in T.grid(N, QK): with T.block("B"): @@ -188,8 +162,7 @@ def interleave_weight_int8_1b( | B_tmp_2[v0, v1] | B_tmp_3[v0, v1] | B_tmp_4[v0, v1] - | B_tmp_5[v0, v1] - ) + | B_tmp_5[v0, v1]) if target_dtype == "float16" and bits == 2: return interleave_weight_f16_2b @@ -207,7 +180,7 @@ def test_lop3_interleave_weight(): K = 16 target_dtype = "float16" torch.manual_seed(0) - uint_max = 2 ** (source_nbits) - 1 + uint_max = 2**(source_nbits) - 1 raw_data = torch.randint(0, uint_max, (N, K), dtype=torch.int8).cpu().numpy() compressed_b = general_compress_to_int8(raw_data, source_nbits) interleaved_weight = interleave_weight(compressed_b, source_nbits, target_dtype)