Skip to content

Commit a67cfc2

Browse files
committed
initial commit
0 parents  commit a67cfc2

File tree

5 files changed

+336
-0
lines changed

5 files changed

+336
-0
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
*.pyc
2+
scratch
3+
*.egg-info
4+
dist

adt/__init__.py

Whitespace-only changes.

adt/adt.py

+236
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
"""
2+
Simple ADTs and tagged-union matching in Python,
3+
Plus immutable records (product types)
4+
5+
Tip of the hat to [union-type](https://github.com/paldelpind/union-type), a
6+
javascript library with similar aims and syntax.
7+
8+
Usage:
9+
10+
Point = Type("Point", [int, int])
11+
Rectangle = Type("Rectangle", [Point, Point])
12+
Circle = Type("Circle", [int, Point])
13+
Triangle = Type("Triangle", [Point, Point, int])
14+
15+
Shape = [ Rectangle, Circle, Triangle ]
16+
17+
area = match(Shape, {
18+
Rectangle: (lambda (t,l), (b,r): (b - t) * (r - l)),
19+
Circle: (lambda r, (x,y): math.pi * (r**2)),
20+
Triangle: (lambda (x1,y1), (x2,y2), h: (((x2 - x1) + (y2 - y1)) * h)/2)
21+
})
22+
23+
rect = Rectangle( Point(0,0), Point(100,100) )
24+
area(rect) # => 10000
25+
26+
circ = Circle( 5, Point(0,0) )
27+
area(circ) # => 78.539816...
28+
29+
tri = Triangle( Point(0,0), Point(100,100), 5 )
30+
area(tri) # => 500
31+
32+
33+
# Composing with records works transparently:
34+
35+
Point = Record("Point", {'x': int, 'y': int})
36+
Rectangle = Type("Rectangle", [Point, Point])
37+
38+
p1 = Point(x=1,y=2)
39+
p2 = Point(x=4,y=6)
40+
rect = Rectangle( p1, p2 )
41+
42+
43+
"""
44+
from f import curry_n
45+
46+
def construct_type_instance(tag, specs, args):
47+
return construct_type(tag, specs)(*args)
48+
49+
def construct_type(tag, specs):
50+
return Type(tag,specs)
51+
52+
def construct_record_instance(tag, specs, attrs):
53+
return construct_record(tag, specs)(**attrs)
54+
55+
def construct_record(tag, specs):
56+
return Record(tag,specs)
57+
58+
def Type(tag, specs):
59+
class _tagged_tuple(tuple):
60+
def __eq__(self,other):
61+
return (
62+
self.__class__.__name__ == other.__class__.__name__ and
63+
super(_tagged_tuple,self).__eq__(other)
64+
)
65+
66+
# Note: only eval()-able if constructors are in scope with same name as tags
67+
def __repr__(self):
68+
return (
69+
self.__class__.__name__ +
70+
"( " + ", ".join(repr(p) for p in self) + " )"
71+
)
72+
73+
# For pickling
74+
def __reduce__(self):
75+
nospecs = [ anything for s in specs ]
76+
return ( construct_type_instance, (tag, nospecs, tuple(v for v in self)) )
77+
78+
_tagged_tuple.__name__ = tag
79+
80+
@curry_n(len(specs))
81+
def _bind(*vals):
82+
nvals = len(vals)
83+
nspecs = len(specs)
84+
if nvals > nspecs:
85+
raise TypeError( "%s: Expected %d values, given %d" % (tag, nspecs, nvals))
86+
87+
for (i,(s,v)) in enumerate(zip(specs,vals)):
88+
ok, err = validate(s,v)
89+
if not ok:
90+
msg = "%s: Invalid type in field %d: %s" % (tag,i,repr(v))
91+
if not (err is None):
92+
msg = "%s\n %s" % (msg, err)
93+
raise TypeError(msg)
94+
95+
return _tagged_tuple(vals)
96+
97+
_bind.__name__ = "construct_%s" % tag
98+
_bind.__adt_class__ = _tagged_tuple
99+
return _bind
100+
101+
102+
def Record(tag,specs):
103+
104+
class _record(object):
105+
__slots__ = specs.keys()
106+
107+
def __eq__(self,other):
108+
return (
109+
self.__class__.__name__ == other.__class__.__name__ and
110+
all([
111+
getattr(self,k) == getattr(other,k)
112+
for k in self.__class__.__slots__
113+
])
114+
)
115+
116+
def __repr__(self):
117+
return (
118+
self.__class__.__name__ +
119+
"( " +
120+
", ".join([
121+
"%s=%s" % (k, repr(getattr(self,k)))
122+
for k in self.__class__.__slots__
123+
]) +
124+
" )"
125+
)
126+
127+
# For pickling
128+
def __reduce__(self):
129+
nospecs = dict([(k,anything) for k in specs.keys()])
130+
attrs = dict([(k,getattr(self,k)) for k in self.__class__.__slots__])
131+
return ( construct_record_instance, (tag, nospecs, attrs) )
132+
133+
def __init__(self,**vals):
134+
for (k,v) in vals.items():
135+
setattr(self.__class__,k,v)
136+
137+
_record.__name__ = tag
138+
139+
def _bind(**vals):
140+
extras = [ ("'%s'" % k) for k in vals.keys() if not specs.has_key(k) ]
141+
if len(extras) > 0:
142+
raise TypeError("%s: Unexpected values given: %s" % (tag, ", ".join(extras)))
143+
144+
for (name,s) in specs.items():
145+
if not vals.has_key(name):
146+
raise TypeError("%s: Expected value for '%s', none given" % (tag, name))
147+
ok, err = validate(s,vals[name])
148+
if not ok:
149+
msg = "%s: Invalid type in field '%s': %s" % (tag,name,repr(vals[name]))
150+
if not (err is None):
151+
msg = "%s\n %s" % (msg, err)
152+
raise TypeError(msg)
153+
154+
return _record(**vals)
155+
156+
_bind.__name__ = "construct_%s" % tag
157+
_bind.__adt_class__ = _record
158+
return _bind
159+
160+
161+
def anything(x):
162+
return True
163+
164+
def typeof(adt):
165+
if not hasattr(adt, '__adt_class__'):
166+
raise TypeError("Not an ADT constructor")
167+
return adt.__adt_class__
168+
169+
@curry_n(2)
170+
def seq_of(t,xs):
171+
return (
172+
hasattr(xs,'__iter__') and all( validate(t,x)[0] for x in xs )
173+
)
174+
175+
@curry_n(2)
176+
def tuple_of(ts,xs):
177+
return (
178+
all( validate(t,x)[0] for (t,x) in zip(ts,xs) )
179+
)
180+
181+
@curry_n(2)
182+
def one_of(ts,x):
183+
return any( validate(t,x)[0] for t in ts )
184+
185+
def validate(s,v):
186+
try:
187+
return ( isinstance(v,s), None )
188+
except TypeError:
189+
try:
190+
return (
191+
( ( type(v) == s ) or
192+
( hasattr(s,"__adt_class__") and isinstance(v,typeof(s)) ) or
193+
( callable(s) and s(v) == True )
194+
),
195+
None
196+
)
197+
except Exception as e:
198+
return (False, e)
199+
200+
201+
@curry_n(3)
202+
def match(adts, cases, target):
203+
204+
assert target.__class__ in [ typeof(adt) for adt in adts ], \
205+
"%s is not in union" % target.__class__.__name__
206+
207+
missing = [
208+
t.__adt_class__.__name__ for t in adts \
209+
if not (cases.has_key(type(None)) or cases.has_key(t))
210+
]
211+
assert len(missing) == 0, \
212+
"No case found for the following type(s): %s" % ", ".join(missing)
213+
214+
fn = None
215+
wildcard = False
216+
try:
217+
fn = (
218+
next(
219+
cases[constr] for constr in cases \
220+
if not constr == type(None) and isinstance(target,typeof(constr))
221+
)
222+
)
223+
224+
except StopIteration:
225+
fn = cases.get(type(None),None)
226+
wildcard = not fn is None
227+
228+
# note should never happen due to type assertions above
229+
if fn is None:
230+
raise TypeError("No cases match %s" % target.__class__.__name__)
231+
232+
assert callable(fn), \
233+
"Matched case is not callable; check your cases"
234+
235+
return fn() if wildcard else fn( *(slot for slot in target) )
236+

adt/f.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from functools import partial, wraps, update_wrapper
2+
from inspect import getargspec
3+
4+
"""
5+
Derived from [fn.py](https://github.com/kachayev/fn.py) function 'curried'
6+
Amended to fix wrapping error: cf. https://github.com/kachayev/fn.py/pull/75
7+
8+
Copyright 2013 Alexey Kachayev
9+
Under the Apache License, Version 2.0
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
"""
13+
def curry(func):
14+
"""A decorator that makes the function curried
15+
16+
Usage example:
17+
18+
>>> @curry
19+
... def sum5(a, b, c, d, e):
20+
... return a + b + c + d + e
21+
...
22+
>>> sum5(1)(2)(3)(4)(5)
23+
15
24+
>>> sum5(1, 2, 3)(4, 5)
25+
15
26+
"""
27+
@wraps(func)
28+
def _curry(*args, **kwargs):
29+
f = func
30+
count = 0
31+
while isinstance(f, partial):
32+
if f.args:
33+
count += len(f.args)
34+
f = f.func
35+
36+
spec = getargspec(f)
37+
38+
if count == len(spec.args) - len(args):
39+
return func(*args, **kwargs)
40+
41+
para_func = partial(func, *args, **kwargs)
42+
update_wrapper(para_func, f)
43+
return curry(para_func)
44+
45+
return _curry
46+
47+
48+
def curry_n(n):
49+
def _curry_n(func):
50+
@wraps(func)
51+
def _curry(*args, **kwargs):
52+
f = func
53+
54+
count = 0
55+
while isinstance(f, partial) and count < n:
56+
if f.args:
57+
count += len(f.args)
58+
f = f.func
59+
60+
if count >= n - len(args):
61+
return func(*args, **kwargs)
62+
63+
para_func = partial(func, *args, **kwargs)
64+
update_wrapper(para_func, f)
65+
return _curry_n(para_func)
66+
67+
return _curry
68+
69+
return _curry_n
70+
71+

setup.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from setuptools import setup
2+
3+
setup(
4+
name = "adt.py",
5+
version = "0.0.1",
6+
author = "Eric Gjertsen",
7+
author_email = "[email protected]",
8+
description = (
9+
"Tagged-union types with simple pattern matching, and immutable records (product types)"
10+
),
11+
license = "MIT",
12+
keywords = "adt types immutable functional",
13+
url = "https://github.com/ericgj/adt.py",
14+
packages = ["adt"],
15+
classifiers = [
16+
"Development Status :: 3 - Alpha",
17+
"Intended Audience :: Developers",
18+
"License :: OSI Approved :: MIT License",
19+
"Programming Language :: Python :: 2",
20+
"Programming Language :: Python :: 3",
21+
"Topic :: Utilities"
22+
]
23+
)
24+
25+

0 commit comments

Comments
 (0)