simple load balancer as an exercise
Inspired by an interview task shared by a friend of mine, I decided to write a simple load balancer in Python using different concurrency models. It seems to be more interesting than solving yet another LeetCode problem. After looking at it closer, though, I realized this exercise is rich in details and can be a great recruitment tool. A candidate can show knowledge of concurrency models, networking, the HTTP protocol, and still needs to code non-trivial code in 50 minutes. The solution can be a great starting point for further discussion about design decisions, trade-offs, further infrastructure integration, etc. Initially, I thought about writing solutions for each concurrency model, but soon realized it would make this post too long. Instead, I’ll focus on an asyncio solution, briefly mention other approaches, but I’ll leave them for another post.
Problem description
Create a load balancing server to handle HTTP requests and distribute them to microservices. The server should offer the following endpoints:
/register
- Receives parameters for URL path, IP address, and port. Upon receiving this request, the load balancer will start sending requests to the corresponding microservice.
Requests to other endpoints are forwarded to microservices based on a Round Robin load balancing scheme. Replies are then sent back to the client.
Example
- Microservice A registers: http://LOAD_BALANCER_HOST/register “/test”, ip address, port
- Microservice B registers: http://LOAD_BALANCER_HOST/register “/test2”, ip address, port
- Microservice C registers: http://LOAD_BALANCER_HOST/register “/test”, ip address, port
- Microservice D registers: http://LOAD_BALANCER_HOST/register “/test”, ip address, port
When a client calls http://LOAD_BALANCER_HOST/test
, the request will be forwarded to either A, C, or D.
Guidelines:
- Time limit: 50 minutes
- Focus on load balancer code; microservices implementation is not required.
- Emphasize code structure, readability, and production-ready quality.
- In-memory data structures (no need for a database).
- You can use any Python web framework you’re comfortable with (but do not use Django)
Design space
- A load balancer needs to handle many concurrent requests as efficiently as possible, so our solution needs to use some concurrency model. There are a couple of options in Python, e.g., asyncio, threads, and processes.
- Asyncio is single-threaded but uses non-blocking I/O and cooperative multitasking. It should do pretty well but will utilize only one CPU core. Context switching between coroutines should be cheaper than context switching using threads. Coroutines are also more memory-efficient than threads.
- We could potentially use processes. There are actually two approaches: using a process per request or having a pool of processes so called pre-fork server. The first option is easy to implement but inefficient since there is some overhead for starting a new process. The latter option is actually well established in form of gunicorn or uvicorn.
- a thread-based solution should perform ok, maybe even on par with asyncio solution if crafted carefully (LB is network bound application while Global Interpreter Lock GIL is most problematic when dealing with CPU-bound applications) but overall threads will consume more memory than coroutines and introduce more context switching putting more load on CPU. For load balancer which could potentially handle many concurrent connections, it’s not the best choice.
- A load balancer should be able to register new microservices at the same port as the HTTP server is listening to, which makes things a bit harder since each request needs to be inspected and forwarded to the target only if it’s not a registration request. This requires an L7 or application layer proxy rather than an L4 (transport) proxy, the latter being easier to implement.
- Load balancing using the round-robin algorithm is quite simple, but shared state will be involved, so it needs special attention. Not only because the list of targets will need to be modified during runtime, but also because the round-robin algorithm needs to be implemented in a way that guarantees proper behavior in a concurrent environment.
Solution
The idea is to use Starlette on top of Uvicorn (ASGI web server), setup routing for registration and all other paths and route them accordingly to round-robin order. To be honest I have a problem with this solution since it requires load balancer to parse each HTTP request. Then request is rewritten and send using httpx library (serialization of headers) and then response comes, httpx does some parsing again (deserialization of headers) and then the response is forwarded back to the client. Seems inefficient, it would be better to just peak into the headers and then redirect the rest of the stream to the target. On the other hand, the approach with Starlette seems easy to implement and doable in 50 minutes. I’m curious though where implementation based purely on loop.create_server and Streams — Python 3.12.0 documentation would take me, but I’ll leave it for another post.
Round-robin algorithm and data structure
First, we keep list of targets per path. Remember? We need to support multiple targets per path. When next target is needed we create infinite generator using itertools.cycle
and then just call next()
on it. Whenever new target is registered we just add it to the list of targets and itertools.cycle
will pick it up automatically after full cycle. Since asyncio is single threaded we don’t have to think much about synchronization issue. Order of execution of next()
and add()
calls is not that important for current spec.
from itertools import cycle
import pydantic
class UpstreamTarget(pydantic.BaseModel):
ip_address: str
port: int
path: str = "/"
class RoundRobinTargets:
def __init__(self, targets: List[UpstreamTarget] = None):
self._targets = {}
for target in targets or []:
self.add(target)
self._cycles = {}
def _reset_cycle(self):
self._cycle = cycle(self._targets)
def get_next(self, path: str):
if path not in self._targets:
raise ValueError(f"No targets registered for path: {path}")
path_cycle = self._cycles.setdefault(path, cycle(self._targets[path]))
next_item = next(path_cycle)
return next_item
def add(self, target: UpstreamTarget):
per_path_targets = self._targets.setdefault(target.path, [])
if target not in per_path_targets:
per_path_targets.append(target)
def __repr__(self):
return f"{self.__class__.__name__}({self._targets})"
Receiving and forwarding request
First, we create Starlette application. Whenever registration request comes we call registration handler which will update RoundRobinTargets instance.
def create_app():
allowed_methods = ["GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "PATCH"]
app = Starlette(
routes=[
Route("/", handle_proxy, methods=allowed_methods),
Route("/register", handle_registration, methods=["POST"]),
Route("/{path:path}", handle_proxy, methods=allowed_methods),
],
on_startup=[app_startup],
on_shutdown=[app_shutdown],
)
return app
For all other requests, we call a proxy handler that will forward the request to the upstream target and return the response to the client. Before the request is forwarded, we need to rewrite it. We need to change the hostname, port, and path. We also need to replace the host
header with the value of the upstream target. It’s worth removing the connection
header as well since it’s not going to be valid for the upstream.
To fully leverage asyncio, we use an asynchronous http client to forward the request to the upstream. Whenever a request comes with content, our load balancer is going to stream that content to the target instead of blocking on I/O. The same happens if the upstream replies with data; we stream it back to the client. This way, we can handle many concurrent requests with a small memory footprint.
While the streaming of response/request data sounds good in theory, in practice, it’s going to take some time to find the right way to do it using httpx and Starlette Request/Response objects and might be overkill for this exercise. I was willing to take this route since I was playing with Starlette and streaming before. If we know for sure that the LB is going to forward traffic with bigger payloads, then it’s worth considering streaming; for file transfers, it’s a must-have. Otherwise, it’s better to keep it simple and just forward the data without streaming.
async def get_proxied_response(client: HttpClient, incoming_req: Request, target: UpstreamTarget):
target_url = URL(hostname=target.ip_address, port=target.port)
scheme = target_url.scheme or "http"
target_url = incoming_req.url.replace(hostname=target_url.hostname, scheme=scheme, port=target_url.port, path="/")
logger.debug("Forwarding to: %s", target_url)
# Create a new request to the target server
endpoint = target_url.hostname
headers = get_sanitized_headers(incoming_req.headers)
headers["host"] = endpoint
has_content = int(headers.get("content-length", 0)) > 0
data = incoming_req.stream() if has_content else None
proxy_request = client.build_request(incoming_req.method, str(target_url), headers=headers, data=data)
response = await client.send(proxy_request)
return response
async def handle_proxy(request: Request):
global lb_targets
global http_client
logger.debug("Handling connection to: %s", request.url.path)
try:
target = lb_targets.get_next(request.url.path)
except ValueError:
raise HTTPException(status_code=503, detail="No targets available")
try:
async with httpx.AsyncClient() as client:
response = await get_proxied_response(client, request, target)
except httpx.ConnectTimeout:
raise HTTPException(status_code=504, detail="Connection to target timed out")
# Create a streaming response to send back to the client
if response.content:
return StreamingResponse(response.aiter_bytes(), status_code=response.status_code, headers=response.headers)
else:
return Response(status_code=response.status_code, headers=response.headers)
Optimizations
Since HTTP client is going to be used all the time I decided to configure httpx AsyncClient in a way that minimizes
repetitive calls (line 26 in snippet above). This way we can reuse connections to upstreams and avoid overhead of
establishing new connection for each request. Most important parameters are max_keepalive_connections
and
max_connections
. First one limits number of keepalive connections to single upstream, second one limits total
number of connections. Finding right parameters to get optimal saturation of network IO and CPU is another topic.
Easiest way to go for now would be to monitor app for request latency without setting any max connection limits. I’m
also defining connect_timeout
to support Gateway Timeout
error - if the connection to upstream takes too long we just return 504 to the client. Introducing wrapper around httpx.AsyncClient can be skipped, but one should be ready to talk about consequences of this approach during interview.
async def app_startup():
global http_client
global lb_targets
limits = httpx.Limits(max_keepalive_connections=100, max_connections=1000)
timeout = httpx.Timeout(None, connect=5)
http_client = httpx.AsyncClient(follow_redirects=True, limits=limits, timeout=timeout)
lb_targets = RoundRobinTargets()
async def app_shutdown():
global http_client
global lb_targets
await http_client.aclose()
What’s next
In the next posts I’d like to implement solution using forking and threads. I’d also like to explore how to implement it using asyncio low level API - minimizing need for external libraries. I’m also curious how solution in rust would look like. Finally, I’d like to compare performance of all solutions and see how they behave under load. Stay tuned.