<?php
declare(strict_types=1);

namespace Modules\PocockSimon\Service;

use Atlas\RandomisationBundle\Contract\AlgorithmInterface;
use Atlas\RandomisationBundle\Contract\SpecificationInterface;
use Atlas\RandomisationBundle\Dto\RandomisationResultDto;
use Atlas\RandomisationBundle\Entity\Participant\Factor as ParticipantFactor;
use Atlas\RandomisationBundle\Exception\RandomisationException;
use Atlas\RandomisationBundle\Repository\Randomisation\InactiveArmRepository;
use Atlas\RandomisationBundle\Repository\Randomisation\InactiveRandomisationRepository;
use Atlas\RandomisationBundle\Service\Randomisation\RandomisationAuditor;
use Atlas\RandomisationBundle\Repository\Randomisation\AllocationRepository;
use Atlas\RandomisationBundle\Repository\Participant\FactorRepository;
use DateTimeInterface;
use Doctrine\DBAL\Exception\UniqueConstraintViolationException;
use Doctrine\ORM\EntityManagerInterface;
use Modules\PocockSimon\Domain\Factor;
use Modules\PocockSimon\Domain\ImbalanceMetricEnum;
use Modules\PocockSimon\Domain\NoBestDistributionEnum;
use Modules\PocockSimon\Domain\Specification;
use Modules\PocockSimon\Entity\Hypothetical;
use Modules\PocockSimon\Entity\Total;
use Random\RandomException;
use Symfony\Component\DependencyInjection\Attribute\AutoconfigureTag;
use Symfony\Component\Uid\Uuid;
use Symfony\Component\Uid\UuidV7;

#[AutoconfigureTag('atlas.randomisation.algorithm', attributes: ['algorithm' => 'pocock_simon'])]
final readonly class Algorithm implements AlgorithmInterface
{

    public function __construct(
        private RandomisationAuditor $logger,
        private InactiveRandomisationRepository $inactive_randomisation,
        private InactiveArmRepository $inactive_arms,
        private AllocationRepository $allocations,
        private FactorRepository $factors,
        private EntityManagerInterface $entity_manager
    )
    {
    }

    public function getType(): string
    {
        return 'pocock_simon';
    }

    /**
     * @param string $studyCode
     * @param string $randomisationName
     * @param Specification $specification
     * @param string $participantIdentifier
     * @param string $location
     * @param string $actionBy
     * @param array $variables
     * @param DateTimeInterface|null $simulate
     * @param Uuid|null $simulateId
     * @return RandomisationResultDto|false
     * @throws RandomException
     */
    public function randomise(
        string $studyCode,
        string $randomisationName,
        SpecificationInterface $specification,
        string $participantIdentifier,
        string $location,
        string $actionBy,
        array $variables = [],
        ?DateTimeInterface $simulate = null,
        ?Uuid $simulateId = null
    ): RandomisationResultDto|false
    {

        $studyCode = mb_strtoupper($studyCode, encoding: 'utf-8');
        $randomisationCode = mb_strtoupper($specification->code, encoding: 'utf-8');

        $uri = Uuid::v7();

        $this->logger->log($studyCode, $randomisationCode, $participantIdentifier, sprintf('Randomisation started using algorithm %s', $this->getType()), $actionBy, $uri, simulate: $simulate, simulateId: $simulateId);

        try {
            $participantLevels = $this->resolveParticipantLevels($specification->factors, $variables);
        }
        catch(RandomisationException $e) {
            $this->logger->log(
                $studyCode, $randomisationCode, $participantIdentifier,
                $e->getMessage(),
                $actionBy, $uri, simulate: $simulate, simulateId: $simulateId
            );

            throw $e;
        }

        $this->logger->log($studyCode, $randomisationCode, $participantIdentifier, 'Collected and checked all factors', $actionBy, $uri, simulate: $simulate, simulateId: $simulateId);

        $inactive = $this->inactive_randomisation->check($studyCode, $randomisationCode, $location);

        if ($inactive) {
            $this->logger->log($studyCode, $randomisationCode, $participantIdentifier, 'Randomisation inactive for all or current locations', $actionBy, $uri, simulate: $simulate, simulateId: $simulateId);

            return false;
        }

        //2: check arms
        $activeArms = [];
        $targetWeights = [];

        foreach ($specification->arms as $name => $arm) {

            $inactive = $this->inactive_arms->check($studyCode, $randomisationCode, $name, $location);

            if ($inactive) {
                $this->logger->log($studyCode, $randomisationCode, $participantIdentifier, sprintf('Arm %s inactive (removed)', $name), $actionBy, $uri, simulate: $simulate, simulateId: $simulateId);
            } else {
                $this->logger->log($studyCode, $randomisationCode, $participantIdentifier, sprintf('Arm %s active with weight: %s', $name, $arm->weight), $actionBy, $uri, simulate: $simulate, simulateId: $simulateId);
                $activeArms[$name] = 1;
                $targetWeights[$name] = $arm->weight;
            }
        }

        if (count($activeArms) < 1) {
            $this->logger->log($studyCode, $randomisationCode, $participantIdentifier, 'No available arms', $actionBy, $uri, simulate: $simulate, simulateId: $simulateId);
            return false;
        }

        if (count($activeArms) === 1) {

            $only = array_key_first($activeArms);
            $this->logger->log(
                $studyCode, $randomisationCode, $participantIdentifier,
                sprintf('Only 1 arm available and allocated %s', $only),
                $actionBy, $uri, simulate: $simulate, simulateId: $simulateId
            );

            $this->addFactors($specification, $participantLevels, $studyCode, $randomisationCode, $participantIdentifier, $uri, $actionBy, $simulate, $simulateId);

            return new RandomisationResultDto($only, $uri);
        }

        //3: now check for safeguarding
        if ($specification->simple_for_first !== null) {

            $simpleForFirstCount = $this->allocations->countInitialRandomAllocations($studyCode, $randomisationCode, simulationId: $simulateId);

            if ($simpleForFirstCount < $specification->simple_for_first) {
                $this->logger->log($studyCode, $randomisationCode, $participantIdentifier, 'Safeguarding: pure simple randomisation', $actionBy, $uri, simulate: $simulate, simulateId: $simulateId);

                $total = array_sum($targetWeights);

                $probs = [];
                if ($total <= 0) {
                    // equal-split
                    $n = max(count($targetWeights), 1);
                    $each = 1.0 / $n;
                    foreach ($targetWeights as $armName => $_w) {
                        $probs[$armName] = $each;
                    }
                } else {
                    foreach ($targetWeights as $armName => $w) {
                        $probs[$armName] = $w / $total;
                    }
                }

                [ $allocated, $u ] = $this->drawFromProbabilities($probs);

                $this->logger->log(
                    $studyCode, $randomisationCode, $participantIdentifier,
                    sprintf('Draw: u=%.6f allocated=%s',
                        $u,
                        $allocated
                    ),
                    $actionBy, $uri, simulate: $simulate, simulateId: $simulateId
                );

                $this->addFactors($specification, $participantLevels, $studyCode, $randomisationCode, $participantIdentifier, $uri, $actionBy, $simulate, $simulateId);

                return new RandomisationResultDto($allocated, $uri, ['type' => sprintf('initial-random{%s}', $simpleForFirstCount + 1)]);
            }
        }

        // 4) Counts & participant’s factor levels
        $counts = $this->getCurrentCounts($studyCode, $randomisationCode, $specification, $simulateId);

        // 5) Hypothetical imbalance per arm (persist PS rows)
        $imbalanceByArm = [];         // [arm => float]
        foreach (array_keys($activeArms) as $armName) {
            [$totalScore, $perFactor] = $this->computeImbalanceIfAdded(
                $armName, $counts, $specification, $participantLevels
            );
            $imbalanceByArm[$armName] = $totalScore;

            // write per-factor hypothetical rows
            foreach ($perFactor as $factorName => $vals) {

                $hypothetical = new Hypothetical(
                    studyCode: $studyCode,
                    randomisationCode: $randomisationCode,
                    participantIdentifier: $participantIdentifier,
                    arm: $armName,
                    factor: $factorName,
                    value: (string)$vals['value'],
                    weight: (float)$vals['weight'],
                    count: (int)$vals['count'],
                    hypothetical: (int)$vals['hypothetical'],
                    total: (int)$vals['factor_score'],
                    runId: $uri,
                    actionBy: $actionBy,
                    simulation: $simulate,
                    simulationId: $simulateId
                );

                $this->entity_manager->persist($hypothetical);
            }
            // write total row (rounded to int for the table)
            $total = new Total(
                studyCode: $studyCode,
                randomisationCode: $randomisationCode,
                participantIdentifier: $participantIdentifier,
                arm: $armName,
                total: (int)round($totalScore),
                runId: $uri,
                actionBy: $actionBy,
                simulation: $simulate,
                simulationId: $simulateId
            );

            $this->entity_manager->persist($total);

            $this->logger->log(
                $studyCode, $randomisationCode, $participantIdentifier,
                sprintf('Imbalance if allocated to %s is %.6f', $armName, $totalScore),
                $actionBy, $uri, simulate: $simulate, simulateId: $simulateId
            );
        }

        $this->entity_manager->flush();

        // 6) Best set (minimum imbalance)
        $min = min($imbalanceByArm);
        $bestSet = array_keys(array_filter($imbalanceByArm, fn($s) => $s == $min));

        $this->logger->log(
            $studyCode, $randomisationCode, $participantIdentifier,
            sprintf(
                'Best set: [%s] with min imbalance=%.6f',
                implode(',', $bestSet),
                $min
            ),
            $actionBy, $uri, simulate: $simulate, simulateId: $simulateId
        );

        // --- BEGIN RATIO TIE-BREAK (maintain 2:2:1 etc. when there's a tie) ---
        $marginals = $this->allocations->countAllocationsByArm($studyCode, $randomisationCode, simulationId: $simulateId);
        $totalN   = array_sum($marginals);

        // normalised target proportions from weights
        $twSum   = array_sum($targetWeights);
        $targetP = [];
        foreach ($targetWeights as $a => $w) {
            $targetP[$a] = $twSum > 0 ? ($w / $twSum) : 0.0; // e.g. 2:2:1 -> 0.4,0.4,0.2
        }

        // shortfall of each arm *if* next subject went there
        $shortfall = [];
        foreach ($targetWeights as $a => $_) {
            $n_a = $marginals[$a] ?? 0;
            $shortfall[$a] = $targetP[$a] * ($totalN + 1) - ($n_a + 1);
        }

        // keep only the best shortfall(s) within the tie
        $shortInBest = [];
        foreach ($bestSet as $a) {
            $shortInBest[$a] = $shortfall[$a] ?? PHP_FLOAT_MIN;
        }

        $maxShort   = max($shortInBest);
        $eps        = 1e-12;
        $ratioBest  = [];
        foreach ($shortInBest as $a => $s) {
            if ($s >= $maxShort - $eps) {
                $ratioBest[] = $a;
            }
        }

        $maxShortAll = max($shortfall);
        $underTarget = [];
        foreach ($shortfall as $a => $s) {
            if ($s >= $maxShortAll - $eps) {
                $underTarget[] = $a;
            }
        }

        $intersect = array_values(array_intersect($bestSet, $underTarget));
        $finalBest = $intersect ?: $underTarget;

        $this->logger->log(
            $studyCode, $randomisationCode, $participantIdentifier,
            sprintf(
                'Ratio gate: totals=%s; shortfall=%s; under_target=[%s]; best_set_before=[%s]; best_set_after=[%s]',
                json_encode($marginals, JSON_UNESCAPED_SLASHES|JSON_UNESCAPED_UNICODE),
                json_encode($shortfall, JSON_UNESCAPED_SLASHES|JSON_UNESCAPED_UNICODE),
                implode(',', $underTarget),
                implode(',', $bestSet),
                implode(',', $finalBest)
            ),
            $actionBy, $uri, simulate: $simulate, simulateId: $simulateId
        );

        $bestSet = $finalBest;
        // --- END RATIO TIE-BREAK ---

        // Optional deterministic gap rule
        if ($specification->force_deterministic_if_gap !== null) {
            $sorted = $imbalanceByArm;
            asort($sorted, SORT_NUMERIC);
            $vals = array_values($sorted);
            if (count($vals) >= 2) {
                $gap = $vals[1] - $vals[0];
                if ($gap >= $specification->force_deterministic_if_gap) {
                    $chosen = array_key_first($sorted);

                    $this->logger->log(
                        $studyCode, $randomisationCode, $participantIdentifier,
                        sprintf('Deterministic safeguard: gap %.6f >= %.6f, choose %s', $gap, $specification->force_deterministic_if_gap, $chosen),
                        $actionBy, $uri, simulate: $simulate, simulateId: $simulateId
                    );

                    $this->addFactors($specification, $participantLevels, $studyCode, $randomisationCode, $participantIdentifier, $uri, $actionBy, $simulate, $simulateId);

                    return new RandomisationResultDto($chosen, $uri, ['type' => 'deterministic-gap']);
                }
            }
        }

        $this->logger->log(
            $studyCode, $randomisationCode, $participantIdentifier,
            sprintf('Probability inputs: best_set=[%s]; prob_best=%.3f; nonbest=%s',
                implode(',', $bestSet),
                $specification->prob_best,
                $specification->no_best_distribution->name
            ),
            $actionBy, $uri, simulate: $simulate, simulateId: $simulateId
        );

        // 7) Probabilities (biased coin + ratio-aware splits)
        $probabilities = $this->buildAssignmentProbabilities(
            bestSet: $bestSet,
            targetWeights: $targetWeights,
            probBest: $specification->prob_best,
            nonBestDistribution: $specification->no_best_distribution
        );

        if (count($bestSet) === count($targetWeights)) {
            $nonBestMassPct = (1.0 - $specification->prob_best) * 100.0;
            $this->logger->log(
                $studyCode, $randomisationCode, $participantIdentifier,
                sprintf(
                    'All arms are in the best set; non-best share (%.1f%%) has no “others”; probabilities re-normalised over best set.',
                    $nonBestMassPct
                ),
                $actionBy, $uri, simulate: $simulate, simulateId: $simulateId
            );
        }


        if (count($bestSet) > 1) {
            $this->logger->log(
                $studyCode, $randomisationCode, $participantIdentifier,
                sprintf('Tie detected across %d arms; biased-coin will split that %.1f%% mass by target ratio',
                    count($bestSet),
                    $specification->prob_best * 100
                ),
                $actionBy, $uri, simulate: $simulate, simulateId: $simulateId
            );
        }

        // Optional cap
        $beforeCap = $probabilities;

        if ($specification->cap_prob_best !== null) {

            $capApplied = false;

            $probabilities = $this->capBestProbabilities($probabilities, $bestSet, $specification->cap_prob_best);
            $cappedArms = [];
            foreach ($bestSet as $a) {
                if (isset($beforeCap[$a]) && $beforeCap[$a] > $specification->cap_prob_best) {
                    $cappedArms[] = $a;
                }
            }
            if ($cappedArms) {
                $capApplied = true;
                $this->logger->log(
                    $studyCode, $randomisationCode, $participantIdentifier,
                    sprintf('Cap applied (%.3f) to arms: [%s]', $specification->cap_prob_best, implode(',', $cappedArms)),
                    $actionBy, $uri, simulate: $simulate, simulateId: $simulateId
                );
            }
            if (!$capApplied) {
                $this->logger->log(
                    $studyCode, $randomisationCode, $participantIdentifier,
                    sprintf('Cap configured (%.3f) — no best-set arm exceeded it', $specification->cap_prob_best),
                    $actionBy, $uri, simulate: $simulate, simulateId: $simulateId
                );
            }
        }

        foreach ($probabilities as $arm => $p) {
            $this->logger->log(
                $studyCode, $randomisationCode, $participantIdentifier,
                sprintf('Arm %s assignment probability %.2f%%', $arm, $p * 100),
                $actionBy, $uri, simulate: $simulate, simulateId: $simulateId
            );
        }

        $this->logger->log($studyCode, $randomisationCode, $participantIdentifier,
            'DEBUG: marginals=' . json_encode($marginals)
            . '; bestSet=' . implode(',', $bestSet)
            . '; probs=' . json_encode($probabilities)
            . '; sum=' . sprintf('%.6f', array_sum($probabilities)),
            $actionBy, $uri, simulate: $simulate, simulateId: $simulateId
        );

        // 8) Draw allocation (no writes to allocation/factor here)
        [ $allocated, $u ] = $this->drawFromProbabilities($probabilities);

        $this->logger->log(
            $studyCode, $randomisationCode, $participantIdentifier,
            sprintf('Draw: u=%.6f allocated=%s',
                $u,
                $allocated
            ),
            $actionBy, $uri, simulate: $simulate, simulateId: $simulateId
        );

        $this->addFactors($specification, $participantLevels, $studyCode, $randomisationCode, $participantIdentifier, $uri, $actionBy, $simulate, $simulateId);

        $this->logger->log(
            $studyCode, $randomisationCode, $participantIdentifier,
            'Summary: ' . json_encode([
                'allocated' => $allocated,
                'best_set'  => array_values($bestSet),
                'u' => $u,
                'probs' => $probabilities,
                // 'counts_after' => $after, // see note below
            ], JSON_UNESCAPED_SLASHES | JSON_UNESCAPED_UNICODE),
            $actionBy, $uri, simulate: $simulate, simulateId: $simulateId
        );

        return new RandomisationResultDto($allocated, $uri, ['type' => 'algorithmic']);
    }

    // ----------------------- helpers -----------------------

    private function getCurrentCounts(
        string $studyCode,
        string $randomisationCode,
        SpecificationInterface $spec,
        ?Uuid $simulateId
    ): array
    {
        $counts = [];

        // Seed factor group with empty arrays so downstream code has keys.
        /** @var Factor $f */
        foreach ($spec->factors as $f) {
            $counts[$f->name] = []; // levels (raw values) will be discovered from DB rows only
        }

        // Pull existing counts from DB: each row must have factor, level, arm, n
        $rows = $this->factors->countsByFactorLevelArm($studyCode, $randomisationCode, $simulateId);

        foreach ($rows as $r) {
            $factor = (string)$r['factor'];
            $level = (string)$r['level']; // raw value as stored
            $arm = (string)$r['arm'];
            $n = (int)$r['n'];

            // initialise shape lazily
            if (!isset($counts[$factor][$level])) {
                $counts[$factor][$level] = [];
            }
            $counts[$factor][$level][$arm] = $n;
        }

        // Ensure every observed [factor][level] has zeroes for all arms that were not present in the rows.
        foreach ($counts as $factor => $levelMap) {
            foreach ($levelMap as $level => $armMap) {
                foreach ($spec->arms as $armName => $_) {
                    if (!isset($counts[$factor][$level][$armName])) {
                        $counts[$factor][$level][$armName] = 0;
                    }
                }
            }
        }

        return $counts;
    }

    private function resolveParticipantLevels(array $factors, array $variables): array
    {
        $out = [];
        $missing = [];
        /** @var Factor $f */
        foreach ($factors as $f) {
            $value = $variables[$f->name] ?? $variables[$f->from] ?? null;

            if ($value === null || (is_string($value) && trim($value) === '')) {
                $missing[] = $f->name;
            }
            else {
                $out[$f->name] = (string)$value;
            }
        }

        if(count($missing) > 0) {
            throw new RandomisationException('Missing factor level(s): '.implode(', ', $missing));
        }

        return $out;
    }

    private function computeImbalanceIfAdded(
        string $armToAdd,
        array $counts,
        SpecificationInterface $spec,
        array $participantLevels
    ): array {
        $total = 0.0;
        $breakdown = [];

        /** @var Factor $f */
        foreach ($spec->factors as $f) {
            $factor = $f->name;
            $weight = $f->weight;
            $level  = (string)($participantLevels[$factor] ?? '');

            // If this level hasn't been seen, initialise zeros across all arms
            if (!isset($counts[$factor][$level])) {
                $counts[$factor][$level] = [];
                foreach ($spec->arms as $armName => $_) {
                    $counts[$factor][$level][$armName] = 0;
                }
            }

            // Baseline arm totals for this factor/level
            $armTotals = [];
            foreach ($spec->arms as $armName => $_) {
                $armTotals[$armName] = (int)$counts[$factor][$level][$armName];
            }

            $currentCount = $armTotals[$armToAdd] ?? 0;
            $armTotals[$armToAdd] = $currentCount + 1; // hypothetical add

            $factorScore = $this->factorImbalance($armTotals, $spec->imbalance_metric);
            $weighted    = $weight * $factorScore;
            $total      += $weighted;

            $breakdown[$factor] = [
                'value'        => $level,
                'weight'       => $weight,
                'count'        => (int)$currentCount,
                'hypothetical' => (int)$currentCount + 1,
                'factor_score' => $factorScore,
            ];
        }

        return [$total, $breakdown];
    }

    private function factorImbalance(array $armTotals, ImbalanceMetricEnum $metric): float
    {
        $vals = array_values($armTotals);

        return match ($metric) {
            ImbalanceMetricEnum::Range => (float)(max($vals) - min($vals)),
            ImbalanceMetricEnum::AbsDiffSum => $this->absDiffSum($vals),
            ImbalanceMetricEnum::PairwiseSum => $this->pairwiseSum($vals)
        };
    }

    private function absDiffSum(array $vals): float
    {
        // Classic PS two-arm reduces to |n1-(n2+1)| etc.; for >=2 arms we use mean-centred absolute deviations.
        $n = count($vals);
        if ($n === 0) return 0.0;
        $m = array_sum($vals) / $n;
        $s = 0.0;
        foreach ($vals as $v) {
            $s += abs($v - $m);
        }
        return $s;
    }

    private function pairwiseSum(array $vals): float
    {
        $s = 0.0;
        $n = count($vals);
        for ($i = 0; $i < $n; $i++) {
            for ($j = $i + 1; $j < $n; $j++) {
                $s += abs($vals[$i] - $vals[$j]);
            }
        }
        return $s;
    }

    private function buildAssignmentProbabilities(
        array $bestSet,
        array $targetWeights, // [arm => int]
        float $probBest,
        NoBestDistributionEnum $nonBestDistribution
    ): array
    {
        $arms = array_keys($targetWeights);

        // normalised ratios
        $twSum = array_sum($targetWeights);
        $ratios = [];
        foreach ($targetWeights as $a => $w) {
            $ratios[$a] = $twSum > 0 ? $w / $twSum : 0.0;
        }

        $probs = array_fill_keys($arms, 0.0);

        // best share
        if ($bestSet) {
            $bestTotal = 0.0;
            foreach ($bestSet as $a) {
                $bestTotal += $ratios[$a] ?? 0.0;
            }
            foreach ($bestSet as $a) {
                $share = $bestTotal > 0 ? ($ratios[$a] / $bestTotal) : (1.0 / max(count($bestSet), 1));
                $probs[$a] += $probBest * $share;
            }
        }

        // non-best share
        $others = array_values(array_diff($arms, $bestSet));
        if ($others) {
            if ($nonBestDistribution === NoBestDistributionEnum::ByTargetRatio) {
                $otherTotal = 0.0;
                foreach ($others as $a) {
                    $otherTotal += $ratios[$a] ?? 0.0;
                }
                foreach ($others as $a) {
                    $share = $otherTotal > 0 ? ($ratios[$a] / $otherTotal) : (1.0 / count($others));
                    $probs[$a] += (1.0 - $probBest) * $share;
                }
            }
            else { //it is UNIFORM

                $each = (1.0 - $probBest) / count($others);
                foreach ($others as $a) {
                    $probs[$a] += $each;
                }
            }
        }

        // normalise defensively
        $sum = array_sum($probs);
        if ($sum > 0.0) {
            foreach ($probs as $k => $v) {
                $probs[$k] = $v / $sum;
            }
        }

        return $probs;
    }


    private function capBestProbabilities(array $probs, array $bestSet, float $cap): array
    {
        // 1) Apply hard cap to best-set arms and collect excess mass
        $capped = [];
        $excess = 0.0;

        foreach ($bestSet as $a) {
            if (isset($probs[$a]) && $probs[$a] > $cap) {
                $excess += $probs[$a] - $cap;
                $probs[$a] = $cap;
                $capped[$a] = true;
            }
        }

        // 2) Redistribute excess to non-capped arms only
        if ($excess > 0.0) {
            $pool = array_diff(array_keys($probs), array_keys($capped));
            if ($pool) {
                $poolSum = 0.0;
                foreach ($pool as $k) $poolSum += $probs[$k];
                if ($poolSum > 0.0) {
                    foreach ($pool as $k) {
                        $probs[$k] += $excess * ($probs[$k] / $poolSum);
                    }
                } else {
                    $each = $excess / count($pool);
                    foreach ($pool as $k) $probs[$k] += $each;
                }
            } else {
                // Infeasible: all arms capped (or no pool). Minimal relaxation:
                // add back uniformly so the vector can sum to 1 again later.
                $n = max(count($probs), 1);
                $each = $excess / $n;
                foreach ($probs as $k => $v) $probs[$k] = $v + $each;
            }
        }

        // 3) Bring total to exactly 1.0 WITHOUT touching capped arms if a pool exists
        $sum = array_sum($probs);
        $deficit = 1.0 - $sum;

        if (abs($deficit) > 1e-12) {
            $pool = array_diff(array_keys($probs), array_keys($capped));
            if ($pool) {
                $poolSum = 0.0;
                foreach ($pool as $k) $poolSum += $probs[$k];
                if ($poolSum > 0.0) {
                    foreach ($pool as $k) {
                        $probs[$k] += $deficit * ($probs[$k] / $poolSum);
                    }
                } else {
                    $each = $deficit / count($pool);
                    foreach ($pool as $k) $probs[$k] += $each;
                }
            } else {
                // Still no pool (all capped): final uniform nudge (minimal relaxation).
                $n = max(count($probs), 1);
                $each = $deficit / $n;
                foreach ($probs as $k => $v) $probs[$k] = $v + $each;
            }
        }

        // 4) Clip to [0,1] for numeric safety (does not re-normalise capped arms)
        foreach ($probs as $k => $v) {
            if ($v < 0) $probs[$k] = 0.0;
            if ($v > 1) $probs[$k] = 1.0;
        }

        return $probs;
    }


    private function drawFromProbabilities(array $probs): array
    {
        // keep deterministic order by arm name to make tests stable
        ksort($probs);
        $u = random_int(1, 1_000_000) / 1_000_000;
        $c = 0.0;
        foreach ($probs as $arm => $p) {
            $c += $p;
            if ($u <= $c) return [ $arm, $u ];
        }
        return [ array_key_first($probs), $u ];
    }

    /**
     * @param SpecificationInterface|Specification $specification
     * @param array $participantLevels
     * @param string $studyCode
     * @param string $randomisationCode
     * @param string $participantIdentifier
     * @param UuidV7 $uri
     * @param string $actionBy
     * @param DateTimeInterface|null $simulate
     * @param Uuid|null $simulateId
     * @return void
     */
    public function addFactors(SpecificationInterface|Specification $specification, array $participantLevels, string $studyCode, string $randomisationCode, string $participantIdentifier, UuidV7 $uri, string $actionBy, ?DateTimeInterface $simulate, ?Uuid $simulateId): void
    {
        foreach ($specification->factors as $factor) {

            $value = $participantLevels[$factor->name] ?? null;

            if (!is_scalar($value)) continue;

            if ($value === '') continue;

            $participantFactor = new ParticipantFactor(
                $studyCode,
                $randomisationCode,
                $participantIdentifier,
                $factor->name,
                $value,
                $uri,
                $actionBy,
                simulation: $simulate,
                simulationId: $simulateId
            );

            $this->entity_manager->persist($participantFactor);

        }

        try {
            $this->entity_manager->flush();
        } catch (UniqueConstraintViolationException $e) {
            throw new RandomisationException('Duplicate participant factor(s) detected; aborting randomisation.', previous: $e);
        }

        $this->logger->log(
            $studyCode, $randomisationCode, $participantIdentifier,
            'Participant factors added to database',
            $actionBy, $uri, simulate: $simulate, simulateId: $simulateId
        );
    }
}
