v0.4.0 - compilation and QoL

The changelog revolves mostly around efficiency improvements, both for
the compilation times as well as execution.

- IMPORTANT !!! ProjectedGradients now runs on GPU. This was made
  possible by a whole suite of improvements w.r.t. to the value function
  calculation, compilation graph simplification (no more decisions on
  the hot-path vis-à-vis what to execute), and some use of
  jax.lax.stop_gradient to better indicate what we need/don't need.
- Both versions of ProjectedGradients (w or w/o early-stopping) benefit
  from this. In particular, the early-stopping also times out at a
  maximum number of iterations - this was not possible before since I
  didn't know about jnp.logical_and (JAX and python's `and` do not get
  along well).

TODO:

- standardize notation: move away from projected_gradients towards
  matg_gdm.
- Clean up the return-types of both versions - they are identical for
  now, but move to a dictionary for readability.
- testing via pytest. Making the execution identical across both
  versions is an end-goal. There are minor differences in stopping
  iterations, but nothing too serious.
- documentation and CI/CD.