Skip to content

[FRONTEND] Complete rewrite of the runtime #644

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 48 commits into from
Sep 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
089e839
some work
ptillet Sep 11, 2022
5462d32
Merge branch 'master' into phil/new-runtime
ptillet Sep 11, 2022
9282bb3
seems to work
ptillet Sep 11, 2022
88c87aa
more work
ptillet Sep 11, 2022
0683209
more work
ptillet Sep 11, 2022
d67e29c
more work
ptillet Sep 11, 2022
1d9ee1e
more progress
ptillet Sep 11, 2022
03be2ec
test-core passes
ptillet Sep 11, 2022
ba6151a
.
ptillet Sep 11, 2022
60c2f64
.
ptillet Sep 11, 2022
f0d7a2d
.
ptillet Sep 12, 2022
fb5ef0a
.
ptillet Sep 12, 2022
f8020e9
.
ptillet Sep 12, 2022
f9fd4bc
.
ptillet Sep 12, 2022
7ea2217
.
ptillet Sep 12, 2022
dec0add
.
ptillet Sep 12, 2022
c96f102
.
ptillet Sep 12, 2022
b74c06a
.
ptillet Sep 12, 2022
ac763a9
.
ptillet Sep 12, 2022
5f90262
.
ptillet Sep 12, 2022
a8a7f2a
.
ptillet Sep 12, 2022
87ede81
.
ptillet Sep 12, 2022
e2d4e3f
.
ptillet Sep 12, 2022
1c3e741
debug
ptillet Sep 12, 2022
ac18830
.
ptillet Sep 13, 2022
d4e1c96
.
ptillet Sep 13, 2022
8c91efd
.
ptillet Sep 13, 2022
c3ed8db
.
ptillet Sep 13, 2022
542092a
.
ptillet Sep 13, 2022
ec7e7d9
.
ptillet Sep 13, 2022
0e88ded
handle do_not_specialize
ptillet Sep 13, 2022
2dd3fd7
.
ptillet Sep 13, 2022
7f70933
.
ptillet Sep 13, 2022
527bb36
fixup
ptillet Sep 13, 2022
8165af8
fix memory leak
ptillet Sep 13, 2022
5a6bdbb
Merge branch 'master' into phil/new-runtime
ptillet Sep 14, 2022
7c90059
Now using current_device and set_device again
ptillet Sep 14, 2022
0d9ed9a
Fixed initialization issue
ptillet Sep 14, 2022
f2075d7
Fix more bugs
ptillet Sep 17, 2022
adc65de
Merge branch 'master' into HEAD
ptillet Sep 17, 2022
43dddc3
.
ptillet Sep 17, 2022
25bbd59
style
ptillet Sep 17, 2022
df4d57b
fixup hook
ptillet Sep 18, 2022
b157f03
some cleaning
ptillet Sep 18, 2022
5ed921c
.
ptillet Sep 18, 2022
ee772da
style
ptillet Sep 18, 2022
3f3cc8c
fixup
ptillet Sep 18, 2022
cfa8d18
Merge branch 'master' into phil/new-runtime
ptillet Sep 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/triton/codegen/extern_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <memory>
#include <string>
#include <map>

#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
Expand Down
1 change: 0 additions & 1 deletion include/triton/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ class module {

// Functions
const functions_list_t &get_function_list() const { return functions_; }
functions_list_t &get_function_list() { return functions_; }
function *get_function(const std::string& name) {
if(symbols_.find(name) == symbols_.end())
throw std::runtime_error("function " + name + " is not declared");
Expand Down
2 changes: 1 addition & 1 deletion lib/codegen/pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
// run passes
inliner.run(ir);
dce.run(ir);
// ir.print(std::cout);
peephole.run(ir);
dce.run(ir);
pipeline.run(ir);
dce.run(ir);
// ir.print(std::cout);
disassociate.run(ir);
dce.run(ir);
align.run(ir);
Expand Down
16 changes: 15 additions & 1 deletion python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,19 @@ void init_triton_codegen(py::module &&m) {
assert(backend == ROCM);
return hip_load_binary(name, asm_map, n_shared_bytes, dev);
}, py::return_value_policy::take_ownership);


struct InstanceDescriptor
{
std::unordered_set<int> divisibleBy16;
std::unordered_set<int> equalTo1;
};

py::class_<InstanceDescriptor>(m, "instance_descriptor")
.def(py::init<>())
.def(py::init<std::unordered_set<int>, std::unordered_set<int>>())
.def_readonly("divisible_by_16", &InstanceDescriptor::divisibleBy16)
.def_readonly("equal_to_1", &InstanceDescriptor::equalTo1);
}


Expand Down Expand Up @@ -758,10 +771,11 @@ void init_triton_ir(py::module &&m) {
.def("get", &ir::struct_type::get, ret::reference)
.def_property_readonly("num_types", &ir::struct_type::get_num_types);

py::class_<ir::module>(m, "module")
py::class_<ir::module>(m, "module", py::dynamic_attr())
.def(py::init<std::string, ir::builder &>())
.def("has_function", &ir::module::has_function)
.def("get_function", &ir::module::get_function, ret::reference)
.def("get_functions", &ir::module::get_function_list, ret::reference)
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
.def("print", [](ir::module *self) {
self->print(std::cout);
Expand Down
18 changes: 8 additions & 10 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import triton
import triton._C.libtriton.triton as _triton
import triton.language as tl
from triton.code_gen import JITFunction, TensorWrapper, reinterpret
from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret

int_dtypes = ['int8', 'int16', 'int32', 'int64']
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
Expand Down Expand Up @@ -273,7 +273,7 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
elif (op in ('%', '/') and
((dtype_x in int_dtypes and dtype_y in uint_dtypes) or
(dtype_x in uint_dtypes and dtype_y in int_dtypes))):
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
with pytest.raises(triton.CompilationError) as exc_info:
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__))
else:
Expand Down Expand Up @@ -311,7 +311,7 @@ def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
else:
numpy_expr = None
if 'float' in dtype_x + dtype_y:
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
with pytest.raises(triton.CompilationError) as exc_info:
_test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
# The CompilationError must have been caused by a C++ exception with this text.
assert re.match('invalid operands of type', str(exc_info.value.__cause__))
Expand Down Expand Up @@ -500,7 +500,7 @@ def generate_kernel(shape_x, shape_z):
def catch_compilation_error(kernel):
try:
kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
except triton.code_gen.CompilationError as e:
except triton.CompilationError as e:
np.testing.assert_(True)
except BaseException:
np.testing.assert_(False)
Expand Down Expand Up @@ -1209,7 +1209,7 @@ def _kernel(dst, src, CACHE: tl.constexpr):
assert 'ld.global.cg' not in ptx


@pytest.mark.parametrize("N", [8, 10, 11, 1024])
@pytest.mark.parametrize("N", [16, 10, 11, 1024])
def test_vectorization(N):
src = torch.empty(1024, device='cuda')
dst = torch.empty(1024, device='cuda')
Expand All @@ -1221,10 +1221,8 @@ def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
tl.store(dst + offsets, x, mask=offsets < N)
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
ptx = pgm.asm["ptx"]
if N % 4 == 0:
if N % 16 == 0:
assert "ld.global.v4.b32" in ptx
elif N % 2 == 0:
assert "ld.global.v2.b32" in ptx
else:
assert "ld.global.b32" in ptx
# triton.testing.assert_almost_equal(dst, src[:N])
Expand Down Expand Up @@ -1292,7 +1290,7 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non

def cache_hook(*args, **kwargs):
nonlocal spec_type
spec_type = kwargs["compile"]["arg_types"][0][1]
spec_type = kwargs["compile"]["signature"][0]
JITFunction.cache_hook = cache_hook

@triton.jit
Expand All @@ -1319,7 +1317,7 @@ def kernel(VALUE, X):
x = torch.tensor([3.14159], device='cuda')

if overflow:
with pytest.raises(RuntimeError, match='integer overflow'):
with pytest.raises(OverflowError):
kernel[(1, )](value, x)
else:
kernel[(1, )](value, x)
Expand Down
7 changes: 3 additions & 4 deletions python/test/unit/operators/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,9 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_()
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)]
kernel = triton.ops._matmul.kernel
decorators = kernel.kernel_decorators
kernel.kernel_decorators = []
triton.autotune(configs, [])(kernel)
kernel.kernel_decorators += decorators[1:]
kernel.configs = configs
# kernel.run = kernel.run.run.run

# get matrix shape
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
Expand Down
18 changes: 9 additions & 9 deletions python/test/unit/runtime/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import triton
import triton.language as tl
from triton.code_gen import JITFunction
from triton.runtime.jit import JITFunction

tmpdir = ".tmp"

Expand Down Expand Up @@ -99,16 +99,16 @@ def inc_counter(*args, **kwargs):
reset_tmp_dir()
x = torch.empty(1, dtype=torch.int32, device='cuda')
function = {'enable': kernel, 'disable': kernel_nospec}[mode]
target = {'enable': 5, 'disable': 1}[mode]
target = {'enable': 3, 'disable': 1}[mode]
for i in [1, 2, 4, 8, 16, 32]:
function[(1,)](x, i, BLOCK=512)
assert counter == target


@pytest.mark.parametrize("value, value_type", [
(-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'),
(2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'),
(2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64')
(-1, 'i32'), (0, 'i32'), (1, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
(2**32, 'i64'), (2**63 - 1, 'i64'), (-2**63, 'i64'),
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**63, 'u64'), (2**64 - 1, 'u64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:

Expand All @@ -120,14 +120,14 @@ def kernel(VALUE, X):

def get_cache_str(*args, **kwargs):
nonlocal cache_str
cache_str = kwargs['key'].split('-')
triton.code_gen.JITFunction.cache_hook = get_cache_str
cache_str = kwargs["repr"]
triton.JITFunction.cache_hook = get_cache_str
reset_tmp_dir()
x = torch.tensor([3.14159], device='cuda')
kernel[(1, )](value, x)
triton.code_gen.JITFunction.cache_hook = None
triton.JITFunction.cache_hook = None

cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1])
cache_str_match = re.match(r".*VALUE: (\w+).*", cache_str)
spec_type = None if cache_str_match is None else cache_str_match.group(1)
assert spec_type == value_type

Expand Down
7 changes: 4 additions & 3 deletions python/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
# or pybind11 shows `munmap_chunk(): invalid pointer`
import torch
# submodules
from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, \
JITFunction, Config, Autotuner, reinterpret
from .utils import *
from .runtime import Config, autotune, heuristics, JITFunction, KernelInterface
from .runtime.jit import jit
from .compiler import compile, CompilationError
from . import language
from . import code_gen
from . import testing
from . import ops
Loading