c++ – Boost.Asio Server and RAII


I am trying to implement a network server application in C++ using Boost.Asio.

Here are the requirements I am trying to meet:

  • The Application creates only one instance of boost::io_context.
  • Single io_context is being run() by a shared Thread Pool. The number of threads is not defined.
  • Application can instantiate multiple Server objects. New Servers can be spawned and killed at any time.
  • Each Server can handle connections from multiple clients.

I am trying to implement RAII pattern for the Server class. What I want to guarantee is that when Server gets deallocated all of its connections are completely closed. There are 3 ways each connection can be closed:

  1. Client responds and there is no more work to be done in a connection.
  2. Server is being deallocated and causes all alive connections to close.
  3. Connection is killed manually by invoking stop() method.

I have arrived to a solution that seems to meet all of the criteria above but since Boost.Asio is still quite new to me I wanted to verify that what I am doing is correct. Also there are couple of things that I was specifically not 100% sure about:

  • I was trying to remove the mutex from the Server class and instead use a strand for all of the synchronisation but I couldn’t find a clear way to do it.
  • Because Thread Pool can consist of only 1 thread and this thread may be what’s calling a Server destructor I had to invoke io_context::poll_one() from the destructor to give a chance for all of the pending connections to complete the shutdown and prevent a potential deadlock.
  • I would welcome any other suggestions for improvements you could think of.

Anyways, here’s the code with some unit tests ( live version on Coliru: http://coliru.stacked-crooked.com/a/1afb0dc34dd09008 ):

#include <boost/asio/io_context.hpp>
#include <boost/asio/io_context_strand.hpp>
#include <boost/asio/executor.hpp>
#include <boost/asio/deadline_timer.hpp>
#include <boost/asio/dispatch.hpp>
#include <iostream>
#include <string>
#include <vector>
#include <memory>
#include <list>
using namespace std;
using namespace boost::asio;
using namespace std::placeholders;


class Connection;


class ConnectionDelegate
{
public:
    virtual ~ConnectionDelegate() { }
    
    virtual class executor executor() const = 0;
    virtual void didReceiveResponse(shared_ptr<Connection> connection) = 0;
};


class Connection: public enable_shared_from_this<Connection>
{
public:
    Connection(string name, io_context& ioContext)
    : _name(name)
    , _ioContext(ioContext)
    , _timer(ioContext)
    {
    }
    
    const string& name() const
    {
        return _name;
    }
    void setDelegate(ConnectionDelegate *delegate)
    {
        _delegate = delegate;
    }
    
    void start()
    {
        // Simulate a network request
        _timer.expires_from_now(boost::posix_time::seconds(3));
        _timer.async_wait(bind(&Connection::handleResponse, shared_from_this(), _1));
    }
    void stop()
    {
        _timer.cancel();
    }
    
private:
    string _name;
    io_context& _ioContext;
    boost::asio::deadline_timer _timer;
    ConnectionDelegate *_delegate;
    
    void handleResponse(const boost::system::error_code& errorCode)
    {
        if (errorCode == error::operation_aborted)
        {
            return;
        }
        dispatch(_delegate->executor(),
                 bind(&ConnectionDelegate::didReceiveResponse, _delegate, shared_from_this()));
    }
};


class Server: public ConnectionDelegate
{
public:
    Server(string name, io_context& ioContext)
    : _name(name)
    , _ioContext(ioContext)
    , _strand(_ioContext)
    {
    }
    ~Server()
    {
        stop();
        assert(_connections.empty());
        assert(_connectionIterators.empty());
    }
    weak_ptr<Connection> addConnection(string name)
    {
        auto connection = shared_ptr<Connection>(new Connection(name, _ioContext), bind(&Server::deleteConnection, this, _1));
        {
            lock_guard<mutex> lock(_mutex);
            _connectionIterators(connection.get()) = _connections.insert(_connections.end(), connection);
        }
        connection->setDelegate(this);
        connection->start();
        return connection;
    }
    
    vector<shared_ptr<Connection>> connections()
    {
        lock_guard<mutex> lock(_mutex);
        
        vector<shared_ptr<Connection>> connections;
        for (auto weakConnection: _connections)
        {
            if (auto connection = weakConnection.lock())
            {
                connections.push_back(connection);
            }
        }
        return connections;
    }
    void stop()
    {
        auto connectionsCount = 0;
        for (auto connection: connections())
        {
            ++connectionsCount;
            connection->stop();
        }
        
        while (connectionsCount != 0)
        {
            _ioContext.poll_one();
            connectionsCount = connections().size();
        }
    }
    
    // MARK: - ConnectionDelegate
    class executor executor() const override
    {
        return _strand;
    }
    void didReceiveResponse(shared_ptr<Connection> connection) override
    {
        // Strand to protect shared resourcess to be accessed by this method.
        assert(_strand.running_in_this_thread());
        
        // Here I plan to execute some business logic and I need both Server & Connection to be alive.
        std::cout << "didReceiveResponse - server: " << _name << ", connection: " << connection->name() << endl;
    }
    
private:
    typedef list<weak_ptr<Connection>> ConnectionsList;
    typedef unordered_map<Connection*, ConnectionsList::iterator> ConnectionsIteratorMap;
    
    string _name;
    io_context& _ioContext;
    io_context::strand _strand;
    ConnectionsList _connections;
    ConnectionsIteratorMap _connectionIterators;
    mutex _mutex;
    
    void deleteConnection(Connection *connection)
    {
        {
            lock_guard<mutex> lock(_mutex);
            auto iterator = _connectionIterators(connection);
            _connections.erase(iterator);
            _connectionIterators.erase(connection);
        }
        default_delete<Connection>()(connection);
    }
};


void testConnectionClosedByTheServer()
{
    io_context ioContext;
    auto server = make_unique<Server>("server1", ioContext);
    
    auto weakConnection = server->addConnection("connection1");
    assert(weakConnection.expired() == false);
    assert(server->connections().size() == 1);
    
    server.reset();
    assert(weakConnection.expired() == true);
}

void testConnectionClosedAfterResponse()
{
    io_context ioContext;
    auto server = make_unique<Server>("server1", ioContext);
    
    auto weakConnection = server->addConnection("connection1");
    assert(weakConnection.expired() == false);
    assert(server->connections().size() == 1);
    
    while (!weakConnection.expired())
    {
        ioContext.poll_one();
    }
    assert(server->connections().size() == 0);
}

void testConnectionClosedManually()
{
    io_context ioContext;
    auto server = make_unique<Server>("server1", ioContext);
    
    auto weakConnection = server->addConnection("connection1");
    assert(weakConnection.expired() == false);
    assert(server->connections().size() == 1);
    
    weakConnection.lock()->stop();
    ioContext.run();
    
    assert(weakConnection.expired() == true);
    assert(server->connections().size() == 0);
}

void testMultipleServers()
{
    io_context ioContext;
    auto server1 = make_unique<Server>("server1", ioContext);
    auto server2 = make_unique<Server>("server2", ioContext);

    auto weakConnection1 = server1->addConnection("connection1");
    auto weakConnection2 = server2->addConnection("connection2");

    server1.reset();
    assert(weakConnection1.expired() == true);
    assert(weakConnection2.expired() == false);
}

void testDeadLock()
{
    io_context ioContext;
    auto server = make_unique<Server>("server1", ioContext);
    
    auto weakConnection = server->addConnection("connection1");
    assert(weakConnection.expired() == false);
    assert(server->connections().size() == 1);
    
    auto connection = weakConnection.lock();
    server.reset(); // <-- deadlock, but that's OK, i will try to prevent it by other means
}


int main()
{
    testConnectionClosedByTheServer();
    testConnectionClosedAfterResponse();
    testConnectionClosedManually();
    // testDeadLock();
}

Kind Regards,
Marek