Module:Set

From The Satanic Wiki
Jump to navigation Jump to search

Documentation for this module may be created at Module:Set/doc

-- <nowiki>
--------------------------------------------------------------------------------
-- Library for building and manipulating sets.
--
-- @author [[User:DarthKitty]]
-- @version 0.5.0
--
-- @TODO Consider adding a `clone` function, since `mw.clone` doesn't work.
-- @TODO Consider using `pairs` internally, to cut down on (en|de)coding.
-- @TODO Consider using `Module::Inspect` for `tostring`.
--------------------------------------------------------------------------------
local p = {}

--------------------------------------------------------------------------------
-- Pass-by-reference placeholders for values that cannot be used as table keys.
-- We use functions instead of tables, since the latter are mutable.
--------------------------------------------------------------------------------
local placeholders = {
    ["nil"] = function () end,
    ["NaN"] = function () end,
}

--------------------------------------------------------------------------------
-- An ephemeron table is the closest approximation of JavaScript's `WeakMap`
-- object, and can be used to store private, per-instance data. This is more
-- efficient than using a closure, since it doesn't rebuild the whole
-- metatable for every instance.
--
-- @see <http://fitzgeraldnick.com/2014/01/13/hiding-implementation-details-with-e6-weakmaps.html>
--------------------------------------------------------------------------------
local internals = setmetatable({}, {
    __mode = "k",
    __index = function ()
        error("This set has not been initialized correctly. Please use `Module:Set.init` to construct new instances")
    end,
})

--------------------------------------------------------------------------------
-- Metatable for set instances. Provides method syntax and overloaded operators
-- for convenience.
--------------------------------------------------------------------------------
local Set = {
    __index = {},
    __newindex = function ()
        error("cannot change a protected table")
    end,
    __metatable = false,
}

--------------------------------------------------------------------------------
-- Wraps a given value for use as a table key.
--
-- @param {*} value
--     The value to wrap.
-- @returns {*}
--     The wrapped value.
--------------------------------------------------------------------------------
local function encode(value)
    -- Per IEEE 754, NaN is the only value which does not equal itself. Tables
    -- cannot generate false-positives, since the `__eq` metamethod is not
    -- called if the values being compared are primitively equal.
    local isNaN = value ~= value

    if value == nil then
        return placeholders["nil"]
    elseif isNaN then
        return placeholders["NaN"]
    else
        return value
    end
end

--------------------------------------------------------------------------------
-- Unwraps a given value after being used as a table key.
--
-- @param {*} value
--     The value to unwrap.
-- @returns {*}
--     The unwrapped value.
--------------------------------------------------------------------------------
local function decode(value)
    if value == placeholders["nil"] then
        return nil
    elseif value == placeholders["NaN"] then
        return 0 / 0 -- There's no `NaN` literal.
    else
        return value
    end
end

--------------------------------------------------------------------------------
-- Constructs a set instance.
--
-- @returns {table}
--     A new set.
--------------------------------------------------------------------------------
local function init()
    local set = setmetatable({}, Set)

    internals[set] = {
        elements = {},
        size = 0,
        prevElement = nil,
    }

    return set
end

p.init = init

--------------------------------------------------------------------------------
-- Checks if a value is a set, as defined by this library.
--
-- @param {*} value
--     The value to check.
-- @returns {boolean}
--     Whether the value is a set.
--------------------------------------------------------------------------------
local function is(value)
    return (pcall(function ()
        return internals[value]
    end))
end

p.is = is

--------------------------------------------------------------------------------
-- Fetches the number of elements in a set.
--
-- @see [[wikipedia:Cardinality]]
--
-- @param {table} set
--     A set.
-- @returns {number}
--     The number of elements in the set.
--------------------------------------------------------------------------------
local function size(set)
    if not is(set) then
        error("bad argument #1 to 'Module:Set.size' (set expected, got " .. type(set) .. ")", 2)
    end

    return internals[set].size
end

p.size = size
Set.__index.size = size

--------------------------------------------------------------------------------
-- Checks if a set has a particular element.
--
-- @param {table} set
--     The set to search.
-- @param {*} element
--     The element to search for.
-- @returns {boolean}
--     Whether the set contains the element.
--------------------------------------------------------------------------------
local function contains(set, element)
    if not is(set) then
        error("bad argument #1 to 'Module:Set.contains' (set expected, got " .. type(set) .. ")", 2)
    end

    return not not internals[set].elements[encode(element)]
end

p.contains = contains
Set.__index.contains = contains

--------------------------------------------------------------------------------
-- Fetches the "next" element in a set, wrapping around to the "beginning" if
-- the "previous" was also the "last". Since sets are unordered, the meaning of
-- those terms is determined by the built-in `next` function.
--
-- This method is provided as a workaround for Lua's horribly-designed iterator
-- "protocol". Iterators use `nil` to signal the end of a sequence, but since
-- this "class" accepts `nil` elements, it is impossible to write an iterator
-- which will always loop over every element. Furthermore, because sets are not
-- ordered, such an iterator would stop at an arbitrary position.
--
-- Instead, combining this method with `Module:Set.size` allows users to loop over
-- a set numerically, much like the `"n"` field trick utilized for sequential
-- tables with holes.
--
-- @param {table} set
--     A set.
-- @returns {*}
--     An element in the set.
--------------------------------------------------------------------------------
local function nextElement(set)
    if not is(set) then
        error("bad argument #1 to 'Module:Set.nextElement' (set expected, got " .. type(set) .. ")", 2)
    end

    local element = next(internals[set].elements, internals[set].prevElement)

    -- If we've reached the end of the set, start over from the beginning.
    if element == nil then
        element = next(internals[set].elements, element)
    end

    -- Advance the pointer.
    internals[set].prevElement = element

    return decode(element)
end

p.nextElement = nextElement
Set.__index.nextElement = nextElement

--------------------------------------------------------------------------------
-- Adds one element to a set.
--
-- @param {table} set
--     The set to add to.
-- @param {*} element
--     The element to add.
-- @returns {table}
--     The same set instance.
--------------------------------------------------------------------------------
local function add(set, element)
    if not is(set) then
        error("bad argument #1 to 'Module:Set.add' (set expected, got " .. type(set) .. ")", 2)
    end

    if not contains(set, element) then
        internals[set].elements[encode(element)] = true
        internals[set].size = internals[set].size + 1
    end

    return set
end

p.add = add
Set.__index.add = add

--------------------------------------------------------------------------------
-- Removes one element from a set.
--
-- @param {table} set
--     The set to remove from.
-- @param {*} element
--     The element to remove.
-- @returns {table}
--     The same set instance.
--------------------------------------------------------------------------------
local function remove(set, element)
    if not is(set) then
        error("bad argument #1 to 'Module:Set.remove' (set expected, got " .. type(set) .. ")", 2)
    end

    if contains(set, element) then
        internals[set].elements[encode(element)] = nil
        internals[set].size = internals[set].size - 1
    end

    return set
end

p.remove = remove
Set.__index.remove = remove

--------------------------------------------------------------------------------
-- Removes all elements from a set.
--
-- @param {table} set
--     The set to remove from.
-- @returns {table}
--     The same set instance.
--------------------------------------------------------------------------------
local function clear(set)
    if not is(set) then
        error("bad argument #1 to 'Module:Set.clear' (set expected, got " .. type(set) .. ")", 2)
    end

    for i = 1, size(set) do
        remove(set, nextElement(set))
    end

    return set
end

p.clear = clear
Set.__index.clear = clear

--------------------------------------------------------------------------------
-- Constructs a set with one or more elements.
--
-- @param {...*} ...
--     The elements to add.
-- @returns {table}
--     The set.
--------------------------------------------------------------------------------
local function of(...)
    local set = init()

    for i = 1, select("#", ...) do
        add(set, select(i, ...))
    end

    return set
end

p.of = of

--------------------------------------------------------------------------------
-- Constructs a set from a the key-value pairs of a table, adding keys that
-- correspond to "truthy" values and ignoring the rest.
--
-- @param {table} tbl
--     The table to iterate over.
-- @returns {table}
--     The set.
--------------------------------------------------------------------------------
local function fromPairs(tbl)
    local typ = type(tbl)

    if typ ~= "table" then
        error("bad argument #1 to 'Module:Set.fromPairs' (table expected, got " .. typ .. ")", 2)
    end

    local set = init()

    for key, value in pairs(tbl) do
        if value then
            add(set, key)
        end
    end

    return set
end

p.fromPairs = fromPairs

--------------------------------------------------------------------------------
-- Builds a new set from two existing ones, where each element of the former is
-- also an element of both of the latter.
--
-- @see [[wikipedia:Intersection (set theory)]]
--
-- @param {table} oldSet1
--     A set.
-- @param {table} oldSet2
--     A set.
-- @returns {table}
--     A new set.
--------------------------------------------------------------------------------
local function intersection(oldSet1, oldSet2)
    if not is(oldSet1) then
        error("bad argument #1 to 'Module:Set.intersection' (set expected, got " .. type(oldSet1) .. ")", 2)
    end

    if not is(oldSet1) then
        error("bad argument #2 to 'Module:Set.intersection' (set expected, got " .. type(oldSet2) .. ")", 2)
    end

    local newSet = init()

    for i = 1, size(oldSet1) do
        local element = nextElement(oldSet1)

        if contains(oldSet2, element) then
            add(newSet, element)
        end
    end

    return newSet
end

p.intersection = intersection
Set.__index.intersection = intersection
Set.__mul = intersection

--------------------------------------------------------------------------------
-- Builds a new set from two existing ones, where each element of the former is
-- also an element of either of the latter.
--
-- @see [[wikipedia:Union (set theory)]]
--
-- @param {table} oldSet1
--     A set.
-- @param {table} oldSet2
--     A set.
-- @returns {table}
--     A new set.
--------------------------------------------------------------------------------
local function union(oldSet1, oldSet2)
    if not is(oldSet1) then
        error("bad argument #1 to 'Module:Set.union' (set expected, got " .. type(oldSet1) .. ")", 2)
    end

    if not is(oldSet2) then
        error("bad argument #2 to 'Module:Set.union' (set expected, got " .. type(oldSet2) .. ")", 2)
    end

    local newSet = init()

    for _, oldSet in ipairs{oldSet1, oldSet2} do
        for i = 1, size(oldSet) do
            add(newSet, nextElement(oldSet))
        end
    end

    return newSet
end

p.union = union
Set.__index.union = union
Set.__add = union

--------------------------------------------------------------------------------
-- Builds a new set from two existing ones, where each element of the former is
-- also an element of the first of the latter, but not the second.
--
-- @see [[wikipedia:Complement (set theory)#Relative complement]]
--
-- @param {table} oldSet1
--     A set.
-- @param {table} oldSet2
--     A set.
-- @returns {table}
--     A new set.
--------------------------------------------------------------------------------
local function difference(oldSet1, oldSet2)
    if not is(oldSet1) then
        error("bad argument #1 to 'Module:Set.difference' (set expected, got " .. type(oldSet1) .. ")", 2)
    end

    if not is(oldSet1) then
        error("bad argument #2 to 'Module:Set.difference' (set expected, got " .. type(oldSet2) .. ")", 2)
    end

    local newSet = init()

    for i = 1, size(oldSet1) do
        local element = nextElement(oldSet1)

        if not contains(oldSet2, element) then
            add(newSet, element)
        end
    end

    return newSet
end

p.difference = difference
Set.__index.difference = difference
Set.__sub = difference

--------------------------------------------------------------------------------
-- Builds a new set from two existing ones, where each element of the former is
-- also an element of one of the latter, but not both.
--
-- @see [[wikipedia:Symmetric difference]]
--
-- @param {table} oldSet1
--     A set.
-- @param {table} oldSet2
--     A set.
-- @returns {table}
--     A new set.
--------------------------------------------------------------------------------
local function symmetricDifference(oldSet1, oldSet2)
    if not is(oldSet1) then
        error("bad argument #1 to 'Module:Set.symmetricDifference' (set expected, got " .. type(oldSet1) .. ")", 2)
    end

    if not is(oldSet1) then
        error("bad argument #2 to 'Module:Set.symmetricDifference' (set expected, got " .. type(oldSet2) .. ")", 2)
    end

    local newSet = init()

    for _, oldSet in ipairs{oldSet1, oldSet2} do
        for i = 1, size(oldSet) do
            local element = nextElement()

            if contains(newSet, element) then
                remove(newSet, element)
            else
                add(newSet, element)
            end
        end
    end

    return newSet
end

p.symmetricDifference = symmetricDifference
Set.__index.symmetricDifference = symmetricDifference
Set.__pow = symmetricDifference

--------------------------------------------------------------------------------
-- Checks if two sets have no elements in common.
--
-- @see [[wikipedia:Disjoint sets]]
--
-- @param {table} set1
--     A set.
-- @param {table} set2
--     A set.
-- @returns {boolean}
--     Whether the two sets have no elements in common.
--------------------------------------------------------------------------------
local function isDisjointFrom(set1, set2)
    if not is(set1) then
        error("bad argument #1 to 'Module:Set.isDisjointFrom' (set expected, got " .. type(set1) .. ")", 2)
    end

    if not is(set1) then
        error("bad argument #2 to 'Module:Set.isDisjointFrom' (set expected, got " .. type(set2) .. ")", 2)
    end

    for i = 1, size(set1) do
        if contains(set2, nextElement(set1)) then
            return false
        end
    end

    return true
end

p.isDisjointFrom = isDisjointFrom
Set.__index.isDisjointFrom = isDisjointFrom

--------------------------------------------------------------------------------
-- Checks if every element of one set is also an element of another.
--
-- @see [[wikipedia:Subset]]
--
-- @param {table} set1
--     A set.
-- @param {table} set2
--     A set.
-- @returns {boolean}
--     Whether every element of the first set is also an element of the second.
--------------------------------------------------------------------------------
local function isSubsetOf(set1, set2)
    if not is(set1) then
        error("bad argument #1 to 'Module:Set.isSubsetOf' (set expected, got " .. type(set1) .. ")", 2)
    end

    if not is(set1) then
        error("bad argument #2 to 'Module:Set.isSubsetOf' (set expected, got " .. type(set2) .. ")", 2)
    end

    for i = 1, size(set1) do
        if not contains(set2, nextElement(set1)) then
            return false
        end
    end

    return true
end

p.isSubsetOf = isSubsetOf
Set.__index.isSubsetOf = isSubsetOf
Set.__le = isSubsetOf

--------------------------------------------------------------------------------
-- Checks if every element of one set is also an element of another, but not the
-- inverse.
--
-- @see [[wikipedia:Subset]]
--
-- @param {table} set1
--     A set.
-- @param {table} set2
--     A set.
-- @returns {boolean}
--     Whether every element of the first set is also an element of the second,
--     but not the inverse.
--------------------------------------------------------------------------------
local function isProperSubsetOf(set1, set2)
    if not is(set1) then
        error("bad argument #1 to 'Module:Set.isProperSubsetOf' (set expected, got " .. type(set1) .. ")", 2)
    end

    if not is(set1) then
        error("bad argument #2 to 'Module:Set.isProperSubsetOf' (set expected, got " .. type(set2) .. ")", 2)
    end

    return size(set1) < size(set2) and isSubsetOf(set1, set2)
end

p.isProperSubsetOf = isProperSubsetOf
Set.__index.isProperSubsetOf = isProperSubsetOf
Set.__lt = isProperSubsetOf

--------------------------------------------------------------------------------
-- Checks if every element of one set is also an element of another, and vice
-- versa.
--
-- @param {table} set1
--     A set.
-- @param {table} set2
--     A set.
-- @returns {boolean}
--     Whether every element of the first set is also an element of the second,
--     and vice versa.
--------------------------------------------------------------------------------
local function equals(set1, set2)
    if not is(set1) then
        error("bad argument #1 to 'Module:Set.equals' (set expected, got " .. type(set1) .. ")", 2)
    end

    if not is(set1) then
        error("bad argument #2 to 'Module:Set.equals' (set expected, got " .. type(set2) .. ")", 2)
    end

    return size(set1) == size(set2) and isSubsetOf(set1, set2)
end

p.equals = equals
Set.__index.equals = equals
Set.__eq = equals

--------------------------------------------------------------------------------
-- Generates a string representation of a set.
--
-- @param {table} set
--     A set.
-- @returns {string}
--     A string representation of the set.
--------------------------------------------------------------------------------
local function tostring_(set)
    local tmp = {}

    for i = 1, size(set) do
        tmp[i] = tostring(nextElement(set))
    end

    return "Set { " .. table.concat(tmp, ", ") .. " }"
end

Set.__tostring = tostring_


-- Makes it slightly more convenient to construct set instances, by allowing
-- users to omit ".of".
return setmetatable(p, {
    __call = function (self, ...)
        return self.of(...)
    end,
})