Source code for counting.count_alleles_sc

"""Single-cell allele counting functions."""

from __future__ import annotations

import logging
import timeit
from collections import defaultdict
from collections.abc import Iterator

import anndata as ad
import numpy as np
import pandas as pd
import polars as pl
from pysam.libcalignmentfile import AlignmentFile
from scipy.sparse import csr_matrix

# Local imports
from .count_alleles import find_read_aln_pos

logger = logging.getLogger(__name__)


[docs] class CountStatsSC: """Container for mutable single-cell counting statistics. Tracks allele counts and metadata per chromosome during counting. """
[docs] def __init__(self) -> None: self.ref_count: defaultdict[tuple[int, int], int] = defaultdict(int) self.alt_count: defaultdict[tuple[int, int], int] = defaultdict(int) self.other_count: defaultdict[tuple[int, int], int] = defaultdict(int) # Keep track of metadata # Number self.num_snps: defaultdict[str, int] = defaultdict(int) self.num_barcodes: defaultdict[str, int] = defaultdict(int) self.reads_counted: defaultdict[str, int] = defaultdict(int) # Reads that were not counted self.reads_skipped_no_barcode: defaultdict[str, int] = defaultdict(int) self.reads_skipped_barcode_no_index: defaultdict[str, int] = defaultdict(int) self.reads_skipped_prev_counted: defaultdict[str, int] = defaultdict(int) self.reads_skipped_no_sequence: defaultdict[str, int] = defaultdict(int) self.reads_skipped_no_aln_pos: defaultdict[str, int] = defaultdict(int) self.reads_skipped_seq_error: defaultdict[str, int] = defaultdict(int)
[docs] def stats_to_df(self) -> pd.DataFrame: """Convert statistics to a pandas DataFrame.""" stat_attributes = [ "num_snps", "num_barcodes", "reads_counted", "reads_skipped_no_barcode", "reads_skipped_barcode_no_index", "reads_skipped_prev_counted", "reads_skipped_no_sequence", "reads_skipped_no_aln_pos", "reads_skipped_seq_error", ] stat_df = pd.DataFrame({key: getattr(self, key) for key in stat_attributes}).reset_index( names="chrom" ) return stat_df
[docs] def make_count_matrix( bam_file: str, df: pl.DataFrame, bc_dict: dict[str, int], include_samples: list[str] | None = None, include_features: list[str] | None = None, ) -> ad.AnnData: """Create sparse count matrix from BAM and variant data. Parameters ---------- bam_file : str Path to BAM file with cell barcodes. df : pl.DataFrame DataFrame with variant positions from intersection. bc_dict : dict[str, int] Mapping of cell barcodes to integer indices. include_samples : list[str] | None, optional Sample columns to include from variant data, by default None. include_features : list[str] | None, optional Feature columns to include, by default None. Returns ------- ad.AnnData AnnData object with count matrices in layers (ref, alt, other). """ chrom_list = df.get_column("chrom").unique(maintain_order=True) # chrom_list = chrom_list[:3] # Testing purposes # Add genotypes annotations # Maybe do this automatically and parse feature col instead? snp_df_cols = ["chrom", "pos", "ref", "alt"] if include_samples is not None: snp_df_cols.extend(include_samples) # Might be more memory efficient to use pandas index instead... snp_df = df.select(snp_df_cols).unique(maintain_order=True).with_row_index() sc_counts = CountStatsSC() # Class that holds total count data with AlignmentFile(bam_file, "rb") as bam: for chrom in chrom_list: chrom_df = snp_df.filter(pl.col("chrom") == chrom) start = timeit.default_timer() try: count_bc_snp_alleles( bam=bam, bc_dict=bc_dict, chrom=chrom, snp_list=chrom_df.select(["index", "pos", "ref", "alt"]).iter_rows(), sc_counts=sc_counts, ) except ValueError: logger.warning("Skipping %s: Contig not found!", chrom) else: logger.info( "%s: Counted %d SNPs in %.2f seconds", chrom, chrom_df.height, timeit.default_timer() - start, ) # Create sparse matrices # sparse array is recommended...but doesnt work with adata sparse_ref = csr_matrix( (list(sc_counts.ref_count.values()), list(zip(*sc_counts.ref_count.keys()))), shape=(snp_df.shape[0], len(bc_dict)), dtype=np.uint8, ) sparse_alt = csr_matrix( (list(sc_counts.alt_count.values()), list(zip(*sc_counts.alt_count.keys()))), shape=(snp_df.shape[0], len(bc_dict)), dtype=np.uint8, ) sparse_other = csr_matrix( (list(sc_counts.other_count.values()), list(zip(*sc_counts.other_count.keys()))), shape=(snp_df.shape[0], len(bc_dict)), dtype=np.uint8, ) # Create anndata With total as X adata = ad.AnnData( X=sparse_ref + sparse_alt + sparse_other, layers={"ref": sparse_ref, "alt": sparse_alt, "other": sparse_other}, ) # Annotate adata: Figure out what to add to adata here vs later adata.obs = snp_df.to_pandas() # Maybe just switch to pandas? Should i set no copy? adata.obs["ref_count"] = adata.layers["ref"].sum(axis=1, dtype=np.uint16).T.A1 adata.obs["alt_count"] = adata.layers["alt"].sum(axis=1, dtype=np.uint16).T.A1 # Add barcode names adata.var_names = bc_dict.keys() # Add genotypes to anndata if include_samples is not None: adata.uns["samples"] = include_samples # TODO: Allow for other features besides 'region' using include_features # Could be case of no features, or feature is gene if "region" in df.columns: # Get unique snps and associated regions # Create dict during analysis step instead adata.uns["feature"] = ( df.join(snp_df, on=["chrom", "pos", "ref", "alt"], how="left") .select(["region", "index"]) .to_pandas() ) # region_snp_dict = dict( # df.join(snp_df, on=["chrom", "pos", "ref", "alt"], how="left" # ).group_by("region", maintain_order=True # ).agg("index").iter_rows() # ) # adata.uns["region_snps"] = region_snp_dict # Write out count stats adata.uns["count_stats"] = sc_counts.stats_to_df() return adata
[docs] def count_bc_snp_alleles( bam: AlignmentFile, bc_dict: dict[str, int], chrom: str, snp_list: Iterator[tuple[int, int, str, str]], sc_counts: CountStatsSC, ) -> None: """Count alleles at SNP positions for each cell barcode. Parameters ---------- bam : AlignmentFile Open BAM file handle. bc_dict : dict[str, int] Mapping of cell barcodes to indices. chrom : str Chromosome to process. snp_list : Iterator[tuple[int, int, str, str]] Iterator of (index, pos, ref, alt) tuples. sc_counts : CountStatsSC Statistics container to update with counts. """ read_set = set() # Keep track of reads seen bc_set = set() for idx, pos, ref, alt in snp_list: for read in bam.fetch(chrom, pos - 1, pos): # If already counted allele or pair in read if read.query_name in read_set: sc_counts.reads_skipped_prev_counted[chrom] += 1 continue # Check if there is a read barcode try: read_bc = str(read.get_tag("CB")) except KeyError: sc_counts.reads_skipped_no_barcode[chrom] += 1 continue # If barcode not in dict if read_bc not in bc_dict: sc_counts.reads_skipped_barcode_no_index[chrom] += 1 continue seq = read.query_sequence if seq is None: sc_counts.reads_skipped_no_sequence[chrom] += 1 continue # Binary search for alignment position qpos = find_read_aln_pos(read, pos - 1) if qpos is None: sc_counts.reads_skipped_no_aln_pos[chrom] += 1 continue try: if seq[qpos] == ref: sc_counts.ref_count[(idx, bc_dict[read_bc])] += 1 elif seq[qpos] == alt: sc_counts.alt_count[(idx, bc_dict[read_bc])] += 1 else: sc_counts.other_count[(idx, bc_dict[read_bc])] += 1 except (TypeError, IndexError) as e: # Narrow exception handling: only catch sequence access errors # Log the actual exception for debugging unexpected errors sc_counts.reads_skipped_seq_error[chrom] += 1 logger.debug( "Skipping read %s: sequence access error at %s:%d (qpos=%s): %s", read.query_name, chrom, pos, qpos, e ) continue else: read_set.add(read.query_name) bc_set.add(read_bc) sc_counts.reads_counted[chrom] += 1 sc_counts.num_snps[chrom] += 1 # Put here in case of error sc_counts.num_barcodes[chrom] = len(bc_set) # Add unique barcodes