spark aggregations – using one aggregation result as in input to another aggregation (within the same groupby)

I have a huge dataset (billions of rows) that summarize user behavior. For e.g. event type and the number of times a user performed that action

Sample data looks like this
|user ID| event_type | count|
|——-|————|———–|
|user_1|prefix1_event1|1|
|user_1|prefix2_event1|2|
|user_1|prefix1_event2|1|
|user_1|prefix2_event2|2|
|user_2|prefix1_event1|1|
|user_2|prefix2_event1|2|
|user_2|prefix1_event2|1|
|user_2|prefix2_event2|2|

The event types have a standard suffix(the suffix are fixed and is small finite list) with many prefixes. I need to find for each user, and for each event suffix which event did the user perform most and how many times

So the result will look like this
|user ID| event_type | count|
|——-|————|———–|
|user_1|prefix2_event1|2|
|user_1|prefix2_event2|2|
|user_2|prefix2_event1|2|
|user_2|prefix2_event2|2|

I am struggling to do this as part of one aggregation. I can find the max count for each suffix for each user with something like this

max(when(col("event_type").endsWith("_event1"), col("count")))

I am unsure how to derive the corresponding event_type in the same aggregation. I tried something like this

collect_set(when((col("event_type").endsWith("_event1")) && (col("count") === max(when(col("event_type").endsWith("_event1"), col("count"))).getItem(0).as(colname)

Basically, I tried to reuse the query to get the max value as a subquery for the next query, but looks like spark does not like that. I get an error:

org.apache.spark.sql.AnalysisException: It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query

Any idea how I can achieve this as part of one aggregation (group by user ID)? PS: I am using SCALA