A few months ago, I shared that I had built an AI-powered personalized news reader which I use (and still do) on a near-daily basis. Since that post, I’ve made a couple of major improvements (which I have just reflected in my public Github).
Switching to JAX
I previously chose Keras 3 for my deep learning algorithm architecture because of its ease of use as well as the advertised ability to shift between AI/ML backends (at least between Tensorflow, JAX, and PyTorch). With Keras creator Francois Chollet noting significant speed-ups just from switching backends to JAX, I decided to give the JAX backend a shot.
Thankfully, Keras 3 lived up to it’s multi-backend promise and made switching to JAX remarkably easy. For my code, I simply had to make three sets of tweaks.
First, I had to change the definition of my container images. Instead of starting from Tensorflow’s official Docker images, I instead installed JAX and Keras on Modal’s default Debian image and set the appropriate environmental variables to configure Keras to use JAX as a backend:
jax_image = (
modal.Image.debian_slim(python_version='3.11')
.pip_install('jax[cuda12]==0.4.35', extra_options="-U")
.pip_install('keras==3.6')
.pip_install('keras-hub==0.17')
.env({"KERAS_BACKEND":"jax"}) # sets Keras backend to JAX
.env({"XLA_PYTHON_CLIENT_MEM_FRACTION":"1.0"})
Code language: Python (python)
Second, because tf.data
pipelines convert everything to Tensorflow tensors, I had to switch my preprocessing pipelines from using Keras’s ops
library (which, because I was using JAX as a backend, expected JAX tensors) to Tensorflow native operations:
ds = ds.map(
lambda i, j, k, l:
(
preprocessor(i),
j,
2*k-1,
loglength_norm_layer(tf.math.log(tf.cast(l, dtype=tf.float32)+1))
),
num_parallel_calls=tf.data.AUTOTUNE
)
Code language: Python (python)
Lastly, I had a few lines of code which assumed Tensorflow tensors (where getting the underlying value required a .numpy()
call). As I was now using JAX as a backend, I had to remove the .numpy()
calls for the code to work.
Everything else — the rest of the tf.data
preprocessing pipeline, the code to train the model, the code to serve it, the previously saved model weights and the code to save & load them — remained the same! Considering that the training time per epoch and the time the model took to evaluate (a measure of inference time) both seemed to improve by 20-40%, this simple switch to JAX seemed well worth it!
Model Architecture Improvements
There were two major improvements I made in the model architecture over the past few months.
First, having run my news reader for the better part of a year now, I now have accumulated enough data where my strategy to simultaneously train on two related tasks (predicting the human rating and predicting the length of an article) no longer required separate inputs. This reduced the memory requirement as well as simplified the data pipeline for training (see architecture diagram below)
Secondly, I was successfully able to train a version of my algorithm which can use dot products natively. This not only allowed me to remove several layers from my previous model architecture (see architecture diagram below), but because the Supabase postgres database I’m using supports pgvector
, it means I can even compute ratings for articles through a SQL query:
UPDATE articleuser
SET
ai_rating = 0.5 + 0.5 * (1 - (a.embedding <=> u.embedding)),
rating_timestamp = NOW(),
updated_at = NOW()
FROM
articles a,
users u
WHERE
articleuser.article_id = a.id
AND articleuser.user_id = u.id
AND articleuser.ai_rating IS NULL;
Code language: SQL (Structured Query Language) (sql)
The result is much greater simplicity in architecture as well as greater operational flexibility as I can now update ratings from the database directly as well as from serving a deep neural network from my serverless backend.
Making Sources a First-Class Citizen
As I used the news reader, I realized early on that the ability to just have sorted content from one source (i.e. a particular blog or news site) would be valuable to have. To add this, I created and populated a new sources
table within the database to track these independently (see database design diagram below) which was linked to the articles
table.
I then modified my scrapers to insert the identifier for each source alongside each new article, as well as made sure my fetch calls all JOIN
‘d and pulled the relevant source information.
With the data infrastructure in place, I added the ability to add a source parameter to the core fetch URLs to enable single (or multiple) source feeds. I then added a quick element at the top of the feed interface (see below) to let a user know when the feed they’re seeing is limited to a given source. I also made all the source links in the feed clickable so that they could take the user to the corresponding single source feed.
<div class="feed-container">
<div class="controls-container">
<div class="controls">
${source_names && source_names.length > 0 && html`
<div class="source-info">
Showing articles from: ${source_names.join(', ')}
</div>
<div>
<a href="/">Return to Main Feed</a>
</div>
`}
</div>
</div>
</div>
Code language: HTML, XML (xml)
Performance Speed-Up
One recurring issue I noticed in my use of the news reader pertained to slow load times. While some of this can be attributed to the “cold start” issue that serverless applications face, much of this was due to how the news reader was fetching pertinent articles from the database. It was deciding at the moment of the fetch request what was most relevant to send over by calculating all the pertinent scores and rank ordering. As the article database got larger, this computation became more complicated.
To address this, I decided to move to a “pre-calculated” ranking system. That way, the system would know what to fetch in advance of a fetch request (and hence return much faster). Couple that with a database index (which effectively “pre-sorts” the results to make retrieval even faster), and I saw visually noticeable improvements in load times.
But with any pre-calculated score scheme, the most important question is how and when re-calculation should happen. Too often and too broadly and you incur unnecessary computing costs. Too infrequently and you risk the scores becoming stale.
The compromise I reached derived itself from the three ways articles are ranked in my system:
- The AI’s rating of an article plays the most important role (60%)
- How recently the article was published is tied with… (20%)
- How similar an article is with the 10 articles a user most recently read (20%
These factors lent themselves to very different natural update cadences:
- Newly scraped articles would have their AI ratings and calculated score computed at the time they enter the database
- AI ratings for the most recent and the previously highest scoring articles would be re-computed after model training updates
- On a daily basis, each article’s score was recomputed (focusing on the change in article recency)
- The article similarity for unread articles is re-evaluated after a user reads 10 articles
This required modifying the reader’s existing scraper and post-training processes to update the appropriate scores after scraping runs and model updates. It also meant tracking article reads on the users
table (and modifying the /read
endpoint to update these scores at the right intervals). Finally, it also meant adding a recurring cleanUp
function set to run every 24 hours to perform this update as well as others.
Next Steps
With some of these performance and architecture improvements in place, my priorities are now focused on finding ways to systematically improve the underlying algorithms as well as increase the platform’s usability as a true news tool. To that end some of the top priorities for next steps in my mind include:
- Testing new backbone models — The core ranking algorithm relies on Roberta, a model released 5 years ago before large language models were common parlance. Keras Hub makes it incredibly easy to incorporate newer models like Meta’s Llama 2 & 3, OpenAI’s GPT2, Microsoft’s Phi-3, and Google’s Gemma and fine-tune them.
- Solving the “all good articles” problem — Because the point of the news reader is to surface content it considers good, users will not readily see lower quality content, nor will they see content the algorithm struggles to rank (i.e. new content very different from what the user has seen before). This makes it difficult to get the full range of data needed to help preserve the algorithm’s usefulness.
- Creating topic and author feeds — Given that many people think in terms of topics and authors of interest, expanding what I’ve already done with Sources but with topics and author feeds sounds like a high-value next step
I also endeavor to make more regular updates to the public Github repository (instead of aggregate many updates I had already made into two large ones). This will make the updates more manageable and hopefully help anyone out there who’s interested in building a similar product.