diff --git a/source/vibe/core/net.d b/source/vibe/core/net.d index 19ed62d..47088a2 100644 --- a/source/vibe/core/net.d +++ b/source/vibe/core/net.d @@ -88,10 +88,12 @@ TCPListener listenTCP(ushort port, TCPConnectionDelegate connection_callback, st auto addr = resolveHost(address); addr.port = port; assert(options == TCPListenOptions.defaults, "TODO"); - auto sock = eventDriver.sockets.listenStream(addr.toUnknownAddress, (StreamListenSocketFD ls, StreamSocketFD s) @safe nothrow { - import vibe.core.core : runTask; - runTask(connection_callback, TCPConnection(s)); - }); + auto sock = eventDriver.sockets.listenStream(addr.toUnknownAddress, + (StreamListenSocketFD ls, StreamSocketFD s, scope RefAddress addr) @safe nothrow { + import vibe.core.core : runTask; + auto conn = TCPConnection(s, addr); + runTask(connection_callback, conn); + }); return TCPListener(sock); } @@ -147,11 +149,8 @@ TCPConnection connectTCP(NetworkAddress addr, NetworkAddress bind_address = anyA enforce(addr.family == bind_address.family, "Destination address and bind address have different address families."); return () @trusted { // scope - scope uaddr = new UnknownAddress; - addr.toUnknownAddress(uaddr); - - scope baddr = new UnknownAddress; - bind_address.toUnknownAddress(baddr); + scope uaddr = new RefAddress(addr.sockAddr, addr.sockAddrLen); + scope baddr = new RefAddress(addr.sockAddr, addr.sockAddrLen); // FIXME: make this interruptible auto result = asyncAwaitUninterruptible!(ConnectCallback, @@ -159,7 +158,8 @@ TCPConnection connectTCP(NetworkAddress addr, NetworkAddress bind_address = anyA //cb => eventDriver.sockets.cancelConnect(cb) ); enforce(result[1] == ConnectStatus.connected, "Failed to connect to "~addr.toString()~": "~result[1].to!string); - return TCPConnection(result[0]); + + return TCPConnection(result[0], uaddr); } (); } @@ -377,6 +377,12 @@ struct TCPConnection { struct Context { BatchBuffer!ubyte readBuffer; + bool tcpNoDelay = false; + bool keepAlive = false; + Duration readTimeout = Duration.max; + NetworkAddress remoteAddress; + NetworkAddress localAddress; + string remoteAddressString; } private { @@ -384,11 +390,18 @@ struct TCPConnection { Context* m_context; } - private this(StreamSocketFD socket) + private this(StreamSocketFD socket, scope RefAddress remote_address) nothrow { + import std.exception : enforce; + m_socket = socket; m_context = () @trusted { return &eventDriver.core.userData!Context(socket); } (); m_context.readBuffer.capacity = 4096; + try m_context.remoteAddress = NetworkAddress(remote_address); + catch (Exception e) { logWarn("Failed to get remote address for TCP connection: %s", e.msg); } + scope laddr = new RefAddress(m_context.localAddress.sockAddr, m_context.localAddress.sockAddrLen); + if (!eventDriver.sockets.getLocalAddress(socket, laddr)) + logWarn("Failed to get local address for TCP connection."); } this(this) @@ -405,15 +418,15 @@ struct TCPConnection { bool opCast(T)() const nothrow if (is(T == bool)) { return m_socket != StreamSocketFD.invalid; } - @property void tcpNoDelay(bool enabled) { eventDriver.sockets.setTCPNoDelay(m_socket, enabled); } - @property bool tcpNoDelay() const { assert(false); } - @property void keepAlive(bool enable) { assert(false); } - @property bool keepAlive() const { assert(false); } - @property void readTimeout(Duration duration) { } - @property Duration readTimeout() const { assert(false); } - @property string peerAddress() const { return ""; } - @property NetworkAddress localAddress() const { return NetworkAddress.init; } - @property NetworkAddress remoteAddress() const { return NetworkAddress.init; } + @property void tcpNoDelay(bool enabled) { eventDriver.sockets.setTCPNoDelay(m_socket, enabled); m_context.tcpNoDelay = enabled; } + @property bool tcpNoDelay() const { return m_context.tcpNoDelay; } + @property void keepAlive(bool enabled) { eventDriver.sockets.setKeepAlive(m_socket, enabled); m_context.keepAlive = enabled; } + @property bool keepAlive() const { return m_context.keepAlive; } + @property void readTimeout(Duration duration) { m_context.readTimeout = duration; } + @property Duration readTimeout() const { return m_context.readTimeout; } + @property string peerAddress() const { return m_context.remoteAddress.toString(); } + @property NetworkAddress localAddress() const { return localAddress; } + @property NetworkAddress remoteAddress() const { return remoteAddress; } @property bool connected() const { if (m_socket == StreamSocketFD.invalid) return false; @@ -438,20 +451,26 @@ struct TCPConnection { bool waitForData(Duration timeout = Duration.max) { mixin(tracer); - // TODO: timeout!! if (m_context.readBuffer.length > 0) return true; auto mode = timeout <= 0.seconds ? IOMode.immediate : IOMode.once; - auto res = asyncAwait!(IOCallback, + + Waitable!( cb => eventDriver.sockets.read(m_socket, m_context.readBuffer.peekDst(), mode, cb), - cb => eventDriver.sockets.cancelRead(m_socket) - ); - logTrace("Socket %s, read %s bytes: %s", res[0], res[2], res[1]); + cb => eventDriver.sockets.cancelRead(m_socket), + IOCallback + ) waiter; + + asyncAwaitAny!true(timeout, waiter); + + if (waiter.cancelled) return false; + + logTrace("Socket %s, read %s bytes: %s", waiter.results[0], waiter.results[2], waiter.results[1]); assert(m_context.readBuffer.length == 0); - m_context.readBuffer.putN(res[2]); - switch (res[1]) { + m_context.readBuffer.putN(waiter.results[2]); + switch (waiter.results[1]) { default: - logInfo("read status %s", res[1]); + logInfo("read status %s", waiter.results[1]); throw new Exception("Error reading data from socket."); case IOStatus.ok: break; case IOStatus.wouldBlock: assert(mode == IOMode.immediate); break; @@ -467,25 +486,28 @@ mixin(tracer); { import std.algorithm.comparison : min; - while (count > 0) { - waitForData(); + m_context.readTimeout.loopWithTimeout!((remaining) { + waitForData(remaining); auto n = min(count, m_context.readBuffer.length); m_context.readBuffer.popFrontN(n); count -= n; - } + return count == 0; + }); } void read(ubyte[] dst) { mixin(tracer); import std.algorithm.comparison : min; - while (dst.length > 0) { + if (!dst.length) return; + m_context.readTimeout.loopWithTimeout!((remaining) { enforce(waitForData(), "Reached end of stream while reading data."); assert(m_context.readBuffer.length > 0); auto l = min(dst.length, m_context.readBuffer.length); m_context.readBuffer.read(dst[0 .. l]); dst = dst[l .. $]; - } + return dst.length == 0; + }); } void write(in ubyte[] bytes) @@ -547,6 +569,29 @@ mixin(tracer); mixin validateConnectionStream!TCPConnection; +private void loopWithTimeout(alias LoopBody, ExceptionType = Exception)(Duration timeout) +{ + import core.time : seconds; + import std.datetime : Clock, SysTime, UTC; + + SysTime now; + if (timeout != Duration.max) + now = Clock.currTime(UTC()); + + do { + if (LoopBody(timeout)) + return; + + if (timeout != Duration.max) { + auto prev = now; + now = Clock.currTime(UTC()); + if (now > prev) timeout -= now - prev; + } + } while (timeout > 0.seconds); + + throw new ExceptionType("Operation timed out."); +} + /** Represents a listening TCP socket. @@ -554,6 +599,7 @@ mixin validateConnectionStream!TCPConnection; struct TCPListener { private { StreamListenSocketFD m_socket; + NetworkAddress m_bindAddress; } this(StreamListenSocketFD socket) @@ -566,7 +612,7 @@ struct TCPListener { /// The local address at which TCP connections are accepted. @property NetworkAddress bindAddress() { - assert(false); + return m_bindAddress; } /// Stops listening and closes the socket. diff --git a/source/vibe/core/sync.d b/source/vibe/core/sync.d index 1f1885f..27138ac 100644 --- a/source/vibe/core/sync.d +++ b/source/vibe/core/sync.d @@ -1093,7 +1093,8 @@ private struct ThreadLocalWaiter { Waitable!( cb => w.wait(cb), - cb => w.cancel() + cb => w.cancel(), + typeof(Waiter.notifier) ) waitable; void removeWaiter() @@ -1124,7 +1125,7 @@ private struct ThreadLocalWaiter { } }, cb => eventDriver.events.cancelWait(evt, cb), - EventID + EventCallback ) ewaitable; asyncAwaitAny!interruptible(timeout, waitable, ewaitable); } else { diff --git a/source/vibe/internal/async.d b/source/vibe/internal/async.d index 9bd2c03..2b1d2bd 100644 --- a/source/vibe/internal/async.d +++ b/source/vibe/internal/async.d @@ -10,14 +10,14 @@ import core.time : Duration, seconds; auto asyncAwait(Callback, alias action, alias cancel, string func = __FUNCTION__)() if (!is(Object == Duration)) { - Waitable!(action, cancel, ParameterTypeTuple!Callback) waitable; + Waitable!(action, cancel, Callback) waitable; asyncAwaitAny!(true, func)(waitable); return tuple(waitable.results); } auto asyncAwait(Callback, alias action, alias cancel, string func = __FUNCTION__)(Duration timeout) { - Waitable!(action, cancel, ParameterTypeTuple!Callback) waitable; + Waitable!(action, cancel, Callback) waitable; asyncAwaitAny!(true, func)(timeout, waitable); static struct R { bool completed; @@ -30,23 +30,24 @@ auto asyncAwaitUninterruptible(Callback, alias action, string func = __FUNCTION_ nothrow { static if (is(typeof(action(Callback.init)) == void)) void cancel(Callback) { assert(false, "Action cannot be cancelled."); } else void cancel(Callback, typeof(action(Callback.init))) { assert(false, "Action cannot be cancelled."); } - Waitable!(action, cancel, ParameterTypeTuple!Callback) waitable; + Waitable!(action, cancel, Callback) waitable; asyncAwaitAny!(false, func)(waitable); return tuple(waitable.results); } auto asyncAwaitUninterruptible(Callback, alias action, alias cancel, string func = __FUNCTION__)(Duration timeout) nothrow { - Waitable!(action, cancel, ParameterTypeTuple!Callback) waitable; + Waitable!(action, cancel, Callback) waitable; asyncAwaitAny!(false, func)(timeout, waitable); return tuple(waitable.results); } -struct Waitable(alias wait, alias cancel, Results...) { +struct Waitable(alias wait, alias cancel, CB) { import std.traits : ReturnType; - alias Callback = void delegate(Results) @safe nothrow; - Results results; + alias Callback = CB; + + ParameterTypeTuple!Callback results; bool cancelled; auto waitCallback(Callback cb) nothrow { return wait(cb); } @@ -68,7 +69,7 @@ void asyncAwaitAny(bool interruptible, string func = __FUNCTION__, Waitables...) Waitable!( cb => eventDriver.timers.wait(tm, cb), cb => eventDriver.timers.cancelWait(tm), - TimerID + TimerCallback ) timerwaitable; asyncAwaitAny!(interruptible, func)(timerwaitable, waitables); } @@ -79,6 +80,8 @@ void asyncAwaitAny(bool interruptible, string func = __FUNCTION__, Waitables...) { import std.meta : staticMap; import std.algorithm.searching : any; + import std.format : format; + import std.meta : AliasSeq; import std.traits : ReturnType; /*scope*/ staticMap!(CBDel, Waitables) callbacks; // FIXME: avoid heap delegates @@ -96,15 +99,17 @@ void asyncAwaitAny(bool interruptible, string func = __FUNCTION__, Waitables...) () @trusted { logDebugV("si %x", &still_inside); } (); foreach (i, W; Waitables) { - /*scope*/auto cb = (typeof(Waitables[i].results) results) @safe nothrow { - () @trusted { logDebugV("siw %x", &still_inside); } (); - debug(VibeAsyncLog) logDebugV("Waitable %s in %s fired (istask=%s).", i, func, t != Task.init); + alias PTypes = ParameterTypeTuple!(CBDel!W); + /*scope*/auto cb = mixin(q{(%s) @safe nothrow { + () @trusted { logDebugV("siw %%x", &still_inside); } (); + debug(VibeAsyncLog) logDebugV("Waitable %%s in %%s fired (istask=%%s).", i, func, t != Task.init); assert(still_inside, "Notification fired after asyncAwait had already returned!"); fired[i] = true; any_fired = true; - waitables[i].results = results; + static if (PTypes.length) + waitables[i].results = AliasSeq!(%s); if (t != Task.init) switchToTask(t); - }; + }}.format(generateParamDecls!(CBDel!W), generateParamNames!(CBDel!W))); callbacks[i] = cb; debug(VibeAsyncLog) logDebugV("Starting operation %s", i); @@ -158,13 +163,13 @@ void asyncAwaitAny(bool interruptible, string func = __FUNCTION__, Waitables...) debug(VibeAsyncLog) logDebugV("Return result for %s.", func); } -private alias CBDel(Waitable) = void delegate(typeof(Waitable.results)) @safe nothrow; +private alias CBDel(Waitable) = Waitable.Callback; private struct ScopeGuard { @safe nothrow: void delegate() op; ~this() { if (op !is null) op(); } } @safe nothrow /*@nogc*/ unittest { int cnt = 0; - auto ret = asyncAwaitUninterruptible!(void delegate(int), (cb) { cnt++; cb(42); }); + auto ret = asyncAwaitUninterruptible!(void delegate(int) @safe nothrow, (cb) { cnt++; cb(42); }); assert(ret[0] == 42); assert(cnt == 1); } @@ -174,12 +179,12 @@ private struct ScopeGuard { @safe nothrow: void delegate() op; ~this() { if (op Waitable!( (cb) { a++; cb(42); }, (cb) { assert(false); }, - int + void delegate(int) @safe nothrow ) w1; Waitable!( (cb) { b++; }, (cb) { c++; }, - int + void delegate(int) @safe nothrow ) w2; asyncAwaitAny!false(w1, w2); @@ -190,3 +195,34 @@ private struct ScopeGuard { @safe nothrow: void delegate() op; ~this() { if (op assert(w1.results[0] == 42 && w2.results[0] == 0); assert(a == 2 && b == 1 && c == 1); } + +private string generateParamDecls(Fun)() +{ + import std.format : format; + import std.traits : ParameterTypeTuple, ParameterStorageClass, ParameterStorageClassTuple; + + alias Types = ParameterTypeTuple!Fun; + alias SClasses = ParameterStorageClassTuple!Fun; + string ret; + foreach (i, T; Types) { + static if (i > 0) ret ~= ", "; + static if (SClasses[i] & ParameterStorageClass.lazy_) ret ~= "lazy "; + static if (SClasses[i] & ParameterStorageClass.scope_) ret ~= "scope "; + static if (SClasses[i] & ParameterStorageClass.out_) ret ~= "out "; + static if (SClasses[i] & ParameterStorageClass.ref_) ret ~= "ref "; + ret ~= format("PTypes[%s] param_%s", i, i); + } + return ret; +} + +private string generateParamNames(Fun)() +{ + import std.format : format; + + string ret; + foreach (i, T; ParameterTypeTuple!Fun) { + static if (i > 0) ret ~= ", "; + ret ~= format("param_%s", i); + } + return ret; +}