Using JAX with QuTiP Core Metrics and Entropy
This guide demonstrates how to utilize JAX’s grad and jit functionalities with QuTiP’s core.metrics and entropy modules by changing the backend to JAX.
Setting Up the JAX Backend
To enable JAX as the backend for QuTiP, you need to set the backend to jax using the use_jax_backend function. This allows you to use JAX’s grad and jit with QuTiP functions.
Note
This feature is not available in a released version of QuTiP. It is only available on an experimental development branch called dev-major in QuTiP.
import qutip
import qutip_jax
# Use JAX as the backend
qutip_jax.set_as_default()
Using jax.jit with QuTiP
jax.jit compiles your function to make it run faster. Here’s how to use jax.jit with functions from qutip.core.metrics and qutip.entropy:
### Example with fidelity from qutip.core.metrics
import jax
from qutip import basis
from qutip.core.metrics import fidelity
import qutip_jax
# Use JAX as the backend
qutip_jax.set_as_default()
# Define states
psi = basis(2, 0).to("jax")
phi = basis(2, 1).to("jax")
# JIT compile the fidelity function
jit_fidelity = jax.jit(fidelity)
# Compute fidelity using JIT compiled function
result = jit_fidelity(psi, phi)
print("Fidelity:", result)
### Example with entropy_vn from qutip.entropy
from qutip import ket2dm
from qutip.entropy import entropy_vn
import qutip_jax
# Use JAX as the backend
qutip_jax.set_as_default()
# Define a density matrix
rho = ket2dm(psi).to("jax")
# JIT compile the entropy_vn function
jit_entropy_vn = jax.jit(entropy_vn)
# Compute von Neumann entropy using JIT compiled function
result = jit_entropy_vn(rho)
print("Von Neumann Entropy:", result)
Using jax.grad with QuTiP
jax.grad computes the gradient of a function. Here’s how to use jax.grad with functions from qutip.core.metrics and qutip.entropy:
### Example with fidelity from qutip.core.metrics
To compute the gradient, you need a function that returns a scalar. Note that jax.grad for fidelity does not support oper states.
#### Gradient of fidelity for Ket/Bra States
import jax
from qutip import basis, fidelity
import qutip_jax
# Use JAX as the backend
qutip_jax.set_as_default()
# Define bra and ket states
bra_state = basis(2, 0).dag()
ket_state = basis(2, 0)
# Convert to JAX objects
bra_state_jax = bra_state.to("jax")
ket_state_jax = ket_state.to("jax")
# Define a fidelity function
def fidelity_jax(state1, state2):
return fidelity(state1, state2)
# Compute the gradient of the fidelity function with respect to the first argument
grad_fidelity = jax.grad(fidelity_jax, argnums=0)
# Calculate the gradient
grad_result = grad_fidelity(bra_state_jax, ket_state_jax)
print("Gradient of Fidelity:", grad_result)
### Example with trace_dist from qutip.core.metrics
The trace_dist function supports oper states for gradient computation.
from qutip import rand_dm
from qutip.core.metrics import trace_dist
import qutip_jax
# Use JAX as the backend
qutip_jax.set_as_default()
# Define an operator state
oper_state = rand_dm(2)
ket_state = basis(2, 0)
# Convert to JAX object
oper_state_jax = oper_state.to("jax")
ket_state_jax = ket_state.to("jax")
# Define a trace distance function
def trace_dist_jax(state1, state2):
return trace_dist(state1, state2)
# Compute the gradient of the trace distance function with respect to the first argument
grad_trace_dist = jax.grad(trace_dist_jax, argnums=0)
# Calculate the gradient
grad_result = grad_trace_dist(oper_state_jax, ket_state_jax)
print("Gradient of Trace Distance:", grad_result)
Changing Back to Default Backend
If you want to switch back to the default backend (NumPy), use the following:
qutip_jax.set_as_default(revert = True)