-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathstubs.py
47 lines (34 loc) · 1.46 KB
/
stubs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import sys
from importlib import import_module
from importlib.util import find_spec
from pathlib import Path
from types import FunctionType, ModuleType
from typing import Dict, List
__all__ = ["category_to_funcs", "array", "extension_to_funcs"]
spec_dir = Path(__file__).parent / "array-api" / "spec" / "API_specification"
assert spec_dir.exists(), f"{spec_dir} not found - try `git pull --recurse-submodules`"
sigs_dir = spec_dir / "signatures"
assert sigs_dir.exists()
spec_abs_path: str = str(spec_dir.resolve())
sys.path.append(spec_abs_path)
assert find_spec("signatures") is not None
name_to_mod: Dict[str, ModuleType] = {}
for path in sigs_dir.glob("*.py"):
name = path.name.replace(".py", "")
name_to_mod[name] = import_module(f"signatures.{name}")
category_to_funcs: Dict[str, List[FunctionType]] = {}
for name, mod in name_to_mod.items():
if name.endswith("_functions"):
category = name.replace("_functions", "")
objects = [getattr(mod, name) for name in mod.__all__]
assert all(isinstance(o, FunctionType) for o in objects)
category_to_funcs[category] = objects
array = name_to_mod["array_object"].array
EXTENSIONS = ["linalg"]
extension_to_funcs: Dict[str, List[FunctionType]] = {}
for ext in EXTENSIONS:
mod = name_to_mod[ext]
objects = [getattr(mod, name) for name in mod.__all__]
assert all(isinstance(o, FunctionType) for o in objects)
extension_to_funcs[ext] = objects
sys.path.remove(spec_abs_path)