Skip to content

Visualization tools

hierarchical_clustering(df, vmin=None, vmax=None, figsize=(8, 8), top_height=2, left_width=2, xmaxticks=None, ymaxticks=None, metric='cosine', cmap=None)

Perform and plot hierarchical clustering on a dataframe.

Parameters:

Name Type Description Default
df DataFrame

Input data in DataFrame format.

required
vmin Optional[float]

Minimum value to anchor the colormap. If None, inferred from data.

None
vmax Optional[float]

Maximum value to anchor the colormap. If None, inferred from data.

None
figsize Tuple[int, int]

Size of the main figure in inches.

(8, 8)
top_height int

Height of the top dendrogram.

2
left_width int

Width of the left dendrogram.

2
xmaxticks Optional[int]

Maximum number of x-ticks to display.

None
ymaxticks Optional[int]

Maximum number of y-ticks to display.

None
metric Union[str, Tuple[str, str]]

Distance metric to use. Either a string to use the same metric for both axes, or a tuple of two strings for different metrics for each axis.

'cosine'
cmap Optional[str]

Matplotlib colormap name. If None, uses "coolwarm".

None

Returns:

Type Description
Tuple[DataFrame, Figure, List[int], List[int]]

A tuple containing: - The clustered DataFrame (reordered according to clustering) - The matplotlib Figure object - The indices of rows in their clustered order - The indices of columns in their clustered order

Source code in src/ms_mint/matplotlib_tools.py
def hierarchical_clustering(
    df: pd.DataFrame,
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
    figsize: Tuple[int, int] = (8, 8),
    top_height: int = 2,
    left_width: int = 2,
    xmaxticks: Optional[int] = None,
    ymaxticks: Optional[int] = None,
    metric: Union[str, Tuple[str, str]] = "cosine",
    cmap: Optional[str] = None,
) -> Tuple[pd.DataFrame, Figure, List[int], List[int]]:
    """Perform and plot hierarchical clustering on a dataframe.

    Args:
        df: Input data in DataFrame format.
        vmin: Minimum value to anchor the colormap. If None, inferred from data.
        vmax: Maximum value to anchor the colormap. If None, inferred from data.
        figsize: Size of the main figure in inches.
        top_height: Height of the top dendrogram.
        left_width: Width of the left dendrogram.
        xmaxticks: Maximum number of x-ticks to display.
        ymaxticks: Maximum number of y-ticks to display.
        metric: Distance metric to use. Either a string to use the same metric for
            both axes, or a tuple of two strings for different metrics for each axis.
        cmap: Matplotlib colormap name. If None, uses "coolwarm".

    Returns:
        A tuple containing:
            - The clustered DataFrame (reordered according to clustering)
            - The matplotlib Figure object
            - The indices of rows in their clustered order
            - The indices of columns in their clustered order
    """
    if isinstance(metric, str):
        metric_x, metric_y = metric, metric
    elif (
        isinstance(metric, tuple)
        and len(metric) == 2
        and isinstance(metric[0], str)
        and isinstance(metric[1], str)
    ):
        metric_x, metric_y = metric
    elif metric is None:
        metric_x, metric_y = None, None
    else:
        raise ValueError("Metric must be a string or a tuple of two strings")

    df = df.copy()

    # Subplot sizes
    total_width, total_height = figsize

    main_h = 1 - (top_height / total_height)
    main_w = 1 - (left_width / total_width)

    gap_x = 0.1 / total_width
    gap_y = 0.1 / total_height

    left_h = main_h
    left_w = 1 - main_w

    top_h = 1 - main_h
    top_w = main_w

    if xmaxticks is None:
        xmaxticks = int(5 * main_w * total_width)
    if ymaxticks is None:
        ymaxticks = int(5 * main_h * total_height)

    dm = df.fillna(0).values
    D1 = squareform(pdist(dm, metric=metric_y))
    D2 = squareform(pdist(dm.T, metric=metric_x))

    fig = plt.figure(figsize=figsize)
    fig.set_layout_engine('tight')

    # add left dendrogram
    ax1 = fig.add_axes([0, 0, left_w - gap_x, left_h], frameon=False)
    Y = linkage(D1, method="complete")
    Z1 = dendrogram(Y, orientation="left", color_threshold=0, above_threshold_color="k")
    ax1.set_xticks([])
    ax1.set_yticks([])
    # add top dendrogram
    ax2 = fig.add_axes([left_w, main_h + gap_y, top_w, top_h - gap_y], frameon=False)
    Y = linkage(D2, method="complete")
    Z2 = dendrogram(Y, color_threshold=0, above_threshold_color="k")
    ax2.set_xticks([])
    ax2.set_yticks([])
    # add matrix plot
    axmatrix = fig.add_axes([left_w, 0, main_w, main_h])
    idx1 = Z1["leaves"]
    idx2 = Z2["leaves"]
    D = dm[idx1, :]
    D = D[:, idx2]

    if cmap is None:
        cmap = "coolwarm"
    im = axmatrix.matshow(D[::-1], aspect="auto", cmap=cmap, vmin=vmin, vmax=vmax)

    clustered = df.iloc[Z1["leaves"][::-1], Z2["leaves"]]

    # Calculate tick positions
    ndx_y = np.linspace(0, len(clustered.index) - 1, ymaxticks)
    ndx_x = np.linspace(0, len(clustered.columns) - 1, xmaxticks)
    ndx_y = [int(i) for i in ndx_y]
    ndx_x = [int(i) for i in ndx_x]

    # Set tick positions and labels with proper formatting
    axmatrix.yaxis.tick_right()
    axmatrix.xaxis.tick_bottom()
    axmatrix.set_yticks(ndx_y)
    axmatrix.set_yticklabels([str(clustered.index[i]) for i in ndx_y], fontsize=8)
    axmatrix.set_xticks(ndx_x)
    axmatrix.set_xticklabels([str(clustered.columns[i]) for i in ndx_x], rotation=45, ha='right', fontsize=8)
    axmatrix.tick_params(axis='both', which='both', length=3)

    ndx_leaves = Z1["leaves"][::-1]
    col_leaves = Z2["leaves"]

    return clustered, fig, ndx_leaves, col_leaves

plot_metabolomics_hist2d(df, figsize=(4, 2.5), dpi=300, set_dim=True, cmap='jet', rt_range=None, mz_range=None, mz_bins=100, **kwargs)

Create a 2D histogram of metabolomics data.

Parameters:

Name Type Description Default
df DataFrame

DataFrame containing metabolomics data with scan_time, mz, and intensity columns.

required
figsize Tuple[float, float]

Size of the figure in inches (width, height).

(4, 2.5)
dpi int

Resolution of the figure in dots per inch.

300
set_dim bool

Whether to set figure dimensions.

True
cmap str

Colormap name to use for the plot.

'jet'
rt_range Optional[Tuple[float, float]]

Retention time range (min, max) to display. If None, uses data range.

None
mz_range Optional[Tuple[float, float]]

M/Z range (min, max) to display. If None, uses data range.

None
mz_bins int

Number of bins to use for the m/z axis.

100
**kwargs

Additional keyword arguments passed to plt.hist2d.

{}

Returns:

Type Description
Tuple[ndarray, ndarray, ndarray, Any]

The result of plt.hist2d, which is a tuple containing: - The histogram array - The edges of the bins along the x-axis - The edges of the bins along the y-axis - The Axes object

Source code in src/ms_mint/matplotlib_tools.py
def plot_metabolomics_hist2d(
    df: pd.DataFrame,
    figsize: Tuple[float, float] = (4, 2.5),
    dpi: int = 300,
    set_dim: bool = True,
    cmap: str = "jet",
    rt_range: Optional[Tuple[float, float]] = None,
    mz_range: Optional[Tuple[float, float]] = None,
    mz_bins: int = 100,
    **kwargs,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Any]:
    """Create a 2D histogram of metabolomics data.

    Args:
        df: DataFrame containing metabolomics data with scan_time, mz, and intensity columns.
        figsize: Size of the figure in inches (width, height).
        dpi: Resolution of the figure in dots per inch.
        set_dim: Whether to set figure dimensions.
        cmap: Colormap name to use for the plot.
        rt_range: Retention time range (min, max) to display. If None, uses data range.
        mz_range: M/Z range (min, max) to display. If None, uses data range.
        mz_bins: Number of bins to use for the m/z axis.
        **kwargs: Additional keyword arguments passed to plt.hist2d.

    Returns:
        The result of plt.hist2d, which is a tuple containing:
            - The histogram array
            - The edges of the bins along the x-axis
            - The edges of the bins along the y-axis
            - The Axes object
    """
    if set_dim:
        plt.figure(figsize=figsize, dpi=dpi)

    if mz_range is None:
        mz_range = (df.mz.min(), df.mz.max())

    if rt_range is None:
        rt_range = (df.scan_time.min(), df.scan_time.max())

    rt_bins = int((rt_range[1] - rt_range[0]) / 2)

    params = dict(vmin=1, vmax=1e3, cmap=cmap, range=(rt_range, mz_range))
    params.update(kwargs)

    fig = plt.hist2d(
        df["scan_time"],
        df["mz"],
        weights=df["intensity"].apply(np.log1p),
        bins=[rt_bins, mz_bins],
        **params,
    )

    plt.xlabel("Scan time [s]")
    plt.ylabel("m/z")
    plt.gca().ticklabel_format(useOffset=False, style="plain")
    return fig

plot_peak_shapes(mint_results, mint_metadata=None, fns=None, peak_labels=None, height=3, aspect=1.5, legend=False, col_wrap=4, hue='ms_file_label', title=None, dpi=None, sharex=False, sharey=False, kind='line', **kwargs)

Plot peak shapes from MS-MINT results.

Parameters:

Name Type Description Default
mint_results DataFrame

DataFrame in Mint results format.

required
mint_metadata Optional[DataFrame]

DataFrame in Mint metadata format for additional sample information.

None
fns Optional[List[str]]

Filenames to include. If None, includes all files.

None
peak_labels Optional[Union[str, List[str]]]

Peak label(s) to include. If None, includes all peak labels.

None
height int

Height of each figure facet in inches.

3
aspect float

Aspect ratio (width/height) of each figure facet.

1.5
legend bool

Whether to display a legend.

False
col_wrap int

Number of columns for subplots.

4
hue str

Column name to use for color grouping.

'ms_file_label'
title Optional[str]

Title to add to the figure.

None
dpi Optional[int]

Resolution of generated image.

None
sharex bool

Whether to share x-axis range between subplots.

False
sharey bool

Whether to share y-axis range between subplots.

False
kind str

Type of seaborn relplot ('line', 'scatter', etc.).

'line'
**kwargs

Additional keyword arguments passed to seaborn's relplot.

{}

Returns:

Type Description
FacetGrid

A seaborn FacetGrid object containing the plot.

Source code in src/ms_mint/matplotlib_tools.py
def plot_peak_shapes(
    mint_results: pd.DataFrame,
    mint_metadata: Optional[pd.DataFrame] = None,
    fns: Optional[List[str]] = None,
    peak_labels: Optional[Union[str, List[str]]] = None,
    height: int = 3,
    aspect: float = 1.5,
    legend: bool = False,
    col_wrap: int = 4,
    hue: str = "ms_file_label",
    title: Optional[str] = None,
    dpi: Optional[int] = None,
    sharex: bool = False,
    sharey: bool = False,
    kind: str = "line",
    **kwargs,
) -> sns.FacetGrid:
    """Plot peak shapes from MS-MINT results.

    Args:
        mint_results: DataFrame in Mint results format.
        mint_metadata: DataFrame in Mint metadata format for additional sample information.
        fns: Filenames to include. If None, includes all files.
        peak_labels: Peak label(s) to include. If None, includes all peak labels.
        height: Height of each figure facet in inches.
        aspect: Aspect ratio (width/height) of each figure facet.
        legend: Whether to display a legend.
        col_wrap: Number of columns for subplots.
        hue: Column name to use for color grouping.
        title: Title to add to the figure.
        dpi: Resolution of generated image.
        sharex: Whether to share x-axis range between subplots.
        sharey: Whether to share y-axis range between subplots.
        kind: Type of seaborn relplot ('line', 'scatter', etc.).
        **kwargs: Additional keyword arguments passed to seaborn's relplot.

    Returns:
        A seaborn FacetGrid object containing the plot.
    """
    R = mint_results.copy()
    R = R[R.peak_area > 0]
    R["peak_label"] = R["peak_label"]

    if peak_labels is not None:
        if isinstance(peak_labels, str):
            peak_labels = [peak_labels]
        R = R[R.peak_label.isin(peak_labels)]
    else:
        peak_labels = R.peak_label.drop_duplicates().values

    if fns is not None:
        R = R[R.ms_file.isin(fns)]

    dfs = []
    for peak_label in peak_labels:
        for _, row in R[(R.peak_label == peak_label) & (R.peak_n_datapoints > 1)].iterrows():
            peak_rt = [float(i) for i in row.peak_shape_rt.split(",")]
            peak_int = [float(i) for i in row.peak_shape_int.split(",")]
            ms_file_label = row.ms_file_label
            mz = row.mz_mean
            rt = row.rt

            df = pd.DataFrame(
                {
                    "Scan time [s]": peak_rt,
                    "Intensity": peak_int,
                    "ms_file_label": ms_file_label,
                    "peak_label": peak_label,
                    "Expected Scan time [s]": rt,
                }
            )
            dfs.append(df)

    if not dfs:
        return None

    df = pd.concat(dfs, ignore_index=True).reset_index(drop=True)

    # Add metadata
    if mint_metadata is not None:
        df = pd.merge(df, mint_metadata, left_on="ms_file_label", right_index=True, how="left")

    _facet_kws = dict(sharex=sharex, sharey=sharey)
    if "facet_kws" in kwargs.keys():
        _facet_kws.update(kwargs.pop("facet_kws"))

    g = sns.relplot(
        data=df,
        x="Scan time [s]",
        y="Intensity",
        hue=hue,
        col="peak_label",
        col_order=peak_labels,
        kind=kind,
        col_wrap=col_wrap,
        height=height,
        aspect=aspect,
        facet_kws=_facet_kws,
        legend=legend,
        **kwargs,
    )

    g.set_titles(row_template="{row_name}", col_template="{col_name}")

    for ax in g.axes.flatten():
        ax.ticklabel_format(style="sci", scilimits=(0, 0), axis="y")

    if title is not None:
        g.fig.suptitle(title, y=1.01)

    return g

plot_peaks(series, peaks=None, highlight=None, expected_rt=None, weights=None, legend=True, label=None, **kwargs)

Plot time series data with peak annotations.

Parameters:

Name Type Description Default
series Series

Time series data with time as index and intensity as values.

required
peaks Optional[DataFrame]

DataFrame containing peak information.

None
highlight Optional[List[int]]

List of peak indices to highlight.

None
expected_rt Optional[float]

Expected retention time to mark on the plot.

None
weights Optional[ndarray]

Array of weight values (e.g., for Gaussian weighting).

None
legend bool

Whether to display the legend.

True
label Optional[str]

Label for the time series data.

None
**kwargs

Additional keyword arguments passed to the plot function.

{}

Returns:

Type Description
Figure

Matplotlib Figure containing the plot.

Source code in src/ms_mint/matplotlib_tools.py
def plot_peaks(
    series: pd.Series,
    peaks: Optional[pd.DataFrame] = None,
    highlight: Optional[List[int]] = None,
    expected_rt: Optional[float] = None,
    weights: Optional[np.ndarray] = None,
    legend: bool = True,
    label: Optional[str] = None,
    **kwargs,
) -> Figure:
    """Plot time series data with peak annotations.

    Args:
        series: Time series data with time as index and intensity as values.
        peaks: DataFrame containing peak information.
        highlight: List of peak indices to highlight.
        expected_rt: Expected retention time to mark on the plot.
        weights: Array of weight values (e.g., for Gaussian weighting).
        legend: Whether to display the legend.
        label: Label for the time series data.
        **kwargs: Additional keyword arguments passed to the plot function.

    Returns:
        Matplotlib Figure containing the plot.
    """
    if highlight is None:
        highlight = []
    ax = plt.gca()
    ax.plot(
        series.index,
        series.values,
        label=label if label is not None else "Intensity",
        **kwargs,
    )
    if peaks is not None:
        series.iloc[peaks.ndxs].plot(label="Peaks", marker="x", y="intensity", lw=0, ax=ax)
        for i, (
            ndx,
            (_, _, _, peak_base_height, _, rt_min, rt_max),
        ) in enumerate(peaks.iterrows()):
            if ndx in highlight:
                plt.axvspan(rt_min, rt_max, color="green", alpha=0.25, label="Selected")
            plt.hlines(
                peak_base_height,
                rt_min,
                rt_max,
                color="orange",
                label="Peak width" if i == 0 else None,
            )
    if expected_rt is not None:
        plt.axvspan(expected_rt, expected_rt + 1, color="blue", alpha=1, label="Expected Rt")
    if weights is not None:
        plt.plot(weights, linestyle="--", label="Gaussian weight")
    plt.ylabel("Intensity")
    plt.xlabel("Scan time [s]")
    ax.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
    if not legend:
        ax.get_legend().remove()
    return plt.gcf()

options: show_root_heading: true show_root_full_path: true show_submodules: true members_order: source

get_palette_colors(palette_name, num_colors)

Get a list of colors from a specific colorlover palette.

Parameters:

Name Type Description Default
palette_name str

Name of the color palette.

required
num_colors int

Number of colors to extract.

required

Returns:

Type Description
List[str]

List of color strings in the requested palette.

Source code in src/ms_mint/plotly_tools.py
def get_palette_colors(palette_name: str, num_colors: int) -> List[str]:
    """Get a list of colors from a specific colorlover palette.

    Args:
        palette_name: Name of the color palette.
        num_colors: Number of colors to extract.

    Returns:
        List of color strings in the requested palette.
    """
    # Categories in the colorlover package
    categories = ["qual", "seq", "div"]

    num_colors = max(num_colors, 3)
    # Check in which category our palette resides
    for category in categories:
        if palette_name in cl.scales[f"{num_colors}"][category]:
            return cl.scales[f"{num_colors}"][category][palette_name]

    # If palette not found in any category, return a default one
    return cl.scales[f"{num_colors}"]["qual"]["Paired"]

plotly_heatmap(df, normed_by_cols=False, transposed=False, clustered=False, add_dendrogram=False, name='', x_tick_colors=None, height=None, width=None, correlation=False, call_show=False, verbose=False)

Create an interactive heatmap from a dense-formatted dataframe.

Parameters:

Name Type Description Default
df DataFrame

Input data in DataFrame format.

required
normed_by_cols bool

Whether to normalize column vectors.

False
transposed bool

Whether to transpose the generated image.

False
clustered bool

Whether to apply hierarchical clustering on rows.

False
add_dendrogram bool

Whether to show a dendrogram (only when clustered=True).

False
name str

Name to use in figure title.

''
x_tick_colors Optional[str]

Color of x-ticks.

None
height Optional[int]

Image height in pixels.

None
width Optional[int]

Image width in pixels.

None
correlation bool

Whether to convert the table to a correlation matrix.

False
call_show bool

Whether to display the figure immediately.

False
verbose bool

Whether to print additional information.

False

Returns:

Type Description
Optional[Figure]

A Plotly Figure object, or None if call_show is True.

Source code in src/ms_mint/plotly_tools.py
def plotly_heatmap(
    df: pd.DataFrame,
    normed_by_cols: bool = False,
    transposed: bool = False,
    clustered: bool = False,
    add_dendrogram: bool = False,
    name: str = "",
    x_tick_colors: Optional[str] = None,
    height: Optional[int] = None,
    width: Optional[int] = None,
    correlation: bool = False,
    call_show: bool = False,
    verbose: bool = False,
) -> Optional[PlotlyFigure]:
    """Create an interactive heatmap from a dense-formatted dataframe.

    Args:
        df: Input data in DataFrame format.
        normed_by_cols: Whether to normalize column vectors.
        transposed: Whether to transpose the generated image.
        clustered: Whether to apply hierarchical clustering on rows.
        add_dendrogram: Whether to show a dendrogram (only when clustered=True).
        name: Name to use in figure title.
        x_tick_colors: Color of x-ticks.
        height: Image height in pixels.
        width: Image width in pixels.
        correlation: Whether to convert the table to a correlation matrix.
        call_show: Whether to display the figure immediately.
        verbose: Whether to print additional information.

    Returns:
        A Plotly Figure object, or None if call_show is True.
    """
    max_is_not_zero = df.max(axis=1) != 0
    non_zero_labels = max_is_not_zero[max_is_not_zero].index
    df = df.loc[non_zero_labels]

    colorscale = "Bluered"
    plot_attributes = []

    if normed_by_cols:
        df = df.divide(df.max()).fillna(0)
        plot_attributes.append("normalized")

    if transposed:
        df = df.T

    if correlation:
        plot_type = "Correlation"
        df = df.corr()
        colorscale = [
            [0.0, "rgb(165,0,38)"],
            [0.1111111111111111, "rgb(215,48,39)"],
            [0.2222222222222222, "rgb(244,109,67)"],
            [0.3333333333333333, "rgb(253,174,97)"],
            [0.4444444444444444, "rgb(254,224,144)"],
            [0.5555555555555556, "rgb(224,243,248)"],
            [0.6666666666666666, "rgb(171,217,233)"],
            [0.7777777777777778, "rgb(116,173,209)"],
            [0.8888888888888888, "rgb(69,117,180)"],
            [1.0, "rgb(49,54,149)"],
        ]
    else:
        plot_type = "Heatmap"

    if clustered:
        dendro_side = ff.create_dendrogram(
            df,
            orientation="right",
            labels=df.index.to_list(),
            color_threshold=0,
            colorscale=["black"] * 8,
        )
        dendro_leaves = dendro_side["layout"]["yaxis"]["ticktext"]
        df = df.loc[dendro_leaves, :]
        if correlation:
            df = df[df.index]

    x = df.columns
    if clustered:
        y = dendro_leaves
    else:
        y = df.index.to_list()
    z = df.values

    heatmap = go.Heatmap(x=x, y=y, z=z, colorscale=colorscale)

    if name == "":
        title = ""
    else:
        title = f"{plot_type} of {','.join(plot_attributes)} {name}"

    # Figure without side-dendrogram
    if (not add_dendrogram) or (not clustered):
        fig = go.Figure(heatmap)
        fig.update_layout(
            {"title_x": 0.5},
            title={"text": title},
            yaxis={"title": "", "tickmode": "array", "automargin": True},
        )

        fig.update_layout({"height": height, "width": width, "hovermode": "closest"})

    else:  # Figure with side-dendrogram
        fig = go.Figure()

        for i in range(len(dendro_side["data"])):
            dendro_side["data"][i]["xaxis"] = "x2"

        for data in dendro_side["data"]:
            fig.add_trace(data)

        y_labels = heatmap["y"]
        heatmap["y"] = dendro_side["layout"]["yaxis"]["tickvals"]

        fig.add_trace(heatmap)

        fig.update_layout(
            {
                "height": height,
                "width": width,
                "showlegend": False,
                "hovermode": "closest",
                "paper_bgcolor": "white",
                "plot_bgcolor": "white",
                "title_x": 0.5,
            },
            title={"text": title},
            # X-axis of main figure
            xaxis={
                "domain": [0.11, 1],
                "mirror": False,
                "showgrid": False,
                "showline": False,
                "zeroline": False,
                "showticklabels": True,
                "ticks": "",
            },
            # X-axis of side-dendrogram
            xaxis2={
                "domain": [0, 0.1],
                "mirror": False,
                "showgrid": True,
                "showline": False,
                "zeroline": False,
                "showticklabels": False,
                "ticks": "",
            },
            # Y-axis of main figure
            yaxis={
                "domain": [0, 1],
                "mirror": False,
                "showgrid": False,
                "showline": False,
                "zeroline": False,
                "showticklabels": False,
            },
        )

        fig["layout"]["yaxis"]["ticktext"] = np.asarray(y_labels)
        fig["layout"]["yaxis"]["tickvals"] = np.asarray(dendro_side["layout"]["yaxis"]["tickvals"])

    fig.update_layout(
        autosize=True,
        hovermode="closest",
    )

    fig.update_yaxes(automargin=True)
    fig.update_xaxes(automargin=True)

    if call_show:
        fig.show(config={"displaylogo": False})
        return None
    else:
        return fig

plotly_peak_shapes(mint_results, mint_metadata=None, color='ms_file_label', fns=None, col_wrap=1, peak_labels=None, legend=True, verbose=False, legend_orientation='v', call_show=False, palette='Plasma')

Plot peak shapes from mint results as interactive Plotly figure.

Parameters:

Name Type Description Default
mint_results DataFrame

DataFrame in Mint results format.

required
mint_metadata Optional[DataFrame]

DataFrame in Mint metadata format.

None
color str

Column name determining color-coding of plots.

'ms_file_label'
fns Optional[List[str]]

Filenames to include. If None, all files are used.

None
col_wrap int

Maximum number of subplot columns.

1
peak_labels Optional[Union[str, List[str]]]

Peak-labels to include. If None, all peaks are used.

None
legend bool

Whether to display legend.

True
verbose bool

If True, prints additional details.

False
legend_orientation str

Legend orientation ('v' for vertical, 'h' for horizontal).

'v'
call_show bool

If True, displays the plot immediately.

False
palette str

Color palette to use.

'Plasma'

Returns:

Type Description
Optional[Figure]

A Plotly Figure object, or None if call_show is True.

Source code in src/ms_mint/plotly_tools.py
def plotly_peak_shapes(
    mint_results: pd.DataFrame,
    mint_metadata: Optional[pd.DataFrame] = None,
    color: str = "ms_file_label",
    fns: Optional[List[str]] = None,
    col_wrap: int = 1,
    peak_labels: Optional[Union[str, List[str]]] = None,
    legend: bool = True,
    verbose: bool = False,
    legend_orientation: str = "v",
    call_show: bool = False,
    palette: str = "Plasma",
) -> Optional[PlotlyFigure]:
    """Plot peak shapes from mint results as interactive Plotly figure.

    Args:
        mint_results: DataFrame in Mint results format.
        mint_metadata: DataFrame in Mint metadata format.
        color: Column name determining color-coding of plots.
        fns: Filenames to include. If None, all files are used.
        col_wrap: Maximum number of subplot columns.
        peak_labels: Peak-labels to include. If None, all peaks are used.
        legend: Whether to display legend.
        verbose: If True, prints additional details.
        legend_orientation: Legend orientation ('v' for vertical, 'h' for horizontal).
        call_show: If True, displays the plot immediately.
        palette: Color palette to use.

    Returns:
        A Plotly Figure object, or None if call_show is True.
    """
    mint_results = mint_results.copy()

    # Merge with metadata if provided
    if mint_metadata is not None:
        mint_results = pd.merge(
            mint_results, mint_metadata, left_on="ms_file_label", right_index=True
        )

    # Filter by filenames
    if fns is not None:
        fns = [fn_to_label(fn) for fn in fns]
        mint_results = mint_results[mint_results.ms_file_label.isin(fns)]
    else:
        fns = mint_results.ms_file_label.unique()

    # Filter by peak_labels
    if peak_labels is not None:
        if isinstance(peak_labels, str):
            peak_labels = [peak_labels]
        mint_results = mint_results[mint_results.peak_label.isin(peak_labels)]
    else:
        peak_labels = mint_results.peak_label.unique()

    # Handle colors based on metadata or fall back to default behavior
    colors = None
    if color:
        unique_hues = mint_results[color].unique()

        colors = get_palette_colors(palette, len(unique_hues))

        color_mapping = dict(zip(unique_hues, colors))

        if color == "ms_file_label":
            hue_column = [color_mapping[fn] for fn in fns]
        else:
            # Existing logic remains the same for the else part
            hue_column = (
                mint_results.drop_duplicates("ms_file_label")
                .set_index("ms_file_label")[color]
                .map(color_mapping)
                .reindex(fns)
                .tolist()
            )

    else:
        hue_column = colors

    # Rest of the plotting process
    res = mint_results[mint_results.peak_max > 0]
    labels = mint_results.peak_label.unique()
    res = res.set_index(["peak_label", "ms_file_label"]).sort_index()

    # Calculate necessary number of rows
    n_rows = max(1, len(labels) // col_wrap)
    if n_rows * col_wrap < len(labels):
        n_rows += 1

    fig = make_subplots(rows=max(1, n_rows), cols=max(1, col_wrap), subplot_titles=peak_labels)

    for label_i, label in enumerate(peak_labels):
        for file_i, fn in enumerate(fns):
            try:
                x, y = res.loc[(label, fn), ["peak_shape_rt", "peak_shape_int"]]
            except KeyError as e:
                logging.warning(e)
                continue

            if not isinstance(x, Iterable):
                continue
            if isinstance(x, str):
                x = x.split(",")
                y = y.split(",")

            ndx_r = (label_i // col_wrap) + 1
            ndx_c = label_i % col_wrap + 1

            trace_color = hue_column[file_i]

            fig.add_trace(
                go.Scattergl(
                    x=x,
                    y=y,
                    name=P(fn).name,
                    mode="markers",
                    legendgroup=file_i,
                    showlegend=(label_i == 0),
                    marker_color=trace_color,
                    text=fn,
                    fill="tozeroy",
                    marker=dict(size=3),
                ),
                row=ndx_r,
                col=ndx_c,
            )

            fig.update_xaxes(title_text="Scan time [s]", row=ndx_r, col=ndx_c)
            fig.update_yaxes(title_text="Intensity", row=ndx_r, col=ndx_c)

    # Layout updates
    if legend:
        fig.update_layout(legend_orientation=legend_orientation)

    fig.update_layout(showlegend=legend)
    fig.update_layout(height=400 * n_rows, title_text="Peak Shapes")

    if call_show:
        fig.show(config={"displaylogo": False})
        return None
    else:
        return fig

set_template()

Set a default template for plotly figures.

Creates a "draft" template with smaller font size and sets it as the default template for all plotly figures.

Source code in src/ms_mint/plotly_tools.py
def set_template() -> None:
    """Set a default template for plotly figures.

    Creates a "draft" template with smaller font size and sets it as the default
    template for all plotly figures.
    """
    pio.templates["draft"] = go.layout.Template(
        layout=dict(font={"size": 10}),
    )

    pio.templates.default = "draft"

options: show_root_heading: true show_root_full_path: true show_submodules: true members_order: source

PCA_Plotter

Class for visualizing PCA results from MS-MINT analysis.

This class provides methods to create various plots of PCA results, including cumulative variance plots, pairplots, and loading plots.

Attributes:

Name Type Description
pca

The PrincipalComponentsAnalyser instance containing results to visualize.

Source code in src/ms_mint/pca.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
class PCA_Plotter:
    """Class for visualizing PCA results from MS-MINT analysis.

    This class provides methods to create various plots of PCA results,
    including cumulative variance plots, pairplots, and loading plots.

    Attributes:
        pca: The PrincipalComponentsAnalyser instance containing results to visualize.
    """

    def __init__(self, pca: PrincipalComponentsAnalyser) -> None:
        """Initialize a PCA_Plotter instance.

        Args:
            pca: PrincipalComponentsAnalyser instance with results to visualize.
        """
        self.pca = pca

    def cumulative_variance(
        self, interactive: bool = False, **kwargs
    ) -> Union[Figure, PlotlyFigure]:
        """Plot the cumulative explained variance of principal components.

        Args:
            interactive: If True, returns a Plotly interactive figure.
                If False, returns a static Matplotlib figure.
            **kwargs: Additional keyword arguments passed to the underlying plotting functions.

        Returns:
            Either a Matplotlib figure or a Plotly figure depending on the interactive parameter.
        """
        if interactive:
            return self.cumulative_variance_px(**kwargs)
        else:
            return self.cumulative_variance_sns(**kwargs)

    def cumulative_variance_px(self, **kwargs) -> PlotlyFigure:
        """Create an interactive Plotly plot of cumulative explained variance.

        Args:
            **kwargs: Additional keyword arguments passed to px.bar.

        Returns:
            Plotly figure showing cumulative explained variance.
        """
        n_components = self.pca.results["n_components"]
        cum_expl_var = self.pca.results["cum_expl_var"]
        df = pd.DataFrame(
            {
                "Principal Component": np.arange(n_components) + 1,
                "Explained variance [%]": cum_expl_var,
            }
        )
        fig = px.bar(
            df,
            x="Principal Component",
            y="Explained variance [%]",
            title="Cumulative explained variance",
            labels={
                "Principal Component": "Principal Component",
                "Explained variance [%]": "Explained variance [%]",
            },
            **kwargs,
        )
        fig.update_layout(autosize=True, showlegend=False)
        return fig

    def cumulative_variance_sns(self, **kwargs) -> Figure:
        """Create a static Matplotlib plot of cumulative explained variance.

        Args:
            **kwargs: Additional keyword arguments for figure customization.
                'aspect': Width-to-height ratio of the figure (default: 1).
                'height': Height of the figure in inches (default: 5).

        Returns:
            Matplotlib figure showing cumulative explained variance.
        """
        # Set default values for aspect and height
        aspect = kwargs.get("aspect", 1)
        height = kwargs.get("height", 5)

        n_components = self.pca.results["n_components"]
        cum_expl_var = self.pca.results["cum_expl_var"]

        # Calculate width based on aspect ratio and number of components
        width = height * aspect

        fig, ax = plt.subplots(figsize=(width, height))
        ax.bar(
            np.arange(n_components) + 1,
            cum_expl_var,
            facecolor="grey",
            edgecolor="none",
        )
        ax.set_xlabel("Principal Component")
        ax.set_ylabel("Explained variance [%]")
        ax.set_title("Cumulative explained variance")
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.set_xticks(range(1, len(cum_expl_var) + 1))
        return fig

    def scatter(
        self,
        x_component: int = 1,
        y_component: int = 2,
        color_by: Optional[str] = None,
        interactive: bool = False,
        **kwargs,
    ) -> Union[Figure, PlotlyFigure]:
        """Create a scatter plot of two principal components.

        Args:
            x_component: Principal component number for x-axis (1-indexed).
            y_component: Principal component number for y-axis (1-indexed).
            color_by: Metadata column to use for coloring points.
            interactive: If True, returns a Plotly interactive figure.
            **kwargs: Additional keyword arguments passed to plotting functions.

        Returns:
            Either a Matplotlib figure or a Plotly figure depending on interactive.
        """
        if interactive:
            return self.scatter_plotly(x_component, y_component, color_by, **kwargs)
        else:
            return self.scatter_sns(x_component, y_component, color_by, **kwargs)

    def scatter_sns(
        self,
        x_component: int = 1,
        y_component: int = 2,
        color_by: Optional[str] = None,
        **kwargs,
    ) -> Figure:
        """Create a static scatter plot of two principal components.

        Args:
            x_component: Principal component number for x-axis (1-indexed).
            y_component: Principal component number for y-axis (1-indexed).
            color_by: Metadata column to use for coloring points.
            **kwargs: Additional keyword arguments for figure customization.

        Returns:
            Matplotlib figure showing the scatter plot.
        """
        df = self.pca.results["df_projected"].copy()
        x_col = f"PC-{x_component}"
        y_col = f"PC-{y_component}"

        if x_col not in df.columns or y_col not in df.columns:
            raise ValueError(f"Components {x_component} or {y_component} not available")

        # Merge with metadata if color_by is specified
        if color_by and color_by != "none":
            meta = self.pca.mint.meta.dropna(axis=1, how="all")
            if color_by in meta.columns:
                df = pd.merge(df, meta[[color_by]], left_index=True, right_index=True, how="left")

        height = kwargs.get("height", 6)
        width = kwargs.get("width", 8)

        fig, ax = plt.subplots(figsize=(width, height))

        if color_by and color_by != "none" and color_by in df.columns:
            # Get unique categories
            categories = df[color_by].dropna().unique()
            colors = plt.cm.tab10(np.linspace(0, 1, len(categories)))
            for cat, color in zip(categories, colors):
                mask = df[color_by] == cat
                ax.scatter(df.loc[mask, x_col], df.loc[mask, y_col],
                          c=[color], label=str(cat), alpha=0.7, s=50)
            ax.legend(title=color_by, bbox_to_anchor=(1.02, 1), loc='upper left')
        else:
            ax.scatter(df[x_col], df[y_col], alpha=0.7, s=50, c='steelblue')

        # Get explained variance for axis labels
        cum_var = self.pca.results["cum_expl_var"]
        var_x = cum_var[x_component - 1] if x_component == 1 else cum_var[x_component - 1] - cum_var[x_component - 2]
        var_y = cum_var[y_component - 1] if y_component == 1 else cum_var[y_component - 1] - cum_var[y_component - 2]

        ax.set_xlabel(f"{x_col} ({var_x:.1f}%)")
        ax.set_ylabel(f"{y_col} ({var_y:.1f}%)")
        ax.set_title(f"PCA: {x_col} vs {y_col}")
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

        plt.tight_layout()
        return fig

    def scatter_plotly(
        self,
        x_component: int = 1,
        y_component: int = 2,
        color_by: Optional[str] = None,
        **kwargs,
    ) -> PlotlyFigure:
        """Create an interactive Plotly scatter plot of two principal components.

        Args:
            x_component: Principal component number for x-axis (1-indexed).
            y_component: Principal component number for y-axis (1-indexed).
            color_by: Metadata column to use for coloring points.
            **kwargs: Additional keyword arguments passed to px.scatter.

        Returns:
            Plotly figure showing the scatter plot.
        """
        df = self.pca.results["df_projected"].copy()
        x_col = f"PC-{x_component}"
        y_col = f"PC-{y_component}"

        if x_col not in df.columns or y_col not in df.columns:
            raise ValueError(f"Components {x_component} or {y_component} not available")

        # Merge with metadata if color_by is specified
        if color_by and color_by != "none":
            meta = self.pca.mint.meta.dropna(axis=1, how="all")
            if color_by in meta.columns:
                df = pd.merge(df, meta[[color_by]], left_index=True, right_index=True, how="left")

        # Get explained variance for axis labels
        cum_var = self.pca.results["cum_expl_var"]
        var_x = cum_var[x_component - 1] if x_component == 1 else cum_var[x_component - 1] - cum_var[x_component - 2]
        var_y = cum_var[y_component - 1] if y_component == 1 else cum_var[y_component - 1] - cum_var[y_component - 2]

        color_col = color_by if (color_by and color_by != "none" and color_by in df.columns) else None

        fig = px.scatter(
            df.reset_index(),
            x=x_col,
            y=y_col,
            color=color_col,
            hover_name="index" if "index" in df.reset_index().columns else None,
            labels={
                x_col: f"{x_col} ({var_x:.1f}%)",
                y_col: f"{y_col} ({var_y:.1f}%)",
            },
            title=f"PCA: {x_col} vs {y_col}",
            **kwargs,
        )
        fig.update_layout(autosize=True)
        return fig

    def _prepare_data(
        self, n_components: int = 3, hue: Optional[Union[str, List[str]]] = None
    ) -> pd.DataFrame:
        """Prepare data for pairplot visualization.

        Args:
            n_components: Number of principal components to include.
            hue: Labels used for coloring points. If a string, data is taken from
                the mint.meta DataFrame. If a list, values are used directly.

        Returns:
            DataFrame containing the prepared data for visualization.
        """
        df = self.pca.results["df_projected"].copy()
        cols = df.columns.to_list()[:n_components]
        df = df[cols]

        df = pd.merge(
            df, self.pca.mint.meta.dropna(axis=1, how="all"), left_index=True, right_index=True
        )

        if hue and (not isinstance(hue, str)):
            df["Label"] = hue
            df["Label"] = df["Label"].astype(str)

        return df

    def pairplot(
        self,
        n_components: int = 3,
        hue: Optional[Union[str, List[str]]] = None,
        fig_kws: Optional[Dict[str, Any]] = None,
        interactive: bool = False,
        **kwargs,
    ) -> Union[sns.axisgrid.PairGrid, PlotlyFigure]:
        """Create a pairplot of principal components.

        Args:
            n_components: Number of principal components to include in the plot.
            hue: Labels used for coloring points. If a string, data is taken from
                the mint.meta DataFrame. If a list, values are used directly.
            fig_kws: Keyword arguments passed to plt.figure if using seaborn.
            interactive: If True, returns a Plotly interactive figure.
                If False, returns a static Seaborn PairGrid.
            **kwargs: Additional keyword arguments passed to the underlying plotting functions.

        Returns:
            Either a Seaborn PairGrid or a Plotly figure depending on the interactive parameter.
        """
        df = self._prepare_data(n_components=n_components, hue=hue)

        if isinstance(hue, list):
            hue = "label"

        if interactive:
            return self.pairplot_plotly(df, color_col=hue, **kwargs)
        else:
            return self.pairplot_sns(df, fig_kws=fig_kws, hue=hue, **kwargs)

    def pairplot_sns(
        self, df: pd.DataFrame, fig_kws: Optional[Dict[str, Any]] = None, **kwargs
    ) -> sns.axisgrid.PairGrid:
        """Create a static Seaborn pairplot of principal components.

        Args:
            df: DataFrame containing the data to visualize.
            fig_kws: Keyword arguments passed to plt.figure.
            **kwargs: Additional keyword arguments passed to sns.pairplot.

        Returns:
            Seaborn PairGrid object.
        """
        if fig_kws is None:
            fig_kws = {}
        # Only plot PC columns, not merged metadata columns
        pc_cols = [c for c in df.columns if c.startswith("PC-")]
        if "vars" not in kwargs:
            kwargs["vars"] = pc_cols
        plt.figure(**fig_kws)
        g = sns.pairplot(df, **kwargs)
        return g

    def pairplot_plotly(
        self, df: pd.DataFrame, color_col: Optional[str] = None, **kwargs
    ) -> PlotlyFigure:
        """Create an interactive Plotly pairplot of principal components.

        Args:
            df: DataFrame containing the data to visualize.
            color_col: Column name to use for coloring points.
            **kwargs: Additional keyword arguments passed to ff.create_scatterplotmatrix.

        Returns:
            Plotly figure object.
        """
        columns = df.filter(regex=f"PC|^{color_col}$").columns
        fig = ff.create_scatterplotmatrix(
            df[columns], index=color_col, hovertext=df.index, **kwargs
        )
        # set the legendgroup equal to the marker color
        for t in fig.data:
            t.legendgroup = t.marker.color
        return fig

    def loadings(
        self, interactive: bool = False, **kwargs
    ) -> Union[sns.axisgrid.FacetGrid, PlotlyFigure]:
        """Plot PCA loadings (feature contributions to principal components).

        Args:
            interactive: If True, returns a Plotly interactive figure.
                If False, returns a static Seaborn FacetGrid.
            **kwargs: Additional keyword arguments passed to the underlying plotting functions.

        Returns:
            Either a Seaborn FacetGrid or a Plotly figure depending on the interactive parameter.
        """
        if interactive:
            return self.loadings_plotly(**kwargs)
        else:
            return self.loadings_sns(**kwargs)

    def loadings_sns(self, **kwargs) -> sns.axisgrid.FacetGrid:
        """Create a static Seaborn plot of PCA loadings.

        Args:
            **kwargs: Additional keyword arguments passed to sns.catplot.
                If 'row' is not specified, it defaults to 'PC'.

        Returns:
            Seaborn FacetGrid object.
        """
        if "row" not in kwargs:
            kwargs["row"] = "PC"
        g = sns.catplot(
            data=self.pca.results["feature_contributions"],
            x="peak_label",
            y="Coefficient",
            kind="bar",
            **kwargs,
        )
        plt.tight_layout()
        return g

    def loadings_plotly(self, **kwargs) -> PlotlyFigure:
        """Create an interactive Plotly plot of PCA loadings.

        Args:
            **kwargs: Additional keyword arguments passed to px.bar.
                If 'facet_row' is not specified, it defaults to 'PC'.

        Returns:
            Plotly figure object.
        """
        if "facet_row" not in kwargs:
            kwargs["facet_row"] = "PC"
        fig = px.bar(
            self.pca.results["feature_contributions"],
            x="peak_label",
            y="Coefficient",
            barmode="group",
            **kwargs,
        )
        return fig

__init__(pca)

Initialize a PCA_Plotter instance.

Parameters:

Name Type Description Default
pca PrincipalComponentsAnalyser

PrincipalComponentsAnalyser instance with results to visualize.

required
Source code in src/ms_mint/pca.py
def __init__(self, pca: PrincipalComponentsAnalyser) -> None:
    """Initialize a PCA_Plotter instance.

    Args:
        pca: PrincipalComponentsAnalyser instance with results to visualize.
    """
    self.pca = pca

cumulative_variance(interactive=False, **kwargs)

Plot the cumulative explained variance of principal components.

Parameters:

Name Type Description Default
interactive bool

If True, returns a Plotly interactive figure. If False, returns a static Matplotlib figure.

False
**kwargs

Additional keyword arguments passed to the underlying plotting functions.

{}

Returns:

Type Description
Union[Figure, Figure]

Either a Matplotlib figure or a Plotly figure depending on the interactive parameter.

Source code in src/ms_mint/pca.py
def cumulative_variance(
    self, interactive: bool = False, **kwargs
) -> Union[Figure, PlotlyFigure]:
    """Plot the cumulative explained variance of principal components.

    Args:
        interactive: If True, returns a Plotly interactive figure.
            If False, returns a static Matplotlib figure.
        **kwargs: Additional keyword arguments passed to the underlying plotting functions.

    Returns:
        Either a Matplotlib figure or a Plotly figure depending on the interactive parameter.
    """
    if interactive:
        return self.cumulative_variance_px(**kwargs)
    else:
        return self.cumulative_variance_sns(**kwargs)

cumulative_variance_px(**kwargs)

Create an interactive Plotly plot of cumulative explained variance.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments passed to px.bar.

{}

Returns:

Type Description
Figure

Plotly figure showing cumulative explained variance.

Source code in src/ms_mint/pca.py
def cumulative_variance_px(self, **kwargs) -> PlotlyFigure:
    """Create an interactive Plotly plot of cumulative explained variance.

    Args:
        **kwargs: Additional keyword arguments passed to px.bar.

    Returns:
        Plotly figure showing cumulative explained variance.
    """
    n_components = self.pca.results["n_components"]
    cum_expl_var = self.pca.results["cum_expl_var"]
    df = pd.DataFrame(
        {
            "Principal Component": np.arange(n_components) + 1,
            "Explained variance [%]": cum_expl_var,
        }
    )
    fig = px.bar(
        df,
        x="Principal Component",
        y="Explained variance [%]",
        title="Cumulative explained variance",
        labels={
            "Principal Component": "Principal Component",
            "Explained variance [%]": "Explained variance [%]",
        },
        **kwargs,
    )
    fig.update_layout(autosize=True, showlegend=False)
    return fig

cumulative_variance_sns(**kwargs)

Create a static Matplotlib plot of cumulative explained variance.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments for figure customization. 'aspect': Width-to-height ratio of the figure (default: 1). 'height': Height of the figure in inches (default: 5).

{}

Returns:

Type Description
Figure

Matplotlib figure showing cumulative explained variance.

Source code in src/ms_mint/pca.py
def cumulative_variance_sns(self, **kwargs) -> Figure:
    """Create a static Matplotlib plot of cumulative explained variance.

    Args:
        **kwargs: Additional keyword arguments for figure customization.
            'aspect': Width-to-height ratio of the figure (default: 1).
            'height': Height of the figure in inches (default: 5).

    Returns:
        Matplotlib figure showing cumulative explained variance.
    """
    # Set default values for aspect and height
    aspect = kwargs.get("aspect", 1)
    height = kwargs.get("height", 5)

    n_components = self.pca.results["n_components"]
    cum_expl_var = self.pca.results["cum_expl_var"]

    # Calculate width based on aspect ratio and number of components
    width = height * aspect

    fig, ax = plt.subplots(figsize=(width, height))
    ax.bar(
        np.arange(n_components) + 1,
        cum_expl_var,
        facecolor="grey",
        edgecolor="none",
    )
    ax.set_xlabel("Principal Component")
    ax.set_ylabel("Explained variance [%]")
    ax.set_title("Cumulative explained variance")
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.set_xticks(range(1, len(cum_expl_var) + 1))
    return fig

loadings(interactive=False, **kwargs)

Plot PCA loadings (feature contributions to principal components).

Parameters:

Name Type Description Default
interactive bool

If True, returns a Plotly interactive figure. If False, returns a static Seaborn FacetGrid.

False
**kwargs

Additional keyword arguments passed to the underlying plotting functions.

{}

Returns:

Type Description
Union[FacetGrid, Figure]

Either a Seaborn FacetGrid or a Plotly figure depending on the interactive parameter.

Source code in src/ms_mint/pca.py
def loadings(
    self, interactive: bool = False, **kwargs
) -> Union[sns.axisgrid.FacetGrid, PlotlyFigure]:
    """Plot PCA loadings (feature contributions to principal components).

    Args:
        interactive: If True, returns a Plotly interactive figure.
            If False, returns a static Seaborn FacetGrid.
        **kwargs: Additional keyword arguments passed to the underlying plotting functions.

    Returns:
        Either a Seaborn FacetGrid or a Plotly figure depending on the interactive parameter.
    """
    if interactive:
        return self.loadings_plotly(**kwargs)
    else:
        return self.loadings_sns(**kwargs)

loadings_plotly(**kwargs)

Create an interactive Plotly plot of PCA loadings.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments passed to px.bar. If 'facet_row' is not specified, it defaults to 'PC'.

{}

Returns:

Type Description
Figure

Plotly figure object.

Source code in src/ms_mint/pca.py
def loadings_plotly(self, **kwargs) -> PlotlyFigure:
    """Create an interactive Plotly plot of PCA loadings.

    Args:
        **kwargs: Additional keyword arguments passed to px.bar.
            If 'facet_row' is not specified, it defaults to 'PC'.

    Returns:
        Plotly figure object.
    """
    if "facet_row" not in kwargs:
        kwargs["facet_row"] = "PC"
    fig = px.bar(
        self.pca.results["feature_contributions"],
        x="peak_label",
        y="Coefficient",
        barmode="group",
        **kwargs,
    )
    return fig

loadings_sns(**kwargs)

Create a static Seaborn plot of PCA loadings.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments passed to sns.catplot. If 'row' is not specified, it defaults to 'PC'.

{}

Returns:

Type Description
FacetGrid

Seaborn FacetGrid object.

Source code in src/ms_mint/pca.py
def loadings_sns(self, **kwargs) -> sns.axisgrid.FacetGrid:
    """Create a static Seaborn plot of PCA loadings.

    Args:
        **kwargs: Additional keyword arguments passed to sns.catplot.
            If 'row' is not specified, it defaults to 'PC'.

    Returns:
        Seaborn FacetGrid object.
    """
    if "row" not in kwargs:
        kwargs["row"] = "PC"
    g = sns.catplot(
        data=self.pca.results["feature_contributions"],
        x="peak_label",
        y="Coefficient",
        kind="bar",
        **kwargs,
    )
    plt.tight_layout()
    return g

pairplot(n_components=3, hue=None, fig_kws=None, interactive=False, **kwargs)

Create a pairplot of principal components.

Parameters:

Name Type Description Default
n_components int

Number of principal components to include in the plot.

3
hue Optional[Union[str, List[str]]]

Labels used for coloring points. If a string, data is taken from the mint.meta DataFrame. If a list, values are used directly.

None
fig_kws Optional[Dict[str, Any]]

Keyword arguments passed to plt.figure if using seaborn.

None
interactive bool

If True, returns a Plotly interactive figure. If False, returns a static Seaborn PairGrid.

False
**kwargs

Additional keyword arguments passed to the underlying plotting functions.

{}

Returns:

Type Description
Union[PairGrid, Figure]

Either a Seaborn PairGrid or a Plotly figure depending on the interactive parameter.

Source code in src/ms_mint/pca.py
def pairplot(
    self,
    n_components: int = 3,
    hue: Optional[Union[str, List[str]]] = None,
    fig_kws: Optional[Dict[str, Any]] = None,
    interactive: bool = False,
    **kwargs,
) -> Union[sns.axisgrid.PairGrid, PlotlyFigure]:
    """Create a pairplot of principal components.

    Args:
        n_components: Number of principal components to include in the plot.
        hue: Labels used for coloring points. If a string, data is taken from
            the mint.meta DataFrame. If a list, values are used directly.
        fig_kws: Keyword arguments passed to plt.figure if using seaborn.
        interactive: If True, returns a Plotly interactive figure.
            If False, returns a static Seaborn PairGrid.
        **kwargs: Additional keyword arguments passed to the underlying plotting functions.

    Returns:
        Either a Seaborn PairGrid or a Plotly figure depending on the interactive parameter.
    """
    df = self._prepare_data(n_components=n_components, hue=hue)

    if isinstance(hue, list):
        hue = "label"

    if interactive:
        return self.pairplot_plotly(df, color_col=hue, **kwargs)
    else:
        return self.pairplot_sns(df, fig_kws=fig_kws, hue=hue, **kwargs)

pairplot_plotly(df, color_col=None, **kwargs)

Create an interactive Plotly pairplot of principal components.

Parameters:

Name Type Description Default
df DataFrame

DataFrame containing the data to visualize.

required
color_col Optional[str]

Column name to use for coloring points.

None
**kwargs

Additional keyword arguments passed to ff.create_scatterplotmatrix.

{}

Returns:

Type Description
Figure

Plotly figure object.

Source code in src/ms_mint/pca.py
def pairplot_plotly(
    self, df: pd.DataFrame, color_col: Optional[str] = None, **kwargs
) -> PlotlyFigure:
    """Create an interactive Plotly pairplot of principal components.

    Args:
        df: DataFrame containing the data to visualize.
        color_col: Column name to use for coloring points.
        **kwargs: Additional keyword arguments passed to ff.create_scatterplotmatrix.

    Returns:
        Plotly figure object.
    """
    columns = df.filter(regex=f"PC|^{color_col}$").columns
    fig = ff.create_scatterplotmatrix(
        df[columns], index=color_col, hovertext=df.index, **kwargs
    )
    # set the legendgroup equal to the marker color
    for t in fig.data:
        t.legendgroup = t.marker.color
    return fig

pairplot_sns(df, fig_kws=None, **kwargs)

Create a static Seaborn pairplot of principal components.

Parameters:

Name Type Description Default
df DataFrame

DataFrame containing the data to visualize.

required
fig_kws Optional[Dict[str, Any]]

Keyword arguments passed to plt.figure.

None
**kwargs

Additional keyword arguments passed to sns.pairplot.

{}

Returns:

Type Description
PairGrid

Seaborn PairGrid object.

Source code in src/ms_mint/pca.py
def pairplot_sns(
    self, df: pd.DataFrame, fig_kws: Optional[Dict[str, Any]] = None, **kwargs
) -> sns.axisgrid.PairGrid:
    """Create a static Seaborn pairplot of principal components.

    Args:
        df: DataFrame containing the data to visualize.
        fig_kws: Keyword arguments passed to plt.figure.
        **kwargs: Additional keyword arguments passed to sns.pairplot.

    Returns:
        Seaborn PairGrid object.
    """
    if fig_kws is None:
        fig_kws = {}
    # Only plot PC columns, not merged metadata columns
    pc_cols = [c for c in df.columns if c.startswith("PC-")]
    if "vars" not in kwargs:
        kwargs["vars"] = pc_cols
    plt.figure(**fig_kws)
    g = sns.pairplot(df, **kwargs)
    return g

scatter(x_component=1, y_component=2, color_by=None, interactive=False, **kwargs)

Create a scatter plot of two principal components.

Parameters:

Name Type Description Default
x_component int

Principal component number for x-axis (1-indexed).

1
y_component int

Principal component number for y-axis (1-indexed).

2
color_by Optional[str]

Metadata column to use for coloring points.

None
interactive bool

If True, returns a Plotly interactive figure.

False
**kwargs

Additional keyword arguments passed to plotting functions.

{}

Returns:

Type Description
Union[Figure, Figure]

Either a Matplotlib figure or a Plotly figure depending on interactive.

Source code in src/ms_mint/pca.py
def scatter(
    self,
    x_component: int = 1,
    y_component: int = 2,
    color_by: Optional[str] = None,
    interactive: bool = False,
    **kwargs,
) -> Union[Figure, PlotlyFigure]:
    """Create a scatter plot of two principal components.

    Args:
        x_component: Principal component number for x-axis (1-indexed).
        y_component: Principal component number for y-axis (1-indexed).
        color_by: Metadata column to use for coloring points.
        interactive: If True, returns a Plotly interactive figure.
        **kwargs: Additional keyword arguments passed to plotting functions.

    Returns:
        Either a Matplotlib figure or a Plotly figure depending on interactive.
    """
    if interactive:
        return self.scatter_plotly(x_component, y_component, color_by, **kwargs)
    else:
        return self.scatter_sns(x_component, y_component, color_by, **kwargs)

scatter_plotly(x_component=1, y_component=2, color_by=None, **kwargs)

Create an interactive Plotly scatter plot of two principal components.

Parameters:

Name Type Description Default
x_component int

Principal component number for x-axis (1-indexed).

1
y_component int

Principal component number for y-axis (1-indexed).

2
color_by Optional[str]

Metadata column to use for coloring points.

None
**kwargs

Additional keyword arguments passed to px.scatter.

{}

Returns:

Type Description
Figure

Plotly figure showing the scatter plot.

Source code in src/ms_mint/pca.py
def scatter_plotly(
    self,
    x_component: int = 1,
    y_component: int = 2,
    color_by: Optional[str] = None,
    **kwargs,
) -> PlotlyFigure:
    """Create an interactive Plotly scatter plot of two principal components.

    Args:
        x_component: Principal component number for x-axis (1-indexed).
        y_component: Principal component number for y-axis (1-indexed).
        color_by: Metadata column to use for coloring points.
        **kwargs: Additional keyword arguments passed to px.scatter.

    Returns:
        Plotly figure showing the scatter plot.
    """
    df = self.pca.results["df_projected"].copy()
    x_col = f"PC-{x_component}"
    y_col = f"PC-{y_component}"

    if x_col not in df.columns or y_col not in df.columns:
        raise ValueError(f"Components {x_component} or {y_component} not available")

    # Merge with metadata if color_by is specified
    if color_by and color_by != "none":
        meta = self.pca.mint.meta.dropna(axis=1, how="all")
        if color_by in meta.columns:
            df = pd.merge(df, meta[[color_by]], left_index=True, right_index=True, how="left")

    # Get explained variance for axis labels
    cum_var = self.pca.results["cum_expl_var"]
    var_x = cum_var[x_component - 1] if x_component == 1 else cum_var[x_component - 1] - cum_var[x_component - 2]
    var_y = cum_var[y_component - 1] if y_component == 1 else cum_var[y_component - 1] - cum_var[y_component - 2]

    color_col = color_by if (color_by and color_by != "none" and color_by in df.columns) else None

    fig = px.scatter(
        df.reset_index(),
        x=x_col,
        y=y_col,
        color=color_col,
        hover_name="index" if "index" in df.reset_index().columns else None,
        labels={
            x_col: f"{x_col} ({var_x:.1f}%)",
            y_col: f"{y_col} ({var_y:.1f}%)",
        },
        title=f"PCA: {x_col} vs {y_col}",
        **kwargs,
    )
    fig.update_layout(autosize=True)
    return fig

scatter_sns(x_component=1, y_component=2, color_by=None, **kwargs)

Create a static scatter plot of two principal components.

Parameters:

Name Type Description Default
x_component int

Principal component number for x-axis (1-indexed).

1
y_component int

Principal component number for y-axis (1-indexed).

2
color_by Optional[str]

Metadata column to use for coloring points.

None
**kwargs

Additional keyword arguments for figure customization.

{}

Returns:

Type Description
Figure

Matplotlib figure showing the scatter plot.

Source code in src/ms_mint/pca.py
def scatter_sns(
    self,
    x_component: int = 1,
    y_component: int = 2,
    color_by: Optional[str] = None,
    **kwargs,
) -> Figure:
    """Create a static scatter plot of two principal components.

    Args:
        x_component: Principal component number for x-axis (1-indexed).
        y_component: Principal component number for y-axis (1-indexed).
        color_by: Metadata column to use for coloring points.
        **kwargs: Additional keyword arguments for figure customization.

    Returns:
        Matplotlib figure showing the scatter plot.
    """
    df = self.pca.results["df_projected"].copy()
    x_col = f"PC-{x_component}"
    y_col = f"PC-{y_component}"

    if x_col not in df.columns or y_col not in df.columns:
        raise ValueError(f"Components {x_component} or {y_component} not available")

    # Merge with metadata if color_by is specified
    if color_by and color_by != "none":
        meta = self.pca.mint.meta.dropna(axis=1, how="all")
        if color_by in meta.columns:
            df = pd.merge(df, meta[[color_by]], left_index=True, right_index=True, how="left")

    height = kwargs.get("height", 6)
    width = kwargs.get("width", 8)

    fig, ax = plt.subplots(figsize=(width, height))

    if color_by and color_by != "none" and color_by in df.columns:
        # Get unique categories
        categories = df[color_by].dropna().unique()
        colors = plt.cm.tab10(np.linspace(0, 1, len(categories)))
        for cat, color in zip(categories, colors):
            mask = df[color_by] == cat
            ax.scatter(df.loc[mask, x_col], df.loc[mask, y_col],
                      c=[color], label=str(cat), alpha=0.7, s=50)
        ax.legend(title=color_by, bbox_to_anchor=(1.02, 1), loc='upper left')
    else:
        ax.scatter(df[x_col], df[y_col], alpha=0.7, s=50, c='steelblue')

    # Get explained variance for axis labels
    cum_var = self.pca.results["cum_expl_var"]
    var_x = cum_var[x_component - 1] if x_component == 1 else cum_var[x_component - 1] - cum_var[x_component - 2]
    var_y = cum_var[y_component - 1] if y_component == 1 else cum_var[y_component - 1] - cum_var[y_component - 2]

    ax.set_xlabel(f"{x_col} ({var_x:.1f}%)")
    ax.set_ylabel(f"{y_col} ({var_y:.1f}%)")
    ax.set_title(f"PCA: {x_col} vs {y_col}")
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    plt.tight_layout()
    return fig

PrincipalComponentsAnalyser

Class for applying PCA to MS-MINT analysis results.

This class provides functionality to perform Principal Component Analysis on MS-MINT metabolomics data and store the results for visualization.

Attributes:

Name Type Description
mint

The Mint instance containing the data to analyze.

results Optional[Dict[str, Any]]

Dictionary containing PCA results after running the analysis.

plot

PCA_Plotter instance for visualizing the PCA results.

Source code in src/ms_mint/pca.py
class PrincipalComponentsAnalyser:
    """Class for applying PCA to MS-MINT analysis results.

    This class provides functionality to perform Principal Component Analysis on
    MS-MINT metabolomics data and store the results for visualization.

    Attributes:
        mint: The Mint instance containing the data to analyze.
        results: Dictionary containing PCA results after running the analysis.
        plot: PCA_Plotter instance for visualizing the PCA results.
    """

    def __init__(self, mint: Optional["ms_mint.Mint.Mint"] = None) -> None:
        """Initialize a PrincipalComponentsAnalyser instance.

        Args:
            mint: Mint instance containing the data to analyze.
        """
        self.mint = mint
        self.results: Optional[Dict[str, Any]] = None
        self.plot = PCA_Plotter(self)

    def run(
        self,
        n_components: int = 3,
        on: Optional[str] = None,
        var_name: str = "peak_max",
        fillna: Union[str, float] = "median",
        apply: Optional[str] = None,
        groupby: Optional[Union[str, List[str]]] = None,
        scaler: str = "standard",
        peak_labels: Optional[List[str]] = None,
    ) -> None:
        """Run Principal Component Analysis on the current results.

        Performs PCA on the data and stores results in self.results.

        Args:
            n_components: Number of PCA components to calculate.
            on: Deprecated, use var_name instead.
            var_name: Column name from results to use for PCA.
            fillna: Method to fill missing values. One of "median", "mean", "zero",
                or a numeric value.
            apply: Transformation to apply to the data before PCA.
            groupby: Column(s) to group by before analysis.
            scaler: Method to scale the data. One of "standard", "robust", "minmax".
            peak_labels: List of peak labels to include. If None, all peaks are used.

        Raises:
            DeprecationWarning: If the deprecated 'on' parameter is used.
        """
        if on is not None:
            warnings.warn("on is deprecated, use var_name instead", DeprecationWarning)
            var_name = on

        df = self.mint.crosstab(var_name=var_name, apply=apply, scaler=scaler, groupby=groupby, peak_labels=peak_labels)

        if fillna == "median":
            fillna = df.median()
        elif fillna == "mean":
            fillna = df.mean()
        elif fillna == "zero":
            fillna = 0

        df = df.fillna(fillna)

        min_dim = min(df.shape)
        n_components = min(n_components, min_dim)
        pca = PCA(n_components)
        X_projected = pca.fit_transform(df)
        # Convert to dataframe
        df_projected = pd.DataFrame(X_projected, index=df.index.get_level_values(0))
        # Set columns to PC-1, PC-2, ...
        df_projected.columns = [f"PC-{int(i) + 1}" for i in df_projected.columns]

        # Calculate cumulative explained variance in percent
        explained_variance = pca.explained_variance_ratio_ * 100
        cum_expl_var = np.cumsum(explained_variance)

        # Create feature contributions
        a = np.zeros((n_components, n_components), int)
        np.fill_diagonal(a, 1)
        dfc = pd.DataFrame(pca.inverse_transform(a))
        dfc.columns = df.columns
        dfc.index = [f"PC-{i + 1}" for i in range(n_components)]
        dfc.index.name = "PC"
        # convert to long format
        dfc = dfc.stack().reset_index().rename(columns={0: "Coefficient"})

        self.results = {
            "df_projected": df_projected,
            "cum_expl_var": cum_expl_var,
            "n_components": n_components,
            "type": "PCA",
            "feature_contributions": dfc,
            "class": pca,
        }

__init__(mint=None)

Initialize a PrincipalComponentsAnalyser instance.

Parameters:

Name Type Description Default
mint Optional['ms_mint.Mint.Mint']

Mint instance containing the data to analyze.

None
Source code in src/ms_mint/pca.py
def __init__(self, mint: Optional["ms_mint.Mint.Mint"] = None) -> None:
    """Initialize a PrincipalComponentsAnalyser instance.

    Args:
        mint: Mint instance containing the data to analyze.
    """
    self.mint = mint
    self.results: Optional[Dict[str, Any]] = None
    self.plot = PCA_Plotter(self)

run(n_components=3, on=None, var_name='peak_max', fillna='median', apply=None, groupby=None, scaler='standard', peak_labels=None)

Run Principal Component Analysis on the current results.

Performs PCA on the data and stores results in self.results.

Parameters:

Name Type Description Default
n_components int

Number of PCA components to calculate.

3
on Optional[str]

Deprecated, use var_name instead.

None
var_name str

Column name from results to use for PCA.

'peak_max'
fillna Union[str, float]

Method to fill missing values. One of "median", "mean", "zero", or a numeric value.

'median'
apply Optional[str]

Transformation to apply to the data before PCA.

None
groupby Optional[Union[str, List[str]]]

Column(s) to group by before analysis.

None
scaler str

Method to scale the data. One of "standard", "robust", "minmax".

'standard'
peak_labels Optional[List[str]]

List of peak labels to include. If None, all peaks are used.

None

Raises:

Type Description
DeprecationWarning

If the deprecated 'on' parameter is used.

Source code in src/ms_mint/pca.py
def run(
    self,
    n_components: int = 3,
    on: Optional[str] = None,
    var_name: str = "peak_max",
    fillna: Union[str, float] = "median",
    apply: Optional[str] = None,
    groupby: Optional[Union[str, List[str]]] = None,
    scaler: str = "standard",
    peak_labels: Optional[List[str]] = None,
) -> None:
    """Run Principal Component Analysis on the current results.

    Performs PCA on the data and stores results in self.results.

    Args:
        n_components: Number of PCA components to calculate.
        on: Deprecated, use var_name instead.
        var_name: Column name from results to use for PCA.
        fillna: Method to fill missing values. One of "median", "mean", "zero",
            or a numeric value.
        apply: Transformation to apply to the data before PCA.
        groupby: Column(s) to group by before analysis.
        scaler: Method to scale the data. One of "standard", "robust", "minmax".
        peak_labels: List of peak labels to include. If None, all peaks are used.

    Raises:
        DeprecationWarning: If the deprecated 'on' parameter is used.
    """
    if on is not None:
        warnings.warn("on is deprecated, use var_name instead", DeprecationWarning)
        var_name = on

    df = self.mint.crosstab(var_name=var_name, apply=apply, scaler=scaler, groupby=groupby, peak_labels=peak_labels)

    if fillna == "median":
        fillna = df.median()
    elif fillna == "mean":
        fillna = df.mean()
    elif fillna == "zero":
        fillna = 0

    df = df.fillna(fillna)

    min_dim = min(df.shape)
    n_components = min(n_components, min_dim)
    pca = PCA(n_components)
    X_projected = pca.fit_transform(df)
    # Convert to dataframe
    df_projected = pd.DataFrame(X_projected, index=df.index.get_level_values(0))
    # Set columns to PC-1, PC-2, ...
    df_projected.columns = [f"PC-{int(i) + 1}" for i in df_projected.columns]

    # Calculate cumulative explained variance in percent
    explained_variance = pca.explained_variance_ratio_ * 100
    cum_expl_var = np.cumsum(explained_variance)

    # Create feature contributions
    a = np.zeros((n_components, n_components), int)
    np.fill_diagonal(a, 1)
    dfc = pd.DataFrame(pca.inverse_transform(a))
    dfc.columns = df.columns
    dfc.index = [f"PC-{i + 1}" for i in range(n_components)]
    dfc.index.name = "PC"
    # convert to long format
    dfc = dfc.stack().reset_index().rename(columns={0: "Coefficient"})

    self.results = {
        "df_projected": df_projected,
        "cum_expl_var": cum_expl_var,
        "n_components": n_components,
        "type": "PCA",
        "feature_contributions": dfc,
        "class": pca,
    }

options: show_root_heading: true show_root_full_path: true show_submodules: true members_order: source