Using JAX with QuTiP Core Metrics and Entropy

This guide demonstrates how to utilize JAX’s grad and jit functionalities with QuTiP’s core.metrics and entropy modules by changing the backend to JAX.

Setting Up the JAX Backend

To enable JAX as the backend for QuTiP, you need to set the backend to jax using the use_jax_backend function. This allows you to use JAX’s grad and jit with QuTiP functions.

Note

This feature is not available in a released version of QuTiP. It is only available on an experimental development branch called dev-major in QuTiP.

import qutip
import qutip_jax

# Use JAX as the backend
qutip_jax.set_as_default()

Using jax.jit with QuTiP

jax.jit compiles your function to make it run faster. Here’s how to use jax.jit with functions from qutip.core.metrics and qutip.entropy:

### Example with fidelity from qutip.core.metrics

import jax
from qutip import basis
from qutip.core.metrics import fidelity
import qutip_jax

# Use JAX as the backend
qutip_jax.set_as_default()

# Define states
psi = basis(2, 0).to("jax")
phi = basis(2, 1).to("jax")

# JIT compile the fidelity function
jit_fidelity = jax.jit(fidelity)

# Compute fidelity using JIT compiled function
result = jit_fidelity(psi, phi)
print("Fidelity:", result)

### Example with entropy_vn from qutip.entropy

from qutip import ket2dm
from qutip.entropy import entropy_vn
import qutip_jax

# Use JAX as the backend
qutip_jax.set_as_default()

# Define a density matrix
rho = ket2dm(psi).to("jax")

# JIT compile the entropy_vn function
jit_entropy_vn = jax.jit(entropy_vn)

# Compute von Neumann entropy using JIT compiled function
result = jit_entropy_vn(rho)
print("Von Neumann Entropy:", result)

Using jax.grad with QuTiP

jax.grad computes the gradient of a function. Here’s how to use jax.grad with functions from qutip.core.metrics and qutip.entropy:

### Example with fidelity from qutip.core.metrics

To compute the gradient, you need a function that returns a scalar. Note that jax.grad for fidelity does not support oper states.

#### Gradient of fidelity for Ket/Bra States

import jax
from qutip import basis, fidelity
import qutip_jax

# Use JAX as the backend
qutip_jax.set_as_default()

# Define bra and ket states
bra_state = basis(2, 0).dag()
ket_state = basis(2, 0)

# Convert to JAX objects
bra_state_jax = bra_state.to("jax")
ket_state_jax = ket_state.to("jax")

# Define a fidelity function
def fidelity_jax(state1, state2):
    return fidelity(state1, state2)

# Compute the gradient of the fidelity function with respect to the first argument
grad_fidelity = jax.grad(fidelity_jax, argnums=0)

# Calculate the gradient
grad_result = grad_fidelity(bra_state_jax, ket_state_jax)
print("Gradient of Fidelity:", grad_result)

### Example with trace_dist from qutip.core.metrics

The trace_dist function supports oper states for gradient computation.

from qutip import rand_dm
from qutip.core.metrics import trace_dist
import qutip_jax

# Use JAX as the backend
qutip_jax.set_as_default()

# Define an operator state
oper_state = rand_dm(2)
ket_state = basis(2, 0)

# Convert to JAX object
oper_state_jax = oper_state.to("jax")
ket_state_jax = ket_state.to("jax")

# Define a trace distance function
def trace_dist_jax(state1, state2):
    return trace_dist(state1, state2)

# Compute the gradient of the trace distance function with respect to the first argument
grad_trace_dist = jax.grad(trace_dist_jax, argnums=0)

# Calculate the gradient
grad_result = grad_trace_dist(oper_state_jax, ket_state_jax)
print("Gradient of Trace Distance:", grad_result)

Changing Back to Default Backend

If you want to switch back to the default backend (NumPy), use the following:

qutip_jax.set_as_default(revert = True)