Improve thread-correctness of Task/TaskFiber.
- Task.join and Task.interrupt are now thread-safe - TaskFiber.task returns Task.init if no task is running (avoids bogus resumes of the TaskFiber by the scheduler) To enable thread-safe join/interrupt, the task counter is now stored together with the necessary flags within a single shared ulong that is manipulated atomically.
This commit is contained in:
parent
a202d33b3e
commit
f734b4a142
|
@ -10,6 +10,7 @@ module vibe.core.task;
|
||||||
import vibe.core.log;
|
import vibe.core.log;
|
||||||
import vibe.core.sync;
|
import vibe.core.sync;
|
||||||
|
|
||||||
|
import core.atomic : atomicOp, atomicLoad, cas;
|
||||||
import core.thread;
|
import core.thread;
|
||||||
import std.exception;
|
import std.exception;
|
||||||
import std.traits;
|
import std.traits;
|
||||||
|
@ -58,7 +59,7 @@ struct Task {
|
||||||
auto tfiber = cast(TaskFiber)fiber;
|
auto tfiber = cast(TaskFiber)fiber;
|
||||||
if (!tfiber) return Task.init;
|
if (!tfiber) return Task.init;
|
||||||
// FIXME: returning a non-.init handle for a finished task might break some layered logic
|
// FIXME: returning a non-.init handle for a finished task might break some layered logic
|
||||||
return () @trusted { return Task(tfiber, tfiber.m_taskCounter); } ();
|
return Task(tfiber, tfiber.getTaskStatusFromOwnerThread().counter);
|
||||||
}
|
}
|
||||||
|
|
||||||
nothrow {
|
nothrow {
|
||||||
|
@ -69,13 +70,15 @@ struct Task {
|
||||||
|
|
||||||
/** Determines if the task is still running or scheduled to be run.
|
/** Determines if the task is still running or scheduled to be run.
|
||||||
*/
|
*/
|
||||||
@property bool running() // FIXME: this is NOT thread safe
|
@property bool running()
|
||||||
const @trusted {
|
const @trusted {
|
||||||
assert(m_fiber !is null, "Invalid task handle");
|
assert(m_fiber !is null, "Invalid task handle");
|
||||||
try if (this.taskFiber.state == Fiber.State.TERM) return false; catch (Throwable) {}
|
auto tf = this.taskFiber;
|
||||||
if (this.taskFiber.m_taskCounter != m_taskCounter)
|
try if (tf.state == Fiber.State.TERM) return false; catch (Throwable) {}
|
||||||
|
auto st = m_fiber.getTaskStatus();
|
||||||
|
if (st.counter != m_taskCounter)
|
||||||
return false;
|
return false;
|
||||||
return this.taskFiber.m_running || this.taskFiber.m_taskFunc.func !is null;
|
return st.initialized;
|
||||||
}
|
}
|
||||||
|
|
||||||
package @property ref ThreadInfo tidInfo() @system { return m_fiber ? taskFiber.tidInfo : s_tidInfo; } // FIXME: this is not thread safe!
|
package @property ref ThreadInfo tidInfo() @system { return m_fiber ? taskFiber.tidInfo : s_tidInfo; } // FIXME: this is not thread safe!
|
||||||
|
@ -91,9 +94,9 @@ struct Task {
|
||||||
|
|
||||||
T opCast(T)() const @safe nothrow if (is(T == bool)) { return m_fiber !is null; }
|
T opCast(T)() const @safe nothrow if (is(T == bool)) { return m_fiber !is null; }
|
||||||
|
|
||||||
void join() @trusted { if (running) taskFiber.join!true(m_taskCounter); } // FIXME: this is NOT thread safe
|
void join() @trusted { if (m_fiber) m_fiber.join!true(m_taskCounter); }
|
||||||
void joinUninterruptible() @trusted nothrow { if (running) taskFiber.join!false(m_taskCounter); } // FIXME: this is NOT thread safe
|
void joinUninterruptible() @trusted nothrow { if (m_fiber) m_fiber.join!false(m_taskCounter); }
|
||||||
void interrupt() @trusted nothrow { if (running) taskFiber.interrupt(m_taskCounter); } // FIXME: this is NOT thread safe
|
void interrupt() @trusted nothrow { if (m_fiber) m_fiber.interrupt(m_taskCounter); }
|
||||||
|
|
||||||
string toString() const @safe { import std.string; return format("%s:%s", () @trusted { return cast(void*)m_fiber; } (), m_taskCounter); }
|
string toString() const @safe { import std.string; return format("%s:%s", () @trusted { return cast(void*)m_fiber; } (), m_taskCounter); }
|
||||||
|
|
||||||
|
@ -295,6 +298,15 @@ final package class TaskFiber : Fiber {
|
||||||
static if ((void*).sizeof >= 8) enum defaultTaskStackSize = 16*1024*1024;
|
static if ((void*).sizeof >= 8) enum defaultTaskStackSize = 16*1024*1024;
|
||||||
else enum defaultTaskStackSize = 512*1024;
|
else enum defaultTaskStackSize = 512*1024;
|
||||||
|
|
||||||
|
private enum Flags {
|
||||||
|
running = 1UL << 0,
|
||||||
|
initialized = 1UL << 1,
|
||||||
|
interrupt = 1UL << 2,
|
||||||
|
|
||||||
|
shiftAmount = 3,
|
||||||
|
flagsMask = (1<<shiftAmount) - 1
|
||||||
|
}
|
||||||
|
|
||||||
private {
|
private {
|
||||||
import std.concurrency : ThreadInfo;
|
import std.concurrency : ThreadInfo;
|
||||||
import std.bitmanip : BitArray;
|
import std.bitmanip : BitArray;
|
||||||
|
@ -305,8 +317,7 @@ final package class TaskFiber : Fiber {
|
||||||
|
|
||||||
Thread m_thread;
|
Thread m_thread;
|
||||||
ThreadInfo m_tidInfo;
|
ThreadInfo m_tidInfo;
|
||||||
shared size_t m_taskCounter;
|
shared ulong m_taskCounterAndFlags = 0; // bits 0-Flags.shiftAmount are flags
|
||||||
shared bool m_running;
|
|
||||||
bool m_shutdown = false;
|
bool m_shutdown = false;
|
||||||
|
|
||||||
shared(ManualEvent) m_onExit;
|
shared(ManualEvent) m_onExit;
|
||||||
|
@ -315,7 +326,6 @@ final package class TaskFiber : Fiber {
|
||||||
BitArray m_flsInit;
|
BitArray m_flsInit;
|
||||||
void[] m_fls;
|
void[] m_fls;
|
||||||
|
|
||||||
bool m_interrupt; // Task.interrupt() is progress
|
|
||||||
package int m_yieldLockCount;
|
package int m_yieldLockCount;
|
||||||
|
|
||||||
static TaskFiber ms_globalDummyFiber;
|
static TaskFiber ms_globalDummyFiber;
|
||||||
|
@ -370,12 +380,14 @@ final package class TaskFiber : Fiber {
|
||||||
if (m_shutdown) return;
|
if (m_shutdown) return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
debug assert(Thread.getThis() is m_thread, "Fiber moved between threads!?");
|
||||||
|
|
||||||
TaskFuncInfo task;
|
TaskFuncInfo task;
|
||||||
swap(task, m_taskFunc);
|
swap(task, m_taskFunc);
|
||||||
Task handle = this.task;
|
Task handle = this.task;
|
||||||
try {
|
try {
|
||||||
m_running = true;
|
atomicOp!"|="(m_taskCounterAndFlags, Flags.running); // set running
|
||||||
scope(exit) m_running = false;
|
scope(exit) atomicOp!"&="(m_taskCounterAndFlags, ~Flags.flagsMask); // clear running/initialized
|
||||||
|
|
||||||
thisTid; // force creation of a message box
|
thisTid; // force creation of a message box
|
||||||
|
|
||||||
|
@ -387,6 +399,9 @@ final package class TaskFiber : Fiber {
|
||||||
}
|
}
|
||||||
task.call();
|
task.call();
|
||||||
debug if (ms_taskEventCallback) ms_taskEventCallback(TaskEvent.end, handle);
|
debug if (ms_taskEventCallback) ms_taskEventCallback(TaskEvent.end, handle);
|
||||||
|
|
||||||
|
debug if (() @trusted { return (cast(shared)this); } ().getTaskStatus().interrupt)
|
||||||
|
logDebug("Task exited while an interrupt was in flight.");
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
debug if (ms_taskEventCallback) ms_taskEventCallback(TaskEvent.fail, handle);
|
debug if (ms_taskEventCallback) ms_taskEventCallback(TaskEvent.fail, handle);
|
||||||
import std.encoding;
|
import std.encoding;
|
||||||
|
@ -394,10 +409,7 @@ final package class TaskFiber : Fiber {
|
||||||
logDebug("Full error: %s", e.toString().sanitize());
|
logDebug("Full error: %s", e.toString().sanitize());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (m_interrupt) {
|
debug assert(Thread.getThis() is m_thread, "Fiber moved?");
|
||||||
logDebug("Task exited while an interrupt was in flight.");
|
|
||||||
m_interrupt = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
this.tidInfo.ident = Tid.init; // clear message box
|
this.tidInfo.ident = Tid.init; // clear message box
|
||||||
|
|
||||||
|
@ -418,6 +430,8 @@ final package class TaskFiber : Fiber {
|
||||||
|
|
||||||
assert(!m_queue, "Fiber done but still scheduled to be resumed!?");
|
assert(!m_queue, "Fiber done but still scheduled to be resumed!?");
|
||||||
|
|
||||||
|
debug assert(Thread.getThis() is m_thread, "Fiber moved between threads!?");
|
||||||
|
|
||||||
// make the fiber available for the next task
|
// make the fiber available for the next task
|
||||||
recycleFiber(this);
|
recycleFiber(this);
|
||||||
}
|
}
|
||||||
|
@ -443,17 +457,23 @@ final package class TaskFiber : Fiber {
|
||||||
|
|
||||||
/** Returns the handle of the current Task running on this fiber.
|
/** Returns the handle of the current Task running on this fiber.
|
||||||
*/
|
*/
|
||||||
@property Task task() @safe nothrow { return Task(this, m_taskCounter); }
|
@property Task task()
|
||||||
|
@safe nothrow {
|
||||||
|
auto ts = getTaskStatusFromOwnerThread();
|
||||||
|
if (!ts.initialized) return Task.init;
|
||||||
|
return Task(this, ts.counter);
|
||||||
|
}
|
||||||
|
|
||||||
@property ref inout(ThreadInfo) tidInfo() inout @safe nothrow { return m_tidInfo; }
|
@property ref inout(ThreadInfo) tidInfo() inout @safe nothrow { return m_tidInfo; }
|
||||||
|
|
||||||
@property size_t taskCounter() const @safe nothrow { return m_taskCounter; }
|
|
||||||
|
|
||||||
/** Shuts down the task handler loop.
|
/** Shuts down the task handler loop.
|
||||||
*/
|
*/
|
||||||
void shutdown()
|
void shutdown()
|
||||||
@safe nothrow {
|
@safe nothrow {
|
||||||
assert(!m_running);
|
debug assert(Thread.getThis() is m_thread);
|
||||||
|
|
||||||
|
assert(!() @trusted { return cast(shared)this; } ().getTaskStatus().initialized);
|
||||||
|
|
||||||
m_shutdown = true;
|
m_shutdown = true;
|
||||||
while (state != Fiber.State.TERM)
|
while (state != Fiber.State.TERM)
|
||||||
() @trusted {
|
() @trusted {
|
||||||
|
@ -465,9 +485,12 @@ final package class TaskFiber : Fiber {
|
||||||
/** Blocks until the task has ended.
|
/** Blocks until the task has ended.
|
||||||
*/
|
*/
|
||||||
void join(bool interruptiple)(size_t task_counter)
|
void join(bool interruptiple)(size_t task_counter)
|
||||||
@trusted {
|
shared @trusted {
|
||||||
auto cnt = m_onExit.emitCount;
|
auto cnt = m_onExit.emitCount;
|
||||||
while ((m_running || m_taskFunc.func !is null) && m_taskCounter == task_counter) {
|
while (true) {
|
||||||
|
auto st = getTaskStatus();
|
||||||
|
if (!st.initialized || st.counter != task_counter)
|
||||||
|
break;
|
||||||
static if (interruptiple)
|
static if (interruptiple)
|
||||||
cnt = m_onExit.wait(cnt);
|
cnt = m_onExit.wait(cnt);
|
||||||
else
|
else
|
||||||
|
@ -478,47 +501,108 @@ final package class TaskFiber : Fiber {
|
||||||
/** Throws an InterruptExeption within the task as soon as it calls an interruptible function.
|
/** Throws an InterruptExeption within the task as soon as it calls an interruptible function.
|
||||||
*/
|
*/
|
||||||
void interrupt(size_t task_counter)
|
void interrupt(size_t task_counter)
|
||||||
@safe nothrow {
|
shared @safe nothrow {
|
||||||
import vibe.core.core : taskScheduler;
|
import vibe.core.core : taskScheduler;
|
||||||
|
|
||||||
if (m_taskCounter != task_counter)
|
auto caller = () @trusted { return cast(shared)TaskFiber.getThis(); } ();
|
||||||
|
|
||||||
|
assert(caller !is this, "A task cannot interrupt itself.");
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
auto tcf = atomicLoad(m_taskCounterAndFlags);
|
||||||
|
auto st = getTaskStatus(tcf);
|
||||||
|
if (!st.initialized || st.interrupt || st.counter != task_counter)
|
||||||
return;
|
return;
|
||||||
|
auto tcf_int = tcf | Flags.interrupt;
|
||||||
auto caller = Task.getThis();
|
if (cas(&m_taskCounterAndFlags, tcf, tcf_int))
|
||||||
if (caller != Task.init) {
|
break;
|
||||||
assert(caller != this.task, "A task cannot interrupt itself.");
|
|
||||||
assert(caller.thread is this.thread, "Interrupting tasks in different threads is not yet supported.");
|
|
||||||
} else assert(() @trusted { return Thread.getThis(); } () is this.thread, "Interrupting tasks in different threads is not yet supported.");
|
|
||||||
debug (VibeTaskLog) logTrace("Resuming task with interrupt flag.");
|
|
||||||
m_interrupt = true;
|
|
||||||
|
|
||||||
auto defer = TaskFiber.getThis().m_yieldLockCount > 0 ? Yes.defer : No.defer;
|
|
||||||
taskScheduler.switchTo(this.task, defer);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (caller.m_thread is m_thread) {
|
||||||
|
auto thisus = () @trusted { return cast()this; } ();
|
||||||
|
debug (VibeTaskLog) logTrace("Resuming task with interrupt flag.");
|
||||||
|
auto defer = caller.m_yieldLockCount > 0 ? Yes.defer : No.defer;
|
||||||
|
taskScheduler.switchTo(thisus.task, defer);
|
||||||
|
} else {
|
||||||
|
debug (VibeTaskLog) logTrace("Set interrupt flag on task without resuming.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Sets the fiber to initialized state and increments the task counter.
|
||||||
|
|
||||||
|
Note that the task information needs to be set up first.
|
||||||
|
*/
|
||||||
void bumpTaskCounter()
|
void bumpTaskCounter()
|
||||||
@safe nothrow {
|
@safe nothrow {
|
||||||
import core.atomic : atomicOp;
|
debug {
|
||||||
() @trusted { atomicOp!"+="(this.m_taskCounter, 1); } ();
|
auto ts = atomicLoad(m_taskCounterAndFlags);
|
||||||
|
assert((ts & Flags.flagsMask) == 0, "bumpTaskCounter() called on fiber with non-zero flags");
|
||||||
|
assert(m_taskFunc.func !is null, "bumpTaskCounter() called without initializing the task function");
|
||||||
|
}
|
||||||
|
|
||||||
|
() @trusted { atomicOp!"+="(m_taskCounterAndFlags, (1 << Flags.shiftAmount) + Flags.initialized); } ();
|
||||||
|
}
|
||||||
|
|
||||||
|
private auto getTaskStatus()
|
||||||
|
shared const @safe nothrow {
|
||||||
|
return getTaskStatus(atomicLoad(m_taskCounterAndFlags));
|
||||||
|
}
|
||||||
|
|
||||||
|
private auto getTaskStatusFromOwnerThread()
|
||||||
|
const @safe nothrow {
|
||||||
|
debug assert(Thread.getThis() is m_thread);
|
||||||
|
return getTaskStatus(atomicLoad(m_taskCounterAndFlags));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static auto getTaskStatus(ulong counter_and_flags)
|
||||||
|
@safe nothrow {
|
||||||
|
static struct S {
|
||||||
|
size_t counter;
|
||||||
|
bool running;
|
||||||
|
bool initialized;
|
||||||
|
bool interrupt;
|
||||||
|
}
|
||||||
|
S ret;
|
||||||
|
ret.counter = cast(size_t)(counter_and_flags >> Flags.shiftAmount);
|
||||||
|
ret.running = (counter_and_flags & Flags.running) != 0;
|
||||||
|
ret.initialized = (counter_and_flags & Flags.initialized) != 0;
|
||||||
|
ret.interrupt = (counter_and_flags & Flags.interrupt) != 0;
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
package void handleInterrupt(scope void delegate() @safe nothrow on_interrupt)
|
package void handleInterrupt(scope void delegate() @safe nothrow on_interrupt)
|
||||||
@safe nothrow {
|
@safe nothrow {
|
||||||
assert(() @trusted { return Task.getThis().fiber; } () is this, "Handling interrupt outside of the corresponding fiber.");
|
assert(() @trusted { return Task.getThis().fiber; } () is this,
|
||||||
if (m_interrupt && on_interrupt) {
|
"Handling interrupt outside of the corresponding fiber.");
|
||||||
|
if (getTaskStatusFromOwnerThread().interrupt && on_interrupt) {
|
||||||
debug (VibeTaskLog) logTrace("Handling interrupt flag.");
|
debug (VibeTaskLog) logTrace("Handling interrupt flag.");
|
||||||
m_interrupt = false;
|
clearInterruptFlag();
|
||||||
on_interrupt();
|
on_interrupt();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
package void handleInterrupt()
|
package void handleInterrupt()
|
||||||
@safe {
|
@safe {
|
||||||
if (m_interrupt) {
|
assert(() @trusted { return Task.getThis().fiber; } () is this,
|
||||||
m_interrupt = false;
|
"Handling interrupt outside of the corresponding fiber.");
|
||||||
|
if (getTaskStatusFromOwnerThread().interrupt) {
|
||||||
|
clearInterruptFlag();
|
||||||
throw new InterruptException;
|
throw new InterruptException;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void clearInterruptFlag()
|
||||||
|
@safe nothrow {
|
||||||
|
auto tcf = atomicLoad(m_taskCounterAndFlags);
|
||||||
|
auto st = getTaskStatus(tcf);
|
||||||
|
while (true) {
|
||||||
|
assert(st.initialized);
|
||||||
|
if (!st.interrupt) break;
|
||||||
|
auto tcf_int = tcf & ~Flags.interrupt;
|
||||||
|
if (cas(&m_taskCounterAndFlags, tcf, tcf_int))
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
package struct TaskFuncInfo {
|
package struct TaskFuncInfo {
|
||||||
|
@ -843,6 +927,8 @@ package struct TaskScheduler {
|
||||||
auto t = m_taskQueue.front;
|
auto t = m_taskQueue.front;
|
||||||
m_taskQueue.popFront();
|
m_taskQueue.popFront();
|
||||||
debug (VibeTaskLog) logTrace("resuming task");
|
debug (VibeTaskLog) logTrace("resuming task");
|
||||||
|
auto task = t.task;
|
||||||
|
if (task != Task.init)
|
||||||
resumeTask(t.task);
|
resumeTask(t.task);
|
||||||
debug (VibeTaskLog) logTrace("task out");
|
debug (VibeTaskLog) logTrace("task out");
|
||||||
|
|
||||||
|
@ -863,6 +949,8 @@ package struct TaskScheduler {
|
||||||
nothrow {
|
nothrow {
|
||||||
import std.encoding : sanitize;
|
import std.encoding : sanitize;
|
||||||
|
|
||||||
|
assert(t != Task.init, "Resuming null task");
|
||||||
|
|
||||||
debug (VibeTaskLog) logTrace("task fiber resume");
|
debug (VibeTaskLog) logTrace("task fiber resume");
|
||||||
auto uncaught_exception = () @trusted nothrow { return t.fiber.call!(Fiber.Rethrow.no)(); } ();
|
auto uncaught_exception = () @trusted nothrow { return t.fiber.call!(Fiber.Rethrow.no)(); } ();
|
||||||
debug (VibeTaskLog) logTrace("task fiber yielded");
|
debug (VibeTaskLog) logTrace("task fiber yielded");
|
||||||
|
|
Loading…
Reference in a new issue