performance tuning – Efficient way to partition a list into equivalence classes given an equivalence class function (not an equivalence relation)?

Suppose I have a list of elements and a function f which maps each element of list to an equivalence class of elements (which should be a sublist of list). What’s an efficient way to partition list into these equivalence classes?

(Note that this is not the same question as this one, as that assumes we start with an equivalence relation equiv mapping pairs of elements of list to True or False.)

A naive way would be to simply use GatherBy(list, f), but if the equivalence classes are big or f is hard to compute, this is wasting a lot of computational resources—we should be able to check to see if each successive element of list is in an already-produced equivalence class, and skip it if so.

Here’s one attempt that naively implements that idea (and performs no checks to make sure f behaves the way you say it does):

EquivalenceClasses0(list_List, f_) := Module({classes = {}},
  Scan(x |-> If(! MemberQ(classes, x, {2}), AppendTo(classes, f(x))), list);
  classes)

Here’s an approach that does it more functionally:

EquivalenceClasses1(list_List, f_) :=
  Fold({classes, x} |-> If(! MemberQ(classes, x, {2}), Append(classes, f(x)), classes), {}, list)

These seem to perform very similarly. But, obviously, I don’t like the use of Append and AppendTo. However, Sow and Reap can’t be used on the top level, because each computation depends on the previous ones, so we’d have to continuously Reap (which would defeat the purpose, I think).

I tried a few alternative approaches: using a Flat head instead of Append, customizing the MemberQ function and using nested heads of the form w(x_w, y_), then collapsing them at the end, as well as naively “allocating memory” for the list via a constant array of the same length as list, keeping track of how big each class is, and then using TakeList to split it up correctly. But I feel as though I’m missing a “nice” or “canonical” way to do this. Any ideas? (I’ll provide an example problem to benchmark things and my attempts so far below, if you’re interested!)


Examples, benchmarking functions & attempts so far

So, here’s an example of the sort of f I’m talking about, which gets the set of all row and column permutations of a given matrix (i.e. the orbit under the standard action of $S_n times S_k$, given bases):

MatrixPermutations(m_?MatrixQ) := 
 Sort @ DeleteDuplicates @ Flatten(Map(Transpose)@*Permutations@*Transpose /@ Permutations(m), 1)

(* Example: *)
MatrixForm /@ MatrixPermutations({{1,0},{1,1}})

(* Row and column permutations of 2x2 0-1 matrices: *)
Map(MatrixForm) /@ EquivalenceClasses1(Tuples({0, 1}, {2, 2}), MatrixPermutations)

Here are all the EquivalenceClass functions I came up with in a block, including those above.

(* Procedural, AppendTo: *)

EquivalenceClasses0(list_List, f_) := Module({classes = {}},
  Scan(x |-> If(! MemberQ(classes, x, {2}), AppendTo(classes, f(x))), list);
  classes)

(* Functional, Append: *)

EquivalenceClasses1(list_List, f_) :=
  Fold({classes, x} |-> If(! MemberQ(classes, x, {2}), Append(classes, f(x)), classes), {}, list)

(* Functional, nesting wrapper with new MemberQ: *)

wMemberQ(w(), elem_) = False;

wMemberQ(w(x_, y_), elem_) := MemberQ(y, elem, 1) || wMemberQ(x, elem)

EquivalenceClasses2(list_List, f_) := 
 Block({w}, (w = Sequence; #) & @ {Fold({wclasses, x} |-> 
      If(! wMemberQ(wclasses, x), w(wclasses, f(x)), wclasses), w(), list)})

(* Functional, Flat attribute for appending: *)

SetAttributes(ww, Flat)

EquivalenceClasses3(list_List, f_) := 
 List @@ 
  Fold({wclasses, x} |-> If(! MemberQ(wclasses, x, {2}), ww(wclasses, f(x)), wclasses), 
       ww(), list)

(* Procedural, line up results in pre-allocated list with Set and Span, then TakeList: *)

EquivalenceClasses4(list_List, f_) := 
 Module({v, s, n = 1, n1 = 0, i = 1, nlist = ConstantArray(0, Length(list))},
  s = ConstantArray(v, Length(list)); 
  Scan(x |-> 
    If(!MemberQ(s, x, 1),
       (class |-> (s((#1 ;; #2)) = class) &(n, n + (n1 = Length(class)) - 1))@f(x);
         n = n + n1; nlist((i)) = n1; i++;),
    list);
  TakeList(s, Take(nlist, i - 1)))

(* Functional, line up results in pre-allocated list with SubsetMap and Range, then TakeList: *)

EquivalenceClasses5(list_List, f_) := 
 Module({n = 1, n1 = 0, i = 1, nlist = ConstantArray(0, Length(list))}, 
  TakeList(#, Take(nlist, i - 1)) & @
   Fold({s, x} |-> 
     If(!MemberQ(s, x, 1),
        (n = n + n1; nlist((i)) = n1; i++; #) &(
          (class |-> SubsetMap(class &, s, Range(n, n + (n1 = Length(class)) - 1)))@f(x)), s), 
    ConstantArray(0, Length(list)),
    list))

(* Procedural, use individual definitions to hold values by Set and MapIndexed, then TakeList: *)

EquivalenceClasses6(list_List, f_) := 
 Module({v, s, nminus = 0, n1 = 0, i = 1, nlist = ConstantArray(0, Length(list))},
  s = Array(v, Length(list)); 
  Scan(x |->
    If(!MemberQ(s, x, 1),
       MapIndexed(
         Construct(Set, v(# + nminus) & @@ #2, #1) &, (n1 = Length(#); #) &@f(x),
         {1}); nminus = nminus + n1; nlist((i)) = n1; i++;), 
    list);
  TakeList(s, Take(nlist, i - 1)))

(* Procedural, use individual definitions to hold values by broadcasted Set, then TakeList: *)

EquivalenceClasses7(list_List, f_) := 
 Module({v, s, n = 1, n1 = 0, i = 1, nlist = ConstantArray(0, Length(list))},
  s = Array(v, Length(list)); 
  Scan(x |-> 
    If(!MemberQ(s, x, 1), 
       Construct(Set, Array(v, n1 = Length(#), n), #) &@f(x);
       n = n + n1; nlist((i)) = n1; i++;),
    list);
  TakeList(s, Take(nlist, i - 1)))

Check they agree on some input:

heads = {EquivalenceClasses0, EquivalenceClasses1, 
   EquivalenceClasses2, EquivalenceClasses3, EquivalenceClasses4, 
   EquivalenceClasses5, EquivalenceClasses6, EquivalenceClasses7};

t = Tuples({0, 1}, {3, 3});

SameQ @@ Through(heads(t, MatrixPermutations))

(* Output should be True *)

I timed these and visualized the times with the following code, plotting input size along the x axis and time on the y axis (both on a log scale).

(* Function that enumerates pairs of integers up to symmetry
   (for use in the second argument of Tuples): *)

symmetricGrid2DStep(1) = {1, 1};

symmetricGrid2DStep(n_Integer) := 
 symmetricGrid2DStep(n) =
   (If(#1 + 1 > #2 - 1, {1, #1 + #2}, {#1 + 1, #2 - 1}) & @@ symmetricGrid2DStep(n - 1))

(* Function which produces lists of pairs {<length of input>, <time spent>} for each head: *)

Get01MatrixTimings(heads : {__}, sizecutoff_Integer : 20000, stepcutoff_Integer : 20, retime : (True | False) : False) :=
  (If(retime, 
   Quiet(Get01MatrixTimings(heads, sizecutoff, stepcutoff, False) =., Unset::norep));
   
  Get01MatrixTimings(heads, sizecutoff, stepcutoff, False) =

   Module({t, step, n = 4, counting = False},
     First /@ Last @ Reap(
       While(n <= stepcutoff,
        step = symmetricGrid2DStep(n); 
        If(! MemberQ(step, 1) && 2^(Times @@ step) <= sizecutoff,
         counting = False; t = Tuples({0, 1}, symmetricGrid2DStep(n));
         Do(
          With({head = heads((p))}, 
           Sow({2^(Times @@ step), First@RepeatedTiming(head(t, MatrixPermutations);)}, 
             head);),
          {p, Length(heads)}),
         If(First@step == 1, If(counting, Break(), counting = True)));
         n++;);
       If(n == stepcutoff, Print("stepcutoff met"));,
       heads)
    ))

LogLogPlot01MatrixTimings(heads : {__}, sizecutoff_Integer : 20000, stepcutoff_Integer : 20, retime : (True | False) : False, opts___Rule) := 
 ListLogLogPlot(
  Get01MatrixTimings(heads, sizecutoff, stepcutoff, retime), 
  PlotLegends -> heads, Joined -> True, PlotMarkers -> Automatic, 
  opts)

Which produced the following plot (don’t run this—it’ll take a while!):

LogLogPlot01MatrixTimings(heads, True, PlotRange -> All, ImageSize -> Large)

A log-log plot showing functions 0, 1, 2, and 3 having middle-of-the-road performance, functions 6 and 7 having worse performance, and functions 4 and five having the best performance, with 4 winning by a hair.

So it seems that of my attempts, EquivalenceClasses4 is the winner by a hair! That was

EquivalenceClasses4(list_List, f_) := 
 Module({v, s, n = 1, n1 = 0, i = 1, 
   nlist = ConstantArray(0, Length(list))},
  s = ConstantArray(v, Length(list)); 
  Scan(x |-> 
    If(!MemberQ(s, x, 1),
       (class |-> (s((#1 ;; #2)) = class) &(n, n + (n1 = Length(class)) - 1))@f(x);
         n = n + n1; nlist((i)) = n1; i++;),
    list);
  TakeList(s, Take(nlist, i - 1)))

It’s interesting that the procedural one beats the functional version ever so slightly. Note that it does manage to beat GatherBy by a little, despite not being nice and C-optimized:

SameQ @@ Through({EquivalenceClasses4, GatherBy}(t, MatrixPermutations))

(* Output: True *)

LogLogPlot01MatrixTimings({EquivalenceClasses4, GatherBy}, PlotRange -> All, ImageSize -> Large)

The same plot as earlier, but it's just GatherBy being beaten by EquivalenceClasses4 a little.

(These performance graphs could be improved; I’m sure there’s some dependence on how big the equivalence classes are (and thus how many of them there are).)

Still, EquivalenceClasses4 looks pretty ugly and ad-hoc. Can it be improved? And am I missing some super-obvious implementation? 👀

(To easily check your own function against my best one, redefine heads to include it—e.g. heads = {EquivalenceClasses4, yourFunction};—check it agrees (t = Tuples({0, 1}, {3, 3}); SameQ @@ Through(heads(t, MatrixPermutations)))—and then run some RepeatedTimings on it, or run LogLogPlot01MatrixTimings(heads, True, PlotRange -> All, ImageSize -> Large).)