configfile: "config.yaml"

import os
import glob

# ---------- parameters from config.yaml ----------
VCF_INPUT      = config["vcf_input"]           # /input/Sample.vcf.gz
REF_DIR        = config["reference_dir"]        # /input/reference/
REF_FASTA      = config["reference_fasta"]      # Homo_sapiens.GRCh38.dna.primary_assembly.fa
QUERY_INPUT    = config["query_input"]          # /input/input.txt
DEVICE_ID      = config.get("device_id", "C")   # C for CPU, G for GPU
CHROMOSOMES    = config["chromosomes"]           # list of chromosome names

REF_PATH = os.path.join(REF_DIR, REF_FASTA)

# ---------- rule all ----------
rule all:
    input:
        "/output/combined_off_target_result.txt"

# ---------- Step 1: Re-compress plain gzip VCF to bgzip + tabix index ----------
rule recompress_vcf:
    """Handle plain gzip VCF: zcat | bgzip, then tabix index."""
    input:
        vcf=VCF_INPUT
    output:
        bgz="/output/work/sample.vcf.gz",
        tbi="/output/work/sample.vcf.gz.tbi"
    log:
        "/output/logs/recompress_vcf.log"
    shell:
        """
        mkdir -p /output/work /output/logs
        zcat {input.vcf} | bgzip -c > {output.bgz} 2> {log}
        tabix -p vcf {output.bgz} 2>> {log}
        """

# ---------- Step 2: Normalize VCF ----------
# vcfallelicprimitives | bcftools norm -m- | vcfcreatemulti
rule normalize_vcf:
    """Normalization chain matching CLI: vcfallelicprimitives | bcftools norm -m- | vcfcreatemulti."""
    input:
        bgz="/output/work/sample.vcf.gz",
        tbi="/output/work/sample.vcf.gz.tbi"
    output:
        norm="/output/work/sample_normalized.vcf.gz",
        tbi="/output/work/sample_normalized.vcf.gz.tbi"
    log:
        "/output/logs/normalize_vcf.log"
    shell:
        """
        vcfallelicprimitives {input.bgz} \
          | bcftools norm -m- \
          | vcfcreatemulti \
          | bgzip -c > {output.norm} 2> {log}
        tabix -p vcf {output.norm} 2>> {log}
        """

# ---------- Step 3: Split VCF by chromosome ----------
rule split_vcf_by_chrom:
    """Split normalized VCF by chromosome to avoid vcf2fasta segfaults."""
    input:
        vcf="/output/work/sample_normalized.vcf.gz",
        tbi="/output/work/sample_normalized.vcf.gz.tbi"
    output:
        chrom_vcf="/output/work/chroms/{chrom}.vcf"
    log:
        "/output/logs/split_{chrom}.log"
    shell:
        """
        mkdir -p /output/work/chroms
        bcftools view -r {wildcards.chrom} {input.vcf} > {output.chrom_vcf} 2> {log}
        """

# ---------- Step 4: Generate FASTA per chromosome ----------
# vcf2fasta -f <ref> -p <prefix> -n NAN <per-chrom.vcf>
rule vcf2fasta:
    """Generate allelic FASTA files per chromosome using vcf2fasta from vcflib."""
    input:
        chrom_vcf="/output/work/chroms/{chrom}.vcf"
    output:
        done=touch("/output/work/fasta/{chrom}.done")
    params:
        ref=REF_PATH,
        prefix="/output/work/fasta/{chrom}"
    log:
        "/output/logs/vcf2fasta_{chrom}.log"
    shell:
        """
        mkdir -p /output/work/fasta
        vcf2fasta -f {params.ref} -p {params.prefix} -n NAN {input.chrom_vcf} 2> {log} || true
        touch {output.done}
        """

# ---------- Step 5: Run cas-offinder per individual FASTA file ----------
rule cas_offinder_per_fasta:
    """Run cas-offinder on each individual allelic FASTA file for a chromosome."""
    input:
        done="/output/work/fasta/{chrom}.done",
        query=QUERY_INPUT
    output:
        result="/output/work/results/{chrom}_offinder.txt"
    params:
        fasta_dir="/output/work/fasta",
        device=DEVICE_ID
    log:
        "/output/logs/cas_offinder_{chrom}.log"
    run:
        import subprocess, glob, tempfile, shutil

        fasta_dir = params.fasta_dir
        chrom = wildcards.chrom
        device = params.device
        query_template = input.query

        # Find all FASTA files generated for this chromosome
        fasta_pattern = os.path.join(fasta_dir, chrom + "*.fa")
        fasta_files = sorted(glob.glob(fasta_pattern))

        combined = ""
        with open(str(log), "w") as logf:
            logf.write(f"Chromosome: {chrom}\n")
            logf.write(f"FASTA files found: {fasta_files}\n")

            for fa in fasta_files:
                # Read the query template and replace the first line (FASTA path)
                with open(query_template, "r") as qf:
                    lines = qf.readlines()

                # First line of cas-offinder input is the FASTA directory/file path
                # Replace it with the individual FASTA file path
                if lines:
                    lines[0] = fa + "\n"

                # Write a temporary query file
                tmp_query = fa + "_query.txt"
                with open(tmp_query, "w") as tq:
                    tq.writelines(lines)

                # Output file for this FASTA
                fa_result = fa + "_result.txt"

                # Run cas-offinder
                cmd = ["cas-offinder", tmp_query, device, fa_result]
                logf.write(f"Running: {' '.join(cmd)}\n")
                result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
                logf.write(f"Return code: {result.returncode}\n")
                if result.stderr:
                    logf.write(f"stderr: {result.stderr.decode()}\n")

                # Collect results
                if os.path.exists(fa_result):
                    with open(fa_result, "r") as rf:
                        combined += rf.read()

                # Clean up temp query
                if os.path.exists(tmp_query):
                    os.remove(tmp_query)

        # Write combined result for this chromosome
        with open(str(output.result), "w") as out:
            out.write(combined)

# ---------- Step 6: Combine all chromosome results ----------
rule combine_results:
    """Merge all per-chromosome off-target results into a single file."""
    input:
        expand("/output/work/results/{chrom}_offinder.txt", chrom=CHROMOSOMES)
    output:
        "/output/combined_off_target_result.txt"
    log:
        "/output/logs/combine_results.log"
    shell:
        """
        cat {input} > {output} 2> {log}
        echo "Combined $(wc -l < {output}) off-target sites." >> {log}
        """