import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
[docs]
def general_temp(num_row, num_col, size_x, size_y, pad=(0.2, 0.5), *args, **kwargs):
"""
Create a standardized matplotlib figure with a grid of subplots.
Initializes a figure using ``plt.subplots`` and applies consistent
tick formatting, spine linewidths, font settings, and default line/marker
styles across all axes. Intended as a base template for publication-ready figures.
Parameters
----------
num_row : int
Number of subplot rows.
num_col : int
Number of subplot columns.
size_x : float
Figure width in inches.
size_y : float
Figure height in inches.
pad : tuple of float, optional
Padding around the subplots.
*args
Additional positional arguments forwarded to ``plt.subplots``.
**kwargs
Additional keyword arguments forwarded to ``plt.subplots``
(e.g., ``sharex``, ``sharey``, ``subplot_kw``).
Returns
-------
fig : matplotlib.figure.Figure
The created figure object.
axs : matplotlib.axes.Axes or numpy.ndarray of Axes
A single ``Axes`` object if ``num_row * num_col == 1``,
otherwise a flattened 1-D array of ``Axes`` objects.
Notes
-----
- Major ticks: inward, width=2, length=8.
- Minor ticks: inward, width=1.5, length=4.
- Spine linewidth: 2.
- Font: Arial, size 18.
- Default line width: 2.0; marker: 'o', size 6.
Examples
--------
>>> fig, axs = general_temp(2, 3, 12, 8)
>>> axs[0].plot([1, 2, 3], [4, 5, 6])
>>> fig.savefig("output.png")
"""
fig, axs = plt.subplots(num_row, num_col, figsize=(size_x, size_y), *args, **kwargs)
plt.subplots_adjust(wspace=pad[0], hspace=pad[1])
if num_row*num_col != 1:
axs = axs.ravel()
for ax in axs:
ax.tick_params(axis="both", which='major', direction='in', width=2, length=8.0)
ax.tick_params(axis="both", which='minor', direction='in', width=1.5, length=4.0)
plt.setp(ax.spines.values(), linewidth=2)
else:
axs.tick_params(axis="both", which='major', direction='in', width=2, length=8.0)
axs.tick_params(axis="both", which='minor', direction='in', width=1.5, length=4.0)
plt.setp(axs.spines.values(), linewidth=2)
plt.rcParams['font.family'] = "Arial"
plt.rcParams['font.size'] = 18
mpl.rcParams['mathtext.default'] = 'regular'
plt.rcParams['lines.linewidth'] = 2.0
plt.rcParams['scatter.marker'] = 'o'
plt.rcParams['lines.markersize'] = 6
plt.rcParams["savefig.dpi"] = 300
return fig, axs
[docs]
def set_grid(ax, *args, **kwargs):
"""
Add a styled major grid to an axis.
Draws dashed major gridlines with reduced opacity, suitable for
background reference without overwhelming plotted data.
Parameters
----------
ax : matplotlib.axes.Axes
The target axes on which the grid will be drawn.
*args
Additional positional arguments forwarded to ``ax.grid``.
**kwargs
Additional keyword arguments forwarded to ``ax.grid``
(e.g., ``color``, ``axis``).
Notes
-----
Grid style defaults: dashes=(5,5), linewidth=1, alpha=0.5.
Examples
--------
>>> fig, axs = general_temp(1, 1, 6, 4)
>>> set_grid(axs)
"""
ax.grid(which='major', ls='--', dashes=(5,5), lw=1, alpha=0.5, *args, **kwargs)
[docs]
def set_legend(ax, *args, **kwargs):
"""
Add a styled legend to an axis.
Applies a white semi-transparent background with no visible edge,
keeping the legend readable without obstructing the plot.
Parameters
----------
ax : matplotlib.axes.Axes
The target axes on which the legend will be placed.
*args
Additional positional arguments forwarded to ``ax.legend``.
**kwargs
Additional keyword arguments forwarded to ``ax.legend``
(e.g., ``loc``, ``ncol``, ``fontsize``).
Examples
--------
>>> ax.plot([1, 2], [3, 4], label='Data A')
>>> set_legend(ax, loc='upper left')
"""
ax.legend(facecolor='white', framealpha=0.7, edgecolor='white', *args, **kwargs)
[docs]
def merge_legend(ax, order=None, *args, **kwargs):
"""
Deduplicate legend entries and optionally reorder them.
Useful when multiple plot calls share the same label (e.g., in a loop)
and only one representative legend entry is desired per label.
An explicit ``order`` list controls the final display sequence.
Parameters
----------
ax : matplotlib.axes.Axes
The target axes whose legend will be rebuilt.
order : list of str, optional
Ordered list of label strings defining the desired legend sequence.
All labels in ``order`` must already exist among the plotted artists.
If ``None``, unique labels are displayed in their first-occurrence order.
*args
Additional positional arguments forwarded to ``ax.legend``.
**kwargs
Additional keyword arguments forwarded to ``ax.legend``
(e.g., ``loc``, ``ncol``).
Raises
------
KeyError
If a label in ``order`` does not match any plotted artist label.
Examples
--------
>>> for val in data:
... ax.plot(x, val, color='blue', label='Series A')
>>> merge_legend(ax, order=['Series A', 'Series B'])
"""
handles, labels = ax.get_legend_handles_labels()
unique = dict(zip(labels, handles)) # Remove duplicate labels
if order is None:
ax.legend(unique.values(), unique.keys(), facecolor='white', framealpha=0.7, edgecolor='white', *args, **kwargs)
else:
ordered_handels = [unique[label] for label in order]
ordered_labels = [label for label in order]
ax.legend(ordered_handels, ordered_labels, facecolor='white', framealpha=0.7, edgecolor='white', *args, **kwargs)
[docs]
def set_unique_legend(ax, *args, **kwargs):
"""
Add a legend showing only the first occurrence of each unique label.
Operates on the current active axes (``plt.gca()``) to collect handles
and labels, deduplicates them by label string, then renders the legend
on the provided ``ax``.
Parameters
----------
ax : matplotlib.axes.Axes
The target axes on which the deduplicated legend will be placed.
*args
Additional positional arguments (currently unused; reserved for
future forwarding to ``ax.legend``).
**kwargs
Additional keyword arguments (currently unused; reserved for
future forwarding to ``ax.legend``).
Notes
-----
This function reads handles from ``plt.gca()``, not from ``ax`` directly.
For multi-axes figures, prefer :func:`merge_legend` to avoid ambiguity.
Examples
--------
>>> ax.plot(x, y1, label='Experiment')
>>> ax.plot(x, y2, label='Experiment') # duplicate label
>>> set_unique_legend(ax)
"""
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
ax.legend(by_label.values(), by_label.keys(), facecolor='white', framealpha=0.7, edgecolor='white')
[docs]
def number2letter(number, style=1):
"""
Convert an integer (1–26) to its corresponding alphabetic character.
Parameters
----------
number : int
Integer in the range [1, 26] to convert.
style : {1, 2}, optional
Output case style:
- ``1`` (default): uppercase letter (e.g., 1 → 'A').
- ``2``: lowercase letter (e.g., 1 → 'a').
Returns
-------
str
A single uppercase or lowercase letter.
Raises
------
ValueError
If ``number`` is outside [1, 26] or ``style`` is not 1 or 2.
Examples
--------
>>> number2letter(3)
'C'
>>> number2letter(3, style=2)
'c'
"""
if style == 1:
if 1 <= number <= 26:
return chr(number + 64)
else:
raise ValueError("Number out of range. Please enter a number between 1 and 26.")
elif style == 2:
if 1 <= number <= 26:
return chr(number + 96)
else:
raise ValueError("Number out of range. Please enter a number between 1 and 26.")
else:
raise ValueError("Style out of range. Please enter a style of 1 or 2.")
[docs]
def set_label(axs, starting=1, style=1, x=-0.2, y=1.05, **kwargs):
"""
Annotate a sequence of axes with alphabetic panel labels.
Places a bold letter label (e.g., 'A', 'B', … or '(a)', '(b)', …)
in the upper-left region of each axis using axis-relative coordinates.
Commonly used for multi-panel publication figures.
Parameters
----------
axs : iterable of matplotlib.axes.Axes
Sequence of axes to label, processed in order.
starting : int, optional
Integer index of the first label (default ``1`` → 'A' or '(a)').
style : {1, 2}, optional
Label format:
- ``1`` (default): uppercase letter only (e.g., 'A', 'B', 'C').
- ``2``: lowercase letter in parentheses (e.g., '(a)', '(b)', '(c)').
x : float, optional
Horizontal position in axis-relative coordinates (default ``-0.2``).
y : float, optional
Vertical position in axis-relative coordinates (default ``1.05``).
**kwargs
Additional keyword arguments forwarded to ``ax.text``
(e.g., ``color``, ``fontsize``).
Raises
------
ValueError
If ``style`` is not 1 or 2.
Examples
--------
>>> fig, axs = general_temp(1, 3, 12, 4)
>>> set_label(axs) # labels: A, B, C
>>> set_label(axs, style=2) # labels: (a), (b), (c)
>>> set_label(axs, starting=4) # labels: D, E, F
"""
if style == 1:
for i, ax in enumerate(axs):
label = number2letter(i + starting, style=1)
ax.text(x, y, label, transform=ax.transAxes, size=19, weight='bold', **kwargs)
elif style == 2:
for i, ax in enumerate(axs):
label = '(' + number2letter(i + starting, style=2) + ')'
ax.text(x, y, label, transform=ax.transAxes, size=19, weight='bold', **kwargs)
else:
raise ValueError("Style out of range. Please enter a style of 1 or 2.")
[docs]
def color_cycle(id):
"""
Return a hex color string from a fixed 10-color palette.
Mirrors matplotlib's default ``tab10`` color cycle, providing
convenient index-based access for consistent multi-series coloring.
Parameters
----------
id : int
Color index in the range [1, 10]:
==== ========= ===================
id hex approximate color
==== ========= ===================
1 #1f77b4 muted blue
2 #ff7f0e safety orange
3 #2ca02c cooked asparagus green
4 #d62728 brick red
5 #9467bd muted purple
6 #8c564b chestnut brown
7 #e377c2 raspberry pink
8 #7f7f7f middle gray
9 #bcbd22 curry yellow-green
10 #17becf blue-teal
==== ========= ===================
Returns
-------
str
Hex color string. Returns ``'#000000'`` (black) for any ``id``
not in [1, 10].
Examples
--------
>>> ax.plot(x, y, color=color_cycle(1)) # muted blue
>>> ax.plot(x, z, color=color_cycle(2)) # safety orange
"""
colors = {
1 : '#1f77b4',
2 : '#ff7f0e',
3 : '#2ca02c',
4 : '#d62728',
5 : '#9467bd',
6 : '#8c564b',
7 : '#e377c2',
8 : '#7f7f7f',
9 : '#bcbd22',
10: '#17becf',
}
return colors.get(id, '#000000')
[docs]
def savefig(fig, filename):
"""
Save a figure to disk at publication-quality resolution.
Wraps ``fig.savefig`` with ``dpi=600`` and ``bbox_inches='tight'``
to ensure all artists (titles, labels, legends) are included without
clipping and the output is suitable for journal submission.
Parameters
----------
fig : matplotlib.figure.Figure
The figure object to save.
filename : str or path-like
Output file path, including extension (e.g., ``'fig1.png'``,
``'fig1.pdf'``, ``'fig1.svg'``). The format is inferred from
the extension.
Examples
--------
>>> fig, axs = general_temp(1, 2, 10, 4)
>>> savefig(fig, 'results/figure1.png')
"""
fig.savefig(filename, dpi=600, bbox_inches='tight')
[docs]
def insert_image(ax, image_path, x, y, zoom=1.0, rotation=0):
"""
Embed a raster image into a matplotlib axis at a specified data coordinate.
Reads an image file, optionally rotates it, and places it as an
``AnnotationBbox`` artist anchored to the given data-space coordinates.
Useful for inset schematics, icons, or experimental photographs.
Parameters
----------
ax : matplotlib.axes.Axes
The target axes into which the image will be inserted.
image_path : str or path-like
Path to the image file (any format supported by
``matplotlib.image.imread``, e.g., PNG, JPEG).
x : float
Horizontal anchor position in data coordinates.
y : float
Vertical anchor position in data coordinates.
zoom : float, optional
Scaling factor applied to the image (default ``1.0``; values
below 1 shrink, above 1 enlarge).
rotation : float, optional
Counter-clockwise rotation angle in degrees (default ``0``).
Uses ``scipy.ndimage.rotate`` with nearest-neighbor interpolation.
Notes
-----
Rotation via ``scipy.ndimage.rotate`` may introduce black borders around
the image corners. For transparent PNGs, consider passing
``reshape=False`` by patching the call if border artifacts occur.
Examples
--------
>>> fig, axs = general_temp(1, 1, 6, 6)
>>> insert_image(axs, 'schematic.png', x=0.5, y=0.5, zoom=0.3, rotation=45)
"""
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from scipy.ndimage import rotate
import matplotlib.image as mimg
img = mimg.imread(image_path)
img = rotate(img, rotation)
imagebox = OffsetImage(img, zoom=zoom)
ab = AnnotationBbox(imagebox, (x, y), frameon=False)
ax.add_artist(ab)
[docs]
def add_break(ax, xranges=None, yranges=None):
"""
Insert an axis break to omit a continuous range of x- or y-values.
Compresses the specified interval to near-zero width/height using
``break_axes.scale_axes``, then adds a visual break indicator via
``break_axes.broken_and_clip_axes``. Helpful when data has a large
gap that would otherwise compress the regions of interest.
Parameters
----------
ax : matplotlib.axes.Axes
The target axes on which the break will be applied.
xranges : tuple of float, optional
``(x_start, x_end)`` defining the x-axis interval to collapse.
If ``None``, no x-break is applied.
yranges : tuple of float, optional
``(y_start, y_end)`` defining the y-axis interval to collapse.
If ``None``, no y-break is applied.
Notes
-----
Requires the third-party ``break_axes`` package. Both ``xranges`` and
``yranges`` can be provided simultaneously to apply breaks on both axes.
The break marker is placed at the lower boundary of each specified range.
Examples
--------
>>> fig, axs = general_temp(1, 1, 6, 4)
>>> axs.plot([0, 1, 10, 11], [0, 1, 2, 3])
>>> add_break(axs, xranges=(2, 9)) # hide x ∈ [2, 9]
"""
from break_axes import broken_and_clip_axes
from break_axes import scale_axes
if xranges is not None:
scale_axes(ax, x_interval=[(xranges[0], xranges[1], 0.01)])
broken_and_clip_axes(ax, x=xranges[0], which='lower')
if yranges is not None:
scale_axes(ax, y_interval=[(yranges[0], yranges[1], 0.01)])
broken_and_clip_axes(ax, y=yranges[0], which='lower')
import numpy as np
import matplotlib as mpl
from matplotlib.colors import LogNorm, LinearSegmentedColormap
from mpl_toolkits.axes_grid1 import make_axes_locatable
[docs]
def plot_hist2d(ax, x, y, bins=50, log_scale=False, contours=False, **kwargs):
"""
Create a 2D histogram with optional contours.
Parameters:
-----------
x : array-like, shape (n_samples, n_features) or (n_samples,)
X data
y : array-like, shape (n_samples, n_features) or (n_samples,)
Y data
ax : matplotlib.axes.Axes
Axes object to plot on.
bins : int or array-like, default=50
Number of bins for histogram
log_scale : bool, default=False
Whether to use log scale for colors
contours : bool, default=False
Whether to overlay contour lines
**kwargs : dict
Additional customization options:
colorbar : bool, default=False
cbar_size : str, default='3%'
cbar_pad : float, default=0.1
cmap : str or LinearSegmentedColormap, default='viridis'
alpha : float, default=1.0
n_contours : int, default=5
contour_colors : str, default='black'
contour_alpha : float, default=0.6
contour_linewidths : float, default=1.0
contour_labels : bool, default=False
override_contour_levels : array-like, optional
grid : bool, default=True
Returns:
--------
hist : 2D histogram array
xedges, yedges : bin edges
cbar_ax : colorbar axes object (None if colorbar=False)
"""
from matplotlib.colors import LogNorm, LinearSegmentedColormap
from mpl_toolkits.axes_grid1 import make_axes_locatable
# Flatten arrays if they are 2D
x_flat = x.flatten() if x.ndim > 1 else x.copy()
y_flat = y.flatten() if y.ndim > 1 else y.copy()
# Remove any NaN values
mask = ~(np.isnan(x_flat) | np.isnan(y_flat))
x_clean = x_flat[mask]
y_clean = y_flat[mask]
# Create 2D histogram
hist, xedges, yedges = np.histogram2d(x_clean, y_clean, bins=[bins, bins], density=True)
# Handle log scale
hist_plot = hist.copy()
hist_plot[hist_plot == 0] = np.nan # White for empty bins
norm = None
if log_scale:
min_val = np.nanmin(hist_plot[hist_plot > 0]) if np.any(hist_plot > 0) else 1
max_val = np.nanmax(hist_plot)
if max_val > min_val:
norm = LogNorm(vmin=min_val, vmax=max_val)
else:
log_scale = False
# Set colormap with white for NaN/empty bins
cmap = kwargs.get('cmap', 'viridis')
if isinstance(cmap, str):
cmap = mpl.colormaps[cmap].copy()
cmap.set_bad(kwargs.get('bad_color', 'none'))
# Plot the 2D histogram
X, Y = np.meshgrid(xedges, yedges)
im = ax.pcolormesh(X, Y, hist_plot.T,
cmap=cmap,
norm=norm,
shading='flat',
alpha=kwargs.get('alpha', 1.0))
# Add contours if requested
if contours and np.any(hist > 0):
n_contours = kwargs.get('n_contours', 5)
if kwargs.get('override_contour_levels') is not None:
contour_levels = kwargs['override_contour_levels']
elif log_scale and norm is not None:
contour_levels = np.logspace(np.log10(min_val), np.log10(max_val), n_contours)
else:
contour_levels = np.linspace(np.nanmin(hist_plot), np.nanmax(hist_plot), n_contours)
x_centers = (xedges[:-1] + xedges[1:]) / 2
y_centers = (yedges[:-1] + yedges[1:]) / 2
X_contour, Y_contour = np.meshgrid(x_centers, y_centers)
cs = ax.contour(X_contour, Y_contour, hist_plot.T,
levels=contour_levels,
colors=kwargs.get('contour_colors', 'black'),
alpha=kwargs.get('contour_alpha', 0.6),
linewidths=kwargs.get('contour_linewidths', 1.0))
if kwargs.get('contour_labels', False):
ax.clabel(cs, inline=True, fontsize=8, fmt='%.1f')
# Grid
if kwargs.get('grid', True):
ax.grid(True, alpha=0.3)
# Optional colorbar
cbar_ax = None
if kwargs.get('colorbar', False):
try:
divider = make_axes_locatable(ax)
cbar_ax = divider.append_axes("right",
size=kwargs.get('cbar_size', '3%'),
pad=kwargs.get('cbar_pad', 0.1))
ax.get_figure().colorbar(im, cax=cbar_ax)
except Exception:
pass
return hist, xedges, yedges, cbar_ax
[docs]
def plot_hist2d_contour(ax, x, y, levels=5, fill_color="steelblue", fill_alpha=0.4,
line_color="steelblue", line_alpha=0.9, linewidths=1.5,
gridsize=100, bw_method="scott",
):
"""
Plot a 2D histogram as filled contours with contour lines on a given Axes.
Parameters
----------
ax : matplotlib.axes.Axes
x, y : array-like of shape (N,)
levels : int, float, or array-like
- int → number of auto-spaced levels.
- float → single density threshold; fills the region above it.
- array → explicit iso-density boundaries (must have >= 2 values).
fill_color : str or RGB tuple
fill_alpha : float
line_color : str or RGB tuple
line_alpha : float
linewidths : float or sequence of float
gridsize : int
bw_method : str, scalar, or callable
Returns
-------
cf : QuadContourSet (filled)
cl : QuadContourSet (lines)
"""
import numpy as np
from scipy.stats import gaussian_kde
x = np.asarray(x, dtype=float)
y = np.asarray(y, dtype=float)
# Grid with a small padding so the outermost contour closes cleanly
pad_x = (x.max() - x.min()) * 0.05
pad_y = (y.max() - y.min()) * 0.05
xi = np.linspace(x.min() - pad_x, x.max() + pad_x, gridsize)
yi = np.linspace(y.min() - pad_y, y.max() + pad_y, gridsize)
Xi, Yi = np.meshgrid(xi, yi)
kde = gaussian_kde(np.vstack([x, y]), bw_method=bw_method)
Zi = kde(np.vstack([Xi.ravel(), Yi.ravel()])).reshape(Xi.shape)
# ── Normalise `levels` so contourf always receives >= 2 boundaries ────────
if np.ndim(levels) == 0: # scalar int or float
levels = int(levels) if float(levels) == int(levels) else float(levels)
if isinstance(levels, int): # e.g. levels=5 → auto spacing
pass # let matplotlib handle it
else: # e.g. levels=0.0032 → threshold
levels = [levels, float(Zi.max())]
# array-like: pass through unchanged (user's responsibility to have >= 2)
# ──────────────────────────────────────────────────────────────────────────
cf = ax.contourf(
Xi, Yi, Zi,
levels=levels,
colors=[fill_color],
alpha=fill_alpha,
)
cl = ax.contour(
Xi, Yi, Zi,
levels=cf.levels,
colors=[line_color],
alpha=line_alpha,
linewidths=linewidths,
)
return cf, cl
[docs]
class DualYAxis:
"""
A wrapper around a matplotlib Axes that supports independent y-axis coloring
for dual-axis (twinx) figures.
Do not instantiate directly. Use :func:`dualY` to create a pair.
All standard ``Axes`` methods (``plot``, ``set_ylabel``, ``set_xlim``, etc.)
are transparently forwarded to the underlying axes via ``__getattr__``.
Parameters
----------
ax : matplotlib.axes.Axes
The underlying axes to wrap.
side : {'left', 'right'}
Which y-axis spine this object owns. ``'left'`` for the original axes,
``'right'`` for the twinx axes.
Examples
--------
>>> fig, axs = general_temp(1, 1, 8, 5)
>>> ax1, ax2 = dualY(axs)
>>> ax1.plot(x, y1, color=color_cycle(1))
>>> ax2.plot(x, y2, color=color_cycle(2))
>>> ax1.set_ylabel('Temperature (°C)')
>>> ax2.set_ylabel('Pressure (Pa)')
>>> ax1.set_color(color_cycle(1))
>>> ax2.set_color(color_cycle(2))
"""
def __init__(self, ax, side):
object.__setattr__(self, '_ax', ax)
object.__setattr__(self, '_side', side)
[docs]
def set_color(self, color):
"""
Apply a uniform color to this axis's y-spine, ticks, tick labels, and label.
Parameters
----------
color : color-like
Any matplotlib-compatible color string or tuple
(e.g. ``'#1f77b4'``, ``'red'``, ``(0.1, 0.5, 0.9)``).
Examples
--------
>>> ax1.set_color('#1f77b4')
>>> ax2.set_color('#ff7f0e')
"""
ax = object.__getattribute__(self, '_ax')
side = object.__getattribute__(self, '_side')
ax.spines[side].set_edgecolor(color)
ax.tick_params(axis='y', colors=color)
ax.yaxis.label.set_color(color)
def __getattr__(self, name):
ax = object.__getattribute__(self, '_ax')
return getattr(ax, name)
def __setattr__(self, name, value):
ax = object.__getattribute__(self, '_ax')
setattr(ax, name, value)
[docs]
def dualY(ax):
"""
Set up a dual y-axis on an existing axes and return two :class:`DualYAxis`
wrappers — one for each y-axis — whose :meth:`~DualYAxis.set_color` method
colors the spine, ticks, tick labels, and axis label together.
The left spine belongs to ``ax1``; the right spine belongs to ``ax2``.
Redundant inner spines are hidden so the frame stays clean.
Parameters
----------
ax : matplotlib.axes.Axes
A single axes object, typically one element from the array returned by
:func:`general_temp`.
Returns
-------
ax1 : DualYAxis
Wraps the original ``ax``. Owns the **left** y-axis.
ax2 : DualYAxis
Wraps a new ``ax.twinx()``. Owns the **right** y-axis.
Notes
-----
- ``ax1`` and ``ax2`` proxy all standard ``Axes`` calls, so you can use
``ax1.plot(...)``, ``ax1.set_xlim(...)``, ``ax2.set_ylabel(...)``, etc.
as normal.
- Call :meth:`~DualYAxis.set_color` *after* setting ``ylabel`` so the
label color is applied correctly.
- The x-axis ticks and spine color are not modified; style them via
``ax1.tick_params(axis='x', ...)`` as usual.
Examples
--------
>>> fig, axs = general_temp(1, 1, 8, 5)
>>> ax1, ax2 = dualY(axs)
>>>
>>> ax1.plot(x, temp, color=color_cycle(1), label='Temperature')
>>> ax2.plot(x, pressure, color=color_cycle(2), label='Pressure')
>>>
>>> ax1.set_xlabel('Time (s)')
>>> ax1.set_ylabel('Temperature (°C)')
>>> ax2.set_ylabel('Pressure (Pa)')
>>>
>>> ax1.set_color(color_cycle(1))
>>> ax2.set_color(color_cycle(2))
"""
twin = ax.twinx()
ax.spines['right'].set_visible(False)
twin.spines['left'].set_visible(False)
twin.tick_params(axis='both', which='major', direction='in', width=2, length=8.0)
twin.tick_params(axis='both', which='minor', direction='in', width=1.5, length=4.0)
plt.setp(twin.spines.values(), linewidth=2)
ax1 = DualYAxis(ax, side='left')
ax2 = DualYAxis(twin, side='right')
return ax1, ax2
import matplotlib.pyplot as plt
[docs]
def color_text(fig, ax, x, y, parts, colors, sep="", **kwargs):
"""
Draw multi-colored text segments on a Matplotlib Axes object.
Each string in *parts* is rendered consecutively at the same baseline,
with horizontal positions computed automatically from the rendered width
of the preceding segment. This makes it easy to style sub-components of
a compound label (e.g. ``"aaa-bbb"``) with independent colors.
Parameters
----------
fig : matplotlib.figure.Figure
The Figure that contains *ax*. Required for accurate text extent
measurement and offset transforms.
ax : matplotlib.axes.Axes
The Axes on which to draw the text.
x : float
The x-coordinate of the text anchor in data coordinates.
y : float
The y-coordinate of the text anchor in data coordinates.
parts : tuple of str
The text segments to render, in left-to-right order.
Example: ``("aaa", "-", "bbb")``
colors : tuple of color
A color for each segment in *parts*. Must be the same length as
*parts*. Accepts any value valid for ``matplotlib.axes.Axes.text``
(named colors, hex strings, RGB tuples, etc.).
Example: ``("red", "grey", "blue")``
sep : str, optional
An unstyled separator string inserted between every pair of
consecutive segments. Rendered in black. Default is ``""`` (no
separator).
**kwargs
Additional keyword arguments forwarded to every
``matplotlib.axes.Axes.text`` call (e.g. ``fontsize``,
``fontweight``, ``va``, ``ha``, ``transform``).
Returns
-------
None
Raises
------
ValueError
If *parts* and *colors* are not the same length.
Examples
--------
Single axes:
>>> fig, ax = plt.subplots()
>>> color_text(
... fig, ax, x=2, y=5,
... parts=("aaa", "-", "bbb"),
... colors=("red", "grey", "blue"),
... fontsize=16,
... fontweight="bold",
... )
>>> plt.show()
Multiple axes:
>>> fig, axs = plt.subplots(1, 3)
>>> ax = axs[0]
>>> color_text(
... fig, ax, x=2, y=5,
... parts=("aaa", "-", "bbb"),
... colors=("red", "grey", "blue"),
... fontsize=16,
... )
>>> plt.show()
"""
from matplotlib import transforms
if len(parts) != len(colors):
raise ValueError(
f"'parts' and 'colors' must be the same length, "
f"got {len(parts)} and {len(colors)}."
)
transform = kwargs.pop("transform", ax.transData)
offset = 0
for i, (part, color) in enumerate(zip(parts, colors)):
text = ax.text(
x, y, part,
color=color,
transform=transforms.offset_copy(transform, fig=fig, x=offset, y=0, units="points"),
**kwargs,
)
fig.canvas.draw()
bbox = text.get_window_extent()
offset += bbox.width / fig.dpi * 72
if sep and i < len(parts) - 1:
sep_text = ax.text(
x, y, sep,
color="black",
transform=transforms.offset_copy(transform, fig=fig, x=offset, y=0, units="points"),
**kwargs,
)
fig.canvas.draw()
offset += sep_text.get_window_extent().width / fig.dpi * 72