Merge pull request #132 from vibe-d/connect_callback_fix

Ensure that only the connect or the connection error callback is ever triggered.
merged-on-behalf-of: Leonid Kramer <l-kramer@users.noreply.github.com>
This commit is contained in:
The Dlang Bot 2020-03-16 08:34:10 +01:00 committed by GitHub
commit 608f60237f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 130 additions and 106 deletions

View file

@ -153,3 +153,4 @@ test_script:
- echo %PATH%
- '%DC% --version'
- dub test --arch=%Darch% --compiler=%DC% --config=%CONFIG%
- for %%i in (tests\*.d) do echo %%i && dub --single %%i --arch=%Darch% --compiler=%DC% --override-config eventcore/%CONFIG% || exit /B 1

View file

@ -19,6 +19,7 @@ import eventcore.drivers.threadedfile;
import eventcore.internal.consumablequeue : ConsumableQueue;
import eventcore.internal.utils;
import core.time : MonoTime;
import std.algorithm.comparison : among, min, max;
version (Posix) {
@ -221,17 +222,17 @@ final class PosixEventDriverCore(Loop : PosixEventLoop, Timers : EventDriverTime
if (timeout <= 0.seconds) {
got_events = m_loop.doProcessEvents(0.seconds);
m_timers.process(currStdTime);
m_timers.process(MonoTime.currTime);
} else {
long now = currStdTime;
auto now = MonoTime.currTime;
do {
auto nextto = max(min(m_timers.getNextTimeout(now), timeout), 0.seconds);
got_events = m_loop.doProcessEvents(nextto);
long prev_step = now;
now = currStdTime;
auto prev_step = now;
now = MonoTime.currTime;
got_events |= m_timers.process(now);
if (timeout != Duration.max)
timeout -= (now - prev_step).hnsecs;
timeout -= now - prev_step;
} while (timeout > 0.seconds && !m_exit && !got_events);
}

View file

@ -17,7 +17,7 @@ version (Posix) {
import core.sys.posix.netinet.tcp;
import core.sys.posix.sys.un;
import core.sys.posix.unistd : close, read, write;
import core.stdc.errno : errno, EAGAIN, EINPROGRESS;
import core.stdc.errno : errno, EAGAIN, EINPROGRESS, ECONNREFUSED;
import core.sys.posix.fcntl;
import core.sys.posix.sys.socket;
@ -88,6 +88,7 @@ version (Windows) {
import core.sys.windows.winsock2;
alias sockaddr_storage = SOCKADDR_STORAGE;
alias EAGAIN = WSAEWOULDBLOCK;
alias ECONNREFUSED = WSAECONNREFUSED;
enum SHUT_RDWR = SD_BOTH;
enum SHUT_RD = SD_RECEIVE;
enum SHUT_WR = SD_SEND;
@ -138,9 +139,7 @@ final class PosixEventDriverSockets(Loop : PosixEventLoop) : EventDriverSockets
}
m_loop.initFD(sock, FDFlags.none, StreamSocketSlot.init);
m_loop.registerFD(sock, EventMask.read|EventMask.write|EventMask.status);
m_loop.setNotifyCallback!(EventType.status)(sock, &onConnectError);
releaseRef(sock); // onConnectError callback is weak reference
m_loop.registerFD(sock, EventMask.read|EventMask.write);
auto ret = () @trusted { return connect(cast(sock_t)sock, address.name, address.nameLen); } ();
if (ret == 0) {
@ -155,10 +154,10 @@ final class PosixEventDriverSockets(Loop : PosixEventLoop) : EventDriverSockets
}
m_loop.setNotifyCallback!(EventType.write)(sock, &onConnect);
} else {
m_loop.unregisterFD(sock, EventMask.read|EventMask.write);
m_loop.clearFD!StreamSocketSlot(sock);
m_loop.unregisterFD(sock, EventMask.read|EventMask.write|EventMask.status);
invalidateSocket();
on_connect(StreamSocketFD.invalid, ConnectStatus.unknownError);
on_connect(StreamSocketFD.invalid, determineConnectStatus(err));
return StreamSocketFD.invalid;
}
}
@ -175,11 +174,7 @@ final class PosixEventDriverSockets(Loop : PosixEventLoop) : EventDriverSockets
"Unable to cancel connect on the socket that is not in connecting state");
state = ConnectionState.closed;
connectCallback = null;
m_loop.setNotifyCallback!(EventType.status)(sock, null);
m_loop.setNotifyCallback!(EventType.write)(sock, null);
m_loop.clearFD!StreamSocketSlot(sock);
m_loop.unregisterFD(sock, EventMask.read|EventMask.write|EventMask.status);
closeSocket(cast(sock_t)sock.value);
}
}
@ -190,7 +185,7 @@ final class PosixEventDriverSockets(Loop : PosixEventLoop) : EventDriverSockets
return StreamSocketFD.invalid;
setSocketNonBlocking(fd);
m_loop.initFD(fd, FDFlags.none, StreamSocketSlot.init);
m_loop.registerFD(fd, EventMask.read|EventMask.write|EventMask.status);
m_loop.registerFD(fd, EventMask.read|EventMask.write);
return fd;
}
@ -199,22 +194,32 @@ final class PosixEventDriverSockets(Loop : PosixEventLoop) : EventDriverSockets
auto sock = cast(StreamSocketFD)fd;
auto l = lockHandle(sock);
m_loop.setNotifyCallback!(EventType.write)(sock, null);
ConnectStatus status = ConnectStatus.unknownError;
int err;
socklen_t errlen = err.sizeof;
if (() @trusted { return getsockopt(cast(sock_t)fd, SOL_SOCKET, SO_ERROR, &err, &errlen); } () == 0)
status = determineConnectStatus(err);
with (m_loop.m_fds[sock].streamSocket) {
state = ConnectionState.connected;
assert(state == ConnectionState.connecting);
state = status == ConnectStatus.connected
? ConnectionState.connected
: ConnectionState.closed;
auto cb = connectCallback;
connectCallback = null;
if (cb) cb(sock, ConnectStatus.connected);
if (cb) cb(cast(StreamSocketFD)sock, status);
}
}
private void onConnectError(FD sock)
private ConnectStatus determineConnectStatus(int sock_err)
{
// FIXME: determine the correct kind of error!
with (m_loop.m_fds[sock].streamSocket) {
state = ConnectionState.closed;
auto cb = connectCallback;
connectCallback = null;
if (cb) cb(cast(StreamSocketFD)sock, ConnectStatus.refused);
switch (sock_err) {
default: return ConnectStatus.unknownError;
case 0: return ConnectStatus.connected;
case ECONNREFUSED: return ConnectStatus.refused;
}
}
@ -287,9 +292,7 @@ final class PosixEventDriverSockets(Loop : PosixEventLoop) : EventDriverSockets
auto fd = cast(StreamSocketFD)sockfd;
m_loop.initFD(fd, FDFlags.none, StreamSocketSlot.init);
m_loop.m_fds[fd].streamSocket.state = ConnectionState.connected;
m_loop.registerFD(fd, EventMask.read|EventMask.write|EventMask.status);
m_loop.setNotifyCallback!(EventType.status)(fd, &onConnectError);
releaseRef(fd); // onConnectError callback is weak reference
m_loop.registerFD(fd, EventMask.read|EventMask.write);
//print("accept %d", sockfd);
scope RefAddress addrc = new RefAddress(() @trusted { return cast(sockaddr*)&addr; } (), addr_len);
m_loop.m_fds[listenfd].streamListen.acceptCallback(cast(StreamListenSocketFD)listenfd, fd, addrc);
@ -656,7 +659,7 @@ final class PosixEventDriverSockets(Loop : PosixEventLoop) : EventDriverSockets
}
m_loop.initFD(sock, is_internal ? FDFlags.internal : FDFlags.none, DgramSocketSlot.init);
m_loop.registerFD(sock, EventMask.read|EventMask.write|EventMask.status);
m_loop.registerFD(sock, EventMask.read|EventMask.write);
return sock;
}
@ -673,7 +676,7 @@ final class PosixEventDriverSockets(Loop : PosixEventLoop) : EventDriverSockets
return DatagramSocketFD.init;
setSocketNonBlocking(fd, close_on_exec);
m_loop.initFD(fd, is_internal ? FDFlags.internal : FDFlags.none, DgramSocketSlot.init);
m_loop.registerFD(fd, EventMask.read|EventMask.write|EventMask.status);
m_loop.registerFD(fd, EventMask.read|EventMask.write);
return fd;
}
@ -881,7 +884,7 @@ final class PosixEventDriverSockets(Loop : PosixEventLoop) : EventDriverSockets
// listening sockets have an incremented the reference count because of setNotifyCallback
int base_refcount = slot.specific.hasType!StreamListenSocketSlot ? 1 : 0;
if (--slot.common.refCount == base_refcount) {
m_loop.unregisterFD(fd, EventMask.read|EventMask.write|EventMask.status);
m_loop.unregisterFD(fd, EventMask.read|EventMask.write);
switch (slot.specific.kind) with (slot.specific.Kind) {
default: assert(false, "File descriptor slot is not a socket.");
case streamSocket:

View file

@ -7,7 +7,7 @@ import eventcore.driver;
import eventcore.internal.consumablequeue;
import eventcore.internal.dlist;
import eventcore.internal.utils : mallocT, freeT, nogc_assert;
import core.time : Duration, MonoTime, hnsecs;
final class LoopTimeoutTimerDriver : EventDriverTimers {
import std.experimental.allocator.building_blocks.free_list;
@ -17,7 +17,6 @@ final class LoopTimeoutTimerDriver : EventDriverTimers {
import std.container.array;
import std.datetime : Clock;
import std.range : SortedRange, assumeSorted, take;
import core.time : hnsecs, Duration;
import core.memory : GC;
private {
@ -46,24 +45,22 @@ final class LoopTimeoutTimerDriver : EventDriverTimers {
package @property size_t pendingCount() const @safe nothrow { return m_timerQueue.length; }
final package Duration getNextTimeout(long stdtime)
final package Duration getNextTimeout(MonoTime time)
@safe nothrow {
if (m_timerQueue.empty) return Duration.max;
return (m_timerQueue.front.timeout - stdtime).hnsecs;
return m_timerQueue.front.timeout - time;
}
final package bool process(long stdtime)
final package bool process(MonoTime time)
@trusted nothrow {
assert(m_firedTimers.length == 0);
if (m_timerQueue.empty) return false;
TimerSlot ts = void;
ts.timeout = stdtime+1;
foreach (tm; m_timerQueue[]) {
if (tm.timeout > stdtime) break;
if (tm.repeatDuration > 0) {
if (tm.timeout > time) break;
if (tm.repeatDuration > Duration.zero) {
do tm.timeout += tm.repeatDuration;
while (tm.timeout <= stdtime);
while (tm.timeout <= time);
} else tm.pending = false;
m_firedTimers.put(tm);
}
@ -72,7 +69,7 @@ final class LoopTimeoutTimerDriver : EventDriverTimers {
foreach (tm; processed_timers) {
m_timerQueue.remove(tm);
if (tm.repeatDuration > 0)
if (tm.repeatDuration > Duration.zero)
enqueueTimer(tm);
}
@ -98,7 +95,7 @@ final class LoopTimeoutTimerDriver : EventDriverTimers {
GC.addRange(tm, TimerSlot.sizeof, typeid(TimerSlot));
tm.id = id;
tm.refCount = 1;
tm.timeout = long.max;
tm.timeout = MonoTime.max;
m_timers[id] = tm;
return id;
}
@ -108,8 +105,8 @@ final class LoopTimeoutTimerDriver : EventDriverTimers {
scope (failure) assert(false);
auto tm = m_timers[timer];
if (tm.pending) stop(timer);
tm.timeout = Clock.currStdTime + timeout.total!"hnsecs";
tm.repeatDuration = repeat.total!"hnsecs";
tm.timeout = MonoTime.currTime + timeout;
tm.repeatDuration = repeat;
tm.pending = true;
enqueueTimer(tm);
}
@ -137,7 +134,7 @@ final class LoopTimeoutTimerDriver : EventDriverTimers {
final override bool isPeriodic(TimerID descriptor)
{
return m_timers[descriptor].repeatDuration > 0;
return m_timers[descriptor].repeatDuration > Duration.zero;
}
final override void wait(TimerID timer, TimerCallback2 callback)
@ -239,8 +236,8 @@ struct TimerSlot {
TimerID id;
uint refCount;
bool pending;
long timeout; // stdtime
long repeatDuration;
MonoTime timeout;
Duration repeatDuration;
TimerCallback2 callback; // TODO: use a list with small-value optimization
DataInitializer userDataDestructor;

View file

@ -96,7 +96,7 @@ final class WinAPIEventDriverCore : EventDriverCore {
override ExitReason processEvents(Duration timeout = Duration.max)
{
import std.algorithm : min;
import core.time : hnsecs, seconds;
import core.time : MonoTime, seconds;
if (m_exit) {
m_exit = false;
@ -106,12 +106,12 @@ final class WinAPIEventDriverCore : EventDriverCore {
if (!waiterCount) return ExitReason.outOfWaiters;
bool got_event;
long now = currStdTime;
auto now = MonoTime.currTime;
do {
auto nextto = min(m_timers.getNextTimeout(now), timeout);
got_event |= doProcessEvents(nextto);
long prev_step = now;
now = currStdTime;
auto prev_step = now;
now = MonoTime.currTime;
got_event |= m_timers.process(now);
if (m_exit) {
@ -119,7 +119,7 @@ final class WinAPIEventDriverCore : EventDriverCore {
return ExitReason.exited;
} else if (got_event) break;
if (timeout != Duration.max)
timeout -= (now - prev_step).hnsecs;
timeout -= now - prev_step;
} while (timeout > 0.seconds);
if (!waiterCount) return ExitReason.outOfWaiters;

View file

@ -85,7 +85,6 @@ final class WinAPIEventDriverSockets : EventDriverSockets {
}
m_core.addWaiter();
addRef(sock);
return sock;
} else {
clearSocketSlot(sock);
@ -103,8 +102,9 @@ final class WinAPIEventDriverSockets : EventDriverSockets {
assert(state == ConnectionState.connecting,
"Must be in 'connecting' state when calling cancelConnection.");
clearSocketSlot(sock);
() @trusted { closesocket(sock); } ();
state = ConnectionState.closed;
connectCallback = null;
m_core.removeWaiter();
}
}
@ -249,6 +249,7 @@ final class WinAPIEventDriverSockets : EventDriverSockets {
{
auto slot = () @trusted { return &m_sockets[socket].streamSocket(); } ();
slot.read.buffer = buffer;
slot.read.bytesTransferred = 0;
slot.read.mode = mode;
slot.read.wsabuf[0].len = buffer.length;
slot.read.wsabuf[0].buf = () @trusted { return buffer.ptr; } ();
@ -306,7 +307,7 @@ final class WinAPIEventDriverSockets : EventDriverSockets {
}
if (slot.streamSocket.read.mode == IOMode.once || !slot.streamSocket.read.buffer.length) {
invokeCallback(IOStatus.ok, cbTransferred);
invokeCallback(IOStatus.ok, slot.streamSocket.read.bytesTransferred);
return;
}
@ -332,6 +333,7 @@ final class WinAPIEventDriverSockets : EventDriverSockets {
{
auto slot = () @trusted { return &m_sockets[socket].streamSocket(); } ();
slot.write.buffer = buffer;
slot.write.bytesTransferred = 0;
slot.write.mode = mode;
slot.write.wsabuf[0].len = buffer.length;
slot.write.wsabuf[0].buf = () @trusted { return cast(ubyte*)buffer.ptr; } ();
@ -381,7 +383,7 @@ final class WinAPIEventDriverSockets : EventDriverSockets {
}
if (slot.streamSocket.write.mode == IOMode.once || !slot.streamSocket.write.buffer.length) {
invokeCallback(IOStatus.ok, cbTransferred);
invokeCallback(IOStatus.ok, slot.streamSocket.write.bytesTransferred);
return;
}
@ -550,8 +552,11 @@ final class WinAPIEventDriverSockets : EventDriverSockets {
}
}
if (mode == IOMode.immediate)
if (mode == IOMode.immediate) {
() @trusted { CancelIoEx(cast(HANDLE)cast(SOCKET)socket, cast(LPOVERLAPPED)&slot.read.overlapped); } ();
on_read_finish(socket, IOStatus.wouldBlock, 0, null);
return;
}
slot.read.callback = on_read_finish;
addRef(socket);
@ -644,8 +649,11 @@ final class WinAPIEventDriverSockets : EventDriverSockets {
}
}
if (mode == IOMode.immediate)
if (mode == IOMode.immediate) {
() @trusted { CancelIoEx(cast(HANDLE)cast(SOCKET)socket, cast(LPOVERLAPPED)&slot.write.overlapped); } ();
on_write_finish(socket, IOStatus.wouldBlock, 0, null);
return;
}
slot.write.callback = on_write_finish;
addRef(socket);
@ -845,6 +853,8 @@ final class WinAPIEventDriverSockets : EventDriverSockets {
default: break;
case FD_CONNECT:
auto cb = slot.streamSocket.connectCallback;
if (!cb) break; // cancelled connect?
slot.streamSocket.connectCallback = null;
slot.common.driver.m_core.removeWaiter();
if (err) {
@ -852,7 +862,6 @@ final class WinAPIEventDriverSockets : EventDriverSockets {
cb(cast(StreamSocketFD)sock, ConnectStatus.refused);
} else {
slot.streamSocket.state = ConnectionState.connected;
if (slot.common.driver.releaseRef(cast(StreamSocketFD)sock))
cb(cast(StreamSocketFD)sock, ConnectStatus.connected);
}
break;

View file

@ -4,8 +4,10 @@
+/
module test;
import eventcore.core;
import std.stdio : writefln;
version (Linux) {
import eventcore.core;
import core.stdc.signal;
import core.sys.posix.signal : SIGUSR1;
import core.time : Duration, msecs;
@ -14,9 +16,6 @@ bool s_done;
void main()
{
version (OSX) writefln("Signals are not yet supported on macOS. Skipping test.");
else {
auto id = eventDriver.signals.listen(SIGUSR1, (id, status, sig) {
assert(!s_done);
assert(status == SignalStatus.ok);
@ -37,6 +36,10 @@ void main()
assert(er == ExitReason.outOfWaiters);
assert(s_done);
s_done = false;
}
} else {
void main()
{
writefln("Signals are not yet supported on macOS/Windows. Skipping test.");
}
}

View file

@ -28,6 +28,7 @@ void main()
eventDriver.timers.wait(tm, (tm) {
assert(eventDriver.sockets.getConnectionState(sock) == ConnectionState.connecting);
eventDriver.sockets.cancelConnectStream(sock);
eventDriver.sockets.releaseRef(sock);
s_done = true;
});

View file

@ -14,11 +14,7 @@ bool s_done;
void main()
{
version (OSX) {
import std.stdio;
writeln("This doesn't work on macOS. Skipping this test until it is determined that this special case should stay supported.");
return;
} else {
version (Linux) {
static ubyte[] pack1 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
@ -68,5 +64,9 @@ void main()
assert(s_done);
s_done = false;
} else {
import std.stdio;
writeln("This doesn't work on macOS/Windows. Skipping this test until it is determined that this special case should stay supported.");
return;
}
}

View file

@ -38,12 +38,16 @@ void main()
assert(bts == pack1.length);
assert(s_rbuf[0 .. pack1.length] == pack1);
auto tmw = eventDriver.timers.create();
eventDriver.timers.set(tmw, 20.msecs, 0.msecs);
eventDriver.timers.wait(tmw, (tmw) {
print("Second write");
client.write!((status, bytes) {
print("Second write done");
assert(status == IOStatus.ok);
assert(bytes == pack2.length);
})(pack2, IOMode.once);
});
print("Second read");
incoming.read!((status, bts) {

View file

@ -26,7 +26,7 @@ void main()
}
try {
assert(dur > 1200.msecs, (dur - 1200.msecs).toString());
assert(dur > 1200.msecs - 2.msecs, (dur - 1200.msecs).toString());
assert(dur < 1300.msecs, (dur - 1200.msecs).toString());
} catch (Exception e) assert(false, e.msg);
@ -38,7 +38,7 @@ void main()
try {
auto dur = MonoTime.currTime() - s_startTime;
s_cnt++;
assert(dur > 300.msecs * s_cnt, (dur - 300.msecs * s_cnt).toString());
assert(dur > 300.msecs * s_cnt - 2.msecs, (dur - 300.msecs * s_cnt).toString());
assert(dur < 300.msecs * s_cnt + 100.msecs, (dur - 300.msecs * s_cnt).toString());
assert(s_cnt <= 3);

View file

@ -20,6 +20,11 @@ void main()
static ubyte[] pack1 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
static ubyte[] pack2 = [4, 3, 2, 1, 0];
// Windows can not provide "immediate" semantics using the overlapped
// I/O API that is used
version (Windows) enum mode_immediate = IOMode.once;
else enum mode_immediate = IOMode.immediate;
auto baddr = new InternetAddress(0x7F000001, 40001);
auto anyaddr = new InternetAddress(0x7F000001, 0);
s_baseSocket = createDatagramSocket(baddr);
@ -55,14 +60,14 @@ void main()
destroy(s_connectedSocket);
s_done = true;
log("done.");
})(s_rbuf, IOMode.immediate);
})(s_rbuf, mode_immediate);
});
})(s_rbuf, IOMode.once);
s_connectedSocket.send!((status, bytes) {
log("send1: %s %s", status, bytes);
assert(status == IOStatus.ok);
assert(bytes == 10);
})(pack1, IOMode.immediate);
})(pack1, mode_immediate);
ExitReason er;
do er = eventDriver.core.processEvents(Duration.max);

View file

@ -44,7 +44,7 @@ void main()
auto dur = MonoTime.currTime() - s_startTime;
s_cnt++;
assert(dur > 300.msecs * s_cnt, (dur - 300.msecs * s_cnt).toString());
assert(dur > 300.msecs * s_cnt - 2.msecs, (dur - 300.msecs * s_cnt).toString());
assert(dur < 300.msecs * s_cnt + 100.msecs, (dur - 300.msecs * s_cnt).toString());
assert(s_cnt <= 5);