diff --git a/source/vibe/core/task.d b/source/vibe/core/task.d index 70bddb3..dfbbb61 100644 --- a/source/vibe/core/task.d +++ b/source/vibe/core/task.d @@ -10,6 +10,7 @@ module vibe.core.task; import vibe.core.log; import vibe.core.sync; +import core.atomic : atomicOp, atomicLoad, cas; import core.thread; import std.exception; import std.traits; @@ -58,7 +59,7 @@ struct Task { auto tfiber = cast(TaskFiber)fiber; if (!tfiber) return Task.init; // FIXME: returning a non-.init handle for a finished task might break some layered logic - return () @trusted { return Task(tfiber, tfiber.m_taskCounter); } (); + return Task(tfiber, tfiber.getTaskStatusFromOwnerThread().counter); } nothrow { @@ -69,13 +70,15 @@ struct Task { /** Determines if the task is still running or scheduled to be run. */ - @property bool running() // FIXME: this is NOT thread safe + @property bool running() const @trusted { assert(m_fiber !is null, "Invalid task handle"); - try if (this.taskFiber.state == Fiber.State.TERM) return false; catch (Throwable) {} - if (this.taskFiber.m_taskCounter != m_taskCounter) + auto tf = this.taskFiber; + try if (tf.state == Fiber.State.TERM) return false; catch (Throwable) {} + auto st = m_fiber.getTaskStatus(); + if (st.counter != m_taskCounter) return false; - return this.taskFiber.m_running || this.taskFiber.m_taskFunc.func !is null; + return st.initialized; } package @property ref ThreadInfo tidInfo() @system { return m_fiber ? taskFiber.tidInfo : s_tidInfo; } // FIXME: this is not thread safe! @@ -91,9 +94,9 @@ struct Task { T opCast(T)() const @safe nothrow if (is(T == bool)) { return m_fiber !is null; } - void join() @trusted { if (running) taskFiber.join!true(m_taskCounter); } // FIXME: this is NOT thread safe - void joinUninterruptible() @trusted nothrow { if (running) taskFiber.join!false(m_taskCounter); } // FIXME: this is NOT thread safe - void interrupt() @trusted nothrow { if (running) taskFiber.interrupt(m_taskCounter); } // FIXME: this is NOT thread safe + void join() @trusted { if (m_fiber) m_fiber.join!true(m_taskCounter); } + void joinUninterruptible() @trusted nothrow { if (m_fiber) m_fiber.join!false(m_taskCounter); } + void interrupt() @trusted nothrow { if (m_fiber) m_fiber.interrupt(m_taskCounter); } string toString() const @safe { import std.string; return format("%s:%s", () @trusted { return cast(void*)m_fiber; } (), m_taskCounter); } @@ -295,6 +298,15 @@ final package class TaskFiber : Fiber { static if ((void*).sizeof >= 8) enum defaultTaskStackSize = 16*1024*1024; else enum defaultTaskStackSize = 512*1024; + private enum Flags { + running = 1UL << 0, + initialized = 1UL << 1, + interrupt = 1UL << 2, + + shiftAmount = 3, + flagsMask = (1< 0 ? Yes.defer : No.defer; - taskScheduler.switchTo(this.task, defer); + while (true) { + auto tcf = atomicLoad(m_taskCounterAndFlags); + auto st = getTaskStatus(tcf); + if (!st.initialized || st.interrupt || st.counter != task_counter) + return; + auto tcf_int = tcf | Flags.interrupt; + if (cas(&m_taskCounterAndFlags, tcf, tcf_int)) + break; + } + + if (caller.m_thread is m_thread) { + auto thisus = () @trusted { return cast()this; } (); + debug (VibeTaskLog) logTrace("Resuming task with interrupt flag."); + auto defer = caller.m_yieldLockCount > 0 ? Yes.defer : No.defer; + taskScheduler.switchTo(thisus.task, defer); + } else { + debug (VibeTaskLog) logTrace("Set interrupt flag on task without resuming."); + } } + /** Sets the fiber to initialized state and increments the task counter. + + Note that the task information needs to be set up first. + */ void bumpTaskCounter() @safe nothrow { - import core.atomic : atomicOp; - () @trusted { atomicOp!"+="(this.m_taskCounter, 1); } (); + debug { + auto ts = atomicLoad(m_taskCounterAndFlags); + assert((ts & Flags.flagsMask) == 0, "bumpTaskCounter() called on fiber with non-zero flags"); + assert(m_taskFunc.func !is null, "bumpTaskCounter() called without initializing the task function"); + } + + () @trusted { atomicOp!"+="(m_taskCounterAndFlags, (1 << Flags.shiftAmount) + Flags.initialized); } (); + } + + private auto getTaskStatus() + shared const @safe nothrow { + return getTaskStatus(atomicLoad(m_taskCounterAndFlags)); + } + + private auto getTaskStatusFromOwnerThread() + const @safe nothrow { + debug assert(Thread.getThis() is m_thread); + return getTaskStatus(atomicLoad(m_taskCounterAndFlags)); + } + + private static auto getTaskStatus(ulong counter_and_flags) + @safe nothrow { + static struct S { + size_t counter; + bool running; + bool initialized; + bool interrupt; + } + S ret; + ret.counter = cast(size_t)(counter_and_flags >> Flags.shiftAmount); + ret.running = (counter_and_flags & Flags.running) != 0; + ret.initialized = (counter_and_flags & Flags.initialized) != 0; + ret.interrupt = (counter_and_flags & Flags.interrupt) != 0; + return ret; } package void handleInterrupt(scope void delegate() @safe nothrow on_interrupt) @safe nothrow { - assert(() @trusted { return Task.getThis().fiber; } () is this, "Handling interrupt outside of the corresponding fiber."); - if (m_interrupt && on_interrupt) { + assert(() @trusted { return Task.getThis().fiber; } () is this, + "Handling interrupt outside of the corresponding fiber."); + if (getTaskStatusFromOwnerThread().interrupt && on_interrupt) { debug (VibeTaskLog) logTrace("Handling interrupt flag."); - m_interrupt = false; + clearInterruptFlag(); on_interrupt(); } } package void handleInterrupt() @safe { - if (m_interrupt) { - m_interrupt = false; + assert(() @trusted { return Task.getThis().fiber; } () is this, + "Handling interrupt outside of the corresponding fiber."); + if (getTaskStatusFromOwnerThread().interrupt) { + clearInterruptFlag(); throw new InterruptException; } } + + private void clearInterruptFlag() + @safe nothrow { + auto tcf = atomicLoad(m_taskCounterAndFlags); + auto st = getTaskStatus(tcf); + while (true) { + assert(st.initialized); + if (!st.interrupt) break; + auto tcf_int = tcf & ~Flags.interrupt; + if (cas(&m_taskCounterAndFlags, tcf, tcf_int)) + break; + } + } } package struct TaskFuncInfo { @@ -843,7 +927,9 @@ package struct TaskScheduler { auto t = m_taskQueue.front; m_taskQueue.popFront(); debug (VibeTaskLog) logTrace("resuming task"); - resumeTask(t.task); + auto task = t.task; + if (task != Task.init) + resumeTask(t.task); debug (VibeTaskLog) logTrace("task out"); assert(!m_taskQueue.empty, "Marker task got removed from tasks queue!?"); @@ -863,6 +949,8 @@ package struct TaskScheduler { nothrow { import std.encoding : sanitize; + assert(t != Task.init, "Resuming null task"); + debug (VibeTaskLog) logTrace("task fiber resume"); auto uncaught_exception = () @trusted nothrow { return t.fiber.call!(Fiber.Rethrow.no)(); } (); debug (VibeTaskLog) logTrace("task fiber yielded");