Hacking observable plot into Python widgets
I recently discovered observable plot. It's a great (!) plotting library and it even has a nice Python port. But before I learned about this Python library I was looking at building a Python port myself.
This is the implementation that I ended up with:
import srsly
import anywidget
import traitlets
from jinja2 import Template
from uuid import uuid4
def plotty(plot_logic, **kwargs):
dataframes = {}
other = {}
for k, v in kwargs.items():
if isinstance(v, pl.DataFrame):
dataframes[k] = serialize(v, renderer="jsdom")
else:
if isinstance(v, str):
other[k] = f'"{v}"'
else:
other[k] = v
template_str = """
import * as Plot from "https://cdn.jsdelivr.net/npm/@observablehq/plot@0.6/+esm";
import * as d3 from "https://cdn.jsdelivr.net/npm/d3@7/+esm";
import * as arrow from 'https://cdn.jsdelivr.net/npm/apache-arrow@latest/+esm';
function readPolarsDataFrame(base64Value) {
// Decode base64 to ArrayBuffer
const binaryString = atob(base64Value);
const bytes = new Uint8Array(binaryString.length);
for (let i = 0; i < binaryString.length; i++) {
bytes[i] = binaryString.charCodeAt(i);
}
const arrayBuffer = bytes.buffer;
// The correct way to create a table from a buffer
const table = arrow.tableFromIPC(arrayBuffer);
return table;
}
function render({ model, el }) {
{% for key, value in other.items() %}
const {{ key }} = value;
{% endfor %}
{% for key, value in dataframes.items() %}
const {{ key }} = readPolarsDataFrame("{{ value }}");
{% endfor %}
const plot = {{plot_logic}};
el.append(plot);
}
export default { render };
"""
esm = Template(template_str).render(plot_logic=plot_logic.strip(), dataframes=dataframes, other=other)
class Widget(anywidget.AnyWidget):
_esm = esm
return Widget()
This implementation allows you to effectively just write the JS you might write normally for observable plot inside of a Python function and then this hacky snippet takes care of the rest. That means that you could write code that looks like this:
# Read in a Polars dataframe
bls = pl.read_csv("bls-metro-unemployment.csv").with_columns(date=pl.col("date").str.to_date())
# Refer to it, but write JS
plotty("""
Plot.plot({
y: {
grid: true,
label: "↑ Unemployment (%)"
},
marks: [
Plot.ruleY([0]),
Plot.lineY(bls, {x: "date", y: "unemployment", z: "division"})
]
})
""", bls=bls)
Here's what's kind of nice about this implementation:
- We can serialize the polars dataframe to a base64 arrow representation which observable plot natively supports. So no hacky "dataframe to JSON" is ever needed here.
- There's no translation needed between Python an JS. We just write the JS as a you would normally.
- Because we wrap it all with an anywidget, you're also able to use it in notebooks directly.
I am happy that I took this exercise, and really wanted to make sure that I do not forget about the implementation, but I have also not tested all the edge cases. If you're keen to explore this tool in a notebook, this Python lib is probably you best bet.
I also took the serialize
function that said library and applied it here. This is the code:
import base64
import io
from datetime import date
from typing import Any
import pandas as pd
import polars as pl
def serialize(data: Any, renderer: str) -> Any:
"""
Serialize a data object.
Parameters
----------
data : Any
data object to serialize.
renderer : str
renderer type.
Returns
-------
Any
serialized data object.
"""
# If polars DataFrame, serialize to Arrow IPC
if isinstance(data, pl.DataFrame):
value = pl_to_arrow(data)
if renderer == "jsdom":
value = base64.standard_b64encode(value).decode("ascii")
return value
# If pandas DataFrame, serialize to Arrow IPC
elif isinstance(data, pd.DataFrame):
value = pd_to_arrow(data)
if renderer == "jsdom":
value = base64.standard_b64encode(value).decode("ascii")
return value
# Else, keep as is
else:
return data
def pd_to_arrow(df: pd.DataFrame) -> bytes:
"""
Convert a pandas DataFrame to Arrow IPC bytes.
Parameters
----------
df : pd.DataFrame
pandas DataFrame to convert.
Returns
-------
bytes
Arrow IPC bytes.
"""
# Convert dates to timestamps
for colname in df.columns:
col = df[colname].dropna()
if col is not None and isinstance(col[0], date):
try:
df[colname] = pd.to_datetime(df[colname])
except ValueError:
pass
# Convert timestamps to millisecond units so that
# Plot will detect them as datetimes
datetime_columns = df.select_dtypes(include=["datetime64"]).columns
df[datetime_columns] = df[datetime_columns].astype("datetime64[ms]")
f = io.BytesIO()
df.to_feather(f, compression="uncompressed")
return f.getvalue()
def pl_to_arrow(df: pl.DataFrame) -> bytes:
"""
Convert a polars DataFrame to Arrow IPC bytes.
Parameters
----------
df : pl.DataFrame
polars DataFrame to convert.
Returns
-------
bytes
Arrow IPC bytes.
"""
# Convert dates and datetimes to millisecond units so that
# Plot will detect them as datetimes
df = df.with_columns(pl.col(pl.Datetime).cast(pl.Datetime("ms")))
df = df.with_columns(pl.col(pl.Date).cast(pl.Datetime("ms")))
f = io.BytesIO()
df_pd = df.to_pandas()
df_pd.to_feather(f, compression="uncompressed")
return f.getvalue()