import os
import gc
import uuid
import glob
import matplotlib
import sunpy.map
import numpy as np
from tqdm import tqdm
import astropy.units as u
from astropy.wcs import WCS
from astropy.time import Time
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from scipy.ndimage import uniform_filter
from astropy.coordinates import SkyCoord
from datetime import datetime, timedelta
from scipy.interpolate import RegularGridInterpolator
import sunpy.visualization.colormaps.cm as sunpy_cm
from matplotlib.ticker import FuncFormatter, FixedLocator, AutoMinorLocator
from sunpy.coordinates import SphericalScreen, Helioprojective, propagate_with_solar_surface
from scipy.optimize import curve_fit
import matplotlib.patheffects as pe
matplotlib.use('TkAgg')
matplotlib.rcParams.update({'font.size': 18, 'font.weight': 'bold'})
data_dir = '/home/JSOC_20250404_004534'
wavelength = 304
chunk_size = 100
npz_file = f'aia{wavelength}_maps.npz'
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_npz_file = f'aia{wavelength}_timeslice_{timestamp}_{uuid.uuid4().hex[:8]}.npz'
click_counter = {'count': 0}
files = sorted(glob.glob(os.path.join(data_dir, f'aia.lev1_euv_12s.*.{wavelength}.image.fits')))
if len(files) <= 182:
raise IndexError(f"File index 182 out of bounds. Only {len(files)} images found.")
print(f"Number of images: {len(files)}")
cmap = getattr(sunpy_cm, f'sdoaia{wavelength}')
regenerate_npz = False
if os.path.exists(npz_file):
try:
with np.load(npz_file, allow_pickle=True) as data:
if 'map_times' not in data or 'map_3d' not in data or 'map_wcs' not in data:
print("Warning: 'map_times', 'map_3d', or 'map_wcs' missing in npz file. Regenerating.")
regenerate_npz = True
else:
if len(data['map_wcs']) > 0 and isinstance(data['map_wcs'][0], str):
print("Warning: 'map_wcs' in npz file is in old string format. Regenerating.")
regenerate_npz = True
except Exception as e:
print(f"Error reading npz file: {e}. Regenerating.")
regenerate_npz = True
else:
regenerate_npz = True
if regenerate_npz:
print("Regenerating NPZ file from FITS data...")
ref_map = sunpy.map.Map(files[0])
ref_observer = ref_map.observer_coordinate
if ref_observer is None:
raise ValueError("Reference map observer coordinate missing.")
ref_hpc_frame = Helioprojective(obstime=ref_map.date, observer=ref_observer)
with SphericalScreen(ref_observer):
ref_map = ref_map.submap(
SkyCoord(700 * u.arcsec, -300 * u.arcsec, frame=ref_hpc_frame),
top_right=SkyCoord(1190 * u.arcsec, 200 * u.arcsec, frame=ref_hpc_frame))
ref_shape = np.array(ref_map.data.shape) * u.pix
ref_wcs = ref_map.wcs
all_maps = []
for i in tqdm(range(0, len(files), chunk_size)):
chunk_files = files[i:i + chunk_size]
for file in chunk_files:
try:
smap = sunpy.map.Map(file)
smap_hpc_frame = Helioprojective(obstime=smap.date, observer=smap.observer_coordinate)
with SphericalScreen(smap.observer_coordinate), propagate_with_solar_surface():
smap_sub = smap.submap(
SkyCoord(700 * u.arcsec, -300 * u.arcsec, frame=smap_hpc_frame),
top_right=SkyCoord(1190 * u.arcsec, 200 * u.arcsec, frame=smap_hpc_frame))
repro_map = smap_sub.reproject_to(ref_wcs)
repro_map = repro_map.resample(ref_shape, method='linear')
data = repro_map.data.copy().astype(np.float32)
data[np.isnan(data)] = [repro_map]
data[data <= 0] = 1e-6
max_val = np.percentile(data[data > 0], 99.5)
data[data > max_val] = max_val
print(f"File {file}: Reprojected map data min={np.nanmin(data)}, max={np.nanmax(data)}, NaN count={np.sum(np.isnan(data))}")
cleaned_map = sunpy.map.Map(data, repro_map.meta)
all_maps.append(cleaned_map)
except Exception as e:
print(f"Error processing file {file}: {e}")
continue
map_3d_arr = np.array([m.data for m in all_maps])
map_wcs_headers_dict = np.array([dict(m.wcs.to_header()) for m in all_maps], dtype=object)
map_times_dt = np.array([m.date.datetime for m in all_maps])
np.savez_compressed(npz_file, map_3d=map_3d_arr, map_wcs=map_wcs_headers_dict, map_times=map_times_dt)
del map_3d_arr, map_wcs_headers_dict, map_times_dt
gc.collect()
else:
print(f"Loading data from existing {npz_file}...")
data = np.load(npz_file, allow_pickle=True, mmap_mode='r')
loaded_map_3d = data['map_3d']
loaded_map_wcs_headers = data['map_wcs']
try:
loaded_map_times_dt = data['map_times']
except KeyError:
print("Warning: 'map_times' not found in npz file. Generating times with 12-second cadence.")
loaded_map_times_dt = np.array([datetime(2014, 2, 19, 9, 45, 6) + timedelta(seconds=12 * i) for i in range(len(loaded_map_3d))])
all_maps = []
for i in range(len(loaded_map_3d)):
try:
current_wcs = WCS(loaded_map_wcs_headers[i])
current_map = sunpy.map.Map(loaded_map_3d[i], current_wcs.to_header())
current_map.meta['date_obs'] = loaded_map_times_dt[i].isoformat()
all_maps.append(current_map)
except Exception as e:
print(f"Error reconstructing map {i} from NPZ: {e}")
continue
del loaded_map_3d, loaded_map_wcs_headers, loaded_map_times_dt
gc.collect()
if len(all_maps) > 0:
print(f"First frame time: {all_maps[0].date.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Last frame time: {all_maps[-1].date.strftime('%Y-%m-%d %H:%M:%S')}")
def calculate_cumulative_distances(points, smap):
if len(points) < 2:
print("Error: Insufficient points for distance calculation.")
return np.array([])
world_coords = smap.pixel_to_world(np.array(points)[:, 0] * u.pixel, np.array(points)[:, 1] * u.pixel)
x_arcsec = world_coords.Tx.to(u.arcsec).value
y_arcsec = world_coords.Ty.to(u.arcsec).value
segment_lengths = np.sqrt(np.diff(x_arcsec)**2 + np.diff(y_arcsec)**2)
cumulative_distances = np.concatenate(([0], np.cumsum(segment_lengths)))
return cumulative_distances
def interpolate_along_path(smap, path_x_pixels, path_y_pixels):
data = smap.data
ny, nx = data.shape
path_x_pixels_clamped = np.clip(path_x_pixels, 0, nx - 1)
path_y_pixels_clamped = np.clip(path_y_pixels, 0, ny - 1)
interpolated_values = RegularGridInterpolator(
(np.arange(ny), np.arange(nx)), data, method='linear', bounds_error=False, fill_value=0
)((path_y_pixels_clamped, path_x_pixels_clamped))
if np.all(interpolated_values == 0):
print("Warning: All interpolated values are zero. Check slit path or map data.")
return interpolated_values
def generate_time_distance_data(maps, cut_points_pixel, slit_width_arcsec=None, boxcar_filter_size=None):
if not maps or len(cut_points_pixel) < 2:
print("Error: No maps or insufficient cut points for time-distance data.")
return np.array([]), np.array([]), []
first_map = maps[0]
wcs = first_map.wcs
x_pixels = np.array([p[0] for p in cut_points_pixel])
y_pixels = np.array([p[1] for p in cut_points_pixel])
segment_lengths_pix = np.sqrt(np.diff(x_pixels)**2 + np.diff(y_pixels)**2)
total_length_pix = np.sum(segment_lengths_pix)
num_points_interp = int(total_length_pix * 2) + 1
t = np.linspace(0, 1, num_points_interp)
dense_x_pix = np.zeros(num_points_interp)
dense_y_pix = np.zeros(num_points_interp)
cumulative_lengths_pix = np.concatenate(([0], np.cumsum(segment_lengths_pix)))
normalized_lengths = cumulative_lengths_pix / cumulative_lengths_pix[-1]
for i in range(len(t)):
t_val = t[i]
segment_idx = np.searchsorted(normalized_lengths, t_val, side='right') - 1
segment_idx = min(segment_idx, len(cut_points_pixel) - 2)
t_segment = (t_val - normalized_lengths[segment_idx]) / (normalized_lengths[segment_idx + 1] - normalized_lengths[segment_idx]) if normalized_lengths[segment_idx + 1] != normalized_lengths[segment_idx] else 0
dense_x_pix[i] = x_pixels[segment_idx] + t_segment * (x_pixels[segment_idx + 1] - x_pixels[segment_idx])
dense_y_pix[i] = y_pixels[segment_idx] + t_segment * (y_pixels[segment_idx + 1] - x_pixels[segment_idx])
path_points_for_dist = list(zip(dense_x_pix, dense_y_pix))
cumulative_distances = calculate_cumulative_distances(path_points_for_dist, first_map)
num_spatial_points = len(dense_x_pix)
map_times = [m.date.datetime for m in maps]
stackplot_data = np.zeros((num_spatial_points, len(maps)))
pixel_scale_x = np.abs(wcs.pixel_to_world(1*u.pixel,0*u.pixel).Tx.value - wcs.pixel_to_world(0*u.pixel,0*u.pixel).Tx.value)
pixel_scale_y = np.abs(wcs.pixel_to_world(0*u.pixel,1*u.pixel).Ty.value - wcs.pixel_to_world(0*u.pixel,0*u.pixel).Ty.value)
avg_pixel_scale = (pixel_scale_x + pixel_scale_y) / 2.0
current_slit_pixels = 0
if slit_width_arcsec is not None:
current_slit_pixels = int(np.round(slit_width_arcsec / avg_pixel_scale))
dx = np.diff(dense_x_pix)
dy = np.diff(dense_y_pix)
segment_lengths = np.sqrt(dx**2 + dy**2)
segment_lengths = np.concatenate(([segment_lengths[0]], segment_lengths))
unit_tx = np.zeros_like(dense_x_pix)
unit_ty = np.zeros_like(dense_y_pix)
unit_px = np.zeros_like(dense_x_pix)
unit_py = np.zeros_like(dense_y_pix)
for i in range(len(dense_x_pix)):
if i == 0:
tx = dx[0]
ty = dy[0]
elif i == len(dense_x_pix) - 1:
tx = dx[-1]
ty = dy[-1]
else:
tx = (dx[i-1] + dx[i]) / 2
ty = (dy[i-1] + dy[i]) / 2
length = np.sqrt(tx**2 + ty**2)
if length > 0:
unit_tx[i] = tx / length
unit_ty[i] = ty / length
unit_px[i] = unit_ty[i]
unit_py[i] = -unit_tx[i]
else:
unit_tx[i] = unit_ty[i] = unit_px[i] = unit_py[i] = 0
for i, smap in enumerate(maps):
map_data = smap.data.copy()
if boxcar_filter_size is not None and boxcar_filter_size > 1:
map_data = uniform_filter(map_data, size=boxcar_filter_size, mode='constant')
temp_map = sunpy.map.Map(map_data, smap.meta)
slit_intensities_per_map = []
for k_offset in range(-current_slit_pixels, current_slit_pixels + 1):
slit_line_x_pixels = dense_x_pix + k_offset * unit_px
slit_line_y_pixels = dense_y_pix + k_offset * unit_py
interpolated_values = interpolate_along_path(temp_map, slit_line_x_pixels, slit_line_y_pixels)
slit_intensities_per_map.append(interpolated_values)
stackplot_data[:, i] = np.nanmean(np.array(slit_intensities_per_map), axis=0)
if np.all(stackplot_data == 0):
print("Warning: Stackplot data is all zeros. Check if slit path lies in valid data region.")
elif np.any(np.isnan(stackplot_data)):
print(f"Warning: Stackplot data contains {np.sum(np.isnan(stackplot_data))} NaN values.")
return stackplot_data, cumulative_distances, map_times
def linear_function(x, m, c):
return m * x + c
if len(files) > 182:
aia_map_initial = sunpy.map.Map(files[182])
else:
if len(files) > 0:
print("Warning: Not enough FITS files to use index 182. Using the first available map for initial display.")
aia_map_initial = sunpy.map.Map(files[0])
else:
raise ValueError("No FITS files found to create an initial map. Check 'data_dir'.")
observer = aia_map_initial.observer_coordinate
if observer is None:
raise ValueError("Observer coordinate missing. Verify data_dir file metadata.")
hpc_frame = Helioprojective(obstime=aia_map_initial.date, observer=observer)
bottom_left = SkyCoord(700 * u.arcsec, -300 * u.arcsec, frame=hpc_frame)
top_right = SkyCoord(1190 * u.arcsec, 200 * u.arcsec, frame=hpc_frame)
with SphericalScreen(observer):
aia_map_initial = aia_map_initial.submap(bottom_left, top_right=top_right)
data_initial = aia_map_initial.data.copy()
data_initial[data_initial <= 0] = 1e-6
max_val_initial = np.percentile(data_initial[data_initial > 0], 99.5)
data_initial[data_initial > max_val_initial] = max_val_initial
aia_map_initial = sunpy.map.Map(data_initial, aia_map_initial.meta)
data_nonzero_initial = aia_map_initial.data[data_initial > 0]
vmin_initial = np.percentile(data_nonzero_initial, 1.5)
vmax_initial = np.percentile(data_nonzero_initial, 98)
if len(all_maps) <= 182:
print(f"Warning: 'initial_map_index' (182) is out of bounds for the loaded maps ({len(all_maps)} maps). Setting to 0.")
initial_plot_index = 0
else:
initial_plot_index = 182
vmin = vmin_initial
vmax = vmax_initial
class Plotter:
def __init__(self, all_maps, initial_map_index=0):
self.maps = all_maps
self.current_map_index = initial_map_index
if not self.maps:
raise ValueError("No SunPy Map objects provided to Plotter.")
self.fig = plt.figure(figsize=(30, 30))
gs = self.fig.add_gridspec(1, 2)
self.ax1 = self.fig.add_subplot(gs[0, 0], projection=aia_map_initial.wcs)
self.ax2 = self.fig.add_subplot(gs[0, 1])
self.cut_points_pixel = []
self.cut_lines = []
self.stackplot_data = None
self.cumulative_distances = None
self.map_times = None
self.upward_points = []
self.downward_points = []
self.fitted_lines = []
self.upward_speed_kms = None
self.downward_speed_kms = None
self.speed_text_artists = []
self.map_image = None
self.slit_width_arcsec = 2.0
self.boxcar_filter_size = None
self.plot_initial_first_panel()
self.plot_second_panel_placeholder()
self.connect_events()
def connect_events(self):
self.cid_click = self.fig.canvas.mpl_connect('button_press_event', self.on_click)
def on_click(self, event):
if event.inaxes == self.ax1:
self.on_click_first_panel(event)
elif event.inaxes == self.ax2:
self.on_click_second_panel(event)
def on_click_first_panel(self, event):
if event.button == 1:
x_pixel, y_pixel = event.xdata, event.ydata
if x_pixel is not None and y_pixel is not None:
map_shape = self.maps[self.current_map_index].data.shape
x_pixel = np.clip(x_pixel, 0, map_shape[1] - 1)
y_pixel = np.clip(y_pixel, 0, map_shape[0] - 1)
self.cut_points_pixel.append((x_pixel, y_pixel))
print(f"Click {len(self.cut_points_pixel)}: ({x_pixel:.2f}, {y_pixel:.2f}) pixels")
self.update_cut_line_and_points()
if len(self.cut_points_pixel) == 16:
self.generate_and_plot_stackplot()
elif event.button == 3:
self.cut_points_pixel = []
for artist in self.cut_lines:
artist.remove()
self.cut_lines = []
for collection in self.ax1.collections:
collection.remove()
if self.map_image is None:
self.plot_initial_first_panel()
else:
self.fig.canvas.draw_idle()
print("Cut points cleared.")
def plot_initial_first_panel(self):
current_map = aia_map_initial
self.map_image = current_map.plot(axes=self.ax1, norm=LogNorm(vmin=vmin_initial, vmax=vmax_initial), cmap=cmap, title='')
timestamp_for_display = current_map.date.strftime("%Y-%m-%d %H:%M:%S")
self.ax1.text(0.05, 0.95, f"{timestamp_for_display}",
transform=self.ax1.transAxes, color='white', fontsize=16,
path_effects=[pe.withStroke(linewidth=1, foreground="black")],
zorder=12)
self.ax1.grid(False)
self.ax1.tick_params(axis='both', direction='in', width=1.6, length=7, labelsize=18, which='major', labelcolor='black')
self.ax1.set_xlabel('Solar-X (arcsec)', fontsize=20, fontweight='bold')
self.ax1.set_ylabel('Solar-Y (arcsec)', fontsize=20, fontweight='bold')
if hasattr(self.ax1, 'coords'):
self.ax1.coords[0].set_ticks(spacing=50 * u.arcsec)
self.ax1.coords[0].set_format_unit(u.arcsec, decimal=True, show_decimal_unit=False)
self.ax1.coords[1].set_ticks(spacing=50 * u.arcsec)
self.ax1.coords[1].set_format_unit(u.arcsec, decimal=True, show_decimal_unit=False)
self.ax1.coords[0].display_minor_ticks(True)
self.ax1.coords[1].display_minor_ticks(True)
self.ax1.coords[0].tick_params(which='minor', length=4)
self.ax1.coords[1].tick_params(which='minor', length=4)
self.ax1.coords[0].set_minor_frequency(5)
self.ax1.coords[1].set_minor_frequency(5)
self.ax1.coords[0].set_ticklabel(visible=True, weight='bold', color='black')
self.ax1.coords[1].set_ticklabel(visible=True, weight='bold', color='black')
self.ax1.coords[0].set_ticklabel(rotation=0)
self.ax1.coords[1].set_ticklabel(rotation=0)
self.ax1.axis('auto')
else:
print("Warning: ax1.coords not found. WCS specific settings might not be applied.")
bottom_left_pix = current_map.pixel_to_world(0 * u.pix, 0 * u.pix)
top_right_pix = current_map.pixel_to_world(current_map.data.shape[1] * u.pix, current_map.data.shape[0] * u.pix)
print(f"WCS coordinate range: Bottom-left = ({bottom_left_pix.Tx.value:.1f}, {bottom_left_pix.Ty.value:.1f}) arcsec, "
f"Top-right = ({top_right_pix.Tx.value:.1f}, {top_right_pix.Ty.value:.1f}) arcsec")
self.fig.canvas.draw_idle()
def update_cut_line_and_points(self):
for artist in self.cut_lines:
artist.remove()
self.cut_lines = []
for collection in self.ax1.collections:
collection.remove()
if len(self.cut_points_pixel) >= 2:
for i in range(len(self.cut_points_pixel) - 1):
start_pix_x, start_pix_y = self.cut_points_pixel[i]
end_pix_x, end_pix_y = self.cut_points_pixel[i + 1]
arrow = self.ax1.arrow(start_pix_x, start_pix_y,
end_pix_x - start_pix_x, end_pix_y - start_pix_y,
color='white', width=2, head_width=10, head_length=10, zorder=10)
self.cut_lines.append(arrow)
if self.cut_points_pixel:
x_coords = [p[0] for p in self.cut_points_pixel]
y_coords = [p[1] for p in self.cut_points_pixel]
self.ax1.scatter(x_coords, y_coords, color='black', s=50, zorder=11)
self.fig.canvas.draw_idle()
def generate_and_plot_stackplot(self):
if len(self.cut_points_pixel) < 2:
print("Please select at least two points for the cut.")
return
print(f"Generating stackplot with slit_width_arcsec={self.slit_width_arcsec} "
f"and boxcar_filter_size={self.boxcar_filter_size}...")
print(f"Number of cut points: {len(self.cut_points_pixel)}")
print(f"Number of maps: {len(self.maps)}")
print(f"Slit points (pixels): {self.cut_points_pixel}")
world_coords = self.maps[0].pixel_to_world(np.array([p[0] for p in self.cut_points_pixel]) * u.pixel,
np.array([p[1] for p in self.cut_points_pixel]) * u.pixel)
print(f"Slit points (arcsec): {[(c.Tx.value, c.Ty.value) for c in world_coords]}")
self.stackplot_data, self.cumulative_distances, self.map_times = \
generate_time_distance_data(self.maps, self.cut_points_pixel,
slit_width_arcsec=self.slit_width_arcsec,
boxcar_filter_size=self.boxcar_filter_size)
print(f"Stackplot data shape: {self.stackplot_data.shape if self.stackplot_data is not None else 'None'}")
print(f"Stackplot data valid: {np.any(self.stackplot_data > 0) if self.stackplot_data is not None and self.stackplot_data.size > 0 else 'Empty or None'}")
print(f"Cumulative distances: {self.cumulative_distances[:5] if self.cumulative_distances is not None else 'None'}")
print(f"Map times: {self.map_times[:5] if self.map_times else 'None'}")
if self.stackplot_data is not None and self.stackplot_data.size > 0:
speeds_to_save = [self.upward_speed_kms if self.upward_speed_kms is not None else np.nan,
self.downward_speed_kms if self.downward_speed_kms is not None else np.nan]
npz_data = {
'timeslice': self.stackplot_data.astype(np.float32),
'cut_x_pixel': np.array([p[0] for p in self.cut_points_pixel], dtype=np.float64),
'cut_y_pixel': np.array([p[1] for p in self.cut_points_pixel], dtype=np.float64),
'distance_along_slice': self.cumulative_distances.astype(np.float64),
'map_times': np.array([t.isoformat() for t in self.map_times], dtype=str),
'speeds': np.array(speeds_to_save, dtype=np.float32),
'units': 'timeslice: intensity (DN), cut_x/y_pixel: pixels, distance_along_slice: arcsec (Euclidean), map_times: ISO format, speeds: [upward, downward] km/s',
'description': f'AIA {wavelength} time-distance plot, generated {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}'}
try:
np.savez_compressed(output_npz_file, **npz_data)
print(f"Saved time-distance plot to {output_npz_file}")
except Exception as e:
print(f"Error saving .npz file: {e}")
self.plot_second_panel()
def plot_second_panel_placeholder(self):
self.ax2.clear()
self.ax2.set_xlabel("Time (UT)", fontsize=20, fontweight='bold')
self.ax2.set_ylabel("Position along Cut (arcsec)", fontsize=20, fontweight='bold')
self.ax2.text(0.5, 0.5, "No Stackplot Generated Yet",
horizontalalignment='center', verticalalignment='center',
transform=self.ax2.transAxes, fontsize=18, color='gray')
self.ax2.tick_params(axis='x', which='major', labelsize=18, width=1.6, length=7, direction='in', labelcolor='black')
self.ax2.tick_params(axis='y', which='major', labelsize=18, width=1.6, length=7, direction='in', labelcolor='black')
self.ax2.tick_params(axis='both', which='minor', length=4, width=1, direction='in', colors='black')
self.ax2.minorticks_on()
self.fig.canvas.draw_idle()