Querying better with FastAPI and SQLAlchemy

Deepjyoti Barman @deepjyoti30
Dec 14, 2021 • 2:39 PM UTC
Post cover

It's been a while that I have been working on FastAPI along with SQLAlchemy. Both the frameworks are pretty good and go well along with each other. One of the things that keeps on coming up when I start out building an API with the above frameworks is when queries end up with no results.

There's times when somebody tries to send a search query with some term and the term would be invalid ending with a None result. So, the ideal behaviour of FastAPI is to raise a 404 error for that. Now think of that with 12 different items. There would be need for 12 separate methods since the models would differ and all the methods would have the same 404 error.

After looking around for a while, I came up with a solution for just that.

What is it?

Here's an idea that came to my mind. Rather than repeating the same code of running a query (through SQLAlchemy) and then checking if the item is None and then raising a 404 error, what if we can do that right from the query call itself?

My initial thought was to just get rid of all the redundant calls to raise a 404 exception everytime a query ends up with None result. That is hard to achieve because every query can have different models that the query is being run on; different query conditions based on which results are returned.

This is when I started looking into SQLALchemy's source code (yet another reason to go Open Source ;-)).

What I found

So, we can pretty easily inherit the query class and override the query methods right? One of the most typical use cases of 404 errors would be when the first() method is called on a query. This method is expected to return the first result. However, if the query doesn't match any results, it just returns a None value.

Now that is well and good, but how about a method that actually raises a 404 Not Found error if the item is not found?

We can do just that with the following piece of code:

"""
Import the Query class.
Import the FastAPI exception to raise.
"""

from sqlalchemy.orm import Query
from fastapi.exceptions import HTTPException


class FastAPIQuery(Query):

    # Define a custom method that raises a 404 exception if
    # the item fetched is None
    def first_with_error(self):
        item = self.first()

        if item is None:
            raise HTTPException(
                status_code=404,
                detail="Item not found!"
            )

        return item

In the above code, what is happening is, the original Query class is imported and a new custom query class is created that is inherited from the Query class. This new class exposes a public method first_with_error that returns the first item of the query search. If the item is None then a FastAPI exception of 404 is raised.

This essentially means that if the item is not found, a 404 will be automatically raised by the query without any extra code.

Okay. How do I add it to SQLAlchemy session?

Well, exactly! All the calls to the query is mostly in the following way:

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from local_model_file import CustomModel

engine = create_engine("SQL connection URI")
db = sessionmaker(
    autocommit=False, autoflush=False, bind=engine)

query = db.query(CustomModel).filter(CustomModel.name == "nana")

# Print the class of the query
print(type(query))  # Should print sqlalchemy.query.Query

item_found = query.first()

In the above, there is no mention of a query class. This is because sqlalchemy automatically creates a query from the query method.

Afraid not, this doesn't mean we cannot use our custom class. The good developers of SQLALchemy thought about some possibility like this and they added a param in the sessionmaker method.

This param query_cls takes a class and when the query method is called it creates an object from the class passed by query_cls field.

So if we want to add our custom FastAPIQuery class to the above code, we can do something like this:

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from local_model_file import CustomModel
from custom_file import FastAPIQuery

engine = create_engine("SQL connection URI")
db = sessionmaker(
    autocommit=False, autoflush=False, bind=engine, query_cls=FastAPIQuery)

query = db.query(CustomModel).filter(CustomModel.name == "nana")

# Print the class of the query
print(type(query))  # Should print custom_file.FastAPIQuery

# Now we can use our custom method that will
# raise a 404 error if the item is not found.
item_found = query.first_with_error()

Yes, it is as simple as that. OOP is pretty fun, especially when it gives ways to endless possibilities like the one above.

Bonus: Add query table name to error

Okay, so the above code works fine. But let's go back to the example in the beginning. We have 12 different calls to the query endpoint and all the 12 calls use 12 different tables that the query is running on. So in that case, it would be better to return a proper 404 error with some detail right?

Something like, if a query is run on the user table, our error message should say Item 'user' not found!. Or even better, if our query is a join on more than one table, eg: user and order, our error should say Item 'user, order' not found!. I know that this looks cool.

How do we do that?

Note: This part is kind of a hack!

So SQLAlchemy, ofcourse, keeps track of the tables that the query has to run on. We will just take advantage of that to curate a nice 404 message.

So the Query class has a private variable named _raw_columns and this contains a list of Table items. These items essentially keep the name of the table and other stuff like metadata of the table.

What we will do is we will iterate this variable to extract all the table names from the query in the following way:

class FastAPIQuery(Query):
    ...
    def __get_table_names(self) -> List:
        """
        Get the table names from the self.

        Just return the list of names as is. They
        can also be coupled into a string by the
        caller method.
        """
        return [table.name for table in self._raw_columns]

And then we update the first_with_error method in the following way to return a better message:

class FastAPIQuery(Query):
    ...
    def first_with_error(self):
        ...
        item = self.first()

        if item is None:
            table_names = ", ".join(self.__get_table_names())
            raise HTTPException(
                status_code=404,
                detail=f"Item '{table_names}' not found!"
            )

        return item

As can be seen above, the error message will now contain the names of the table that the query ran on. This is kind of a nice hack to use if there are a lot of tables. The first_with_error method comes in handy to get the first item. Similarly other methods can be wrapped around to make them work better with FastAPI.

It is always a good idea to use a wrapper instead of overwriting the original method!

Discussion