Skip to content

Scripts

Run scripts outside the pipeline

You can run workflow helper scripts directly via the CLI without running the full Snakemake pipeline.

General form:

episodic script run <script_name> --env <env_name> -- <script_args>

Example (show script help):

episodic script run plot_partition_local_rate_posteriors --env python -- --help

Notes:

  • --env selects the workflow conda environment declared for that script (for example python, phylo, or ggtree).
  • The extra -- separates episodic script run options from the script's own arguments.

date_to_decimal_year(date_str)

Converts a date in the format '%Y-%m-%d' to a decimal year.

Parameters:

Name Type Description Default
date_str str

The date in the format '%Y-%m-%d'.

required

Returns:

Name Type Description
float

The decimal year.

Examples:

>>> date_to_decimal_year('2020-07-02')
2020.5
Source code in src/episodic/workflow/utils.py
def date_to_decimal_year(date_str):
    """
    Converts a date in the format '%Y-%m-%d' to a decimal year.

    Args:
      date_str (str): The date in the format '%Y-%m-%d'.

    Returns:
      float: The decimal year.

    Examples:
      >>> date_to_decimal_year('2020-07-02')
      2020.5
    """
    date = datetime.strptime(date_str, '%Y-%m-%d')
    year = date.year
    start_of_year = datetime(year, 1, 1)
    end_of_year = datetime(year + 1, 1, 1)
    days_in_year = (end_of_year - start_of_year).days
    days_passed = (date - start_of_year).days
    decimal_year = year + days_passed / days_in_year
    return decimal_year

decimal_year_to_date(decimal_year)

Converts a decimal year to a date in the format '%Y-%m-%d'.

Parameters:

Name Type Description Default
decimal_year float

The decimal year to convert.

required

Returns:

Name Type Description
str

The date in the format '%Y-%m-%d'.

Examples:

>>> decimal_year_to_date(2020.5)
'2020-07-02'
Source code in src/episodic/workflow/utils.py
def decimal_year_to_date(decimal_year):
    """
    Converts a decimal year to a date in the format '%Y-%m-%d'.

    Args:
      decimal_year (float): The decimal year to convert.

    Returns:
      str: The date in the format '%Y-%m-%d'.

    Examples:
      >>> decimal_year_to_date(2020.5)
      '2020-07-02'
    """
    year = int(decimal_year)
    remainder = decimal_year - year
    start_of_year = datetime(year, 1, 1)
    end_of_year = datetime(year + 1, 1, 1)
    days_in_year = (end_of_year - start_of_year).days
    days = remainder * days_in_year
    date = start_of_year + timedelta(days=days)
    return date.strftime('%Y-%m-%d')

analyze_rates(trees_path=typer.Argument(..., help='Path to the BEAST output trees file'), groups_file=typer.Option(..., '--groups-file', help='TSV mapping taxa to group labels'), output_plot_path=typer.Option(..., '--output-plot', help='Output path for the plot file'), output_csv_path=typer.Option(..., '--output-csv', help='Output path for the CSV file'), burnin=typer.Option(0.1, '--burnin', '-b', help='Fraction of trees to discard as burn-in'))

Analyzes rates from a given BEAST output trees file and generates a plot and CSV file.

Parameters:

Name Type Description Default
trees_path str

The path to the BEAST output trees file.

typer.Argument(..., help='Path to the BEAST output trees file')
groups_file Path

TSV mapping taxa to group labels.

typer.Option(..., '--groups-file', help='TSV mapping taxa to group labels')
output_plot_path str

The output path for the plot file.

typer.Option(..., '--output-plot', help='Output path for the plot file')
output_csv_path str

The output path for the CSV file.

typer.Option(..., '--output-csv', help='Output path for the CSV file')
burnin float

The fraction of trees to discard as burn-in.

typer.Option(0.1, '--burnin', '-b', help='Fraction of trees to discard as burn-in')

Returns:

Type Description

None

Examples:

>>> analyze_rates('trees.nexus', Path('groups.tsv'), 'plot.png', 'stats.csv', 0.1)
Source code in src/episodic/workflow/scripts/phylo_rate_quantile_analysis.py
@app.command()
def analyze_rates(
    trees_path: str = typer.Argument(..., help="Path to the BEAST output trees file"),
    groups_file: Path = typer.Option(..., "--groups-file", help="TSV mapping taxa to group labels"),
    output_plot_path: str = typer.Option(..., "--output-plot", help="Output path for the plot file"),
    output_csv_path: str = typer.Option(..., "--output-csv", help="Output path for the CSV file"),
    burnin: float = typer.Option(0.1, "--burnin", "-b", help="Fraction of trees to discard as burn-in"),
):
    """
    Analyzes rates from a given BEAST output trees file and generates a plot and CSV file.

    Args:
      trees_path (str): The path to the BEAST output trees file.
      groups_file (Path): TSV mapping taxa to group labels.
      output_plot_path (str): The output path for the plot file.
      output_csv_path (str): The output path for the CSV file.
      burnin (float): The fraction of trees to discard as burn-in.

    Returns:
      None

    Examples:
      >>> analyze_rates('trees.nexus', Path('groups.tsv'), 'plot.png', 'stats.csv', 0.1)
    """
    group_members = read_group_members(groups_file)
    groups = list(group_members)

    # time ow long it takes to run
    now = datetime.now()
    tree_yielder = dendropy.Tree.yield_from_files(files=[trees_path], schema="nexus", preserve_underscores=True)
    total_trees = sum(1 for _ in tree_yielder)  # Count total trees
    burnin_count = int(total_trees * burnin)
    total_time = datetime.now() - now
    print(f"Total time to count trees: {total_time}")

    tree_yielder = dendropy.Tree.yield_from_files(  # Reinitialize generator
        files=[trees_path], schema="nexus", preserve_underscores=True
    )
    group_stats: Dict[str, Dict[str, List]] = {g: {"ranks": [], "quantiles": []} for g in groups}

    with typer.progressbar(
            tree_yielder,
            length=total_trees,
            label="Processing trees",
            show_pos=True,
            show_percent=True
        ) as progress:
        for tree_idx, tree in enumerate(progress):
            if tree_idx < burnin_count:
                continue
            analyze_tree(tree, group_members, group_stats)

    csv_data = [
        [
            "Group",
            "Mean Rank",
            "Rank Credible Interval",
            "Mean Quantile",
            "Quantile Credible Interval",
        ]
    ]

    plt.figure(figsize=(15, 5 * len(groups)))

    for i, group in enumerate(groups, start=1):
        ranks = group_stats[group]["ranks"]
        quantiles = group_stats[group]["quantiles"]

        mean_rank = np.mean(ranks)
        rank_credible_interval = (np.percentile(ranks, 2.5), np.percentile(ranks, 97.5))

        mean_quantile = np.mean(quantiles)
        quantile_credible_interval = (
            np.percentile(quantiles, 2.5),
            np.percentile(quantiles, 97.5),
        )

        csv_data.append(
            [
                group,
                f"{mean_rank:.4f}",
                f"[{rank_credible_interval[0]:.0f}, {rank_credible_interval[1]:.0f}]",
                f"{mean_quantile:.4f}",
                f"[{quantile_credible_interval[0]:.2f}, {quantile_credible_interval[1]:.2f}]",
            ]
        )

        plt.subplot(len(groups), 2, 2 * i - 1)
        plt.hist(ranks, bins=30, alpha=0.7, color="blue")
        rank_credible_interval_str = f"[{rank_credible_interval[0]:.0f}, {rank_credible_interval[1]:.0f}]"
        plt.title(f"Ranks for {group} - Mean: {mean_rank:.2f}, 95% CI: {rank_credible_interval_str}")
        plt.xlabel("Rank")
        plt.ylabel("Frequency")

        plt.subplot(len(groups), 2, 2 * i)
        plt.hist(quantiles, bins=30, alpha=0.7, color="green")
        quantile_credible_interval_str = f"[{quantile_credible_interval[0]:.2f}, {quantile_credible_interval[1]:.2f}]"
        plt.title(f"Quantiles for {group} - Mean: {mean_quantile:.2f}, 95% CI: {quantile_credible_interval_str}")
        plt.xlabel("Quantile")
        plt.ylabel("Frequency")

    plt.tight_layout()

    plt.savefig(output_plot_path)

    with open(output_csv_path, "w", newline="") as csvfile:
        csv_writer = csv.writer(csvfile)
        csv_writer.writerows(csv_data)

analyze_tree(tree, group_members, group_stats)

Analyzes a given tree and updates group statistics.

Parameters:

Name Type Description Default
tree dendropy.Tree

The tree to analyze.

required
group_members Dict[str, List[str]]

Taxa assigned to each analyzed group.

required
group_stats Dict[str, Dict[str, List]]

A dictionary containing group statistics.

required

Returns:

Type Description

None

Examples:

>>> analyze_tree(tree, ['A', 'B'], {'A': {'ranks': [], 'quantiles': []}, 'B': {'ranks': [], 'quantiles': []}})
Source code in src/episodic/workflow/scripts/phylo_rate_quantile_analysis.py
def analyze_tree(tree, group_members, group_stats):
    """
    Analyzes a given tree and updates group statistics.

    Args:
      tree (dendropy.Tree): The tree to analyze.
      group_members (Dict[str, List[str]]): Taxa assigned to each analyzed group.
      group_stats (Dict[str, Dict[str, List]]): A dictionary containing group statistics.

    Returns:
      None

    Examples:
      >>> analyze_tree(tree, ['A', 'B'], {'A': {'ranks': [], 'quantiles': []}, 'B': {'ranks': [], 'quantiles': []}})
    """
    # Assuming sorted_rates is generated here for each tree passed to this function
    sorted_rates = extract_and_sort_rates(tree)
    tree_taxa = {node.taxon.label for node in tree if node.taxon is not None}
    for group, taxa in group_members.items():
        taxon_labels = [taxon for taxon in taxa if taxon in tree_taxa]
        if not taxon_labels:
            msg = f"Group '{group}' has no taxa present in the tree."
            raise ValueError(msg)

        mrca = tree.mrca(taxon_labels=taxon_labels)
        group_rate = float(mrca.annotations.get_value("rate"))

        # Use bisect_left for efficient rank finding in a sorted list
        rank = bisect_left(sorted_rates, group_rate) + 1
        group_stats[group]["ranks"].append(rank)

        quantile = rank / len(sorted_rates)
        group_stats[group]["quantiles"].append(quantile)

extract_and_sort_rates(tree)

Extracts and sorts rates from a given tree.

Parameters:

Name Type Description Default
tree dendropy.Tree

The tree to extract rates from.

required

Returns:

Type Description

List[float]: A list of sorted rates.

Examples:

>>> extract_and_sort_rates(tree)
[0.1, 0.2, 0.3]
Source code in src/episodic/workflow/scripts/phylo_rate_quantile_analysis.py
def extract_and_sort_rates(tree):
    """
    Extracts and sorts rates from a given tree.

    Args:
      tree (dendropy.Tree): The tree to extract rates from.

    Returns:
      List[float]: A list of sorted rates.

    Examples:
      >>> extract_and_sort_rates(tree)
      [0.1, 0.2, 0.3]
    """
    rates = [
        float(node.annotations.get_value("rate"))
        for node in tree
        if node.annotations.get_value("rate")
    ]
    sorted_rates = sorted(rates)
    return sorted_rates

compare(logs=typer.Argument(..., help='BEAST log files'), output_prefix=typer.Option(..., help='Prefix for output files'), gamma_shape=typer.Option(..., help='Shape parameter for the gamma prior'), gamma_scale=typer.Option(..., help='Scale parameter for the gamma prior'), baseline_rate=typer.Option(None, help='Rate parameter used as baseline for posterior contrasts. Defaults to the first detected rate column.'), burnin=typer.Option(0.1, help='Fraction of the chain to discard as burnin'))

Generate comparison visualizations for model rate parameters.

Source code in src/episodic/workflow/scripts/arviz_output.py
@app.command()
def compare(
    logs: List[Path] = typer.Argument(..., help="BEAST log files"),
    output_prefix: Path = typer.Option(..., help="Prefix for output files"),
    gamma_shape: float = typer.Option(..., help="Shape parameter for the gamma prior"),
    gamma_scale: float = typer.Option(..., help="Scale parameter for the gamma prior"),
    baseline_rate: Optional[str] = typer.Option(
        None,
        help="Rate parameter used as baseline for posterior contrasts. Defaults to the first detected rate column.",
    ),
    burnin: float = typer.Option(0.1, help="Fraction of the chain to discard as burnin"),
):
    """Generate comparison visualizations for model rate parameters."""
    df = load_log_files(logs, burnin=burnin)
    var_names = rate_columns(df)
    if not var_names:
        raise typer.BadParameter("No rate columns found in logs.")

    posterior_df = df[var_names]
    display_map = wrapped_label_map(var_names)
    display_names = [display_map[name] for name in var_names]

    display_df = posterior_df.rename(columns=display_map)
    xdata = xr.Dataset.from_dataframe(display_df)
    dataset = az.InferenceData(posterior=xdata)

    # 1) Forest plot (median and HDI intervals)
    az.plot_forest(
        dataset,
        var_names=display_names,
        combined=True,
        hdi_prob=0.95,
        figsize=(12, max(4, len(var_names) * 0.8)),
    )
    plt.tight_layout()
    plt.savefig(f"{output_prefix}-forest.svg")
    plt.close()

    # 2) Prior vs posterior density overlays
    n_rates = len(var_names)
    fig, axes = plt.subplots(n_rates, 1, figsize=(12, max(4, n_rates * 2.4)), squeeze=False)
    prior_values = np.random.gamma(gamma_shape, gamma_scale, len(posterior_df))

    for idx, column in enumerate(var_names):
        ax = axes[idx, 0]
        posterior_values = posterior_df[column].to_numpy()
        post_x, post_y = density_xy(posterior_values)
        prior_x, prior_y = density_xy(prior_values)

        ax.plot(post_x, post_y, label="posterior", linewidth=2)
        ax.plot(prior_x, prior_y, label="prior", linewidth=1.8, linestyle="--")
        ax.fill_between(post_x, post_y, alpha=0.2)
        ax.set_title(display_map[column], fontsize=10)
        ax.set_ylabel("density")
        if idx == n_rates - 1:
            ax.set_xlabel("rate")
        ax.grid(alpha=0.25)
        ax.legend(loc="upper right")

    plt.tight_layout()
    plt.savefig(f"{output_prefix}-prior-vs-posterior.svg")
    plt.close()

    # 3) Posterior contrasts to baseline
    baseline = baseline_rate or var_names[0]
    if baseline not in var_names:
        raise typer.BadParameter(
            f"baseline_rate '{baseline}' not found in detected rate columns: {', '.join(var_names)}"
        )

    contrast_columns = [column for column in var_names if column != baseline]
    if contrast_columns:
        contrast_df = pd.DataFrame(
            {
                wrap_label(
                    f"{display_rate_label(column)} - {display_rate_label(baseline)}",
                    width=32,
                ): posterior_df[column] - posterior_df[baseline]
                for column in contrast_columns
            },
            index=posterior_df.index,
        )
        contrast_xdata = xr.Dataset.from_dataframe(contrast_df)
        contrast_dataset = az.InferenceData(posterior=contrast_xdata)

        axs = az.plot_violin(
            contrast_dataset,
            figsize=(max(8, len(contrast_df.columns) * 3.5), 8),
            textsize=14,
            sharey=True,
            sharex=False,
            rug=False,
            grid=(1, len(contrast_df.columns)),
        )
        for ax in np.ravel(np.array(axs)):
            ax.axhline(0, color="black", linewidth=1, linestyle="--", alpha=0.7)

        plt.tight_layout()
        plt.savefig(f"{output_prefix}-contrast-violin.svg")
        plt.close()
    else:
        typer.echo("Only one rate column found; skipping contrast plot.")

density_xy(values, bins=100)

Compute density x/y coordinates from histogram bins.

Source code in src/episodic/workflow/scripts/arviz_output.py
def density_xy(values: np.ndarray, bins: int = 100) -> tuple[np.ndarray, np.ndarray]:
    """Compute density x/y coordinates from histogram bins."""
    hist, edges = np.histogram(values, bins=bins, density=True)
    centers = (edges[:-1] + edges[1:]) / 2
    return centers, hist

display_rate_label(label)

Convert BEAST rate column IDs into concise plot labels.

Source code in src/episodic/workflow/scripts/arviz_output.py
def display_rate_label(label: str) -> str:
    """Convert BEAST rate column IDs into concise plot labels."""
    for suffix in (".clock.rate", ".ucgd.mean", ".rate"):
        if label.endswith(suffix):
            label = label[: -len(suffix)]
            break

    parts = label.split(".")
    if len(parts) >= 4 and parts[0] == parts[2] and parts[-1] in {"stem", "clade", "stem_and_clade"}:
        parts.pop(2)
    return ".".join(parts)

load_log_files(logs, burnin=0.1)

Loads BEAST log files into a single pandas DataFrame.

Parameters:

Name Type Description Default
logs List[Path]

A list of paths to the log files.

required
burnin float

The fraction of the chain to discard as burnin.

0.1

Returns:

Type Description
pd.DataFrame

pd.DataFrame: A DataFrame containing the log data.

Source code in src/episodic/workflow/scripts/arviz_output.py
def load_log_files(logs: List[Path], burnin: float = 0.1) -> pd.DataFrame:
    """
    Loads BEAST log files into a single pandas DataFrame.

    Args:
      logs (List[Path]): A list of paths to the log files.
      burnin (float): The fraction of the chain to discard as burnin.

    Returns:
      pd.DataFrame: A DataFrame containing the log data.
    """
    model_groups = defaultdict(list)
    for path in logs:
        model = path.parent.parent.name # assumes that the model name is the parent of the parent of the log file
        model_groups[model].append(path)

    dfs = []
    for model, paths in model_groups.items():
        model_df = pd.DataFrame()
        chain_count = 0
        for trace_log in paths:
            print(trace_log)
            duplicate_df = pd.read_csv(trace_log, sep="\t", comment="#").rename(columns={"state": "draw"})
            posterior_df = duplicate_df.truncate(before=burnin * len(duplicate_df))
            posterior_df["chain"] = chain_count
            model_df = pd.concat([model_df, posterior_df])
            chain_count += 1
        if len(model_groups) > 1:
            # rename the columns to include the model name when there are multiple models
            var_names = [c for c in model_df.columns if c not in ("chain", "draw")]
            rename_dict = {col: f"{model}.{col}" for col in var_names}
            model_df.rename(columns=rename_dict, inplace=True)
        if len(dfs):
            # drop the chain and draw columns from all but the first model
            model_df.drop(columns=["chain", "draw"], inplace=True)
        dfs.append(model_df)
    df = pd.concat(dfs, axis=1)
    df = df.set_index(["chain", "draw"])
    return df

no_browser()

A context manager that temporarily replaces the Bokeh show function with a dummy one.

Yields:

Name Type Description
None

Allows the code inside the 'with' block to run.

Notes

This context manager is used to prevent the Bokeh show function from opening a browser window when saving a plot.

Source code in src/episodic/workflow/scripts/arviz_output.py
@contextmanager
def no_browser():
    """
    A context manager that temporarily replaces the Bokeh show function with a dummy one.

    Yields:
      None: Allows the code inside the 'with' block to run.

    Notes:
      This context manager is used to prevent the Bokeh show function from opening a browser window when saving a plot.
    """
    # Save the original show function
    original_show = bokeh.io.showing._show_file_with_state

    # Define a custom show function that does nothing
    def dummy_show(obj, state, *args, **kwargs):
        """
    A custom show function that does nothing.

    Args:
      obj (object): The object to show.
      state (object): The state of the object.
      *args: Additional positional arguments.
      **kwargs: Additional keyword arguments.

    Returns:
      None: Does not return anything.
    """
        filename = save(obj, state=state)

    # Replace the Bokeh show function with the dummy one
    bokeh.io.showing._show_file_with_state = dummy_show

    try:
        yield  # This allows the code inside the 'with' block to run
    finally:
        # Restore the original show function
        bokeh.io.showing._show_file_with_state = original_show

rate_columns(df)

Return rate-like columns from a loaded trace dataframe.

Source code in src/episodic/workflow/scripts/arviz_output.py
def rate_columns(df: pd.DataFrame) -> List[str]:
    """Return rate-like columns from a loaded trace dataframe."""
    return [c for c in df.columns if c.endswith(".rate") or c.endswith(".ucgd.mean")]

rates(logs=typer.Argument(..., help='BEAST log files'), output_prefix=typer.Option(..., help='Prefix for output files'), gamma_shape=typer.Option(..., help='Shape parameter for the gamma prior'), gamma_scale=typer.Option(..., help='Scale parameter for the gamma prior'), burnin=typer.Option(0.1, help='Fraction of the chain to discard as burnin'))

Plots the rates from BEAST log files.

Parameters:

Name Type Description Default
logs List[Path]

A list of paths to the log files.

typer.Argument(..., help='BEAST log files')
output_prefix Path

The prefix for the output files.

typer.Option(..., help='Prefix for output files')
gamma_shape float

The shape parameter for the gamma prior.

typer.Option(..., help='Shape parameter for the gamma prior')
gamma_scale float

The scale parameter for the gamma prior.

typer.Option(..., help='Scale parameter for the gamma prior')
burnin float

The fraction of the chain to discard as burnin.

typer.Option(0.1, help='Fraction of the chain to discard as burnin')

Returns:

Name Type Description
None

Does not return anything.

Source code in src/episodic/workflow/scripts/arviz_output.py
@app.command()
def rates(
    logs: List[Path] = typer.Argument(..., help="BEAST log files"),
    output_prefix: Path = typer.Option(..., help="Prefix for output files"),
    gamma_shape: float = typer.Option(..., help="Shape parameter for the gamma prior"),
    gamma_scale: float = typer.Option(..., help="Scale parameter for the gamma prior"),
    burnin: float = typer.Option(0.1, help="Fraction of the chain to discard as burnin"),
):
    """
    Plots the rates from BEAST log files.

    Args:
      logs (List[Path]): A list of paths to the log files.
      output_prefix (Path): The prefix for the output files.
      gamma_shape (float): The shape parameter for the gamma prior.
      gamma_scale (float): The scale parameter for the gamma prior.
      burnin (float): The fraction of the chain to discard as burnin.

    Returns:
      None: Does not return anything.
    """
    df = load_log_files(logs, burnin=burnin)

    # extract the rate columns
    var_names = rate_columns(df)
    df = df[var_names]
    display_map = wrapped_label_map(var_names)
    df = df.rename(columns=display_map)


    # add a prior rate column
    prior_rate = np.random.gamma(gamma_shape, gamma_scale, len(df))
    df["Prior"] = prior_rate

    # convert to xarray
    xdata = xr.Dataset.from_dataframe(df)
    dataset = az.InferenceData(posterior=xdata)

    for rug in (True, False):
        rug_str = "violin-rug" if rug else "violin"
        # plot the rates
        axs = az.plot_violin(
            dataset,
            figsize=(len(df.columns) * 5, 12),
            textsize=16,
            sharey=True,
            sharex=False,
            rug=rug,
            grid=(1, len(df.columns)),
        )
        plt.savefig(f"{output_prefix}-{rug_str}.svg")

        selected_columns = df.columns[~df.columns.isin(["Prior", "draw", "chain"])]

        for ax in axs.flatten():
            ymax = df[selected_columns].max().max() + df[selected_columns].min().std()
            ymin = df[selected_columns].min().min() - df[selected_columns].min().std()
            ax.set_ylim(ymin, ymax)

        plt.savefig(f"{output_prefix}-{rug_str}-trimmed.svg")

    # plot the trace
    az.plot_trace(
        dataset,
        figsize=(12, len(df.columns) * 4),
        )
    plt.tight_layout(h_pad=2.0)
    plt.subplots_adjust(hspace=0.5)
    plt.savefig(f"{output_prefix}-trace.svg")

    # plot interactive trace
    output_file(filename=f"{output_prefix}-trace.html")
    with no_browser():
        # hack to save the html file without opening it in the browser
        # must set show=True to save the file
        az.plot_trace(dataset, backend="bokeh", show=True)

summary(logs=typer.Argument(..., help='BEAST log files'), output=typer.Argument(..., help='Output csv file'), burnin=0.1)

Generates a summary of the BEAST log files and saves it to a csv file.

Parameters:

Name Type Description Default
logs List[Path]

A list of paths to the log files.

typer.Argument(..., help='BEAST log files')
output Path

The output csv file.

typer.Argument(..., help='Output csv file')
burnin float

The fraction of the chain to discard as burnin.

0.1

Returns:

Name Type Description
None

Does not return anything.

Source code in src/episodic/workflow/scripts/arviz_output.py
@app.command()
def summary(
        logs: List[Path] = typer.Argument(..., help="BEAST log files"),
        output: Path = typer.Argument(..., help="Output csv file"),
        burnin: float = 0.1,
    ):
    """
    Generates a summary of the BEAST log files and saves it to a csv file.

    Args:
      logs (List[Path]): A list of paths to the log files.
      output (Path): The output csv file.
      burnin (float): The fraction of the chain to discard as burnin.

    Returns:
      None: Does not return anything.
    """

    df = load_log_files(logs, burnin=burnin)

    xdata = xr.Dataset.from_dataframe(df)
    dataset = az.InferenceData(posterior=xdata)
    summary = az.summary(dataset, round_to=6)

    # save the summary to csv
    summary.to_csv(output)

trace(logs=typer.Argument(..., help='BEAST log files'), directory=typer.Argument(..., help='Output directory'), burnin=typer.Option(0.1, help='Fraction of the chain to discard as burnin'))

Plots the trace from BEAST log files.

Parameters:

Name Type Description Default
logs List[Path]

A list of paths to the log files.

typer.Argument(..., help='BEAST log files')
directory Path

The output directory.

typer.Argument(..., help='Output directory')
burnin float

The fraction of the chain to discard as burnin.

typer.Option(0.1, help='Fraction of the chain to discard as burnin')

Returns:

Name Type Description
None

Does not return anything.

Source code in src/episodic/workflow/scripts/arviz_output.py
@app.command()
def trace(
        logs: List[Path] = typer.Argument(..., help="BEAST log files"),
        directory: Path = typer.Argument(..., help="Output directory"),
        burnin: float = typer.Option(0.1, help="Fraction of the chain to discard as burnin"),
):
    """
    Plots the trace from BEAST log files.

    Args:
      logs (List[Path]): A list of paths to the log files.
      directory (Path): The output directory.
      burnin (float): The fraction of the chain to discard as burnin.

    Returns:
      None: Does not return anything.
    """
    df = load_log_files(logs, burnin=burnin)

    # convert to xarray
    xdata = xr.Dataset.from_dataframe(df)
    dataset = az.InferenceData(posterior=xdata)

    output_file(filename=directory / f"{directory.name}-trace.html", title="Static HTML file")

    with no_browser():
        # hack to save the html file without opening it in the browser
        # must set show=True to save the file
        with az.rc_context(rc={'plot.max_subplots': None}):
            az.plot_trace(dataset, backend="bokeh", show=True)

wrap_label(label, width=28)

Wrap long parameter labels to reduce overlap in plots.

Source code in src/episodic/workflow/scripts/arviz_output.py
def wrap_label(label: str, width: int = 28) -> str:
    """Wrap long parameter labels to reduce overlap in plots."""
    label = label.replace("_", " ")
    label = re.sub(r"(?<!\d)\.|\.(?!\d)", " ", label)
    wrapped = textwrap.fill(label, width=width, break_long_words=False, break_on_hyphens=False)
    return wrapped

wrapped_label_map(labels, width=28)

Build a wrapped display-name mapping while preserving uniqueness.

Source code in src/episodic/workflow/scripts/arviz_output.py
def wrapped_label_map(labels: List[str], width: int = 28) -> dict[str, str]:
    """Build a wrapped display-name mapping while preserving uniqueness."""
    mapping = {}
    used = set()
    for original in labels:
        candidate = wrap_label(display_rate_label(original), width=width)
        if candidate in used:
            idx = 2
            alt = f"{candidate}\n({idx})"
            while alt in used:
                idx += 1
                alt = f"{candidate}\n({idx})"
            candidate = alt
        mapping[original] = candidate
        used.add(candidate)
    return mapping

calculate_log_diff(pos_odds, p_odds)

Calculates the log difference of two odds ratios.

Parameters:

Name Type Description Default
pos_odds float

The posterior odds ratio.

required
p_odds float

The prior odds ratio.

required

Returns:

Name Type Description
float

The log difference of the two odds ratios.

Examples:

>>> calculate_log_diff(2, 1)
0.6931471805599453
Source code in src/episodic/workflow/scripts/calculate_odds.py
def calculate_log_diff(pos_odds, p_odds):
    """
    Calculates the log difference of two odds ratios.

    Args:
      pos_odds (float): The posterior odds ratio.
      p_odds (float): The prior odds ratio.

    Returns:
      float: The log difference of the two odds ratios.

    Examples:
      >>> calculate_log_diff(2, 1)
      0.6931471805599453
    """
    log_pos_odds = safe_log(pos_odds)
    log_p_odds = safe_log(p_odds)

    if log_pos_odds is not None and log_p_odds is not None:
        log_diff = log_pos_odds - log_p_odds
    else:
        # Handle the case where one or both of the odds are not valid for log operation
        log_diff = float("inf")  # Or choose another appropriate response

    return log_diff

calculate_odds(logs=typer.Argument(..., help='The path to the log CSV file'), output_file=typer.Argument(..., help='The path to the output CSV file where the results will be saved'), gamma_shape=typer.Option(..., help='The shape parameter for the gamma distribution'), gamma_scale=typer.Option(..., help='The scale parameter for the gamma distribution'), foreground_label=typer.Option(None, help='Optional foreground/local-rate label prefix.'), background_label=typer.Option(None, help='Optional background-rate label prefix.'), burnin=typer.Option(0.1, '--burnin', '-b', help='Fraction of trees to discard as burn-in'))

Calculates the odds and log differences for a given set of data.

Parameters:

Name Type Description Default
logs List[Path]

The paths to the log CSV files.

typer.Argument(..., help='The path to the log CSV file')
output_file str

The path to the output CSV file where the results will be saved.

typer.Argument(..., help='The path to the output CSV file where the results will be saved')
gamma_shape float

The shape parameter for the gamma distribution.

typer.Option(..., help='The shape parameter for the gamma distribution')
gamma_scale float

The scale parameter for the gamma distribution.

typer.Option(..., help='The scale parameter for the gamma distribution')
burnin float

Fraction of trees to discard as burn-in.

typer.Option(0.1, '--burnin', '-b', help='Fraction of trees to discard as burn-in')

Returns:

Type Description

None

Examples:

>>> calculate_odds(['log1.csv', 'log2.csv'], 'results.csv', 2, 1, 0.1)
Calculated odds and log differences saved to results.csv
Source code in src/episodic/workflow/scripts/calculate_odds.py
@app.command()
def calculate_odds(
    logs: List[Path] = typer.Argument(..., help="The path to the log CSV file"),
    output_file: str = typer.Argument(..., help="The path to the output CSV file where the results will be saved"),
    gamma_shape: float = typer.Option(..., help="The shape parameter for the gamma distribution"),
    gamma_scale: float = typer.Option(..., help="The scale parameter for the gamma distribution"),
    foreground_label: Optional[str] = typer.Option(None, help="Optional foreground/local-rate label prefix."),
    background_label: Optional[str] = typer.Option(None, help="Optional background-rate label prefix."),
    burnin: float = typer.Option(0.1, "--burnin", "-b", help="Fraction of trees to discard as burn-in"),
):
    """
    Calculates the odds and log differences for a given set of data.

    Args:
      logs (List[Path]): The paths to the log CSV files.
      output_file (str): The path to the output CSV file where the results will be saved.
      gamma_shape (float): The shape parameter for the gamma distribution.
      gamma_scale (float): The scale parameter for the gamma distribution.
      burnin (float): Fraction of trees to discard as burn-in.

    Returns:
      None

    Examples:
      >>> calculate_odds(['log1.csv', 'log2.csv'], 'results.csv', 2, 1, 0.1)
      Calculated odds and log differences saved to results.csv
    """
    # Read the CSV file into a DataFrame df
    dfs = []
    for log_path in logs:
        duplicate = pd.read_csv(log_path, sep="\t", comment="#")
        # discard burn-in
        dfs.append(duplicate[int(burnin * len(duplicate)) :])

    df = pd.concat(dfs)

    # Identify local clock rate columns and compare them to the matching
    # partition background rate. For backward compatibility, unpartitioned
    # analyses still use `clock.rate`.
    partition_backgrounds = {
        _rate_key(column, ".clock.rate", background_label or "background"): column
        for column in df.columns
        if column.endswith(".clock.rate")
    }
    default_background = "clock.rate" if "clock.rate" in df.columns else None
    if default_background is None and len(partition_backgrounds) == 1:
        default_background = next(iter(partition_backgrounds.values()))

    rate_columns = []
    for col in df.columns:
        if not col.endswith(".rate"):
            continue
        if col == default_background or col.endswith(".clock.rate"):
            continue
        rate_columns.append(col)

    # Generate gamma-distributed random variables for p_fg and p_bg
    n_samples = len(df)
    p_fg = np.random.gamma(gamma_shape, gamma_scale, n_samples)
    p_bg = np.random.gamma(gamma_shape, gamma_scale, n_samples)

    # Calculate the prior odds (p_odds)
    p_p = np.mean(p_fg > p_bg)
    p_odds = p_p / (1 - p_p) if p_p < 1 else float("inf")

    # Results dictionary
    results = []

    # Calculate the odds for each .rate column compared to df['clock.rate']
    for rate_column in rate_columns:
        pos_fg = df[rate_column]

        rate_key = _rate_key(rate_column, ".rate", foreground_label)
        local_parts = rate_key.split(".")

        background_column = partition_backgrounds.get(rate_key, default_background)
        while background_column is None and local_parts:
            local_parts = local_parts[:-1]
            background_column = partition_backgrounds.get(".".join(local_parts))

        if background_column is None:
            raise ValueError(
                f"Could not find a background clock rate column for '{rate_column}'."
            )
        pos_bg = df[background_column]

        # Calculate the posterior odds (pos_odds)
        pos_p = np.mean(pos_fg > pos_bg)
        pos_odds = pos_p / (1 - pos_p) if pos_p < 1 else float("inf")

        # Calculate the log difference of the odds ratios, if possible
        log_diff = calculate_log_diff(pos_odds, p_odds)

        # Store the results
        results.append(
            {
                "Rate Column": rate_column,
                "Background Column": background_column,
                "p_p": p_p,
                "p_odds": p_odds,
                "pos_p": pos_p,
                "pos_odds": pos_odds,
                "bf": log_diff,
            }
        )

    # Create a DataFrame from the results dictionary
    results_df = pd.DataFrame(results)
    results_df.to_csv(output_file, index=False)

    typer.echo(f"Calculated odds and log differences saved to {output_file}")

safe_log(x)

Returns the logarithm of x if x is positive, otherwise returns None.

Source code in src/episodic/workflow/scripts/calculate_odds.py
def safe_log(x):
    """Returns the logarithm of x if x is positive, otherwise returns None."""
    if x > 0:
        return np.log(x)
    else:
        # Handle the case where x is not positive
        return None

Plot partition-level background vs foreground posterior substitution-rate histograms.

This script consumes one or more posterior sample tables (typically BEAST .log files) and writes:

  • an overlaid frequency histogram of per-partition background and foreground rates
  • a long-format CSV table used to generate the plot

Supported input formats:

  1. Wide BEAST-style columns, e.g.:

    • <partition>.clock.rate (background)
    • <partition>.clade.rate, <partition>.stem.rate, <partition>.stem_and_clade.rate, or equivalent
  2. Long columns:

    • partition, background_rate, clade_rate

For current FLC logs used by Episodic, foreground-rate columns are treated as absolute rates.

plot_partition_foreground_rates(posterior_samples=typer.Argument(..., help='Posterior samples files (BEAST .log or CSV/TSV).'), output_plot=typer.Argument(..., help='Output SVG/PDF/PNG path for the frequency histogram.'), output_table=typer.Argument(..., help='Output CSV for long-format rates used in plotting.'), burnin=typer.Option(0.1, '--burnin', '-b', help='Fraction of each chain to discard as burn-in.'), foreground_label=typer.Option(None, help='Optional foreground/foreground-rate label for plot legends.'), background_label=typer.Option(None, help='Optional background-rate label for plot legends.'), bins=typer.Option(220, '--bins', help='Number of histogram bins.'))

Generate partitioned posterior frequency histograms for background vs foreground rates.

Parameters:

Name Type Description Default
posterior_samples List[Path]

One or more posterior sample files (.log, .tsv, .csv).

typer.Argument(..., help='Posterior samples files (BEAST .log or CSV/TSV).')
output_plot Path

Output path for overlaid frequency histogram (.svg, .png, .pdf).

typer.Argument(..., help='Output SVG/PDF/PNG path for the frequency histogram.')
output_table Path

Output path for long-format CSV with columns partition, clock_state, and rate.

typer.Argument(..., help='Output CSV for long-format rates used in plotting.')
burnin float

Fraction of each input chain to discard from the start.

typer.Option(0.1, '--burnin', '-b', help='Fraction of each chain to discard as burn-in.')

Raises:

Type Description
typer.BadParameter

If burnin is outside [0, 1).

ValueError

If required rate columns are missing or no posterior samples remain.

Source code in src/episodic/workflow/scripts/plot_partition_local_rate_posteriors.py
@app.command()
def plot_partition_foreground_rates(
    posterior_samples: List[Path] = typer.Argument(..., help="Posterior samples files (BEAST .log or CSV/TSV)."),
    output_plot: Path = typer.Argument(..., help="Output SVG/PDF/PNG path for the frequency histogram."),
    output_table: Path = typer.Argument(..., help="Output CSV for long-format rates used in plotting."),
    burnin: float = typer.Option(0.1, "--burnin", "-b", help="Fraction of each chain to discard as burn-in."),
    foreground_label: Optional[str] = typer.Option(None, help="Optional foreground/foreground-rate label for plot legends."),
    background_label: Optional[str] = typer.Option(None, help="Optional background-rate label for plot legends."),
    bins: int = typer.Option(220, "--bins", help="Number of histogram bins."),
):
    """Generate partitioned posterior frequency histograms for background vs foreground rates.

    Args:
        posterior_samples: One or more posterior sample files (`.log`, `.tsv`, `.csv`).
        output_plot: Output path for overlaid frequency histogram (`.svg`, `.png`, `.pdf`).
        output_table: Output path for long-format CSV with columns
            `partition`, `clock_state`, and `rate`.
        burnin: Fraction of each input chain to discard from the start.

    Raises:
        typer.BadParameter: If `burnin` is outside `[0, 1)`.
        ValueError: If required rate columns are missing or no posterior samples remain.
    """
    if burnin < 0 or burnin >= 1:
        raise typer.BadParameter("--burnin must be in [0, 1).")
    if bins < 1:
        raise typer.BadParameter("--bins must be >= 1.")

    dfs = []
    for path in posterior_samples:
        raw = _read_table(path)
        dfs.append(_drop_burnin(raw, burnin=burnin))

    combined = pd.concat(dfs, ignore_index=True)

    if {"partition", "background_rate", "clade_rate"}.issubset(combined.columns):
        plot_df = _extract_long(combined)
    else:
        plot_df = _extract_wide(
            combined,
            foreground_label=foreground_label,
            background_label=background_label,
        )

    plot_df = plot_df.dropna(subset=["partition", "clock_state", "rate"])
    if plot_df.empty:
        raise ValueError("No posterior rate samples available after filtering and burn-in.")

    plot_df["partition"] = plot_df["partition"].astype(str)
    plot_df["clock_state"] = plot_df["clock_state"].map(
        lambda state: _display_state(str(state), foreground_label, background_label)
    )
    background_state = _display_state("background", foreground_label, background_label)
    foreground_states = sorted(state for state in plot_df["clock_state"].unique() if state != background_state)
    state_order = [background_state, *foreground_states]
    plot_df["clock_state"] = pd.Categorical(plot_df["clock_state"], categories=state_order, ordered=True)

    partitions = sorted(plot_df["partition"].unique())
    palette = _build_palette(partitions, states=state_order)

    sns.set_style("whitegrid")
    fig, ax = plt.subplots(figsize=(11, 7.5))
    annotations = []
    finite_rates = plot_df["rate"].to_numpy(dtype=float)
    finite_rates = finite_rates[np.isfinite(finite_rates)]
    if finite_rates.size == 0:
        raise ValueError("No finite posterior rate samples available for plotting.")
    histogram_bins = _histogram_bins(finite_rates, bins=bins)

    for state in state_order:
        for partition in partitions:
            subset = plot_df[
                (plot_df["partition"] == partition) & (plot_df["clock_state"] == state)
            ]
            if subset.empty:
                continue
            samples = subset["rate"].to_numpy(dtype=float)
            samples = samples[np.isfinite(samples)]
            if samples.size == 0:
                continue
            weights = np.ones(samples.size, dtype=float) / samples.size
            ax.hist(
                samples,
                bins=histogram_bins,
                weights=weights,
                alpha=0.35 if state == background_state else 0.42,
                linewidth=0.18,
                color=palette[(partition, state)],
                edgecolor=palette[(partition, state)],
                label=f"{partition} ({state})",
            )

            peak_x, peak_frequency = _estimate_frequency_peak(samples, histogram_bins)
            annotations.append((partition, state, peak_x, peak_frequency))

    ymin, ymax = ax.get_ylim()
    y_span = max(ymax - ymin, 1e-9)
    for idx, (partition, state, peak_x, peak_frequency) in enumerate(annotations):
        y_text = peak_frequency + y_span * (0.02 + (idx % 3) * 0.015)
        ax.text(
            peak_x,
            y_text,
            partition,
            color=palette[(partition, state)],
            fontsize=12,
            fontweight="bold",
            ha="center",
            va="bottom",
        )

    ax.set_xlabel("substitutions/site/year", fontsize=19, fontweight="bold")
    ax.set_ylabel("Frequency", fontsize=16)
    ax.set_title("Partitioned fixed local clock posterior rates", fontsize=16)
    if np.all(finite_rates > 0):
        ax.set_xscale("log")
        ax.set_xlim(float(histogram_bins[0]), float(histogram_bins[-1]))
        ticks = _rate_ticks(float(histogram_bins[0]), float(histogram_bins[-1]))
        if ticks:
            ax.set_xticks(ticks)
    ax.xaxis.set_major_formatter(FuncFormatter(_scientific_tick))
    y_limit, y_ticks, y_tick_labels = _frequency_ticks(ax.get_ylim()[1])
    ax.set_ylim(0, y_limit)
    ax.set_yticks(y_ticks)
    ax.set_yticklabels(y_tick_labels)
    ax.tick_params(axis="both", labelsize=13)
    ax.grid(True, which="major", alpha=0.22, linewidth=1.2)
    ax.grid(True, which="minor", axis="x", alpha=0.05)
    ax.set_axisbelow(True)
    for spine in ["top", "right"]:
        ax.spines[spine].set_visible(False)
    handles, labels = _legend_entries(plot_df, partitions, state_order, palette, background_state)
    legend = ax.legend(handles, labels, frameon=True, fontsize=9)
    for text in legend.get_texts():
        if text.get_text() in state_order:
            text.set_fontweight("bold")

    output_plot.parent.mkdir(parents=True, exist_ok=True)
    output_table.parent.mkdir(parents=True, exist_ok=True)

    fig.tight_layout()
    fig.savefig(output_plot)
    plt.close(fig)

    plot_df.to_csv(output_table, index=False)

Plot posterior GTR substitution-model relative rates from FLC analyses.

The current Episodic BEAST template estimates one GTR relative-rate vector per alignment partition. FLC models change branch clock rates, not branch-specific substitution-model relative rates, so this plot shows the logged GTR relative rates that are actually estimated for each partition.

plot_flc_substitution_model_rates(logs=typer.Argument(..., help='BEAST log files from a partitioned FLC analysis.'), output_plot=typer.Argument(..., help='Output SVG/PDF/PNG path.'), output_table=typer.Argument(..., help='Output CSV path for long-format logged relative rates.'), burnin=typer.Option(0.1, '--burnin', '-b', help='Fraction of each chain to discard as burn-in.'))

Plot posterior log GTR relative rates by partition and substitution type.

Source code in src/episodic/workflow/scripts/plot_flc_substitution_model_rates.py
@app.command()
def plot_flc_substitution_model_rates(
    logs: List[Path] = typer.Argument(..., help="BEAST log files from a partitioned FLC analysis."),
    output_plot: Path = typer.Argument(..., help="Output SVG/PDF/PNG path."),
    output_table: Path = typer.Argument(..., help="Output CSV path for long-format logged relative rates."),
    burnin: float = typer.Option(0.1, "--burnin", "-b", help="Fraction of each chain to discard as burn-in."),
):
    """Plot posterior log GTR relative rates by partition and substitution type."""
    if burnin < 0 or burnin >= 1:
        raise typer.BadParameter("--burnin must be in [0, 1).")

    df = pd.concat([_read_log(path, burnin=burnin) for path in logs], ignore_index=True)
    rates = _extract_relative_rates(df)

    output_plot.parent.mkdir(parents=True, exist_ok=True)
    output_table.parent.mkdir(parents=True, exist_ok=True)
    rates.to_csv(output_table, index=False)

    sns.set_style("whitegrid")
    partitions = sorted(rates["partition"].unique())
    ncols = min(3, len(partitions))
    nrows = math.ceil(len(partitions) / ncols)
    fig, axes = plt.subplots(
        nrows,
        ncols,
        figsize=(max(4, ncols * 3.4), max(3.4, nrows * 3.2)),
        squeeze=False,
    )

    for ax in axes.flatten():
        ax.set_visible(False)

    for ax, partition in zip(axes.flatten(), partitions):
        ax.set_visible(True)
        _plot_partition(ax, rates[rates["partition"] == partition], partition)

    fig.suptitle("GTR substitution-model relative rates", fontsize=13)
    fig.tight_layout()
    fig.savefig(output_plot)
    plt.close(fig)

add_posterior_summary(ax, mean, median, hdi_low, hdi_high)

Add a compact summary annotation to the histogram panel.

Source code in src/episodic/workflow/scripts/plot_traces.py
def add_posterior_summary(ax: plt.Axes, mean: float, median: float, hdi_low: float, hdi_high: float) -> None:
    """Add a compact summary annotation to the histogram panel."""
    summary = (
        f"mean   {mean:.3g}\n"
        f"median {median:.3g}\n"
        f"95% CI [{hdi_low:.3g}, {hdi_high:.3g}]"
    )
    ax.text(
        0.98,
        0.98,
        summary,
        transform=ax.transAxes,
        ha="right",
        va="top",
        fontsize=9,
        color=TEXT_COLOR,
        bbox={"boxstyle": "round,pad=0.35", "facecolor": "white", "edgecolor": "#d1d5db", "alpha": 0.95},
    )

camel_to_title_case(column)

Transform camelCase or dotted names into readable title case.

Source code in src/episodic/workflow/scripts/plot_traces.py
def camel_to_title_case(column: str) -> str:
    """Transform camelCase or dotted names into readable title case."""
    column = column.replace(".", " ")
    return re.sub(r"((?<=[a-z])[A-Z]|(?<!\A)[A-Z](?=[a-z]))", r" \1", column).title()

debug_log(enabled, message)

Print a debug message when enabled.

Source code in src/episodic/workflow/scripts/plot_traces.py
def debug_log(enabled: bool, message: str) -> None:
    """Print a debug message when enabled."""
    if enabled:
        typer.echo(f"[debug] {message}")

downsample_xy(x, y, max_points)

Evenly downsample paired x/y series for fast plotting.

Source code in src/episodic/workflow/scripts/plot_traces.py
def downsample_xy(x: pd.Series, y: pd.Series, max_points: int) -> tuple[np.ndarray, np.ndarray]:
    """Evenly downsample paired x/y series for fast plotting."""
    if max_points <= 0 or len(x) <= max_points:
        return x.to_numpy(), y.to_numpy()

    indices = np.linspace(0, len(x) - 1, num=max_points, dtype=int)
    indices = np.unique(indices)
    return x.iloc[indices].to_numpy(), y.iloc[indices].to_numpy()

get_axis_limits(values)

Compute padded y-axis limits for a series.

Source code in src/episodic/workflow/scripts/plot_traces.py
def get_axis_limits(values: pd.Series) -> tuple[float, float]:
    """Compute padded y-axis limits for a series."""
    value_min = float(values.min())
    value_max = float(values.max())
    padding = 0.1 * abs(value_max - value_min)

    if padding == 0:
        padding = 0.1 * abs(value_max) if value_max != 0 else 0.1

    return value_min - padding, value_max + padding

plot_single_variable(output_path, variable, burnin_df, posterior_df, max_points, output_format, dpi, debug)

Render and save one trace plot.

Source code in src/episodic/workflow/scripts/plot_traces.py
def plot_single_variable(
    output_path: Path,
    variable: str,
    burnin_df: pd.DataFrame,
    posterior_df: pd.DataFrame,
    max_points: int,
    output_format: str,
    dpi: int,
    debug: bool,
) -> None:
    """Render and save one trace plot."""
    variable_start = time.perf_counter()
    variable_title = camel_to_title_case(variable)
    typer.echo(f"Plotting {variable_title}")
    debug_log(debug, f"[{variable}] start")

    stats_start = time.perf_counter()
    posterior_series = posterior_df[variable]
    burnin_series = burnin_df[variable]
    burnin_y_min, burnin_y_max = get_axis_limits(burnin_series) if not burnin_series.empty else get_axis_limits(posterior_series)
    posterior_y_min, posterior_y_max = get_axis_limits(posterior_series)
    mean = float(posterior_series.mean())
    median = float(posterior_series.median())
    hdi_low, hdi_high = np.quantile(posterior_series.to_numpy(), [0.025, 0.975])
    debug_log(debug, f"[{variable}] computed stats in {time.perf_counter() - stats_start:.2f}s")

    sample_start = time.perf_counter()
    burnin_x, burnin_y = downsample_xy(burnin_df["state"], burnin_series, max_points)
    posterior_x, posterior_y = downsample_xy(posterior_df["state"], posterior_series, max_points)
    debug_log(debug, f"[{variable}] downsampled data in {time.perf_counter() - sample_start:.2f}s")

    figure_start = time.perf_counter()
    fig, axes = plt.subplots(
        1,
        3,
        figsize=(FIGURE_WIDTH, FIGURE_HEIGHT),
        gridspec_kw={"width_ratios": [1, 3, 1]},
        constrained_layout=True,
    )
    burnin_ax, posterior_ax, hist_ax = axes

    fig.patch.set_facecolor("white")
    fig.suptitle(variable_title, fontsize=18, color=TEXT_COLOR, fontweight="semibold")

    for ax in axes:
        style_axis(ax)

    burnin_ax.set_ylim(burnin_y_min, burnin_y_max)
    posterior_ax.set_ylim(posterior_y_min, posterior_y_max)
    hist_ax.set_ylim(posterior_y_min, posterior_y_max)

    burnin_ax.axhspan(burnin_y_min, burnin_y_max, color=POSTERIOR_FILL, zorder=0)
    posterior_ax.axhspan(posterior_y_min, posterior_y_max, color=POSTERIOR_FILL, zorder=0)
    posterior_ax.axhspan(hdi_low, hdi_high, color="#c6dbef", alpha=0.65, zorder=0)

    if len(burnin_x) > 0:
        burnin_ax.axvspan(float(burnin_x.min()), float(burnin_x.max()), color=BURNIN_COLOR, alpha=0.12, zorder=0)
        burnin_ax.plot(burnin_x, burnin_y, color=BURNIN_COLOR, linewidth=1.1, alpha=0.95)
        if len(burnin_x) <= 1000:
            burnin_ax.scatter(burnin_x, burnin_y, color=BURNIN_COLOR, s=7, alpha=0.45, linewidths=0)

    posterior_ax.plot(posterior_x, posterior_y, color=POSTERIOR_COLOR, linewidth=1.1, alpha=0.95)
    if len(posterior_x) <= 1000:
        posterior_ax.scatter(posterior_x, posterior_y, color=POSTERIOR_COLOR, s=7, alpha=0.4, linewidths=0)

    posterior_ax.axhline(mean, color=MEAN_COLOR, linewidth=1.2, linestyle="--", alpha=0.9)
    posterior_ax.axhline(median, color=MEDIAN_COLOR, linewidth=1.2, linestyle=":", alpha=0.9)

    hist_ax.hist(
        posterior_series.to_numpy(),
        bins=60,
        orientation="horizontal",
        density=True,
        color=POSTERIOR_COLOR,
        alpha=0.78,
        edgecolor="white",
        linewidth=0.6,
    )
    hist_ax.axhline(mean, color=MEAN_COLOR, linewidth=1.2, linestyle="--")
    hist_ax.axhline(median, color=MEDIAN_COLOR, linewidth=1.2, linestyle=":")
    hist_ax.axhspan(hdi_low, hdi_high, color="#c6dbef", alpha=0.35, zorder=0)
    add_posterior_summary(hist_ax, mean, median, hdi_low, hdi_high)
    hist_ax.set_xlim(left=0)

    burnin_ax.set_title("Burn-in", fontsize=12, color=TEXT_COLOR, fontweight="semibold")
    posterior_ax.set_title("Posterior", fontsize=12, color=TEXT_COLOR, fontweight="semibold")
    hist_ax.set_title("Posterior density", fontsize=12, color=TEXT_COLOR, fontweight="semibold")

    burnin_ax.set_xlabel("MCMC state", color=TEXT_COLOR)
    posterior_ax.set_xlabel("MCMC state", color=TEXT_COLOR)
    hist_ax.set_xlabel("Density", color=TEXT_COLOR)
    burnin_ax.set_ylabel(variable_title)
    posterior_ax.set_ylabel("")
    hist_ax.set_ylabel("")

    burnin_ax.text(
        0.02,
        0.98,
        f"n={len(burnin_df):,}",
        transform=burnin_ax.transAxes,
        ha="left",
        va="top",
        fontsize=9,
        color=TEXT_COLOR,
    )
    posterior_ax.text(
        0.02,
        0.98,
        f"n={len(posterior_df):,}",
        transform=posterior_ax.transAxes,
        ha="left",
        va="top",
        fontsize=9,
        color=TEXT_COLOR,
    )

    debug_log(debug, f"[{variable}] built matplotlib figure in {time.perf_counter() - figure_start:.2f}s")

    write_start = time.perf_counter()
    save_kwargs = {"bbox_inches": "tight"}
    if output_format == "png":
        save_kwargs["dpi"] = dpi
    fig.savefig(output_path, **save_kwargs)
    plt.close(fig)
    debug_log(debug, f"[{variable}] wrote figure in {time.perf_counter() - write_start:.2f}s")
    debug_log(debug, f"[{variable}] total {time.perf_counter() - variable_start:.2f}s")

plot_traces(trace_log=TRACE_LOG_ARGUMENT, output=OUTPUT_ARGUMENT, burnin=0.1, max_points=MAX_POINTS_OPTION, dpi=DPI_OPTION, debug=DEBUG_OPTION, output_format=FORMAT_OPTION)

Produce fast publication-ready trace plots using Matplotlib.

Source code in src/episodic/workflow/scripts/plot_traces.py
def plot_traces(
    trace_log: Path = TRACE_LOG_ARGUMENT,
    output: Path = OUTPUT_ARGUMENT,
    burnin: float = 0.1,
    max_points: int = MAX_POINTS_OPTION,
    dpi: int = DPI_OPTION,
    debug: bool = DEBUG_OPTION,
    output_format: str = FORMAT_OPTION,
) -> None:
    """Produce fast publication-ready trace plots using Matplotlib."""
    start_time = time.perf_counter()
    output.mkdir(exist_ok=True, parents=True)

    output_format = output_format.lower()
    if output_format not in VALID_FORMATS:
        msg = f"Unsupported output format '{output_format}'. Choose from: {', '.join(sorted(VALID_FORMATS))}."
        raise typer.BadParameter(msg)

    debug_log(debug, f"Output directory: {output}")
    debug_log(debug, f"Output format: {output_format}")

    read_start = time.perf_counter()
    df = pd.read_csv(trace_log, sep="\t", comment="#")
    debug_log(debug, f"Loaded trace log with shape {df.shape} in {time.perf_counter() - read_start:.2f}s")

    if "state" not in df.columns:
        msg = "Trace log must contain a 'state' column."
        raise typer.BadParameter(msg)

    split_start = time.perf_counter()
    burnin_df, posterior_df = split_trace(df, burnin)
    debug_log(
        debug,
        (
            f"Split trace in {time.perf_counter() - split_start:.2f}s: "
            f"burn-in rows={len(burnin_df)}, posterior rows={len(posterior_df)}"
        ),
    )

    if posterior_df.empty:
        msg = "Posterior trace is empty after burn-in removal. Reduce burn-in or provide a longer chain."
        raise typer.BadParameter(msg)

    variables = [column for column in df.columns if column != "state"]
    debug_log(debug, f"Preparing plots for {len(variables)} variables")

    for variable in variables:
        output_path = output / f"{variable}.{output_format}"
        plot_single_variable(
            output_path=output_path,
            variable=variable,
            burnin_df=burnin_df,
            posterior_df=posterior_df,
            max_points=max_points,
            output_format=output_format,
            dpi=dpi,
            debug=debug,
        )

    debug_log(debug, f"Finished all plots in {time.perf_counter() - start_time:.2f}s")

split_trace(df, burnin)

Split a trace dataframe into burn-in and posterior segments.

Source code in src/episodic/workflow/scripts/plot_traces.py
def split_trace(df: pd.DataFrame, burnin: float) -> tuple[pd.DataFrame, pd.DataFrame]:
    """Split a trace dataframe into burn-in and posterior segments."""
    burnin_rows = int(len(df) * burnin)
    burnin_rows = min(max(burnin_rows, 0), len(df))
    return df.iloc[:burnin_rows], df.iloc[burnin_rows:]

style_axis(ax)

Apply a publication-style theme to an axis.

Source code in src/episodic/workflow/scripts/plot_traces.py
def style_axis(ax: plt.Axes) -> None:
    """Apply a publication-style theme to an axis."""
    ax.set_facecolor("white")
    ax.grid(True, color=GRID_COLOR, linewidth=0.8)
    ax.set_axisbelow(True)

    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.spines["left"].set_linewidth(1.0)
    ax.spines["bottom"].set_linewidth(1.0)
    ax.spines["left"].set_color(TEXT_COLOR)
    ax.spines["bottom"].set_color(TEXT_COLOR)

    ax.tick_params(direction="out", colors=TEXT_COLOR, labelsize=10)
    ax.xaxis.set_major_locator(MaxNLocator(nbins=5))
    ax.yaxis.set_major_locator(MaxNLocator(nbins=5))
    formatter = ScalarFormatter(useMathText=True)
    formatter.set_powerlimits((-3, 4))
    ax.yaxis.set_major_formatter(formatter)

get_schema(path)

Gets the schema for a given file.

Parameters:

Name Type Description Default
path Path

The path to the file.

required

Returns:

Name Type Description
str

The schema for the file.

Raises:

Type Description
Exception

If the schema cannot be determined from the file extension.

Examples:

>>> get_schema(Path('file.nexus'))
'nexus'
Source code in src/episodic/workflow/scripts/tree_converter.py
def get_schema(path: Path):
    """
    Gets the schema for a given file.

    Args:
      path (Path): The path to the file.

    Returns:
      str: The schema for the file.

    Raises:
      Exception: If the schema cannot be determined from the file extension.

    Examples:
      >>> get_schema(Path('file.nexus'))
      'nexus'
    """
    if path.suffix in [".nxs", ".nexus", ".treefile"]:
        return "nexus"
    if path.suffix in [".newick", ".nwk"]:
        return "newick"
    raise Exception(f"Cannot get schema for file {path}")

tree_converter(input=typer.Argument(..., help='The path to the tree file in newick format.'), output=typer.Argument(..., help='The path to the tree file in newick format.'), input_schema=typer.Option('', help='The input file schema. If empty then it tries to infer the schema from the file extension.'), output_schema=typer.Option('', help='The output file schema. If empty then it tries to infer the schema from the file extension.'), node_label=typer.Option('', help='Label the nodes from an annotation if present.'))

Converts a tree file from one format to another.

Parameters:

Name Type Description Default
input Path

The path to the tree file in newick format.

typer.Argument(..., help='The path to the tree file in newick format.')
output Path

The path to the tree file in newick format.

typer.Argument(..., help='The path to the tree file in newick format.')
input_schema str

The input file schema. If empty then it tries to infer the schema from the file extension.

typer.Option('', help='The input file schema. If empty then it tries to infer the schema from the file extension.')
output_schema str

The output file schema. If empty then it tries to infer the schema from the file extension.

typer.Option('', help='The output file schema. If empty then it tries to infer the schema from the file extension.')
node_label str

Label the nodes from an annotation if present.

typer.Option('', help='Label the nodes from an annotation if present.')

Returns:

Type Description

None

Examples:

>>> tree_converter(Path('input.nwk'), Path('output.nexus'), input_schema='newick', output_schema='nexus', node_label='label')
Source code in src/episodic/workflow/scripts/tree_converter.py
def tree_converter(
    input: Path = typer.Argument(..., help="The path to the tree file in newick format."),
    output: Path = typer.Argument(..., help="The path to the tree file in newick format."),
    input_schema: str = typer.Option(
        "", help="The input file schema. If empty then it tries to infer the schema from the file extension."
    ),
    output_schema: str = typer.Option(
        "", help="The output file schema. If empty then it tries to infer the schema from the file extension."
    ),
    node_label: str = typer.Option("", help="Label the nodes from an annotation if present."),
):
    """
    Converts a tree file from one format to another.

    Args:
      input (Path): The path to the tree file in newick format.
      output (Path): The path to the tree file in newick format.
      input_schema (str): The input file schema. If empty then it tries to infer the schema from the file extension.
      output_schema (str): The output file schema. If empty then it tries to infer the schema from the file extension.
      node_label (str): Label the nodes from an annotation if present.

    Returns:
      None

    Examples:
      >>> tree_converter(Path('input.nwk'), Path('output.nexus'), input_schema='newick', output_schema='nexus', node_label='label')
    """
    if not input_schema:
        input_schema = get_schema(input)
    if not output_schema:
        output_schema = get_schema(output)

    input_str = input.read_text()

    # Replace quote char because dendropy nexus tokenizer only uses single quotes by default
    input_str = input_str.replace('"', "'")

    tree = dendropy.Tree.get(data=input_str, schema=input_schema)

    if node_label:
        for node in tree:
            try:
                node.label = "%.2g" % float(node.annotations.get_value(node_label))
            except:
                node.label = node.annotations.get_value(node_label)

    tree.write(path=output, schema=output_schema, suppress_rooting=True)

Log dataclass

Dataclass representing a log file.

Attributes:

Name Type Description
log_every int

The frequency at which to log.

file_name str

The name of the log file.

Source code in src/episodic/workflow/scripts/populate_beast_template.py
@dataclass
class Log:
    """
    Dataclass representing a log file.

    Attributes:
      log_every (int): The frequency at which to log.
      file_name (str): The name of the log file.
    """
    log_every: int
    file_name: str

MLE dataclass

Bases: Log

Dataclass representing a marginal likelihood estimator.

Attributes:

Name Type Description
log_every int

The frequency at which to log.

file_name str

The name of the log file.

results_file_name str

The name of the results file.

chain_length int

The length of the MCMC chain.

path_steps int

The number of path steps for the MLE.

Source code in src/episodic/workflow/scripts/populate_beast_template.py
@dataclass
class MLE(Log):
    """
    Dataclass representing a marginal likelihood estimator.

    Attributes:
      log_every (int): The frequency at which to log.
      file_name (str): The name of the log file.
      results_file_name (str): The name of the results file.
      chain_length (int): The length of the MCMC chain.
      path_steps (int): The number of path steps for the MLE.
    """
    results_file_name: str
    chain_length: int
    path_steps: int

Partition dataclass

Dataclass representing an alignment partition.

Attributes:

Name Type Description
prefix str

XML-safe prefix used for non-rate IDs.

background_prefix str

XML-safe prefix used for background-rate IDs.

foreground_prefix str

XML-safe prefix used for foreground/local-rate IDs.

taxa List[Taxon]

Partition sequences keyed by shared taxon IDs.

Source code in src/episodic/workflow/scripts/populate_beast_template.py
@dataclass
class Partition:
    """
    Dataclass representing an alignment partition.

    Attributes:
      prefix (str): XML-safe prefix used for non-rate IDs.
      background_prefix (str): XML-safe prefix used for background-rate IDs.
      foreground_prefix (str): XML-safe prefix used for foreground/local-rate IDs.
      taxa (List[Taxon]): Partition sequences keyed by shared taxon IDs.
    """

    prefix: str
    background_prefix: str
    foreground_prefix: str
    taxa: List[Taxon]

Taxon dataclass

Dataclass representing a taxon.

Attributes:

Name Type Description
id str

The id of the taxon.

sequence str

The sequence of the taxon.

date float

The date of the taxon.

uncertainty float

The uncertainty of the taxon's date.

Source code in src/episodic/workflow/scripts/populate_beast_template.py
@dataclass
class Taxon:
    """
    Dataclass representing a taxon.

    Attributes:
      id (str): The id of the taxon.
      sequence (str): The sequence of the taxon.
      date (float): The date of the taxon.
      uncertainty (float): The uncertainty of the taxon's date.
    """
    id: str
    sequence: str
    date: float
    uncertainty: float = 0.0

build_partition_prefix(alignment_path, used_prefixes)

Create a unique XML-safe prefix for a partition.

Source code in src/episodic/workflow/scripts/populate_beast_template.py
def build_partition_prefix(alignment_path: Path, used_prefixes: set) -> str:
    """Create a unique XML-safe prefix for a partition."""
    prefix = re.sub(r"[^0-9A-Za-z_.-]+", "_", alignment_path.stem)
    if not prefix:
        prefix = "partition"
    if not prefix[0].isalpha():
        prefix = f"partition_{prefix}"

    candidate = prefix
    index = 2
    while candidate in used_prefixes:
        candidate = f"{prefix}_{index}"
        index += 1
    used_prefixes.add(candidate)
    return candidate

build_partitions(alignment_paths, date_delimiter, date_index, foreground_label=None, background_label=None)

Parse and validate multiple FASTA alignments into BEAST partitions.

Source code in src/episodic/workflow/scripts/populate_beast_template.py
def build_partitions(
    alignment_paths: List[Path],
    date_delimiter: str,
    date_index: int,
    foreground_label: Optional[str] = None,
    background_label: Optional[str] = None,
) -> List[Partition]:
    """Parse and validate multiple FASTA alignments into BEAST partitions."""
    if not alignment_paths:
        msg = "At least one alignment partition must be provided."
        raise ValueError(msg)

    used_prefixes = set()
    partitions: List[Partition] = []
    reference_taxa: Optional[List[Taxon]] = None
    reference_by_id = {}

    multiple_partitions = len(alignment_paths) > 1

    for alignment_path in alignment_paths:
        taxa = taxa_from_fasta(alignment_path, date_delimiter=date_delimiter, date_index=date_index)
        if not taxa:
            msg = f"Alignment partition '{alignment_path}' does not contain any taxa."
            raise ValueError(msg)

        if reference_taxa is None:
            reference_taxa = taxa
            reference_by_id = {taxon.id: taxon for taxon in reference_taxa}
            ordered_taxa = taxa
        else:
            partition_by_id = {taxon.id: taxon for taxon in taxa}
            if set(partition_by_id) != set(reference_by_id):
                msg = (
                    "All alignment partitions must contain the same set of taxon headers. "
                    f"Partition '{alignment_path}' does not match the first alignment."
                )
                raise ValueError(
                    msg
                )

            ordered_taxa = []
            for reference_taxon in reference_taxa:
                taxon = partition_by_id[reference_taxon.id]
                if taxon.date != reference_taxon.date or taxon.uncertainty != reference_taxon.uncertainty:
                    msg = (
                        "All alignment partitions must encode identical sampling dates for each taxon. "
                        f"Mismatch found for taxon '{reference_taxon.id}' in '{alignment_path}'."
                    )
                    raise ValueError(msg)
                ordered_taxa.append(taxon)

        prefix = build_partition_prefix(alignment_path, used_prefixes)
        partitions.append(
            Partition(
                prefix=prefix,
                background_prefix=labeled_rate_prefix(
                    background_label,
                    prefix,
                    multiple_partitions,
                    default_label="background",
                ),
                foreground_prefix=labeled_rate_prefix(foreground_label, prefix, multiple_partitions),
                taxa=ordered_taxa,
            )
        )

    return partitions

labeled_rate_prefix(label, partition_prefix, multiple_partitions, default_label=None)

Build the rate-parameter prefix for a partition and optional state label.

Source code in src/episodic/workflow/scripts/populate_beast_template.py
def labeled_rate_prefix(
    label: Optional[str],
    partition_prefix: str,
    multiple_partitions: bool,
    default_label: Optional[str] = None,
) -> str:
    """Build the rate-parameter prefix for a partition and optional state label."""
    safe_label = xml_safe_label(label) or xml_safe_label(default_label)
    if safe_label is None:
        return partition_prefix
    if multiple_partitions:
        return f"{partition_prefix}.{safe_label}"
    return safe_label

populate_beast_template(work_dir, name, template_path, alignment_paths, clock, groups=None, groups_file=None, rate_gamma_prior_shape=0.5, rate_gamma_prior_scale=0.1, chain_length=100000000, samples=10000, mle_chain_length=1000000, mle_path_steps=100, mle_log_every=10000, date_delimiter='|', date_index=-1, fixed_tree=None, foreground_label=None, background_label=None, *, trace=True, trees=True, mle=True)

Populates a Beast XML template with an alignment file.

Parameters:

Name Type Description Default
work_dir Path

The path to the working directory.

required
name str

The name of the output file.

required
template_path Path

The path to the input Beast template file.

required
alignment_paths List[Path]

The paths to the input alignment partitions.

required
groups list

A list of groups to include in the analysis.

None
groups_file Path

Optional TSV mapping taxa to groups.

None
clock str

The clock model to use in the analysis.

required
rate_gamma_prior_shape float

The shape parameter of the gamma prior on the rate.

0.5
rate_gamma_prior_scale float

The scale parameter of the gamma prior on the rate.

0.1
chain_length int

The length of the MCMC chain.

100000000
samples int

The number of samples to draw from the MCMC chain.

10000
mle_chain_length int

The length of the MCMC chain for the marginal likelihood estimator.

1000000
mle_path_steps int

The number of path steps for the marginal likelihood estimator.

100
mle_log_every int

The log every for the marginal likelihood estimator.

10000
date_delimiter str

The delimiter for the date in the fasta header.

'|'
date_index int

The index of the date in the fasta header.

-1
fixed_tree Path

The path to the fixed tree file.

None
foreground_label str

Optional label prefix for foreground/local-rate parameters.

None
background_label str

Optional label prefix for background clock-rate parameters.

None

Other Parameters:

Name Type Description
trace bool

Whether to enable the trace log.

trees bool

Whether to enable the trees log.

mle bool

Whether to run the marginal likelihood estimator.

Returns:

Name Type Description
str

The rendered Beast XML template.

Examples:

>>> populate_beast_template(

... work_dir=Path("output"), ... name="my_analysis", ... template_path=Path("template.xml"), ... alignment_paths=[Path("alignment.fasta")], ... groups=["group1", "group2"], ... clock="strict", ... rate_gamma_prior_shape=0.5, ... rate_gamma_prior_scale=0.1, ... chain_length=100000000, ... samples=10000, ... mle_chain_length=1000000, ... mle_path_steps=100, ... mle_log_every=10000, ... date_delimiter="|", ... date_index=-1, ... fixed_tree=Path("fixed_tree.nwk"), ... trace=True, ... trees=True, ... mle=True, ... )

Source code in src/episodic/workflow/scripts/populate_beast_template.py
def populate_beast_template(
    work_dir: Path,
    name: str,
    template_path: Path,
    alignment_paths: List[Path],
    clock: str,
    groups: Optional[List[str]] = None,
    groups_file: Optional[Path] = None,
    rate_gamma_prior_shape: float = 0.5,
    rate_gamma_prior_scale: float = 0.1,
    chain_length: int = 100000000,
    samples: int = 10000,
    mle_chain_length: int = 1000000,
    mle_path_steps: int = 100,
    mle_log_every: int = 10000,
    date_delimiter="|",
    date_index=-1,
    fixed_tree: Optional[Path] = None,
    foreground_label: Optional[str] = None,
    background_label: Optional[str] = None,
    *,
    trace: bool = True,
    trees: bool = True,
    mle: bool = True,
):
    """
    Populates a Beast XML template with an alignment file.

    Args:
            work_dir (Path): The path to the working directory.
            name (str): The name of the output file.
            template_path (Path): The path to the input Beast template file.
            alignment_paths (List[Path]): The paths to the input alignment partitions.
            groups (list): A list of groups to include in the analysis.
            groups_file (Path): Optional TSV mapping taxa to groups.
            clock (str): The clock model to use in the analysis.
            rate_gamma_prior_shape (float): The shape parameter of the gamma prior on the rate.
            rate_gamma_prior_scale (float): The scale parameter of the gamma prior on the rate.
            chain_length (int): The length of the MCMC chain.
            samples (int): The number of samples to draw from the MCMC chain.
            mle_chain_length (int): The length of the MCMC chain for the marginal likelihood estimator.
            mle_path_steps (int): The number of path steps for the marginal likelihood estimator.
            mle_log_every (int): The log every for the marginal likelihood estimator.
            date_delimiter (str): The delimiter for the date in the fasta header.
            date_index (int): The index of the date in the fasta header.
            fixed_tree (Path): The path to the fixed tree file.
            foreground_label (str): Optional label prefix for foreground/local-rate parameters.
            background_label (str): Optional label prefix for background clock-rate parameters.

    Keyword Args:
      trace (bool): Whether to enable the trace log.
      trees (bool): Whether to enable the trees log.
      mle (bool): Whether to run the marginal likelihood estimator.

    Returns:
      str: The rendered Beast XML template.

    Examples:
            >>> populate_beast_template(
    ...     work_dir=Path("output"),
    ...     name="my_analysis",
    ...     template_path=Path("template.xml"),
    ...     alignment_paths=[Path("alignment.fasta")],
    ...     groups=["group1", "group2"],
    ...     clock="strict",
    ...     rate_gamma_prior_shape=0.5,
    ...     rate_gamma_prior_scale=0.1,
    ...     chain_length=100000000,
    ...     samples=10000,
    ...     mle_chain_length=1000000,
    ...     mle_path_steps=100,
    ...     mle_log_every=10000,
    ...     date_delimiter="|",
    ...     date_index=-1,
    ...     fixed_tree=Path("fixed_tree.nwk"),
    ...     trace=True,
    ...     trees=True,
    ...     mle=True,
    ... )
      <Rendered Beast XML template>
    """
    # Load the template
    template = Template(template_path.read_text(), undefined=StrictUndefined)

    # Parse alignment partitions into Taxon objects
    partitions = build_partitions(
        alignment_paths,
        date_delimiter=date_delimiter,
        date_index=date_index,
        foreground_label=foreground_label,
        background_label=background_label,
    )
    taxa = partitions[0].taxa
    taxon_ids = [taxon.id for taxon in taxa]

    if groups_file is not None:
        group_members = read_group_members(groups_file)
        groups = list(group_members)
    elif groups is not None:
        group_members = build_group_members(taxon_ids, groups)
    else:
        msg = "Either groups or groups_file must be provided."
        raise ValueError(msg)

    missing_taxa = {
        taxon_id
        for members in group_members.values()
        for taxon_id in members
        if taxon_id not in taxon_ids
    }
    if missing_taxa:
        missing_taxa_str = ", ".join(sorted(missing_taxa))
        msg = f"Group mapping references taxa not present in the alignment: {missing_taxa_str}"
        raise ValueError(msg)

    if fixed_tree is not None:
        fixed_tree = fixed_tree.read_text()

    log_every = max(1, chain_length // samples)

    trace_log = None
    if trace:
        trace_log = Log(
            log_every=log_every,
            file_name=f"{name}.log",
        )

    tree_log = None
    if trees:
        tree_log = Log(
            log_every=log_every,
            file_name=f"{name}.trees",
        )

    mle_log = None
    if mle:
        mle_log = MLE(
            log_every=mle_log_every,
            file_name=f"{name}.mle.log",
            results_file_name=work_dir / f"{name}.mle.results.log",
            chain_length=mle_chain_length,
            path_steps=mle_path_steps,
        )

    # Render the template
    rendered_template = template.render(
        taxa=taxa,
        partitions=partitions,
        groups=groups,
        groupMembers=group_members,
        clock=clock,
        fixedTree=fixed_tree,
        rateGammaPriorShape=rate_gamma_prior_shape,
        rateGammaPriorScale=rate_gamma_prior_scale,
        chainLength=chain_length,
        screenLogEvery=log_every,
        traceLog=trace_log,
        treeLog=tree_log,
        marginalLikelihoodEstimator=mle_log,
    )

    # Write the rendered template to a file
    return rendered_template

taxa_from_fasta(fasta_path, date_delimiter='|', date_index=-1)

Parses a fasta file into a list of Taxon objects.

Parameters:

Name Type Description Default
fasta_path Path

The path to the fasta file.

required
date_delimiter str

The delimiter for the date in the fasta header.

'|'
date_index int

The index of the date in the fasta header.

-1

Returns:

Type Description
List[Taxon]

List[Taxon]: A list of Taxon objects representing the taxa in the fasta file.

Raises:

Type Description
ValueError

If the fasta file is invalid.

Source code in src/episodic/workflow/scripts/populate_beast_template.py
def taxa_from_fasta(fasta_path, date_delimiter="|", date_index=-1) -> List[Taxon]:
    """
    Parses a fasta file into a list of Taxon objects.

    Args:
      fasta_path (Path): The path to the fasta file.
      date_delimiter (str): The delimiter for the date in the fasta header.
      date_index (int): The index of the date in the fasta header.

    Returns:
      List[Taxon]: A list of Taxon objects representing the taxa in the fasta file.

    Raises:
      ValueError: If the fasta file is invalid.
    """
    # Read the fasta file
    with open(fasta_path) as fasta_file:
        fasta_lines = fasta_file.readlines()

    # check if valid fasta
    if not fasta_lines:
        msg = "Invalid fasta file."
        raise ValueError(msg)
    if not fasta_lines[0].startswith(">"):
        msg = "Invalid fasta file."
        raise ValueError(msg)
    # Parse the fasta file into Taxon objects. Support multi-line sequences.
    taxa = []
    for line in fasta_lines:
        if line.startswith(">"):
            header = line[1:].strip()
            # 1992/1 = 1992 to 1993
            date_with_uncertainty = header.split(date_delimiter)[date_index]
            date, *uncertainty = date_with_uncertainty.split("/")
            if uncertainty:
                uncertainty = float(uncertainty[0])
            else:
                uncertainty = 0.0
            try:
                date = float(date)
            except ValueError:
                date = date_to_decimal_year(date)
            taxa.append(Taxon(id=header, sequence="", date=float(date), uncertainty=uncertainty))
        else:
            taxa[-1].sequence += line.strip()

    return taxa

xml_safe_label(label)

Return a BEAST/XML-safe label prefix, or None when no label is set.

Source code in src/episodic/workflow/scripts/populate_beast_template.py
def xml_safe_label(label: Optional[str]) -> Optional[str]:
    """Return a BEAST/XML-safe label prefix, or None when no label is set."""
    if label is None:
        return None

    safe_label = re.sub(r"[^0-9A-Za-z_.-]+", "_", label.strip())
    safe_label = safe_label.strip("._-")
    if not safe_label:
        return None
    if not safe_label[0].isalpha():
        safe_label = f"label_{safe_label}"
    return safe_label