Question
Dictionary indexing with Numpy/Jax
I'm writing an interpolation routine and have a dictionary which stores the function values at the fitting points. Ideally, the dictionary keys would be 2D Numpy arrays of the fitting point coordinates, np.array([x, y])
, but since Numpy arrays aren't hashable these are converted to tuples for the keys.
# fit_pt_coords: (n_pts, n_dims) array
# fn_vals: (n_pts,) array
def fit(fit_pt_coords, fn_vals):
pt_map = {tuple(k): v for k, v in zip(fit_pt_coords, fn_vals)}
...
Later in the code I need to get the function values using coordinates as keys in order to do the interpolation fitting. I'd like this to be within @jax.jit
ed code, but the coordinate values are of type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
, which can't be converted to a tuple. I've tried other things, like creating a dictionary key as (x + y, x - y)
, but again this requires concrete values, and calling .item()
results in an ConcretizationTypeError
.
At the moment I've @jax.jit
ed all of the code I can, and have just left this code un-jitted. It would be great if I could jit this code as well however. Are there any better ways to do the dictionary indexing (or better Jax-compatible data structures) which would allow all of the code to be jitted? I am new to Jax and still understading how it works, so I'm sure there must be better ways of doing it...