From 95c2919d2bfd034e1df179081a0f4cee44374650 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B6nke=20Ludwig?= Date: Fri, 22 Feb 2019 00:28:15 +0100 Subject: [PATCH] Introduce TaggedUnion as the low-level basis for TaggedAlgebraic. TaggedUnion has a number of convenience features compared to TaggedAlgebraic that are possible because of the missing dynamic dispatch functionality. If the latter is not required, TaggedUnion provides a much less complex and more robust way to store a fixed set of types/kinds. --- source/taggedalgebraic.d | 494 +++++++++++++++++++++++++++------------ 1 file changed, 340 insertions(+), 154 deletions(-) diff --git a/source/taggedalgebraic.d b/source/taggedalgebraic.d index 82447ad..a981431 100644 --- a/source/taggedalgebraic.d +++ b/source/taggedalgebraic.d @@ -1,12 +1,13 @@ /** - * Algebraic data type implementation based on a tagged union. + * Generic tagged union and algebraic data type implementations. * - * Copyright: Copyright 2015-2016, Sönke Ludwig. + * Copyright: Copyright 2015-2019, Sönke Ludwig. * License: $(WEB www.boost.org/LICENSE_1_0.txt, Boost License 1.0). * Authors: Sönke Ludwig */ module taggedalgebraic; +import std.algorithm.mutation : move, swap; import std.typetuple; import std.traits : Unqual, isInstanceOf; @@ -14,6 +15,301 @@ import std.traits : Unqual, isInstanceOf; // - distinguish between @property and non@-property methods. // - verify that static methods are handled properly + +/** Implements a generic tagged union type. + + This struct takes a `union` or `struct` declaration as an input and builds + an algebraic data type from its fields, using an automatically generated + `Kind` enumeration to identify which field of the union is currently used. + Multiple fields with the same value are supported. + + For each field defined by `U` a number of convenience members are generated. + For a given field "foo", these fields are: + + $(UL + $(LI `static foo(value)`) - returns a new tagged union with the specified value) + $(LI `isFoo` - equivalent to `kind == Kind.foo`) + $(LI `setFoo(value)` - equivalent to `set!(Kind.foo)(value)`) + $(LI `getFoo` - equivalent to `get!(Kind.foo)`) + ) +*/ +struct TaggedUnion(U) if (is(U == union) || is(U == struct)) +{ + import std.traits : FieldTypeTuple, FieldNameTuple, Largest, + hasElaborateCopyConstructor, hasElaborateDestructor, isCopyable; + import std.ascii : toUpper; + + alias FieldTypes = FieldTypeTuple!U; + alias fieldNames = FieldNameTuple!U; + + static assert(FieldTypes.length > 0, "The TaggedUnions's union type must have at least one field."); + static assert(FieldTypes.length == fieldNames.length); + + /// A type enum that identifies the type of value currently stored. + alias Kind = TypeEnum!U; + + private { + static if (isUnionType!(FieldTypes[0]) || __VERSION__ < 2072) { + void[Largest!FieldTypes.sizeof] m_data; + } else { + union Dummy { + FieldTypes[0] initField; + void[Largest!FieldTypes.sizeof] data; + alias data this; + } + Dummy m_data = { initField: FieldTypes[0].init }; + } + Kind m_kind; + } + + this(TaggedUnion other) + { + rawSwap(this, other); + } + + void opAssign(TaggedUnion other) + { + rawSwap(this, other); + } + + // disable default construction if first type is not a null/Void type + static if (!isUnionType!(FieldTypes[0]) && __VERSION__ < 2072) { + @disable this(); + } + + // postblit constructor + static if (!allSatisfy!(isCopyable, FieldTypes)) { + @disable this(this); + } else static if (anySatisfy!(hasElaborateCopyConstructor, FieldTypes)) { + this(this) + { + switch (m_kind) { + default: break; + foreach (i, tname; fieldNames) { + alias T = FieldTypes[i]; + static if (hasElaborateCopyConstructor!T) + { + case __traits(getMember, Kind, tname): + typeid(T).postblit(cast(void*)&trustedGet!T()); + return; + } + } + } + } + } + + // destructor + static if (anySatisfy!(hasElaborateDestructor, FieldTypes)) { + ~this() + { + final switch (m_kind) { + foreach (i, tname; fieldNames) { + alias T = FieldTypes[i]; + case __traits(getMember, Kind, tname): + static if (hasElaborateDestructor!T) { + .destroy(trustedGet!T); + } + return; + } + } + } + } + + /// Enables conversion or extraction of the stored value. + T opCast(T)() + { + import std.conv : to; + + final switch (m_kind) { + foreach (i, FT; FieldTypes) { + case __traits(getMember, Kind, fieldNames[i]): + static if (is(typeof(trustedGet!FT) : T)) + return trustedGet!FT; + else static if (is(typeof(to!T(trustedGet!FT)))) { + return to!T(trustedGet!FT); + } else { + assert(false, "Cannot cast a " ~ fieldNames[i] + ~ " value of type " ~ FT.stringof ~ " to " ~ T.stringof); + } + } + } + assert(false); // never reached + } + /// ditto + T opCast(T)() const + { + // this method needs to be duplicated because inout doesn't work with to!() + import std.conv : to; + + final switch (m_kind) { + foreach (i, FT; FieldTypes) { + case __traits(getMember, Kind, fieldNames[i]): + static if (is(typeof(trustedGet!FT) : T)) + return trustedGet!FT; + else static if (is(typeof(to!T(trustedGet!FT)))) { + return to!T(trustedGet!FT); + } else { + assert(false, "Cannot cast a " ~ fieldNames[i] + ~ " value of type" ~ FT.stringof ~ " to " ~ T.stringof); + } + } + } + assert(false); // never reached + } + + /// Enables equality comparison with the stored value. + bool opEquals()(auto ref inout(TaggedUnion) other) + inout { + if (this.kind != other.kind) return false; + + final switch (this.kind) { + foreach (i, fname; TaggedUnion!U.fieldNames) + case __traits(getMember, Kind, fname): + return trustedGet!(FieldTypes[i]) == other.trustedGet!(FieldTypes[i]); + } + assert(false); // never reached + } + + /// The type ID of the currently stored value. + @property Kind kind() const { return m_kind; } + + static foreach (i, name; fieldNames) { + // NOTE: using getX/setX here because using just x would be prone to + // misuse (attempting to "get" a value for modification when + // a different kind is set instead of assigning a new value) + mixin("alias get"~pascalCase(name)~" = get!(Kind."~name~");"); + mixin("alias set"~pascalCase(name)~" = set!(Kind."~name~");"); + mixin("@property bool is"~pascalCase(name)~"() const { return m_kind == Kind."~name~"; }"); + + static if (!isUnionType!(FieldTypes[i])) { + mixin("static TaggedUnion "~name~"(FieldTypes["~i.stringof~"] value)" + ~ "{ TaggedUnion tu; tu.set!(Kind."~name~")(move(value)); return tu; }"); + + // TODO: define assignment operator for unique types + } else { + mixin("static @property TaggedUnion "~name~"() { TaggedUnion tu; tu.set!(Kind."~name~"); return tu; }"); + } + + } + + ref inout(FieldTypes[kind]) get(Kind kind)() + inout { + if (this.kind != kind) { + enum msg(.string k_is) = "Attempt to get kind "~kind.stringof~" from tagged union with kind "~k_is; + final switch (this.kind) { + static foreach (i, n; fieldNames) + case __traits(getMember, Kind, n): + assert(false, msg!n); + } + } + //return trustedGet!(FieldTypes[kind]); + return *() @trusted { return cast(const(FieldTypes[kind])*)m_data.ptr; } (); + } + + + ref inout(T) get(T)() inout + if (staticIndexOf!(T, FieldTypes) >= 0) + { + final switch (this.kind) { + static foreach (n; fieldNames) { + case __traits(getMember, Kind, n): + static if (is(FieldTypes[__traits(getMember, Kind, n)] == T)) + return trustedGet!T; + else assert(false, "Attempting to get type "~T.stringof + ~ " from a TaggedUnion with type " + ~ FieldTypes[__traits(getMember, Kind, n)].stringof); + } + } + } + + ref FieldTypes[kind] set(Kind kind)(FieldTypes[kind] value) + if (!isUnionType!(FieldTypes[kind])) + { + if (m_kind != kind) { + destroy(this); + m_data.rawEmplace(value); + } else { + rawSwap(trustedGet!(FieldTypes[kind]), value); + } + m_kind = kind; + + return trustedGet!(FieldTypes[kind]); + } + + void set(Kind kind)() + if (isUnionType!(FieldTypes[kind])) + { + if (m_kind != kind) { + destroy(this); + } + m_kind = kind; + } + + private @trusted @property ref inout(T) trustedGet(T)() inout { return *cast(inout(T)*)m_data.ptr; } +} + +/// +@safe nothrow unittest { + union Kinds { + int count; + string text; + } + alias TU = TaggedUnion!Kinds; + + // default initialized to the first field defined + TU tu; + assert(tu.kind == TU.Kind.count); + assert(tu.isCount); // qequivalent to the line above + assert(!tu.isText); + assert(tu.get!(TU.Kind.count) == int.init); + + // set to a specific count + tu.setCount(42); + assert(tu.isCount); + assert(tu.getCount() == 42); + assert(tu.get!(TU.Kind.count) == 42); + assert(tu.get!int == 42); // can also get by type + assert(tu.getCount() == 42); + + // assign a new tagged algebraic value + tu = TU.count(43); + + // test equivalence with other tagged unions + assert(tu == TU.count(43)); + assert(tu != TU.count(42)); + assert(tu != TU.text("hello")); + + // modify by reference + tu.getCount()++; + assert(tu.getCount() == 44); + + // set the second field + tu.setText("hello"); + assert(!tu.isCount); + assert(tu.isText); + assert(tu.kind == TU.Kind.text); + assert(tu.getText() == "hello"); +} + +unittest { // test for name clashes + union U { .string string; } + alias TU = TaggedUnion!U; + TU tu; + tu = TU.string("foo"); + assert(tu.isString); + assert(tu.getString() == "foo"); +} + +enum isUnionType(T) = is(T == Void) || is(T == void) || is(T == typeof(null)); + +private string pascalCase(string camel_case) +{ + if (!__ctfe) assert(false); + import std.ascii : toUpper; + return camel_case[0].toUpper ~ camel_case[1 .. $]; +} + + /** Implements a generic algebraic type using an enum to identify the stored type. This struct takes a `union` or `struct` declaration as an input and builds @@ -43,40 +339,20 @@ struct TaggedAlgebraic(U) if (is(U == union) || is(U == struct)) { import std.algorithm : among; import std.string : format; - import std.traits : FieldTypeTuple, FieldNameTuple, Largest, hasElaborateCopyConstructor, hasElaborateDestructor; /// Alias of the type used for defining the possible storage types/kinds. alias Union = U; - private alias FieldTypes = FieldTypeTuple!U; - private alias fieldNames = FieldNameTuple!U; - - static assert(FieldTypes.length > 0, "The TaggedAlgebraic's union type must have at least one field."); - static assert(FieldTypes.length == fieldNames.length); - - - private { - static if (is(FieldTypes[0] == typeof(null)) || is(FieldTypes[0] == Void) || __VERSION__ < 2072) { - void[Largest!FieldTypes.sizeof] m_data; - } else { - union Dummy { - FieldTypes[0] initField; - void[Largest!FieldTypes.sizeof] data; - alias data this; - } - Dummy m_data = { initField: FieldTypes[0].init }; - } - Kind m_kind; - } + private TaggedUnion!U m_union; /// A type enum that identifies the type of value currently stored. - alias Kind = TypeEnum!U; + alias Kind = TaggedUnion!U.Kind; /// Compatibility alias deprecated("Use 'Kind' instead.") alias Type = Kind; /// The type ID of the currently stored value. - @property Kind kind() const { return m_kind; } + @property Kind kind() const { return m_union.kind; } // Compatibility alias deprecated("Use 'kind' instead.") @@ -96,85 +372,10 @@ struct TaggedAlgebraic(U) if (is(U == union) || is(U == struct)) rawSwap(this, other); } - // postblit constructor - static if (anySatisfy!(hasElaborateCopyConstructor, FieldTypes)) - { - this(this) - { - switch (m_kind) { - default: break; - foreach (i, tname; fieldNames) { - alias T = typeof(__traits(getMember, U, tname)); - static if (hasElaborateCopyConstructor!T) - { - case __traits(getMember, Kind, tname): - typeid(T).postblit(cast(void*)&trustedGet!tname()); - return; - } - } - } - } - } - - // destructor - static if (anySatisfy!(hasElaborateDestructor, FieldTypes)) - { - ~this() - { - final switch (m_kind) { - foreach (i, tname; fieldNames) { - alias T = typeof(__traits(getMember, U, tname)); - case __traits(getMember, Kind, tname): - static if (hasElaborateDestructor!T) { - .destroy(trustedGet!tname); - } - return; - } - } - } - } - /// Enables conversion or extraction of the stored value. - T opCast(T)() - { - import std.conv : to; - - final switch (m_kind) { - foreach (i, FT; FieldTypes) { - case __traits(getMember, Kind, fieldNames[i]): - static if (is(typeof(trustedGet!(fieldNames[i])) : T)) - return trustedGet!(fieldNames[i]); - else static if (is(typeof(to!T(trustedGet!(fieldNames[i]))))) { - return to!T(trustedGet!(fieldNames[i])); - } else { - assert(false, "Cannot cast a " ~ fieldNames[i] - ~ " value of type " ~ FT.stringof ~ " to " ~ T.stringof); - } - } - } - assert(false); // never reached - } + T opCast(T)() { return cast(T)m_union; } /// ditto - T opCast(T)() const - { - // this method needs to be duplicated because inout doesn't work with to!() - import std.conv : to; - - final switch (m_kind) { - foreach (i, FT; FieldTypes) { - case __traits(getMember, Kind, fieldNames[i]): - static if (is(typeof(trustedGet!(fieldNames[i])) : T)) - return trustedGet!(fieldNames[i]); - else static if (is(typeof(to!T(trustedGet!(fieldNames[i]))))) { - return to!T(trustedGet!(fieldNames[i])); - } else { - assert(false, "Cannot cast a " ~ fieldNames[i] - ~ " value of type" ~ FT.stringof ~ " to " ~ T.stringof); - } - } - } - assert(false); // never reached - } + T opCast(T)() const { return cast(T)m_union; } /// Uses `cast(string)`/`to!string` to return a string representation of the enclosed value. string toString() const { return cast(string)this; } @@ -193,12 +394,7 @@ struct TaggedAlgebraic(U) if (is(U == union) || is(U == struct)) if (is(Unqual!T == TaggedAlgebraic) || hasOp!(TA, OpKind.binary, "==", T)) { static if (is(Unqual!T == TaggedAlgebraic)) { - if (this.kind != other.kind) return false; - final switch (this.kind) - foreach (i, fname; fieldNames) - case __traits(getMember, Kind, fname): - return trustedGet!fname == other.trustedGet!fname; - assert(false); // never reached + return m_union == other.m_union; } else return implementOp!(OpKind.binary, "==")(this, other); } /// Enables relational comparisons with the stored value. @@ -219,9 +415,6 @@ struct TaggedAlgebraic(U) if (is(U == union) || is(U == struct)) auto opIndexAssign(this TA, ARGS...)(auto ref ARGS args) if (hasOp!(TA, OpKind.indexAssign, null, ARGS)) { return implementOp!(OpKind.indexAssign, null)(this, args); } /// Enables call syntax operations on the stored value. auto opCall(this TA, ARGS...)(auto ref ARGS args) if (hasOp!(TA, OpKind.call, null, ARGS)) { return implementOp!(OpKind.call, null)(this, args); } - - private @trusted @property ref inout(typeof(__traits(getMember, U, f))) trustedGet(string f)() inout { return trustedGet!(inout(typeof(__traits(getMember, U, f)))); } - private @trusted @property ref inout(T) trustedGet(T)() inout { return *cast(inout(T)*)m_data.ptr; } } /// @@ -688,7 +881,7 @@ unittest { // issue #13 */ bool hasType(T, U)(in ref TaggedAlgebraic!U ta) { - alias Fields = Filter!(fieldMatchesType!(U, T), ta.fieldNames); + alias Fields = Filter!(fieldMatchesType!(U, T), ta.m_union.fieldNames); static assert(Fields.length > 0, "Type "~T.stringof~" cannot be stored in a "~(TaggedAlgebraic!U).stringof~"."); switch (ta.kind) { @@ -779,16 +972,12 @@ static if (__VERSION__ >= 2072) { */ ref inout(T) get(T, U)(ref inout(TaggedAlgebraic!U) ta) { - import std.format : format; - assert(hasType!(T, U)(ta), "Type mismatch!"); - return ta.trustedGet!T; + return ta.m_union.get!T; } /// ditto inout(T) get(T, U)(inout(TaggedAlgebraic!U) ta) { - import std.format : format; - assert(hasType!(T, U)(ta), "Type mismatch!"); - return ta.trustedGet!T; + return ta.m_union.get!T; } @nogc @safe nothrow unittest { @@ -820,9 +1009,9 @@ auto apply(alias handler, TA)(TA ta) if (isInstanceOf!(TaggedAlgebraic, TA)) { final switch (ta.kind) { - foreach (i, fn; TA.fieldNames) { + foreach (i, fn; TA.m_union.fieldNames) { case __traits(getMember, ta.Kind, fn): - return handler(get!(TA.FieldTypes[i])(ta)); + return handler(get!(TA.m_union.FieldTypes[i])(ta)); } } static if (__VERSION__ <= 2068) assert(false); @@ -956,26 +1145,26 @@ private static auto implementOp(OpKind kind, string name, T, ARGS...)(ref T self //pragma(msg, typeof(T.Union.tupleof)); //import std.meta : staticMap; pragma(msg, staticMap!(isMatchingUniqueType!(T.Union), info.ReturnTypes)); - switch (self.m_kind) { + switch (self.kind) { enum assert_msg = "Operator "~name~" ("~kind.stringof~") can only be used on values of the following types: "~[info.fields].join(", "); default: assert(false, assert_msg); foreach (i, f; info.fields) { alias FT = typeof(__traits(getMember, T.Union, f)); case __traits(getMember, T.Kind, f): static if (NoDuplicates!(info.ReturnTypes).length == 1) - return info.perform(self.trustedGet!FT, args); + return info.perform(self.m_union.trustedGet!FT, args); else static if (allSatisfy!(isMatchingUniqueType!(T.Union), info.ReturnTypes)) - return TaggedAlgebraic!(T.Union)(info.perform(self.trustedGet!FT, args)); + return TaggedAlgebraic!(T.Union)(info.perform(self.m_union.trustedGet!FT, args)); else static if (allSatisfy!(isNoVariant, info.ReturnTypes)) { alias Alg = Algebraic!(NoDuplicates!(info.ReturnTypes)); - info.ReturnTypes[i] ret = info.perform(self.trustedGet!FT, args); + info.ReturnTypes[i] ret = info.perform(self.m_union.trustedGet!FT, args); import std.traits : isInstanceOf; return Alg(ret); } else static if (is(FT == Variant)) - return info.perform(self.trustedGet!FT, args); + return info.perform(self.m_union.trustedGet!FT, args); else - return Variant(info.perform(self.trustedGet!FT, args)); + return Variant(info.perform(self.m_union.trustedGet!FT, args)); } } @@ -1168,49 +1357,46 @@ private string generateConstructors(U)() string ret; - static if (__VERSION__ < 2072) { - // disable default construction if first type is not a null/Void type - static if (!is(FieldTypeTuple!U[0] == typeof(null)) && !is(FieldTypeTuple!U[0] == Void)) - { - ret ~= q{ - @disable this(); - }; - } - } // normal type constructors foreach (tname; UniqueTypeFields!U) ret ~= q{ - this(typeof(U.%s) value) + this(typeof(U.%1$s) value) { - m_data.rawEmplace(value); - m_kind = Kind.%s; + static if (isUnionType!(typeof(U.%1$s))) + m_union.set!(Kind.%1$s)(); + else + m_union.set!(Kind.%1$s)(value); } - void opAssign(typeof(U.%s) value) + void opAssign(typeof(U.%1$s) value) { - if (m_kind != Kind.%s) { - // NOTE: destroy(this) doesn't work for some opDispatch-related reason - static if (is(typeof(&this.__xdtor))) - this.__xdtor(); - m_data.rawEmplace(value); - } else { - trustedGet!"%s" = value; - } - m_kind = Kind.%s; + static if (isUnionType!(typeof(U.%1$s))) + m_union.set!(Kind.%1$s)(); + else + m_union.set!(Kind.%1$s)(value); } - }.format(tname, tname, tname, tname, tname, tname); + }.format(tname); // type constructors with explicit type tag foreach (tname; TypeTuple!(UniqueTypeFields!U, AmbiguousTypeFields!U)) ret ~= q{ - this(typeof(U.%s) value, Kind type) + this(typeof(U.%1$s) value, Kind type) { - assert(type.among!(%s), format("Invalid type ID for type %%s: %%s", typeof(U.%s).stringof, type)); - m_data.rawEmplace(value); - m_kind = type; + switch (type) { + default: assert(false, format("Invalid type ID for type %%s: %%s", typeof(U.%1$s).stringof, type)); + foreach (i, n; TaggedUnion!U.fieldNames) { + static if (is(typeof(U.%1$s) == typeof(__traits(getMember, U, n)))) { + case __traits(getMember, Kind, n): + static if (isUnionType!(m_union.FieldTypes[i])) + m_union.set!(__traits(getMember, Kind, n))(); + else m_union.set!(__traits(getMember, Kind, n))(value); + return; + } + } + } } - }.format(tname, [SameTypeFields!(U, tname)].map!(f => "Kind."~f).join(", "), tname); + }.format(tname); return ret; }