Improve thread-correctness of Task/TaskFiber.

- Task.join and Task.interrupt are now thread-safe
- TaskFiber.task returns Task.init if no task is running (avoids bogus resumes of the TaskFiber by the scheduler)

To enable thread-safe join/interrupt, the task counter is now stored together with the necessary flags within a single shared ulong that is manipulated atomically.
This commit is contained in:
Sönke Ludwig 2019-04-13 17:12:00 +02:00
parent a202d33b3e
commit f734b4a142

View file

@ -10,6 +10,7 @@ module vibe.core.task;
import vibe.core.log;
import vibe.core.sync;
import core.atomic : atomicOp, atomicLoad, cas;
import core.thread;
import std.exception;
import std.traits;
@ -58,7 +59,7 @@ struct Task {
auto tfiber = cast(TaskFiber)fiber;
if (!tfiber) return Task.init;
// FIXME: returning a non-.init handle for a finished task might break some layered logic
return () @trusted { return Task(tfiber, tfiber.m_taskCounter); } ();
return Task(tfiber, tfiber.getTaskStatusFromOwnerThread().counter);
}
nothrow {
@ -69,13 +70,15 @@ struct Task {
/** Determines if the task is still running or scheduled to be run.
*/
@property bool running() // FIXME: this is NOT thread safe
@property bool running()
const @trusted {
assert(m_fiber !is null, "Invalid task handle");
try if (this.taskFiber.state == Fiber.State.TERM) return false; catch (Throwable) {}
if (this.taskFiber.m_taskCounter != m_taskCounter)
auto tf = this.taskFiber;
try if (tf.state == Fiber.State.TERM) return false; catch (Throwable) {}
auto st = m_fiber.getTaskStatus();
if (st.counter != m_taskCounter)
return false;
return this.taskFiber.m_running || this.taskFiber.m_taskFunc.func !is null;
return st.initialized;
}
package @property ref ThreadInfo tidInfo() @system { return m_fiber ? taskFiber.tidInfo : s_tidInfo; } // FIXME: this is not thread safe!
@ -91,9 +94,9 @@ struct Task {
T opCast(T)() const @safe nothrow if (is(T == bool)) { return m_fiber !is null; }
void join() @trusted { if (running) taskFiber.join!true(m_taskCounter); } // FIXME: this is NOT thread safe
void joinUninterruptible() @trusted nothrow { if (running) taskFiber.join!false(m_taskCounter); } // FIXME: this is NOT thread safe
void interrupt() @trusted nothrow { if (running) taskFiber.interrupt(m_taskCounter); } // FIXME: this is NOT thread safe
void join() @trusted { if (m_fiber) m_fiber.join!true(m_taskCounter); }
void joinUninterruptible() @trusted nothrow { if (m_fiber) m_fiber.join!false(m_taskCounter); }
void interrupt() @trusted nothrow { if (m_fiber) m_fiber.interrupt(m_taskCounter); }
string toString() const @safe { import std.string; return format("%s:%s", () @trusted { return cast(void*)m_fiber; } (), m_taskCounter); }
@ -295,6 +298,15 @@ final package class TaskFiber : Fiber {
static if ((void*).sizeof >= 8) enum defaultTaskStackSize = 16*1024*1024;
else enum defaultTaskStackSize = 512*1024;
private enum Flags {
running = 1UL << 0,
initialized = 1UL << 1,
interrupt = 1UL << 2,
shiftAmount = 3,
flagsMask = (1<<shiftAmount) - 1
}
private {
import std.concurrency : ThreadInfo;
import std.bitmanip : BitArray;
@ -305,8 +317,7 @@ final package class TaskFiber : Fiber {
Thread m_thread;
ThreadInfo m_tidInfo;
shared size_t m_taskCounter;
shared bool m_running;
shared ulong m_taskCounterAndFlags = 0; // bits 0-Flags.shiftAmount are flags
bool m_shutdown = false;
shared(ManualEvent) m_onExit;
@ -315,7 +326,6 @@ final package class TaskFiber : Fiber {
BitArray m_flsInit;
void[] m_fls;
bool m_interrupt; // Task.interrupt() is progress
package int m_yieldLockCount;
static TaskFiber ms_globalDummyFiber;
@ -370,12 +380,14 @@ final package class TaskFiber : Fiber {
if (m_shutdown) return;
}
debug assert(Thread.getThis() is m_thread, "Fiber moved between threads!?");
TaskFuncInfo task;
swap(task, m_taskFunc);
Task handle = this.task;
try {
m_running = true;
scope(exit) m_running = false;
atomicOp!"|="(m_taskCounterAndFlags, Flags.running); // set running
scope(exit) atomicOp!"&="(m_taskCounterAndFlags, ~Flags.flagsMask); // clear running/initialized
thisTid; // force creation of a message box
@ -387,6 +399,9 @@ final package class TaskFiber : Fiber {
}
task.call();
debug if (ms_taskEventCallback) ms_taskEventCallback(TaskEvent.end, handle);
debug if (() @trusted { return (cast(shared)this); } ().getTaskStatus().interrupt)
logDebug("Task exited while an interrupt was in flight.");
} catch (Exception e) {
debug if (ms_taskEventCallback) ms_taskEventCallback(TaskEvent.fail, handle);
import std.encoding;
@ -394,10 +409,7 @@ final package class TaskFiber : Fiber {
logDebug("Full error: %s", e.toString().sanitize());
}
if (m_interrupt) {
logDebug("Task exited while an interrupt was in flight.");
m_interrupt = false;
}
debug assert(Thread.getThis() is m_thread, "Fiber moved?");
this.tidInfo.ident = Tid.init; // clear message box
@ -418,6 +430,8 @@ final package class TaskFiber : Fiber {
assert(!m_queue, "Fiber done but still scheduled to be resumed!?");
debug assert(Thread.getThis() is m_thread, "Fiber moved between threads!?");
// make the fiber available for the next task
recycleFiber(this);
}
@ -443,17 +457,23 @@ final package class TaskFiber : Fiber {
/** Returns the handle of the current Task running on this fiber.
*/
@property Task task() @safe nothrow { return Task(this, m_taskCounter); }
@property Task task()
@safe nothrow {
auto ts = getTaskStatusFromOwnerThread();
if (!ts.initialized) return Task.init;
return Task(this, ts.counter);
}
@property ref inout(ThreadInfo) tidInfo() inout @safe nothrow { return m_tidInfo; }
@property size_t taskCounter() const @safe nothrow { return m_taskCounter; }
/** Shuts down the task handler loop.
*/
void shutdown()
@safe nothrow {
assert(!m_running);
debug assert(Thread.getThis() is m_thread);
assert(!() @trusted { return cast(shared)this; } ().getTaskStatus().initialized);
m_shutdown = true;
while (state != Fiber.State.TERM)
() @trusted {
@ -465,9 +485,12 @@ final package class TaskFiber : Fiber {
/** Blocks until the task has ended.
*/
void join(bool interruptiple)(size_t task_counter)
@trusted {
shared @trusted {
auto cnt = m_onExit.emitCount;
while ((m_running || m_taskFunc.func !is null) && m_taskCounter == task_counter) {
while (true) {
auto st = getTaskStatus();
if (!st.initialized || st.counter != task_counter)
break;
static if (interruptiple)
cnt = m_onExit.wait(cnt);
else
@ -478,47 +501,108 @@ final package class TaskFiber : Fiber {
/** Throws an InterruptExeption within the task as soon as it calls an interruptible function.
*/
void interrupt(size_t task_counter)
@safe nothrow {
shared @safe nothrow {
import vibe.core.core : taskScheduler;
if (m_taskCounter != task_counter)
return;
auto caller = () @trusted { return cast(shared)TaskFiber.getThis(); } ();
auto caller = Task.getThis();
if (caller != Task.init) {
assert(caller != this.task, "A task cannot interrupt itself.");
assert(caller.thread is this.thread, "Interrupting tasks in different threads is not yet supported.");
} else assert(() @trusted { return Thread.getThis(); } () is this.thread, "Interrupting tasks in different threads is not yet supported.");
debug (VibeTaskLog) logTrace("Resuming task with interrupt flag.");
m_interrupt = true;
assert(caller !is this, "A task cannot interrupt itself.");
auto defer = TaskFiber.getThis().m_yieldLockCount > 0 ? Yes.defer : No.defer;
taskScheduler.switchTo(this.task, defer);
while (true) {
auto tcf = atomicLoad(m_taskCounterAndFlags);
auto st = getTaskStatus(tcf);
if (!st.initialized || st.interrupt || st.counter != task_counter)
return;
auto tcf_int = tcf | Flags.interrupt;
if (cas(&m_taskCounterAndFlags, tcf, tcf_int))
break;
}
if (caller.m_thread is m_thread) {
auto thisus = () @trusted { return cast()this; } ();
debug (VibeTaskLog) logTrace("Resuming task with interrupt flag.");
auto defer = caller.m_yieldLockCount > 0 ? Yes.defer : No.defer;
taskScheduler.switchTo(thisus.task, defer);
} else {
debug (VibeTaskLog) logTrace("Set interrupt flag on task without resuming.");
}
}
/** Sets the fiber to initialized state and increments the task counter.
Note that the task information needs to be set up first.
*/
void bumpTaskCounter()
@safe nothrow {
import core.atomic : atomicOp;
() @trusted { atomicOp!"+="(this.m_taskCounter, 1); } ();
debug {
auto ts = atomicLoad(m_taskCounterAndFlags);
assert((ts & Flags.flagsMask) == 0, "bumpTaskCounter() called on fiber with non-zero flags");
assert(m_taskFunc.func !is null, "bumpTaskCounter() called without initializing the task function");
}
() @trusted { atomicOp!"+="(m_taskCounterAndFlags, (1 << Flags.shiftAmount) + Flags.initialized); } ();
}
private auto getTaskStatus()
shared const @safe nothrow {
return getTaskStatus(atomicLoad(m_taskCounterAndFlags));
}
private auto getTaskStatusFromOwnerThread()
const @safe nothrow {
debug assert(Thread.getThis() is m_thread);
return getTaskStatus(atomicLoad(m_taskCounterAndFlags));
}
private static auto getTaskStatus(ulong counter_and_flags)
@safe nothrow {
static struct S {
size_t counter;
bool running;
bool initialized;
bool interrupt;
}
S ret;
ret.counter = cast(size_t)(counter_and_flags >> Flags.shiftAmount);
ret.running = (counter_and_flags & Flags.running) != 0;
ret.initialized = (counter_and_flags & Flags.initialized) != 0;
ret.interrupt = (counter_and_flags & Flags.interrupt) != 0;
return ret;
}
package void handleInterrupt(scope void delegate() @safe nothrow on_interrupt)
@safe nothrow {
assert(() @trusted { return Task.getThis().fiber; } () is this, "Handling interrupt outside of the corresponding fiber.");
if (m_interrupt && on_interrupt) {
assert(() @trusted { return Task.getThis().fiber; } () is this,
"Handling interrupt outside of the corresponding fiber.");
if (getTaskStatusFromOwnerThread().interrupt && on_interrupt) {
debug (VibeTaskLog) logTrace("Handling interrupt flag.");
m_interrupt = false;
clearInterruptFlag();
on_interrupt();
}
}
package void handleInterrupt()
@safe {
if (m_interrupt) {
m_interrupt = false;
assert(() @trusted { return Task.getThis().fiber; } () is this,
"Handling interrupt outside of the corresponding fiber.");
if (getTaskStatusFromOwnerThread().interrupt) {
clearInterruptFlag();
throw new InterruptException;
}
}
private void clearInterruptFlag()
@safe nothrow {
auto tcf = atomicLoad(m_taskCounterAndFlags);
auto st = getTaskStatus(tcf);
while (true) {
assert(st.initialized);
if (!st.interrupt) break;
auto tcf_int = tcf & ~Flags.interrupt;
if (cas(&m_taskCounterAndFlags, tcf, tcf_int))
break;
}
}
}
package struct TaskFuncInfo {
@ -843,7 +927,9 @@ package struct TaskScheduler {
auto t = m_taskQueue.front;
m_taskQueue.popFront();
debug (VibeTaskLog) logTrace("resuming task");
resumeTask(t.task);
auto task = t.task;
if (task != Task.init)
resumeTask(t.task);
debug (VibeTaskLog) logTrace("task out");
assert(!m_taskQueue.empty, "Marker task got removed from tasks queue!?");
@ -863,6 +949,8 @@ package struct TaskScheduler {
nothrow {
import std.encoding : sanitize;
assert(t != Task.init, "Resuming null task");
debug (VibeTaskLog) logTrace("task fiber resume");
auto uncaught_exception = () @trusted nothrow { return t.fiber.call!(Fiber.Rethrow.no)(); } ();
debug (VibeTaskLog) logTrace("task fiber yielded");