52
52
import java .util .concurrent .TimeUnit ;
53
53
import java .util .concurrent .atomic .AtomicReference ;
54
54
55
- import static com .mongodb .assertions .Assertions .assertFalse ;
56
55
import static com .mongodb .assertions .Assertions .assertTrue ;
57
56
import static com .mongodb .assertions .Assertions .isTrue ;
58
57
import static com .mongodb .internal .connection .ServerAddressHelper .getSocketAddresses ;
@@ -100,35 +99,34 @@ public void close() {
100
99
group .shutdown ();
101
100
}
102
101
102
+ /**
103
+ * Monitors `OP_CONNECT` events for socket connections.
104
+ */
103
105
private static class SelectorMonitor implements Closeable {
104
106
105
107
static final class SocketRegistration {
106
108
private final SocketChannel socketChannel ;
107
- private final Runnable attachment ;
108
- private final AtomicReference <ConnectionRegistrationState > connectionRegistrationState ;
109
+ private final AtomicReference <Runnable > afterConnectAction ;
109
110
110
- enum ConnectionRegistrationState {
111
- CONNECTING ,
112
- CONNECTED ,
113
- TIMEOUT_OUT
111
+ SocketRegistration (final SocketChannel socketChannel , final Runnable afterConnectAction ) {
112
+ this .socketChannel = socketChannel ;
113
+ this .afterConnectAction = new AtomicReference <>(afterConnectAction );
114
114
}
115
115
116
- private SocketRegistration (final SocketChannel socketChannel , final Runnable attachment ) {
117
- this .socketChannel = socketChannel ;
118
- this .attachment = attachment ;
119
- this .connectionRegistrationState = new AtomicReference <>(ConnectionRegistrationState .CONNECTING );
116
+ boolean tryCancelPendingConnection () {
117
+ return tryTakeAction () != null ;
120
118
}
121
119
122
- public boolean markConnectionEstablishmentTimedOut () {
123
- return connectionRegistrationState .compareAndSet (
124
- ConnectionRegistrationState .CONNECTING ,
125
- ConnectionRegistrationState .TIMEOUT_OUT );
120
+ void runAfterConnectActionIfNotCanceled () {
121
+ Runnable afterConnectActionToExecute = tryTakeAction ();
122
+ if (afterConnectActionToExecute != null ) {
123
+ afterConnectActionToExecute .run ();
124
+ }
126
125
}
127
126
128
- public boolean markConnectionEstablishmentCompleted () {
129
- return connectionRegistrationState .compareAndSet (
130
- ConnectionRegistrationState .CONNECTING ,
131
- ConnectionRegistrationState .CONNECTED );
127
+ @ Nullable
128
+ private Runnable tryTakeAction () {
129
+ return afterConnectAction .getAndSet (null );
132
130
}
133
131
}
134
132
@@ -144,7 +142,6 @@ public boolean markConnectionEstablishmentCompleted() {
144
142
}
145
143
}
146
144
147
- // Monitors OP_CONNECT events.
148
145
void start () {
149
146
Thread selectorThread = new Thread (() -> {
150
147
try {
@@ -153,13 +150,7 @@ void start() {
153
150
selector .select ();
154
151
for (SelectionKey selectionKey : selector .selectedKeys ()) {
155
152
selectionKey .cancel ();
156
- SocketRegistration socketRegistration = (SocketRegistration ) selectionKey .attachment ();
157
-
158
- boolean markedCompleted = socketRegistration .markConnectionEstablishmentCompleted ();
159
- if (markedCompleted ) {
160
- Runnable runnable = socketRegistration .attachment ;
161
- runnable .run ();
162
- }
153
+ ((SocketRegistration ) selectionKey .attachment ()).runAfterConnectActionIfNotCanceled ();
163
154
}
164
155
165
156
for (Iterator <SocketRegistration > iter = pendingRegistrations .iterator (); iter .hasNext ();) {
@@ -228,15 +219,14 @@ public void openAsync(final OperationContext operationContext, final AsyncComple
228
219
if (getSettings ().getSendBufferSize () > 0 ) {
229
220
socketChannel .setOption (StandardSocketOptions .SO_SNDBUF , getSettings ().getSendBufferSize ());
230
221
}
231
-
222
+ //getConnectTimeoutMs MUST be called before connection attempt, as it might throw MongoOperationTimeout exception.
223
+ int connectTimeoutMs = operationContext .getTimeoutContext ().getConnectTimeoutMs ();
232
224
socketChannel .connect (getSocketAddresses (getServerAddress (), inetAddressResolver ).get (0 ));
233
-
234
225
SelectorMonitor .SocketRegistration socketRegistration = new SelectorMonitor .SocketRegistration (
235
226
socketChannel , () -> initializeTslChannel (handler , socketChannel ));
236
227
237
- int connectTimeoutMs = getSettings ().getConnectTimeout (TimeUnit .MILLISECONDS );
238
228
if (connectTimeoutMs > 0 ) {
239
- scheduleTimeoutInterruption (handler , socketRegistration , socketChannel , connectTimeoutMs );
229
+ scheduleTimeoutInterruption (handler , socketRegistration , connectTimeoutMs );
240
230
}
241
231
selectorMonitor .register (socketRegistration );
242
232
} catch (IOException e ) {
@@ -248,12 +238,10 @@ public void openAsync(final OperationContext operationContext, final AsyncComple
248
238
249
239
private void scheduleTimeoutInterruption (final AsyncCompletionHandler <Void > handler ,
250
240
final SelectorMonitor .SocketRegistration socketRegistration ,
251
- final SocketChannel socketChannel ,
252
241
final int connectTimeoutMs ) {
253
242
group .getTimeoutExecutor ().schedule (() -> {
254
- boolean markedTimedOut = socketRegistration .markConnectionEstablishmentTimedOut ();
255
- if (markedTimedOut ) {
256
- closeAndTimeout (handler , socketChannel );
243
+ if (socketRegistration .tryCancelPendingConnection ()) {
244
+ closeAndTimeout (handler , socketRegistration .socketChannel );
257
245
}
258
246
}, connectTimeoutMs , TimeUnit .MILLISECONDS );
259
247
}
0 commit comments