Source code for resistics.plot

"""
Module to help plotting various data
"""
from typing import List, Dict, Tuple, Optional, Union
import numpy as np
import pandas as pd
import lttbc
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

PLOTLY_TEMPLATE = "seaborn"
PLOTLY_MARGIN = dict(l=0, r=0, b=0, t=50)


[docs]def lttb_downsample( x: np.ndarray, y: np.ndarray, max_pts: int = 5_000 ) -> Tuple[np.ndarray, np.ndarray]: """ Downsample x, y for visualisation Parameters ---------- x : np.ndarray x array y : np.ndarray y array max_pts : int, optional Maximum number of points after downsampling, by default 5000 Returns ------- Tuple[np.ndarray, np.ndarray] (new_x, new_y), the downsampled x and y arrays Raises ------ ValueError If the size of x does not match the size of y """ if x.size != y.size: raise ValueError(f"x size {x.size} must equal y size {y.size}") if max_pts >= x.size: return x, y x_dtype = x.dtype y_dtype = y.dtype nx, ny = lttbc.downsample( x.astype(np.float32), y.astype(np.float32), max_pts, ) return nx.astype(x_dtype), ny.astype(y_dtype)
[docs]def apply_lttb( data: np.ndarray, max_pts: Union[int, None] ) -> Tuple[np.ndarray, np.ndarray]: """ Apply lttb downsampling if max_pts is not None There is a helper function Parameters ---------- data : np.ndarray The data to downsample max_pts : Union[int, None] The maximum number of points or None. If None, no downsamping is performed Returns ------- Tuple[np.ndarray, np.ndarray] Indices and data selected for plotting """ indices = np.arange(data.size) if max_pts is None: return indices, data indices, data = lttb_downsample(indices, data, max_pts) return indices, data
[docs]def plot_timeline( df: pd.DataFrame, y_col: str, title: str = "Timeline", ref_time: Optional[pd.Timestamp] = None, ) -> go.Figure: """ Plot a timeline Parameters ---------- df : pd.DataFrame DataFrame with the first and last times of the horizontal bars y_col : str The column to use for the y axis title : str, optional The title for the plot, by default "Timeline" ref_time : Optional[pd.Timestamp], optional The reference time, by default None Returns ------- go.Figure Plotly figure """ # get range for x axis min_time = df["first_time"].min() if ref_time is not None and ref_time < min_time: min_time = ref_time max_time = df["last_time"].max() pad = 0.1 * (max_time - min_time) min_time = min_time - pad max_time = max_time + pad # sort for ordering df = df.sort_values([y_col, "first_time"]) fig = px.timeline( df, x_start="first_time", x_end="last_time", y=y_col, color="fs", title=title ) if ref_time is not None: fig.add_vline(x=ref_time, line_width=3, line_dash="dash", line_color="red") fig.update_layout(template=PLOTLY_TEMPLATE, margin=dict(PLOTLY_MARGIN)) fig.update_xaxes(range=[min_time, max_time]) fig.update_layout(legend=dict(itemclick=False, itemdoubleclick=False)) return fig
[docs]def get_calibration_fig() -> go.Figure: """ Get a figure for plotting calibration data Returns ------- go.Figure Plotly figure """ fig = make_subplots( rows=2, cols=1, shared_xaxes=True, subplot_titles=["Magnitude", "Phase"], vertical_spacing=0.05, ) fig.update_xaxes(type="log", row=1, col=1) fig.update_yaxes(title_text="Magnitude, nT/mV", type="log", row=1, col=1) fig.update_xaxes(title_text="Frequency, Hz", type="log", row=2, col=1) fig.update_yaxes(title_text="Phase, radians", row=2, col=1) fig.layout.update(template=PLOTLY_TEMPLATE, margin=dict(PLOTLY_MARGIN)) return fig
[docs]def get_time_fig(chans: List[str], y_axis_label: Dict[str, str]) -> go.Figure: """ Get a figure for plotting time data Parameters ---------- chans : List[str] The channels to plot y_axis_label : Dict[str, str] The labels to use for the y axis Returns ------- go.Figure Plotly figure """ fig = make_subplots( rows=len(chans), cols=1, shared_xaxes=True, subplot_titles=[f"Channel {chan}" for chan in chans], vertical_spacing=0.05, ) for idx, chan in enumerate(chans): fig.update_yaxes(title_text=y_axis_label[chan], row=idx + 1, col=1) fig.layout.update(template=PLOTLY_TEMPLATE, margin=dict(PLOTLY_MARGIN)) return fig
[docs]def get_spectra_stack_fig(chans: List[str], y_axis_label: Dict[str, str]) -> go.Figure: """ Get a figure for plotting spectra stack data Parameters ---------- chans : List[str] The channels to plot y_axis_label : Dict[str, str] The y axis labels Returns ------- go.Figure Plotly figure """ fig = make_subplots( rows=len(chans), cols=1, shared_xaxes=True, subplot_titles=[f"Channel {chan}" for chan in chans], vertical_spacing=0.05, ) for idx, chan in enumerate(chans): fig.update_xaxes(type="log") fig.update_yaxes(title_text=y_axis_label[chan], type="log", row=idx + 1, col=1) fig.update_xaxes(title_text="Frequency, Hz", row=len(chans), col=1) fig.layout.update(template=PLOTLY_TEMPLATE, margin=dict(PLOTLY_MARGIN)) return fig
[docs]def get_spectra_section_fig(chans: List[str]) -> go.Figure: """ Get figure for plotting spectra sections Parameters ---------- chans : List[str] The channels to plot Returns ------- go.Figure Plotly figure """ fig = make_subplots( rows=len(chans), cols=1, subplot_titles=[f"Channel {chan}" for chan in chans], vertical_spacing=0.05, x_title="Date", y_title="Frequency, Hz", ) fig.layout.update(template=PLOTLY_TEMPLATE, margin=dict(PLOTLY_MARGIN)) return fig