Using the JAX version

As mentioned in the README, there also exists a version in JAX, which is a high-performant language framework in Python that supports JIT-compilation, GPU/TPU accelerated framework, automatic differentiation, and parallelization.

This is implemented to be applied in the context of NIFTy. However, it can also be used to synthesise traces normally. This is generally not recommended, however, as:

  • For most general cases, the numpy

version is sufficient in terms of performance,

  • it can be annoying to

install jax in your machine, especially if you want to use the GPU features

  • there may be memory issues related to it for large origin showers (fix in progress!)

  • some features may not be directly available.

Nevertheless, we provide the user with a way to run and compare the JAX version with the numpy version here.

import jax

# IMPORTANT! set the device and enable x64 for precision
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

import jax.numpy as jnp  # numpy API from jax
import numpy as np

import matplotlib.pyplot as plt
import smiet.numpy as smiet_np
import smiet.jax as smiet_jax  # will fail if JAX is not installed

from smiet import units

Paths and setup

# Selections (to be put in main)
ANT_DEBUG = None
FREQ = [30, 500, 100]

# Variables
shower_path = '/home/kwatanabe/Projects/radio-ift/radio_resources/template_synthesis/origin_library/base'
shower_origin = "000100"
shower_target = "000111"

Synthesising a single origin shower from a SlicedShower with JAX

origin_jax = smiet_jax.SlicedShower(f'{shower_path}/SIM{shower_origin}.hdf5')
target_jax = smiet_jax.SlicedShower(f'{shower_path}/SIM{shower_target}.hdf5')

synthesis_jax = smiet_jax.TemplateSynthesis(freq_ar=FREQ)
synthesis_jax.make_template(origin_jax)
synth_geo_jax, synth_ce_jax = synthesis_jax.map_template(target_jax)
synth_dcore_jax = np.linalg.norm(synthesis_jax.antenna_information["position_showerplane"][:,:2], axis=1)

Get the target CoREAS slice traces from the JAX version

geo_target, ce_target = target_jax.get_traces_geoce()  # already bandpass filtered to [30, 500] MHz
geo_target = np.sum(geo_target, axis=-1)
ce_target = np.sum(ce_target, axis=-1)

Do the same thing for NumPy

origin_np = smiet_np.SlicedShower(f'{shower_path}/SIM{shower_origin}.hdf5')
target_np = smiet_np.SlicedShower(f'{shower_path}/SIM{shower_target}.hdf5')

synthesis_np = smiet_np.TemplateSynthesis(freq_ar=FREQ)
synthesis_np.make_template(origin_np)
synth_geo_np, synth_ce_np = synthesis_np.map_template(target_np)

Compare both implementations

for ant_idx, ant_name in enumerate(synthesis_jax.get_antenna_names()):

    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    ax = ax.flatten()

    # find the antenna index comparing the distance to the core,
    # since the antenna positions may not exactly match
    ant_idx_target = np.argmin(
        np.abs(target_jax.antenna_array["dis_to_core"] - synth_dcore_jax[ant_idx])
    )

    time_axis = np.arange(geo_target.shape[1]) * target_jax.delta_t + target_jax.trace_times[ant_idx_target,0]

    ax[0].plot(
        time_axis,
        geo_target[ant_idx_target,:] / (units.microvolt / units.m),
        c='k', label='CoREAS',
        alpha=0.5
    )
    ax[1].plot(
        time_axis,
        ce_target[ant_idx_target,:] / (units.microvolt / units.m),
        c='k', label='CoREAS',
        alpha=0.5
    )

    ax[0].plot(
        synthesis_jax.get_time_axis()[ant_idx],
        synth_geo_jax[ant_idx] / (units.microvolt / units.m),
        '--',
        c='r', label='TS - JAX',
        alpha=0.8
    )
    ax[1].plot(
        synthesis_jax.get_time_axis()[ant_idx],
        synth_ce_jax[ant_idx] / (units.microvolt / units.m),
        '--',
        c='r', label='TS - JAX',
        alpha=0.8
    )

    # find the antenna name in the numpy version, since the indices
    # may not match
    ant_idx_np = list(synthesis_np.get_antenna_names()).index(ant_name)

    ax[0].plot(
        synthesis_np.get_time_axis()[ant_idx_np],
        synth_geo_np[ant_idx_np] / (units.microvolt / units.m),
        '--',
        c='b', label='TS - NumPy',
        alpha = 0.8
    )
    ax[1].plot(
        synthesis_np.get_time_axis()[ant_idx_np],
        synth_ce_np[ant_idx_np] / (units.microvolt / units.m),
        '--',
        c='b', label='TS - NumPy',
        alpha = 0.8
    )

    # ax[0].set_xlim([-10, 100])
    # ax[1].set_xlim([-10, 100])

    ax[0].set_ylabel(r'$E [\mu \mathrm{V}/\mathrm{m}]$', size=16)
    ax[0].set_xlabel('Time [ns]')
    ax[1].set_xlabel('Time [ns]')

    ax[0].legend(fontsize=14)
    ax[1].legend(fontsize=14)

    ax[0].set_title('Geomagnetic component')
    ax[1].set_title('Charge-excess component')

    fig.suptitle(fr'Signals for antenna {ant_name}, $d_\mathrm{{core}} = {target_jax.antenna_array["dis_to_core"][ant_idx_target]:.2f}$ m' + '\n'
                 r'$X^{\mathrm{origin}}_{\mathrm{max}}$ = ' + f' {origin_jax.xmax:.1f} ' + r'$\mathrm{g}/\mathrm{cm}^2$' + ' - '
                 r' $X^{\mathrm{target}}_{\mathrm{max}}$ = ' + f' {target_jax.xmax:.1f} ' + r'$\mathrm{g}/\mathrm{cm}^2$'
                 '\n', y=1.05, size=17)
/tmp/ipykernel_225867/4190965798.py:3: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (matplotlib.pyplot.figure) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam figure.max_open_warning). Consider using matplotlib.pyplot.close().
  fig, ax = plt.subplots(1, 2, figsize=(12, 6))
../../_images/using_jax_13_1.png ../../_images/using_jax_13_2.png ../../_images/using_jax_13_3.png ../../_images/using_jax_13_4.png ../../_images/using_jax_13_5.png ../../_images/using_jax_13_6.png ../../_images/using_jax_13_7.png ../../_images/using_jax_13_8.png ../../_images/using_jax_13_9.png ../../_images/using_jax_13_10.png ../../_images/using_jax_13_11.png ../../_images/using_jax_13_12.png ../../_images/using_jax_13_13.png ../../_images/using_jax_13_14.png ../../_images/using_jax_13_15.png ../../_images/using_jax_13_16.png ../../_images/using_jax_13_17.png ../../_images/using_jax_13_18.png ../../_images/using_jax_13_19.png ../../_images/using_jax_13_20.png ../../_images/using_jax_13_21.png ../../_images/using_jax_13_22.png ../../_images/using_jax_13_23.png ../../_images/using_jax_13_24.png

Comparing the fluence distribution

def get_fluence(traces, delta_t):
    """Calculate fluence from traces.

    Args:
        traces (ndarray): shape (n_pol, n_antennas, n_times)
        delta_t (float): time step in nanoseconds

    Returns:
        ndarray: fluence for each antenna, shape (n_antennas,)
    """
    conversion_factor_integrated_signal = 2.65441729e-3 * 6.24150934e18  # V**2/m**2 * s -> J/m**2 -> eV/m**2
    fluence = jnp.sum(traces**2, axis=(0,2)) * delta_t / units.s * conversion_factor_integrated_signal

    return fluence
target_traces = np.array([geo_target, ce_target])
synth_traces_jax = np.squeeze(np.array([synth_geo_jax, synth_ce_jax]))
synth_traces_np = np.squeeze(np.array([synth_geo_np, synth_ce_np]))
# get the distance to the core
synth_dcore_jax = np.linalg.norm(synthesis_jax.antenna_information['position_showerplane'][:,:2], axis=1)
synth_dcore_np = np.linalg.norm(synthesis_np.antenna_information["position_showerplane"][:,:2], axis=1)
fig, ax = plt.subplots(figsize=(7,4))

ax.plot(
    target_jax.antenna_array['dis_to_core'] / units.m,
    get_fluence(target_traces, delta_t = target_jax.delta_t / units.ns) / (units.eV / units.m**2),
    'o',
    label='CoREAS',
    alpha=0.8,
    ms=6,
    color='k',
    lw=2
)

ax.plot(
    synth_dcore_jax / units.m,
    get_fluence(synth_traces_jax, delta_t = synthesis_jax.delta_t / units.ns) / (units.eV / units.m**2),
    's',
    label='JAX',
    alpha=0.8,
    color='b',
    ms=6,
    lw=2
)

ax.plot(
    synth_dcore_np / units.m,
    get_fluence(synth_traces_np, delta_t = synthesis_np.delta_t / units.ns) / (units.eV / units.m**2),
    '^',
    label='NumPy',
    alpha=0.8,
    color='r',
    ms=6,
    lw=2
)

ax.set_xlabel('Distance to core [m]', size=16)
ax.set_ylabel(r'Fluence [$\mathrm{eV}/\mathrm{m}^2$]', size=16)

ax.tick_params(axis='both', which='major', labelsize=14)

ax.legend(fontsize=14)

ax.set_ylim(ymax=10)
# ax.set_xlim(xmax=20)

ax.set_title(r'$X^{\mathrm{origin}}_{\mathrm{max}}$ = ' + f' {origin_np.xmax:.1f} ' + r'$\mathrm{g}/\mathrm{cm}^2$' + ' - '
                 r' $X^{\mathrm{target}}_{\mathrm{max}}$ = ' + f' {target_np.xmax:.1f} ' + r'$\mathrm{g}/\mathrm{cm}^2$' + '\n'
                 r'$\theta^{\mathrm{origin}}$ = ' + f' {np.rad2deg(origin_np.zenith):.1f} ' + r'$^\circ$' + ' - '
                 r'$\theta^{\mathrm{target}}$ = ' + f' {np.rad2deg(target_np.zenith):.1f} ' + r'$^\circ$'
                 , size=17)
Text(0.5, 1.0, '$X^{\mathrm{origin}}_{\mathrm{max}}$ =  611.9 $\mathrm{g}/\mathrm{cm}^2$ -  $X^{\mathrm{target}}_{\mathrm{max}}$ =  663.1 $\mathrm{g}/\mathrm{cm}^2$n$\theta^{\mathrm{origin}}$ =  3.0 $^\circ$ - $\theta^{\mathrm{target}}$ =  3.0 $^\circ$')
../../_images/using_jax_18_1.png