diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 90262f428..12408b72d 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -51,6 +51,13 @@ jobs: cd benchmark/operators python ./benchmark_ops_matmul.py + benchmark_head: + # On pull requests and if the comment starts with `/run-benchmark` + if: github.event.issue.pull_request != null && startsWith(github.event.comment.body, '/run-benchmark') + runs-on: self-hosted + depends-on: [benchmark_base] + + steps: - name: Checkout PR branch code uses: actions/checkout@v2 with: @@ -92,7 +99,7 @@ jobs: python ./benchmark_ops_matmul.py benchmark_compare: - if: github.event.issue.pull_request != '' && contains(github.event.comment.body, '/run-benchmark') + if: github.event.issue.pull_request != null && contains(github.event.comment.body, '/run-benchmark') needs: [benchmark_base, benchmark_head] runs-on: self-hosted diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9b76866c5..be3688ee5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,7 +56,7 @@ jobs: run: | source bitblas_ci/bin/activate python -m pip install --upgrade pip - if [ -f requirements-dev.txt ]; then python -m pip install -r requirements-dev.txt; fi + if [ -f requirements-test.txt ]; then python -m pip install -r requirements-test.txt; fi - name: Install project in wheel mode run: | diff --git a/3rdparty/tvm b/3rdparty/tvm index 0b0faa5cd..a9be6adb4 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 0b0faa5cd7ae077bc730c2638bf2ab29adaede5d +Subproject commit a9be6adb4793eca1ca7f1de544571b087f3c8b32 diff --git a/benchmark/operators/benchmark_ops_matmul.py b/benchmark/operators/benchmark_ops_matmul.py index 723cf035b..a17baa154 100644 --- a/benchmark/operators/benchmark_ops_matmul.py +++ b/benchmark/operators/benchmark_ops_matmul.py @@ -108,41 +108,41 @@ def prepare_benchmark_sets(self): "FP16xFP16_ACCFP16_NT", [ *self.prepare_set_group_4x("FP16xFP16_ACCFP16_NT", 16384, 16384, 16384), - *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 3200, 3200), - *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8640, 3200), - *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 3200, 8640), - *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 5120, 5120), - *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 13824, 5120), - *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 5120, 13824), - *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 6656, 6656), - *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 17920, 6656), - *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 6656, 17920), - *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 1024, 8192), - *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8192, 8192), - *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 28672, 8192), - *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8192, 28672), + # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 3200, 3200), + # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8640, 3200), + # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 3200, 8640), + # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 5120, 5120), + # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 13824, 5120), + # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 5120, 13824), + # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 6656, 6656), + # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 17920, 6656), + # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 6656, 17920), + # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 1024, 8192), + # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8192, 8192), + # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 28672, 8192), + # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8192, 28672), ], ) - self.add_benchmark_set( - "INT8xINT8_ACCINT32_NT", - [ - *self.prepare_set_group_4x("INT8xINT8_ACCINT32_NT", 16384, 16384, 16384), - *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 3200, 3200), - *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8640, 3200), - *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 3200, 8640), - *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 5120, 5120), - *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 13824, 5120), - *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 5120, 13824), - *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 6656, 6656), - *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 17920, 6656), - *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 6656, 17920), - *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 1024, 8192), - *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8192, 8192), - *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 28672, 8192), - *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8192, 28672), - ], - ) + # self.add_benchmark_set( + # "INT8xINT8_ACCINT32_NT", + # [ + # *self.prepare_set_group_4x("INT8xINT8_ACCINT32_NT", 16384, 16384, 16384), + # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 3200, 3200), + # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8640, 3200), + # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 3200, 8640), + # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 5120, 5120), + # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 13824, 5120), + # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 5120, 13824), + # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 6656, 6656), + # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 17920, 6656), + # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 6656, 17920), + # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 1024, 8192), + # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8192, 8192), + # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 28672, 8192), + # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8192, 28672), + # ], + # ) def generate_operator_config(self, name: str, M, N, K) -> MatmulConfig: """Generate configuration for the given operator.""" diff --git a/bitblas/__init__.py b/bitblas/__init__.py index ee79bc3c9..91e88133c 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -36,7 +36,7 @@ ) from . import testing # noqa: F401 -from .utils import auto_detect_nvidia_target # noqa: F401 +from .utils import auto_detect_nvidia_target, apply_transform_on_input # noqa: F401 from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401 from .ops.general_matmul_splitk import MatmulConfigWithSplitK, MatmulWithSplitK # noqa: F401 from .ops.matmul_dequantize import MatmulWeightOnlyDequantizeConfig, MatmulWeightOnlyDequantize # noqa: F401 diff --git a/bitblas/builder/lib_generator/__init__.py b/bitblas/builder/lib_generator/__init__.py index a0800751a..fd877c679 100644 --- a/bitblas/builder/lib_generator/__init__.py +++ b/bitblas/builder/lib_generator/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from typing import Optional -from bitblas import TileDevice +from bitblas.base.arch import TileDevice import ctypes import os import tempfile diff --git a/bitblas/builder/wrapper/tir.py b/bitblas/builder/wrapper/tir.py index 2d0162f66..0bedf70ed 100644 --- a/bitblas/builder/wrapper/tir.py +++ b/bitblas/builder/wrapper/tir.py @@ -3,13 +3,14 @@ from bitblas import tvm from typing import Optional, List, Dict, Union from tvm import IRModule -from bitblas import TileDevice +from bitblas.base.arch import TileDevice from bitblas.utils import match_global_kernel from bitblas.utils.rtmod_analysis import get_annotated_device_mod import re -from .base import BaseWrapper import logging +from .base import BaseWrapper + logger = logging.getLogger(__name__) diff --git a/bitblas/gpu/intrin/lop3.py b/bitblas/gpu/intrin/lop3.py index dc4bb587b..08ed8f7f0 100644 --- a/bitblas/gpu/intrin/lop3.py +++ b/bitblas/gpu/intrin/lop3.py @@ -3,7 +3,7 @@ from bitblas import tvm from tvm.tir.function import TensorIntrin from tvm.script import tir as T -from typing import Dict, Literal +from typing import Dict, Literal, List from bitblas.quantization import ( _tir_packed_int_to_int_convert, _tir_packed_to_signed_convert, @@ -769,6 +769,7 @@ def get_fast_decode_intrin( with_scale=False, with_zeros=False, zeros_mode="original", + storage_scope="local", ): """ loops extent is the number of elements to be decoded in one stage @@ -814,7 +815,7 @@ def fast_decode_desc(compressed: T.handle, decompressed: T.handle) -> None: n_storage_elems, ], dtype=storage_dtype, - scope="local", + scope=storage_scope, ) Decompressed = T.match_buffer( decompressed, @@ -822,7 +823,7 @@ def fast_decode_desc(compressed: T.handle, decompressed: T.handle) -> None: loops_extent, ], dtype=target_dtype, - scope="local", + scope=storage_scope, ) with T.block("root"): @@ -846,7 +847,8 @@ def fast_decode_impl(compressed: T.handle, decompressed: T.handle) -> None: n_storage_elems, ], dtype=storage_dtype, - scope="local", + scope=storage_scope, + offset_factor=n_storage_elems, ) Decompressed = T.match_buffer( decompressed, @@ -854,7 +856,8 @@ def fast_decode_impl(compressed: T.handle, decompressed: T.handle) -> None: loops_extent, ], dtype=target_dtype, - scope="local", + scope=storage_scope, + offset_factor=loops_extent, ) with T.block("root"): @@ -863,8 +866,8 @@ def fast_decode_impl(compressed: T.handle, decompressed: T.handle) -> None: T.call_extern( "handle", func_name, - Compressed.data, - Decompressed.data, + Compressed.access_ptr("r"), + Decompressed.access_ptr("w"), loops_extent, ) @@ -878,7 +881,7 @@ def fast_decode_desc(compressed: T.handle, decompressed: T.handle, scale: T.hand n_storage_elems, ], dtype=storage_dtype, - scope="local", + scope=storage_scope, ) Decompressed = T.match_buffer( decompressed, @@ -886,7 +889,7 @@ def fast_decode_desc(compressed: T.handle, decompressed: T.handle, scale: T.hand loops_extent, ], dtype=target_dtype, - scope="local", + scope=storage_scope, ) Scale = T.match_buffer( scale, @@ -920,7 +923,7 @@ def fast_decode_impl(compressed: T.handle, decompressed: T.handle, scale: T.hand n_storage_elems, ], dtype=storage_dtype, - scope="local", + scope=storage_scope, ) Decompressed = T.match_buffer( decompressed, @@ -928,7 +931,7 @@ def fast_decode_impl(compressed: T.handle, decompressed: T.handle, scale: T.hand loops_extent, ], dtype=target_dtype, - scope="local", + scope=storage_scope, ) Scale = T.match_buffer( scale, @@ -988,7 +991,7 @@ def fast_decode_desc( n_storage_elems, ], dtype=storage_dtype, - scope="local", + scope=storage_scope, ) Decompressed = T.match_buffer( decompressed, @@ -996,7 +999,7 @@ def fast_decode_desc( loops_extent, ], dtype=target_dtype, - scope="local", + scope=storage_scope, ) Scale = T.match_buffer( scale, @@ -1004,7 +1007,7 @@ def fast_decode_desc( 1, ], dtype=target_dtype, - scope="local", + scope=storage_scope, ) Zeros = T.match_buffer( zeros, @@ -1012,7 +1015,7 @@ def fast_decode_desc( 1, ], dtype=storage_dtype, - scope="local", + scope=storage_scope, ) with T.block("root"): T.reads(*get_dequantize_buffers_list( @@ -1053,7 +1056,7 @@ def fast_decode_impl( n_storage_elems, ], dtype=storage_dtype, - scope="local", + scope=storage_scope, ) Decompressed = T.match_buffer( decompressed, @@ -1061,7 +1064,7 @@ def fast_decode_impl( loops_extent, ], dtype=target_dtype, - scope="local", + scope=storage_scope, ) Scale = T.match_buffer( scale, @@ -1071,7 +1074,7 @@ def fast_decode_impl( dtype=target_dtype, offset_factor=1, strides=[s0], - scope="local", + scope=storage_scope, ) Zeros = T.match_buffer( zeros, @@ -1081,7 +1084,7 @@ def fast_decode_impl( dtype=storage_dtype, offset_factor=1, strides=[s1], - scope="local", + scope=storage_scope, ) with T.block("root"): T.reads(Compressed[0:n_storage_elems], Scale[0:1], Zeros[0:1]) @@ -1128,7 +1131,7 @@ def fast_decode_desc( n_storage_elems, ], dtype=storage_dtype, - scope="local", + scope=storage_scope, ) Decompressed = T.match_buffer( decompressed, @@ -1136,7 +1139,7 @@ def fast_decode_desc( loops_extent, ], dtype=target_dtype, - scope="local", + scope=storage_scope, ) Scale = T.match_buffer( scale, @@ -1192,7 +1195,7 @@ def fast_decode_impl( n_storage_elems, ], dtype=storage_dtype, - scope="local", + scope=storage_scope, ) Decompressed = T.match_buffer( decompressed, @@ -1200,7 +1203,7 @@ def fast_decode_impl( loops_extent, ], dtype=target_dtype, - scope="local", + scope=storage_scope, ) Scale = T.match_buffer( scale, @@ -1238,353 +1241,83 @@ def fast_decode_impl( return fast_decode_desc, fast_decode_impl -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u4_to_int8_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=4, storage_dtype="int8", target_dtype="float16", loops_extent=8), -) - -LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u2_to_int8_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=2, storage_dtype="int8", target_dtype="float16", loops_extent=8), -) - -LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u1_to_int8_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=1, storage_dtype="int8", target_dtype="float16", loops_extent=8), -) - -LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u4_to_int32_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=4, storage_dtype="int32", target_dtype="float16", loops_extent=8), -) - -LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_u4_to_int32_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="int32", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) - -LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u4_to_uint32_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=4, storage_dtype="uint32", target_dtype="float16", loops_extent=8), -) - -LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_u4_to_uint32_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="uint32", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) - -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) - -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN = ( - "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_zeros_original_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="original", - ), -) - -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN = ( - "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_zeros_rescale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="rescale", - ), -) - -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN = ( - "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_zeros_quantized_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="quantized", - ), -) - -LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=2, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) - -LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN = ( - "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_zeros_original_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN, - *get_fast_decode_intrin( - source_bit=2, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="original", - ), -) - -LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN = ( - "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_zeros_rescale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=2, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="rescale", - ), -) - -LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN = ( - "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_zeros_quantized_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN, - *get_fast_decode_intrin( - source_bit=2, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="quantized", - ), -) - -LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_u1_to_int8_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=1, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) - -LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN = ( - "lop3_fast_decode_u1_to_int8_to_f16_l8_scale_zeros_original_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN, - *get_fast_decode_intrin( - source_bit=1, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="original", - ), -) - -LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN = ( - "lop3_fast_decode_u1_to_int8_to_f16_l8_scale_zeros_rescale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=1, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="rescale", - ), -) - -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L8_INTRIN = ("lop3_fast_decode_u4_to_int8_to_i8_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=4, storage_dtype="int8", target_dtype="int8", loops_extent=8), -) - -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u4_to_int8_to_i8_l16_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L16_INTRIN, - *get_fast_decode_intrin( - source_bit=4, storage_dtype="int8", target_dtype="int8", loops_extent=16), -) - -LOP3_FAST_DECODE_UINT2_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u2_to_int8_to_i8_l16_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT2_TO_INT8_TO_INT8_L16_INTRIN, - *get_fast_decode_intrin( - source_bit=2, storage_dtype="int8", target_dtype="int8", loops_extent=16), -) - -LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_i2_to_int8_to_i8_l16_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN, - *get_fast_decode_intrin( - source_bit=2, - source_format="int", - storage_dtype="int8", - target_dtype="int8", - loops_extent=16), -) - -LOP3_FAST_DECODE_UINT1_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u1_to_int8_to_i8_l16_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT1_TO_INT8_TO_INT8_L16_INTRIN, - *get_fast_decode_intrin( - source_bit=1, storage_dtype="int8", target_dtype="int8", loops_extent=16), -) - -LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_i1_to_int8_to_i8_l16_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN, - *get_fast_decode_intrin( - source_bit=1, - source_format="int", - storage_dtype="int8", - target_dtype="int8", - loops_extent=16), -) - -LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i4_to_int8_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="int8", - source_format="int", - target_dtype="float16", - loops_extent=8, - ), -) - -LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_i4_to_int8_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="int8", - source_format="int", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) - -LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i2_to_int8_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=2, - storage_dtype="int8", - source_format="int", - target_dtype="float16", - loops_extent=8, - ), -) - -LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_i2_to_int8_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=2, - storage_dtype="int8", - source_format="int", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) - -LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i1_to_int8_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=1, - storage_dtype="int8", - source_format="int", - target_dtype="float16", - loops_extent=8, - ), -) - -LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_i1_to_int8_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=1, - storage_dtype="int8", - source_format="int", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) +# Define the intrin definitions +intrin_definitions = [ + # (source_bit, storage_dtype, target_dtype, loops_extent, storage_scope, source_format, with_scale, with_zeros, zeros_mode) + (4, "int8", "float16", 8, "local", "uint", False, False, "original"), + (4, "int8", "float16", 8, "warp", "uint", False, False, "original"), + (2, "int8", "float16", 8, "local", "uint", False, False, "original"), + (1, "int8", "float16", 8, "local", "uint", False, False, "original"), + (4, "int32", "float16", 8, "local", "uint", False, False, "original"), + (4, "int32", "float16", 8, "local", "uint", True, False, "original"), + (4, "uint32", "float16", 8, "local", "uint", False, False, "original"), + (4, "uint32", "float16", 8, "local", "uint", True, False, "original"), + (4, "int8", "float16", 8, "local", "uint", True, False, "original"), + (4, "int8", "float16", 8, "local", "uint", True, True, "original"), + (4, "int8", "float16", 8, "local", "uint", True, True, "rescale"), + (4, "int8", "float16", 8, "local", "uint", True, True, "quantized"), + (2, "int8", "float16", 8, "local", "uint", True, False, "original"), + (2, "int8", "float16", 8, "local", "uint", True, True, "original"), + (2, "int8", "float16", 8, "local", "uint", True, True, "rescale"), + (2, "int8", "float16", 8, "local", "uint", True, True, "quantized"), + (1, "int8", "float16", 8, "local", "uint", True, False, "original"), + (1, "int8", "float16", 8, "local", "uint", True, True, "original"), + (1, "int8", "float16", 8, "local", "uint", True, True, "rescale"), + (4, "int8", "int8", 8, "local", "uint", False, False, "original"), + (4, "int8", "int8", 16, "local", "uint", False, False, "original"), + (2, "int8", "int8", 16, "local", "uint", False, False, "original"), + (2, "int8", "int8", 16, "local", "int", False, False, "original"), + (1, "int8", "int8", 16, "local", "uint", False, False, "original"), + (1, "int8", "int8", 16, "local", "int", False, False, "original"), + (4, "int8", "float16", 8, "local", "int", False, False, "original"), + (4, "int8", "float16", 8, "local", "int", True, False, "original"), + (2, "int8", "float16", 8, "local", "int", False, False, "original"), + (2, "int8", "float16", 8, "local", "int", True, False, "original"), + (1, "int8", "float16", 8, "local", "int", False, False, "original"), +] + + +# Register the intrin +def initialize_tensor_intrin(): + registered_intrins: List[str] = [] + for params in intrin_definitions: + # Repack from the params + source_bit, storage_dtype, target_dtype, loops_extent, storage_scope, source_format, with_scale, with_zeros, zeros_mode = params + + # Construct the name + name_parts = [ + "lop3_fast_decode", f"{source_format[0]}{source_bit}", f"to_{storage_dtype}", + f"to_{target_dtype}", f"l{loops_extent}" + ] + if with_scale: + name_parts.append("scale") + if with_zeros: + name_parts.extend(["zeros", zeros_mode]) + if storage_scope == "warp": + name_parts.append("warp") + + name = "_".join(part for part in name_parts if part) + "_" + + # Get intrin desc and implementation + intrin = get_fast_decode_intrin( + source_bit=source_bit, + storage_dtype=storage_dtype, + source_format=source_format, + target_dtype=target_dtype, + loops_extent=loops_extent, + with_scale=with_scale, + with_zeros=with_zeros, + zeros_mode=zeros_mode, + storage_scope=storage_scope) + + # Register the intrin + TensorIntrin.register(name, *intrin) + registered_intrins.append(name) + + return registered_intrins + + +registered_intrins = initialize_tensor_intrin() def get_lop3_intrin_group( @@ -1595,6 +1328,7 @@ def get_lop3_intrin_group( with_scaling: bool = False, with_zeros: bool = False, zeros_mode: Literal["original", "rescale", "quantized"] = "original", + storage_scope: str = "local", ) -> Dict[str, str]: """ This function is used to get the intrinsic group of the LOP3 operation to avoid the overhead of fast decoding. @@ -1615,6 +1349,15 @@ def get_lop3_intrin_group( with_scale : bool, optional A boolean parameter that indicates whether scaling should be applied. By default, it is False. + with_zeros : bool, optional + A boolean parameter that indicates whether zeros should be used. By default, it is False. + + zeros_mode : Literal["original", "rescale", "quantized"], optional + The mode of zeros. It can be either "original", "rescale", or "quantized". By default, it is "original". + + storage_scope : Literal["local", "warp"], optional + The scope of the storage. It can be either "local" or "warp". By default, it is "local". + Returns ------- Dict[str, str] @@ -1630,11 +1373,13 @@ def get_lop3_intrin_group( raise ValueError("Invalid source_format. Expected 'int' or 'uint'.") source_symbol = "i" if source_format == "int" else "u" - _intrin = f"lop3_fast_decode_{source_symbol}{source_bit}_to_{storage_dtype}_to_{target_dtype}_l{loop_extent}_" + _intrin = f"lop3_fast_decode_{source_symbol}{source_bit}_to_{storage_dtype}_to_{out_dtype}_l{loop_extent}_" if with_scaling: _intrin += "scale_" if with_zeros: _intrin += f"zeros_{zeros_mode}_" + if storage_scope == "warp": + _intrin += "warp_" import_c_map = { "i4_to_f16": decode_i4_to_f16, diff --git a/bitblas/gpu/matmul_analysis.py b/bitblas/gpu/matmul_analysis.py index 6537a555a..210c560a1 100644 --- a/bitblas/gpu/matmul_analysis.py +++ b/bitblas/gpu/matmul_analysis.py @@ -734,6 +734,63 @@ def ldmatrix_permutation_16x32_32x16_32x16(kernel_i, kernel_j): return ldmatrix_index_map, inversed_index_map +# This function is used to get the index map for the stage3 of the +# Ladder weight propagation, which can be used to avoid the ldmatrix +# Instructions. +def get_ladder_stage3_map(dtype="float16", index_dtype="int32"): + + def shared_32x8_to_mma_32x8_layout(i, j): + thread_id = (i % 8) * 4 + (j // 2) + local_id = (i // 8) * 2 + (j % 2) + return thread_id, local_id + + def shared_32x16_to_mma_32x16_layout(i, j): + thread_id = (i % 8) * 4 + (j // 4) + local_id = (i // 8) * 4 + (j % 4) + return thread_id, local_id + + assert dtype in [ + "float16", + "int8", + "e4m3_float8", + "e5m2_float8", + ], "Only support float16, int8, e4m3_float8, e5m2_float8" + if dtype == "float16": + stage3_layout = shared_32x8_to_mma_32x8_layout + elif dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + stage3_layout = shared_32x16_to_mma_32x16_layout + else: + raise ValueError("Unknown dtype ", dtype) + + # IntraWarp memory layout was occurred by ldmatrix, we should lift the ld_matrix out + def ladder_stage3_permutation_16x16_32x8_32x8_16x16(kernel_i, kernel_j): + thread_id = kernel_i * 2 + kernel_j // 8 + local_id = kernel_j % 8 + new_thread_id, new_local_id = stage3_layout(thread_id, local_id) + new_kernel_i = (new_thread_id * 8 + new_local_id) // 16 + new_kernel_j = (new_thread_id * 8 + new_local_id) % 16 + return new_kernel_i, new_kernel_j + + def ladder_stage3_permutation_16x32_32x16_32x16_16x32(kernel_i, kernel_j): + thread_id = kernel_i * 2 + kernel_j // 16 + local_id = kernel_j % 16 + new_thread_id, new_local_id = stage3_layout(thread_id, local_id) + new_kernel_i = (new_thread_id * 16 + new_local_id) // 32 + new_kernel_j = (new_thread_id * 16 + new_local_id) % 32 + return new_kernel_i, new_kernel_j + + if dtype == "float16": + stage3_index_map = ladder_stage3_permutation_16x16_32x8_32x8_16x16 + else: + stage3_index_map = ladder_stage3_permutation_16x32_32x16_32x16_16x32 + + stage3_index_map = IndexMap.from_func(stage3_index_map, index_dtype=index_dtype) + # TODO(lei): index_dtype should be analyzed from the schedule + row, col = [16, 16] if dtype == "float16" else [16, 32] + inversed_index_map = stage3_index_map.inverse([row, col]) + return stage3_index_map, inversed_index_map + + def layout_propagate_chain( sch: tir.Schedule, start_block: BlockRV, diff --git a/bitblas/gpu/matmul_mma.py b/bitblas/gpu/matmul_mma.py index 629d93768..8700e6580 100644 --- a/bitblas/gpu/matmul_mma.py +++ b/bitblas/gpu/matmul_mma.py @@ -8,13 +8,14 @@ from tvm import tir, DataType from tvm.target import Target -from bitblas.base.roller import Hint -from bitblas.base.roller.rasterization import NoRasterization -from bitblas.base import analysis -from bitblas.gpu.base import GPUScheduleRule -from bitblas.gpu.matmul_mma_dequantize import MatmulTensorizationMMAWithDequantizeInfo -from bitblas.base.analysis import get_coalesced_veclen -from bitblas.gpu.matmul_analysis import ( +from ..ops.operator import TransformKind +from ..base.roller import Hint +from ..base.roller.rasterization import NoRasterization +from ..base import analysis +from .base import GPUScheduleRule +from .matmul_mma_dequantize import MatmulTensorizationMMAWithDequantizeInfo +from ..base.analysis import get_coalesced_veclen +from .matmul_analysis import ( auto_inline_consumer_chain, is_transpose_block, is_identity_block, @@ -345,8 +346,9 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring dequantize_rule = MatmulTensorizationMMAWithDequantizeInfo() return dequantize_rule.apply_config(func, config) - if hasattr(config, "block_reduction_depth") and config.block_reduction_depth is not None: - return self.apply_block_reduction_with_config(func, config) + is_cross_thread_reduce = ( + hasattr(config, "block_reduction_depth") and config.block_reduction_depth is not None) + block_reduction_depth = config.block_reduction_depth if is_cross_thread_reduce else 1 from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel get_mma_intrin_group,) @@ -398,6 +400,10 @@ def check_has_dynamic(func: tir.PrimFunc): shared_scope = config.shared_scope intrin_info = config.intrin_info + input_transform_kind = intrin_info.input_transform_kind + weight_transform_kind = intrin_info.weight_transform_kind + assert input_transform_kind <= TransformKind.IntraWarpTransform, "Only support up to intra-warp transform" + intrin_group = get_mma_intrin_group( load_scope=shared_scope, store_scope=shared_scope if cache_write_required else "global", @@ -410,10 +416,8 @@ def check_has_dynamic(func: tir.PrimFunc): smooth_b=intrin_info.smooth_b, not_use_mma_store_intrinic=False, ) - # Start Schedule # Step 0. Get schedule config. - # NOTE: we can analyze the config by the hardware spec in the future warp_row_tiles = config.warp[0] warp_col_tiles = config.warp[1] @@ -421,8 +425,8 @@ def check_has_dynamic(func: tir.PrimFunc): block_col_warps = config.block[1] // warp_col_tiles stage = config.pipeline_stage use_async = config.use_async + reduce_k = block_reduction_depth chunk = config.rstep[0] - # tensor core intrinsic size micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] @@ -436,7 +440,7 @@ def get_axis(l, r, trans): # noqa: E741 def can_enable_swizzle(dtype: str, smooth: bool): # inject_permuted_layout only support float16 currently if dtype == "float16" or dtype == "int8": - if chunk * DataType(dtype).bits != 512: + if (chunk * reduce_k) * DataType(dtype).bits != (512): # currently the swizzle rule only support 512 bit. return False # if we use smooth layout, we don't need to do swizzling @@ -451,7 +455,7 @@ def can_enable_swizzle(dtype: str, smooth: bool): i_factors, j_factors, k_factors = ( [None, 1, block_row_warps, warp_row_tiles // micro_size_x], [1, None, block_col_warps, warp_col_tiles // micro_size_y], - [None, chunk // micro_size_k], + [None, (reduce_k * chunk) // micro_size_k], ) num_ty = i_factors[2] @@ -489,8 +493,11 @@ def can_enable_swizzle(dtype: str, smooth: bool): i0, i1, i2, i3 = sch.split(i, factors=i_factors) j0, j1, j2, j3 = sch.split(j, factors=j_factors) k0, k1 = sch.split(k, k_factors) - - sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) + if reduce_k > 1: + k0, kr = sch.split(k0, [None, reduce_k]) + sch.reorder(i0, j0, i1, j1, i2, j2, kr, k0, k1, i3, j3) + else: + sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) block_idy = sch.fuse(i0, j0) block_idx = sch.fuse(i1, j1) @@ -500,8 +507,12 @@ def can_enable_swizzle(dtype: str, smooth: bool): sch.bind(batch, "blockIdx.z") sch.bind(block_idx, "blockIdx.x") sch.bind(block_idy, "blockIdx.y") - sch.bind(thread_idy, "threadIdx.y") - sch.bind(thread_idz, "threadIdx.z") + if reduce_k > 1: + thread_idz = j2 = thread_idy = sch.fuse(thread_idy, thread_idz) + sch.bind(kr, "threadIdx.z") + else: + sch.bind(thread_idy, "threadIdx.y") + sch.bind(thread_idz, "threadIdx.z") # rewrite smooth layout of shared memory def smooth_smem_layout_rewrite(block, scope, l=16, r=16, enable=True): # noqa: E741 @@ -525,18 +536,26 @@ def smooth_smem_layout_rewrite(block, scope, l=16, r=16, enable=True): # noqa: block_outer, ("read", 1), *b_lr, enable=intrin_info.inter_transform_b) smooth_smem_layout_rewrite(block_outer, ("write", 0), enable=True) - def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False, trans=False): + def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False, reduce_k=1): block_read = sch.cache_read(block, idx, shared_scope) sch.compute_at(block_read, k0, preserve_unit_loops=True) ndim = len(sch.get(block_read).iter_vars) fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - f_0, f_1, f_2, f_3, f_4 = sch.split( - fused, factors=[num_ty, num_tz, None, warp_size, vec_len]) + if reduce_k > 1: + f_r, f_0, f_1, f_2, f_3, f_4 = sch.split( + fused, factors=[reduce_k, num_ty, num_tz, None, warp_size, vec_len]) + sch.bind(f_3, "threadIdx.x") + f_0 = f_1 = sch.fuse(f_0, f_1) + sch.bind(f_0, "threadIdx.y") + sch.bind(f_r, "threadIdx.z") + else: + f_0, f_1, f_2, f_3, f_4 = sch.split( + fused, factors=[num_ty, num_tz, None, warp_size, vec_len]) + sch.bind(f_3, "threadIdx.x") + sch.bind(f_1, "threadIdx.z") + sch.bind(f_0, "threadIdx.y") - sch.bind(f_3, "threadIdx.x") - sch.bind(f_1, "threadIdx.z") - sch.bind(f_0, "threadIdx.y") sch.vectorize(f_4) sch.unroll(f_2) # Apply Swizzling @@ -557,7 +576,7 @@ def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False, tra vec_len=list(config.vectorize.values())[0], can_swizzle=can_swizzle_a, is_smooth=intrin_info.smooth_a, - trans=intrin_info.trans_a, + reduce_k=reduce_k, ) b_g2s = fetch_to_shared( block_outer, @@ -565,7 +584,7 @@ def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False, tra vec_len=list(config.vectorize.values())[1], can_swizzle=can_swizzle_b, is_smooth=intrin_info.smooth_b, - trans=intrin_info.trans_b, + reduce_k=reduce_k, ) # rewrite global smooth layout @@ -666,376 +685,16 @@ def inverse_permutation(i, j, ii, jj): i0, i1 = sch.split(i, factors=[None, b_lr[0]]) j0, j1 = sch.split(j, factors=[None, b_lr[1]]) sch.reorder(i0, j0, i1, j1) - bb = sch.blockize(i1) - sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) - sch.tensorize(bb, intrin_group["load_b"]) - - def tensorize_init_store_compute(): - sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) - sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) - sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) - - tensorize_init_store_compute() - - if stage > 1: - sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1]) - sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) - if use_async: - sch.annotate(k0, "software_pipeline_async_stages", [0]) - - # plan rasteration - if not isinstance(config.rasterization_plan, NoRasterization): - device_func, invoke_func = config.rasterization_plan.get_code() - import_source.append(device_func) - sch.annotate( - sch.get_loops(block_init_c)[-2], - ann_key="inject_customized_code_prepend", - ann_val=invoke_func, - ) - # plan import source - if len(import_source) > 0: - sch.annotate( - thread_idz, - ann_key="pragma_import_c", - ann_val=("\n").join(import_source), - ) - return sch - - def apply_block_reduction_with_config( # pylint: disable=too-many-locals,missing-docstring - self, - func: tir.PrimFunc, - config: Hint, - ) -> Optional[tir.Schedule]: - if "dequantize_info" in func.attrs: - dequantize_rule = MatmulTensorizationMMAWithDequantizeInfo() - return dequantize_rule.sch_shared_memory_prefetch_block_reduction_with_config( - func, config) - - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_mma_intrin_group,) - - import_source: List[str] = [] - - sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - - if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): - return None - - reduction_blocks = get_reduction_blocks(sch, blocks) - if reduction_blocks is None: - return None - - main_block = reduction_blocks[0] - - output_blocks = [sch.get(block) for block in sch.get_output_blocks(root_block)] - - def check_require_cache(func: tir.PrimFunc, config): - conditions: List[bool] = [] - - # check if has dynamic symbolic - def check_has_dynamic(func: tir.PrimFunc): - for param in func.params: - if param not in func.buffer_map: - continue - arg = func.buffer_map[param] - for i in arg.shape: - if isinstance(i, tir.Var): - return True - return False - - conditions.append(check_has_dynamic(func)) - # check if has post process - conditions.append(sch.get(main_block) not in output_blocks) - # check if not use async copy - conditions.append(config.use_async is False) - return any(conditions) - - cache_write_required = check_require_cache(func, config=config) - - # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] - if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): - sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) - - shared_scope = config.shared_scope - - intrin_info = config.intrin_info - intrin_group = get_mma_intrin_group( - load_scope=shared_scope, - store_scope=shared_scope if cache_write_required else "global", - a_dtype=intrin_info.in_dtype, - b_dtype=intrin_info.in_dtype, - out_dtype=intrin_info.out_dtype, - trans_a=intrin_info.trans_a, - trans_b=intrin_info.trans_b, - smooth_a=intrin_info.smooth_a, - smooth_b=intrin_info.smooth_b, - not_use_mma_store_intrinic=False, - ) - - # Start Schedule - # Step 0. Get schedule config. - warp_row_tiles = config.warp[0] - warp_col_tiles = config.warp[1] - block_row_warps = config.block[0] // warp_row_tiles - block_col_warps = config.block[1] // warp_col_tiles - stage = config.pipeline_stage - use_async = config.use_async - assert (config.block_reduction_depth is not None), "block_reduction_depth is required" - reduce_k = config.block_reduction_depth - chunk = config.rstep[0] // reduce_k - - # tensor core intrinsic size - micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] - - # get the axis for layout transform - def get_axis(l, r, trans): # noqa: E741 - return (r, l) if trans else (l, r) # noqa: E741 - - a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) - b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) - - def can_enable_swizzle(dtype: str, smooth: bool): - # inject_permuted_layout only support float16 currently - if dtype == "float16" or dtype == "int8": - # introduce the constraint of reduce_k because reduce_k will doubling the size of - if (chunk * reduce_k) * DataType(dtype).bits != (512): - # currently the swizzle rule only support 512 bit. - return False - # if we use smooth layout, we don't need to do swizzling - return not smooth - return False - - can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) - can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b) - - warp_size = 32 - - i_factors, j_factors, k_factors = ( - [None, 1, block_row_warps, warp_row_tiles // micro_size_x], - [1, None, block_col_warps, warp_col_tiles // micro_size_y], - [None, (reduce_k * chunk) // micro_size_k], - ) - - num_ty = i_factors[2] - num_tz = j_factors[2] - x_pad_factor = i_factors[2] * i_factors[3] - y_pad_factor = j_factors[2] * j_factors[3] - k_pad_factor = k_factors[1] - - # Step 2. Padding for dynamic shape kernels - sch.pad_einsum( - main_block, - [ - 1, - micro_size_x * x_pad_factor, - micro_size_y * y_pad_factor, - micro_size_k * k_pad_factor, - ], - ) - - # Step 3. Schedule matmul to use tensor core - block = main_block - - batch, i, j, k = sch.get_loops(block) - - # inner loops for tensor core computation - i, i_inner = sch.split(i, factors=[None, micro_size_x]) - j, j_inner = sch.split(j, factors=[None, micro_size_y]) - k, k_inner = sch.split(k, factors=[None, micro_size_k]) - - sch.reorder(i, j, k, i_inner, j_inner, k_inner) - - block_inner = block - block_outer = sch.blockize(i_inner) - - i0, i1, i2, i3 = sch.split(i, factors=i_factors) - j0, j1, j2, j3 = sch.split(j, factors=j_factors) - k0, k1 = sch.split(k, k_factors) - k0, kr = sch.split(k0, [None, reduce_k]) - - sch.reorder(i0, j0, i1, j1, i2, j2, kr, i3, j3, k0, k1) - - block_idy = sch.fuse(i0, j0) - block_idx = sch.fuse(i1, j1) - thread_idy = i2 - thread_idz = j2 - - sch.bind(batch, "blockIdx.z") - sch.bind(block_idx, "blockIdx.x") - sch.bind(block_idy, "blockIdx.y") - thread_idz = j2 = thread_idy = sch.fuse(thread_idy, thread_idz) - sch.bind(thread_idy, "threadIdx.y") - sch.bind(kr, "threadIdx.z") - - # rewrite smooth layout of shared memory - def smooth_smem_layout_rewrite(block, scope, l=16, r=16, enable=True): # noqa: E741 - if not enable: - return - sch.transform_layout( - block, - scope, - lambda b, i, j: ( - b, - i // l, - j // r, - i % l, - j % r, - ), - ) - - smooth_smem_layout_rewrite( - block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) - smooth_smem_layout_rewrite( - block_outer, ("read", 1), *b_lr, enable=intrin_info.inter_transform_b) - smooth_smem_layout_rewrite(block_outer, ("write", 0), enable=True) - - def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False, trans=False): - block_read = sch.cache_read(block, idx, shared_scope) - sch.compute_at(block_read, k0, preserve_unit_loops=True) - ndim = len(sch.get(block_read).iter_vars) - fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - - f_r, f_0, f_1, f_2, f_3, f_4 = sch.split( - fused, factors=[reduce_k, num_ty, num_tz, None, warp_size, vec_len]) - - sch.bind(f_3, "threadIdx.x") - f_0 = f_1 = sch.fuse(f_0, f_1) - sch.bind(f_0, "threadIdx.y") - sch.bind(f_r, "threadIdx.z") - sch.vectorize(f_4) - sch.unroll(f_2) - # Apply Swizzling - sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) - # if not, apply padding to alleviate bank conflict - if not (can_swizzle or is_smooth): - pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 - sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) - sch.annotate(f_2, "pragma_unroll_explicit", False) - return block_read - - if len(config.vectorize.values()) < 2: - return None - - a_g2s = fetch_to_shared( - block_outer, - 0, - vec_len=list(config.vectorize.values())[0], - can_swizzle=can_swizzle_a, - is_smooth=intrin_info.smooth_a, - trans=intrin_info.trans_a, - ) - b_g2s = fetch_to_shared( - block_outer, - 1, - vec_len=list(config.vectorize.values())[1], - can_swizzle=can_swizzle_b, - is_smooth=intrin_info.smooth_b, - trans=intrin_info.trans_b, - ) - - # rewrite global smooth layout - def smooth_gmem_layout_rewrite(sch, block, enable=True, trans=False, matrix_name="A"): - if not enable: - return - # step1: find the first producer block - # Notes: we assume the layout propagate happens in the first producer block - # otherwise, the layout transform will have no effect as it will transform both - # read and write buffer - producers = _collect_producers(sch, block) - g2s_block = a_g2s if matrix_name == "A" else b_g2s - propagate_block: tir.Block = (producers[-1] if len(producers) > 0 else g2s_block) - - # step2: transform the layout with inverse permutation - intra_indexmap, _ = get_propagate_map( - trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) - - def inverse_permutation(i, j, ii, jj): - return (i, j, *intra_indexmap.map_indices([ii, jj])) - - sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) - - smooth_gmem_layout_rewrite( - sch, a_g2s, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") - smooth_gmem_layout_rewrite( - sch, b_g2s, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") - auto_inline_producers(sch, a_g2s) - auto_inline_producers(sch, b_g2s) - - # create read cache to load matrix from shared memory to wmma fragments - A_mat = sch.cache_read(block_outer, 0, "warp") - B_mat = sch.cache_read(block_outer, 1, "warp") - sch.compute_at(A_mat, k1) - sch.compute_at(B_mat, k1) - - # create write cache to store matrix from wmma fragments to shared memory and global memory - if cache_write_required: - accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) - - store = sch.cache_write(block_outer, 0, "warp") - sch.reverse_compute_at(store, j2) - - # split the store loop to match hardware intrinsic pattern - i, j = sch.get_loops(store)[-2:] - i0, i1 = sch.split(i, factors=[None, micro_size_x], preserve_unit_iters=False) - j0, j1 = sch.split(j, factors=[None, micro_size_y], preserve_unit_iters=False) - sch.reorder(i0, j0, i1, j1) - - if cache_write_required: - auto_inline_consumer_chain(sch, accumulator_shared_to_global) - sch.reverse_compute_at( - accumulator_shared_to_global, - sch.get_loops(store)[-5], - preserve_unit_loops=True, - ) - vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) - fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:]) + if weight_transform_kind >= TransformKind.LDMatrixTransform: + fused = sch.fuse(i1, j1) + vec_len = get_coalesced_veclen(sch.get(B_mat)) f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) sch.bind(f1, "threadIdx.x") sch.vectorize(f2) - sch.unroll(f0) - sch.annotate(f0, "pragma_unroll_explicit", False) else: - auto_inline_consumer_chain(sch, store) - - block_init_c = sch.decompose_reduction(block_outer, k0) - block_init_c_inner = sch.get_child_blocks(block_init_c)[0] - - # Tensorization by hardware intrinsics - index_map_a, index_map_b, index_map_c = intrin_group["index_map"] - - sch.transform_layout( - A_mat, - ("write", 0), - get_warp_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), - ) - sch.transform_layout( - B_mat, - ("write", 0), - get_warp_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), - ) - sch.transform_layout( - store, - ("read", 0), - get_warp_index_map(index_map_c, is_5d=True), - ) - - i, j = sch.get_loops(A_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, a_lr[0]]) - j0, j1 = sch.split(j, factors=[None, a_lr[1]]) - sch.reorder(i0, j0, i1, j1) - ba = sch.blockize(i1) - # sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) - sch.tensorize(ba, intrin_group["load_a"]) - - i, j = sch.get_loops(B_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, b_lr[0]]) - j0, j1 = sch.split(j, factors=[None, b_lr[1]]) - sch.reorder(i0, j0, i1, j1) - bb = sch.blockize(i1) - # sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) - sch.tensorize(bb, intrin_group["load_b"]) + bb = sch.blockize(i1) + sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) + sch.tensorize(bb, intrin_group["load_b"]) def tensorize_init_store_compute(): sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index 3a8cf048a..b9a0167ab 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -7,14 +7,14 @@ from contextlib import suppress from tvm import tir, DataType - -from bitblas.base.roller.hint import Hint, IntrinInfo from tvm.target import Target -from bitblas.base.roller.rasterization import NoRasterization -from bitblas.base import analysis -from bitblas.gpu.base import GPUScheduleRule -from bitblas.base.analysis import get_coalesced_veclen -from bitblas.gpu.matmul_analysis import ( + +from ..base.roller.hint import Hint, IntrinInfo +from ..base.roller.rasterization import NoRasterization +from ..base import analysis +from .base import GPUScheduleRule +from ..base.analysis import get_coalesced_veclen +from .matmul_analysis import ( auto_inline_consumer_chain, auto_inline_producers, get_reduction_blocks, diff --git a/bitblas/ops/general_matmul/tirscript/matmul_impl.py b/bitblas/ops/general_matmul/tirscript/matmul_impl.py index b093f0d9c..6a3e1de2d 100644 --- a/bitblas/ops/general_matmul/tirscript/matmul_impl.py +++ b/bitblas/ops/general_matmul/tirscript/matmul_impl.py @@ -5,6 +5,7 @@ from tvm import te from bitblas.gpu.matmul_analysis import get_propagate_map from bitblas.ops.operator import TransformKind +from typing import Union def matmul_nn( @@ -307,9 +308,14 @@ def select_implementation( accum_dtype="float16", with_bias=False, layout="nt", - propagate_a: TransformKind = TransformKind.NonTransform, - propagate_b: TransformKind = TransformKind.NonTransform, + propagate_a: Union[int, TransformKind] = TransformKind.NonTransform, + propagate_b: Union[int, TransformKind] = TransformKind.NonTransform, ): + if isinstance(propagate_a, int): + propagate_a = TransformKind(propagate_a) + if isinstance(propagate_b, int): + propagate_b = TransformKind(propagate_b) + if layout == "nn": if propagate_a or propagate_b: raise ValueError( diff --git a/bitblas/ops/ladder_permutate/ladder_permutate_impl.py b/bitblas/ops/ladder_permutate/ladder_permutate_impl.py index 76b5a01fb..2dcf79bcb 100644 --- a/bitblas/ops/ladder_permutate/ladder_permutate_impl.py +++ b/bitblas/ops/ladder_permutate/ladder_permutate_impl.py @@ -1,6 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from bitblas.gpu.matmul_analysis import get_propagate_map +from bitblas.gpu.matmul_analysis import ( + get_propagate_map, + get_ladder_stage3_map, +) from typing import Literal from tvm import te, IRModule, DataType from tvm.tir import IndexMap @@ -28,6 +31,8 @@ def select_implementation( intra_index_map, _ = get_propagate_map( transpose_matrix, dtype=datatype, matrix_name=propagate_kind) + ladder_stage3_map, _ = get_ladder_stage3_map(dtype=datatype) + target_dtype = DataType(datatype) scaling_factor = 1 if dequantize_bits > 0 and dequantize_bits < target_dtype.bits: @@ -46,6 +51,20 @@ def select_implementation( None, ) + initial_indices = ladder_stage3_map.initial_indices + scaling_final_indices = ladder_stage3_map.map_indices( + initial_indices[:-1] + [initial_indices[-1] * scaling_factor]) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + ladder_stage3_map = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + ladder_stage3_map_inverse = ladder_stage3_map.inverse([l, r]) + inp = te.placeholder((M, N // scaling_factor), name="inp", dtype=storage_dtype) args = [inp] @@ -75,6 +94,22 @@ def fcompute(*args): name="intra_warp_permutate", ) args.append(intra_warp) + if transform_kind >= 3: + arg = args[-1] + + def fcompute(*args): + warp_i, warp_j = args[-2:] + spatial_args = args[:-2] + permutate_i, permutate_j = ladder_stage3_map_inverse.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, permutate_i, permutate_j) + return arg[new_index] + + out = te.compute( + (M // l, (N // scaling_factor) // r, l, r), + fcompute, + name="permutate", + ) + args.append(out) args = [args[0], args[-1]] func = te.create_prim_func(args) diff --git a/bitblas/ops/lop3_permutate/__init__.py b/bitblas/ops/lop3_permutate/__init__.py index 7715be471..10c452b3d 100644 --- a/bitblas/ops/lop3_permutate/__init__.py +++ b/bitblas/ops/lop3_permutate/__init__.py @@ -13,7 +13,7 @@ class LOP3PermutateConfig: M: int N: int datatype: Literal["float16", "int8"] = "float16" - storage_dtype: Literal["int8", "uint8", "int32", "uint32"] = "int32" + storage_dtype: Literal["int8", "uint8", "int32", "uint32"] = "int8" dequantize_bits: int = 4 @@ -28,10 +28,10 @@ def __init__( # consider to warp the arguments to MatmulConfig super().__init__(name, config, target) + target = self.target if target.kind.name != "llvm": raise ValueError("Currently only support llvm target for Permutation") - self.target = target self._build_runtime_module(target) def _select_implementation(self): diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index d35476ee5..1ca5cf0a3 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -10,7 +10,7 @@ import ctypes from typing import List, Dict, Any, Optional import numpy as np -from ..base import fast_tune, fast_tune_with_dynamic_range +from bitblas.base import fast_tune, fast_tune_with_dynamic_range from copy import deepcopy from bitblas.base.arch import get_arch from bitblas.utils.tensor_adapter import tvm_tensor_to_torch @@ -27,6 +27,7 @@ class TransformKind(IntEnum): NonTransform = 0 InterWarpTransform = 1 IntraWarpTransform = 2 + LDMatrixTransform = 3 @dataclass(frozen=True) diff --git a/bitblas/utils/__init__.py b/bitblas/utils/__init__.py index d4ded65e6..6f2f95e3b 100644 --- a/bitblas/utils/__init__.py +++ b/bitblas/utils/__init__.py @@ -4,6 +4,7 @@ from .tensor_adapter import tvm_tensor_to_torch, lazy_tvm_tensor_to_torch, lazy_torch_to_tvm_tensor # noqa: F401 from .target_detector import get_all_nvidia_targets, auto_detect_nvidia_target # noqa: F401 from .rtmod_analysis import get_annotated_device_mod # noqa: F401 +from .weight_propagate import apply_transform_on_input # noqa: F401 import os import subprocess diff --git a/bitblas/utils/weight_propagate.py b/bitblas/utils/weight_propagate.py new file mode 100644 index 000000000..90cbe3274 --- /dev/null +++ b/bitblas/utils/weight_propagate.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas import tvm +from tvm import te +from tvm.tir import IndexMap +from tvm.contrib.dlpack import to_pytorch_func +import torch + + +def apply_transform_on_input(input: torch.Tensor, index_map: IndexMap) -> torch.Tensor: + dtype = str(input.dtype).split(".")[1] + inp = te.placeholder(input.shape, name="inp", dtype=dtype) + args = [inp] + arg = args[-1] + + def fcompute(*args): + warp_i, warp_j = args[-2:] + spatial_args = args[:-2] + permutate_i, permutate_j = index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, permutate_i, permutate_j) + return arg[new_index] + + out = te.compute( + input.shape, + fcompute, + name="permutate", + ) + args.append(out) + func = te.create_prim_func(args) + rt_mod = tvm.build(func, target="llvm", name="permutate") + output = torch.zeros_like(input) + torch_func = to_pytorch_func(rt_mod) + torch_func(input, output) + + return output diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 000000000..194cb1ba8 --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,33 @@ +# formatting +yapf==0.32.0 +toml==0.10.2 +tomli==2.0.1 +ruff==0.1.5 +codespell==2.3.0 + +cffi +cpplint +Cython +decorator +docutils +dtlib +numpy>=1.23.5 +pytest>=6.2.4 +pytest_xdist>=2.2.1 +packaging>=21.0 +PyYAML +tqdm>=4.62.3 +typing_extensions>=4.10.0 +requests +attrs +cloudpickle +ml_dtypes +psutil +scipy +tornado +torch +thefuzz +tabulate +wheel +setuptools +auto-gptq diff --git a/testing/python/module/test_repack_from_gptq.py b/testing/python/module/test_repack_from_gptq.py new file mode 100644 index 000000000..f613acc93 --- /dev/null +++ b/testing/python/module/test_repack_from_gptq.py @@ -0,0 +1,73 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import bitblas +import torch + +try: + import auto_gptq # noqa +except ImportError as e: + raise ImportError("Please install auto-gptq by running `pip install auto-gptq`") from e + +from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import ( + QuantLinear as CudaOldQuantLinear,) + +torch.manual_seed(0) +bitblas.set_log_level("DEBUG") + + +def assert_output_with_gptq(m, in_features, out_features, group_size): + _, linear, s, _ = bitblas.quantization.gen_quant4(in_features, out_features, group_size) + + zeros = torch.full((in_features // group_size, out_features), 7, dtype=torch.int32) + + cuda_old_linear = CudaOldQuantLinear( + bits=4, + group_size=group_size, + infeatures=in_features, + outfeatures=out_features, + bias=False, + ) + + cuda_old_linear.pack(linear, s.T, zeros.T, g_idx=None) + + bitblas_linear = bitblas.Linear( + opt_M=m, + in_features=in_features, + out_features=out_features, + bias=False, + A_dtype="float16", # activation A dtype + W_dtype="uint4", # weight W dtype + accum_dtype="float16", # accumulation dtype + out_dtype="float16", # output dtype + # configs for weight only quantization + group_size=group_size, # setting for grouped quantization + with_scaling=True, # setting for scaling factor + with_zeros=True, # setting for zeros + zeros_mode="quantized", # setting for how to calculating zeros + ) + # Repack weights from CudaOldQuantLinear to BitBLAS linear module + bitblas_linear.repack_from_gptq(cuda_old_linear) + + # Prepare input data + inp = torch.rand(m, in_features, dtype=torch.float16, device="cuda") + + # Move models to CUDA for execution + cuda_old_linear = cuda_old_linear.to("cuda") + bitblas_linear = bitblas_linear.to("cuda") + + # Perform inference without gradient calculations + with torch.no_grad(): + res_cuda_old = cuda_old_linear(inp) + res_bitblas = bitblas_linear(inp) + + # Verify the outputs are close within specified tolerances + torch.testing.assert_close(res_bitblas, res_cuda_old, rtol=1e-0, atol=1e-1) + + +def test_assert_output_with_gptq(): + assert_output_with_gptq(1, 256, 256, 64) + assert_output_with_gptq(1, 256, 256, -1) + + +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/operators/test_general_matmul_splitk_ops.py b/testing/python/operators/test_general_matmul_splitk_ops.py index fcdf90239..3183efb8f 100644 --- a/testing/python/operators/test_general_matmul_splitk_ops.py +++ b/testing/python/operators/test_general_matmul_splitk_ops.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import pytest + import bitblas from bitblas.ops.general_matmul_splitk import MatmulWithSplitK, MatmulConfigWithSplitK @@ -11,8 +11,8 @@ def get_codegen_result(ops): # fmt: off -def matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, - with_bias, group_size, with_scaling, with_zeros, zeros_mode): +def matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, + group_size, with_scaling, with_zeros, zeros_mode): matmul_config = MatmulConfigWithSplitK( M=M, @@ -36,15 +36,15 @@ def matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, la def test_matmul_codegen_default(): - matmul_codegen_default(1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, - None) - matmul_codegen_default(16, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, - None) + matmul_codegen_default(1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, + -1, False, False, None) + matmul_codegen_default(16, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, + -1, False, False, None) def matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, - layout, with_bias, group_size, with_scaling, with_zeros, - zeros_mode): + layout, with_bias, group_size, with_scaling, with_zeros, + zeros_mode): import torch torch.random.manual_seed(0) matmul_config = MatmulConfigWithSplitK( @@ -77,15 +77,16 @@ def matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accum_dty output_torch = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1]) torch.testing.assert_close(output_bitblas, output_torch, rtol=1e-2, atol=1e-1) + def test_matmul_torch_forward_consistent(): - matmul_torch_forward_consistent(1, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, - False, None) - matmul_torch_forward_consistent(4, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, - False, None) - -def matmul_torch_forward_fp8e4m3(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, - layout, with_bias, group_size, with_scaling, with_zeros, - zeros_mode): + matmul_torch_forward_consistent(1, 1, 4096, 12800, "float16", "float16", "float16", "float16", + "nt", False, -1, False, False, None) + matmul_torch_forward_consistent(4, 1, 4096, 12800, "float16", "float16", "float16", "float16", + "nt", False, -1, False, False, None) + + +def matmul_torch_forward_fp8e4m3(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, + with_bias, group_size, with_scaling, with_zeros, zeros_mode): import torch torch.random.manual_seed(0) matmul_config = MatmulConfigWithSplitK( @@ -135,7 +136,7 @@ def map_torch_type(intype): matmul.forward(torch_a, torch_b) print("torch_ref_out", ref_out) print("bitblas_out", bitblas_out) - + matmul.forward(torch_a, torch_b, output=bitblas_out) print("torch_ref_out", ref_out) print("bitblas_out", bitblas_out) @@ -146,12 +147,14 @@ def map_torch_type(intype): torch.testing.assert_close(bitblas_out, ref_out, rtol=1e0, atol=1e-1) + @bitblas.testing.requires_cuda_compute_version(8, 9) def test_matmul_torch_forward_fp8e4m3(): - matmul_torch_forward_fp8e4m3(1, 16, 4096, 12800, "e4m3_float8", "e4m3_float8", "float32", "float16", "nt", False, -1, False, - False, None) - matmul_torch_forward_fp8e4m3(4, 16, 4096, 12800, "e4m3_float8", "e4m3_float8", "float32", "float16", "nt", False, -1, False, - False, None) + matmul_torch_forward_fp8e4m3(1, 16, 4096, 12800, "e4m3_float8", "e4m3_float8", "float32", + "float16", "nt", False, -1, False, False, None) + matmul_torch_forward_fp8e4m3(4, 16, 4096, 12800, "e4m3_float8", "e4m3_float8", "float32", + "float16", "nt", False, -1, False, False, None) + # fmt: on if __name__ == "__main__": diff --git a/testing/python/operators/test_general_matmul_tile_schedule.py b/testing/python/operators/test_general_matmul_tile_schedule.py new file mode 100644 index 000000000..5e7859075 --- /dev/null +++ b/testing/python/operators/test_general_matmul_tile_schedule.py @@ -0,0 +1,226 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import bitblas +from bitblas import tvm +from bitblas.ops.general_matmul.tirscript import ( + matmul_select_implementation,) +import logging +from bitblas import set_log_level + +set_log_level(logging.DEBUG) + + +def assert_correctness_with_block_reduce( + M=None, + N=256, + K=256, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + propagate_a=0, + propagate_b=0, +): + matmul_func = matmul_select_implementation( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + propagate_a=propagate_a, + propagate_b=propagate_b)["main"] + target = bitblas.auto_detect_nvidia_target() + intrin_info = bitblas.base.hint.IntrinInfo( + in_dtype=in_dtype, + out_dtype=accum_dtype, + trans_b=True, + input_transform_kind=propagate_a, + weight_transform_kind=propagate_b, + ) + arch = bitblas.base.CUDA(target=target) + ref_sch = bitblas.gpu.MatmulTensorizationMMA().apply_config( + matmul_func, + config=bitblas.base.Hint.from_dict({ + "arch": arch, + "block": [16, 128], + "warp": [16, 32], + "rstep": [128], + "pipeline_stage": 4, + "use_async": True, + "intrin_info": intrin_info, + "shared_scope": "shared.dyn", + "vectorize": { + "b": 8, + "a": 8 + }, + }), + ) + with tvm.transform.PassContext(config={ + "tir.use_async_copy": True, + "tir.merge_static_smem": False + }): + ref_rt_mod = tvm.build(ref_sch.mod, target=target) + + block_reduce_sch = bitblas.gpu.MatmulTensorizationMMA().apply_config( + matmul_func, + config=bitblas.base.Hint.from_dict({ + "arch": arch, + "block": [16, 128], + "warp": [16, 32], + "rstep": [128], + "pipeline_stage": 4, + "use_async": True, + "intrin_info": intrin_info, + "shared_scope": "shared.dyn", + "vectorize": { + "b": 8, + "a": 8 + }, + "block_reduction_depth": 2, + }), + ) + with tvm.transform.PassContext(config={ + "tir.use_async_copy": True, + "tir.merge_static_smem": False + }): + block_reduce_rt_mod = tvm.build(block_reduce_sch.mod, target=target) + + # Check correctness + import numpy as np + tvm_a = tvm.nd.array(np.random.randn(M, K).astype(in_dtype), device=tvm.cuda()) + tvm_b = tvm.nd.array(np.random.randn(N, K).astype(in_dtype), device=tvm.cuda()) + tvm_c = tvm.nd.array(np.random.randn(M, N).astype(out_dtype), device=tvm.cuda()) + tvm_c_ref = tvm.nd.array(np.zeros((M, N)).astype(out_dtype), device=tvm.cuda()) + + ref_rt_mod(tvm_a, tvm_b, tvm_c_ref) + + block_reduce_rt_mod(tvm_a, tvm_b, tvm_c) + np.testing.assert_allclose(tvm_c.asnumpy(), tvm_c_ref.asnumpy(), rtol=1e-3, atol=1e-3) + + +def test_assert_correctness_with_block_reduce(): + assert_correctness_with_block_reduce( + M=256, + N=256, + K=256, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + propagate_a=0, + propagate_b=0) + assert_correctness_with_block_reduce( + M=256, + N=256, + K=256, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + propagate_a=0, + propagate_b=2) + assert_correctness_with_block_reduce( + M=256, + N=256, + K=256, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + propagate_a=2, + propagate_b=2) + + +def assert_correctness_with_ladder_ldmatrix_propagate( + M=None, + N=256, + K=256, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + block_reduction_depth=1, +): + matmul_func = matmul_select_implementation( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + propagate_a=0, + propagate_b=3)["main"] + propagate_b = 3 + target = bitblas.auto_detect_nvidia_target() + intrin_info = bitblas.base.hint.IntrinInfo( + in_dtype=in_dtype, + out_dtype=accum_dtype, + trans_b=True, + input_transform_kind=0, + weight_transform_kind=propagate_b, + ) + arch = bitblas.base.CUDA(target=target) + block_reduce_sch = bitblas.gpu.MatmulTensorizationMMA().apply_config( + matmul_func, + config=bitblas.base.Hint.from_dict({ + "arch": arch, + "block": [16, 128], + "warp": [16, 32], + "rstep": [128], + "pipeline_stage": 4, + "use_async": True, + "intrin_info": intrin_info, + "shared_scope": "shared.dyn", + "vectorize": { + "b": 8, + "a": 8 + }, + "block_reduction_depth": block_reduction_depth, + }), + ) + with tvm.transform.PassContext(config={ + "tir.use_async_copy": True, + "tir.merge_static_smem": False + }): + block_reduce_rt_mod = tvm.build(block_reduce_sch.mod, target=target) + + # Evaluate the correctness + import numpy as np + a = np.random.randn(M, K).astype(np.float16 if in_dtype == "float16" else "int8") + b = np.random.randn(N, K).astype(np.float16 if in_dtype == "float16" else "int8") + c = np.random.randn(M, N).astype(np.float16 if in_dtype == "float16" else "int8") + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=K, + datatype=in_dtype, + storage_dtype=in_dtype, + transform_kind=propagate_b, + transpose_matrix=True, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + tvm_b = tvm.nd.array(b) + tvm_transformed_b = ladder_permutate.get_profile_tensors()[-1] + ladder_permutate.rt_mod(tvm_b, tvm_transformed_b) + ladder_transformed_b = tvm_transformed_b.asnumpy() + + # transformed_b = b + tvm_a = tvm.nd.array(a, device=tvm.cuda(0)) + tvm_b = tvm.nd.array(ladder_transformed_b, device=tvm.cuda(0)) + tvm_c = tvm.nd.array(c, device=tvm.cuda(0)) + + block_reduce_rt_mod(tvm_a, tvm_b, tvm_c) + print("removed ldmatrix b output is \n", tvm_c) + + np_c = np.dot(a, b.T) + print("numpy output is \n", np_c) + + +def test_assert_correctness_with_ladder_ldmatrix_propagate(): + assert_correctness_with_ladder_ldmatrix_propagate( + M=256, N=256, K=256, in_dtype="float16", out_dtype="float16", accum_dtype="float16") + assert_correctness_with_ladder_ldmatrix_propagate( + M=256, N=256, K=256, in_dtype="int8", out_dtype="int8", accum_dtype="int32") + + +# fmt: on +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/operators/test_ladder_permutate_ops.py b/testing/python/operators/test_ladder_permutate_ops.py index 8fa54a4ca..e9fe452eb 100644 --- a/testing/python/operators/test_ladder_permutate_ops.py +++ b/testing/python/operators/test_ladder_permutate_ops.py @@ -1,6 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import pytest import bitblas from bitblas.ops.ladder_permutate import LadderPermutate, LadderPermutateConfig from bitblas import tvm @@ -41,9 +40,14 @@ def ladder_permutate_profile_latency( def test_ladder_permutate_profile_latency(): - ladder_permutate_profile_latency(1024, 1024, "float16", -1, "float16", "B", True, 1, "nvidia-mma") - ladder_permutate_profile_latency(1024, 1024, "float16", -1, "float16", "B", True, 2, "nvidia-mma") + ladder_permutate_profile_latency(1024, 1024, "float16", -1, "float16", "B", True, 1, + "nvidia-mma") + ladder_permutate_profile_latency(1024, 1024, "float16", -1, "float16", "B", True, 2, + "nvidia-mma") + ladder_permutate_profile_latency(1024, 1024, "float16", -1, "float16", "B", True, 3, + "nvidia-mma") ladder_permutate_profile_latency(1024, 1024, "float16", 4, "uint32", "B", True, 2, "nvidia-mma") + ladder_permutate_profile_latency(1024, 1024, "float16", 4, "uint32", "B", True, 3, "nvidia-mma") def ladder_permutate_profile_latency_cuda( @@ -73,16 +77,19 @@ def ladder_permutate_profile_latency_cuda( config=ladder_permutate_config, target="cuda", ) - # ladder_permutate.hardware_aware_finetune() latency = ladder_permutate.profile_latency() - print(latency) assert latency def test_ladder_permutate_profile_latency_cuda(): - ladder_permutate_profile_latency_cuda(1024, 1024, "float16", -1, "float16", "A", True, 1, "nvidia-mma") - ladder_permutate_profile_latency_cuda(1024, 1024, "float16", -1, "float16", "A", True, 2, "nvidia-mma") - ladder_permutate_profile_latency_cuda(1024, 1024, "float16", 4, "uint32", "A", True, 2, "nvidia-mma") + ladder_permutate_profile_latency_cuda(1024, 1024, "float16", -1, "float16", "A", True, 1, + "nvidia-mma") + ladder_permutate_profile_latency_cuda(1024, 1024, "float16", -1, "float16", "A", True, 2, + "nvidia-mma") + ladder_permutate_profile_latency_cuda(1024, 1024, "float16", 4, "uint32", "A", True, 2, + "nvidia-mma") + + # fmt: on if __name__ == "__main__": diff --git a/testing/python/weight_transform/test_ladder_transform_stage3.py b/testing/python/weight_transform/test_ladder_transform_stage3.py new file mode 100644 index 000000000..1b9001cd0 --- /dev/null +++ b/testing/python/weight_transform/test_ladder_transform_stage3.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import bitblas +import torch +from bitblas.gpu.matmul_analysis import (get_ladder_stage3_map) + +torch.manual_seed(0) + + +def compare_propagate_with_torch_iter_4_fp16(M, N, inner_m=16, inner_n=16): + b = torch.randn(M // inner_m, N // inner_n, inner_m, inner_n, dtype=torch.float16) + + def shared_32x8_to_mma_32x8_layout(i, j): + thread_id = (i % 8) * 4 + (j // 2) + local_id = (i // 8) * 2 + (j % 2) + return thread_id, local_id + + torch_transformed_b = torch.zeros((M // inner_m, N // inner_n, inner_m, inner_n), + dtype=torch.float16) + for i in range(M // inner_m): + for j in range(N // inner_n): + for ii in range(inner_m): + for jj in range(inner_n): + dummy_ii = (ii * inner_n + jj) // 8 + dummy_jj = (ii * inner_n + jj) % 8 + new_dummy_ii, new_dummy_jj = shared_32x8_to_mma_32x8_layout(dummy_ii, dummy_jj) + new_ii = (new_dummy_ii * 8 + new_dummy_jj) // inner_n + new_jj = (new_dummy_ii * 8 + new_dummy_jj) % inner_n + torch_transformed_b[i, j, new_ii, new_jj] = b[i, j, ii, jj] + + ladder_stage3_map, ladder_stage3_map_inverse = get_ladder_stage3_map(dtype="float16") + bitblas_transformed_b = bitblas.apply_transform_on_input(b, ladder_stage3_map_inverse) + + # assert cpu simulated and ladder compiled results are close + torch.testing.assert_close(torch_transformed_b, bitblas_transformed_b, rtol=1e-2, atol=1e-2) + + torch_recovered_b = bitblas.apply_transform_on_input(bitblas_transformed_b, ladder_stage3_map) + + # assert recovered results are close to original + torch.testing.assert_close(torch_recovered_b, b, rtol=1e-2, atol=1e-2) + + +def test_compare_propagate_with_torch_iter_4_fp16(): + compare_propagate_with_torch_iter_4_fp16(16, 16) + compare_propagate_with_torch_iter_4_fp16(32, 32) + compare_propagate_with_torch_iter_4_fp16(64, 64) + + +def compare_propagate_with_torch_iter_4_int8(M, N, inner_m=16, inner_n=32): + b = torch.randint(-127, 127, (M // inner_m, N // inner_n, inner_m, inner_n), dtype=torch.int8) + + def shared_32x16_to_mma_32x16_layout(i, j): + thread_id = (i % 8) * 4 + (j // 4) + local_id = (i // 8) * 4 + (j % 4) + return thread_id, local_id + + torch_transformed_b = torch.zeros((M // inner_m, N // inner_n, inner_m, inner_n), + dtype=torch.int8) + for i in range(M // inner_m): + for j in range(N // inner_n): + for ii in range(inner_m): + for jj in range(inner_n): + dummy_ii = (ii * inner_n + jj) // 16 + dummy_jj = (ii * inner_n + jj) % 16 + new_dummy_ii, new_dummy_jj = shared_32x16_to_mma_32x16_layout( + dummy_ii, dummy_jj) + new_ii = (new_dummy_ii * 16 + new_dummy_jj) // inner_n + new_jj = (new_dummy_ii * 16 + new_dummy_jj) % inner_n + torch_transformed_b[i, j, new_ii, new_jj] = b[i, j, ii, jj] + + ladder_stage3_map, ladder_stage3_map_inverse = get_ladder_stage3_map(dtype="int8") + bitblas_transformed_b = bitblas.apply_transform_on_input(b, ladder_stage3_map_inverse) + + # assert cpu simulated and ladder compiled results are close + torch.testing.assert_close(torch_transformed_b, bitblas_transformed_b, rtol=1e-2, atol=1e-2) + + torch_recovered_b = bitblas.apply_transform_on_input(bitblas_transformed_b, ladder_stage3_map) + + # assert recovered results are close to original + torch.testing.assert_close(torch_recovered_b, b, rtol=1e-2, atol=1e-2) + + +def test_compare_propagate_with_torch_iter_4_int8(): + compare_propagate_with_torch_iter_4_int8(16, 32) + compare_propagate_with_torch_iter_4_int8(32, 64) + compare_propagate_with_torch_iter_4_int8(64, 128) + + +if __name__ == "__main__": + bitblas.testing.main()