1
- use std:: ops:: Deref ;
2
-
3
1
use itertools:: Itertools ;
4
2
use rustc_hash:: { FxBuildHasher , FxHashMap } ;
5
3
@@ -72,79 +70,83 @@ impl AlwaysFixableViolation for RepeatedEqualityComparison {
72
70
73
71
/// PLR1714
74
72
pub ( crate ) fn repeated_equality_comparison ( checker : & mut Checker , bool_op : & ast:: ExprBoolOp ) {
75
- if bool_op
76
- . values
77
- . iter ( )
78
- . any ( |value| !is_allowed_value ( bool_op. op , value, checker. semantic ( ) ) )
79
- {
80
- return ;
81
- }
82
-
83
73
// Map from expression hash to (starting offset, number of comparisons, list
84
- let mut value_to_comparators: FxHashMap < HashableExpr , ( TextSize , Vec < & Expr > , Vec < & Expr > ) > =
74
+ let mut value_to_comparators: FxHashMap < HashableExpr , ( TextSize , Vec < & Expr > , Vec < usize > ) > =
85
75
FxHashMap :: with_capacity_and_hasher ( bool_op. values . len ( ) * 2 , FxBuildHasher ) ;
86
76
87
- for value in & bool_op. values {
88
- // Enforced via `is_allowed_value`.
89
- let Expr :: Compare ( ast:: ExprCompare {
90
- left, comparators, ..
91
- } ) = value
92
- else {
93
- return ;
94
- } ;
95
-
96
- // Enforced via `is_allowed_value`.
97
- let [ right] = & * * comparators else {
98
- return ;
77
+ for ( i, value) in bool_op. values . iter ( ) . enumerate ( ) {
78
+ let Some ( ( left, right) ) = to_allowed_value ( bool_op. op , value, checker. semantic ( ) ) else {
79
+ continue ;
99
80
} ;
100
81
101
- if matches ! ( left. as_ref ( ) , Expr :: Name ( _) | Expr :: Attribute ( _) ) {
102
- let ( _, left_matches, value_matches ) = value_to_comparators
103
- . entry ( left. deref ( ) . into ( ) )
82
+ if matches ! ( left, Expr :: Name ( _) | Expr :: Attribute ( _) ) {
83
+ let ( _, left_matches, index_matches ) = value_to_comparators
84
+ . entry ( left. into ( ) )
104
85
. or_insert_with ( || ( left. start ( ) , Vec :: new ( ) , Vec :: new ( ) ) ) ;
105
86
left_matches. push ( right) ;
106
- value_matches . push ( value ) ;
87
+ index_matches . push ( i ) ;
107
88
}
108
89
109
90
if matches ! ( right, Expr :: Name ( _) | Expr :: Attribute ( _) ) {
110
- let ( _, right_matches, value_matches ) = value_to_comparators
91
+ let ( _, right_matches, index_matches ) = value_to_comparators
111
92
. entry ( right. into ( ) )
112
93
. or_insert_with ( || ( right. start ( ) , Vec :: new ( ) , Vec :: new ( ) ) ) ;
113
94
right_matches. push ( left) ;
114
- value_matches . push ( value ) ;
95
+ index_matches . push ( i ) ;
115
96
}
116
97
}
117
98
118
- for ( value, ( start , comparators, values ) ) in value_to_comparators
99
+ for ( value, ( _ , comparators, indices ) ) in value_to_comparators
119
100
. iter ( )
120
101
. sorted_by_key ( |( _, ( start, _, _) ) | * start)
121
102
{
122
- if comparators. len ( ) > 1 {
103
+ // If there's only one comparison, there's nothing to merge.
104
+ if comparators. len ( ) == 1 {
105
+ continue ;
106
+ }
107
+
108
+ // Break into sequences of consecutive comparisons.
109
+ let mut sequences: Vec < ( Vec < usize > , Vec < & Expr > ) > = Vec :: new ( ) ;
110
+ let mut last = None ;
111
+ for ( index, comparator) in indices. iter ( ) . zip ( comparators. iter ( ) ) {
112
+ if last. is_some_and ( |last| last + 1 == * index) {
113
+ let ( indices, comparators) = sequences. last_mut ( ) . unwrap ( ) ;
114
+ indices. push ( * index) ;
115
+ comparators. push ( * comparator) ;
116
+ } else {
117
+ sequences. push ( ( vec ! [ * index] , vec ! [ * comparator] ) ) ;
118
+ }
119
+ last = Some ( * index) ;
120
+ }
121
+
122
+ for ( indices, comparators) in sequences {
123
+ if indices. len ( ) == 1 {
124
+ continue ;
125
+ }
126
+
123
127
let mut diagnostic = Diagnostic :: new (
124
128
RepeatedEqualityComparison {
125
129
expression : SourceCodeSnippet :: new ( merged_membership_test (
126
130
value. as_expr ( ) ,
127
131
bool_op. op ,
128
- comparators,
132
+ & comparators,
129
133
checker. locator ( ) ,
130
134
) ) ,
131
135
} ,
132
136
bool_op. range ( ) ,
133
137
) ;
134
138
135
139
// Grab the remaining comparisons.
136
- let ( before , after ) = bool_op
137
- . values
138
- . iter ( )
139
- . filter ( |value| ! values. contains ( value ) )
140
- . partition :: < Vec < _ > , _ > ( |value| value . start ( ) < * start ) ;
140
+ let [ first , .. , last ] = indices . as_slice ( ) else {
141
+ unreachable ! ( "Indices should have at least two elements" )
142
+ } ;
143
+ let before = bool_op . values . iter ( ) . take ( * first ) . cloned ( ) ;
144
+ let after = bool_op . values . iter ( ) . skip ( last + 1 ) . cloned ( ) ;
141
145
142
146
diagnostic. set_fix ( Fix :: unsafe_edit ( Edit :: range_replacement (
143
147
checker. generator ( ) . expr ( & Expr :: BoolOp ( ast:: ExprBoolOp {
144
148
op : bool_op. op ,
145
149
values : before
146
- . into_iter ( )
147
- . cloned ( )
148
150
. chain ( std:: iter:: once ( Expr :: Compare ( ast:: ExprCompare {
149
151
left : Box :: new ( value. as_expr ( ) . clone ( ) ) ,
150
152
ops : match bool_op. op {
@@ -159,7 +161,7 @@ pub(crate) fn repeated_equality_comparison(checker: &mut Checker, bool_op: &ast:
159
161
} ) ] ) ,
160
162
range : bool_op. range ( ) ,
161
163
} ) ) )
162
- . chain ( after. into_iter ( ) . cloned ( ) )
164
+ . chain ( after)
163
165
. collect ( ) ,
164
166
range : bool_op. range ( ) ,
165
167
} ) ) ,
@@ -174,39 +176,43 @@ pub(crate) fn repeated_equality_comparison(checker: &mut Checker, bool_op: &ast:
174
176
/// Return `true` if the given expression is compatible with a membership test.
175
177
/// E.g., `==` operators can be joined with `or` and `!=` operators can be
176
178
/// joined with `and`.
177
- fn is_allowed_value ( bool_op : BoolOp , value : & Expr , semantic : & SemanticModel ) -> bool {
179
+ fn to_allowed_value < ' a > (
180
+ bool_op : BoolOp ,
181
+ value : & ' a Expr ,
182
+ semantic : & SemanticModel ,
183
+ ) -> Option < ( & ' a Expr , & ' a Expr ) > {
178
184
let Expr :: Compare ( ast:: ExprCompare {
179
185
left,
180
186
ops,
181
187
comparators,
182
188
..
183
189
} ) = value
184
190
else {
185
- return false ;
191
+ return None ;
186
192
} ;
187
193
188
194
// Ignore, e.g., `foo == bar == baz`.
189
195
let [ op] = & * * ops else {
190
- return false ;
196
+ return None ;
191
197
} ;
192
198
193
199
if match bool_op {
194
200
BoolOp :: Or => !matches ! ( op, CmpOp :: Eq ) ,
195
201
BoolOp :: And => !matches ! ( op, CmpOp :: NotEq ) ,
196
202
} {
197
- return false ;
203
+ return None ;
198
204
}
199
205
200
206
// Ignore self-comparisons, e.g., `foo == foo`.
201
207
let [ right] = & * * comparators else {
202
- return false ;
208
+ return None ;
203
209
} ;
204
210
if ComparableExpr :: from ( left) == ComparableExpr :: from ( right) {
205
- return false ;
211
+ return None ;
206
212
}
207
213
208
214
if contains_effect ( value, |id| semantic. has_builtin_binding ( id) ) {
209
- return false ;
215
+ return None ;
210
216
}
211
217
212
218
// Ignore `sys.version_info` and `sys.platform` comparisons, which are only
@@ -221,10 +227,10 @@ fn is_allowed_value(bool_op: BoolOp, value: &Expr, semantic: &SemanticModel) ->
221
227
)
222
228
} )
223
229
} ) {
224
- return false ;
230
+ return None ;
225
231
}
226
232
227
- true
233
+ Some ( ( left , right ) )
228
234
}
229
235
230
236
/// Generate a string like `obj in (a, b, c)` or `obj not in (a, b, c)`.
0 commit comments