from __future__ import annotations
import functools
import re
import tempfile
from collections.abc import Callable
from pathlib import Path
from typing import ParamSpec, TypeVar
from .count_alleles import make_count_df
# local imports
from .filter_variant_data import intersect_vcf_region, parse_intersect_region_new, vcf_to_bed
from .parse_gene_data import make_gene_data, parse_intersect_genes_new
P = ParamSpec("P")
T = TypeVar("T")
[docs]
class WaspCountFiles:
"""Container for WASP counting pipeline file paths and configuration.
Manages input/output file paths and parsing logic for the variant counting pipeline.
Attributes:
bam_file: Path to the BAM alignment file.
variant_file: Path to the variant file (VCF, BCF, or PGEN).
region_file: Optional path to a region file (BED, GTF, or GFF3).
samples: List of sample IDs to process, or None for all samples.
use_region_names: Whether to use region names from the region file.
out_file: Output file path for count results.
temp_loc: Directory for temporary files.
is_gene_file: Whether the region file is a gene annotation file.
gtf_bed: Path to converted GTF/GFF3 BED file, if applicable.
variant_prefix: Prefix extracted from variant filename.
vcf_bed: Path to variant BED file.
skip_vcf_to_bed: Whether to skip VCF-to-BED conversion.
region_type: Type of regions ('regions' or 'genes').
intersect_file: Path to intersected variant-region file.
skip_intersect: Whether to skip intersection step.
"""
# Class attribute type hints
bam_file: str
variant_file: str
region_file: str | None
samples: list[str] | None
use_region_names: bool
out_file: str
temp_loc: str
is_gene_file: bool
gtf_bed: str | None
variant_prefix: str
vcf_bed: str
skip_vcf_to_bed: bool
region_type: str | None
intersect_file: str
skip_intersect: bool
[docs]
def __init__(
self,
bam_file: str,
variant_file: str,
region_file: str | None = None,
samples: str | list[str] | None = None,
use_region_names: bool = False,
out_file: str | None = None,
temp_loc: str | None = None,
precomputed_vcf_bed: str | None = None,
precomputed_intersect: str | None = None,
) -> None:
# User input files
self.bam_file = bam_file
self.variant_file = variant_file
self.region_file = region_file
self.use_region_names = use_region_names
# gtf and gff specific
self.is_gene_file = False # check if using gff3/gtf
self.gtf_bed = None
# Make sure samples turned into str list
if isinstance(samples, str):
# Check if sample file or comma delim string
if Path(samples).is_file():
with open(samples) as sample_file:
self.samples = [l.strip() for l in sample_file]
else:
self.samples = [s.strip() for s in samples.split(",")]
else:
self.samples = samples
# parse output?
self.out_file: str = out_file if out_file is not None else str(Path.cwd() / "counts.tsv")
# Failsafe if decorator doesnt create temp_loc
self.temp_loc: str = temp_loc if temp_loc is not None else str(Path.cwd())
# Parse variant file prefix (handle VCF, BCF, PGEN)
variant_name = Path(self.variant_file).name
if variant_name.endswith(".vcf.gz"):
variant_prefix = variant_name[:-7] # Remove .vcf.gz
elif variant_name.endswith(".pgen"):
variant_prefix = variant_name[:-5] # Remove .pgen
else:
variant_prefix = re.split(r"\.vcf|\.bcf", variant_name)[0]
self.variant_prefix = variant_prefix
# Filtered variant output (or precomputed)
self.vcf_bed = (
precomputed_vcf_bed
if precomputed_vcf_bed is not None
else str(Path(self.temp_loc) / f"{variant_prefix}.bed")
)
self.skip_vcf_to_bed = precomputed_vcf_bed is not None
# Parse region file
self.region_type = None # maybe use a boolean flag instead
if self.region_file is not None:
f_ext = "".join(Path(self.region_file).suffixes)
if re.search(r"\.(.*Peak|bed)(?:\.gz)?$", f_ext, re.I):
self.region_type = "regions"
self.intersect_file = (
precomputed_intersect
if precomputed_intersect is not None
else str(Path(self.temp_loc) / f"{variant_prefix}_intersect_regions.bed")
)
self.is_gene_file = False
elif re.search(r"\.g[tf]f(?:\.gz)?$", f_ext, re.I):
self.region_type = "genes"
self.intersect_file = (
precomputed_intersect
if precomputed_intersect is not None
else str(Path(self.temp_loc) / f"{variant_prefix}_intersect_genes.bed")
)
self.is_gene_file = True
gtf_prefix = re.split(r".g[tf]f", Path(self.region_file).name)[0]
self.gtf_bed = str(Path(self.temp_loc) / f"{gtf_prefix}.bed")
self.use_region_names = True # Use feature attributes as region names
elif re.search(r"\.gff3(?:\.gz)?$", f_ext, re.I):
self.region_type = "genes"
self.intersect_file = (
precomputed_intersect
if precomputed_intersect is not None
else str(Path(self.temp_loc) / f"{variant_prefix}_intersect_genes.bed")
)
self.is_gene_file = True
gtf_prefix = re.split(r".gff3", Path(self.region_file).name)[0]
self.gtf_bed = str(Path(self.temp_loc) / f"{gtf_prefix}.bed")
self.use_region_names = True # Use feature attributes as region names
else:
raise ValueError(
f"Invalid region file type. Expected .bed, .gtf, or .gff3, got: {self.region_file}"
)
else:
# No region file: intersect file defaults to vcf_bed (or provided precomputed)
self.intersect_file = (
precomputed_intersect if precomputed_intersect is not None else self.vcf_bed
)
self.skip_intersect = precomputed_intersect is not None
# TODO UPDATE THIS WHEN I ADD AUTOPARSERS
if self.is_gene_file:
# Possible edge case of vcf and gtf prefix conflict
if self.vcf_bed == self.gtf_bed:
self.gtf_bed = str(Path(self.temp_loc) / "genes.bed")
[docs]
def tempdir_decorator(func: Callable[P, T]) -> Callable[P, T]:
"""Decorator that creates a temporary directory for the wrapped function.
If 'temp_loc' is not provided in kwargs, creates a temporary directory
and passes it to the function. The directory is cleaned up after execution.
Args:
func: The function to wrap.
Returns:
Wrapped function with automatic temporary directory management.
"""
@functools.wraps(func)
def tempdir_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
if kwargs.get("temp_loc") is not None:
return func(*args, **kwargs)
else:
with tempfile.TemporaryDirectory() as tmpdir:
kwargs["temp_loc"] = tmpdir
return func(*args, **kwargs)
return tempdir_wrapper
[docs]
@tempdir_decorator
def run_count_variants(
bam_file: str,
variant_file: str,
region_file: str | None = None,
samples: str | list[str] | None = None,
use_region_names: bool = False,
out_file: str | None = None,
temp_loc: str | None = None,
gene_feature: str | None = None,
gene_attribute: str | None = None,
gene_parent: str | None = None,
use_rust: bool = True,
precomputed_vcf_bed: str | None = None,
precomputed_intersect: str | None = None,
include_indels: bool = False,
) -> None:
"""Run the WASP variant counting pipeline.
Counts allele-specific reads at heterozygous variant positions within
optional genomic regions.
Args:
bam_file: Path to the BAM alignment file.
variant_file: Path to the variant file (VCF, BCF, or PGEN).
region_file: Optional path to a region file (BED, GTF, or GFF3).
samples: Sample ID(s) to process. Can be a single ID, comma-separated
string, path to a file with one sample per line, or list of IDs.
use_region_names: Whether to use region names from the region file.
out_file: Output file path. Defaults to 'counts.tsv' in current directory.
temp_loc: Directory for temporary files. Auto-created if not provided.
gene_feature: GTF/GFF3 feature type to extract (e.g., 'gene', 'exon').
gene_attribute: GTF/GFF3 attribute for region names (e.g., 'gene_name').
gene_parent: GTF/GFF3 parent attribute for hierarchical features.
use_rust: Whether to use the Rust backend for counting (faster).
precomputed_vcf_bed: Path to pre-computed variant BED file (skips conversion).
precomputed_intersect: Path to pre-computed intersection file.
include_indels: Whether to include indels in variant counting.
Returns:
None. Results are written to out_file.
"""
# call the data class
count_files = WaspCountFiles(
bam_file,
variant_file,
region_file=region_file,
samples=samples,
use_region_names=use_region_names,
out_file=out_file,
temp_loc=temp_loc,
precomputed_vcf_bed=precomputed_vcf_bed,
precomputed_intersect=precomputed_intersect,
)
# print(*vars(count_files).items(), sep="\n") # For debugging
with_gt = False
if (count_files.samples is not None) and (len(count_files.samples) == 1):
with_gt = True
# temporarily disable for ASE
# if not count_files.is_gene_file:
# with_gt = True
# Create Intermediary Files
if not count_files.skip_vcf_to_bed:
vcf_to_bed(
vcf_file=count_files.variant_file,
out_bed=count_files.vcf_bed,
samples=count_files.samples,
include_gt=with_gt,
include_indels=include_indels,
)
# TODO PARSE GENE FEATURES AND ATTRIBUTES
region_col_name = None # Defaults to 'region' as region name
intersect_genes = False
# region_files is valid to perform intersects
if count_files.region_file is not None:
# Check if we need to prepare genes for intersection
if count_files.gtf_bed is not None:
# TODO UPDATE THIS WHEN I ADD AUTOPARSERS AND VALIDATORS
gene_data = make_gene_data(
gene_file=count_files.region_file,
out_bed=count_files.gtf_bed,
feature=gene_feature,
attribute=gene_attribute,
parent_attribute=gene_parent,
)
regions_to_intersect = count_files.gtf_bed
region_col_name = gene_data.feature
intersect_genes = True
else:
regions_to_intersect = count_files.region_file
region_col_name = None # Defaults to 'region' as region name
if not count_files.skip_intersect:
intersect_vcf_region(
vcf_file=count_files.vcf_bed,
region_file=regions_to_intersect,
out_file=count_files.intersect_file,
)
# Create Variant Dataframe
# TODO validate
if intersect_genes:
df = parse_intersect_genes_new(
intersect_file=count_files.intersect_file,
attribute=gene_data.attribute,
parent_attribute=gene_data.parent_attribute,
)
elif with_gt:
df = parse_intersect_region_new(
intersect_file=count_files.intersect_file,
samples=["GT"],
use_region_names=count_files.use_region_names,
region_col=region_col_name,
)
else:
df = parse_intersect_region_new(
intersect_file=count_files.intersect_file,
samples=None,
use_region_names=count_files.use_region_names,
region_col=region_col_name,
)
# df = parse_intersect_region(
# intersect_file=count_files.intersect_file,
# use_region_names=count_files.use_region_names,
# region_col=region_col_name)
# Should I include a filt bam step???
# Count
count_df = make_count_df(bam_file=count_files.bam_file, df=df, use_rust=use_rust)
# Write counts
count_df.write_csv(count_files.out_file, include_header=True, separator="\t")
# Should i return for use in analysis pipeline?
# return count_df