|
| 1 | +""" |
| 2 | +Simple ADTs and tagged-union matching in Python, |
| 3 | +Plus immutable records (product types) |
| 4 | +
|
| 5 | +Tip of the hat to [union-type](https://github.com/paldelpind/union-type), a |
| 6 | +javascript library with similar aims and syntax. |
| 7 | +
|
| 8 | +Usage: |
| 9 | +
|
| 10 | + Point = Type("Point", [int, int]) |
| 11 | + Rectangle = Type("Rectangle", [Point, Point]) |
| 12 | + Circle = Type("Circle", [int, Point]) |
| 13 | + Triangle = Type("Triangle", [Point, Point, int]) |
| 14 | + |
| 15 | + Shape = [ Rectangle, Circle, Triangle ] |
| 16 | + |
| 17 | + area = match(Shape, { |
| 18 | + Rectangle: (lambda (t,l), (b,r): (b - t) * (r - l)), |
| 19 | + Circle: (lambda r, (x,y): math.pi * (r**2)), |
| 20 | + Triangle: (lambda (x1,y1), (x2,y2), h: (((x2 - x1) + (y2 - y1)) * h)/2) |
| 21 | + }) |
| 22 | + |
| 23 | + rect = Rectangle( Point(0,0), Point(100,100) ) |
| 24 | + area(rect) # => 10000 |
| 25 | + |
| 26 | + circ = Circle( 5, Point(0,0) ) |
| 27 | + area(circ) # => 78.539816... |
| 28 | + |
| 29 | + tri = Triangle( Point(0,0), Point(100,100), 5 ) |
| 30 | + area(tri) # => 500 |
| 31 | +
|
| 32 | +
|
| 33 | + # Composing with records works transparently: |
| 34 | +
|
| 35 | + Point = Record("Point", {'x': int, 'y': int}) |
| 36 | + Rectangle = Type("Rectangle", [Point, Point]) |
| 37 | + |
| 38 | + p1 = Point(x=1,y=2) |
| 39 | + p2 = Point(x=4,y=6) |
| 40 | + rect = Rectangle( p1, p2 ) |
| 41 | +
|
| 42 | +
|
| 43 | +""" |
| 44 | +from f import curry_n |
| 45 | + |
| 46 | +def construct_type_instance(tag, specs, args): |
| 47 | + return construct_type(tag, specs)(*args) |
| 48 | + |
| 49 | +def construct_type(tag, specs): |
| 50 | + return Type(tag,specs) |
| 51 | + |
| 52 | +def construct_record_instance(tag, specs, attrs): |
| 53 | + return construct_record(tag, specs)(**attrs) |
| 54 | + |
| 55 | +def construct_record(tag, specs): |
| 56 | + return Record(tag,specs) |
| 57 | + |
| 58 | +def Type(tag, specs): |
| 59 | + class _tagged_tuple(tuple): |
| 60 | + def __eq__(self,other): |
| 61 | + return ( |
| 62 | + self.__class__.__name__ == other.__class__.__name__ and |
| 63 | + super(_tagged_tuple,self).__eq__(other) |
| 64 | + ) |
| 65 | + |
| 66 | + # Note: only eval()-able if constructors are in scope with same name as tags |
| 67 | + def __repr__(self): |
| 68 | + return ( |
| 69 | + self.__class__.__name__ + |
| 70 | + "( " + ", ".join(repr(p) for p in self) + " )" |
| 71 | + ) |
| 72 | + |
| 73 | + # For pickling |
| 74 | + def __reduce__(self): |
| 75 | + nospecs = [ anything for s in specs ] |
| 76 | + return ( construct_type_instance, (tag, nospecs, tuple(v for v in self)) ) |
| 77 | + |
| 78 | + _tagged_tuple.__name__ = tag |
| 79 | + |
| 80 | + @curry_n(len(specs)) |
| 81 | + def _bind(*vals): |
| 82 | + nvals = len(vals) |
| 83 | + nspecs = len(specs) |
| 84 | + if nvals > nspecs: |
| 85 | + raise TypeError( "%s: Expected %d values, given %d" % (tag, nspecs, nvals)) |
| 86 | + |
| 87 | + for (i,(s,v)) in enumerate(zip(specs,vals)): |
| 88 | + ok, err = validate(s,v) |
| 89 | + if not ok: |
| 90 | + msg = "%s: Invalid type in field %d: %s" % (tag,i,repr(v)) |
| 91 | + if not (err is None): |
| 92 | + msg = "%s\n %s" % (msg, err) |
| 93 | + raise TypeError(msg) |
| 94 | + |
| 95 | + return _tagged_tuple(vals) |
| 96 | + |
| 97 | + _bind.__name__ = "construct_%s" % tag |
| 98 | + _bind.__adt_class__ = _tagged_tuple |
| 99 | + return _bind |
| 100 | + |
| 101 | + |
| 102 | +def Record(tag,specs): |
| 103 | + |
| 104 | + class _record(object): |
| 105 | + __slots__ = specs.keys() |
| 106 | + |
| 107 | + def __eq__(self,other): |
| 108 | + return ( |
| 109 | + self.__class__.__name__ == other.__class__.__name__ and |
| 110 | + all([ |
| 111 | + getattr(self,k) == getattr(other,k) |
| 112 | + for k in self.__class__.__slots__ |
| 113 | + ]) |
| 114 | + ) |
| 115 | + |
| 116 | + def __repr__(self): |
| 117 | + return ( |
| 118 | + self.__class__.__name__ + |
| 119 | + "( " + |
| 120 | + ", ".join([ |
| 121 | + "%s=%s" % (k, repr(getattr(self,k))) |
| 122 | + for k in self.__class__.__slots__ |
| 123 | + ]) + |
| 124 | + " )" |
| 125 | + ) |
| 126 | + |
| 127 | + # For pickling |
| 128 | + def __reduce__(self): |
| 129 | + nospecs = dict([(k,anything) for k in specs.keys()]) |
| 130 | + attrs = dict([(k,getattr(self,k)) for k in self.__class__.__slots__]) |
| 131 | + return ( construct_record_instance, (tag, nospecs, attrs) ) |
| 132 | + |
| 133 | + def __init__(self,**vals): |
| 134 | + for (k,v) in vals.items(): |
| 135 | + setattr(self.__class__,k,v) |
| 136 | + |
| 137 | + _record.__name__ = tag |
| 138 | + |
| 139 | + def _bind(**vals): |
| 140 | + extras = [ ("'%s'" % k) for k in vals.keys() if not specs.has_key(k) ] |
| 141 | + if len(extras) > 0: |
| 142 | + raise TypeError("%s: Unexpected values given: %s" % (tag, ", ".join(extras))) |
| 143 | + |
| 144 | + for (name,s) in specs.items(): |
| 145 | + if not vals.has_key(name): |
| 146 | + raise TypeError("%s: Expected value for '%s', none given" % (tag, name)) |
| 147 | + ok, err = validate(s,vals[name]) |
| 148 | + if not ok: |
| 149 | + msg = "%s: Invalid type in field '%s': %s" % (tag,name,repr(vals[name])) |
| 150 | + if not (err is None): |
| 151 | + msg = "%s\n %s" % (msg, err) |
| 152 | + raise TypeError(msg) |
| 153 | + |
| 154 | + return _record(**vals) |
| 155 | + |
| 156 | + _bind.__name__ = "construct_%s" % tag |
| 157 | + _bind.__adt_class__ = _record |
| 158 | + return _bind |
| 159 | + |
| 160 | + |
| 161 | +def anything(x): |
| 162 | + return True |
| 163 | + |
| 164 | +def typeof(adt): |
| 165 | + if not hasattr(adt, '__adt_class__'): |
| 166 | + raise TypeError("Not an ADT constructor") |
| 167 | + return adt.__adt_class__ |
| 168 | + |
| 169 | +@curry_n(2) |
| 170 | +def seq_of(t,xs): |
| 171 | + return ( |
| 172 | + hasattr(xs,'__iter__') and all( validate(t,x)[0] for x in xs ) |
| 173 | + ) |
| 174 | + |
| 175 | +@curry_n(2) |
| 176 | +def tuple_of(ts,xs): |
| 177 | + return ( |
| 178 | + all( validate(t,x)[0] for (t,x) in zip(ts,xs) ) |
| 179 | + ) |
| 180 | + |
| 181 | +@curry_n(2) |
| 182 | +def one_of(ts,x): |
| 183 | + return any( validate(t,x)[0] for t in ts ) |
| 184 | + |
| 185 | +def validate(s,v): |
| 186 | + try: |
| 187 | + return ( isinstance(v,s), None ) |
| 188 | + except TypeError: |
| 189 | + try: |
| 190 | + return ( |
| 191 | + ( ( type(v) == s ) or |
| 192 | + ( hasattr(s,"__adt_class__") and isinstance(v,typeof(s)) ) or |
| 193 | + ( callable(s) and s(v) == True ) |
| 194 | + ), |
| 195 | + None |
| 196 | + ) |
| 197 | + except Exception as e: |
| 198 | + return (False, e) |
| 199 | + |
| 200 | + |
| 201 | +@curry_n(3) |
| 202 | +def match(adts, cases, target): |
| 203 | + |
| 204 | + assert target.__class__ in [ typeof(adt) for adt in adts ], \ |
| 205 | + "%s is not in union" % target.__class__.__name__ |
| 206 | + |
| 207 | + missing = [ |
| 208 | + t.__adt_class__.__name__ for t in adts \ |
| 209 | + if not (cases.has_key(type(None)) or cases.has_key(t)) |
| 210 | + ] |
| 211 | + assert len(missing) == 0, \ |
| 212 | + "No case found for the following type(s): %s" % ", ".join(missing) |
| 213 | + |
| 214 | + fn = None |
| 215 | + wildcard = False |
| 216 | + try: |
| 217 | + fn = ( |
| 218 | + next( |
| 219 | + cases[constr] for constr in cases \ |
| 220 | + if not constr == type(None) and isinstance(target,typeof(constr)) |
| 221 | + ) |
| 222 | + ) |
| 223 | + |
| 224 | + except StopIteration: |
| 225 | + fn = cases.get(type(None),None) |
| 226 | + wildcard = not fn is None |
| 227 | + |
| 228 | + # note should never happen due to type assertions above |
| 229 | + if fn is None: |
| 230 | + raise TypeError("No cases match %s" % target.__class__.__name__) |
| 231 | + |
| 232 | + assert callable(fn), \ |
| 233 | + "Matched case is not callable; check your cases" |
| 234 | + |
| 235 | + return fn() if wildcard else fn( *(slot for slot in target) ) |
| 236 | + |
0 commit comments