Websocket server with pings using rust & tokio-tungstenite

I’m new to rust, so I wonder if anyone could help code review this snippet. This is a simple implementation of a websocket server in rust using tokio-tungstenite. In particular, I wanted to detect on the server side whether the client has stopped responding to pings within a certain timeout, and that feature wasn’t implemented in the crate directly so I had to build it in. I’ve tested so far with just a simple websocket client in Python, and I think it works, but maybe the experts here will be able to find any issues.

Thanks!

use anyhow::Result;
use async_channel::{bounded, Sender};
use async_trait::async_trait;
use futures::{SinkExt, StreamExt};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{oneshot, Mutex};
use tokio::time::{timeout, Instant};
use tungstenite::Message;

fn get_random_buf() -> Result<(u8; 4), getrandom::Error> {
    let mut buf = (0u8; 4);
    getrandom::getrandom(&mut buf)?;
    Ok(buf)
}

#(async_trait)
trait Handler {
    async fn on_connect(&self) -> Option<Message>;
    async fn on_msg(&self, msg: Message) -> Option<Message>;
}

async fn accept_connection(
    stream: TcpStream,
    handler: Arc<(dyn Handler + Send + Sync)>,
) -> Result<()> {
    let ping_interval = Duration::from_millis(1000);
    let ping_timeout = Duration::from_millis(2000);

    let mut ws_stream = tokio_tungstenite::accept_async(stream)
        .await
        .expect("Error during the websocket handshake occurred");

    let mut ping_ticker = tokio::time::interval_at(Instant::now() + ping_interval, ping_interval);
    let (hangup_tx, hangup_rx) = bounded(1);
    let pongs: Arc<Mutex<HashMap<_, oneshot::Sender<()>>>> = Arc::new(Mutex::new(HashMap::new()));

    async fn await_pongs(
        pongs: oneshot::Receiver<()>,
        hangup_tx: Sender<()>,
        ping_timeout: Duration,
    ) -> Result<()> {
        if let Err(_) = timeout(ping_timeout, async move {
            if let Ok(pong) = pongs.await {
                return Some(pong);
            }
            None
        })
        .await
        {
            hangup_tx.send(()).await?;
        }
        Ok(())
    }

    // Run the on-connect handler
    if let Some(resp) = handler.on_connect().await {
        ws_stream.send(resp).await?;
    }

    loop {
        tokio::select! {
            msg = ws_stream.next() => {
                match msg {
                    Some(msg) => {
                        let msg = msg?;
                        if msg.is_text() || msg.is_binary() {
                            let handler = handler.clone();
                            if let Some(resp) = handler.on_msg(msg).await {
                                ws_stream.send(resp).await?;
                            }
                        } else if msg.is_pong() {
                            let data = msg.into_data();
                            let mut guard = pongs.lock().await;
                            if let Some(channel) = guard.remove(&data) {
                                channel.send(()).expect("failed to send");
                            }
                        } else if msg.is_close() {
                            break;
                        }
                    }
                    None => break,
                }
            }
            _ = ping_ticker.tick() => {
                let data = get_random_buf()?.to_vec();
                let (pong_tx, pong_rx) = oneshot::channel();
                ws_stream.send(Message::Ping(data.clone())).await?;
                pongs.lock().await.insert(data, pong_tx);
                tokio::spawn(await_pongs(pong_rx, hangup_tx.clone(), ping_timeout.clone()));
            },
            _ = hangup_rx.recv() => {
                break;
            },
        }
    }

    Ok(())
}

struct MyHandler {}

#(async_trait)
impl Handler for MyHandler {
    async fn on_connect(&self) -> Option<Message> {
        println!("on connect");
        Some(Message::Text("Responding to onconnect".to_owned()))
    }
    async fn on_msg(&self, msg: Message) -> Option<Message> {
        println!("on msg {:#?}", msg);
        Some(Message::Close(None))
    }
}

#(tokio::main)
async fn main() -> Result<()> {
    let try_socket = TcpListener::bind("127.0.0.1:8080").await;
    let handler = Arc::new(MyHandler {});
    let listener = try_socket.expect("Failed to bind");
    while let Ok((stream, _)) = listener.accept().await {
        let handler = handler.clone();
        tokio::spawn(accept_connection(stream, handler));
    }

    Ok(())
}
```