Improve thread-correctness of Task/TaskFiber.
- Task.join and Task.interrupt are now thread-safe - TaskFiber.task returns Task.init if no task is running (avoids bogus resumes of the TaskFiber by the scheduler) To enable thread-safe join/interrupt, the task counter is now stored together with the necessary flags within a single shared ulong that is manipulated atomically.
This commit is contained in:
parent
a202d33b3e
commit
f734b4a142
|
@ -10,6 +10,7 @@ module vibe.core.task;
|
|||
import vibe.core.log;
|
||||
import vibe.core.sync;
|
||||
|
||||
import core.atomic : atomicOp, atomicLoad, cas;
|
||||
import core.thread;
|
||||
import std.exception;
|
||||
import std.traits;
|
||||
|
@ -58,7 +59,7 @@ struct Task {
|
|||
auto tfiber = cast(TaskFiber)fiber;
|
||||
if (!tfiber) return Task.init;
|
||||
// FIXME: returning a non-.init handle for a finished task might break some layered logic
|
||||
return () @trusted { return Task(tfiber, tfiber.m_taskCounter); } ();
|
||||
return Task(tfiber, tfiber.getTaskStatusFromOwnerThread().counter);
|
||||
}
|
||||
|
||||
nothrow {
|
||||
|
@ -69,13 +70,15 @@ struct Task {
|
|||
|
||||
/** Determines if the task is still running or scheduled to be run.
|
||||
*/
|
||||
@property bool running() // FIXME: this is NOT thread safe
|
||||
@property bool running()
|
||||
const @trusted {
|
||||
assert(m_fiber !is null, "Invalid task handle");
|
||||
try if (this.taskFiber.state == Fiber.State.TERM) return false; catch (Throwable) {}
|
||||
if (this.taskFiber.m_taskCounter != m_taskCounter)
|
||||
auto tf = this.taskFiber;
|
||||
try if (tf.state == Fiber.State.TERM) return false; catch (Throwable) {}
|
||||
auto st = m_fiber.getTaskStatus();
|
||||
if (st.counter != m_taskCounter)
|
||||
return false;
|
||||
return this.taskFiber.m_running || this.taskFiber.m_taskFunc.func !is null;
|
||||
return st.initialized;
|
||||
}
|
||||
|
||||
package @property ref ThreadInfo tidInfo() @system { return m_fiber ? taskFiber.tidInfo : s_tidInfo; } // FIXME: this is not thread safe!
|
||||
|
@ -91,9 +94,9 @@ struct Task {
|
|||
|
||||
T opCast(T)() const @safe nothrow if (is(T == bool)) { return m_fiber !is null; }
|
||||
|
||||
void join() @trusted { if (running) taskFiber.join!true(m_taskCounter); } // FIXME: this is NOT thread safe
|
||||
void joinUninterruptible() @trusted nothrow { if (running) taskFiber.join!false(m_taskCounter); } // FIXME: this is NOT thread safe
|
||||
void interrupt() @trusted nothrow { if (running) taskFiber.interrupt(m_taskCounter); } // FIXME: this is NOT thread safe
|
||||
void join() @trusted { if (m_fiber) m_fiber.join!true(m_taskCounter); }
|
||||
void joinUninterruptible() @trusted nothrow { if (m_fiber) m_fiber.join!false(m_taskCounter); }
|
||||
void interrupt() @trusted nothrow { if (m_fiber) m_fiber.interrupt(m_taskCounter); }
|
||||
|
||||
string toString() const @safe { import std.string; return format("%s:%s", () @trusted { return cast(void*)m_fiber; } (), m_taskCounter); }
|
||||
|
||||
|
@ -295,6 +298,15 @@ final package class TaskFiber : Fiber {
|
|||
static if ((void*).sizeof >= 8) enum defaultTaskStackSize = 16*1024*1024;
|
||||
else enum defaultTaskStackSize = 512*1024;
|
||||
|
||||
private enum Flags {
|
||||
running = 1UL << 0,
|
||||
initialized = 1UL << 1,
|
||||
interrupt = 1UL << 2,
|
||||
|
||||
shiftAmount = 3,
|
||||
flagsMask = (1<<shiftAmount) - 1
|
||||
}
|
||||
|
||||
private {
|
||||
import std.concurrency : ThreadInfo;
|
||||
import std.bitmanip : BitArray;
|
||||
|
@ -305,8 +317,7 @@ final package class TaskFiber : Fiber {
|
|||
|
||||
Thread m_thread;
|
||||
ThreadInfo m_tidInfo;
|
||||
shared size_t m_taskCounter;
|
||||
shared bool m_running;
|
||||
shared ulong m_taskCounterAndFlags = 0; // bits 0-Flags.shiftAmount are flags
|
||||
bool m_shutdown = false;
|
||||
|
||||
shared(ManualEvent) m_onExit;
|
||||
|
@ -315,7 +326,6 @@ final package class TaskFiber : Fiber {
|
|||
BitArray m_flsInit;
|
||||
void[] m_fls;
|
||||
|
||||
bool m_interrupt; // Task.interrupt() is progress
|
||||
package int m_yieldLockCount;
|
||||
|
||||
static TaskFiber ms_globalDummyFiber;
|
||||
|
@ -370,12 +380,14 @@ final package class TaskFiber : Fiber {
|
|||
if (m_shutdown) return;
|
||||
}
|
||||
|
||||
debug assert(Thread.getThis() is m_thread, "Fiber moved between threads!?");
|
||||
|
||||
TaskFuncInfo task;
|
||||
swap(task, m_taskFunc);
|
||||
Task handle = this.task;
|
||||
try {
|
||||
m_running = true;
|
||||
scope(exit) m_running = false;
|
||||
atomicOp!"|="(m_taskCounterAndFlags, Flags.running); // set running
|
||||
scope(exit) atomicOp!"&="(m_taskCounterAndFlags, ~Flags.flagsMask); // clear running/initialized
|
||||
|
||||
thisTid; // force creation of a message box
|
||||
|
||||
|
@ -387,6 +399,9 @@ final package class TaskFiber : Fiber {
|
|||
}
|
||||
task.call();
|
||||
debug if (ms_taskEventCallback) ms_taskEventCallback(TaskEvent.end, handle);
|
||||
|
||||
debug if (() @trusted { return (cast(shared)this); } ().getTaskStatus().interrupt)
|
||||
logDebug("Task exited while an interrupt was in flight.");
|
||||
} catch (Exception e) {
|
||||
debug if (ms_taskEventCallback) ms_taskEventCallback(TaskEvent.fail, handle);
|
||||
import std.encoding;
|
||||
|
@ -394,10 +409,7 @@ final package class TaskFiber : Fiber {
|
|||
logDebug("Full error: %s", e.toString().sanitize());
|
||||
}
|
||||
|
||||
if (m_interrupt) {
|
||||
logDebug("Task exited while an interrupt was in flight.");
|
||||
m_interrupt = false;
|
||||
}
|
||||
debug assert(Thread.getThis() is m_thread, "Fiber moved?");
|
||||
|
||||
this.tidInfo.ident = Tid.init; // clear message box
|
||||
|
||||
|
@ -418,6 +430,8 @@ final package class TaskFiber : Fiber {
|
|||
|
||||
assert(!m_queue, "Fiber done but still scheduled to be resumed!?");
|
||||
|
||||
debug assert(Thread.getThis() is m_thread, "Fiber moved between threads!?");
|
||||
|
||||
// make the fiber available for the next task
|
||||
recycleFiber(this);
|
||||
}
|
||||
|
@ -443,17 +457,23 @@ final package class TaskFiber : Fiber {
|
|||
|
||||
/** Returns the handle of the current Task running on this fiber.
|
||||
*/
|
||||
@property Task task() @safe nothrow { return Task(this, m_taskCounter); }
|
||||
@property Task task()
|
||||
@safe nothrow {
|
||||
auto ts = getTaskStatusFromOwnerThread();
|
||||
if (!ts.initialized) return Task.init;
|
||||
return Task(this, ts.counter);
|
||||
}
|
||||
|
||||
@property ref inout(ThreadInfo) tidInfo() inout @safe nothrow { return m_tidInfo; }
|
||||
|
||||
@property size_t taskCounter() const @safe nothrow { return m_taskCounter; }
|
||||
|
||||
/** Shuts down the task handler loop.
|
||||
*/
|
||||
void shutdown()
|
||||
@safe nothrow {
|
||||
assert(!m_running);
|
||||
debug assert(Thread.getThis() is m_thread);
|
||||
|
||||
assert(!() @trusted { return cast(shared)this; } ().getTaskStatus().initialized);
|
||||
|
||||
m_shutdown = true;
|
||||
while (state != Fiber.State.TERM)
|
||||
() @trusted {
|
||||
|
@ -465,9 +485,12 @@ final package class TaskFiber : Fiber {
|
|||
/** Blocks until the task has ended.
|
||||
*/
|
||||
void join(bool interruptiple)(size_t task_counter)
|
||||
@trusted {
|
||||
shared @trusted {
|
||||
auto cnt = m_onExit.emitCount;
|
||||
while ((m_running || m_taskFunc.func !is null) && m_taskCounter == task_counter) {
|
||||
while (true) {
|
||||
auto st = getTaskStatus();
|
||||
if (!st.initialized || st.counter != task_counter)
|
||||
break;
|
||||
static if (interruptiple)
|
||||
cnt = m_onExit.wait(cnt);
|
||||
else
|
||||
|
@ -478,47 +501,108 @@ final package class TaskFiber : Fiber {
|
|||
/** Throws an InterruptExeption within the task as soon as it calls an interruptible function.
|
||||
*/
|
||||
void interrupt(size_t task_counter)
|
||||
@safe nothrow {
|
||||
shared @safe nothrow {
|
||||
import vibe.core.core : taskScheduler;
|
||||
|
||||
if (m_taskCounter != task_counter)
|
||||
return;
|
||||
auto caller = () @trusted { return cast(shared)TaskFiber.getThis(); } ();
|
||||
|
||||
auto caller = Task.getThis();
|
||||
if (caller != Task.init) {
|
||||
assert(caller != this.task, "A task cannot interrupt itself.");
|
||||
assert(caller.thread is this.thread, "Interrupting tasks in different threads is not yet supported.");
|
||||
} else assert(() @trusted { return Thread.getThis(); } () is this.thread, "Interrupting tasks in different threads is not yet supported.");
|
||||
debug (VibeTaskLog) logTrace("Resuming task with interrupt flag.");
|
||||
m_interrupt = true;
|
||||
assert(caller !is this, "A task cannot interrupt itself.");
|
||||
|
||||
auto defer = TaskFiber.getThis().m_yieldLockCount > 0 ? Yes.defer : No.defer;
|
||||
taskScheduler.switchTo(this.task, defer);
|
||||
while (true) {
|
||||
auto tcf = atomicLoad(m_taskCounterAndFlags);
|
||||
auto st = getTaskStatus(tcf);
|
||||
if (!st.initialized || st.interrupt || st.counter != task_counter)
|
||||
return;
|
||||
auto tcf_int = tcf | Flags.interrupt;
|
||||
if (cas(&m_taskCounterAndFlags, tcf, tcf_int))
|
||||
break;
|
||||
}
|
||||
|
||||
if (caller.m_thread is m_thread) {
|
||||
auto thisus = () @trusted { return cast()this; } ();
|
||||
debug (VibeTaskLog) logTrace("Resuming task with interrupt flag.");
|
||||
auto defer = caller.m_yieldLockCount > 0 ? Yes.defer : No.defer;
|
||||
taskScheduler.switchTo(thisus.task, defer);
|
||||
} else {
|
||||
debug (VibeTaskLog) logTrace("Set interrupt flag on task without resuming.");
|
||||
}
|
||||
}
|
||||
|
||||
/** Sets the fiber to initialized state and increments the task counter.
|
||||
|
||||
Note that the task information needs to be set up first.
|
||||
*/
|
||||
void bumpTaskCounter()
|
||||
@safe nothrow {
|
||||
import core.atomic : atomicOp;
|
||||
() @trusted { atomicOp!"+="(this.m_taskCounter, 1); } ();
|
||||
debug {
|
||||
auto ts = atomicLoad(m_taskCounterAndFlags);
|
||||
assert((ts & Flags.flagsMask) == 0, "bumpTaskCounter() called on fiber with non-zero flags");
|
||||
assert(m_taskFunc.func !is null, "bumpTaskCounter() called without initializing the task function");
|
||||
}
|
||||
|
||||
() @trusted { atomicOp!"+="(m_taskCounterAndFlags, (1 << Flags.shiftAmount) + Flags.initialized); } ();
|
||||
}
|
||||
|
||||
private auto getTaskStatus()
|
||||
shared const @safe nothrow {
|
||||
return getTaskStatus(atomicLoad(m_taskCounterAndFlags));
|
||||
}
|
||||
|
||||
private auto getTaskStatusFromOwnerThread()
|
||||
const @safe nothrow {
|
||||
debug assert(Thread.getThis() is m_thread);
|
||||
return getTaskStatus(atomicLoad(m_taskCounterAndFlags));
|
||||
}
|
||||
|
||||
private static auto getTaskStatus(ulong counter_and_flags)
|
||||
@safe nothrow {
|
||||
static struct S {
|
||||
size_t counter;
|
||||
bool running;
|
||||
bool initialized;
|
||||
bool interrupt;
|
||||
}
|
||||
S ret;
|
||||
ret.counter = cast(size_t)(counter_and_flags >> Flags.shiftAmount);
|
||||
ret.running = (counter_and_flags & Flags.running) != 0;
|
||||
ret.initialized = (counter_and_flags & Flags.initialized) != 0;
|
||||
ret.interrupt = (counter_and_flags & Flags.interrupt) != 0;
|
||||
return ret;
|
||||
}
|
||||
|
||||
package void handleInterrupt(scope void delegate() @safe nothrow on_interrupt)
|
||||
@safe nothrow {
|
||||
assert(() @trusted { return Task.getThis().fiber; } () is this, "Handling interrupt outside of the corresponding fiber.");
|
||||
if (m_interrupt && on_interrupt) {
|
||||
assert(() @trusted { return Task.getThis().fiber; } () is this,
|
||||
"Handling interrupt outside of the corresponding fiber.");
|
||||
if (getTaskStatusFromOwnerThread().interrupt && on_interrupt) {
|
||||
debug (VibeTaskLog) logTrace("Handling interrupt flag.");
|
||||
m_interrupt = false;
|
||||
clearInterruptFlag();
|
||||
on_interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
package void handleInterrupt()
|
||||
@safe {
|
||||
if (m_interrupt) {
|
||||
m_interrupt = false;
|
||||
assert(() @trusted { return Task.getThis().fiber; } () is this,
|
||||
"Handling interrupt outside of the corresponding fiber.");
|
||||
if (getTaskStatusFromOwnerThread().interrupt) {
|
||||
clearInterruptFlag();
|
||||
throw new InterruptException;
|
||||
}
|
||||
}
|
||||
|
||||
private void clearInterruptFlag()
|
||||
@safe nothrow {
|
||||
auto tcf = atomicLoad(m_taskCounterAndFlags);
|
||||
auto st = getTaskStatus(tcf);
|
||||
while (true) {
|
||||
assert(st.initialized);
|
||||
if (!st.interrupt) break;
|
||||
auto tcf_int = tcf & ~Flags.interrupt;
|
||||
if (cas(&m_taskCounterAndFlags, tcf, tcf_int))
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
package struct TaskFuncInfo {
|
||||
|
@ -843,7 +927,9 @@ package struct TaskScheduler {
|
|||
auto t = m_taskQueue.front;
|
||||
m_taskQueue.popFront();
|
||||
debug (VibeTaskLog) logTrace("resuming task");
|
||||
resumeTask(t.task);
|
||||
auto task = t.task;
|
||||
if (task != Task.init)
|
||||
resumeTask(t.task);
|
||||
debug (VibeTaskLog) logTrace("task out");
|
||||
|
||||
assert(!m_taskQueue.empty, "Marker task got removed from tasks queue!?");
|
||||
|
@ -863,6 +949,8 @@ package struct TaskScheduler {
|
|||
nothrow {
|
||||
import std.encoding : sanitize;
|
||||
|
||||
assert(t != Task.init, "Resuming null task");
|
||||
|
||||
debug (VibeTaskLog) logTrace("task fiber resume");
|
||||
auto uncaught_exception = () @trusted nothrow { return t.fiber.call!(Fiber.Rethrow.no)(); } ();
|
||||
debug (VibeTaskLog) logTrace("task fiber yielded");
|
||||
|
|
Loading…
Reference in a new issue