Coverage for functions \ flipdare \ analysis \ plotter.py: 92%
131 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-05-08 12:22 +1000
« prev ^ index » next coverage.py v7.13.0, created at 2026-05-08 12:22 +1000
1#!/usr/bin/env python
2# Copyright (c) 2026 Flipdare Pty Ltd. All rights reserved.
3#
4# This file is part of Flipdare's proprietary software and contains
5# confidential and copyrighted material. Unauthorised copying,
6# modification, distribution, or use of this file is strictly
7# prohibited without prior written permission from Flipdare Pty Ltd.
8#
9# This software includes third-party components licensed under MIT,
10# BSD, and Apache 2.0 licences. See THIRD_PARTY_NOTICES for details.
11#
13from typing import ClassVar, TypeIs
14from dataclasses import dataclass
15from enum import StrEnum
16import io
17import pandas as pd
18import seaborn as sns
19import numpy as np
20import matplotlib.pyplot as plt
21from matplotlib.axes import Axes
22from flipdare.app_log import LOG
23from flipdare.app_types import AnalysisArrayType, AnalysisDataType
24from flipdare.constants import GRAPH_LOG_SCALE_THRESHOLD, IS_TRACE
26__all__ = ["Plotter", "ScatterData"]
28# default theme
29sns.set_theme(style="whitegrid")
32@dataclass
33class ScatterData:
34 points: list[float]
35 indices: list[int]
36 label: str
38 def __post_init__(self) -> None:
39 if len(self.points) != len(self.indices):
40 msg = f"Inconsistent array lengths: {len(self.points)} points != {len(self.indices)} indices."
41 raise ValueError(msg)
44class PlotStrategy(StrEnum):
45 LOG_SCALE = "log_scale"
46 NORMAL_SCALE = "normal_scale"
49def _is_nested(
50 val: AnalysisDataType | AnalysisArrayType,
51) -> TypeIs[AnalysisDataType]:
52 """Precisely narrows type in BOTH if and else branches."""
53 return bool(val and isinstance(val[0], list))
56class Plotter:
57 __slots__ = (
58 "_data",
59 "_graph_title",
60 "_legend_labels",
61 "_log_scale_threshold",
62 "_scatter_data",
63 "_x_label",
64 "_x_values",
65 "_y_label",
66 )
68 # fmt: off
69 _MARKERS:ClassVar[list[str]] = ["s", "o", "^", "D", "v", "p", "*", "X", "P"]
71 _COLORS:ClassVar[list[str]] = [
72 "orange", "teal", "green",
73 "purple", "brown", "pink",
74 "gray", "olive", "cyan"]
76 _SCATTER_COLOR:ClassVar[list[str]] = [
77 "coral", "mediumturquoise", "lightgreen",
78 "mediumpurple", "rosybrown", "lightpink",
79 "silver", "olivedrab", "lightcyan"]
80 # fmt: on
82 def __init__(
83 self,
84 title: str,
85 x_label: str,
86 y_label: str,
87 data: list[AnalysisArrayType] | AnalysisArrayType,
88 legend_labels: list[str],
89 x_values: list[str] | None = None,
90 scatter_data: ScatterData | list[ScatterData] | None = None,
91 log_scale_threshold: float = GRAPH_LOG_SCALE_THRESHOLD,
92 ) -> None:
93 self._graph_title = title
94 self._x_label = x_label
95 self._y_label = y_label
96 if not data:
97 self._data = []
98 elif _is_nested(data):
99 self._data = data
100 else:
101 self._data = [data]
103 self._scatter_data = scatter_data
104 self._x_values = x_values
105 self._legend_labels = legend_labels
106 self._log_scale_threshold = log_scale_threshold
108 @property
109 def plot_strategy(self) -> PlotStrategy:
110 # Calculate the max value for each individual line
111 data = self._data
112 threshold = self._log_scale_threshold
114 arr = np.array(data, dtype=float)
115 min_v, max_v = np.nanmin(arr), np.nanmax(arr)
116 # Calculate the ratio between the absolute highest and lowest peaks
117 if IS_TRACE:
118 LOG().trace(f"Data range for plot: min={min_v}, max={max_v}")
120 min_v = 1 if min_v == 0 else min_v # Avoid division by zero in ratio calculation
121 scale_ratio = max_v / min_v
122 strategy = PlotStrategy.LOG_SCALE if scale_ratio > threshold else PlotStrategy.NORMAL_SCALE
124 if IS_TRACE:
125 msg = f"Data scale ratio {scale_ratio:.2f} exceeds threshold {threshold}. Using {strategy.value} strategy."
126 LOG().trace(msg)
128 return strategy
130 def create(self) -> io.BytesIO:
131 raw_data = self._data
132 if len(raw_data) == 0:
133 raise ValueError("Data list is empty. Cannot create plot.")
135 x_label = self._x_label
136 y_label = self._y_label
137 fig, ax = plt.subplots(figsize=(12, 6))
139 data = self._data
140 for idx in range(len(data)):
141 data_entry = np.array(data[idx], dtype=float)
143 # 1. Construct DataFrame with BOTH columns to satisfy Seaborn's lookup
144 # We use range(len()) to create the X-axis points (0, 1, 2...)
145 df = pd.DataFrame({x_label: range(len(data_entry)), y_label: data_entry})
147 # 2. Interpolate NaNs to bridge gaps in the line
148 # 'linear' creates a straight line between the points you DO have
149 df[y_label] = df[y_label].interpolate(method="linear")
151 # Note: interpolate() doesn't fill NaNs at the very start or very end.
152 # If you want to fill those too, chain it with bfill() and ffill()
153 df[y_label] = df[y_label].bfill().ffill()
155 color_idx = idx % len(self._COLORS)
156 marker = self._MARKERS[color_idx]
157 line_color = self._COLORS[color_idx]
158 scatter_color = self._SCATTER_COLOR[color_idx]
160 # plotting
161 self._plot_main_trend(ax, df, idx, marker, line_color)
162 self._plot_scatter_points(ax, idx, scatter_color)
163 # self._plot_missing_markers(ax, df)
165 self._set_plot_defaults(ax)
166 fig.tight_layout()
168 buffer = io.BytesIO()
169 plt.savefig(buffer, format="png")
170 plt.close()
171 buffer.seek(0)
173 return buffer
175 def _plot_main_trend(
176 self,
177 ax: Axes,
178 df: pd.DataFrame,
179 idx: int,
180 marker_color: str,
181 line_color: str,
182 ) -> None:
183 # 1. Plot the main trend
184 x_label = self._x_label
185 y_label = self._y_label
186 legend_labels = self._legend_labels
188 line_label = legend_labels[idx] if legend_labels and idx < len(legend_labels) else y_label
189 sns.lineplot(
190 data=df,
191 ax=ax,
192 x=x_label,
193 y=y_label,
194 marker=marker_color,
195 color=line_color,
196 label=line_label,
197 linewidth=max(3 - (idx * 0.8), 1.4),
198 alpha=0.6,
199 sort=False, # Preserve original order of x-values
200 estimator=None, # Don't aggregate; plot raw data
201 )
203 def _plot_scatter_points(self, ax: Axes, idx: int, scatter_color: str) -> None:
204 scatter_data: list[ScatterData] = []
205 if self._scatter_data:
206 if isinstance(self._scatter_data, list):
207 scatter_data = self._scatter_data
208 else:
209 scatter_data = [self._scatter_data]
211 scatter_values = scatter_data[idx] if scatter_data and idx < len(scatter_data) else None
212 if scatter_values is None:
213 return
215 # 2. Plot scatter points if provided (e.g. outliers)
216 sns.scatterplot(
217 x=scatter_values.indices,
218 y=scatter_values.points,
219 ax=ax, # Target the specific axes
220 color=scatter_color,
221 s=120, # Size
222 label=scatter_values.label,
223 zorder=5, # Ensures markers sit on top of the line
224 marker="o", # Or any marker style you prefer
225 )
227 def _plot_missing_markers(self, ax: Axes, df: pd.DataFrame) -> None:
228 # 1.b generate a scatter plot for missing indices.
229 x_label = self._x_label
230 y_label = self._y_label
232 missing_mask = df[y_label].isna()
234 if missing_mask.any():
235 # 2. Create a temporary Series where NaNs are filled by
236 # drawing a straight line between the surrounding points
237 y_interpolated = df[y_label].interpolate(method="linear")
239 # 3. Create a DataFrame containing ONLY the missing points
240 # but with their new interpolated Y values
241 df_missing = pd.DataFrame(
242 {x_label: df.loc[missing_mask, x_label], y_label: y_interpolated[missing_mask]}
243 )
245 sns.scatterplot(
246 data=df_missing,
247 x=x_label,
248 y=y_label,
249 ax=ax,
250 marker="x",
251 color="black",
252 s=50,
253 linewidth=1.4,
254 legend=False,
255 zorder=3, # Ensures it sits on top of the line
256 )
258 def _set_plot_defaults(self, ax: Axes) -> None:
259 # 1. Handle X-Axis Ticks and Rotation
260 if self._x_values:
261 # Set the numeric positions first, then the string labels
262 ax.set_xticks(range(len(self._x_values)))
263 ax.set_xticklabels(self._x_values, rotation=45)
264 else:
265 # Use tick_params for rotation if using default numeric indices
266 ax.tick_params(axis="x", labelrotation=45)
268 # 2. Handle Scale and Labels via Match Case
269 graph_title = self._graph_title
270 x_label = self._x_label
271 y_label = self._y_label
273 match self.plot_strategy:
274 case PlotStrategy.LOG_SCALE:
275 graph_title += " (Log Scale)"
276 ax.set_yscale("log")
277 ax.set_ylabel(f"{y_label} (Log Scale)")
278 case PlotStrategy.NORMAL_SCALE:
279 ax.set_yscale("linear")
280 ax.set_ylabel(y_label)
282 ax.set_title(graph_title, fontsize=14)
283 ax.set_xlabel(x_label)
284 ax.set_ylabel(y_label)
286 # Note: lineplot sets labels automatically from the DataFrame columns
287 ax.legend(
288 loc="upper left", # Align the TOP LEFT of the legend...
289 bbox_to_anchor=(1.02, 1), # ...to just right (1.02) and top (1) of the plot
290 borderaxespad=0, # Remove padding between anchor and legend
291 frameon=False, # Clean look
292 fontsize="small", # Keeps the box compact
293 )