From f8835076f85e654663cfaa6755eb0a8f2d43e7ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B6nke=20Ludwig?= Date: Mon, 17 Aug 2015 15:49:32 +0200 Subject: [PATCH] Improve operator forwarding. - Added @disableIndex to disable opIndex forwarding for certain fields - Added field/properly forwarding - Fixed binary operator overloading - Allow implicit conversions when trying to match return types to TaggedAlgebraic during operator return type deduction --- source/taggedalgebraic.d | 302 ++++++++++++++++++++++++++++++--------- 1 file changed, 234 insertions(+), 68 deletions(-) diff --git a/source/taggedalgebraic.d b/source/taggedalgebraic.d index ffe5b6d..a5ac339 100644 --- a/source/taggedalgebraic.d +++ b/source/taggedalgebraic.d @@ -42,8 +42,9 @@ struct TaggedAlgebraic(U) if (is(U == union) || is(U == struct)) { import std.algorithm : among; import std.string : format; - import std.traits : CopyTypeQualifiers, FieldTypeTuple, FieldNameTuple, Largest, hasElaborateCopyConstructor, hasElaborateDestructor; + import std.traits : FieldTypeTuple, FieldNameTuple, Largest, hasElaborateCopyConstructor, hasElaborateDestructor; + private alias Union = U; private alias FieldTypes = FieldTypeTuple!U; private alias fieldNames = FieldNameTuple!U; @@ -52,7 +53,7 @@ struct TaggedAlgebraic(U) if (is(U == union) || is(U == struct)) private { - void[Largest!FieldTypes.sizeof] m_data; + void[Largest!FieldTypes.sizeof] m_data = void; Type m_type; } @@ -66,6 +67,12 @@ struct TaggedAlgebraic(U) if (is(U == union) || is(U == struct)) //pragma(msg, generateConstructors!U()); mixin(generateConstructors!U); + this(TaggedAlgebraic other) + { + import std.algorithm : swap; + swap(this, other); + } + void opAssign(TaggedAlgebraic other) { import std.algorithm : swap; @@ -136,6 +143,8 @@ struct TaggedAlgebraic(U) if (is(U == union) || is(U == struct)) /// Enables the invocation of methods of the stored value. auto opDispatch(string name, this TA, ARGS...)(auto ref ARGS args) if (hasOp!(TA, OpKind.method, name, ARGS)) { return implementOp!(OpKind.method, name)(this, args); } + /// Enables accessing properties/fields of the stored value. + @property auto opDispatch(string name, this TA, ARGS...)(auto ref ARGS args) if (hasOp!(TA, OpKind.field, name, ARGS) && !hasOp!(TA, OpKind.method, name, ARGS)) { return implementOp!(OpKind.field, name)(this, args); } /// Enables equality comparison with the stored value. auto opEquals(T, this TA)(auto ref T other) if (hasOp!(TA, OpKind.binary, "==", T)) { return implementOp!(OpKind.binary, "==")(this, other); } /// Enables relational comparisons with the stored value. @@ -143,7 +152,7 @@ struct TaggedAlgebraic(U) if (is(U == union) || is(U == struct)) /// Enables the use of unary operators with the stored value. auto opUnary(string op, this TA)() if (hasOp!(TA, OpKind.unary, op)) { return implementOp!(OpKind.unary, op)(this); } /// Enables the use of binary operators with the stored value. - auto opBinary(string op, T, this TA)(auto ref T other) inout if (hasOp!(TA, OpKind.binary, op, T)) { return implementOp!(OpKind.binary, op)(this, other); } + auto opBinary(string op, T, this TA)(auto ref T other) if (hasOp!(TA, OpKind.binary, op, T)) { return implementOp!(OpKind.binary, op)(this, other); } /// Enables operator assignments on the stored value. auto opOpAssign(string op, T, this TA)(auto ref T other) if (hasOp!(TA, OpKind.binary, op~"=", T)) { return implementOp!(OpKind.binary, op~"=")(this, other); } /// Enables indexing operations on the stored value. @@ -153,41 +162,6 @@ struct TaggedAlgebraic(U) if (is(U == union) || is(U == struct)) /// Enables call syntax operations on the stored value. auto opCall(this TA, ARGS...)(auto ref ARGS args) if (hasOp!(TA, OpKind.call, null, ARGS)) { return implementOp!(OpKind.call, null)(this, args); } - private template hasOp(TA, OpKind kind, string name, ARGS...) - { - alias UQ = CopyTypeQualifiers!(TA, U); - enum hasOp = .hasOp!(UQ, kind, name, ARGS); - } - - private static auto implementOp(OpKind kind, string name, T, ARGS...)(ref T self, auto ref ARGS args) - { - import std.array : join; - import std.variant : Algebraic, Variant; - alias UQ = CopyTypeQualifiers!(T, U); - - alias info = OpInfo!(UQ, kind, name, ARGS); - - switch (self.m_type) { - default: assert(false, "Operator "~name~" ("~kind.stringof~") can only be used on values of the following types: "~[info.fields].join(", ")); - foreach (i, f; info.fields) { - alias FT = FieldTypes[i]; - case __traits(getMember, Type, f): - static if (NoDuplicates!(info.ReturnTypes).length == 1) - return info.perform(self.trustedGet!f, args); - else static if (allSatisfy!(isMatchingUniqueType!U, info.ReturnTypes)) - return TaggedAlgebraic(info.perform(self.trustedGet!f, args)); - else static if (allSatisfy!(isNoVariant, info.ReturnTypes)) - return Algebraic!(NoDuplicates!(info.ReturnTypes))(info.perform(self.trustedGet!f, args)); - else static if (is(FT == Variant)) - return info.perform(self.trustedGet!f, args); - else - return Variant(info.perform(self.trustedGet!f, args)); - } - } - - assert(false); // never reached - } - private @trusted @property ref inout(typeof(__traits(getMember, U, f))) trustedGet(string f)() inout { return trustedGet!(inout(typeof(__traits(getMember, U, f)))); } private @trusted @property ref inout(T) trustedGet(T)() inout { return *cast(inout(T)*)m_data.ptr; } } @@ -457,13 +431,49 @@ unittest { // Ambiguous binary op between two TaggedAlgebraic values TA a = 1, b = 2; static assert(is(typeof(a + b) == TA)); + assert((a + b).typeID == TA.Type.i); assert(a + b == 3); } +unittest { + struct S { + union U { + @disableIndex string str; + S[] array; + S[string] object; + } + alias TA = TaggedAlgebraic!U; + TA payload; + alias payload this; + } + + S a = S(S.TA("hello")); + S b = S(S.TA(["foo": a])); + S c = S(S.TA([a])); + assert(b["foo"] == a); + assert(b["foo"] == "hello"); + assert(c[0] == a); + assert(c[0] == "hello"); +} + /// Convenience type that can be used for union fields that have no value (`void` is not allowed). struct Void {} -private enum hasOp(U, OpKind kind, string name, ARGS...) = TypeTuple!(OpInfo!(U, kind, name, ARGS).fields).length > 0; +/// User-defined attibute to disable `opIndex` forwarding for a particular tagged union member. +@property auto disableIndex() { assert(__ctfe, "disableIndex must only be used as an attribute."); return DisableOpAttribute(OpKind.index, null); } + +private struct DisableOpAttribute { + OpKind kind; + string name; +} + + +private template hasOp(TA, OpKind kind, string name, ARGS...) +{ + import std.traits : CopyTypeQualifiers; + alias UQ = CopyTypeQualifiers!(TA, TA.Union); + enum hasOp = TypeTuple!(OpInfo!(UQ, kind, name, ARGS).fields).length > 0; +} unittest { static struct S { @@ -473,21 +483,124 @@ unittest { } static union U { int i; string s; S st; } + alias TA = TaggedAlgebraic!U; - static assert(hasOp!(U, OpKind.binary, "+", int)); - static assert(hasOp!(U, OpKind.binary, "~", string)); - static assert(hasOp!(U, OpKind.binary, "==", int)); - static assert(hasOp!(U, OpKind.binary, "==", string)); - static assert(hasOp!(U, OpKind.binary, "==", int)); - static assert(hasOp!(U, OpKind.binary, "==", S)); - static assert(hasOp!(U, OpKind.method, "m", int)); - static assert(hasOp!(U, OpKind.binary, "+=", int)); - static assert(!hasOp!(U, OpKind.binary, "~", int)); - static assert(!hasOp!(U, OpKind.binary, "~", int)); - static assert(!hasOp!(U, OpKind.method, "m", string)); - static assert(!hasOp!(U, OpKind.method, "m")); - static assert(!hasOp!(const(U), OpKind.binary, "+=", int)); - static assert(!hasOp!(const(U), OpKind.method, "m", int)); + static assert(hasOp!(TA, OpKind.binary, "+", int)); + static assert(hasOp!(TA, OpKind.binary, "~", string)); + static assert(hasOp!(TA, OpKind.binary, "==", int)); + static assert(hasOp!(TA, OpKind.binary, "==", string)); + static assert(hasOp!(TA, OpKind.binary, "==", int)); + static assert(hasOp!(TA, OpKind.binary, "==", S)); + static assert(hasOp!(TA, OpKind.method, "m", int)); + static assert(hasOp!(TA, OpKind.binary, "+=", int)); + static assert(!hasOp!(TA, OpKind.binary, "~", int)); + static assert(!hasOp!(TA, OpKind.binary, "~", int)); + static assert(!hasOp!(TA, OpKind.method, "m", string)); + static assert(!hasOp!(TA, OpKind.method, "m")); + static assert(!hasOp!(const(TA), OpKind.binary, "+=", int)); + static assert(!hasOp!(const(TA), OpKind.method, "m", int)); +} + +unittest { + struct S { + union U { + string s; + S[] arr; + S[string] obj; + } + alias TA = TaggedAlgebraic!(S.U); + TA payload; + alias payload this; + } + static assert(hasOp!(S.TA, OpKind.index, null, size_t)); + static assert(hasOp!(S.TA, OpKind.index, null, int)); + static assert(hasOp!(S.TA, OpKind.index, null, string)); + static assert(hasOp!(S.TA, OpKind.field, "length")); +} + +private static auto implementOp(OpKind kind, string name, T, ARGS...)(ref T self, auto ref ARGS args) +{ + import std.array : join; + import std.traits : CopyTypeQualifiers; + import std.variant : Algebraic, Variant; + alias UQ = CopyTypeQualifiers!(T, T.Union); + + alias info = OpInfo!(UQ, kind, name, ARGS); + + static assert(hasOp!(T, kind, name, ARGS)); + + static assert(info.fields.length > 0, "Implementing operator that has no valid implementation for any supported type."); + + //pragma(msg, "Fields for "~kind.stringof~" "~name~", "~T.stringof~": "~info.fields.stringof); + //pragma(msg, "Return types for "~kind.stringof~" "~name~", "~T.stringof~": "~info.ReturnTypes.stringof); + //pragma(msg, typeof(T.Union.tupleof)); + //import std.meta : staticMap; pragma(msg, staticMap!(isMatchingUniqueType!(T.Union), info.ReturnTypes)); + + switch (self.m_type) { + default: assert(false, "Operator "~name~" ("~kind.stringof~") can only be used on values of the following types: "~[info.fields].join(", ")); + foreach (i, f; info.fields) { + alias FT = typeof(__traits(getMember, T.Union, f)); + case __traits(getMember, T.Type, f): + static if (NoDuplicates!(info.ReturnTypes).length == 1) + return info.perform(self.trustedGet!FT, args); + else static if (allSatisfy!(isMatchingUniqueType!(T.Union), info.ReturnTypes)) + return TaggedAlgebraic!(T.Union)(info.perform(self.trustedGet!FT, args)); + else static if (allSatisfy!(isNoVariant, info.ReturnTypes)) { + alias Alg = Algebraic!(NoDuplicates!(info.ReturnTypes)); + info.ReturnTypes[i] ret = info.perform(self.trustedGet!FT, args); + import std.traits : isInstanceOf; + static if (isInstanceOf!(TaggedAlgebraic, typeof(ret))) return Alg(ret.payload); + else return Alg(ret); + } + else static if (is(FT == Variant)) + return info.perform(self.trustedGet!FT, args); + else + return Variant(info.perform(self.trustedGet!FT, args)); + } + } + + assert(false); // never reached +} + +unittest { // opIndex on recursive TA with closed return value set + static struct S { + union U { + char ch; + string str; + S[] arr; + } + alias TA = TaggedAlgebraic!U; + TA payload; + alias payload this; + + this(T)(T t) { this.payload = t; } + } + S a = S("foo"); + S s = S([a]); + + assert(implementOp!(OpKind.field, "length")(s.payload) == 1); + static assert(is(typeof(implementOp!(OpKind.index, null)(s.payload, 0)) == S.TA)); + assert(implementOp!(OpKind.index, null)(s.payload, 0) == "foo"); +} + +unittest { // opIndex on recursive TA with closed return value set using @disableIndex + static struct S { + union U { + @disableIndex string str; + S[] arr; + } + alias TA = TaggedAlgebraic!U; + TA payload; + alias payload this; + + this(T)(T t) { this.payload = t; } + } + S a = S("foo"); + S s = S([a]); + + assert(implementOp!(OpKind.field, "length")(s.payload) == 1); + static assert(is(typeof(implementOp!(OpKind.index, null)(s.payload, 0)) == S)); + assert(implementOp!(OpKind.index, null)(s.payload, 0) == "foo"); } @@ -496,6 +609,7 @@ private auto performOpRaw(U, OpKind kind, string name, T, ARGS...)(ref T value, static if (kind == OpKind.binary) return mixin("value "~name~" args[0]"); else static if (kind == OpKind.unary) return mixin("name "~value); else static if (kind == OpKind.method) return __traits(getMember, value, name)(args); + else static if (kind == OpKind.field) return __traits(getMember, value, name); else static if (kind == OpKind.index) return value[args]; else static if (kind == OpKind.indexAssign) return value[args[1 .. $]] = args[0]; else static if (kind == OpKind.call) return value(args); @@ -555,15 +669,30 @@ unittest { private template OpInfo(U, OpKind kind, string name, ARGS...) { - import std.traits : FieldTypeTuple, FieldNameTuple, ReturnType; + import std.traits : CopyTypeQualifiers, FieldTypeTuple, FieldNameTuple, ReturnType; - alias FieldTypes = FieldTypeTuple!U; - alias fieldNames = FieldNameTuple!U; + private alias FieldTypes = FieldTypeTuple!U; + private alias fieldNames = FieldNameTuple!U; + + private template isOpEnabled(string field) + { + alias attribs = TypeTuple!(__traits(getAttributes, __traits(getMember, U, field))); + template impl(size_t i) { + static if (i < attribs.length) { + static if (is(typeof(attribs[i]) == DisableOpAttribute)) { + static if (kind == attribs[i].kind && name == attribs[i].name) + enum impl = false; + else enum impl = impl!(i+1); + } else enum impl = impl!(i+1); + } else enum impl = true; + } + enum isOpEnabled = impl!0; + } template fieldsImpl(size_t i) { static if (i < FieldTypes.length) { - static if (is(typeof(&performOp!(U, kind, name, FieldTypes[i], ARGS)))) { + static if (isOpEnabled!(fieldNames[i]) && is(typeof(&performOp!(U, kind, name, FieldTypes[i], ARGS)))) { alias fieldsImpl = TypeTuple!(fieldNames[i], fieldsImpl!(i+1)); } else alias fieldsImpl = fieldsImpl!(i+1); } else alias fieldsImpl = TypeTuple!(); @@ -571,11 +700,9 @@ private template OpInfo(U, OpKind kind, string name, ARGS...) alias fields = fieldsImpl!0; template ReturnTypesImpl(size_t i) { - static if (i < FieldTypes.length) { - static if (is(typeof(&performOp!(U, kind, name, FieldTypes[i], ARGS)))) { - alias T = ReturnType!(performOp!(U, kind, name, FieldTypes[i], ARGS)); - alias ReturnTypesImpl = TypeTuple!(T, ReturnTypesImpl!(i+1)); - } else alias ReturnTypesImpl = ReturnTypesImpl!(i+1); + static if (i < fields.length) { + alias FT = CopyTypeQualifiers!(U, typeof(__traits(getMember, U, fields[i]))); + alias ReturnTypesImpl = TypeTuple!(ReturnType!(performOp!(U, kind, name, FT, ARGS)), ReturnTypesImpl!(i+1)); } else alias ReturnTypesImpl = TypeTuple!(); } alias ReturnTypes = ReturnTypesImpl!0; @@ -583,10 +710,21 @@ private template OpInfo(U, OpKind kind, string name, ARGS...) static auto perform(T)(ref T value, auto ref ARGS args) { return performOp!(U, kind, name)(value, args); } } +private template ImplicitUnqual(T) { + import std.traits : Unqual, hasAliasing; + static if (is(T == void)) alias ImplicitUnqual = void; + else { + private static struct S { T t; } + static if (hasAliasing!S) alias ImplicitUnqual = T; + else alias ImplicitUnqual = Unqual!T; + } +} + private enum OpKind { binary, unary, method, + field, index, indexAssign, call @@ -727,16 +865,44 @@ private template isMatchingType(U) { } private template isMatchingUniqueType(U) { - import std.traits : FieldTypeTuple; + import std.traits : staticMap; + alias UniqueTypes = staticMap!(FieldTypeOf!U, UniqueTypeFields!U); template isMatchingUniqueType(T) { - alias Types = FieldTypeTuple!U; - enum idx = staticIndexOf!(T, Types); - static if (idx < 0) enum isMatchingUniqueType = false; - else static if (staticIndexOf!(T, Types[idx+1 .. $]) >= 0) enum isMatchingUniqueType = false; - else enum isMatchingUniqueType = true; + static if (is(T : TaggedAlgebraic!U)) enum isMatchingUniqueType = true; + else enum isMatchingUniqueType = staticIndexOfImplicit!(T, UniqueTypes) >= 0; } } +private template FieldTypeOf(U) { + template FieldTypeOf(string name) { + alias FieldTypeOf = typeof(__traits(getMember, U, name)); + } +} + +private template staticIndexOfImplicit(T, Types...) { + template impl(size_t i) { + static if (i < Types.length) { + static if (is(T : Types[i])) { + pragma(msg, "YEPP "~T.stringof~" "~Types[i].stringof); + enum impl = i; + } else { + pragma(msg, "NOPE "~T.stringof~" "~Types[i].stringof); + enum impl = impl!(i+1); + } + } else enum impl = -1; + } + enum staticIndexOfImplicit = impl!0; +} + +unittest { + static assert(staticIndexOfImplicit!(immutable(char), char) == 0); + static assert(staticIndexOfImplicit!(int, long) == 0); + static assert(staticIndexOfImplicit!(long, int) < 0); + static assert(staticIndexOfImplicit!(int, int, double) == 0); + static assert(staticIndexOfImplicit!(double, int, double) == 1); +} + + private template isNoVariant(T) { import std.variant : Variant; enum isNoVariant = !is(T == Variant);