-
Notifications
You must be signed in to change notification settings - Fork 175
/
setup.py
147 lines (125 loc) · 4.69 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0
"""MegaBlocks package setup."""
import os
import warnings
from typing import Any, Dict, Mapping
from setuptools import find_packages, setup
# We require torch in setup.py to build cpp extensions "ahead of time"
# More info here: # https://pytorch.org/tutorials/advanced/cpp_extension.html
try:
import torch
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
except ModuleNotFoundError as e:
raise ModuleNotFoundError("No module named 'torch'. `torch` is required to install `MegaBlocks`.",) from e
_PACKAGE_NAME = 'megablocks'
_PACKAGE_DIR = 'megablocks'
_REPO_REAL_PATH = os.path.dirname(os.path.realpath(__file__))
_PACKAGE_REAL_PATH = os.path.join(_REPO_REAL_PATH, _PACKAGE_DIR)
# Read the package version
# We can't use `.__version__` from the library since it's not installed yet
version_path = os.path.join(_PACKAGE_REAL_PATH, '_version.py')
with open(version_path, encoding='utf-8') as f:
version_globals: Dict[str, Any] = {}
version_locals: Mapping[str, object] = {}
content = f.read()
exec(content, version_globals, version_locals)
repo_version = version_locals['__version__']
with open('README.md', 'r', encoding='utf-8') as fh:
long_description = fh.read()
# Hide the content between <!-- SETUPTOOLS_LONG_DESCRIPTION_HIDE_BEGIN --> and
# <!-- SETUPTOOLS_LONG_DESCRIPTION_HIDE_END --> tags in the README
while True:
start_tag = '<!-- SETUPTOOLS_LONG_DESCRIPTION_HIDE_BEGIN -->'
end_tag = '<!-- SETUPTOOLS_LONG_DESCRIPTION_HIDE_END -->'
start = long_description.find(start_tag)
end = long_description.find(end_tag)
if start == -1:
assert end == -1, 'there should be a balanced number of start and ends'
break
else:
assert end != -1, 'there should be a balanced number of start and ends'
long_description = long_description[:start] + \
long_description[end + len(end_tag):]
classifiers = [
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'License :: OSI Approved :: BSD License',
'Operating System :: Unix',
]
install_requires = [
'numpy>=1.21.5,<2.1.0',
'packaging>=21.3.0,<24.2',
'torch>=2.4.0,<2.4.1',
'triton>=2.1.0',
'stanford-stk==0.7.1',
]
extra_deps = {}
extra_deps['gg'] = [
'grouped_gemm==0.1.6',
]
extra_deps['dev'] = [
'absl-py', # TODO: delete when finish removing all absl tests
'coverage[toml]==7.4.4',
'pytest_codeblocks>=0.16.1,<0.17',
'pytest-cov>=4,<5',
'pytest>=7.2.1,<8',
'pre-commit>=3.4.0,<4',
]
extra_deps['testing'] = [
'mosaicml>=0.24.1',
]
extra_deps['all'] = list({dep for key, deps in extra_deps.items() for dep in deps if key not in {'testing'}})
cmdclass = {}
ext_modules = []
# Only install CUDA extensions if available
if 'cu' in torch.__version__ and CUDA_HOME is not None:
cmdclass = {'build_ext': BuildExtension}
nvcc_flags = ['--ptxas-options=-v', '--optimize=2']
if os.environ.get('TORCH_CUDA_ARCH_LIST'):
# Let PyTorch builder to choose device to target for.
device_capability = ''
else:
device_capability_tuple = torch.cuda.get_device_capability()
device_capability = f'{device_capability_tuple[0]}{device_capability_tuple[1]}'
if device_capability:
nvcc_flags.append(f'--generate-code=arch=compute_{device_capability},code=sm_{device_capability}',)
ext_modules = [
CUDAExtension(
'megablocks_ops',
['csrc/ops.cu'],
include_dirs=['csrc'],
extra_compile_args={
'cxx': ['-fopenmp'],
'nvcc': nvcc_flags,
},
),
]
elif CUDA_HOME is None:
warnings.warn(
'Attempted to install CUDA extensions, but CUDA_HOME was None. ' +
'Please install CUDA and ensure that the CUDA_HOME environment ' +
'variable points to the installation location.',
)
else:
warnings.warn('Warning: No CUDA devices; cuda code will not be compiled.')
setup(
name=_PACKAGE_NAME,
version=repo_version,
author='Trevor Gale',
author_email='[email protected]',
description='MegaBlocks',
long_description=long_description,
long_description_content_type='text/markdown',
url='https://github.com/databricks/megablocks',
classifiers=classifiers,
packages=find_packages(exclude=['tests*', 'third_party*', 'yamls*', 'exp*', '.github*']),
ext_modules=ext_modules,
cmdclass=cmdclass,
install_requires=install_requires,
extras_require=extra_deps,
python_requires='>=3.9',
package_data={_PACKAGE_NAME: ['py.typed']},
)