/*
 * Decompiled with CFR 0.152.
 */
package com.atlassian.stash.internal.mesh;

import com.atlassian.bitbucket.dmz.mesh.DmzMeshPartitionMigrationService;
import com.atlassian.bitbucket.dmz.mesh.MeshPartition;
import com.atlassian.bitbucket.dmz.mesh.MeshPartitionMigrationRequest;
import com.atlassian.bitbucket.dmz.mesh.MeshPartitionReplica;
import com.atlassian.bitbucket.dmz.mesh.MeshPartitionReplicaAlreadyExistsException;
import com.atlassian.bitbucket.mesh.MeshNode;
import com.atlassian.stash.internal.mesh.AvailabilityZoneRequirement;
import com.atlassian.stash.internal.mesh.BalancedDistributionRequirement;
import com.atlassian.stash.internal.mesh.DefaultRingHasher;
import com.atlassian.stash.internal.mesh.DefaultRingMap;
import com.atlassian.stash.internal.mesh.InternalMeshNode;
import com.atlassian.stash.internal.mesh.InternalMeshPartitionRegistry;
import com.atlassian.stash.internal.mesh.MeshNodeDao;
import com.atlassian.stash.internal.mesh.MeshPartitionReplicaDao;
import com.atlassian.stash.internal.mesh.MeshPartitionReplicaService;
import com.atlassian.stash.internal.mesh.PartitionAllocationStrategy;
import com.atlassian.stash.internal.mesh.ReplicaRequirement;
import com.atlassian.stash.internal.mesh.RingHasher;
import com.atlassian.stash.internal.mesh.RingMap;
import com.atlassian.stash.internal.mesh.SimpleMeshPartition;
import com.atlassian.stash.internal.mesh.SingleReplicaPerNodeRequirement;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.hash.Hashing;
import jakarta.annotation.Nonnull;
import jakarta.annotation.Nullable;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.mutable.MutableInt;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

@Component
public class DefaultPartitionAllocationStrategy
implements PartitionAllocationStrategy {
    private static final float DISTRIBUTION_FACTOR = 1.2f;
    private static final String PROP_PARTITIONS_PER_NODE = "${plugin.bitbucket-git.mesh.partitions-per-node:64}";
    private static final String PROP_REPLICATION_FACTOR = "${plugin.bitbucket-git.mesh.replication.factor:3}";
    private static final int VIRTUAL_NODE_COUNT = 8;
    private static final Logger log = LoggerFactory.getLogger(DefaultPartitionAllocationStrategy.class);
    private final RingHasher hasher;
    private final DmzMeshPartitionMigrationService migrationService;
    private final MeshNodeDao nodeDao;
    private final InternalMeshPartitionRegistry partitionRegistry;
    private final MeshPartitionReplicaDao partitionReplicaDao;
    private final MeshPartitionReplicaService partitionReplicaService;
    private final int partitionsPerNode;
    private final int replicationFactor;

    @Autowired
    public DefaultPartitionAllocationStrategy(MeshNodeDao nodeDao, MeshPartitionReplicaDao partitionReplicaDao, DmzMeshPartitionMigrationService migrationService, InternalMeshPartitionRegistry partitionRegistry, MeshPartitionReplicaService partitionReplicaService, @Value(value="${plugin.bitbucket-git.mesh.partitions-per-node:64}") int partitionsPerNode, @Value(value="${plugin.bitbucket-git.mesh.replication.factor:3}") int replicationFactor) {
        this.nodeDao = nodeDao;
        this.partitionReplicaDao = partitionReplicaDao;
        this.migrationService = migrationService;
        this.partitionRegistry = partitionRegistry;
        this.partitionReplicaService = partitionReplicaService;
        this.partitionsPerNode = partitionsPerNode;
        this.replicationFactor = replicationFactor;
        this.hasher = new DefaultRingHasher(Hashing.sha256());
    }

    @Override
    public void ensurePartitionCount() {
        List<InternalMeshNode> availableNodes = this.getAvailableNodes();
        int availableNodeCount = availableNodes.size();
        int nextPartitionId = this.partitionReplicaDao.getNextUnassignedPartition();
        int desiredPartitionCount = IntStream.rangeClosed(1, availableNodeCount).map(this::getPartitionsToAdd).sum();
        int targetPartitionCount = Math.max(nextPartitionId, desiredPartitionCount);
        this.partitionRegistry.refresh();
        Collection allPartitions = this.partitionRegistry.getAllPartitions();
        int desiredReplicaCount = Math.min(this.replicationFactor, availableNodeCount);
        Set partitions = IntStream.range(0, targetPartitionCount).boxed().collect(Collectors.toCollection(LinkedHashSet::new));
        RingMap<VirtualMeshNode> ring = this.calculateRing(availableNodes);
        Map partitionsById = allPartitions.stream().collect(Collectors.toMap(MeshPartition::getId, Function.identity()));
        Map<MeshNode, Set<MeshPartitionReplica>> replicasByNode = this.getReplicasByNode(allPartitions);
        List<MeshPartitionReplica> allReplicas = allPartitions.stream().flatMap(meshPartition -> meshPartition.getReplicas().stream()).collect(Collectors.toList());
        List<ReplicaRequirement> requirements = this.getRequirements(replicasByNode, allReplicas);
        this.partitionReplicaService.runBatch(batchOperation -> {
            partitions.forEach(partitionId -> {
                MeshPartition partition = (MeshPartition)partitionsById.get(partitionId);
                if (partition == null) {
                    partition = new SimpleMeshPartition.Builder((int)partitionId).build();
                }
                int currentPartitionReplicaCount = partition.getReplicas().size();
                IntStream.range(currentPartitionReplicaCount, desiredReplicaCount).forEach(replicaOrdinal -> {
                    RingMap.Entry potentialTargetNode = ring.nextEntry(this.hasher.hash(partitionId.intValue(), replicaOrdinal));
                    if (potentialTargetNode == null) {
                        log.debug("No target node found for partition {}", partitionId);
                        return;
                    }
                    VirtualMeshNode targetNode = this.findNodeForPartitionReplica(null, (VirtualMeshNode)potentialTargetNode.getValue(), (int)partitionId, ring, requirements);
                    if (targetNode == null) {
                        log.debug("No target node found for partition {}, will have 1 less replica", partitionId);
                        return;
                    }
                    log.debug("Assigning partition {} replica {} to node [{}]", new Object[]{partitionId, replicaOrdinal, targetNode.backingNode});
                    MeshPartitionReplica partitionReplica = batchOperation.assignPartitionReplica(targetNode.backingNode, partitionId.intValue());
                    requirements.forEach(requirement -> requirement.move(partitionReplica, null, targetNode.backingNode));
                });
            });
            return null;
        });
    }

    @Override
    public CompletionStage<Void> evacuateNode(@Nonnull MeshNode nodeToEvacuate) {
        List<InternalMeshNode> availableNodes = this.getAvailableNodes();
        if (availableNodes.size() < this.replicationFactor) {
            log.debug("Evacuate node [{}]: No nodes available to evacuate to, skipping.", (Object)nodeToEvacuate);
            return CompletableFuture.completedFuture(null);
        }
        RingMap<VirtualMeshNode> ring = this.calculateRing(availableNodes.stream().filter(node -> !Objects.equals(node, nodeToEvacuate)).toList());
        this.partitionRegistry.refresh();
        Collection allPartitions = this.partitionRegistry.getAllPartitions();
        Map<MeshNode, Set<MeshPartitionReplica>> replicasByNode = this.getReplicasByNode(allPartitions);
        MutableInt completedMigrations = new MutableInt();
        MutableInt totalMigrations = new MutableInt();
        Set<MeshPartitionReplica> replicasToMigrate = replicasByNode.get(nodeToEvacuate);
        if (replicasToMigrate == null || replicasToMigrate.isEmpty()) {
            log.debug("Evacuate node [{}]: No partitions to migrate, skipping.", (Object)nodeToEvacuate);
            return CompletableFuture.completedFuture(null);
        }
        List<MeshPartitionReplica> allReplicas = allPartitions.stream().flatMap(meshPartition -> meshPartition.getReplicas().stream()).collect(Collectors.toList());
        List<ReplicaRequirement> requirements = this.getRequirements(replicasByNode, allReplicas);
        return CompletableFuture.allOf((CompletableFuture[])((Stream)this.migrationService.batchMigrations(migrationBatch -> {
            totalMigrations.add(replicasToMigrate.size());
            return replicasToMigrate.stream().map(replica -> {
                int partitionId;
                long virtualNodeId = this.hasher.hash(replica.getId(), 8L);
                long virtualNodeHash = this.hasher.hash(nodeToEvacuate.getId(), virtualNodeId);
                RingMap.Entry nextEntry = ring.nextEntry(virtualNodeHash);
                if (nextEntry == null) {
                    log.debug("Evacuate node [{}]: No target node found for partition {}", (Object)nodeToEvacuate, (Object)replica.getPartition());
                    return CompletableFuture.completedFuture(null);
                }
                VirtualMeshNode targetNode = (VirtualMeshNode)nextEntry.getValue();
                VirtualMeshNode redirectedTargetNode = this.findNodeForPartitionReplica(null, targetNode, partitionId = replica.getPartition(), ring, requirements);
                if (redirectedTargetNode == null) {
                    log.debug("Evacuate node [{}]: No target node found for partition {}, will have 1 less replica when the node is evacuated", (Object)nodeToEvacuate, (Object)partitionId);
                    this.partitionReplicaService.unassignPartitionReplica(nodeToEvacuate, partitionId);
                    return CompletableFuture.completedFuture(null);
                }
                log.debug("Evacuate node [{}]: Starting migration of partition {} to node [{}]", new Object[]{nodeToEvacuate, partitionId, redirectedTargetNode.backingNode});
                try {
                    requirements.forEach(requirement -> requirement.move((MeshPartitionReplica)replica, nodeToEvacuate, redirectedTargetNode.backingNode));
                    return migrationBatch.migrate(new MeshPartitionMigrationRequest.Builder(partitionId, nodeToEvacuate, redirectedTargetNode.backingNode).build()).whenComplete((ignored, throwable) -> log.debug("Evacuate node [{}]: Finished migrating {}. Migrated {}/{} partitions.", new Object[]{nodeToEvacuate, partitionId, completedMigrations.incrementAndGet(), replicasToMigrate.size()})).toCompletableFuture();
                }
                catch (MeshPartitionReplicaAlreadyExistsException e) {
                    log.error("Failed to evacuate node [{}]", (Object)nodeToEvacuate, (Object)e);
                    return CompletableFuture.failedFuture(e);
                }
            });
        })).toArray(CompletableFuture[]::new)).whenComplete((ignored, throwable) -> {
            if (throwable != null) {
                log.error("Exception occurred while evacuating node [{}]", (Object)nodeToEvacuate, throwable);
            } else {
                log.info("Evacuate node [{}]: Migrated {} partitions.", (Object)nodeToEvacuate, (Object)totalMigrations);
            }
        });
    }

    @Override
    public void rebalancePartitionsToNode(@Nonnull MeshNode targetNode) {
        List<InternalMeshNode> availableNodes = this.getAvailableNodes();
        if (this.replicationFactor >= availableNodes.size()) {
            log.debug("Skipping rebalancing to node [{}]; not enough nodes to rebalance", (Object)targetNode);
            return;
        }
        RingMap<VirtualMeshNode> ring = this.calculateRing(availableNodes.stream().toList());
        Collection allPartitions = this.partitionRegistry.getAllPartitions();
        Map<MeshNode, Set<MeshPartitionReplica>> replicasByNode = this.getReplicasByNode(allPartitions);
        Map<VirtualMeshNode, Set<MeshPartitionReplica>> partitionsByVirtualNode = this.getPartitionsByVirtualNode(allPartitions);
        List<MeshPartitionReplica> allReplicas = allPartitions.stream().flatMap(meshPartition -> meshPartition.getReplicas().stream()).collect(Collectors.toList());
        List<ReplicaRequirement> requirements = this.getRequirements(replicasByNode, allReplicas);
        ArrayList<MeshPartitionMigrationRequest> migrations = new ArrayList<MeshPartitionMigrationRequest>();
        HashSet migratedReplicas = new HashSet();
        IntStream.range(0, 8).forEach(virtualNodeId -> {
            long targetNodeHash = this.hasher.hash(targetNode.getId(), virtualNodeId);
            VirtualMeshNode virtualTargetNode = new VirtualMeshNode(targetNode, virtualNodeId);
            RingMap.Entry previousNodeEntry = ring.previousEntry(targetNodeHash);
            while (previousNodeEntry != null && ((VirtualMeshNode)previousNodeEntry.getValue()).backingNode.equals((Object)targetNode)) {
                previousNodeEntry = ring.previousEntry(previousNodeEntry.getKey());
            }
            if (previousNodeEntry == null) {
                log.error("Couldn't find previous node for target node [{}]", (Object)targetNode);
                return;
            }
            long previousNodeHash = previousNodeEntry.getKey();
            RingMap.Entry nextNodeEntry = ring.nextEntry(targetNodeHash);
            while (nextNodeEntry != null && ((VirtualMeshNode)nextNodeEntry.getValue()).backingNode.equals((Object)targetNode)) {
                nextNodeEntry = ring.nextEntry(nextNodeEntry.getKey());
            }
            if (nextNodeEntry == null) {
                log.error("Couldn't find next node for target node [{}]", (Object)targetNode);
                return;
            }
            long nextNodeHash = nextNodeEntry.getKey();
            VirtualMeshNode nextNode = (VirtualMeshNode)nextNodeEntry.getValue();
            Set nextNodeReplicas = (Set)partitionsByVirtualNode.get(nextNode);
            Set replicasToMigrate = nextNodeReplicas == null ? Collections.emptySet() : nextNodeReplicas.stream().filter(replica -> this.hasher.hashWithBounds(replica.getId(), previousNodeHash, nextNodeHash) < targetNodeHash).collect(Collectors.toSet());
            partitionsByVirtualNode.forEach((virtualNode, partitionReplicas) -> partitionReplicas.forEach(replica -> {
                if (migratedReplicas.contains(replica.getId())) {
                    return;
                }
                VirtualMeshNode startingNode = replicasToMigrate.contains(replica) ? virtualTargetNode : virtualNode;
                VirtualMeshNode redirectedTargetNode = this.findNodeForPartitionReplica((MeshPartitionReplica)replica, startingNode, replica.getPartition(), ring, requirements);
                if (redirectedTargetNode == null || Objects.equals(redirectedTargetNode.backingNode, virtualNode.backingNode)) {
                    return;
                }
                migrations.add(new MeshPartitionMigrationRequest.Builder(replica.getPartition(), virtualNode.backingNode, redirectedTargetNode.backingNode).build());
                requirements.forEach(requirement -> requirement.move((MeshPartitionReplica)replica, virtualNode.backingNode, redirectedTargetNode.backingNode));
                migratedReplicas.add(replica.getId());
            }));
        });
        log.trace("Computed migrations: {}", (Object)StringUtils.join(migrations, (String)"\n"));
        List<MeshPartitionMigrationRequest> effectiveMigrations = DefaultPartitionAllocationStrategy.effectiveMigrations(migrations);
        log.trace("Effective migrations: {}", (Object)StringUtils.join(effectiveMigrations, (String)"\n"));
        log.debug("Rebalance to node [{}]: {} partitions will be migrated from other nodes", (Object)targetNode, (Object)effectiveMigrations.size());
        MutableInt completedMigrations = new MutableInt();
        this.migrationService.batchMigrations(migrationBatch -> {
            effectiveMigrations.forEach(request -> {
                log.debug("Rebalance to node [{}]: Partition {} will be migrated from node [{}] -> [{}]", new Object[]{targetNode, request.getPartitionId(), request.getSourceNode(), request.getTargetNode()});
                migrationBatch.migrate(request).whenComplete((ignored, throwable) -> log.debug("Rebalance to node [{}]: Finished migrating partition {}. Migrated {}/{} partitions.", new Object[]{targetNode, request.getPartitionId(), completedMigrations.incrementAndGet(), effectiveMigrations.size()}));
            });
            return null;
        });
    }

    @VisibleForTesting
    static List<MeshPartitionMigrationRequest> effectiveMigrations(List<MeshPartitionMigrationRequest> requests) {
        Map<Integer, List<MeshPartitionMigrationRequest>> requestsPerPartition = requests.stream().collect(Collectors.groupingBy(MeshPartitionMigrationRequest::getPartitionId));
        ArrayList<MeshPartitionMigrationRequest> effectivePartitionMigrationRequests = new ArrayList<MeshPartitionMigrationRequest>();
        requestsPerPartition.forEach((partitionId, migrationRequests) -> {
            boolean changed;
            LinkedHashMap<MeshNode, MeshNode> migrationTargets = new LinkedHashMap<MeshNode, MeshNode>();
            migrationRequests.forEach(request -> migrationTargets.put(request.getSourceNode(), request.getTargetNode()));
            block0: do {
                changed = false;
                for (Map.Entry entry : new LinkedHashSet(migrationTargets.entrySet())) {
                    MeshNode previousCandidate;
                    MeshNode source2 = (MeshNode)entry.getKey();
                    MeshNode target2 = (MeshNode)entry.getValue();
                    migrationTargets.remove(source2);
                    MeshNode candidate = target2;
                    do {
                        previousCandidate = candidate;
                        if ((candidate = (MeshNode)migrationTargets.remove(candidate)) == null) continue;
                        changed = true;
                    } while (candidate != null);
                    MeshNode effectiveTarget = previousCandidate;
                    if (!Objects.equals(source2, effectiveTarget)) {
                        migrationTargets.put(source2, effectiveTarget);
                    }
                    if (!changed) continue;
                    continue block0;
                }
            } while (changed);
            migrationTargets.forEach((source, target) -> effectivePartitionMigrationRequests.add(new MeshPartitionMigrationRequest.Builder(partitionId.intValue(), source, target).build()));
        });
        return effectivePartitionMigrationRequests;
    }

    private RingMap<VirtualMeshNode> calculateRing(List<InternalMeshNode> targetNodes) {
        DefaultRingMap<VirtualMeshNode> ring = new DefaultRingMap<VirtualMeshNode>();
        targetNodes.forEach(node -> IntStream.range(0, 8).forEach(vnodeId -> ring.put(this.hasher.hash(node.getId(), vnodeId), new VirtualMeshNode((MeshNode)node, vnodeId))));
        return ring;
    }

    private VirtualMeshNode findNodeForPartitionReplica(@Nullable MeshPartitionReplica existingReplica, VirtualMeshNode node, int partitionId, RingMap<VirtualMeshNode> ring, List<ReplicaRequirement> requirements) {
        VirtualMeshNode candidate = node;
        long candidateHash = this.hasher.hash(candidate.backingNode.getId(), candidate.virtualNodeId);
        ArrayDeque<ReplicaRequirement> localRequirements = new ArrayDeque<ReplicaRequirement>(requirements);
        while (!localRequirements.isEmpty()) {
            do {
                boolean found = true;
                for (ReplicaRequirement requirement : localRequirements) {
                    if (requirement.test(existingReplica, candidate.backingNode, partitionId)) continue;
                    found = false;
                    break;
                }
                if (found) {
                    return candidate;
                }
                RingMap.Entry<VirtualMeshNode> nextNodeEntry = ring.nextEntry(candidateHash);
                if (nextNodeEntry == null) {
                    log.error("Couldn't find the next node for node [{}]", (Object)candidate.backingNode);
                    return null;
                }
                candidate = nextNodeEntry.getValue();
                candidateHash = nextNodeEntry.getKey();
            } while (!candidate.equals(node));
            log.trace("Couldn't find a node that satisfies all requirements for partition {}, dropping one and retrying. Number of requirements left: {}", (Object)partitionId, (Object)(localRequirements.size() - 1));
            localRequirements.poll();
            candidate = node;
        }
        log.debug("Couldn't find a node that satisfies requirements for partition {}. Started at node [{}]", (Object)partitionId, (Object)node.backingNode);
        return null;
    }

    private List<InternalMeshNode> getAvailableNodes() {
        List availableNodes = this.nodeDao.getAll();
        return availableNodes.stream().filter(node -> node.getInternalState() != MeshNode.State.DELETING).toList();
    }

    private Map<VirtualMeshNode, Set<MeshPartitionReplica>> getPartitionsByVirtualNode(Collection<MeshPartition> allPartitions) {
        return Collections.unmodifiableMap(allPartitions.stream().flatMap(meshPartition -> meshPartition.getReplicas().stream()).collect(Collectors.groupingBy(replica -> new VirtualMeshNode(replica.getNode(), this.hasher.hashWithUpperBound(replica.getId(), 8L)), Collectors.collectingAndThen(Collectors.toCollection(() -> new TreeSet<MeshPartitionReplica>(Comparator.comparingLong(MeshPartitionReplica::getId))), Collections::unmodifiableSet))));
    }

    private int getPartitionsToAdd(long nodeOrdinal) {
        if (nodeOrdinal > (long)this.replicationFactor) {
            return this.partitionsPerNode;
        }
        int extraPartitions = 0;
        for (int i = 1; i < this.replicationFactor; ++i) {
            extraPartitions += (this.replicationFactor - i) * this.partitionsPerNode;
        }
        int correction = extraPartitions / (this.replicationFactor * this.replicationFactor);
        return this.partitionsPerNode - correction;
    }

    private Map<MeshNode, Set<MeshPartitionReplica>> getReplicasByNode(Collection<MeshPartition> allPartitions) {
        return Collections.unmodifiableMap(allPartitions.stream().flatMap(meshPartition -> meshPartition.getReplicas().stream()).collect(Collectors.groupingBy(MeshPartitionReplica::getNode, Collectors.collectingAndThen(Collectors.toCollection(() -> new TreeSet<MeshPartitionReplica>(Comparator.comparingLong(MeshPartitionReplica::getId))), Collections::unmodifiableSet))));
    }

    private List<ReplicaRequirement> getRequirements(Map<MeshNode, Set<MeshPartitionReplica>> replicasByNode, List<MeshPartitionReplica> allPartitions) {
        return List.of(new BalancedDistributionRequirement(replicasByNode, 1.2f), new AvailabilityZoneRequirement(allPartitions), new SingleReplicaPerNodeRequirement(replicasByNode));
    }

    record VirtualMeshNode(MeshNode backingNode, long virtualNodeId) {
    }
}

